diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 0635a80309..117dcc93ee 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -15,6 +15,7 @@ #include "src/reader/spirv/function.h" #include +#include #include #include #include @@ -49,6 +50,7 @@ #include "src/ast/switch_statement.h" #include "src/ast/type/bool_type.h" #include "src/ast/type/u32_type.h" +#include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/uint_literal.h" #include "src/ast/unary_op.h" @@ -615,6 +617,7 @@ bool FunctionEmitter::EmitBody() { // TODO(dneto): register phis // TODO(dneto): register SSA values which need to be hoisted + RegisterValuesNeedingNamedDefinition(); if (!EmitFunctionVariables()) { return false; @@ -2394,9 +2397,11 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) { // Handle combinatorial instructions first. auto combinatorial_expr = MaybeEmitCombinatorialValue(inst); if (combinatorial_expr.expr != nullptr) { - if (def_use_mgr_->NumUses(&inst) == 1) { - // If it's used once, then defer emitting the expression until it's - // used. Any supporting statements have already been emitted. + if ((needs_named_const_def_.count(inst.result_id()) == 0) && + (def_use_mgr_->NumUses(&inst) == 1)) { + // If it's used once, and doesn't need a named constant definition, + // then defer emitting the expression until it's used. Any supporting + // statements have already been emitted. singly_used_values_.insert( std::make_pair(inst.result_id(), std::move(combinatorial_expr))); return success(); @@ -2525,6 +2530,10 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( return MakeCompositeExtract(inst); } + if (opcode == SpvOpVectorShuffle) { + return MakeVectorShuffle(inst); + } + // builtin readonly function // glsl.std.450 readonly function @@ -2817,6 +2826,73 @@ std::unique_ptr FunctionEmitter::MakeFalse() const { std::make_unique(parser_impl_.BoolType(), false)); } +TypedExpression FunctionEmitter::MakeVectorShuffle( + const spvtools::opt::Instruction& inst) { + const auto vec0_id = inst.GetSingleWordInOperand(0); + const auto vec1_id = inst.GetSingleWordInOperand(1); + const spvtools::opt::Instruction& vec0 = *(def_use_mgr_->GetDef(vec0_id)); + const spvtools::opt::Instruction& vec1 = *(def_use_mgr_->GetDef(vec1_id)); + const auto vec0_len = + type_mgr_->GetType(vec0.type_id())->AsVector()->element_count(); + const auto vec1_len = + type_mgr_->GetType(vec1.type_id())->AsVector()->element_count(); + + // Idiomatic vector accessors. + const char* swizzles[] = {"x", "y", "z", "w"}; + + // Generate an ast::TypeConstructor expression. + // Assume the literal indices are valid, and there is a valid number of them. + ast::type::VectorType* result_type = + parser_impl_.ConvertType(inst.type_id())->AsVector(); + ast::ExpressionList values; + for (uint32_t i = 2; i < inst.NumInOperands(); ++i) { + const auto index = inst.GetSingleWordInOperand(i); + if (index < vec0_len) { + assert(index < sizeof(swizzles) / sizeof(swizzles[0])); + values.emplace_back(std::make_unique( + MakeExpression(vec0_id).expr, + std::make_unique(swizzles[index]))); + } else if (index < vec0_len + vec1_len) { + const auto sub_index = index - vec0_len; + assert(sub_index < sizeof(swizzles) / sizeof(swizzles[0])); + values.emplace_back(std::make_unique( + MakeExpression(vec1_id).expr, + std::make_unique(swizzles[sub_index]))); + } else if (index == 0xFFFFFFFF) { + // By rule, this maps to OpUndef. Instead, make it zero. + values.emplace_back(parser_impl_.MakeNullValue(result_type->type())); + } else { + Fail() << "invalid vectorshuffle ID %" << inst.result_id() + << ": index too large: " << index; + return {}; + } + } + return {result_type, std::make_unique( + result_type, std::move(values))}; +} + +void FunctionEmitter::RegisterValuesNeedingNamedDefinition() { + for (auto& block : function_) { + for (const auto& inst : block) { + if (inst.opcode() == SpvOpVectorShuffle) { + // We might access the vector operands multiple times. Make sure they + // are evaluated only once. + for (auto index : std::array{0, 1}) { + auto id = inst.GetSingleWordInOperand(index); + if (constant_mgr_->FindDeclaredConstant(id)) { + // If it's constant, then avoid making a const definition + // in the wrong place; it would be wrong if it didn't + // dominate its uses. + continue; + } + // Othewrise, register it. + needs_named_const_def_.insert(id); + } + } + } + } +} + } // namespace spirv } // namespace reader } // namespace tint diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 169a1d1d54..3d16aa90ab 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -276,6 +276,13 @@ class FunctionEmitter { /// @returns false if bad nesting has been detected. bool FindIfSelectionInternalHeaders(); + /// Record the SPIR-V IDs of non-constants that should get a 'const' + /// definition in WGSL. This occurs when a SPIR-V instruction might use the + /// dynamically computed value only once, but the WGSL code might reference + /// it multiple times. For example, this occurs for the vector operands of + /// OpVectorShuffle. Populates |needs_named_const_def_| + void RegisterValuesNeedingNamedDefinition(); + /// Emits declarations of function variables. /// @returns false if emission failed. bool EmitFunctionVariables(); @@ -451,6 +458,11 @@ class FunctionEmitter { /// @returns an AST expression for the instruction, or nullptr. TypedExpression MakeCompositeExtract(const spvtools::opt::Instruction& inst); + /// Creates an expression for OpVectorShuffle + /// @param inst an OpVectorShuffle instruction. + /// @returns an AST expression for the instruction, or nullptr. + TypedExpression MakeVectorShuffle(const spvtools::opt::Instruction& inst); + /// Gets the block info for a block ID, if any exists /// @param id the SPIR-V ID of the OpLabel instruction starting the block /// @returns the block info for the given ID, if it exists, or nullptr @@ -583,6 +595,8 @@ class FunctionEmitter { std::unordered_set identifier_values_; // Mapping from SPIR-V ID that is used at most once, to its AST expression. std::unordered_map singly_used_values_; + // Set of SPIR-V IDs which should get a named const definition. + std::unordered_set needs_named_const_def_; // The IDs of basic blocks, in reverse structured post-order (RSPO). // This is the output order for the basic blocks. diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc index af6d400668..8d5cb70177 100644 --- a/src/reader/spirv/function_composite_test.cc +++ b/src/reader/spirv/function_composite_test.cc @@ -51,6 +51,8 @@ std::string Preamble() { %float_70 = OpConstant %float 70 %v2uint = OpTypeVector %uint 2 + %v3uint = OpTypeVector %uint 3 + %v4uint = OpTypeVector %uint 4 %v2int = OpTypeVector %int 2 %v2float = OpTypeVector %float 2 @@ -60,6 +62,8 @@ std::string Preamble() { %s_v2f_u_i = OpTypeStruct %v2float %uint %int %a_u_5 = OpTypeArray %uint %uint_5 + %v2uint_3_4 = OpConstantComposite %v2uint %uint_3 %uint_4 + %v2uint_4_3 = OpConstantComposite %v2uint %uint_4 %uint_3 %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60 %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50 %v2float_70_70 = OpConstantComposite %v2float %float_70 %float_70 @@ -464,7 +468,7 @@ TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) { FunctionEmitter fe(p, *spirv_function(100)); EXPECT_FALSE(fe.EmitBody()); EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 40 is out of " - "bounds for structure %23 having 3 elements")); + "bounds for structure %25 having 3 elements")); } TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) { @@ -584,6 +588,160 @@ VariableDeclStatement{ })")) << ToString(fe.ast_body()); } +using SpvParserTest_VectorShuffle = SpvParserTest; + +TEST_F(SpvParserTest_VectorShuffle, FunctionScopeOperands_UseBoth) { + // Note that variables are generated for the vector operands. + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %v2uint %v2uint_3_4 + %2 = OpIAdd %v2uint %v2uint_4_3 %v2uint_3_4 + %10 = OpVectorShuffle %v4uint %1 %2 3 2 1 0 + OpReturn + OpFunctionEnd +)"; + + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __vec_4__u32 + { + TypeConstructor{ + __vec_4__u32 + MemberAccessor{ + Identifier{x_2} + Identifier{y} + } + MemberAccessor{ + Identifier{x_2} + Identifier{x} + } + MemberAccessor{ + Identifier{x_1} + Identifier{y} + } + MemberAccessor{ + Identifier{x_1} + Identifier{x} + } + } + } + } +})")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_VectorShuffle, ConstantOperands_UseBoth) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %10 = OpVectorShuffle %v4uint %v2uint_3_4 %v2uint_4_3 3 2 1 0 + OpReturn + OpFunctionEnd +)"; + + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __vec_4__u32 + { + TypeConstructor{ + __vec_4__u32 + MemberAccessor{ + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{4} + ScalarConstructor{3} + } + Identifier{y} + } + MemberAccessor{ + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{4} + ScalarConstructor{3} + } + Identifier{x} + } + MemberAccessor{ + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{3} + ScalarConstructor{4} + } + Identifier{y} + } + MemberAccessor{ + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{3} + ScalarConstructor{4} + } + Identifier{x} + } + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_VectorShuffle, ConstantOperands_AllOnesMapToNull) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %v2uint %v2uint_4_3 + %10 = OpVectorShuffle %v2uint %1 %1 0xFFFFFFFF 1 + OpReturn + OpFunctionEnd +)"; + + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __vec_2__u32 + { + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{0} + MemberAccessor{ + Identifier{x_1} + Identifier{y} + } + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_VectorShuffle, IndexTooBig_IsError) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %10 = OpVectorShuffle %v4uint %v2uint_3_4 %v2uint_4_3 9 2 1 0 + OpReturn + OpFunctionEnd +)"; + + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()) << p->error(); + EXPECT_THAT(p->error(), + Eq("invalid vectorshuffle ID %10: index too large: 9")); +} + } // namespace } // namespace spirv } // namespace reader