diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index ad68e73c0f..8b89d80673 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -624,7 +624,7 @@ bool FunctionEmitter::EmitBody() { // TODO(dneto): register phis // TODO(dneto): register SSA values which need to be hoisted - RegisterValuesNeedingNamedDefinition(); + RegisterValuesNeedingNamedOrHoistedDefinition(); if (!EmitFunctionVariables()) { return false; @@ -2361,6 +2361,19 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info, // Only emit this part of the basic block once. return true; } + + // Emit declarations of hoisted variables. + for (auto id : block_info.hoisted_ids) { + const auto* def_inst = def_use_mgr_->GetDef(id); + assert(def_inst); + AddStatement( + std::make_unique(parser_impl_.MakeVariable( + id, ast::StorageClass::kFunction, + parser_impl_.ConvertType(def_inst->type_id())))); + // Save this as an already-named value. + identifier_values_.insert(id); + } + const spvtools::opt::BasicBlock& bb = *(block_info.basic_block); const auto* terminator = bb.terminator(); const auto* merge = bb.GetMergeInst(); // Might be nullptr @@ -2399,22 +2412,38 @@ bool FunctionEmitter::EmitConstDefinition( return success(); } +bool FunctionEmitter::EmitConstDefOrWriteToHoistedVar( + const spvtools::opt::Instruction& inst, + TypedExpression ast_expr) { + const auto result_id = inst.result_id(); + if (needs_hoisted_def_.count(result_id) != 0) { + // Emit an assignment of the expression to the hoisted variable. + AddStatement(std::make_unique( + std::make_unique(namer_.Name(result_id)), + std::move(ast_expr.expr))); + return true; + } + return EmitConstDefinition(inst, std::move(ast_expr)); +} + bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { - // Handle combinatorial instructions first. + const auto result_id = inst.result_id(); + // Handle combinatorial instructions. auto combinatorial_expr = MaybeEmitCombinatorialValue(inst); if (combinatorial_expr.expr != nullptr) { - if ((needs_named_const_def_.count(inst.result_id()) == 0) && + if ((needs_hoisted_def_.count(result_id) == 0) && + (needs_named_const_def_.count(result_id) == 0) && (def_use_mgr_->NumUses(&inst) == 1)) { // If it's used once, and doesn't need a named constant definition, // then defer emitting the expression until it's used. Any supporting // statements have already been emitted. singly_used_values_.insert( - std::make_pair(inst.result_id(), std::move(combinatorial_expr))); + std::make_pair(result_id, std::move(combinatorial_expr))); return success(); } // Otherwise, generate a const definition for it now and later use // the const's name at the uses of the value. - return EmitConstDefinition(inst, std::move(combinatorial_expr)); + return EmitConstDefOrWriteToHoistedVar(inst, std::move(combinatorial_expr)); } if (failed()) { return false; @@ -2435,13 +2464,13 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { case SpvOpLoad: // Memory accesses must be issued in SPIR-V program order. // So represent a load by a new const definition. - return EmitConstDefinition( + return EmitConstDefOrWriteToHoistedVar( inst, MakeExpression(inst.GetSingleWordInOperand(0))); case SpvOpCopyObject: // Arguably, OpCopyObject is purely combinatorial. On the other hand, // it exists to make a new name for something. So we choose to make // a new named constant definition. - return EmitConstDefinition( + return EmitConstDefOrWriteToHoistedVar( inst, MakeExpression(inst.GetSingleWordInOperand(0))); case SpvOpFunctionCall: // TODO(dneto): Fill this out. Make this pass, for existing tests @@ -2876,7 +2905,7 @@ TypedExpression FunctionEmitter::MakeVectorShuffle( result_type, std::move(values))}; } -void FunctionEmitter::RegisterValuesNeedingNamedDefinition() { +void FunctionEmitter::RegisterValuesNeedingNamedOrHoistedDefinition() { // Maps a result ID to the block position where it is last used. std::unordered_map id_to_last_use_pos; // List of pairs of (result id, block position of the definition). @@ -2930,12 +2959,32 @@ void FunctionEmitter::RegisterValuesNeedingNamedDefinition() { auto last_use_where = id_to_last_use_pos.find(id); if (last_use_where != id_to_last_use_pos.end()) { const auto last_use_pos = last_use_where->second; - const auto* def_in_construct = + const auto* const def_in_construct = GetBlockInfo(block_order_[def_pos])->construct; - const auto* last_use_in_construct = + const auto* const construct_with_last_use = GetBlockInfo(block_order_[last_use_pos])->construct; - if (def_in_construct != last_use_in_construct) { - needs_named_const_def_.insert(id); + + // Find the smallest structured construct that encloses the definition + // and all its uses. + const auto* enclosing_construct = def_in_construct; + while (enclosing_construct && + !enclosing_construct->ContainsPos(last_use_pos)) { + enclosing_construct = enclosing_construct->parent; + } + // At worst, we go all the way out to the function construct. + assert(enclosing_construct != nullptr); + + if (def_in_construct != construct_with_last_use) { + if (enclosing_construct == def_in_construct) { + // We can use a plain 'const' definition. + needs_named_const_def_.insert(id); + } else { + // We need to make a hoisted variable definition. + // TODO(dneto): Handle non-storable types, particularly pointers. + needs_hoisted_def_.insert(id); + auto* hoist_to_block = GetBlockInfo(enclosing_construct->begin_id); + hoist_to_block->hoisted_ids.push_back(id); + } } } } diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 9a8ebada2b..484df30244 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -157,6 +157,11 @@ struct BlockInfo { /// This occurs when a block in this selection has both an if-break edge, and /// also a different normal forward edge but without a merge instruction. std::string flow_guard_name = ""; + + /// The result IDs that this block is responsible for declaring as a + /// hoisted variable. See the |needs_hoisted_def_| member of + /// FunctionEmitter for an explanation. + std::vector hoisted_ids; }; inline std::ostream& operator<<(std::ostream& o, const BlockInfo& bi) { @@ -278,11 +283,18 @@ class FunctionEmitter { bool FindIfSelectionInternalHeaders(); /// Record the SPIR-V IDs of non-constants that should get a 'const' - /// definition in WGSL. This occurs when a SPIR-V instruction might use the - /// dynamically computed value only once, but the WGSL code might reference - /// it multiple times. For example, this occurs for the vector operands of - /// OpVectorShuffle. Populates |needs_named_const_def_| - void RegisterValuesNeedingNamedDefinition(); + /// definition in WGSL, or a 'var' definition at an outer scope. + /// This occurs in several cases: + /// - When a SPIR-V instruction might use the dynamically computed value + /// only once, but the WGSL code might reference it multiple times. + /// For example, this occurs for the vector operands of OpVectorShuffle. + /// In this case the definition is added to |needs_named_const_def_|. + /// - When a definition and at least one of its uses are not in the + /// same structured construct. + /// In this case the definition is added to |needs_named_const_def_|. + /// - When a definition is in a construct that does not enclose all the + /// uses. In this case the definition is added to |needs_hoisted_def_|. + void RegisterValuesNeedingNamedOrHoistedDefinition(); /// Emits declarations of function variables. /// @returns false if emission failed. @@ -431,6 +443,15 @@ class FunctionEmitter { bool EmitConstDefinition(const spvtools::opt::Instruction& inst, TypedExpression ast_expr); + /// Emits a write to a hoisted variable for the given SPIR-V id, + /// if that ID has a hoisted declaration. Otherwise, emits a const + /// definition instead. + /// @param inst the SPIR-V instruction defining the value + /// @param ast_expr the already-computed AST expression for the value + /// @returns false if emission failed. + bool EmitConstDefOrWriteToHoistedVar(const spvtools::opt::Instruction& inst, + TypedExpression ast_expr); + /// Makes an expression /// @param id the SPIR-V ID of the value /// @returns true if emission has not yet failed. @@ -603,6 +624,19 @@ class FunctionEmitter { std::unordered_map singly_used_values_; // Set of SPIR-V IDs which should get a named const definition. std::unordered_set needs_named_const_def_; + // The SPIR-V IDs that must be declared in WGSL before the corresponding + // location in SPIR-V. This compensates for the difference between dominance + // and scoping. An SSA definition can dominate all its uses, but the construct + // where it is defined does not enclose all the uses, and so if it were + // declared as a WGSL constant definition at the point of its SPIR-V + // definition, then the WGSL name would go out of scope too early. Fix that by + // creating a variable at the top of the smallest construct that encloses both + // the definition and all its uses. Then the original SPIR-V definition maps + // to a WGSL assignment to that variable, and each SPIR-V use becomes a WGSL + // read from the variable. + // TODO(dneto): This works for constants of storable type, but not, for + // example, pointers. + std::unordered_set needs_hoisted_def_; // The IDs of basic blocks, in reverse structured post-order (RSPO). // This is the output order for the basic blocks. diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc index ca8e217e34..85ac938d50 100644 --- a/src/reader/spirv/function_var_test.cc +++ b/src/reader/spirv/function_var_test.cc @@ -61,9 +61,13 @@ std::string Preamble() { %false = OpConstantFalse %bool %float_0 = OpConstant %float 0.0 %float_1p5 = OpConstant %float 1.5 + %uint_0 = OpConstant %uint 0 %uint_1 = OpConstant %uint 1 %int_m1 = OpConstant %int -1 %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 + %uint_4 = OpConstant %uint 4 + %uint_5 = OpConstant %uint 5 %v2float = OpTypeVector %float 2 %m3v2float = OpTypeMatrix %v2float 3 @@ -794,6 +798,137 @@ Return{} )")) << ToString(fe.ast_body()); } +TEST_F( + SpvParserTest, + EmitStatement_CombinatorialNonPointer_DefConstruct_DoesNotEncloseAllUses) { + // Compensate for the difference between dominance and scoping. + // Exercise hoisting of the constant definition to before its natural + // location. + // + // The definition of %2 should be hoisted + auto assembly = Preamble() + R"( + %pty = OpTypePointer Private %uint + %1 = OpVariable %pty Private + + %100 = OpFunction %void None %voidfn + + %3 = OpLabel + OpStore %1 %uint_0 + OpBranch %5 + + %5 = OpLabel + OpStore %1 %uint_1 + OpLoopMerge %99 %80 None + OpBranchConditional %false %99 %20 + + %20 = OpLabel + OpStore %1 %uint_3 + OpSelectionMerge %50 None + OpBranchConditional %true %30 %40 + + %30 = OpLabel + ; This combinatorial definition in nested control flow dominates + ; the use in the merge block in %50 + %2 = OpIAdd %uint %uint_1 %uint_1 + OpBranch %50 + + %40 = OpLabel + OpReturn + + %50 = OpLabel ; merge block for if-selection + OpStore %1 %2 + OpBranch %80 + + %80 = OpLabel ; merge block + OpStore %1 %uint_4 + OpBranchConditional %false %99 %5 ; loop backedge + + %99 = OpLabel + OpStore %1 %uint_5 + OpReturn + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(Assignment{ + Identifier{x_1} + ScalarConstructor{0} +} +Loop{ + VariableDeclStatement{ + Variable{ + x_2 + function + __u32 + } + } + Assignment{ + Identifier{x_1} + ScalarConstructor{1} + } + If{ + ( + ScalarConstructor{false} + ) + { + Break{} + } + } + Assignment{ + Identifier{x_1} + ScalarConstructor{3} + } + If{ + ( + ScalarConstructor{true} + ) + { + Assignment{ + Identifier{x_2} + Binary{ + ScalarConstructor{1} + add + ScalarConstructor{1} + } + } + } + } + Else{ + { + Return{} + } + } + Assignment{ + Identifier{x_1} + Identifier{x_2} + } + continuing { + Assignment{ + Identifier{x_1} + ScalarConstructor{4} + } + If{ + ( + ScalarConstructor{false} + ) + { + Break{} + } + } + } +} +Assignment{ + Identifier{x_1} + ScalarConstructor{5} +} +Return{} +)")) << ToString(fe.ast_body()); +} + } // namespace } // namespace spirv } // namespace reader