tint/uniformity: Add derivative_uniformity filter

Add additional `RequiredToBeUniform` nodes for each severity
level. When processing a call to a derivative builtin, look up the
severity from the semantic info for that AST node, and add an edge to
the corresponding `RequiredToBeUniform` node.

Propagate the severities to the callsite and parameter tags for a
function that contains a builtin.

Traverse that graph from each `RequiredToBeUniform` node to look for
violations at each severity level, starting with the most severe. Only
stop the analysis if an error is found, otherwise report the violation
and keep going.

Bug: tint:1809
Change-Id: I4ac838e85da3f4fb3d63f4892dce7f12b096f74b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/117602
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
James Price 2023-01-25 01:24:46 +00:00
parent 5853205342
commit e657c470bd
9 changed files with 428 additions and 80 deletions

View File

@ -93,6 +93,9 @@ DiagnosticRule ParseDiagnosticRule(std::string_view str) {
if (str == "chromium_unreachable_code") {
return DiagnosticRule::kChromiumUnreachableCode;
}
if (str == "derivative_uniformity") {
return DiagnosticRule::kDerivativeUniformity;
}
return DiagnosticRule::kUndefined;
}
@ -102,6 +105,8 @@ std::ostream& operator<<(std::ostream& out, DiagnosticRule value) {
return out << "undefined";
case DiagnosticRule::kChromiumUnreachableCode:
return out << "chromium_unreachable_code";
case DiagnosticRule::kDerivativeUniformity:
return out << "derivative_uniformity";
}
return out << "<unknown>";
}

View File

@ -66,6 +66,7 @@ constexpr const char* kDiagnosticSeverityStrings[] = {
enum class DiagnosticRule {
kUndefined,
kChromiumUnreachableCode,
kDerivativeUniformity,
};
/// @param out the std::ostream to write to
@ -80,6 +81,7 @@ DiagnosticRule ParseDiagnosticRule(std::string_view str);
constexpr const char* kDiagnosticRuleStrings[] = {
"chromium_unreachable_code",
"derivative_uniformity",
};
/// Convert a DiagnosticSeverity to the corresponding diag::Severity.

View File

@ -50,7 +50,9 @@ void DiagnosticRuleParser(::benchmark::State& state) {
std::array kStrings{
"hromium_unyeachable_code", "chrorrillmGunnreachable_c77de", "chromium_unreachable4cod00",
"chromium_unreachable_code", "chromium_unracaboo_code", "chromium_unrzzchabl_code",
"ciipp11ium_unreachable_cod",
"ciipp11ium_unreachable_cod", "derivXXtive_uniformity", "55erivativeIIunifonn99ity",
"derirratHHaae_YniforSSity", "derivative_uniformity", "erivtive_unHkkormit",
"jerivaive_uniforRgty", "derivatbve_unformiy",
};
for (auto _ : state) {
for (auto& str : kStrings) {

View File

@ -115,12 +115,16 @@ inline std::ostream& operator<<(std::ostream& out, Case c) {
static constexpr Case kValidCases[] = {
{"chromium_unreachable_code", DiagnosticRule::kChromiumUnreachableCode},
{"derivative_uniformity", DiagnosticRule::kDerivativeUniformity},
};
static constexpr Case kInvalidCases[] = {
{"cXromggum_unreachable_cde", DiagnosticRule::kUndefined},
{"chroVium_unruchble_codX", DiagnosticRule::kUndefined},
{"chromium_3nreachable_code", DiagnosticRule::kUndefined},
{"derivatEve_uniformity", DiagnosticRule::kUndefined},
{"deTTPivative_uniformit", DiagnosticRule::kUndefined},
{"derivtive_uddxxformity", DiagnosticRule::kUndefined},
};
using DiagnosticRuleParseTest = testing::TestWithParam<Case>;

View File

@ -42,6 +42,8 @@ enum builtin_value {
// https://gpuweb.github.io/gpuweb/wgsl/#filterable-triggering-rules
enum diagnostic_rule {
// Rules defined in the spec.
derivative_uniformity
// Chromium specific rules not defined in the spec.
chromium_unreachable_code
}

View File

@ -142,6 +142,17 @@ bool Resolver::Resolve() {
ApplyDiagnosticSeverities(mod);
builder_->Sem().SetModule(mod);
if (result) {
// Run the uniformity analysis, which requires a complete semantic module.
if (!enabled_extensions_.Contains(ast::Extension::kChromiumDisableUniformityAnalysis)) {
if (!AnalyzeUniformity(builder_, dependencies_)) {
if (kUniformityFailuresAsError) {
return false;
}
}
}
}
return result;
}
@ -182,14 +193,6 @@ bool Resolver::ResolveInternal() {
return false;
}
if (!enabled_extensions_.Contains(ast::Extension::kChromiumDisableUniformityAnalysis)) {
if (!AnalyzeUniformity(builder_, dependencies_)) {
if (kUniformityFailuresAsError) {
return false;
}
}
}
bool result = true;
for (auto* node : builder_->ASTNodes().Objects()) {
if (TINT_UNLIKELY(!marked_[node->node_id.value])) {

View File

@ -66,9 +66,12 @@ const ast::Expression* UnwrapIndirectAndAddressOfChain(const ast::UnaryOpExpress
}
/// CallSiteTag describes the uniformity requirements on the call sites of a function.
enum CallSiteTag {
CallSiteRequiredToBeUniform,
CallSiteNoRestriction,
struct CallSiteTag {
enum {
CallSiteRequiredToBeUniform,
CallSiteNoRestriction,
} tag;
ast::DiagnosticSeverity severity = ast::DiagnosticSeverity::kUndefined;
};
/// FunctionTag describes a functions effects on uniformity.
@ -78,10 +81,13 @@ enum FunctionTag {
};
/// ParameterTag describes the uniformity requirements of values passed to a function parameter.
enum ParameterTag {
ParameterValueRequiredToBeUniform,
ParameterContentsRequiredToBeUniform,
ParameterNoRestriction,
struct ParameterTag {
enum {
ParameterValueRequiredToBeUniform,
ParameterContentsRequiredToBeUniform,
ParameterNoRestriction,
} tag;
ast::DiagnosticSeverity severity = ast::DiagnosticSeverity::kUndefined;
};
/// Node represents a node in the graph of control flow and value nodes within the analysis of a
@ -138,9 +144,9 @@ struct ParameterInfo {
/// The semantic node in corresponds to this parameter.
const sem::Parameter* sem;
/// The parameter's direct uniformity requirements.
ParameterTag tag_direct = ParameterNoRestriction;
ParameterTag tag_direct = {ParameterTag::ParameterNoRestriction};
/// The parameter's uniformity requirements that affect the function return value.
ParameterTag tag_retval = ParameterNoRestriction;
ParameterTag tag_retval = {ParameterTag::ParameterNoRestriction};
/// Will be `true` if this function may cause the contents of this pointer parameter to become
/// non-uniform.
bool pointer_may_become_non_uniform = false;
@ -166,11 +172,13 @@ struct FunctionInfo {
/// @param builder the program builder
FunctionInfo(const ast::Function* func, const ProgramBuilder* builder) {
name = builder->Symbols().NameFor(func->symbol);
callsite_tag = CallSiteNoRestriction;
callsite_tag = {CallSiteTag::CallSiteNoRestriction};
function_tag = NoRestriction;
// Create special nodes.
required_to_be_uniform = CreateNode({"RequiredToBeUniform"});
required_to_be_uniform_error = CreateNode({"RequiredToBeUniform_Error"});
required_to_be_uniform_warning = CreateNode({"RequiredToBeUniform_Warning"});
required_to_be_uniform_info = CreateNode({"RequiredToBeUniform_Info"});
may_be_non_uniform = CreateNode({"MayBeNonUniform"});
cf_start = CreateNode({"CF_start"});
if (func->return_type) {
@ -214,14 +222,16 @@ struct FunctionInfo {
/// The control flow graph.
utils::BlockAllocator<Node> nodes;
/// Special `RequiredToBeUniform` node.
Node* required_to_be_uniform;
/// Special `RequiredToBeUniform` nodes.
Node* required_to_be_uniform_error = nullptr;
Node* required_to_be_uniform_warning = nullptr;
Node* required_to_be_uniform_info = nullptr;
/// Special `MayBeNonUniform` node.
Node* may_be_non_uniform;
Node* may_be_non_uniform = nullptr;
/// Special `CF_start` node.
Node* cf_start;
Node* cf_start = nullptr;
/// Special `Value_return` node.
Node* value_return;
Node* value_return = nullptr;
/// Map from variables to their value nodes in the graph, scoped with respect to control flow.
ScopeStack<const sem::Variable*, Node*> variables;
@ -246,6 +256,21 @@ struct FunctionInfo {
utils::Hashmap<const sem::Variable*, Node*, 4> var_exit_nodes;
};
/// @returns the RequiredToBeUniform node that corresponds to `severity`
Node* RequiredToBeUniform(ast::DiagnosticSeverity severity) {
switch (severity) {
case ast::DiagnosticSeverity::kError:
return required_to_be_uniform_error;
case ast::DiagnosticSeverity::kWarning:
return required_to_be_uniform_warning;
case ast::DiagnosticSeverity::kInfo:
return required_to_be_uniform_info;
default:
TINT_ASSERT(Resolver, false && "unhandled severity");
return nullptr;
}
}
/// @returns a LoopSwitchInfo for the given statement, allocating the LoopSwitchInfo if this is
/// the first call with the given statement.
LoopSwitchInfo& LoopSwitchInfoFor(const sem::Statement* stmt) {
@ -401,32 +426,49 @@ class UniformityGraph {
// For pointers, we distinguish between requiring uniformity of the contents versus
// the pointer itself.
if (reachable.Contains(param_info.ptr_input_contents)) {
return ParameterContentsRequiredToBeUniform;
return ParameterTag::ParameterContentsRequiredToBeUniform;
} else if (reachable.Contains(param_info.value)) {
return ParameterValueRequiredToBeUniform;
return ParameterTag::ParameterValueRequiredToBeUniform;
}
} else if (reachable.Contains(current_function_->variables.Get(param))) {
// For non-pointers, the requirement is always on the value.
return ParameterValueRequiredToBeUniform;
return ParameterTag::ParameterValueRequiredToBeUniform;
}
return ParameterNoRestriction;
return ParameterTag::ParameterNoRestriction;
};
// Look at which nodes are reachable from "RequiredToBeUniform".
{
utils::UniqueVector<Node*, 4> reachable;
Traverse(current_function_->required_to_be_uniform, &reachable);
if (reachable.Contains(current_function_->may_be_non_uniform)) {
MakeError(*current_function_, current_function_->may_be_non_uniform);
return false;
}
if (reachable.Contains(current_function_->cf_start)) {
current_function_->callsite_tag = CallSiteRequiredToBeUniform;
}
auto traverse = [&](ast::DiagnosticSeverity severity) {
Traverse(current_function_->RequiredToBeUniform(severity), &reachable);
if (reachable.Contains(current_function_->may_be_non_uniform)) {
MakeError(*current_function_, current_function_->may_be_non_uniform, severity);
return false;
}
if (reachable.Contains(current_function_->cf_start)) {
if (current_function_->callsite_tag.tag == CallSiteTag::CallSiteNoRestriction) {
current_function_->callsite_tag = {CallSiteTag::CallSiteRequiredToBeUniform,
severity};
}
}
// Set the tags to capture the direct uniformity requirements of each parameter.
for (size_t i = 0; i < func->params.Length(); i++) {
current_function_->parameters[i].tag_direct = get_param_tag(reachable, i);
// Set the tags to capture the direct uniformity requirements of each parameter.
for (size_t i = 0; i < func->params.Length(); i++) {
if (current_function_->parameters[i].tag_direct.tag ==
ParameterTag::ParameterNoRestriction) {
current_function_->parameters[i].tag_direct = {get_param_tag(reachable, i),
severity};
}
}
return true;
};
if (!traverse(ast::DiagnosticSeverity::kError)) {
return false;
} else {
if (traverse(ast::DiagnosticSeverity::kWarning)) {
traverse(ast::DiagnosticSeverity::kInfo);
}
}
}
@ -441,7 +483,7 @@ class UniformityGraph {
// Set the tags to capture the uniformity requirements of each parameter with respect to
// the function return value.
for (size_t i = 0; i < func->params.Length(); i++) {
current_function_->parameters[i].tag_retval = get_param_tag(reachable, i);
current_function_->parameters[i].tag_retval = {get_param_tag(reachable, i)};
}
}
@ -467,9 +509,9 @@ class UniformityGraph {
for (size_t j = 0; j < func->params.Length(); j++) {
auto tag = get_param_tag(reachable, j);
auto* source_param = sem_.Get<sem::Parameter>(func->params[j]);
if (tag == ParameterContentsRequiredToBeUniform) {
if (tag == ParameterTag::ParameterContentsRequiredToBeUniform) {
param_info.ptr_output_source_param_contents.Push(source_param);
} else if (tag == ParameterValueRequiredToBeUniform) {
} else if (tag == ParameterTag::ParameterValueRequiredToBeUniform) {
param_info.ptr_output_source_param_values.Push(source_param);
}
}
@ -1444,8 +1486,11 @@ class UniformityGraph {
result->type = Node::kFunctionCallReturnValue;
Node* cf_after = CreateNode({"CF_after_", name}, call);
auto default_severity = kUniformityFailuresAsError ? ast::DiagnosticSeverity::kError
: ast::DiagnosticSeverity::kWarning;
// Get tags for the callee.
CallSiteTag callsite_tag = CallSiteNoRestriction;
CallSiteTag callsite_tag = {CallSiteTag::CallSiteNoRestriction};
FunctionTag function_tag = NoRestriction;
auto* sem = SemCall(call);
const FunctionInfo* func_info = nullptr;
@ -1455,21 +1500,23 @@ class UniformityGraph {
// Most builtins have no restrictions. The exceptions are barriers, derivatives,
// some texture sampling builtins, and atomics.
if (builtin->IsBarrier()) {
callsite_tag = CallSiteRequiredToBeUniform;
callsite_tag = {CallSiteTag::CallSiteRequiredToBeUniform, default_severity};
} else if (builtin->Type() == sem::BuiltinType::kWorkgroupUniformLoad) {
callsite_tag = CallSiteRequiredToBeUniform;
callsite_tag = {CallSiteTag::CallSiteRequiredToBeUniform, default_severity};
} else if (builtin->IsDerivative() ||
builtin->Type() == sem::BuiltinType::kTextureSample ||
builtin->Type() == sem::BuiltinType::kTextureSampleBias ||
builtin->Type() == sem::BuiltinType::kTextureSampleCompare) {
callsite_tag = CallSiteRequiredToBeUniform;
function_tag = ReturnValueMayBeNonUniform;
// Get the severity of derivative uniformity violations in this context.
auto severity =
sem_.DiagnosticSeverity(call, ast::DiagnosticRule::kDerivativeUniformity);
if (severity != ast::DiagnosticSeverity::kOff) {
callsite_tag = {CallSiteTag::CallSiteRequiredToBeUniform, severity};
function_tag = ReturnValueMayBeNonUniform;
}
} else if (builtin->IsAtomic()) {
callsite_tag = CallSiteNoRestriction;
callsite_tag = {CallSiteTag::CallSiteNoRestriction};
function_tag = ReturnValueMayBeNonUniform;
} else {
callsite_tag = CallSiteNoRestriction;
function_tag = NoRestriction;
}
},
[&](const sem::Function* func) {
@ -1482,11 +1529,11 @@ class UniformityGraph {
func_info = info;
},
[&](const sem::TypeInitializer*) {
callsite_tag = CallSiteNoRestriction;
callsite_tag = {CallSiteTag::CallSiteNoRestriction};
function_tag = NoRestriction;
},
[&](const sem::TypeConversion*) {
callsite_tag = CallSiteNoRestriction;
callsite_tag = {CallSiteTag::CallSiteNoRestriction};
function_tag = NoRestriction;
},
[&](Default) {
@ -1507,27 +1554,29 @@ class UniformityGraph {
auto& param_info = func_info->parameters[i];
// Capture the direct uniformity requirements.
switch (param_info.tag_direct) {
case ParameterValueRequiredToBeUniform:
current_function_->required_to_be_uniform->AddEdge(args[i]);
switch (param_info.tag_direct.tag) {
case ParameterTag::ParameterValueRequiredToBeUniform:
current_function_->RequiredToBeUniform(param_info.tag_direct.severity)
->AddEdge(args[i]);
break;
case ParameterContentsRequiredToBeUniform: {
current_function_->required_to_be_uniform->AddEdge(ptrarg_contents[i]);
case ParameterTag::ParameterContentsRequiredToBeUniform: {
current_function_->RequiredToBeUniform(param_info.tag_direct.severity)
->AddEdge(ptrarg_contents[i]);
break;
}
case ParameterNoRestriction:
case ParameterTag::ParameterNoRestriction:
break;
}
// Capture the effects of this parameter on the return value.
switch (param_info.tag_retval) {
case ParameterValueRequiredToBeUniform:
switch (param_info.tag_retval.tag) {
case ParameterTag::ParameterValueRequiredToBeUniform:
result->AddEdge(args[i]);
break;
case ParameterContentsRequiredToBeUniform: {
case ParameterTag::ParameterContentsRequiredToBeUniform: {
result->AddEdge(ptrarg_contents[i]);
break;
}
case ParameterNoRestriction:
case ParameterTag::ParameterNoRestriction:
break;
}
@ -1566,7 +1615,7 @@ class UniformityGraph {
auto* builtin = sem->Target()->As<sem::Builtin>();
if (builtin && builtin->Type() == sem::BuiltinType::kWorkgroupUniformLoad) {
// The workgroupUniformLoad builtin requires its parameter to be uniform.
current_function_->required_to_be_uniform->AddEdge(args[i]);
current_function_->RequiredToBeUniform(default_severity)->AddEdge(args[i]);
} else {
// All other builtin function parameters are RequiredToBeUniformForReturnValue,
// as are parameters for type initializers and type conversions.
@ -1578,8 +1627,8 @@ class UniformityGraph {
// Add the callsite requirement last.
// We traverse edges in reverse order, so this makes the callsite requirement take highest
// priority when reporting violations.
if (callsite_tag == CallSiteRequiredToBeUniform) {
current_function_->required_to_be_uniform->AddEdge(call_node);
if (callsite_tag.tag == CallSiteTag::CallSiteRequiredToBeUniform) {
current_function_->RequiredToBeUniform(callsite_tag.severity)->AddEdge(call_node);
}
return {cf_after, result};
@ -1625,8 +1674,9 @@ class UniformityGraph {
}
/// Recursively descend through the function called by `call` and the functions that it calls in
/// order to find a call to a builtin function that requires uniformity.
const ast::CallExpression* FindBuiltinThatRequiresUniformity(const ast::CallExpression* call) {
/// order to find a call to a builtin function that requires uniformity with the given severity.
const ast::CallExpression* FindBuiltinThatRequiresUniformity(const ast::CallExpression* call,
ast::DiagnosticSeverity severity) {
auto* target = SemCall(call)->Target();
if (target->Is<sem::Builtin>()) {
// This is a call to a builtin, so we must be done.
@ -1635,10 +1685,10 @@ class UniformityGraph {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
auto target_info = functions_.Find(user->Declaration());
for (auto* call_node : target_info->required_to_be_uniform->edges) {
for (auto* call_node : target_info->RequiredToBeUniform(severity)->edges) {
if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
return FindBuiltinThatRequiresUniformity(child_call);
return FindBuiltinThatRequiresUniformity(child_call, severity);
}
}
TINT_ASSERT(Resolver, false && "unable to find child call with uniformity requirement");
@ -1783,16 +1833,15 @@ class UniformityGraph {
});
}
/// Generate an error message for a uniformity issue.
/// Generate a diagnostic message for a uniformity issue.
/// @param function the function that the diagnostic is being produced for
/// @param source_node the node that has caused a uniformity issue in `function`
void MakeError(FunctionInfo& function, Node* source_node) {
/// @param severity the severity of the diagnostic
void MakeError(FunctionInfo& function, Node* source_node, ast::DiagnosticSeverity severity) {
// Helper to produce a diagnostic message, as a note or with the global failure severity.
auto report = [&](Source source, std::string msg, bool note) {
diag::Diagnostic error{};
auto failureSeverity =
kUniformityFailuresAsError ? diag::Severity::Error : diag::Severity::Warning;
error.severity = note ? diag::Severity::Note : failureSeverity;
error.severity = note ? diag::Severity::Note : ast::ToSeverity(severity);
error.system = diag::System::Resolver;
error.source = source;
error.message = msg;
@ -1801,12 +1850,12 @@ class UniformityGraph {
// Traverse the graph to generate a path from RequiredToBeUniform to the source node.
function.ResetVisited();
Traverse(function.required_to_be_uniform);
Traverse(function.RequiredToBeUniform(severity));
TINT_ASSERT(Resolver, source_node->visited_from);
// Find a node that is required to be uniform that has a path to the source node.
auto* cause = TraceBackAlongPathUntil(source_node, [&](Node* node) {
return node->visited_from == function.required_to_be_uniform;
return node->visited_from == function.RequiredToBeUniform(severity);
});
// The node will always have a corresponding call expression.
@ -1825,7 +1874,7 @@ class UniformityGraph {
auto next_function = functions_.Find(user_func->Declaration());
auto& param_info = next_function->parameters[cause->arg_index];
MakeError(*next_function,
is_value ? param_info.value : param_info.ptr_input_contents);
is_value ? param_info.value : param_info.ptr_input_contents, severity);
}
// Show the place where the non-uniform argument was passed.
@ -1838,7 +1887,7 @@ class UniformityGraph {
// Show the origin of non-uniformity for the value or data that is being passed.
ShowSourceOfNonUniformity(source_node->visited_from);
} else {
auto* builtin_call = FindBuiltinThatRequiresUniformity(call);
auto* builtin_call = FindBuiltinThatRequiresUniformity(call, severity);
{
// Show a builtin was reachable from this call (which may be the call itself).
// This will be the trigger location for the failure.

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
@ -7868,6 +7869,284 @@ note: control flow depends on possibly non-uniform value
note: reading from module-scope private variable 'v0' may result in a non-uniform value)");
}
////////////////////////////////////////////////////////////////////////////////
/// Tests for the derivative_uniformity diagnostic filter.
////////////////////////////////////////////////////////////////////////////////
class UniformityAnalysisDiagnosticFilterTest
: public UniformityAnalysisTestBase,
public ::testing::TestWithParam<ast::DiagnosticSeverity> {
protected:
// TODO(jrprice): Remove this in favour of utils::ToString() when we change "note" to "info".
const char* ToStr(ast::DiagnosticSeverity severity) {
switch (severity) {
case ast::DiagnosticSeverity::kError:
return "error";
case ast::DiagnosticSeverity::kWarning:
return "warning";
case ast::DiagnosticSeverity::kInfo:
return "note";
default:
return "<undefined>";
}
}
};
TEST_P(UniformityAnalysisDiagnosticFilterTest, Directive) {
auto& param = GetParam();
std::ostringstream ss;
ss << "diagnostic(" << param << ", derivative_uniformity);"
<< R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@group(0) @binding(1) var t : texture_2d<f32>;
@group(0) @binding(2) var s : sampler;
fn foo() {
if (non_uniform == 42) {
let color = textureSample(t, s, vec2(0, 0));
}
}
)";
RunTest(ss.str(), param == ast::DiagnosticSeverity::kOff);
if (param == ast::DiagnosticSeverity::kOff) {
EXPECT_TRUE(error_.empty());
} else {
std::ostringstream err;
err << ToStr(param) << ": 'textureSample' must only be called";
EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
}
}
TEST_P(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction) {
auto& param = GetParam();
std::ostringstream ss;
ss << R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@group(0) @binding(1) var t : texture_2d<f32>;
@group(0) @binding(2) var s : sampler;
)"
<< "@diagnostic(" << param << ", derivative_uniformity)"
<<
R"(fn foo() {
if (non_uniform == 42) {
let color = textureSample(t, s, vec2(0, 0));
}
}
)";
RunTest(ss.str(), param == ast::DiagnosticSeverity::kOff);
if (param == ast::DiagnosticSeverity::kOff) {
EXPECT_TRUE(error_.empty());
} else {
std::ostringstream err;
err << ToStr(param) << ": 'textureSample' must only be called";
EXPECT_THAT(error_, ::testing::HasSubstr(err.str()));
}
}
INSTANTIATE_TEST_SUITE_P(UniformityAnalysisTest,
UniformityAnalysisDiagnosticFilterTest,
::testing::Values(ast::DiagnosticSeverity::kError,
ast::DiagnosticSeverity::kWarning,
ast::DiagnosticSeverity::kInfo,
ast::DiagnosticSeverity::kOff));
TEST_F(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction_CalledByAnotherFunction) {
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@diagnostic(info, derivative_uniformity)
fn bar() {
dpdx(1.0);
}
fn foo() {
if (non_uniform == 42) {
bar();
}
}
)";
RunTest(src, false);
EXPECT_THAT(error_, ::testing::HasSubstr("note: 'dpdx' must only be called"));
}
TEST_F(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction_RequirementOnParameter) {
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
@diagnostic(info, derivative_uniformity)
fn bar(x : i32) {
if (x == 0) {
dpdx(1.0);
}
}
fn foo() {
bar(non_uniform);
}
)";
RunTest(src, false);
EXPECT_THAT(error_, ::testing::HasSubstr("note: 'dpdx' must only be called"));
}
TEST_F(UniformityAnalysisDiagnosticFilterTest, AttributeOnFunction_BuiltinInChildCall) {
// Make sure that the diagnostic filter does not descend into functions called by the function
// with the attribute.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn bar() {
dpdx(1.0);
}
@diagnostic(off, derivative_uniformity)
fn foo() {
if (non_uniform == 42) {
bar();
}
}
)";
RunTest(src, false);
EXPECT_THAT(error_, ::testing::HasSubstr(": 'dpdx' must only be called"));
}
TEST_F(UniformityAnalysisDiagnosticFilterTest, MixOfGlobalAndLocalFilters) {
// Test that a global filter is overridden by a local attribute, and that we find multiple
// violations until an error is found.
std::string src = R"(
diagnostic(info, derivative_uniformity);
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn a() {
if (non_uniform == 42) {
dpdx(1.0);
}
}
@diagnostic(off, derivative_uniformity)
fn b() {
if (non_uniform == 42) {
dpdx(1.0);
}
}
@diagnostic(info, derivative_uniformity)
fn c() {
if (non_uniform == 42) {
dpdx(1.0);
}
}
@diagnostic(warning, derivative_uniformity)
fn d() {
if (non_uniform == 42) {
dpdx(1.0);
}
}
@diagnostic(error, derivative_uniformity)
fn e() {
if (non_uniform == 42) {
dpdx(1.0);
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:8:5 note: 'dpdx' must only be called from uniform control flow
dpdx(1.0);
^^^^
test:7:3 note: control flow depends on possibly non-uniform value
if (non_uniform == 42) {
^^
test:7:7 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
if (non_uniform == 42) {
^^^^^^^^^^^
test:22:5 note: 'dpdx' must only be called from uniform control flow
dpdx(1.0);
^^^^
test:21:3 note: control flow depends on possibly non-uniform value
if (non_uniform == 42) {
^^
test:21:7 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
if (non_uniform == 42) {
^^^^^^^^^^^
test:29:5 warning: 'dpdx' must only be called from uniform control flow
dpdx(1.0);
^^^^
test:28:3 note: control flow depends on possibly non-uniform value
if (non_uniform == 42) {
^^
test:28:7 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
if (non_uniform == 42) {
^^^^^^^^^^^
test:36:5 error: 'dpdx' must only be called from uniform control flow
dpdx(1.0);
^^^^
test:35:3 note: control flow depends on possibly non-uniform value
if (non_uniform == 42) {
^^
test:35:7 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
if (non_uniform == 42) {
^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisDiagnosticFilterTest, BarriersNotAffected) {
// Make sure that the diagnostic filter does not affect barriers.
std::string src = R"(
diagnostic(off, derivative_uniformity);
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
if (non_uniform == 42) {
dpdx(1.0);
}
}
fn bar() {
if (non_uniform == 42) {
workgroupBarrier();
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:14:5 warning: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
test:13:3 note: control flow depends on possibly non-uniform value
if (non_uniform == 42) {
^^
test:13:7 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
if (non_uniform == 42) {
^^^^^^^^^^^
)");
}
////////////////////////////////////////////////////////////////////////////////
/// Tests for the quality of the error messages produced by the analysis.
////////////////////////////////////////////////////////////////////////////////

View File

@ -168,6 +168,8 @@ Validator::Validator(
atomic_composite_info_(atomic_composite_info),
valid_type_storage_layouts_(valid_type_storage_layouts) {
// Set default severities for filterable diagnostic rules.
diagnostic_filters_.Set(ast::DiagnosticRule::kDerivativeUniformity,
ast::DiagnosticSeverity::kWarning);
diagnostic_filters_.Set(ast::DiagnosticRule::kChromiumUnreachableCode,
ast::DiagnosticSeverity::kWarning);
}