// Copyright 2022 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/resolver/uniformity.h" #include #include #include #include #include #include #include "src/tint/program_builder.h" #include "src/tint/resolver/dependency_graph.h" #include "src/tint/scope_stack.h" #include "src/tint/sem/block_statement.h" #include "src/tint/sem/for_loop_statement.h" #include "src/tint/sem/function.h" #include "src/tint/sem/if_statement.h" #include "src/tint/sem/info.h" #include "src/tint/sem/loop_statement.h" #include "src/tint/sem/statement.h" #include "src/tint/sem/switch_statement.h" #include "src/tint/sem/type_conversion.h" #include "src/tint/sem/type_initializer.h" #include "src/tint/sem/variable.h" #include "src/tint/sem/while_statement.h" #include "src/tint/utils/block_allocator.h" #include "src/tint/utils/map.h" #include "src/tint/utils/unique_vector.h" // Set to `1` to dump the uniformity graph for each function in graphviz format. #define TINT_DUMP_UNIFORMITY_GRAPH 0 namespace tint::resolver { namespace { /// Unwraps `u->expr`'s chain of indirect (*) and address-of (&) expressions, returning the first /// expression that is neither of these. /// E.g. If `u` is `*(&(*(&p)))`, returns `p`. const ast::Expression* UnwrapIndirectAndAddressOfChain(const ast::UnaryOpExpression* u) { auto* e = u->expr; while (true) { auto* unary = e->As(); if (unary && (unary->op == ast::UnaryOp::kIndirection || unary->op == ast::UnaryOp::kAddressOf)) { e = unary->expr; } else { break; } } return e; } /// CallSiteTag describes the uniformity requirements on the call sites of a function. enum CallSiteTag { CallSiteRequiredToBeUniform, CallSiteNoRestriction, }; /// FunctionTag describes a functions effects on uniformity. enum FunctionTag { ReturnValueMayBeNonUniform, NoRestriction, }; /// ParameterTag describes the uniformity requirements of values passed to a function parameter. enum ParameterTag { ParameterRequiredToBeUniform, ParameterRequiredToBeUniformForReturnValue, ParameterNoRestriction, }; /// Node represents a node in the graph of control flow and value nodes within the analysis of a /// single function. struct Node { /// Constructor /// @param a the corresponding AST node explicit Node(const ast::Node* a) : ast(a) {} #if TINT_DUMP_UNIFORMITY_GRAPH /// The node tag. std::string tag; #endif /// Type describes the type of the node, which is used to determine additional diagnostic /// information. enum Type { kRegular, kFunctionCallArgument, kFunctionCallPointerArgumentResult, kFunctionCallReturnValue, }; /// The type of the node. Type type = kRegular; /// `true` if this node represents a potential control flow change. bool affects_control_flow = false; /// The corresponding AST node, or nullptr. const ast::Node* ast = nullptr; /// The function call argument index, if applicable. uint32_t arg_index; /// The set of edges from this node to other nodes in the graph. utils::UniqueVector edges; /// The node that this node was visited from, or nullptr if not visited. Node* visited_from = nullptr; /// Add an edge to the `to` node. /// @param to the destination node void AddEdge(Node* to) { edges.Add(to); } }; /// ParameterInfo holds information about the uniformity requirements and effects for a particular /// function parameter. struct ParameterInfo { /// The semantic node in corresponds to this parameter. const sem::Parameter* sem; /// The parameter's uniformity requirements. ParameterTag tag = ParameterNoRestriction; /// Will be `true` if this function may cause the contents of this pointer parameter to become /// non-uniform. bool pointer_may_become_non_uniform = false; /// The parameters that are required to be uniform for the contents of this pointer parameter to /// be uniform at function exit. std::vector pointer_param_output_sources; /// The node in the graph that corresponds to this parameter's initial value. Node* init_value; /// The node in the graph that corresponds to this parameter's output value (or nullptr). Node* pointer_return_value = nullptr; }; /// FunctionInfo holds information about the uniformity requirements and effects for a particular /// function, as well as the control flow graph. struct FunctionInfo { /// Constructor /// @param func the AST function /// @param builder the program builder FunctionInfo(const ast::Function* func, const ProgramBuilder* builder) { name = builder->Symbols().NameFor(func->symbol); callsite_tag = CallSiteNoRestriction; function_tag = NoRestriction; // Create special nodes. required_to_be_uniform = CreateNode("RequiredToBeUniform"); may_be_non_uniform = CreateNode("MayBeNonUniform"); cf_start = CreateNode("CF_start"); if (func->return_type) { value_return = CreateNode("Value_return"); } // Create nodes for parameters. parameters.resize(func->params.Length()); for (size_t i = 0; i < func->params.Length(); i++) { auto* param = func->params[i]; auto param_name = builder->Symbols().NameFor(param->symbol); auto* sem = builder->Sem().Get(param); parameters[i].sem = sem; Node* node_init; if (sem->Type()->Is()) { node_init = CreateNode("ptrparam_" + name + "_init"); parameters[i].pointer_return_value = CreateNode("ptrparam_" + name + "_return"); local_var_decls.insert(sem); } else { node_init = CreateNode("param_" + name); } parameters[i].init_value = node_init; variables.Set(sem, node_init); } } /// The name of the function. std::string name; /// The call site uniformity requirements. CallSiteTag callsite_tag; /// The function's uniformity effects. FunctionTag function_tag; /// The uniformity requirements of the function's parameters. std::vector parameters; /// The control flow graph. utils::BlockAllocator nodes; /// Special `RequiredToBeUniform` node. Node* required_to_be_uniform; /// Special `MayBeNonUniform` node. Node* may_be_non_uniform; /// Special `CF_start` node. Node* cf_start; /// Special `Value_return` node. Node* value_return; /// Map from variables to their value nodes in the graph, scoped with respect to control flow. ScopeStack variables; /// The set of a local read-write vars that are in scope at any given point in the process. /// Includes pointer parameters. std::unordered_set local_var_decls; /// The set of partial pointer variables - pointers that point to a subobject (into an array or /// struct). std::unordered_set partial_ptrs; /// LoopSwitchInfo tracks information about the value of variables for a control flow construct. struct LoopSwitchInfo { /// The type of this control flow construct. std::string type; /// The input values for local variables at the start of this construct. std::unordered_map var_in_nodes; /// The exit values for local variables at the end of this construct. std::unordered_map var_exit_nodes; }; /// Map from control flow statements to the corresponding LoopSwitchInfo structure. std::unordered_map loop_switch_infos; /// Create a new node. /// @param tag a tag used to identify the node for debugging purposes /// @param ast the optional AST node that this node corresponds to /// @returns the new node Node* CreateNode([[maybe_unused]] std::string tag, const ast::Node* ast = nullptr) { auto* node = nodes.Create(ast); #if TINT_DUMP_UNIFORMITY_GRAPH // Make the tag unique and set it. // This only matters if we're dumping the graph. std::string unique_tag = tag; int suffix = 0; while (tags_.count(unique_tag)) { unique_tag = tag + "_$" + std::to_string(++suffix); } tags_.insert(unique_tag); node->tag = name + "." + unique_tag; #endif return node; } /// Reset the visited status of every node in the graph. void ResetVisited() { for (auto* node : nodes.Objects()) { node->visited_from = nullptr; } } private: /// A list of tags that have already been used within the current function. std::unordered_set tags_; }; /// UniformityGraph is used to analyze the uniformity requirements and effects of functions in a /// module. class UniformityGraph { public: /// Constructor. /// @param builder the program to analyze explicit UniformityGraph(ProgramBuilder* builder) : builder_(builder), sem_(builder->Sem()), diagnostics_(builder->Diagnostics()) {} /// Destructor. ~UniformityGraph() {} /// Build and analyze the graph to determine whether the program satisfies the uniformity /// constraints of WGSL. /// @param dependency_graph the dependency-ordered module-scope declarations /// @returns true if all uniformity constraints are satisfied, otherise false bool Build(const DependencyGraph& dependency_graph) { #if TINT_DUMP_UNIFORMITY_GRAPH std::cout << "digraph G {\n"; std::cout << "rankdir=BT\n"; #endif // Process all functions in the module. bool success = true; for (auto* decl : dependency_graph.ordered_globals) { if (auto* func = decl->As()) { if (!ProcessFunction(func)) { success = false; break; } } } #if TINT_DUMP_UNIFORMITY_GRAPH std::cout << "\n}\n"; #endif return success; } private: const ProgramBuilder* builder_; const sem::Info& sem_; diag::List& diagnostics_; /// Map of analyzed function results. std::unordered_map functions_; /// The function currently being analyzed. FunctionInfo* current_function_; /// Create a new node. /// @param tag a tag used to identify the node for debugging purposes. /// @param ast the optional AST node that this node corresponds to /// @returns the new node Node* CreateNode(std::string tag, const ast::Node* ast = nullptr) { return current_function_->CreateNode(std::move(tag), ast); } /// Process a function. /// @param func the function to process /// @returns true if there are no uniformity issues, false otherwise bool ProcessFunction(const ast::Function* func) { functions_.emplace(func, FunctionInfo(func, builder_)); current_function_ = &functions_.at(func); // Process function body. if (func->body) { ProcessStatement(current_function_->cf_start, func->body); } #if TINT_DUMP_UNIFORMITY_GRAPH // Dump the graph for this function as a subgraph. std::cout << "\nsubgraph cluster_" << current_function_->name << " {\n"; std::cout << " label=" << current_function_->name << ";"; for (auto* node : current_function_->nodes.Objects()) { std::cout << "\n \"" << node->tag << "\";"; for (auto* edge : node->edges) { std::cout << "\n \"" << node->tag << "\" -> \"" << edge->tag << "\";"; } } std::cout << "\n}\n"; #endif // Look at which nodes are reachable from "RequiredToBeUniform". { utils::UniqueVector reachable; Traverse(current_function_->required_to_be_uniform, &reachable); if (reachable.Contains(current_function_->may_be_non_uniform)) { MakeError(*current_function_, current_function_->may_be_non_uniform); return false; } if (reachable.Contains(current_function_->cf_start)) { current_function_->callsite_tag = CallSiteRequiredToBeUniform; } // Set the parameter tag to ParameterRequiredToBeUniform for each parameter node that // was reachable. for (size_t i = 0; i < func->params.Length(); i++) { auto* param = func->params[i]; if (reachable.Contains(current_function_->variables.Get(sem_.Get(param)))) { current_function_->parameters[i].tag = ParameterRequiredToBeUniform; } } } // If "Value_return" exists, look at which nodes are reachable from it if (current_function_->value_return) { utils::UniqueVector reachable; Traverse(current_function_->value_return, &reachable); if (reachable.Contains(current_function_->may_be_non_uniform)) { current_function_->function_tag = ReturnValueMayBeNonUniform; } // Set the parameter tag to ParameterRequiredToBeUniformForReturnValue for each // parameter node that was reachable. for (size_t i = 0; i < func->params.Length(); i++) { auto* param = func->params[i]; if (reachable.Contains(current_function_->variables.Get(sem_.Get(param)))) { current_function_->parameters[i].tag = ParameterRequiredToBeUniformForReturnValue; } } } // Traverse the graph for each pointer parameter. for (size_t i = 0; i < func->params.Length(); i++) { if (current_function_->parameters[i].pointer_return_value == nullptr) { continue; } // Reset "visited" state for all nodes. current_function_->ResetVisited(); utils::UniqueVector reachable; Traverse(current_function_->parameters[i].pointer_return_value, &reachable); if (reachable.Contains(current_function_->may_be_non_uniform)) { current_function_->parameters[i].pointer_may_become_non_uniform = true; } // Check every other parameter to see if they feed into this parameter's final value. for (size_t j = 0; j < func->params.Length(); j++) { auto* param_source = sem_.Get(func->params[j]); if (reachable.Contains(current_function_->parameters[j].init_value)) { current_function_->parameters[i].pointer_param_output_sources.push_back( param_source); } } } return true; } /// Process a statement, returning the new control flow node. /// @param cf the input control flow node /// @param stmt the statement to process d /// @returns the new control flow node Node* ProcessStatement(Node* cf, const ast::Statement* stmt) { return Switch( stmt, [&](const ast::AssignmentStatement* a) { auto [cf1, v1] = ProcessExpression(cf, a->rhs); if (a->lhs->Is()) { return cf1; } else { auto [cf2, l2] = ProcessLValueExpression(cf1, a->lhs); l2->AddEdge(v1); return cf2; } }, [&](const ast::BlockStatement* b) { std::unordered_map scoped_assignments; { // Push a new scope for variable assignments in the block. current_function_->variables.Push(); TINT_DEFER(current_function_->variables.Pop()); for (auto* s : b->statements) { cf = ProcessStatement(cf, s); if (!sem_.Get(s)->Behaviors().Contains(sem::Behavior::kNext)) { break; } } if (sem_.Get(b)) { // We've reached the end of the function body. // Add edges from pointer parameter outputs to their current value. for (auto param : current_function_->parameters) { if (param.pointer_return_value) { param.pointer_return_value->AddEdge( current_function_->variables.Get(param.sem)); } } } scoped_assignments = std::move(current_function_->variables.Top()); } // Propagate all variables assignments to the containing scope if the behavior is // either 'Next' or 'Fallthrough'. auto& behaviors = sem_.Get(b)->Behaviors(); if (behaviors.Contains(sem::Behavior::kNext) || behaviors.Contains(sem::Behavior::kFallthrough)) { for (auto var : scoped_assignments) { current_function_->variables.Set(var.first, var.second); } } // Remove any variables declared in this scope from the set of in-scope variables. for (auto decl : sem_.Get(b)->Decls()) { current_function_->local_var_decls.erase(decl.value.variable); } return cf; }, [&](const ast::BreakStatement* b) { // Find the loop or switch statement that we are in. auto* parent = sem_.Get(b) ->FindFirstParent(); TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent)); auto& info = current_function_->loop_switch_infos.at(parent); // Propagate variable values to the loop/switch exit nodes. for (auto* var : current_function_->local_var_decls) { // Skip variables that were declared inside this loop/switch. if (auto* lv = var->As(); lv && lv->Statement()->FindFirstParent([&](auto* s) { return s == parent; })) { continue; } // Add an edge from the variable exit node to its value at this point. auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); return CreateNode(name + "_value_" + info.type + "_exit"); }); exit_node->AddEdge(current_function_->variables.Get(var)); } return cf; }, [&](const ast::BreakIfStatement* b) { // This works very similar to the IfStatement uniformity below, execpt instead of // processing the body, we directly inline the BreakStatement uniformity from // above. auto [_, v_cond] = ProcessExpression(cf, b->condition); // Add a diagnostic node to capture the control flow change. auto* v = current_function_->CreateNode("break_if_stmt", b); v->affects_control_flow = true; v->AddEdge(v_cond); { auto* parent = sem_.Get(b)->FindFirstParent(); TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent)); auto& info = current_function_->loop_switch_infos.at(parent); // Propagate variable values to the loop exit nodes. for (auto* var : current_function_->local_var_decls) { // Skip variables that were declared inside this loop. if (auto* lv = var->As(); lv && lv->Statement()->FindFirstParent( [&](auto* s) { return s == parent; })) { continue; } // Add an edge from the variable exit node to its value at this point. auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); return CreateNode(name + "_value_" + info.type + "_exit"); }); exit_node->AddEdge(current_function_->variables.Get(var)); } } auto* sem_break_if = sem_.Get(b); if (sem_break_if->Behaviors() != sem::Behaviors{sem::Behavior::kNext}) { auto* cf_end = CreateNode("break_if_CFend"); cf_end->AddEdge(v); return cf_end; } return cf; }, [&](const ast::CallStatement* c) { auto [cf1, _] = ProcessCall(cf, c->expr); return cf1; }, [&](const ast::CompoundAssignmentStatement* c) { // The compound assignment statement `a += b` is equivalent to `a = a + b`. auto [cf1, v1] = ProcessExpression(cf, c->lhs); auto [cf2, v2] = ProcessExpression(cf1, c->rhs); auto* result = CreateNode("binary_expr_result"); result->AddEdge(v1); result->AddEdge(v2); auto [cf3, l3] = ProcessLValueExpression(cf2, c->lhs); l3->AddEdge(result); return cf3; }, [&](const ast::ContinueStatement* c) { // Find the loop statement that we are in. auto* parent = sem_.Get(c) ->FindFirstParent(); TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent)); auto& info = current_function_->loop_switch_infos.at(parent); // Propagate assignments to the loop input nodes. for (auto* var : current_function_->local_var_decls) { // Skip variables that were declared inside this loop. if (auto* lv = var->As(); lv && lv->Statement()->FindFirstParent([&](auto* s) { return s == parent; })) { continue; } // Add an edge from the variable's loop input node to its value at this point. TINT_ASSERT(Resolver, info.var_in_nodes.count(var)); auto* in_node = info.var_in_nodes.at(var); auto* out_node = current_function_->variables.Get(var); if (out_node != in_node) { in_node->AddEdge(out_node); } } return cf; }, [&](const ast::DiscardStatement*) { return cf; }, [&](const ast::FallthroughStatement*) { return cf; }, [&](const ast::ForLoopStatement* f) { auto* sem_loop = sem_.Get(f); auto* cfx = CreateNode("loop_start"); // Insert the initializer before the loop. auto* cf_init = cf; if (f->initializer) { cf_init = ProcessStatement(cf, f->initializer); } auto* cf_start = cf_init; auto& info = current_function_->loop_switch_infos[sem_loop]; info.type = "forloop"; // Create input nodes for any variables declared before this loop. for (auto* v : current_function_->local_var_decls) { auto name = builder_->Symbols().NameFor(v->Declaration()->symbol); auto* in_node = CreateNode(name + "_value_forloop_in"); in_node->AddEdge(current_function_->variables.Get(v)); info.var_in_nodes[v] = in_node; current_function_->variables.Set(v, in_node); } // Insert the condition at the start of the loop body. if (f->condition) { auto [cf_cond, v] = ProcessExpression(cfx, f->condition); auto* cf_condition_end = CreateNode("for_condition_CFend", f); cf_condition_end->affects_control_flow = true; cf_condition_end->AddEdge(v); cf_start = cf_condition_end; // Propagate assignments to the loop exit nodes. for (auto* var : current_function_->local_var_decls) { auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); return CreateNode(name + "_value_" + info.type + "_exit"); }); exit_node->AddEdge(current_function_->variables.Get(var)); } } auto* cf1 = ProcessStatement(cf_start, f->body); // Insert the continuing statement at the end of the loop body. if (f->continuing) { auto* cf2 = ProcessStatement(cf1, f->continuing); cfx->AddEdge(cf2); } else { cfx->AddEdge(cf1); } cfx->AddEdge(cf); // Add edges from variable loop input nodes to their values at the end of the loop. for (auto v : info.var_in_nodes) { auto* in_node = v.second; auto* out_node = current_function_->variables.Get(v.first); if (out_node != in_node) { in_node->AddEdge(out_node); } } // Set each variable's exit node as its value in the outer scope. for (auto v : info.var_exit_nodes) { current_function_->variables.Set(v.first, v.second); } current_function_->loop_switch_infos.erase(sem_loop); if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) { return cf; } else { return cfx; } }, [&](const ast::WhileStatement* w) { auto* sem_loop = sem_.Get(w); auto* cfx = CreateNode("loop_start"); auto* cf_start = cf; auto& info = current_function_->loop_switch_infos[sem_loop]; info.type = "whileloop"; // Create input nodes for any variables declared before this loop. for (auto* v : current_function_->local_var_decls) { auto name = builder_->Symbols().NameFor(v->Declaration()->symbol); auto* in_node = CreateNode(name + "_value_forloop_in"); in_node->AddEdge(current_function_->variables.Get(v)); info.var_in_nodes[v] = in_node; current_function_->variables.Set(v, in_node); } // Insert the condition at the start of the loop body. { auto [cf_cond, v] = ProcessExpression(cfx, w->condition); auto* cf_condition_end = CreateNode("while_condition_CFend", w); cf_condition_end->affects_control_flow = true; cf_condition_end->AddEdge(v); cf_start = cf_condition_end; } // Propagate assignments to the loop exit nodes. for (auto* var : current_function_->local_var_decls) { auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); return CreateNode(name + "_value_" + info.type + "_exit"); }); exit_node->AddEdge(current_function_->variables.Get(var)); } auto* cf1 = ProcessStatement(cf_start, w->body); cfx->AddEdge(cf1); cfx->AddEdge(cf); // Add edges from variable loop input nodes to their values at the end of the loop. for (auto v : info.var_in_nodes) { auto* in_node = v.second; auto* out_node = current_function_->variables.Get(v.first); if (out_node != in_node) { in_node->AddEdge(out_node); } } // Set each variable's exit node as its value in the outer scope. for (auto v : info.var_exit_nodes) { current_function_->variables.Set(v.first, v.second); } current_function_->loop_switch_infos.erase(sem_loop); if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) { return cf; } else { return cfx; } }, [&](const ast::IfStatement* i) { auto* sem_if = sem_.Get(i); auto [_, v_cond] = ProcessExpression(cf, i->condition); // Add a diagnostic node to capture the control flow change. auto* v = current_function_->CreateNode("if_stmt", i); v->affects_control_flow = true; v->AddEdge(v_cond); std::unordered_map true_vars; std::unordered_map false_vars; // Helper to process a statement with a new scope for variable assignments. // Populates `assigned_vars` with new nodes for any variables that are assigned in // this statement. auto process_in_scope = [&](Node* cf_in, const ast::Statement* s, std::unordered_map& assigned_vars) { // Push a new scope for variable assignments. current_function_->variables.Push(); // Process the statement. auto* cf_out = ProcessStatement(cf_in, s); assigned_vars = current_function_->variables.Top(); // Pop the scope and return. current_function_->variables.Pop(); return cf_out; }; auto* cf1 = process_in_scope(v, i->body, true_vars); bool true_has_next = sem_.Get(i->body)->Behaviors().Contains(sem::Behavior::kNext); bool false_has_next = true; Node* cf2 = nullptr; if (i->else_statement) { cf2 = process_in_scope(v, i->else_statement, false_vars); false_has_next = sem_.Get(i->else_statement)->Behaviors().Contains(sem::Behavior::kNext); } // Update values for any variables assigned in the if or else blocks. for (auto* var : current_function_->local_var_decls) { // Skip variables not assigned in either block. if (true_vars.count(var) == 0 && false_vars.count(var) == 0) { continue; } // Create an exit node for the variable. auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); auto* out_node = CreateNode(name + "_value_if_exit"); // Add edges to the assigned value or the initial value. // Only add edges if the behavior for that block contains 'Next'. if (true_has_next) { if (true_vars.count(var)) { out_node->AddEdge(true_vars.at(var)); } else { out_node->AddEdge(current_function_->variables.Get(var)); } } if (false_has_next) { if (false_vars.count(var)) { out_node->AddEdge(false_vars.at(var)); } else { out_node->AddEdge(current_function_->variables.Get(var)); } } current_function_->variables.Set(var, out_node); } if (sem_if->Behaviors() != sem::Behaviors{sem::Behavior::kNext}) { auto* cf_end = CreateNode("if_CFend"); cf_end->AddEdge(cf1); if (cf2) { cf_end->AddEdge(cf2); } return cf_end; } return cf; }, [&](const ast::IncrementDecrementStatement* i) { // The increment/decrement statement `i++` is equivalent to `i = i + 1`. auto [cf1, v1] = ProcessExpression(cf, i->lhs); auto* result = CreateNode("incdec_result"); result->AddEdge(v1); result->AddEdge(cf1); auto [cf2, l2] = ProcessLValueExpression(cf1, i->lhs); l2->AddEdge(result); return cf2; }, [&](const ast::LoopStatement* l) { auto* sem_loop = sem_.Get(l); auto* cfx = CreateNode("loop_start"); auto& info = current_function_->loop_switch_infos[sem_loop]; info.type = "loop"; // Create input nodes for any variables declared before this loop. for (auto* v : current_function_->local_var_decls) { auto name = builder_->Symbols().NameFor(v->Declaration()->symbol); auto* in_node = CreateNode(name + "_value_loop_in", v->Declaration()); in_node->AddEdge(current_function_->variables.Get(v)); info.var_in_nodes[v] = in_node; current_function_->variables.Set(v, in_node); } auto* cf1 = ProcessStatement(cfx, l->body); if (l->continuing) { auto* cf2 = ProcessStatement(cf1, l->continuing); cfx->AddEdge(cf2); } else { cfx->AddEdge(cf1); } cfx->AddEdge(cf); // Add edges from variable loop input nodes to their values at the end of the loop. for (auto v : info.var_in_nodes) { auto* in_node = v.second; auto* out_node = current_function_->variables.Get(v.first); if (out_node != in_node) { in_node->AddEdge(out_node); } } // Set each variable's exit node as its value in the outer scope. for (auto v : info.var_exit_nodes) { current_function_->variables.Set(v.first, v.second); } current_function_->loop_switch_infos.erase(sem_loop); if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) { return cf; } else { return cfx; } }, [&](const ast::ReturnStatement* r) { Node* cf_ret; if (r->value) { auto [cf1, v] = ProcessExpression(cf, r->value); current_function_->value_return->AddEdge(v); cf_ret = cf1; } else { TINT_ASSERT(Resolver, cf != nullptr); cf_ret = cf; } // Add edges from each pointer parameter output to its current value. for (auto param : current_function_->parameters) { if (param.pointer_return_value) { param.pointer_return_value->AddEdge( current_function_->variables.Get(param.sem)); } } return cf_ret; }, [&](const ast::SwitchStatement* s) { auto* sem_switch = sem_.Get(s); auto [cfx, v_cond] = ProcessExpression(cf, s->condition); // Add a diagnostic node to capture the control flow change. auto* v = current_function_->CreateNode("switch_stmt", s); v->affects_control_flow = true; v->AddEdge(v_cond); Node* cf_end = nullptr; if (sem_switch->Behaviors() != sem::Behaviors{sem::Behavior::kNext}) { cf_end = CreateNode("switch_CFend"); } auto& info = current_function_->loop_switch_infos[sem_switch]; info.type = "switch"; auto* cf_n = v; bool previous_case_has_fallthrough = false; for (auto* c : s->body) { auto* sem_case = sem_.Get(c); if (previous_case_has_fallthrough) { cf_n = ProcessStatement(cf_n, c->body); } else { current_function_->variables.Push(); cf_n = ProcessStatement(v, c->body); } if (cf_end) { cf_end->AddEdge(cf_n); } bool has_fallthrough = sem_case->Behaviors().Contains(sem::Behavior::kFallthrough); if (!has_fallthrough) { if (sem_case->Behaviors().Contains(sem::Behavior::kNext)) { // Propagate variable values to the switch exit nodes. for (auto* var : current_function_->local_var_decls) { // Skip variables that were declared inside the switch. if (auto* lv = var->As(); lv && lv->Statement()->FindFirstParent( [&](auto* st) { return st == sem_switch; })) { continue; } // Add an edge from the variable exit node to its new value. auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() { auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); return CreateNode(name + "_value_" + info.type + "_exit"); }); exit_node->AddEdge(current_function_->variables.Get(var)); } } current_function_->variables.Pop(); } previous_case_has_fallthrough = has_fallthrough; } // Update nodes for any variables assigned in the switch statement. for (auto var : info.var_exit_nodes) { current_function_->variables.Set(var.first, var.second); } return cf_end ? cf_end : cf; }, [&](const ast::VariableDeclStatement* decl) { Node* node; auto* sem_var = sem_.Get(decl->variable); if (decl->variable->initializer) { auto [cf1, v] = ProcessExpression(cf, decl->variable->initializer); cf = cf1; node = v; // Store if lhs is a partial pointer if (sem_var->Type()->Is()) { auto* init = sem_.Get(decl->variable->initializer); if (auto* unary_init = init->Declaration()->As()) { auto* e = UnwrapIndirectAndAddressOfChain(unary_init); if (e->IsAnyOf()) { current_function_->partial_ptrs.insert(sem_var); } } } } else { node = cf; } current_function_->variables.Set(sem_var, node); if (decl->variable->Is()) { current_function_->local_var_decls.insert( sem_.Get(decl->variable)); } return cf; }, [&](const ast::StaticAssert*) { return cf; // No impact on uniformity }, [&](Default) { TINT_ICE(Resolver, diagnostics_) << "unknown statement type: " << std::string(stmt->TypeInfo().name); return nullptr; }); } /// Process an identifier expression. /// @param cf the input control flow node /// @param ident the identifier expression to process /// @returns a pair of (control flow node, value node) std::pair ProcessIdentExpression(Node* cf, const ast::IdentifierExpression* ident) { // Helper to check if the entry point attribute of `obj` indicates non-uniformity. auto has_nonuniform_entry_point_attribute = [](auto* obj) { // Only the num_workgroups and workgroup_id builtins are uniform. if (auto* builtin = ast::GetAttribute(obj->attributes)) { if (builtin->builtin == ast::BuiltinValue::kNumWorkgroups || builtin->builtin == ast::BuiltinValue::kWorkgroupId) { return false; } } return true; }; auto name = builder_->Symbols().NameFor(ident->symbol); auto* sem = sem_.Get(ident)->UnwrapMaterialize()->As()->Variable(); auto* node = CreateNode(name + "_ident_expr", ident); return Switch( sem, [&](const sem::Parameter* param) { auto* user_func = param->Owner()->As(); if (user_func && user_func->Declaration()->IsEntryPoint()) { if (auto* str = param->Type()->As()) { // We consider the whole struct to be non-uniform if any one of its members // is non-uniform. bool uniform = true; for (auto* member : str->Members()) { if (has_nonuniform_entry_point_attribute(member->Declaration())) { uniform = false; } } node->AddEdge(uniform ? cf : current_function_->may_be_non_uniform); return std::make_pair(cf, node); } else { if (has_nonuniform_entry_point_attribute(param->Declaration())) { node->AddEdge(current_function_->may_be_non_uniform); } else { node->AddEdge(cf); } return std::make_pair(cf, node); } } else { auto* x = current_function_->variables.Get(param); node->AddEdge(cf); node->AddEdge(x); return std::make_pair(cf, node); } }, [&](const sem::GlobalVariable* global) { if (!global->Declaration()->Is() || global->Access() == ast::Access::kRead) { node->AddEdge(cf); } else { node->AddEdge(current_function_->may_be_non_uniform); } return std::make_pair(cf, node); }, [&](const sem::LocalVariable* local) { node->AddEdge(cf); if (auto* x = current_function_->variables.Get(local)) { node->AddEdge(x); } return std::make_pair(cf, node); }, [&](Default) { TINT_ICE(Resolver, diagnostics_) << "unknown identifier expression type: " << std::string(sem->TypeInfo().name); return std::pair(nullptr, nullptr); }); } /// Process an expression. /// @param cf the input control flow node /// @param expr the expression to process /// @returns a pair of (control flow node, value node) std::pair ProcessExpression(Node* cf, const ast::Expression* expr) { return Switch( expr, [&](const ast::BinaryExpression* b) { if (b->IsLogical()) { // Short-circuiting binary operators are a special case. auto [cf1, v1] = ProcessExpression(cf, b->lhs); // Add a diagnostic node to capture the control flow change. auto* v1_cf = current_function_->CreateNode("short_circuit_op", b); v1_cf->affects_control_flow = true; v1_cf->AddEdge(v1); auto [cf2, v2] = ProcessExpression(v1_cf, b->rhs); return std::pair(cf, v2); } else { auto [cf1, v1] = ProcessExpression(cf, b->lhs); auto [cf2, v2] = ProcessExpression(cf1, b->rhs); auto* result = CreateNode("binary_expr_result", b); result->AddEdge(v1); result->AddEdge(v2); return std::pair(cf2, result); } }, [&](const ast::BitcastExpression* b) { return ProcessExpression(cf, b->expr); }, [&](const ast::CallExpression* c) { return ProcessCall(cf, c); }, [&](const ast::IdentifierExpression* i) { return ProcessIdentExpression(cf, i); }, [&](const ast::IndexAccessorExpression* i) { auto [cf1, v1] = ProcessExpression(cf, i->object); auto [cf2, v2] = ProcessExpression(cf1, i->index); auto* result = CreateNode("index_accessor_result"); result->AddEdge(v1); result->AddEdge(v2); return std::pair(cf2, result); }, [&](const ast::LiteralExpression*) { return std::make_pair(cf, cf); }, [&](const ast::MemberAccessorExpression* m) { return ProcessExpression(cf, m->structure); }, [&](const ast::UnaryOpExpression* u) { if (u->op == ast::UnaryOp::kIndirection) { // Cut the analysis short, since we only need to know the originating variable // which is being accessed. auto* root_ident = sem_.Get(u)->RootIdentifier(); auto* value = current_function_->variables.Get(root_ident); if (!value) { value = cf; } return std::pair(cf, value); } return ProcessExpression(cf, u->expr); }, [&](Default) { TINT_ICE(Resolver, diagnostics_) << "unknown expression type: " << std::string(expr->TypeInfo().name); return std::pair(nullptr, nullptr); }); } /// @param u unary expression with op == kIndirection /// @returns true if `u` is an indirection unary expression that ultimately dereferences a /// partial pointer, false otherwise. bool IsDerefOfPartialPointer(const ast::UnaryOpExpression* u) { TINT_ASSERT(Resolver, u->op == ast::UnaryOp::kIndirection); // To determine if we're dereferencing a partial pointer, unwrap *& // chains; if the final expression is an identifier, see if it's a // partial pointer. If it's not an identifier, then it must be an // index/acessor expression, and thus a partial pointer. auto* e = UnwrapIndirectAndAddressOfChain(u); if (auto* var_user = sem_.Get(e)) { if (current_function_->partial_ptrs.count(var_user->Variable())) { return true; } } else { TINT_ASSERT( Resolver, (e->IsAnyOf())); return true; } return false; } /// Process an LValue expression. /// @param cf the input control flow node /// @param expr the expression to process /// @returns a pair of (control flow node, variable node) std::pair ProcessLValueExpression(Node* cf, const ast::Expression* expr, bool is_partial_reference = false) { return Switch( expr, [&](const ast::IdentifierExpression* i) { auto name = builder_->Symbols().NameFor(i->symbol); auto* sem = sem_.Get(i); if (sem->Variable()->Is()) { return std::make_pair(cf, current_function_->may_be_non_uniform); } else if (auto* local = sem->Variable()->As()) { // Create a new value node for this variable. auto* value = CreateNode(name + "_lvalue"); auto* old_value = current_function_->variables.Set(local, value); // If i is part of an expression that is a partial reference to a variable (e.g. // index or member access), we link back to the variable's previous value. If // the previous value was non-uniform, a partial assignment will not make it // uniform. if (is_partial_reference && old_value) { value->AddEdge(old_value); } return std::make_pair(cf, value); } else { TINT_ICE(Resolver, diagnostics_) << "unknown lvalue identifier expression type: " << std::string(sem->Variable()->TypeInfo().name); return std::pair(nullptr, nullptr); } }, [&](const ast::IndexAccessorExpression* i) { auto [cf1, l1] = ProcessLValueExpression(cf, i->object, /*is_partial_reference*/ true); auto [cf2, v2] = ProcessExpression(cf1, i->index); l1->AddEdge(v2); return std::pair(cf2, l1); }, [&](const ast::MemberAccessorExpression* m) { return ProcessLValueExpression(cf, m->structure, /*is_partial_reference*/ true); }, [&](const ast::UnaryOpExpression* u) { if (u->op == ast::UnaryOp::kIndirection) { // Cut the analysis short, since we only need to know the originating variable // that is being written to. auto* root_ident = sem_.Get(u)->RootIdentifier(); auto name = builder_->Symbols().NameFor(root_ident->Declaration()->symbol); auto* deref = CreateNode(name + "_deref"); auto* old_value = current_function_->variables.Set(root_ident, deref); if (old_value) { // If derefercing a partial reference or partial pointer, we link back to // the variable's previous value. If the previous value was non-uniform, a // partial assignment will not make it uniform. if (is_partial_reference || IsDerefOfPartialPointer(u)) { deref->AddEdge(old_value); } } return std::pair(cf, deref); } return ProcessLValueExpression(cf, u->expr, is_partial_reference); }, [&](Default) { TINT_ICE(Resolver, diagnostics_) << "unknown lvalue expression type: " << std::string(expr->TypeInfo().name); return std::pair(nullptr, nullptr); }); } /// Process a function call expression. /// @param cf the input control flow node /// @param call the function call to process /// @returns a pair of (control flow node, value node) std::pair ProcessCall(Node* cf, const ast::CallExpression* call) { std::string name; if (call->target.name) { name = builder_->Symbols().NameFor(call->target.name->symbol); } else { name = call->target.type->FriendlyName(builder_->Symbols()); } // Process call arguments Node* cf_last_arg = cf; std::vector args; for (size_t i = 0; i < call->args.Length(); i++) { auto [cf_i, arg_i] = ProcessExpression(cf_last_arg, call->args[i]); // Capture the index of this argument in a new node. // Note: This is an additional node that isn't described in the specification, for the // purpose of providing diagnostic information. Node* arg_node = CreateNode(name + "_arg_" + std::to_string(i), call); arg_node->type = Node::kFunctionCallArgument; arg_node->arg_index = static_cast(i); arg_node->AddEdge(arg_i); cf_last_arg = cf_i; args.push_back(arg_node); } // Note: This is an additional node that isn't described in the specification, for the // purpose of providing diagnostic information. Node* call_node = CreateNode(name + "_call", call); call_node->AddEdge(cf_last_arg); Node* result = CreateNode(name + "_return_value", call); result->type = Node::kFunctionCallReturnValue; Node* cf_after = CreateNode("CF_after_" + name, call); // Get tags for the callee. CallSiteTag callsite_tag = CallSiteNoRestriction; FunctionTag function_tag = NoRestriction; auto* sem = SemCall(call); const FunctionInfo* func_info = nullptr; Switch( sem->Target(), [&](const sem::Builtin* builtin) { // Most builtins have no restrictions. The exceptions are barriers, derivatives, and // some texture sampling builtins. if (builtin->IsBarrier()) { callsite_tag = CallSiteRequiredToBeUniform; } else if (builtin->IsDerivative() || builtin->Type() == sem::BuiltinType::kTextureSample || builtin->Type() == sem::BuiltinType::kTextureSampleBias || builtin->Type() == sem::BuiltinType::kTextureSampleCompare) { callsite_tag = CallSiteRequiredToBeUniform; function_tag = ReturnValueMayBeNonUniform; } else { callsite_tag = CallSiteNoRestriction; function_tag = NoRestriction; } }, [&](const sem::Function* func) { // We must have already analyzed the user-defined function since we process // functions in dependency order. TINT_ASSERT(Resolver, functions_.count(func->Declaration())); auto& info = functions_.at(func->Declaration()); callsite_tag = info.callsite_tag; function_tag = info.function_tag; func_info = &info; }, [&](const sem::TypeInitializer*) { callsite_tag = CallSiteNoRestriction; function_tag = NoRestriction; }, [&](const sem::TypeConversion*) { callsite_tag = CallSiteNoRestriction; function_tag = NoRestriction; }, [&](Default) { TINT_ICE(Resolver, diagnostics_) << "unhandled function call target: " << name; }); if (callsite_tag == CallSiteRequiredToBeUniform) { current_function_->required_to_be_uniform->AddEdge(call_node); } cf_after->AddEdge(call_node); if (function_tag == ReturnValueMayBeNonUniform) { result->AddEdge(current_function_->may_be_non_uniform); } result->AddEdge(cf_after); // For each argument, add edges based on parameter tags. for (size_t i = 0; i < args.size(); i++) { if (func_info) { switch (func_info->parameters[i].tag) { case ParameterRequiredToBeUniform: current_function_->required_to_be_uniform->AddEdge(args[i]); break; case ParameterRequiredToBeUniformForReturnValue: result->AddEdge(args[i]); break; case ParameterNoRestriction: break; } auto* sem_arg = sem_.Get(call->args[i]); if (sem_arg->Type()->Is()) { auto* ptr_result = CreateNode(name + "_ptrarg_" + std::to_string(i) + "_result", call); ptr_result->type = Node::kFunctionCallPointerArgumentResult; ptr_result->arg_index = static_cast(i); if (func_info->parameters[i].pointer_may_become_non_uniform) { ptr_result->AddEdge(current_function_->may_be_non_uniform); } else { // Add edge to the call to catch when it's called in non-uniform control // flow. ptr_result->AddEdge(call_node); // Add edges from the resulting pointer value to any other arguments that // feed it. for (auto* source : func_info->parameters[i].pointer_param_output_sources) { ptr_result->AddEdge(args[source->Index()]); } } // Update the current stored value for this pointer argument. auto* root_ident = sem_arg->RootIdentifier(); TINT_ASSERT(Resolver, root_ident); current_function_->variables.Set(root_ident, ptr_result); } } else { // All builtin function parameters are RequiredToBeUniformForReturnValue, as are // parameters for type initializers and type conversions. // The arrayLength() builtin is a special case, as there is currently no way for it // to have a non-uniform return value. auto* builtin = sem->Target()->As(); if (!builtin || builtin->Type() != sem::BuiltinType::kArrayLength) { result->AddEdge(args[i]); } } } return {cf_after, result}; } /// Traverse a graph starting at `source`, inserting all visited nodes into `reachable` and /// recording which node they were reached from. /// @param source the starting node /// @param reachable the set of reachable nodes to populate, if required void Traverse(Node* source, utils::UniqueVector* reachable = nullptr) { std::vector to_visit{source}; while (!to_visit.empty()) { auto* node = to_visit.back(); to_visit.pop_back(); if (reachable) { reachable->Add(node); } for (auto* to : node->edges) { if (to->visited_from == nullptr) { to->visited_from = node; to_visit.push_back(to); } } } } /// Trace back along a path from `start` until finding a node that matches a predicate. /// @param start the starting node /// @param pred the predicate function /// @returns the first node found that matches the predicate, or nullptr template Node* TraceBackAlongPathUntil(Node* start, F&& pred) { auto* current = start; while (current) { if (pred(current)) { break; } current = current->visited_from; } return current; } /// Recursively descend through the function called by `call` and the functions that it calls in /// order to find a call to a builtin function that requires uniformity. const ast::CallExpression* FindBuiltinThatRequiresUniformity(const ast::CallExpression* call) { auto* target = SemCall(call)->Target(); if (target->Is()) { // This is a call to a builtin, so we must be done. return call; } else if (auto* user = target->As()) { // This is a call to a user-defined function, so inspect the functions called by that // function and look for one whose node has an edge from the RequiredToBeUniform node. auto& target_info = functions_.at(user->Declaration()); for (auto* call_node : target_info.required_to_be_uniform->edges) { if (call_node->type == Node::kRegular) { auto* child_call = call_node->ast->As(); return FindBuiltinThatRequiresUniformity(child_call); } } TINT_ASSERT(Resolver, false && "unable to find child call with uniformity requirement"); } else { TINT_ASSERT(Resolver, false && "unexpected call expression type"); } return nullptr; } /// Add diagnostic notes to show where control flow became non-uniform on the way to a node. /// @param function the function being analyzed /// @param required_to_be_uniform the node to traverse from /// @param may_be_non_uniform the node to traverse to void ShowCauseOfNonUniformity(FunctionInfo& function, Node* required_to_be_uniform, Node* may_be_non_uniform) { // Traverse the graph to generate a path from the node to the source of non-uniformity. function.ResetVisited(); Traverse(required_to_be_uniform); // Get the source of the non-uniform value. auto* non_uniform_source = may_be_non_uniform->visited_from; TINT_ASSERT(Resolver, non_uniform_source); // Show where the non-uniform value results in non-uniform control flow. auto* control_flow = TraceBackAlongPathUntil( non_uniform_source, [](Node* node) { return node->affects_control_flow; }); if (control_flow) { diagnostics_.add_note(diag::System::Resolver, "control flow depends on non-uniform value", control_flow->ast->source); // TODO(jrprice): There are cases where the function with uniformity requirements is not // actually inside this control flow construct, for example: // - A conditional interrupt (e.g. break), with a barrier elsewhere in the loop // - A conditional assignment to a variable, which is later used to guard a barrier // In these cases, the diagnostics are not entirely accurate as they may not highlight // the actual cause of divergence. } auto get_var_type = [&](const sem::Variable* var) { switch (var->AddressSpace()) { case ast::AddressSpace::kStorage: return "read_write storage buffer "; case ast::AddressSpace::kWorkgroup: return "workgroup storage variable "; case ast::AddressSpace::kPrivate: return "module-scope private variable "; default: if (ast::HasAttribute(var->Declaration()->attributes)) { return "builtin "; } else if (ast::HasAttribute( var->Declaration()->attributes)) { return "user-defined input "; } else { // TODO(jrprice): Provide more info for this case. } break; } return ""; }; // Show the source of the non-uniform value. Switch( non_uniform_source->ast, [&](const ast::IdentifierExpression* ident) { auto* var = sem_.Get(ident)->Variable(); std::string var_type = get_var_type(var); diagnostics_.add_note(diag::System::Resolver, "reading from " + var_type + "'" + builder_->Symbols().NameFor(ident->symbol) + "' may result in a non-uniform value", ident->source); }, [&](const ast::Variable* v) { auto* var = sem_.Get(v); std::string var_type = get_var_type(var); diagnostics_.add_note(diag::System::Resolver, "reading from " + var_type + "'" + builder_->Symbols().NameFor(v->symbol) + "' may result in a non-uniform value", v->source); }, [&](const ast::CallExpression* c) { auto target_name = builder_->Symbols().NameFor( c->target.name->As()->symbol); switch (non_uniform_source->type) { case Node::kFunctionCallReturnValue: { diagnostics_.add_note( diag::System::Resolver, "return value of '" + target_name + "' may be non-uniform", c->source); break; } case Node::kFunctionCallPointerArgumentResult: { diagnostics_.add_note( diag::System::Resolver, "pointer contents may become non-uniform after calling '" + target_name + "'", c->args[non_uniform_source->arg_index]->source); break; } default: { TINT_ICE(Resolver, diagnostics_) << "unhandled source of non-uniformity"; break; } } }, [&](const ast::Expression* e) { diagnostics_.add_note(diag::System::Resolver, "result of expression may be non-uniform", e->source); }, [&](Default) { TINT_ICE(Resolver, diagnostics_) << "unhandled source of non-uniformity"; }); } /// Generate an error message for a uniformity issue. /// @param function the function that the diagnostic is being produced for /// @param source_node the node that has caused a uniformity issue in `function` /// @param note `true` if the diagnostic should be emitted as a note void MakeError(FunctionInfo& function, Node* source_node, bool note = false) { // Helper to produce a diagnostic message with the severity required by this invocation of // the `MakeError` function. auto report = [&](Source source, std::string msg) { diag::Diagnostic error{}; auto failureSeverity = kUniformityFailuresAsError ? diag::Severity::Error : diag::Severity::Warning; error.severity = note ? diag::Severity::Note : failureSeverity; error.system = diag::System::Resolver; error.source = source; error.message = msg; diagnostics_.add(std::move(error)); }; // Traverse the graph to generate a path from RequiredToBeUniform to the source node. function.ResetVisited(); Traverse(function.required_to_be_uniform); TINT_ASSERT(Resolver, source_node->visited_from); // Find a node that is required to be uniform that has a path to the source node. auto* cause = TraceBackAlongPathUntil(source_node, [&](Node* node) { return node->visited_from == function.required_to_be_uniform; }); // The node will always have a corresponding call expression. auto* call = cause->ast->As(); TINT_ASSERT(Resolver, call); auto* target = SemCall(call)->Target(); std::string func_name; if (auto* builtin = target->As()) { func_name = builtin->str(); } else if (auto* user = target->As()) { func_name = builder_->Symbols().NameFor(user->Declaration()->symbol); } if (cause->type == Node::kFunctionCallArgument) { // The requirement was on a function parameter. auto param_name = builder_->Symbols().NameFor( target->Parameters()[cause->arg_index]->Declaration()->symbol); report(call->args[cause->arg_index]->source, "parameter '" + param_name + "' of '" + func_name + "' must be uniform"); // If this is a call to a user-defined function, add a note to show the reason that the // parameter is required to be uniform. if (auto* user = target->As()) { auto& next_function = functions_.at(user->Declaration()); Node* next_cause = next_function.parameters[cause->arg_index].init_value; MakeError(next_function, next_cause, true); } } else { // The requirement was on a function callsite. report(call->source, "'" + func_name + "' must only be called from uniform control flow"); // If this is a call to a user-defined function, add a note to show the builtin that // causes the uniformity requirement. auto* innermost_call = FindBuiltinThatRequiresUniformity(call); if (innermost_call != call) { auto* sem_call = SemCall(call); auto* sem_innermost_call = SemCall(innermost_call); // Determine whether the builtin is being called directly or indirectly. bool indirect = false; if (sem_call->Target()->As() != sem_innermost_call->Stmt()->Function()) { indirect = true; } auto* builtin = sem_innermost_call->Target()->As(); diagnostics_.add_note(diag::System::Resolver, "'" + func_name + "' requires uniformity because it " + (indirect ? "indirectly " : "") + "calls " + builtin->str(), innermost_call->source); } } // Show the cause of non-uniformity (starting at the top-level error). if (!note) { ShowCauseOfNonUniformity(function, function.required_to_be_uniform, function.may_be_non_uniform); } } // Helper for obtaining the sem::Call node for the ast::CallExpression const sem::Call* SemCall(const ast::CallExpression* expr) const { return sem_.Get(expr)->UnwrapMaterialize()->As(); } }; } // namespace bool AnalyzeUniformity(ProgramBuilder* builder, const DependencyGraph& dependency_graph) { UniformityGraph graph(builder); return graph.Build(dependency_graph); } } // namespace tint::resolver