diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index 200221a602..7a4d25ef57 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -557,10 +557,13 @@ class UniformityGraph { auto [cf_r, _] = ProcessExpression(cf, a->rhs); return cf_r; } - auto [cf_l, v_l, apply] = ProcessLValueExpression(cf, a->lhs); + auto [cf_l, v_l, ident] = ProcessLValueExpression(cf, a->lhs); auto [cf_r, v_r] = ProcessExpression(cf_l, a->rhs); v_l->AddEdge(v_r); - apply(); + + // Update the variable node for the LHS variable. + current_function_->variables.Set(ident, v_l); + return cf_r; }, @@ -706,18 +709,28 @@ class UniformityGraph { // The compound assignment statement `a += b` is equivalent to: // let p = &a; // *p = *p + b; - // Note: we set load_rule=true when evaluating the LHS, as the resolver does not add - // a load node for it. - auto [cf1, l1, apply] = ProcessLValueExpression(cf, c->lhs); - auto [cf2, v2] = ProcessExpression(cf1, c->lhs, /* load_rule */ true); - auto [cf3, v3] = ProcessExpression(cf2, c->rhs); + + // Evaluate the LHS. + auto [cf1, l1, ident] = ProcessLValueExpression(cf, c->lhs); + + // Get the current value loaded from the LHS reference before evaluating the RHS. + auto* lhs_load = current_function_->variables.Get(ident); + + // Evaluate the RHS. + auto [cf2, v2] = ProcessExpression(cf1, c->rhs); + + // Create a node for the resulting value. auto* result = CreateNode({"binary_expr_result"}); result->AddEdge(v2); - result->AddEdge(v3); + if (lhs_load) { + result->AddEdge(lhs_load); + } + // Update the variable node for the LHS variable. l1->AddEdge(result); - apply(); - return cf3; + current_function_->variables.Set(ident, l1); + + return cf2; }, [&](const ast::ContinueStatement* c) { @@ -968,17 +981,25 @@ class UniformityGraph { [&](const ast::IncrementDecrementStatement* i) { // The increment/decrement statement `i++` is equivalent to `i = i + 1`. - // Note: we set load_rule=true when evaluating the LHS the first time, as the - // resolver does not add a load node for it. - auto [cf1, v1] = ProcessExpression(cf, i->lhs, /* load_rule */ true); - auto* result = CreateNode({"incdec_result"}); - result->AddEdge(v1); - result->AddEdge(cf1); - auto [cf2, l2, apply] = ProcessLValueExpression(cf1, i->lhs); - l2->AddEdge(result); - apply(); - return cf2; + // Evaluate the LHS. + auto [cf1, l1, ident] = ProcessLValueExpression(cf, i->lhs); + + // Get the current value loaded from the LHS reference. + auto* lhs_load = current_function_->variables.Get(ident); + + // Create a node for the resulting value. + auto* result = CreateNode({"incdec_result"}); + result->AddEdge(cf1); + if (lhs_load) { + result->AddEdge(lhs_load); + } + + // Update the variable node for the LHS variable. + l1->AddEdge(result); + current_function_->variables.Set(ident, l1); + + return cf1; }, [&](const ast::LoopStatement* l) { @@ -1384,8 +1405,8 @@ class UniformityGraph { /// The new value node for an LValue expression Node* new_val = nullptr; - /// Updates the value node of the LValue expression to be #new_val. - std::function apply; + /// The root identifier for an LValue expression. + const sem::Variable* root_identifier = nullptr; }; /// Process an LValue expression. @@ -1401,13 +1422,11 @@ class UniformityGraph { [&](const ast::IdentifierExpression* i) { auto* sem = sem_.GetVal(i)->UnwrapLoad()->As(); if (sem->Variable()->Is()) { - return LValue{cf, current_function_->may_be_non_uniform, [] {}}; + return LValue{cf, current_function_->may_be_non_uniform, nullptr}; } else if (auto* local = sem->Variable()->As()) { // Create a new value node for this variable. auto* value = CreateNode({NameFor(i), "_lvalue"}); - auto apply = [=] { current_function_->variables.Set(local, value); }; - // If i is part of an expression that is a partial reference to a variable (e.g. // index or member access), we link back to the variable's previous value. If // the previous value was non-uniform, a partial assignment will not make it @@ -1417,7 +1436,7 @@ class UniformityGraph { value->AddEdge(old_value); } - return LValue{cf, value, apply}; + return LValue{cf, value, local}; } else { TINT_ICE(Resolver, diagnostics_) << "unknown lvalue identifier expression type: " @@ -1427,11 +1446,11 @@ class UniformityGraph { }, [&](const ast::IndexAccessorExpression* i) { - auto [cf1, l1, apply] = + auto [cf1, l1, root_ident] = ProcessLValueExpression(cf, i->object, /*is_partial_reference*/ true); auto [cf2, v2] = ProcessExpression(cf1, i->index); l1->AddEdge(v2); - return LValue{cf2, l1, apply}; + return LValue{cf2, l1, root_ident}; }, [&](const ast::MemberAccessorExpression* m) { @@ -1445,8 +1464,6 @@ class UniformityGraph { auto* root_ident = sem_.Get(u)->RootIdentifier(); auto* deref = CreateNode({NameFor(root_ident), "_deref"}); - auto apply = [=] { current_function_->variables.Set(root_ident, deref); }; - if (auto* old_value = current_function_->variables.Get(root_ident)) { // If dereferencing a partial reference or partial pointer, we link back to // the variable's previous value. If the previous value was non-uniform, a @@ -1455,7 +1472,7 @@ class UniformityGraph { deref->AddEdge(old_value); } } - return LValue{cf, deref, apply}; + return LValue{cf, deref, root_ident}; } return ProcessLValueExpression(cf, u->expr, is_partial_reference); }, diff --git a/src/tint/resolver/uniformity_test.cc b/src/tint/resolver/uniformity_test.cc index a55d259d8a..fe36da28dc 100644 --- a/src/tint/resolver/uniformity_test.cc +++ b/src/tint/resolver/uniformity_test.cc @@ -7402,6 +7402,128 @@ test:5:11 note: reading from read_write storage buffer 'rw' may result in a non- )"); } +TEST_F(UniformityAnalysisTest, CompoundAssignment_Global) { + // Use compound assignment on a global variable. + // Tests that we do not assume there is always a variable node for the LHS, but we still process + // the expression. + std::string src = R"( +@group(0) @binding(0) var rw : i32; + +var v : array; + +fn bar(p : ptr) -> i32 { + if (*p == 0) { + workgroupBarrier(); + } + return 0; +} + +fn foo() { + var f = rw; + v[bar(&f)] += 1; +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:8:5 error: 'workgroupBarrier' must only be called from uniform control flow + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ + +test:7:3 note: control flow depends on possibly non-uniform value + if (*p == 0) { + ^^ + +test:7:8 note: parameter 'p' of 'bar' may be non-uniform + if (*p == 0) { + ^ + +test:15:9 note: possibly non-uniform value passed via pointer here + v[bar(&f)] += 1; + ^ + +test:14:11 note: reading from read_write storage buffer 'rw' may result in a non-uniform value + var f = rw; + ^^ +)"); +} + +TEST_F(UniformityAnalysisTest, IncDec_StillNonUniform) { + // Use increment on a variable that is already non-uniform. + std::string src = R"( +@group(0) @binding(0) var rw : i32; + +fn foo() { + var v = rw; + v++; + if (v == 0) { + workgroupBarrier(); + } +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:8:5 error: 'workgroupBarrier' must only be called from uniform control flow + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ + +test:7:3 note: control flow depends on possibly non-uniform value + if (v == 0) { + ^^ + +test:5:11 note: reading from read_write storage buffer 'rw' may result in a non-uniform value + var v = rw; + ^^ +)"); +} + +TEST_F(UniformityAnalysisTest, IncDec_Global) { + // Use increment on a global variable. + // Tests that we do not assume there is always a variable node for the LHS, but we still process + // the expression. + std::string src = R"( +@group(0) @binding(0) var rw : i32; + +var v : array; + +fn bar(p : ptr) -> i32 { + if (*p == 0) { + workgroupBarrier(); + } + return 0; +} + +fn foo() { + var f = rw; + v[bar(&f)]++; +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:8:5 error: 'workgroupBarrier' must only be called from uniform control flow + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ + +test:7:3 note: control flow depends on possibly non-uniform value + if (*p == 0) { + ^^ + +test:7:8 note: parameter 'p' of 'bar' may be non-uniform + if (*p == 0) { + ^ + +test:15:9 note: possibly non-uniform value passed via pointer here + v[bar(&f)]++; + ^ + +test:14:11 note: reading from read_write storage buffer 'rw' may result in a non-uniform value + var f = rw; + ^^ +)"); +} + TEST_F(UniformityAnalysisTest, ShortCircuiting_UniformLHS) { std::string src = R"( @group(0) @binding(0) var uniform_global : i32; @@ -8649,5 +8771,108 @@ test:19:9 note: contents of pointer may become non-uniform after calling 'a' )"); } +TEST_F(UniformityAnalysisTest, CompoundAssignmentEval_RHS_Makes_LHS_NonUniform_After_Load) { + // Test that the LHS is loaded from before the RHS makes is evaluated. + std::string src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn bar(p : ptr) -> i32 { + *p = non_uniform; + return 0; +} + +fn foo() { + var i = 0; + var arr : array; + i += arr[bar(&i)]; + if (i == 0) { + workgroupBarrier(); + } +} +)"; + + RunTest(src, true); +} + +TEST_F(UniformityAnalysisTest, CompoundAssignmentEval_RHS_Makes_LHS_Uniform_After_Load) { + // Test that the LHS is loaded from before the RHS makes is evaluated. + std::string src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn bar(p : ptr) -> i32 { + *p = 0; + return 0; +} + +fn foo() { + var i = non_uniform; + var arr : array; + i += arr[bar(&i)]; + if (i == 0) { + workgroupBarrier(); + } +} +)"; + + RunTest(src, false); + EXPECT_EQ(error_, + R"(test:14:5 error: 'workgroupBarrier' must only be called from uniform control flow + workgroupBarrier(); + ^^^^^^^^^^^^^^^^ + +test:13:3 note: control flow depends on possibly non-uniform value + if (i == 0) { + ^^ + +test:10:11 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value + var i = non_uniform; + ^^^^^^^^^^^ +)"); +} + +TEST_F(UniformityAnalysisTest, CompoundAssignmentEval_LHS_OnlyOnce) { + std::string src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn bar(p : ptr) -> i32 { + if (*p == 0) { + workgroupBarrier(); + } + *p = non_uniform; + return 0; +} + +fn foo(){ + var f : i32 = 0; + var arr : array; + arr[bar(&f)] += 1; +} +)"; + + RunTest(src, true); +} + +TEST_F(UniformityAnalysisTest, IncDec_LHS_OnlyOnce) { + std::string src = R"( +@group(0) @binding(0) var non_uniform : i32; + +fn bar(p : ptr) -> i32 { + if (*p == 0) { + workgroupBarrier(); + } + *p = non_uniform; + return 0; +} + +fn foo(){ + var f : i32 = 0; + var arr : array; + arr[bar(&f)]++; +} +)"; + + RunTest(src, true); +} + } // namespace } // namespace tint::resolver