tint: Show the reason for a uniformity requirement

When producing an error from the uniformity analysis, add notes to
show the underlying reason for the uniformity requirement.

For function calls that are required-to-be-uniform, show the innermost
builtin call that has the requirement.

For function parameters that are required-to-be-uniform, recurse into
that function to show where its requirement comes from.

Add some new tests to specifically test the error messages.

Bug: tint:880
Change-Id: Ib166fdeceaffb156a3afc50ebc5a4ad0860dc002
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89722
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2022-05-11 22:05:15 +00:00
parent 874b61f1ba
commit 9c03abfb55
7 changed files with 264 additions and 34 deletions

View File

@ -321,9 +321,9 @@ class UniformityGraph {
// Look at which nodes are reachable from "RequiredToBeUniform".
{
utils::UniqueVector<Node*> reachable;
Traverse(current_function_->required_to_be_uniform, reachable);
Traverse(current_function_->required_to_be_uniform, &reachable);
if (reachable.contains(current_function_->may_be_non_uniform)) {
MakeError();
MakeError(*current_function_, current_function_->may_be_non_uniform);
return false;
}
if (reachable.contains(current_function_->cf_start)) {
@ -343,7 +343,7 @@ class UniformityGraph {
// Look at which nodes are reachable from "CF_return"
{
utils::UniqueVector<Node*> reachable;
Traverse(current_function_->cf_return, reachable);
Traverse(current_function_->cf_return, &reachable);
if (reachable.contains(current_function_->may_be_non_uniform)) {
current_function_->function_tag = SubsequentControlFlowMayBeNonUniform;
}
@ -362,7 +362,7 @@ class UniformityGraph {
// If "Value_return" exists, look at which nodes are reachable from it
if (current_function_->value_return) {
utils::UniqueVector<Node*> reachable;
Traverse(current_function_->value_return, reachable);
Traverse(current_function_->value_return, &reachable);
if (reachable.contains(current_function_->may_be_non_uniform)) {
current_function_->function_tag = ReturnValueMayBeNonUniform;
}
@ -388,7 +388,7 @@ class UniformityGraph {
current_function_->ResetVisited();
utils::UniqueVector<Node*> reachable;
Traverse(current_function_->parameters[i].pointer_return_value, reachable);
Traverse(current_function_->parameters[i].pointer_return_value, &reachable);
if (reachable.contains(current_function_->may_be_non_uniform)) {
current_function_->parameters[i].pointer_may_become_non_uniform = true;
}
@ -1234,9 +1234,11 @@ class UniformityGraph {
/// Recursively traverse a graph starting at `node`, inserting all nodes that are reached into
/// `reachable`.
/// @param node the starting node
/// @param reachable the set of reachable nodes to populate
void Traverse(Node* node, utils::UniqueVector<Node*>& reachable) {
reachable.add(node);
/// @param reachable the set of reachable nodes to populate, if required
void Traverse(Node* node, utils::UniqueVector<Node*>* reachable = nullptr) {
if (reachable) {
reachable->add(node);
}
for (auto* to : node->edges) {
if (to->visited_from == nullptr) {
to->visited_from = node;
@ -1245,48 +1247,113 @@ class UniformityGraph {
}
}
/// Generate an error for a required_to_be_uniform->may_be_non_uniform path.
void MakeError() {
// Trace back to find a node that is required to be uniform that was reachable from a
// non-uniform value or control flow node.
Node* current = current_function_->may_be_non_uniform;
while (current) {
TINT_ASSERT(Resolver, current->visited_from);
if (current->visited_from == current_function_->required_to_be_uniform) {
break;
/// Recursively descend through the function called by `call` and the functions that it calls in
/// order to find a call to a builtin function that requires uniformity.
const ast::CallExpression* FindBuiltinThatRequiresUniformity(const ast::CallExpression* call) {
auto* target = sem_.Get(call)->Target();
if (target->Is<sem::Builtin>()) {
// This is a call to a builtin, so we must be done.
return call;
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
auto& target_info = functions_.at(user->Declaration());
for (auto* call_node : target_info.required_to_be_uniform->edges) {
if (call_node->arg_index == std::numeric_limits<uint32_t>::max()) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
return FindBuiltinThatRequiresUniformity(child_call);
}
current = current->visited_from;
}
TINT_ASSERT(Resolver, false && "unable to find child call with uniformity requirement");
} else {
TINT_ASSERT(Resolver, false && "unexpected call expression type");
}
return nullptr;
}
// The node will always have an corresponding call expression.
auto* call = current->ast->As<ast::CallExpression>();
/// Generate an error message for a uniformity issue.
/// @param function the function that the diagnostic is being produced for
/// @param source_node the node that has caused a uniformity issue in `function`
/// @param note `true` if the diagnostic should be emitted as a note
void MakeError(FunctionInfo& function, Node* source_node, bool note = false) {
// Helper to produce a diagnostic message with the severity required by this invocation of
// the `MakeError` function.
auto report = [&](Source source, std::string msg) {
// TODO(jrprice): Switch to error instead of warning when feedback has settled.
diag::Diagnostic error{};
error.severity = note ? diag::Severity::Note : diag::Severity::Warning;
error.system = diag::System::Resolver;
error.source = source;
error.message = msg;
diagnostics_.add(std::move(error));
};
// Traverse the graph to generate a path from RequiredToBeUniform to the source node.
function.ResetVisited();
Traverse(function.required_to_be_uniform);
TINT_ASSERT(Resolver, source_node->visited_from);
// Trace back through the graph to find a node that is required to be uniform that has
// a path to the source node.
Node* cause = source_node;
while (cause) {
if (cause->visited_from == function.required_to_be_uniform) {
break;
}
cause = cause->visited_from;
}
// The node will always have a corresponding call expression.
auto* call = cause->ast->As<ast::CallExpression>();
TINT_ASSERT(Resolver, call);
auto* target = sem_.Get(call)->Target();
std::string name;
std::string func_name;
if (auto* builtin = target->As<sem::Builtin>()) {
name = builtin->str();
func_name = builtin->str();
} else if (auto* user = target->As<sem::Function>()) {
name = builder_->Symbols().NameFor(user->Declaration()->symbol);
func_name = builder_->Symbols().NameFor(user->Declaration()->symbol);
}
// TODO(jrprice): Switch to error instead of warning when feedback has settled.
if (current->arg_index != std::numeric_limits<uint32_t>::max()) {
if (cause->arg_index != std::numeric_limits<uint32_t>::max()) {
// The requirement was on a function parameter.
auto param_name = builder_->Symbols().NameFor(
target->Parameters()[current->arg_index]->Declaration()->symbol);
diagnostics_.add_warning(
diag::System::Resolver,
"parameter '" + param_name + "' of '" + name + "' must be uniform",
call->args[current->arg_index]->source);
// TODO(jrprice): Show the reason why.
target->Parameters()[cause->arg_index]->Declaration()->symbol);
report(call->args[cause->arg_index]->source,
"parameter '" + param_name + "' of '" + func_name + "' must be uniform");
// If this is a call to a user-defined function, add a note to show the reason that the
// parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) {
auto& next_function = functions_.at(user->Declaration());
Node* next_cause = next_function.parameters[cause->arg_index].init_value;
MakeError(next_function, next_cause, true);
}
} else {
// The requirement was on a function callsite.
diagnostics_.add_warning(diag::System::Resolver,
"'" + name + "' must only be called from uniform control flow",
call->source);
// TODO(jrprice): Show full call stack to the problematic builtin.
report(call->source,
"'" + func_name + "' must only be called from uniform control flow");
// If this is a call to a user-defined function, add a note to show the builtin that
// causes the uniformity requirement.
auto* innermost_call = FindBuiltinThatRequiresUniformity(call);
if (innermost_call != call) {
// Determine whether the builtin is being called directly or indirectly.
bool indirect = false;
if (sem_.Get(call)->Target()->As<sem::Function>() !=
sem_.Get(innermost_call)->Stmt()->Function()) {
indirect = true;
}
auto* builtin = sem_.Get(innermost_call)->Target()->As<sem::Builtin>();
diagnostics_.add_note(diag::System::Resolver,
"'" + func_name + "' requires uniformity because it " +
(indirect ? "indirectly " : "") + "calls " +
builtin->str(),
innermost_call->source);
}
}
// TODO(jrprice): Show the source of non-uniformity.
}
};

View File

@ -466,6 +466,10 @@ fn bar() {
R"(test:11:7 warning: parameter 'i' of 'foo' must be uniform
foo(rw);
^^
test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
)");
}
@ -3229,6 +3233,34 @@ fn foo() {
)");
}
TEST_F(UniformityAnalysisTest, LoadNonUniformThroughPointerParameter) {
auto src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn bar(p : ptr<function, i32>) {
if (*p == 0) {
workgroupBarrier();
}
}
fn foo() {
var v = non_uniform;
bar(&v);
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:12:7 warning: parameter 'p' of 'bar' must be uniform
bar(&v);
^
test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, LoadUniformThroughPointer) {
auto src = R"(
fn foo() {
@ -3256,6 +3288,23 @@ fn foo() {
RunTest(src, true);
}
TEST_F(UniformityAnalysisTest, LoadUniformThroughPointerParameter) {
auto src = R"(
fn bar(p : ptr<function, i32>) {
if (*p == 0) {
workgroupBarrier();
}
}
fn foo() {
var v = 42;
bar(&v);
}
)";
RunTest(src, true);
}
TEST_F(UniformityAnalysisTest, StoreNonUniformAfterCapturingPointer) {
auto src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@ -4884,5 +4933,114 @@ fn foo() {
RunTest(src, true);
}
////////////////////////////////////////////////////////////////////////////////
/// Tests for the quality of the error messages produced by the analysis.
////////////////////////////////////////////////////////////////////////////////
TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinDirectly) {
auto src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
workgroupBarrier();
}
fn main() {
if (non_uniform == 42) {
foo();
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:10:5 warning: 'foo' must only be called from uniform control flow
foo();
^^^
test:5:3 note: 'foo' requires uniformity because it calls workgroupBarrier
workgroupBarrier();
^^^^^^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, Error_CallUserThatCallsBuiltinIndirectly) {
auto src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn zoo() {
workgroupBarrier();
}
fn bar() {
zoo();
}
fn foo() {
bar();
}
fn main() {
if (non_uniform == 42) {
foo();
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:18:5 warning: 'foo' must only be called from uniform control flow
foo();
^^^
test:5:3 note: 'foo' requires uniformity because it indirectly calls workgroupBarrier
workgroupBarrier();
^^^^^^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, Error_ParametersRequireUniformityInChain) {
auto src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn zoo(a : i32) {
if (a == 42) {
workgroupBarrier();
}
}
fn bar(b : i32) {
zoo(b);
}
fn foo(c : i32) {
bar(c);
}
fn main() {
foo(non_uniform);
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:19:7 warning: parameter 'c' of 'foo' must be uniform
foo(non_uniform);
^^^^^^^^^^^
test:15:7 note: parameter 'b' of 'bar' must be uniform
bar(c);
^
test:11:7 note: parameter 'a' of 'zoo' must be uniform
zoo(b);
^
test:6:5 note: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
)");
}
} // namespace
} // namespace tint::resolver

View File

@ -1,4 +1,5 @@
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
note: 'workgroupBarrier' must only be called from uniform control flow
#version 310 es
struct Uniforms {

View File

@ -1,4 +1,5 @@
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
note: 'workgroupBarrier' must only be called from uniform control flow
static int dimAOuter_1 = 0;
cbuffer cbuffer_x_48 : register(b3, space0) {
uint4 x_48[5];

View File

@ -1,4 +1,5 @@
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
note: 'workgroupBarrier' must only be called from uniform control flow
#include <metal_stdlib>
using namespace metal;

View File

@ -1,4 +1,5 @@
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
note: 'workgroupBarrier' must only be called from uniform control flow
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0

View File

@ -1,4 +1,5 @@
warning: parameter 'dimInner' of 'mm_matMul_i1_i1_i1_' must be uniform
note: 'workgroupBarrier' must only be called from uniform control flow
struct Uniforms {
NAN : f32,
@size(12)