diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index cf3ba179c9..521fec4300 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -321,9 +321,9 @@ class UniformityGraph { // Look at which nodes are reachable from "RequiredToBeUniform". { utils::UniqueVector reachable; - Traverse(current_function_->required_to_be_uniform, reachable); + Traverse(current_function_->required_to_be_uniform, &reachable); if (reachable.contains(current_function_->may_be_non_uniform)) { - MakeError(); + MakeError(*current_function_, current_function_->may_be_non_uniform); return false; } if (reachable.contains(current_function_->cf_start)) { @@ -343,7 +343,7 @@ class UniformityGraph { // Look at which nodes are reachable from "CF_return" { utils::UniqueVector reachable; - Traverse(current_function_->cf_return, reachable); + Traverse(current_function_->cf_return, &reachable); if (reachable.contains(current_function_->may_be_non_uniform)) { current_function_->function_tag = SubsequentControlFlowMayBeNonUniform; } @@ -362,7 +362,7 @@ class UniformityGraph { // If "Value_return" exists, look at which nodes are reachable from it if (current_function_->value_return) { utils::UniqueVector reachable; - Traverse(current_function_->value_return, reachable); + Traverse(current_function_->value_return, &reachable); if (reachable.contains(current_function_->may_be_non_uniform)) { current_function_->function_tag = ReturnValueMayBeNonUniform; } @@ -388,7 +388,7 @@ class UniformityGraph { current_function_->ResetVisited(); utils::UniqueVector reachable; - Traverse(current_function_->parameters[i].pointer_return_value, reachable); + Traverse(current_function_->parameters[i].pointer_return_value, &reachable); if (reachable.contains(current_function_->may_be_non_uniform)) { current_function_->parameters[i].pointer_may_become_non_uniform = true; } @@ -1234,9 +1234,11 @@ class UniformityGraph { /// Recursively traverse a graph starting at `node`, inserting all nodes that are reached into /// `reachable`. /// @param node the starting node - /// @param reachable the set of reachable nodes to populate - void Traverse(Node* node, utils::UniqueVector& reachable) { - reachable.add(node); + /// @param reachable the set of reachable nodes to populate, if required + void Traverse(Node* node, utils::UniqueVector* reachable = nullptr) { + if (reachable) { + reachable->add(node); + } for (auto* to : node->edges) { if (to->visited_from == nullptr) { to->visited_from = node; @@ -1245,48 +1247,113 @@ class UniformityGraph { } } - /// Generate an error for a required_to_be_uniform->may_be_non_uniform path. - void MakeError() { - // Trace back to find a node that is required to be uniform that was reachable from a - // non-uniform value or control flow node. - Node* current = current_function_->may_be_non_uniform; - while (current) { - TINT_ASSERT(Resolver, current->visited_from); - if (current->visited_from == current_function_->required_to_be_uniform) { + /// Recursively descend through the function called by `call` and the functions that it calls in + /// order to find a call to a builtin function that requires uniformity. + const ast::CallExpression* FindBuiltinThatRequiresUniformity(const ast::CallExpression* call) { + auto* target = sem_.Get(call)->Target(); + if (target->Is()) { + // This is a call to a builtin, so we must be done. + return call; + } else if (auto* user = target->As()) { + // This is a call to a user-defined function, so inspect the functions called by that + // function and look for one whose node has an edge from the RequiredToBeUniform node. + auto& target_info = functions_.at(user->Declaration()); + for (auto* call_node : target_info.required_to_be_uniform->edges) { + if (call_node->arg_index == std::numeric_limits::max()) { + auto* child_call = call_node->ast->As(); + return FindBuiltinThatRequiresUniformity(child_call); + } + } + TINT_ASSERT(Resolver, false && "unable to find child call with uniformity requirement"); + } else { + TINT_ASSERT(Resolver, false && "unexpected call expression type"); + } + return nullptr; + } + + /// Generate an error message for a uniformity issue. + /// @param function the function that the diagnostic is being produced for + /// @param source_node the node that has caused a uniformity issue in `function` + /// @param note `true` if the diagnostic should be emitted as a note + void MakeError(FunctionInfo& function, Node* source_node, bool note = false) { + // Helper to produce a diagnostic message with the severity required by this invocation of + // the `MakeError` function. + auto report = [&](Source source, std::string msg) { + // TODO(jrprice): Switch to error instead of warning when feedback has settled. + diag::Diagnostic error{}; + error.severity = note ? diag::Severity::Note : diag::Severity::Warning; + error.system = diag::System::Resolver; + error.source = source; + error.message = msg; + diagnostics_.add(std::move(error)); + }; + + // Traverse the graph to generate a path from RequiredToBeUniform to the source node. + function.ResetVisited(); + Traverse(function.required_to_be_uniform); + TINT_ASSERT(Resolver, source_node->visited_from); + + // Trace back through the graph to find a node that is required to be uniform that has + // a path to the source node. + Node* cause = source_node; + while (cause) { + if (cause->visited_from == function.required_to_be_uniform) { break; } - current = current->visited_from; + cause = cause->visited_from; } - // The node will always have an corresponding call expression. - auto* call = current->ast->As(); + // The node will always have a corresponding call expression. + auto* call = cause->ast->As(); TINT_ASSERT(Resolver, call); auto* target = sem_.Get(call)->Target(); - std::string name; + std::string func_name; if (auto* builtin = target->As()) { - name = builtin->str(); + func_name = builtin->str(); } else if (auto* user = target->As()) { - name = builder_->Symbols().NameFor(user->Declaration()->symbol); + func_name = builder_->Symbols().NameFor(user->Declaration()->symbol); } - // TODO(jrprice): Switch to error instead of warning when feedback has settled. - if (current->arg_index != std::numeric_limits::max()) { + if (cause->arg_index != std::numeric_limits::max()) { // The requirement was on a function parameter. auto param_name = builder_->Symbols().NameFor( - target->Parameters()[current->arg_index]->Declaration()->symbol); - diagnostics_.add_warning( - diag::System::Resolver, - "parameter '" + param_name + "' of '" + name + "' must be uniform", - call->args[current->arg_index]->source); - // TODO(jrprice): Show the reason why. + target->Parameters()[cause->arg_index]->Declaration()->symbol); + report(call->args[cause->arg_index]->source, + "parameter '" + param_name + "' of '" + func_name + "' must be uniform"); + + // If this is a call to a user-defined function, add a note to show the reason that the + // parameter is required to be uniform. + if (auto* user = target->As()) { + auto& next_function = functions_.at(user->Declaration()); + Node* next_cause = next_function.parameters[cause->arg_index].init_value; + MakeError(next_function, next_cause, true); + } } else { // The requirement was on a function callsite. - diagnostics_.add_warning(diag::System::Resolver, - "'" + name + "' must only be called from uniform control flow", - call->source); - // TODO(jrprice): Show full call stack to the problematic builtin. + report(call->source, + "'" + func_name + "' must only be called from uniform control flow"); + + // If this is a call to a user-defined function, add a note to show the builtin that + // causes the uniformity requirement. + auto* innermost_call = FindBuiltinThatRequiresUniformity(call); + if (innermost_call != call) { + // Determine whether the builtin is being called directly or indirectly. + bool indirect = false; + if (sem_.Get(call)->Target()->As() != + sem_.Get(innermost_call)->Stmt()->Function()) { + indirect = true; + } + + auto* builtin = sem_.Get(innermost_call)->Target()->As(); + diagnostics_.add_note(diag::System::Resolver, + "'" + func_name + "' requires uniformity because it " + + (indirect ? "indirectly " : "") + "calls " + + builtin->str(), + innermost_call->source); + } } + // TODO(jrprice): Show the source of non-uniformity. } }; diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc index bb33f894d2..d81d7d6168 100644 --- a/src/tint/resolver/uniformity_test.cc +++ b/src/tint/resolver/uniformity_test.cc @@ -466,6 +466,10 @@ fn bar() { R"(test:11:7 warning: parameter 'i' of 'foo' must be uniform foo(rw); ^^ + +test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ )"); } @@ -3229,6 +3233,34 @@ fn foo() { )"); } +TEST_F(UniformityAnalysisTest, LoadNonUniformThroughPointerParameter) { + auto src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn bar(p : ptr) { + if (*p == 0) { + workgroupBarrier(); + } +} + +fn foo() { + var v = non_uniform; + bar(&v); +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:12:7 warning: parameter 'p' of 'bar' must be uniform + bar(&v); + ^ + +test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ +)"); +} + TEST_F(UniformityAnalysisTest, LoadUniformThroughPointer) { auto src = R"( fn foo() { @@ -3256,6 +3288,23 @@ fn foo() { RunTest(src, true); } +TEST_F(UniformityAnalysisTest, LoadUniformThroughPointerParameter) { + auto src = R"( +fn bar(p : ptr) { + if (*p == 0) { + workgroupBarrier(); + } +} + +fn foo() { + var v = 42; + bar(&v); +} +)"; + + RunTest(src, true); +} + TEST_F(UniformityAnalysisTest, StoreNonUniformAfterCapturingPointer) { auto src = R"( @group(0) @binding(0) var non_uniform : i32; @@ -4884,5 +4933,114 @@ fn foo() { RunTest(src, true); } +//////////////////////////////////////////////////////////////////////////////// +/// Tests for the quality of the error messages produced by the analysis. +//////////////////////////////////////////////////////////////////////////////// + +TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinDirectly) { + auto src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn foo() { + workgroupBarrier(); +} + +fn main() { + if (non_uniform == 42) { + foo(); + } +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:10:5 warning: 'foo' must only be called from uniform control flow + foo(); + ^^^ + +test:5:3 note: 'foo' requires uniformity because it calls workgroupBarrier + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ +)"); +} + +TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinIndirectly) { + auto src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn zoo() { + workgroupBarrier(); +} + +fn bar() { + zoo(); +} + +fn foo() { + bar(); +} + +fn main() { + if (non_uniform == 42) { + foo(); + } +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:18:5 warning: 'foo' must only be called from uniform control flow + foo(); + ^^^ + +test:5:3 note: 'foo' requires uniformity because it indirectly calls workgroupBarrier + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ +)"); +} + +TEST_F(UniformityAnalysisTest, Error_ParametersRequireUniformityInChain) { + auto src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn zoo(a : i32) { + if (a == 42) { + workgroupBarrier(); + } +} + +fn bar(b : i32) { + zoo(b); +} + +fn foo(c : i32) { + bar(c); +} + +fn main() { + foo(non_uniform); +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:19:7 warning: parameter 'c' of 'foo' must be uniform + foo(non_uniform); + ^^^^^^^^^^^ + +test:15:7 note: parameter 'b' of 'bar' must be uniform + bar(c); + ^ + +test:11:7 note: parameter 'a' of 'zoo' must be uniform + zoo(b); + ^ + +test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ +)"); +} + } // namespace } // namespace tint::resolver diff --git a/test/tint/bug/tint/943.spvasm.expected.glsl b/test/tint/bug/tint/943.spvasm.expected.glsl index f581b5d873..7edb269726 100644 --- a/test/tint/bug/tint/943.spvasm.expected.glsl +++ b/test/tint/bug/tint/943.spvasm.expected.glsl @@ -1,4 +1,5 @@ warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform +note: 'workgroupBarrier' must only be called from uniform control flow #version 310 es struct Uniforms { diff --git a/test/tint/bug/tint/943.spvasm.expected.hlsl b/test/tint/bug/tint/943.spvasm.expected.hlsl index 89accf2e8b..ece0c20678 100644 --- a/test/tint/bug/tint/943.spvasm.expected.hlsl +++ b/test/tint/bug/tint/943.spvasm.expected.hlsl @@ -1,4 +1,5 @@ warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform +note: 'workgroupBarrier' must only be called from uniform control flow static int dimAOuter_1 = 0; cbuffer cbuffer_x_48 : register(b3, space0) { uint4 x_48[5]; diff --git a/test/tint/bug/tint/943.spvasm.expected.msl b/test/tint/bug/tint/943.spvasm.expected.msl index 031bf63865..48f3434ea8 100644 --- a/test/tint/bug/tint/943.spvasm.expected.msl +++ b/test/tint/bug/tint/943.spvasm.expected.msl @@ -1,4 +1,5 @@ warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform +note: 'workgroupBarrier' must only be called from uniform control flow #include using namespace metal; diff --git a/test/tint/bug/tint/943.spvasm.expected.spvasm b/test/tint/bug/tint/943.spvasm.expected.spvasm index 228d7dfc20..e9fc8c5981 100644 --- a/test/tint/bug/tint/943.spvasm.expected.spvasm +++ b/test/tint/bug/tint/943.spvasm.expected.spvasm @@ -1,4 +1,5 @@ warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform +note: 'workgroupBarrier' must only be called from uniform control flow ; SPIR-V ; Version: 1.3 ; Generator: Google Tint Compiler; 0 diff --git a/test/tint/bug/tint/943.spvasm.expected.wgsl b/test/tint/bug/tint/943.spvasm.expected.wgsl index aa383dac12..08be16a8ee 100644 --- a/test/tint/bug/tint/943.spvasm.expected.wgsl +++ b/test/tint/bug/tint/943.spvasm.expected.wgsl @@ -1,4 +1,5 @@ warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform +note: 'workgroupBarrier' must only be called from uniform control flow struct Uniforms { NAN : f32, @size(12)