diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 05a9c9e4e3..75276f2cd8 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -45,6 +45,7 @@ #include "src/ast/sint_literal.h" #include "src/ast/storage_class.h" #include "src/ast/switch_statement.h" +#include "src/ast/type/u32_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/uint_literal.h" #include "src/ast/unary_op.h" @@ -2400,6 +2401,10 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( ast_type, std::move(operands))}; } + if (opcode == SpvOpCompositeExtract) { + return MakeCompositeExtract(inst); + } + // builtin readonly function // glsl.std.450 readonly function @@ -2462,7 +2467,7 @@ TypedExpression FunctionEmitter::MakeAccessChain( // A SPIR-V access chain is a single instruction with multiple indices // walking down into composites. The Tint AST represents this as - // ever-deeper nested indexing expresions. Start off with an expression + // ever-deeper nested indexing expressions. Start off with an expression // for the base, and then bury that inside nested indexing expressions. TypedExpression current_expr(MakeOperand(inst, 0)); @@ -2574,6 +2579,113 @@ TypedExpression FunctionEmitter::MakeAccessChain( return current_expr; } +TypedExpression FunctionEmitter::MakeCompositeExtract( + const spvtools::opt::Instruction& inst) { + // This is structurally similar to creating an access chain, but + // the SPIR-V instruction has literal indices instead of IDs for indices. + + // A SPIR-V composite extract is a single instruction with multiple + // literal indices walking down into composites. The Tint AST represents + // this as ever-deeper nested indexing expressions. Start off with an + // expression for the composite, and then bury that inside nested indexing + // expressions. + TypedExpression current_expr(MakeOperand(inst, 0)); + + auto make_index = [](uint32_t literal) { + ast::type::U32Type u32; + return std::make_unique( + std::make_unique(&u32, literal)); + }; + static const char* swizzles[] = {"x", "y", "z", "w"}; + + const auto composite = inst.GetSingleWordInOperand(0); + const auto composite_type_id = def_use_mgr_->GetDef(composite)->type_id(); + const auto* current_type = type_mgr_->GetType(composite_type_id); + const auto num_in_operands = inst.NumInOperands(); + for (uint32_t index = 1; index < num_in_operands; ++index) { + const uint32_t index_val = inst.GetSingleWordInOperand(index); + std::unique_ptr next_expr; + switch (current_type->kind()) { + case spvtools::opt::analysis::Type::kVector: { + // Try generating a MemberAccessor expression. That result in something + // like "foo.z", which is more idiomatic than "foo[2]". + if (current_type->AsVector()->element_count() <= index_val) { + Fail() << "CompositeExtract %" << inst.result_id() << " index value " + << index_val << " is out of bounds for vector of " + << current_type->AsVector()->element_count() << " elements"; + return {}; + } + if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) { + Fail() << "internal error: swizzle index " << index_val + << " is too big. Max handled index is " + << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1); + } + auto letter_index = + std::make_unique(swizzles[index_val]); + next_expr = std::make_unique( + std::move(current_expr.expr), std::move(letter_index)); + current_type = current_type->AsVector()->element_type(); + break; + } + case spvtools::opt::analysis::Type::kMatrix: + // Check bounds + if (current_type->AsMatrix()->element_count() <= index_val) { + Fail() << "CompositeExtract %" << inst.result_id() << " index value " + << index_val << " is out of bounds for matrix of " + << current_type->AsMatrix()->element_count() << " elements"; + return {}; + } + if (index_val >= sizeof(swizzles) / sizeof(swizzles[0])) { + Fail() << "internal error: swizzle index " << index_val + << " is too big. Max handled index is " + << ((sizeof(swizzles) / sizeof(swizzles[0])) - 1); + } + // Use array syntax. + next_expr = std::make_unique( + std::move(current_expr.expr), make_index(index_val)); + current_type = current_type->AsMatrix()->element_type(); + break; + case spvtools::opt::analysis::Type::kArray: + // The array size could be a spec constant, and so it's not always + // statically checkable. Instead, rely on a runtime index clamp + // or runtime check to keep this safe. + next_expr = std::make_unique( + std::move(current_expr.expr), make_index(index_val)); + current_type = current_type->AsArray()->element_type(); + break; + case spvtools::opt::analysis::Type::kRuntimeArray: + Fail() << "can't do OpCompositeExtract on a runtime array"; + return {}; + case spvtools::opt::analysis::Type::kStruct: { + if (current_type->AsStruct()->element_types().size() <= index_val) { + Fail() << "CompositeExtract %" << inst.result_id() << " index value " + << index_val << " is out of bounds for structure %" + << type_mgr_->GetId(current_type) << " having " + << current_type->AsStruct()->element_types().size() + << " elements"; + return {}; + } + auto member_access = + std::make_unique(namer_.GetMemberName( + type_mgr_->GetId(current_type), uint32_t(index_val))); + + next_expr = std::make_unique( + std::move(current_expr.expr), std::move(member_access)); + current_type = current_type->AsStruct()->element_types()[index_val]; + break; + } + default: + Fail() << "CompositeExtract with bad type %" + << type_mgr_->GetId(current_type) << " " << current_type->str(); + return {}; + } + current_expr.reset(TypedExpression( + parser_impl_.ConvertType(type_mgr_->GetId(current_type)), + std::move(next_expr))); + } + return current_expr; +} + } // namespace spirv } // namespace reader } // namespace tint diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index 0e9290726f..686c92482b 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -433,6 +433,11 @@ class FunctionEmitter { /// @returns an AST expression for the instruction, or nullptr. TypedExpression EmitGlslStd450ExtInst(const spvtools::opt::Instruction& inst); + /// Creates an expression for OpCompositeExtract + /// @param inst an OpCompositeExtract instruction. + /// @returns an AST expression for the instruction, or nullptr. + TypedExpression MakeCompositeExtract(const spvtools::opt::Instruction& inst); + /// Gets the block info for a block ID, if any exists /// @param id the SPIR-V ID of the OpLabel instruction starting the block /// @returns the block info for the given ID, if it exists, or nullptr diff --git a/src/reader/spirv/function_composite_test.cc b/src/reader/spirv/function_composite_test.cc index 6030a90869..944db61de9 100644 --- a/src/reader/spirv/function_composite_test.cc +++ b/src/reader/spirv/function_composite_test.cc @@ -27,6 +27,7 @@ namespace reader { namespace spirv { namespace { +using ::testing::Eq; using ::testing::HasSubstr; std::string Preamble() { @@ -54,6 +55,7 @@ std::string Preamble() { %v2float = OpTypeVector %float 2 %m3v2float = OpTypeMatrix %v2float 3 + %m3v2float_0 = OpConstantNull %m3v2float %s_v2f_u_i = OpTypeStruct %v2float %uint %int %a_u_5 = OpTypeArray %uint %uint_5 @@ -229,6 +231,283 @@ TEST_F(SpvParserTest_Composite_Construct, Struct) { << ToString(fe.ast_body()); } +using SpvParserTest_CompositeExtract = SpvParserTest; + +TEST_F(SpvParserTest_CompositeExtract, Vector) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCompositeExtract %float %v2float_50_60 1 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_1 + none + __f32 + { + MemberAccessor{ + TypeConstructor{ + __vec_2__f32 + ScalarConstructor{50.000000} + ScalarConstructor{60.000000} + } + Identifier{y} + } + } + })")) << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_CompositeExtract, Vector_IndexTooBigError) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCompositeExtract %float %v2float_50_60 900 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), Eq("CompositeExtract %1 index value 900 is out of " + "bounds for vector of 2 elements")); +} + +TEST_F(SpvParserTest_CompositeExtract, Matrix) { + const auto assembly = Preamble() + R"( + %ptr = OpTypePointer Function %m3v2float + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %m3v2float %var + %2 = OpCompositeExtract %v2float %1 2 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_2 + none + __vec_2__f32 + { + ArrayAccessor{ + Identifier{x_1} + ScalarConstructor{2} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_CompositeExtract, Matrix_IndexTooBigError) { + const auto assembly = Preamble() + R"( + %ptr = OpTypePointer Function %m3v2float + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %m3v2float %var + %2 = OpCompositeExtract %v2float %1 3 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()) << p->error(); + EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 3 is out of " + "bounds for matrix of 3 elements")); +} + +TEST_F(SpvParserTest_CompositeExtract, Matrix_Vector) { + const auto assembly = Preamble() + R"( + %ptr = OpTypePointer Function %m3v2float + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %m3v2float %var + %2 = OpCompositeExtract %float %1 2 1 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_2 + none + __f32 + { + MemberAccessor{ + ArrayAccessor{ + Identifier{x_1} + ScalarConstructor{2} + } + Identifier{y} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_CompositeExtract, Array) { + const auto assembly = Preamble() + R"( + %ptr = OpTypePointer Function %a_u_5 + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %a_u_5 %var + %2 = OpCompositeExtract %uint %1 3 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_2 + none + __u32 + { + ArrayAccessor{ + Identifier{x_1} + ScalarConstructor{3} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_CompositeExtract, RuntimeArray_IsError) { + const auto assembly = Preamble() + R"( + %rtarr = OpTypeRuntimeArray %uint + %ptr = OpTypePointer Function %rtarr + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %rtarr %var + %2 = OpCompositeExtract %uint %1 3 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()) << p->error(); + EXPECT_THAT(p->error(), Eq("can't do OpCompositeExtract on a runtime array")); +} + +TEST_F(SpvParserTest_CompositeExtract, Struct) { + const auto assembly = Preamble() + R"( + %ptr = OpTypePointer Function %s_v2f_u_i + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %s_v2f_u_i %var + %2 = OpCompositeExtract %int %1 2 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_2 + none + __i32 + { + MemberAccessor{ + Identifier{x_1} + Identifier{field2} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvParserTest_CompositeExtract, Struct_IndexTooBigError) { + const auto assembly = Preamble() + R"( + %ptr = OpTypePointer Function %s_v2f_u_i + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %s_v2f_u_i %var + %2 = OpCompositeExtract %int %1 40 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), Eq("CompositeExtract %2 index value 40 is out of " + "bounds for structure %23 having 3 elements")); +} + +TEST_F(SpvParserTest_CompositeExtract, Struct_Array_Matrix_Vector) { + const auto assembly = Preamble() + R"( + %a_mat = OpTypeArray %m3v2float %uint_3 + %s = OpTypeStruct %uint %a_mat + %ptr = OpTypePointer Function %s + %var = OpVariable %ptr Function + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpLoad %s %var + %2 = OpCompositeExtract %float %1 1 2 0 1 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_2 + none + __f32 + { + MemberAccessor{ + ArrayAccessor{ + ArrayAccessor{ + MemberAccessor{ + Identifier{x_1} + Identifier{field1} + } + ScalarConstructor{2} + } + ScalarConstructor{0} + } + Identifier{y} + } + } + })")) + << ToString(fe.ast_body()); +} + } // namespace } // namespace spirv } // namespace reader