diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 5d1e5c4b46..15286e85d0 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -172,6 +172,8 @@ namespace spirv { namespace { +constexpr uint32_t kMaxVectorLen = 4; + // Gets the AST unary opcode for the given SPIR-V opcode, if any // @param opcode SPIR-V opcode // @param ast_unary_op return parameter @@ -2874,6 +2876,16 @@ TypedExpression FunctionEmitter::EmitGlslStd450ExtInst( return {ast_type, call}; } +ast::IdentifierExpression* FunctionEmitter::Swizzle(uint32_t i) { + if (i >= kMaxVectorLen) { + Fail() << "vector component index is larger than " << kMaxVectorLen - 1 + << ": " << i; + return nullptr; + } + const char* names[] = {"x", "y", "z", "w"}; + return ast_module_.create(names[i & 3]); +} + TypedExpression FunctionEmitter::MakeAccessChain( const spvtools::opt::Instruction& inst) { if (inst.NumInOperands() < 1) { @@ -2888,7 +2900,6 @@ TypedExpression FunctionEmitter::MakeAccessChain( // for the base, and then bury that inside nested indexing expressions. TypedExpression current_expr(MakeOperand(inst, 0)); const auto constants = constant_mgr_->GetOperandConstants(&inst); - static const char* swizzles[] = {"x", "y", "z", "w"}; const auto base_id = inst.GetSingleWordInOperand(0); auto ptr_ty_id = def_use_mgr_->GetDef(base_id)->type_id(); @@ -2981,16 +2992,12 @@ TypedExpression FunctionEmitter::MakeAccessChain( << num_elems << " elements"; return {}; } - if (uint64_t(index_const_val) >= - sizeof(swizzles) / sizeof(swizzles[0])) { + if (uint64_t(index_const_val) >= kMaxVectorLen) { Fail() << "internal error: swizzle index " << index_const_val - << " is too big. Max handled index is " - << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1); + << " is too big. Max handled index is " << kMaxVectorLen - 1; } - auto* letter_index = - create(swizzles[index_const_val]); - next_expr = create(current_expr.expr, - letter_index); + next_expr = create( + current_expr.expr, Swizzle(uint32_t(index_const_val))); } else { // Non-constant index. Use array syntax next_expr = create( @@ -3072,7 +3079,6 @@ TypedExpression FunctionEmitter::MakeCompositeExtract( return create( create(&u32, literal)); }; - static const char* swizzles[] = {"x", "y", "z", "w"}; const auto composite = inst.GetSingleWordInOperand(0); auto current_type_id = def_use_mgr_->GetDef(composite)->type_id(); @@ -3102,15 +3108,12 @@ TypedExpression FunctionEmitter::MakeCompositeExtract( << " elements"; return {}; } - if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) { + if (index_val >= kMaxVectorLen) { Fail() << "internal error: swizzle index " << index_val - << " is too big. Max handled index is " - << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1); + << " is too big. Max handled index is " << kMaxVectorLen - 1; } - auto* letter_index = - create(swizzles[index_val]); next_expr = create(current_expr.expr, - letter_index); + Swizzle(index_val)); // All vector components are the same type. current_type_id = current_type_inst->GetSingleWordInOperand(0); break; @@ -3124,10 +3127,9 @@ TypedExpression FunctionEmitter::MakeCompositeExtract( << " elements"; return {}; } - if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) { + if (index_val >= kMaxVectorLen) { Fail() << "internal error: swizzle index " << index_val - << " is too big. Max handled index is " - << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1); + << " is too big. Max handled index is " << kMaxVectorLen - 1; } // Use array syntax. next_expr = create(current_expr.expr, @@ -3197,7 +3199,6 @@ TypedExpression FunctionEmitter::MakeVectorShuffle( type_mgr_->GetType(vec1.type_id())->AsVector()->element_count(); // Idiomatic vector accessors. - const char* swizzles[] = {"x", "y", "z", "w"}; // Generate an ast::TypeConstructor expression. // Assume the literal indices are valid, and there is a valid number of them. @@ -3207,16 +3208,13 @@ TypedExpression FunctionEmitter::MakeVectorShuffle( for (uint32_t i = 2; i < inst.NumInOperands(); ++i) { const auto index = inst.GetSingleWordInOperand(i); if (index < vec0_len) { - assert(index < sizeof(swizzles) / sizeof(swizzles[0])); values.emplace_back(create( - MakeExpression(vec0_id).expr, - create(swizzles[index]))); + MakeExpression(vec0_id).expr, Swizzle(index))); } else if (index < vec0_len + vec1_len) { const auto sub_index = index - vec0_len; - assert(sub_index < sizeof(swizzles) / sizeof(swizzles[0])); + assert(sub_index < kMaxVectorLen); values.emplace_back(create( - MakeExpression(vec1_id).expr, - create(swizzles[sub_index]))); + MakeExpression(vec1_id).expr, Swizzle(sub_index))); } else if (index == 0xFFFFFFFF) { // By rule, this maps to OpUndef. Instead, make it zero. values.emplace_back(parser_impl_.MakeNullValue(result_type->type())); diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index cfebe80b01..f5c3872550 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -672,6 +672,13 @@ class FunctionEmitter { /// @returns the associated loop construct, or nullptr const Construct* SiblingLoopConstruct(const Construct* c) const; + /// Returns an identifier expression for the swizzle name of the given + /// index into a vector. Emits an error and returns nullptr if the + /// index is out of range, i.e. 4 or higher. + /// @param i index of the subcomponent + /// @returns the identifier expression for the @p i'th component + ast::IdentifierExpression* Swizzle(uint32_t i); + private: /// @returns the store type for the OpVariable instruction, or /// null on failure. diff --git a/src/reader/spirv/function_misc_test.cc b/src/reader/spirv/function_misc_test.cc index f5d1bd94de..e215fdd7dd 100644 --- a/src/reader/spirv/function_misc_test.cc +++ b/src/reader/spirv/function_misc_test.cc @@ -16,6 +16,7 @@ #include #include "gmock/gmock.h" +#include "src/ast/identifier_expression.h" #include "src/reader/spirv/function.h" #include "src/reader/spirv/parser_impl.h" #include "src/reader/spirv/parser_impl_test_helper.h" @@ -295,6 +296,53 @@ TEST_F(SpvParserTestMiscInstruction, OpNop) { )")) << ToString(fe.ast_body()); } +// Test swizzle generation. + +struct SwizzleCase { + uint32_t index; + std::string expected_expr; + std::string expected_error; +}; +using SpvParserSwizzleTest = + SpvParserTestBase<::testing::TestWithParam>; + +TEST_P(SpvParserSwizzleTest, Sample) { + // We need a function so we can get a FunctionEmitter. + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + OpReturn + OpFunctionEnd +)"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + + auto* result = fe.Swizzle(GetParam().index); + if (GetParam().expected_error.empty()) { + EXPECT_TRUE(fe.success()); + ASSERT_NE(result, nullptr); + std::ostringstream ss; + result->to_str(ss, 0); + EXPECT_THAT(ss.str(), Eq(GetParam().expected_expr)); + } else { + EXPECT_EQ(result, nullptr); + EXPECT_FALSE(fe.success()); + EXPECT_THAT(p->error(), Eq(GetParam().expected_error)); + } +} + +INSTANTIATE_TEST_SUITE_P( + ValidIndex, + SpvParserSwizzleTest, + ::testing::ValuesIn(std::vector{ + {0, "Identifier[not set]{x}\n", ""}, + {1, "Identifier[not set]{y}\n", ""}, + {2, "Identifier[not set]{z}\n", ""}, + {3, "Identifier[not set]{w}\n", ""}, + {4, "", "vector component index is larger than 3: 4"}, + {99999, "", "vector component index is larger than 3: 99999"}})); + // TODO(dneto): OpSizeof : requires Kernel (OpenCL) } // namespace