diff --git a/src/program_builder.h b/src/program_builder.h index 97ef17fdaf..c4f2a2fdec 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -1737,6 +1737,16 @@ class ProgramBuilder { Expr(std::forward(rhs))); } + /// @param lhs the left hand argument to the equal expression + /// @param rhs the right hand argument to the equal expression + /// @returns a `ast::BinaryExpression` comparing `lhs` equal to `rhs` + template + const ast::BinaryExpression* Equal(LHS&& lhs, RHS&& rhs) { + return create(ast::BinaryOp::kEqual, + Expr(std::forward(lhs)), + Expr(std::forward(rhs))); + } + /// @param source the source information /// @param arr the array argument for the array accessor expression /// @param idx the index argument for the array accessor expression diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 25bcd445e5..9ee50fe513 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2166,6 +2166,7 @@ bool Resolver::Statement(const ast::Statement* stmt) { } return false; } + current_function_->SetHasDiscard(); return true; } if (stmt->Is()) { diff --git a/src/sem/function.h b/src/sem/function.h index d8b666281e..d8a58a890c 100644 --- a/src/sem/function.h +++ b/src/sem/function.h @@ -229,6 +229,13 @@ class Function : public Castable { /// @returns true if `sym` is an ancestor entry point of this function bool HasAncestorEntryPoint(Symbol sym) const; + /// Sets that this function has a discard statement + void SetHasDiscard() { has_discard_ = true; } + + /// Returns true if this function has a discard statement + /// @returns true if this function has a discard statement + bool HasDiscard() const { return has_discard_; } + private: VariableBindings TransitivelyReferencedSamplerVariablesImpl( ast::SamplerKind kind) const; @@ -245,6 +252,7 @@ class Function : public Castable { std::vector direct_calls_; std::vector callsites_; std::vector ancestor_entry_points_; + bool has_discard_ = false; }; } // namespace sem diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 92fe82e391..27b866667a 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -2339,12 +2339,53 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) { out << ") {"; } + if (sem->HasDiscard() && !sem->ReturnType()->Is()) { + // BUG(crbug.com/tint/1081): work around non-void functions with discard + // failing compilation sometimes + if (!EmitFunctionBodyWithDiscard(func)) { + return false; + } + } else { + if (!EmitStatementsWithIndent(func->body->statements)) { + return false; + } + } + + line() << "}"; + + return true; +} + +bool GeneratorImpl::EmitFunctionBodyWithDiscard(const ast::Function* func) { + // FXC sometimes fails to compile functions that discard with 'Not all control + // paths return a value'. We work around this by wrapping the function body + // within an "if (true) { } return ;" so that + // there is always an (unused) return statement. + + auto* sem = builder_.Sem().Get(func); + TINT_ASSERT(Writer, sem->HasDiscard() && !sem->ReturnType()->Is()); + + ScopedIndent si(this); + line() << "if (true) {"; + if (!EmitStatementsWithIndent(func->body->statements)) { return false; } line() << "}"; + // Return an unused result that matches the type of the return value + auto name = builder_.Symbols().NameFor(builder_.Symbols().New("unused")); + { + auto out = line(); + if (!EmitTypeAndName(out, sem->ReturnType(), ast::StorageClass::kNone, + ast::Access::kReadWrite, name)) { + return false; + } + out << ";"; + } + line() << "return " << name << ";"; + return true; } diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 1945d4e6ba..03f1c0eb8a 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -240,7 +240,11 @@ class GeneratorImpl : public TextGenerator { /// @param func the function to generate /// @returns true if the function was emitted bool EmitFunction(const ast::Function* func); - + /// Handles emitting the function body if it discards to work around a FXC + /// compilation bug. + /// @param func the function with the body to emit + /// @returns true if the function was emitted + bool EmitFunctionBodyWithDiscard(const ast::Function* func); /// Handles emitting a global variable /// @param global the global variable /// @returns true on success diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 41b2294203..3af0f4f344 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -809,6 +809,51 @@ my_func_ret my_func() { )"); } +TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithDiscardAndVoidReturn) { + Func("my_func", {Param("a", ty.i32())}, ty.void_(), + { + If(Equal("a", 0), // + Block(create())), + Return(), + }); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_EQ(gen.result(), R"(void my_func(int a) { + if ((a == 0)) { + discard; + } + return; +} +)"); +} + +TEST_F(HlslGeneratorImplTest_Function, + Emit_Function_WithDiscardAndNonVoidReturn) { + Func("my_func", {Param("a", ty.i32())}, ty.i32(), + { + If(Equal("a", 0), // + Block(create())), + Return(42), + }); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_EQ(gen.result(), R"(int my_func(int a) { + if (true) { + if ((a == 0)) { + discard; + } + return 42; + } + int unused; + return unused; +} +)"); +} + // https://crbug.com/tint/297 TEST_F(HlslGeneratorImplTest_Function, Emit_Multiple_EntryPoint_With_Same_ModuleVar) { diff --git a/test/bug/dawn/947.wgsl.expected.hlsl b/test/bug/dawn/947.wgsl.expected.hlsl index 5c18116c6a..da0ff6a983 100644 --- a/test/bug/dawn/947.wgsl.expected.hlsl +++ b/test/bug/dawn/947.wgsl.expected.hlsl @@ -46,12 +46,16 @@ struct tint_symbol_5 { }; float4 fs_main_inner(float2 texcoord) { - float2 clampedTexcoord = clamp(texcoord, float2(0.0f, 0.0f), float2(1.0f, 1.0f)); - if (!(all((clampedTexcoord == texcoord)))) { - discard; + if (true) { + float2 clampedTexcoord = clamp(texcoord, float2(0.0f, 0.0f), float2(1.0f, 1.0f)); + if (!(all((clampedTexcoord == texcoord)))) { + discard; + } + float4 srcColor = myTexture.Sample(mySampler, texcoord); + return srcColor; } - float4 srcColor = myTexture.Sample(mySampler, texcoord); - return srcColor; + float4 unused; + return unused; } tint_symbol_5 fs_main(tint_symbol_4 tint_symbol_3) { diff --git a/test/bug/tint/1081.wgsl b/test/bug/tint/1081.wgsl new file mode 100644 index 0000000000..7350e4cda2 --- /dev/null +++ b/test/bug/tint/1081.wgsl @@ -0,0 +1,18 @@ +fn f(x : i32) -> i32 { + if (x == 10) { + discard; + } + return x; +} + +[[stage(fragment)]] +fn main([[location(1)]] x: vec3) -> [[location(2)]] i32 { + var y = x.x; + loop { + let r = f(y); + if (r == 0) { + break; + } + } + return y; +} diff --git a/test/bug/tint/1081.wgsl.expected.hlsl b/test/bug/tint/1081.wgsl.expected.hlsl new file mode 100644 index 0000000000..056c225261 --- /dev/null +++ b/test/bug/tint/1081.wgsl.expected.hlsl @@ -0,0 +1,39 @@ +bug/tint/1081.wgsl:9:25 warning: integral user-defined fragment inputs must have a flat interpolation attribute +fn main([[location(1)]] x: vec3) -> [[location(2)]] i32 { + ^ + +int f(int x) { + if (true) { + if ((x == 10)) { + discard; + } + return x; + } + int unused; + return unused; +} + +struct tint_symbol_1 { + int3 x : TEXCOORD1; +}; +struct tint_symbol_2 { + int value : SV_Target2; +}; + +int main_inner(int3 x) { + int y = x.x; + while (true) { + const int r = f(y); + if ((r == 0)) { + break; + } + } + return y; +} + +tint_symbol_2 main(tint_symbol_1 tint_symbol) { + const int inner_result = main_inner(tint_symbol.x); + tint_symbol_2 wrapper_result = (tint_symbol_2)0; + wrapper_result.value = inner_result; + return wrapper_result; +} diff --git a/test/bug/tint/1081.wgsl.expected.msl b/test/bug/tint/1081.wgsl.expected.msl new file mode 100644 index 0000000000..0d57311f34 --- /dev/null +++ b/test/bug/tint/1081.wgsl.expected.msl @@ -0,0 +1,39 @@ +bug/tint/1081.wgsl:9:25 warning: integral user-defined fragment inputs must have a flat interpolation attribute +fn main([[location(1)]] x: vec3) -> [[location(2)]] i32 { + ^ + +#include + +using namespace metal; +struct tint_symbol_2 { + int3 x [[user(locn1)]]; +}; +struct tint_symbol_3 { + int value [[color(2)]]; +}; + +int f(int x) { + if ((x == 10)) { + discard_fragment(); + } + return x; +} + +int tint_symbol_inner(int3 x) { + int y = x[0]; + while (true) { + int const r = f(y); + if ((r == 0)) { + break; + } + } + return y; +} + +fragment tint_symbol_3 tint_symbol(tint_symbol_2 tint_symbol_1 [[stage_in]]) { + int const inner_result = tint_symbol_inner(tint_symbol_1.x); + tint_symbol_3 wrapper_result = {}; + wrapper_result.value = inner_result; + return wrapper_result; +} + diff --git a/test/bug/tint/1081.wgsl.expected.spvasm b/test/bug/tint/1081.wgsl.expected.spvasm new file mode 100644 index 0000000000..58810deb88 --- /dev/null +++ b/test/bug/tint/1081.wgsl.expected.spvasm @@ -0,0 +1,83 @@ +bug/tint/1081.wgsl:9:25 warning: integral user-defined fragment inputs must have a flat interpolation attribute +fn main([[location(1)]] x: vec3) -> [[location(2)]] i32 { + ^ + +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 41 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %x_1 %value + OpExecutionMode %main OriginUpperLeft + OpName %x_1 "x_1" + OpName %value "value" + OpName %f "f" + OpName %x "x" + OpName %main_inner "main_inner" + OpName %x_0 "x" + OpName %y "y" + OpName %main "main" + OpDecorate %x_1 Location 1 + OpDecorate %x_1 Flat + OpDecorate %value Location 2 + %int = OpTypeInt 32 1 + %v3int = OpTypeVector %int 3 +%_ptr_Input_v3int = OpTypePointer Input %v3int + %x_1 = OpVariable %_ptr_Input_v3int Input +%_ptr_Output_int = OpTypePointer Output %int + %7 = OpConstantNull %int + %value = OpVariable %_ptr_Output_int Output %7 + %8 = OpTypeFunction %int %int + %int_10 = OpConstant %int 10 + %bool = OpTypeBool + %17 = OpTypeFunction %int %v3int +%_ptr_Function_int = OpTypePointer Function %int + %int_0 = OpConstant %int 0 + %void = OpTypeVoid + %35 = OpTypeFunction %void + %f = OpFunction %int None %8 + %x = OpFunctionParameter %int + %11 = OpLabel + %13 = OpIEqual %bool %x %int_10 + OpSelectionMerge %15 None + OpBranchConditional %13 %16 %15 + %16 = OpLabel + OpKill + %15 = OpLabel + OpReturnValue %x + OpFunctionEnd + %main_inner = OpFunction %int None %17 + %x_0 = OpFunctionParameter %v3int + %20 = OpLabel + %y = OpVariable %_ptr_Function_int Function %7 + %21 = OpCompositeExtract %int %x_0 0 + OpStore %y %21 + OpBranch %24 + %24 = OpLabel + OpLoopMerge %25 %26 None + OpBranch %27 + %27 = OpLabel + %29 = OpLoad %int %y + %28 = OpFunctionCall %int %f %29 + %31 = OpIEqual %bool %28 %int_0 + OpSelectionMerge %32 None + OpBranchConditional %31 %33 %32 + %33 = OpLabel + OpBranch %25 + %32 = OpLabel + OpBranch %26 + %26 = OpLabel + OpBranch %24 + %25 = OpLabel + %34 = OpLoad %int %y + OpReturnValue %34 + OpFunctionEnd + %main = OpFunction %void None %35 + %38 = OpLabel + %40 = OpLoad %v3int %x_1 + %39 = OpFunctionCall %int %main_inner %40 + OpStore %value %39 + OpReturn + OpFunctionEnd diff --git a/test/bug/tint/1081.wgsl.expected.wgsl b/test/bug/tint/1081.wgsl.expected.wgsl new file mode 100644 index 0000000000..ab2a38dbe8 --- /dev/null +++ b/test/bug/tint/1081.wgsl.expected.wgsl @@ -0,0 +1,22 @@ +bug/tint/1081.wgsl:9:25 warning: integral user-defined fragment inputs must have a flat interpolation attribute +fn main([[location(1)]] x: vec3) -> [[location(2)]] i32 { + ^ + +fn f(x : i32) -> i32 { + if ((x == 10)) { + discard; + } + return x; +} + +[[stage(fragment)]] +fn main([[location(1)]] x : vec3) -> [[location(2)]] i32 { + var y = x.x; + loop { + let r = f(y); + if ((r == 0)) { + break; + } + } + return y; +}