tint/uniformity: Rework generation of diagnostics

Flip the diagnostics so that the trigger location is on the builtin
that requires uniformity.

We also now show the place at which control flow diverges regardless
of where it is in the function call stack.

Change-Id: Id739a137b9011c900649b74165a6600a95d87ca4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/116691
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2023-01-13 17:15:32 +00:00 committed by Dawn LUCI CQ
parent 48a49f3730
commit dd54f74de1
2 changed files with 406 additions and 333 deletions

View File

@ -15,6 +15,7 @@
#include "src/tint/resolver/uniformity.h" #include "src/tint/resolver/uniformity.h"
#include <limits> #include <limits>
#include <sstream>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -1651,9 +1652,9 @@ class UniformityGraph {
/// @param function the function being analyzed /// @param function the function being analyzed
/// @param required_to_be_uniform the node to traverse from /// @param required_to_be_uniform the node to traverse from
/// @param may_be_non_uniform the node to traverse to /// @param may_be_non_uniform the node to traverse to
void ShowCauseOfNonUniformity(FunctionInfo& function, void ShowControlFlowDivergence(FunctionInfo& function,
Node* required_to_be_uniform, Node* required_to_be_uniform,
Node* may_be_non_uniform) { Node* may_be_non_uniform) {
// Traverse the graph to generate a path from the node to the source of non-uniformity. // Traverse the graph to generate a path from the node to the source of non-uniformity.
function.ResetVisited(); function.ResetVisited();
Traverse(required_to_be_uniform); Traverse(required_to_be_uniform);
@ -1667,7 +1668,7 @@ class UniformityGraph {
non_uniform_source, [](Node* node) { return node->affects_control_flow; }); non_uniform_source, [](Node* node) { return node->affects_control_flow; });
if (control_flow) { if (control_flow) {
diagnostics_.add_note(diag::System::Resolver, diagnostics_.add_note(diag::System::Resolver,
"control flow depends on non-uniform value", "control flow depends on possibly non-uniform value",
control_flow->ast->source); control_flow->ast->source);
// TODO(jrprice): There are cases where the function with uniformity requirements is not // TODO(jrprice): There are cases where the function with uniformity requirements is not
// actually inside this control flow construct, for example: // actually inside this control flow construct, for example:
@ -1677,7 +1678,15 @@ class UniformityGraph {
// the actual cause of divergence. // the actual cause of divergence.
} }
auto get_var_type = [&](const sem::Variable* var) { ShowSourceOfNonUniformity(non_uniform_source);
}
/// Add a diagnostic note to show the origin of a non-uniform value.
/// @param non_uniform_source the node that represents a non-uniform value
void ShowSourceOfNonUniformity(Node* non_uniform_source) {
TINT_ASSERT(Resolver, non_uniform_source);
auto var_type = [&](const sem::Variable* var) {
switch (var->AddressSpace()) { switch (var->AddressSpace()) {
case ast::AddressSpace::kStorage: case ast::AddressSpace::kStorage:
return "read_write storage buffer "; return "read_write storage buffer ";
@ -1686,17 +1695,18 @@ class UniformityGraph {
case ast::AddressSpace::kPrivate: case ast::AddressSpace::kPrivate:
return "module-scope private variable "; return "module-scope private variable ";
default: default:
if (ast::HasAttribute<ast::BuiltinAttribute>(var->Declaration()->attributes)) { return "";
return "builtin "; }
} else if (ast::HasAttribute<ast::LocationAttribute>( };
var->Declaration()->attributes)) { auto param_type = [&](const sem::Parameter* param) {
return "user-defined input "; if (ast::HasAttribute<ast::BuiltinAttribute>(param->Declaration()->attributes)) {
} else { return "builtin ";
// TODO(jrprice): Provide more info for this case. } else if (ast::HasAttribute<ast::LocationAttribute>(
} param->Declaration()->attributes)) {
break; return "user-defined input ";
} else {
return "parameter ";
} }
return "";
}; };
// Show the source of the non-uniform value. // Show the source of the non-uniform value.
@ -1704,19 +1714,23 @@ class UniformityGraph {
non_uniform_source->ast, non_uniform_source->ast,
[&](const ast::IdentifierExpression* ident) { [&](const ast::IdentifierExpression* ident) {
auto* var = sem_.Get(ident)->UnwrapLoad()->As<sem::VariableUser>()->Variable(); auto* var = sem_.Get(ident)->UnwrapLoad()->As<sem::VariableUser>()->Variable();
std::string var_type = get_var_type(var); std::ostringstream ss;
diagnostics_.add_note(diag::System::Resolver, if (auto* param = var->As<sem::Parameter>()) {
"reading from " + var_type + "'" + NameFor(ident) + auto* func = param->Owner()->As<sem::Function>();
"' may result in a non-uniform value", ss << param_type(param) << "'" << NameFor(ident) << "' of '"
ident->source); << NameFor(func->Declaration()) << "' may be non-uniform";
} else {
ss << "reading from " << var_type(var) << "'" << NameFor(ident)
<< "' may result in a non-uniform value";
}
diagnostics_.add_note(diag::System::Resolver, ss.str(), ident->source);
}, },
[&](const ast::Variable* v) { [&](const ast::Variable* v) {
auto* var = sem_.Get(v); auto* var = sem_.Get(v);
std::string var_type = get_var_type(var); std::ostringstream ss;
diagnostics_.add_note(diag::System::Resolver, ss << "reading from " << var_type(var) << "'" << NameFor(v)
"reading from " + var_type + "'" + NameFor(v) + << "' may result in a non-uniform value";
"' may result in a non-uniform value", diagnostics_.add_note(diag::System::Resolver, ss.str(), v->source);
v->source);
}, },
[&](const ast::CallExpression* c) { [&](const ast::CallExpression* c) {
auto target_name = NameFor(c->target.name); auto target_name = NameFor(c->target.name);
@ -1730,11 +1744,10 @@ class UniformityGraph {
case Node::kFunctionCallArgumentContents: { case Node::kFunctionCallArgumentContents: {
auto* arg = c->args[non_uniform_source->arg_index]; auto* arg = c->args[non_uniform_source->arg_index];
auto* var = sem_.Get(arg)->RootIdentifier(); auto* var = sem_.Get(arg)->RootIdentifier();
std::string var_type = get_var_type(var); std::ostringstream ss;
diagnostics_.add_note(diag::System::Resolver, ss << "reading from " << var_type(var) << "'" << NameFor(var->Declaration())
"reading from " + var_type + "'" + << "' may result in a non-uniform value";
NameFor(var->Declaration()) + diagnostics_.add_note(diag::System::Resolver, ss.str(),
"' may result in a non-uniform value",
var->Declaration()->source); var->Declaration()->source);
break; break;
} }
@ -1750,7 +1763,7 @@ class UniformityGraph {
case Node::kFunctionCallPointerArgumentResult: { case Node::kFunctionCallPointerArgumentResult: {
diagnostics_.add_note( diagnostics_.add_note(
diag::System::Resolver, diag::System::Resolver,
"pointer contents may become non-uniform after calling '" + "contents of pointer may become non-uniform after calling '" +
target_name + "'", target_name + "'",
c->args[non_uniform_source->arg_index]->source); c->args[non_uniform_source->arg_index]->source);
break; break;
@ -1773,11 +1786,9 @@ class UniformityGraph {
/// Generate an error message for a uniformity issue. /// Generate an error message for a uniformity issue.
/// @param function the function that the diagnostic is being produced for /// @param function the function that the diagnostic is being produced for
/// @param source_node the node that has caused a uniformity issue in `function` /// @param source_node the node that has caused a uniformity issue in `function`
/// @param note `true` if the diagnostic should be emitted as a note void MakeError(FunctionInfo& function, Node* source_node) {
void MakeError(FunctionInfo& function, Node* source_node, bool note = false) { // Helper to produce a diagnostic message, as a note or with the global failure severity.
// Helper to produce a diagnostic message with the severity required by this invocation of auto report = [&](Source source, std::string msg, bool note) {
// the `MakeError` function.
auto report = [&](Source source, std::string msg) {
diag::Diagnostic error{}; diag::Diagnostic error{};
auto failureSeverity = auto failureSeverity =
kUniformityFailuresAsError ? diag::Severity::Error : diag::Severity::Warning; kUniformityFailuresAsError ? diag::Severity::Error : diag::Severity::Warning;
@ -1802,77 +1813,54 @@ class UniformityGraph {
auto* call = cause->ast->As<ast::CallExpression>(); auto* call = cause->ast->As<ast::CallExpression>();
TINT_ASSERT(Resolver, call); TINT_ASSERT(Resolver, call);
auto* target = SemCall(call)->Target(); auto* target = SemCall(call)->Target();
auto func_name = NameFor(call->target.name);
std::string func_name; if (cause->type == Node::kFunctionCallArgumentValue ||
if (auto* builtin = target->As<sem::Builtin>()) { cause->type == Node::kFunctionCallArgumentContents) {
func_name = builtin->str(); bool is_value = (cause->type == Node::kFunctionCallArgumentValue);
} else if (auto* user = target->As<sem::Function>()) {
func_name = NameFor(user->Declaration());
}
if (cause->type == Node::kFunctionCallArgumentValue) { auto* user_func = target->As<sem::Function>();
// The requirement was on a function parameter. if (user_func) {
auto* ast_param = target->Parameters()[cause->arg_index]->Declaration(); // Recurse into the called function to show the reason for the requirement.
std::string param_name; auto next_function = functions_.Find(user_func->Declaration());
if (ast_param) { auto& param_info = next_function->parameters[cause->arg_index];
param_name = " '" + NameFor(ast_param) + "'"; MakeError(*next_function,
is_value ? param_info.value : param_info.ptr_input_contents);
} }
report(call->args[cause->arg_index]->source,
"parameter" + param_name + " of '" + func_name + "' must be uniform");
// If this is a call to a user-defined function, add a note to show the reason that the // Show the place where the non-uniform argument was passed.
// parameter is required to be uniform. // If this is a builtin, this will be the trigger location for the failure.
if (auto* user = target->As<sem::Function>()) { std::ostringstream ss;
auto next_function = functions_.Find(user->Declaration()); ss << "possibly non-uniform value passed" << (is_value ? "" : " via pointer")
Node* next_cause = next_function->parameters[cause->arg_index].value; << " here";
MakeError(*next_function, next_cause, true); report(call->args[cause->arg_index]->source, ss.str(), /* note */ user_func != nullptr);
}
} else if (cause->type == Node::kFunctionCallArgumentContents) {
// The requirement was on the contents of a function parameter.
auto param_name = NameFor(target->Parameters()[cause->arg_index]->Declaration());
report(call->args[cause->arg_index]->source, "contents of parameter '" + param_name +
"' of '" + func_name +
"' must be uniform");
// If this is a call to a user-defined function, add a note to show the reason that the // Show the origin of non-uniformity for the value or data that is being passed.
// parameter is required to be uniform. ShowSourceOfNonUniformity(source_node->visited_from);
if (auto* user = target->As<sem::Function>()) {
auto next_function = functions_.Find(user->Declaration());
Node* next_cause = next_function->parameters[cause->arg_index].ptr_input_contents;
MakeError(*next_function, next_cause, true);
}
} else { } else {
// The requirement was on a function callsite. auto* builtin_call = FindBuiltinThatRequiresUniformity(call);
report(call->source, {
"'" + func_name + "' must only be called from uniform control flow"); // Show a builtin was reachable from this call (which may be the call itself).
// This will be the trigger location for the failure.
// If this is a call to a user-defined function, add a note to show the builtin that std::ostringstream ss;
// causes the uniformity requirement. ss << "'" << NameFor(builtin_call->target.name)
auto* innermost_call = FindBuiltinThatRequiresUniformity(call); << "' must only be called from uniform control flow";
if (innermost_call != call) { report(builtin_call->source, ss.str(), /* note */ false);
auto* sem_call = SemCall(call);
auto* sem_innermost_call = SemCall(innermost_call);
// Determine whether the builtin is being called directly or indirectly.
bool indirect = false;
if (sem_call->Target()->As<sem::Function>() !=
sem_innermost_call->Stmt()->Function()) {
indirect = true;
}
auto* builtin = sem_innermost_call->Target()->As<sem::Builtin>();
diagnostics_.add_note(diag::System::Resolver,
"'" + func_name + "' requires uniformity because it " +
(indirect ? "indirectly " : "") + "calls " +
builtin->str(),
innermost_call->source);
} }
}
// Show the cause of non-uniformity (starting at the top-level error). if (builtin_call != call) {
if (!note) { // The call was to a user function, so show that call too.
ShowCauseOfNonUniformity(function, function.required_to_be_uniform, std::ostringstream ss;
function.may_be_non_uniform); ss << "called ";
if (target->As<sem::Function>() != SemCall(builtin_call)->Stmt()->Function()) {
ss << "indirectly ";
}
ss << "by '" << func_name << "' from '" << function.name << "'";
report(call->source, ss.str(), /* note */ true);
}
// Show the point at which control-flow depends on a non-uniform value.
ShowControlFlowDivergence(function, cause, source_node);
} }
} }

File diff suppressed because it is too large Load Diff