diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index 6c4b53a5ee..f5463a1198 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -474,6 +474,8 @@ class UniformityGraph { // If "Value_return" exists, look at which nodes are reachable from it. if (current_function_->value_return) { + current_function_->ResetVisited(); + utils::UniqueVector reachable; Traverse(current_function_->value_return, &reachable); if (reachable.Contains(current_function_->may_be_non_uniform)) { diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc index 334352de4f..b8636f9b7d 100644 --- a/src/tint/resolver/uniformity_test.cc +++ b/src/tint/resolver/uniformity_test.cc @@ -8166,6 +8166,46 @@ test:6:9 note: return value of 'dpdx' may be non-uniform )"); } +TEST_F(UniformityAnalysisDiagnosticFilterTest, + ParameterRequiredToBeUniform_With_ParameterRequiredToBeUniformForReturnValue) { + // Make sure that both requirements on parameters are captured. + std::string src = R"( +@diagnostic(info,derivative_uniformity) +fn foo(x : bool) -> bool { + if (x) { + _ = dpdx(1.0); // Should trigger an info + } + return x; +} + +var non_uniform: bool; + +@diagnostic(error,derivative_uniformity) +fn bar() { + let ret = foo(non_uniform); + if (ret) { + _ = dpdy(1.0); // Should trigger an error + } +} + +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:16:9 error: 'dpdy' must only be called from uniform control flow + _ = dpdy(1.0); // Should trigger an error + ^^^^ + +test:15:3 note: control flow depends on possibly non-uniform value + if (ret) { + ^^ + +test:14:17 note: reading from module-scope private variable 'non_uniform' may result in a non-uniform value + let ret = foo(non_uniform); + ^^^^^^^^^^^ +)"); +} + TEST_F(UniformityAnalysisDiagnosticFilterTest, BarriersNotAffected) { // Make sure that the diagnostic filter does not affect barriers. std::string src = R"(