diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 840836bbb6..ad68e73c0f 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -2877,8 +2877,30 @@ TypedExpression FunctionEmitter::MakeVectorShuffle( } void FunctionEmitter::RegisterValuesNeedingNamedDefinition() { - for (auto& block : function_) { - for (const auto& inst : block) { + // Maps a result ID to the block position where it is last used. + std::unordered_map<uint32_t, uint32_t> id_to_last_use_pos; + // List of pairs of (result id, block position of the definition). + std::vector<std::pair<uint32_t, uint32_t>> id_def_pos; + + for (auto block_id : block_order_) { + const auto* block_info = GetBlockInfo(block_id); + const auto block_pos = block_info->pos; + + for (const auto& inst : *(block_info->basic_block)) { + const auto result_id = inst.result_id(); + if (result_id != 0) { + id_def_pos.emplace_back( + std::pair<uint32_t, uint32_t>{result_id, block_pos}); + } + inst.ForEachInId( + [&id_to_last_use_pos, block_pos](const uint32_t* id_ptr) { + // If the id is not in the map already, this will create + // an entry with value 0. + auto& pos = id_to_last_use_pos[*id_ptr]; + // Update the entry. + pos = std::max(pos, block_pos); + }); + if (inst.opcode() == SpvOpVectorShuffle) { // We might access the vector operands multiple times. Make sure they // are evaluated only once. @@ -2896,6 +2918,27 @@ void FunctionEmitter::RegisterValuesNeedingNamedDefinition() { } } } + + // For an ID defined in this function, if it is used in a different construct + // than its definition, then it needs a named constant definition. Otherwise + // we might sink an expensive computation into control flow, and hence change + // performance. + for (const auto& id_and_pos : id_def_pos) { + const auto id = id_and_pos.first; + const auto def_pos = id_and_pos.second; + + auto last_use_where = id_to_last_use_pos.find(id); + if (last_use_where != id_to_last_use_pos.end()) { + const auto last_use_pos = last_use_where->second; + const auto* def_in_construct = + GetBlockInfo(block_order_[def_pos])->construct; + const auto* last_use_in_construct = + GetBlockInfo(block_order_[last_use_pos])->construct; + if (def_in_construct != last_use_in_construct) { + needs_named_const_def_.insert(id); + } + } + } } TypedExpression FunctionEmitter::MakeNumericConversion( diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index fbab63bad5..9a8ebada2b 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -112,7 +112,8 @@ struct BlockInfo { /// as its own continue target, and has branch to itself. bool is_single_block_loop = false; - /// The immediately enclosing structured construct. + /// The immediately enclosing structured construct. If this block is not + /// in the block order at all, then this is still nullptr. const Construct* construct = nullptr; /// Maps the ID of a successor block (in the CFG) to its edge classification. diff --git a/src/reader/spirv/function_var_test.cc b/src/reader/spirv/function_var_test.cc index dd9f283ae7..ca8e217e34 100644 --- a/src/reader/spirv/function_var_test.cc +++ b/src/reader/spirv/function_var_test.cc @@ -26,6 +26,7 @@ namespace reader { namespace spirv { namespace { +using ::testing::Eq; using ::testing::HasSubstr; /// @returns a SPIR-V assembly segment which assigns debug names @@ -38,8 +39,11 @@ std::string Names(std::vector<std::string> ids) { return outs.str(); } -std::string CommonTypes() { +std::string Preamble() { return R"( + OpCapability Shader + OpMemoryModel Logical Simple + %void = OpTypeVoid %voidfn = OpTypeFunction %void @@ -70,7 +74,7 @@ std::string CommonTypes() { } TEST_F(SpvParserTest, EmitFunctionVariables_AnonymousVars) { - auto* p = parser(test::Assemble(CommonTypes() + R"( + auto* p = parser(test::Assemble(Preamble() + R"( %100 = OpFunction %void None %voidfn %entry = OpLabel %1 = OpVariable %ptr_uint Function @@ -108,7 +112,7 @@ VariableDeclStatement{ } TEST_F(SpvParserTest, EmitFunctionVariables_NamedVars) { - auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + CommonTypes() + R"( + auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + Preamble() + R"( %100 = OpFunction %void None %voidfn %entry = OpLabel %a = OpVariable %ptr_uint Function @@ -146,7 +150,7 @@ VariableDeclStatement{ } TEST_F(SpvParserTest, EmitFunctionVariables_MixedTypes) { - auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + CommonTypes() + R"( + auto* p = parser(test::Assemble(Names({"a", "b", "c"}) + Preamble() + R"( %100 = OpFunction %void None %voidfn %entry = OpLabel %a = OpVariable %ptr_uint Function @@ -184,8 +188,8 @@ VariableDeclStatement{ } TEST_F(SpvParserTest, EmitFunctionVariables_ScalarInitializers) { - auto* p = parser( - test::Assemble(Names({"a", "b", "c", "d", "e"}) + CommonTypes() + R"( + auto* p = + parser(test::Assemble(Names({"a", "b", "c", "d", "e"}) + Preamble() + R"( %100 = OpFunction %void None %voidfn %entry = OpLabel %a = OpVariable %ptr_bool Function %true @@ -254,8 +258,7 @@ VariableDeclStatement{ } TEST_F(SpvParserTest, EmitFunctionVariables_ScalarNullInitializers) { - auto* p = - parser(test::Assemble(Names({"a", "b", "c", "d"}) + CommonTypes() + R"( + auto* p = parser(test::Assemble(Names({"a", "b", "c", "d"}) + Preamble() + R"( %null_bool = OpConstantNull %bool %null_int = OpConstantNull %int %null_uint = OpConstantNull %uint @@ -318,7 +321,7 @@ VariableDeclStatement{ } TEST_F(SpvParserTest, EmitFunctionVariables_VectorInitializer) { - auto* p = parser(test::Assemble(CommonTypes() + R"( + auto* p = parser(test::Assemble(Preamble() + R"( %ptr = OpTypePointer Function %v2float %two = OpConstant %float 2.0 %const = OpConstantComposite %v2float %float_1p5 %two @@ -351,7 +354,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_VectorInitializer) { } TEST_F(SpvParserTest, EmitFunctionVariables_MatrixInitializer) { - auto* p = parser(test::Assemble(CommonTypes() + R"( + auto* p = parser(test::Assemble(Preamble() + R"( %ptr = OpTypePointer Function %m3v2float %two = OpConstant %float 2.0 %three = OpConstant %float 3.0 @@ -402,7 +405,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_MatrixInitializer) { } TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer) { - auto* p = parser(test::Assemble(CommonTypes() + R"( + auto* p = parser(test::Assemble(Preamble() + R"( %ptr = OpTypePointer Function %arr2uint %two = OpConstant %uint 2 %const = OpConstantComposite %arr2uint %uint_1 %two @@ -436,7 +439,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer) { TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType) { auto* p = parser(test::Assemble( - std::string("OpDecorate %arr2uint ArrayStride 16\n") + CommonTypes() + R"( + std::string("OpDecorate %arr2uint ArrayStride 16\n") + Preamble() + R"( %ptr = OpTypePointer Function %arr2uint %two = OpConstant %uint 2 %const = OpConstantComposite %arr2uint %uint_1 %two @@ -469,7 +472,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType) { } TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_Null) { - auto* p = parser(test::Assemble(CommonTypes() + R"( + auto* p = parser(test::Assemble(Preamble() + R"( %ptr = OpTypePointer Function %arr2uint %two = OpConstant %uint 2 %const = OpConstantNull %arr2uint @@ -503,7 +506,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_Null) { TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType_Null) { auto* p = parser(test::Assemble( - std::string("OpDecorate %arr2uint ArrayStride 16\n") + CommonTypes() + R"( + std::string("OpDecorate %arr2uint ArrayStride 16\n") + Preamble() + R"( %ptr = OpTypePointer Function %arr2uint %two = OpConstant %uint 2 %const = OpConstantNull %arr2uint @@ -536,7 +539,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_ArrayInitializer_AliasType_Null) { } TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer) { - auto* p = parser(test::Assemble(CommonTypes() + R"( + auto* p = parser(test::Assemble(Preamble() + R"( %ptr = OpTypePointer Function %strct %two = OpConstant %uint 2 %arrconst = OpConstantComposite %arr2uint %uint_1 %two @@ -575,7 +578,7 @@ TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer) { } TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer_Null) { - auto* p = parser(test::Assemble(CommonTypes() + R"( + auto* p = parser(test::Assemble(Preamble() + R"( %ptr = OpTypePointer Function %strct %two = OpConstant %uint 2 %arrconst = OpConstantComposite %arr2uint %uint_1 %two @@ -613,6 +616,184 @@ TEST_F(SpvParserTest, EmitFunctionVariables_StructInitializer_Null) { )")) << ToString(fe.ast_body()); } +TEST_F(SpvParserTest, + EmitStatement_CombinatorialValue_Defer_UsedOnceSameConstruct) { + auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + %25 = OpVariable %ptr_uint Function + %2 = OpIAdd %uint %uint_1 %uint_1 + OpStore %25 %uint_1 ; Do initial store to mark source location + OpBranch %20 + + %20 = OpLabel + OpStore %25 %2 ; defer emission of the addition until here. + OpReturn + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{ + Variable{ + x_25 + function + __u32 + } +} +Assignment{ + Identifier{x_25} + ScalarConstructor{1} +} +Assignment{ + Identifier{x_25} + Binary{ + ScalarConstructor{1} + add + ScalarConstructor{1} + } +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, EmitStatement_CombinatorialValue_Immediate_UsedTwice) { + auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + %25 = OpVariable %ptr_uint Function + %2 = OpIAdd %uint %uint_1 %uint_1 + OpStore %25 %uint_1 ; Do initial store to mark source location + OpBranch %20 + + %20 = OpLabel + OpStore %25 %2 + OpStore %25 %2 + OpReturn + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{ + Variable{ + x_25 + function + __u32 + } +} +VariableDeclStatement{ + Variable{ + x_2 + none + __u32 + { + Binary{ + ScalarConstructor{1} + add + ScalarConstructor{1} + } + } + } +} +Assignment{ + Identifier{x_25} + ScalarConstructor{1} +} +Assignment{ + Identifier{x_25} + Identifier{x_2} +} +Assignment{ + Identifier{x_25} + Identifier{x_2} +} +Return{} +)")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest, + EmitStatement_CombinatorialValue_Immediate_UsedOnceDifferentConstruct) { + // Translation should not sink expensive operations into or out of control + // flow. As a simple heuristic, don't move *any* combinatorial operation + // across any constrol flow. + auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + + %10 = OpLabel + %25 = OpVariable %ptr_uint Function + %2 = OpIAdd %uint %uint_1 %uint_1 + OpStore %25 %uint_1 ; Do initial store to mark source location + OpBranch %20 + + %20 = OpLabel ; Introduce a new construct + OpLoopMerge %99 %80 None + OpBranch %80 + + %80 = OpLabel + OpStore %25 %2 ; store combinatorial value %2, inside the loop + OpBranch %20 + + %99 = OpLabel ; merge block + OpStore %25 %uint_2 + OpReturn + + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + + EXPECT_THAT(ToString(fe.ast_body()), Eq(R"(VariableDeclStatement{ + Variable{ + x_25 + function + __u32 + } +} +VariableDeclStatement{ + Variable{ + x_2 + none + __u32 + { + Binary{ + ScalarConstructor{1} + add + ScalarConstructor{1} + } + } + } +} +Assignment{ + Identifier{x_25} + ScalarConstructor{1} +} +Loop{ + continuing { + Assignment{ + Identifier{x_25} + Identifier{x_2} + } + } +} +Assignment{ + Identifier{x_25} + ScalarConstructor{2} +} +Return{} +)")) << ToString(fe.ast_body()); +} + } // namespace } // namespace spirv } // namespace reader