tint: Show the reason for a uniformity requirement
When producing an error from the uniformity analysis, add notes to show the underlying reason for the uniformity requirement. For function calls that are required-to-be-uniform, show the innermost builtin call that has the requirement. For function parameters that are required-to-be-uniform, recurse into that function to show where its requirement comes from. Add some new tests to specifically test the error messages. Bug: tint:880 Change-Id: Ib166fdeceaffb156a3afc50ebc5a4ad0860dc002 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89722 Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
874b61f1ba
commit
9c03abfb55
|
@ -321,9 +321,9 @@ class UniformityGraph {
|
|||
// Look at which nodes are reachable from "RequiredToBeUniform".
|
||||
{
|
||||
utils::UniqueVector<Node*> reachable;
|
||||
Traverse(current_function_->required_to_be_uniform, reachable);
|
||||
Traverse(current_function_->required_to_be_uniform, &reachable);
|
||||
if (reachable.contains(current_function_->may_be_non_uniform)) {
|
||||
MakeError();
|
||||
MakeError(*current_function_, current_function_->may_be_non_uniform);
|
||||
return false;
|
||||
}
|
||||
if (reachable.contains(current_function_->cf_start)) {
|
||||
|
@ -343,7 +343,7 @@ class UniformityGraph {
|
|||
// Look at which nodes are reachable from "CF_return"
|
||||
{
|
||||
utils::UniqueVector<Node*> reachable;
|
||||
Traverse(current_function_->cf_return, reachable);
|
||||
Traverse(current_function_->cf_return, &reachable);
|
||||
if (reachable.contains(current_function_->may_be_non_uniform)) {
|
||||
current_function_->function_tag = SubsequentControlFlowMayBeNonUniform;
|
||||
}
|
||||
|
@ -362,7 +362,7 @@ class UniformityGraph {
|
|||
// If "Value_return" exists, look at which nodes are reachable from it
|
||||
if (current_function_->value_return) {
|
||||
utils::UniqueVector<Node*> reachable;
|
||||
Traverse(current_function_->value_return, reachable);
|
||||
Traverse(current_function_->value_return, &reachable);
|
||||
if (reachable.contains(current_function_->may_be_non_uniform)) {
|
||||
current_function_->function_tag = ReturnValueMayBeNonUniform;
|
||||
}
|
||||
|
@ -388,7 +388,7 @@ class UniformityGraph {
|
|||
current_function_->ResetVisited();
|
||||
|
||||
utils::UniqueVector<Node*> reachable;
|
||||
Traverse(current_function_->parameters[i].pointer_return_value, reachable);
|
||||
Traverse(current_function_->parameters[i].pointer_return_value, &reachable);
|
||||
if (reachable.contains(current_function_->may_be_non_uniform)) {
|
||||
current_function_->parameters[i].pointer_may_become_non_uniform = true;
|
||||
}
|
||||
|
@ -1234,9 +1234,11 @@ class UniformityGraph {
|
|||
/// Recursively traverse a graph starting at `node`, inserting all nodes that are reached into
|
||||
/// `reachable`.
|
||||
/// @param node the starting node
|
||||
/// @param reachable the set of reachable nodes to populate
|
||||
void Traverse(Node* node, utils::UniqueVector<Node*>& reachable) {
|
||||
reachable.add(node);
|
||||
/// @param reachable the set of reachable nodes to populate, if required
|
||||
void Traverse(Node* node, utils::UniqueVector<Node*>* reachable = nullptr) {
|
||||
if (reachable) {
|
||||
reachable->add(node);
|
||||
}
|
||||
for (auto* to : node->edges) {
|
||||
if (to->visited_from == nullptr) {
|
||||
to->visited_from = node;
|
||||
|
@ -1245,48 +1247,113 @@ class UniformityGraph {
|
|||
}
|
||||
}
|
||||
|
||||
/// Generate an error for a required_to_be_uniform->may_be_non_uniform path.
|
||||
void MakeError() {
|
||||
// Trace back to find a node that is required to be uniform that was reachable from a
|
||||
// non-uniform value or control flow node.
|
||||
Node* current = current_function_->may_be_non_uniform;
|
||||
while (current) {
|
||||
TINT_ASSERT(Resolver, current->visited_from);
|
||||
if (current->visited_from == current_function_->required_to_be_uniform) {
|
||||
break;
|
||||
/// Recursively descend through the function called by `call` and the functions that it calls in
|
||||
/// order to find a call to a builtin function that requires uniformity.
|
||||
const ast::CallExpression* FindBuiltinThatRequiresUniformity(const ast::CallExpression* call) {
|
||||
auto* target = sem_.Get(call)->Target();
|
||||
if (target->Is<sem::Builtin>()) {
|
||||
// This is a call to a builtin, so we must be done.
|
||||
return call;
|
||||
} else if (auto* user = target->As<sem::Function>()) {
|
||||
// This is a call to a user-defined function, so inspect the functions called by that
|
||||
// function and look for one whose node has an edge from the RequiredToBeUniform node.
|
||||
auto& target_info = functions_.at(user->Declaration());
|
||||
for (auto* call_node : target_info.required_to_be_uniform->edges) {
|
||||
if (call_node->arg_index == std::numeric_limits<uint32_t>::max()) {
|
||||
auto* child_call = call_node->ast->As<ast::CallExpression>();
|
||||
return FindBuiltinThatRequiresUniformity(child_call);
|
||||
}
|
||||
current = current->visited_from;
|
||||
}
|
||||
TINT_ASSERT(Resolver, false && "unable to find child call with uniformity requirement");
|
||||
} else {
|
||||
TINT_ASSERT(Resolver, false && "unexpected call expression type");
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// The node will always have an corresponding call expression.
|
||||
auto* call = current->ast->As<ast::CallExpression>();
|
||||
/// Generate an error message for a uniformity issue.
|
||||
/// @param function the function that the diagnostic is being produced for
|
||||
/// @param source_node the node that has caused a uniformity issue in `function`
|
||||
/// @param note `true` if the diagnostic should be emitted as a note
|
||||
void MakeError(FunctionInfo& function, Node* source_node, bool note = false) {
|
||||
// Helper to produce a diagnostic message with the severity required by this invocation of
|
||||
// the `MakeError` function.
|
||||
auto report = [&](Source source, std::string msg) {
|
||||
// TODO(jrprice): Switch to error instead of warning when feedback has settled.
|
||||
diag::Diagnostic error{};
|
||||
error.severity = note ? diag::Severity::Note : diag::Severity::Warning;
|
||||
error.system = diag::System::Resolver;
|
||||
error.source = source;
|
||||
error.message = msg;
|
||||
diagnostics_.add(std::move(error));
|
||||
};
|
||||
|
||||
// Traverse the graph to generate a path from RequiredToBeUniform to the source node.
|
||||
function.ResetVisited();
|
||||
Traverse(function.required_to_be_uniform);
|
||||
TINT_ASSERT(Resolver, source_node->visited_from);
|
||||
|
||||
// Trace back through the graph to find a node that is required to be uniform that has
|
||||
// a path to the source node.
|
||||
Node* cause = source_node;
|
||||
while (cause) {
|
||||
if (cause->visited_from == function.required_to_be_uniform) {
|
||||
break;
|
||||
}
|
||||
cause = cause->visited_from;
|
||||
}
|
||||
|
||||
// The node will always have a corresponding call expression.
|
||||
auto* call = cause->ast->As<ast::CallExpression>();
|
||||
TINT_ASSERT(Resolver, call);
|
||||
auto* target = sem_.Get(call)->Target();
|
||||
|
||||
std::string name;
|
||||
std::string func_name;
|
||||
if (auto* builtin = target->As<sem::Builtin>()) {
|
||||
name = builtin->str();
|
||||
func_name = builtin->str();
|
||||
} else if (auto* user = target->As<sem::Function>()) {
|
||||
name = builder_->Symbols().NameFor(user->Declaration()->symbol);
|
||||
func_name = builder_->Symbols().NameFor(user->Declaration()->symbol);
|
||||
}
|
||||
|
||||
// TODO(jrprice): Switch to error instead of warning when feedback has settled.
|
||||
if (current->arg_index != std::numeric_limits<uint32_t>::max()) {
|
||||
if (cause->arg_index != std::numeric_limits<uint32_t>::max()) {
|
||||
// The requirement was on a function parameter.
|
||||
auto param_name = builder_->Symbols().NameFor(
|
||||
target->Parameters()[current->arg_index]->Declaration()->symbol);
|
||||
diagnostics_.add_warning(
|
||||
diag::System::Resolver,
|
||||
"parameter '" + param_name + "' of '" + name + "' must be uniform",
|
||||
call->args[current->arg_index]->source);
|
||||
// TODO(jrprice): Show the reason why.
|
||||
target->Parameters()[cause->arg_index]->Declaration()->symbol);
|
||||
report(call->args[cause->arg_index]->source,
|
||||
"parameter '" + param_name + "' of '" + func_name + "' must be uniform");
|
||||
|
||||
// If this is a call to a user-defined function, add a note to show the reason that the
|
||||
// parameter is required to be uniform.
|
||||
if (auto* user = target->As<sem::Function>()) {
|
||||
auto& next_function = functions_.at(user->Declaration());
|
||||
Node* next_cause = next_function.parameters[cause->arg_index].init_value;
|
||||
MakeError(next_function, next_cause, true);
|
||||
}
|
||||
} else {
|
||||
// The requirement was on a function callsite.
|
||||
diagnostics_.add_warning(diag::System::Resolver,
|
||||
"'" + name + "' must only be called from uniform control flow",
|
||||
call->source);
|
||||
// TODO(jrprice): Show full call stack to the problematic builtin.
|
||||
report(call->source,
|
||||
"'" + func_name + "' must only be called from uniform control flow");
|
||||
|
||||
// If this is a call to a user-defined function, add a note to show the builtin that
|
||||
// causes the uniformity requirement.
|
||||
auto* innermost_call = FindBuiltinThatRequiresUniformity(call);
|
||||
if (innermost_call != call) {
|
||||
// Determine whether the builtin is being called directly or indirectly.
|
||||
bool indirect = false;
|
||||
if (sem_.Get(call)->Target()->As<sem::Function>() !=
|
||||
sem_.Get(innermost_call)->Stmt()->Function()) {
|
||||
indirect = true;
|
||||
}
|
||||
|
||||
auto* builtin = sem_.Get(innermost_call)->Target()->As<sem::Builtin>();
|
||||
diagnostics_.add_note(diag::System::Resolver,
|
||||
"'" + func_name + "' requires uniformity because it " +
|
||||
(indirect ? "indirectly " : "") + "calls " +
|
||||
builtin->str(),
|
||||
innermost_call->source);
|
||||
}
|
||||
}
|
||||
// TODO(jrprice): Show the source of non-uniformity.
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -466,6 +466,10 @@ fn bar() {
|
|||
R"(test:11:7 warning: parameter 'i' of 'foo' must be uniform
|
||||
foo(rw);
|
||||
^^
|
||||
|
||||
test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
workgroupBarrier();
|
||||
^^^^^^^^^^^^^^^^
|
||||
)");
|
||||
}
|
||||
|
||||
|
@ -3229,6 +3233,34 @@ fn foo() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(UniformityAnalysisTest, LoadNonUniformThroughPointerParameter) {
|
||||
auto src = R"(
|
||||
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
|
||||
|
||||
fn bar(p : ptr<function, i32>) {
|
||||
if (*p == 0) {
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
|
||||
fn foo() {
|
||||
var v = non_uniform;
|
||||
bar(&v);
|
||||
}
|
||||
)";
|
||||
|
||||
RunTest(src, false);
|
||||
EXPECT_EQ(error_,
|
||||
R"(test:12:7 warning: parameter 'p' of 'bar' must be uniform
|
||||
bar(&v);
|
||||
^
|
||||
|
||||
test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
workgroupBarrier();
|
||||
^^^^^^^^^^^^^^^^
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(UniformityAnalysisTest, LoadUniformThroughPointer) {
|
||||
auto src = R"(
|
||||
fn foo() {
|
||||
|
@ -3256,6 +3288,23 @@ fn foo() {
|
|||
RunTest(src, true);
|
||||
}
|
||||
|
||||
TEST_F(UniformityAnalysisTest, LoadUniformThroughPointerParameter) {
|
||||
auto src = R"(
|
||||
fn bar(p : ptr<function, i32>) {
|
||||
if (*p == 0) {
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
|
||||
fn foo() {
|
||||
var v = 42;
|
||||
bar(&v);
|
||||
}
|
||||
)";
|
||||
|
||||
RunTest(src, true);
|
||||
}
|
||||
|
||||
TEST_F(UniformityAnalysisTest, StoreNonUniformAfterCapturingPointer) {
|
||||
auto src = R"(
|
||||
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
|
||||
|
@ -4884,5 +4933,114 @@ fn foo() {
|
|||
RunTest(src, true);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Tests for the quality of the error messages produced by the analysis.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinDirectly) {
|
||||
auto src = R"(
|
||||
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
|
||||
|
||||
fn foo() {
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
if (non_uniform == 42) {
|
||||
foo();
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
RunTest(src, false);
|
||||
EXPECT_EQ(error_,
|
||||
R"(test:10:5 warning: 'foo' must only be called from uniform control flow
|
||||
foo();
|
||||
^^^
|
||||
|
||||
test:5:3 note: 'foo' requires uniformity because it calls workgroupBarrier
|
||||
workgroupBarrier();
|
||||
^^^^^^^^^^^^^^^^
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinIndirectly) {
|
||||
auto src = R"(
|
||||
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
|
||||
|
||||
fn zoo() {
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
fn bar() {
|
||||
zoo();
|
||||
}
|
||||
|
||||
fn foo() {
|
||||
bar();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
if (non_uniform == 42) {
|
||||
foo();
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
RunTest(src, false);
|
||||
EXPECT_EQ(error_,
|
||||
R"(test:18:5 warning: 'foo' must only be called from uniform control flow
|
||||
foo();
|
||||
^^^
|
||||
|
||||
test:5:3 note: 'foo' requires uniformity because it indirectly calls workgroupBarrier
|
||||
workgroupBarrier();
|
||||
^^^^^^^^^^^^^^^^
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(UniformityAnalysisTest, Error_ParametersRequireUniformityInChain) {
|
||||
auto src = R"(
|
||||
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
|
||||
|
||||
fn zoo(a : i32) {
|
||||
if (a == 42) {
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
|
||||
fn bar(b : i32) {
|
||||
zoo(b);
|
||||
}
|
||||
|
||||
fn foo(c : i32) {
|
||||
bar(c);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
foo(non_uniform);
|
||||
}
|
||||
)";
|
||||
|
||||
RunTest(src, false);
|
||||
EXPECT_EQ(error_,
|
||||
R"(test:19:7 warning: parameter 'c' of 'foo' must be uniform
|
||||
foo(non_uniform);
|
||||
^^^^^^^^^^^
|
||||
|
||||
test:15:7 note: parameter 'b' of 'bar' must be uniform
|
||||
bar(c);
|
||||
^
|
||||
|
||||
test:11:7 note: parameter 'a' of 'zoo' must be uniform
|
||||
zoo(b);
|
||||
^
|
||||
|
||||
test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
workgroupBarrier();
|
||||
^^^^^^^^^^^^^^^^
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::resolver
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
|
||||
note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
#version 310 es
|
||||
|
||||
struct Uniforms {
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
|
||||
note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
static int dimAOuter_1 = 0;
|
||||
cbuffer cbuffer_x_48 : register(b3, space0) {
|
||||
uint4 x_48[5];
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
|
||||
note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
|
||||
note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
|
||||
note: 'workgroupBarrier' must only be called from uniform control flow
|
||||
struct Uniforms {
|
||||
NAN : f32,
|
||||
@size(12)
|
||||
|
|
Loading…
Reference in New Issue