diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 6bb1fdb950..8011e7446c 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -13,6 +13,7 @@ #include "src/writer/spirv/builder.h" +#include #include #include @@ -97,6 +98,28 @@ bool LastIsTerminator(const ast::StatementList& stmts) { last->IsKill() || last->IsFallthrough(); } +uint32_t IndexFromName(char name) { + switch (name) { + case 'x': + case 'r': + case 's': + return 0; + case 'y': + case 'g': + case 't': + return 1; + case 'z': + case 'b': + case 'p': + return 2; + case 'w': + case 'a': + case 'q': + return 3; + } + return std::numeric_limits::max(); +} + } // namespace Builder::Builder(ast::Module* mod) : mod_(mod), scope_stack_({}) {} @@ -502,7 +525,28 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { idx_list.insert(idx_list.begin(), Operand::Int(idx_id)); } else if (data_type->IsVector()) { - // TODO(dsinclair): Handle swizzle + auto swiz = mem_accessor->member()->name(); + if (swiz.size() == 1) { + // A single item swizzle is a simple access chain + auto val = IndexFromName(swiz[0]); + if (val == std::numeric_limits::max()) { + error_ = "invalid swizzle name: " + swiz; + return false; + } + + ast::type::U32Type u32; + ast::IntLiteral idx(&u32, val); + auto idx_id = GenerateLiteralIfNeeded(&idx); + if (idx_id == 0) { + return false; + } + idx_list.insert(idx_list.begin(), Operand::Int(idx_id)); + } else { + // A multi-item swizzle means we need to generate the access chain + // to the current point and then pull values out of it + // + // TODO(dsinclair): Handle multi-item swizzle + } } else { error_ = "invalid type for member accessor: " + data_type->type_name(); return 0; diff --git a/src/writer/spirv/builder_accessor_expression_test.cc b/src/writer/spirv/builder_accessor_expression_test.cc index f09f0cff21..2e1b85f155 100644 --- a/src/writer/spirv/builder_accessor_expression_test.cc +++ b/src/writer/spirv/builder_accessor_expression_test.cc @@ -472,7 +472,7 @@ OpStore %6 %11 )"); } -TEST_F(BuilderTest, DISABLED_MemberAccessor_Swizzle_Single) { +TEST_F(BuilderTest, MemberAccessor_Swizzle_Single) { ast::type::F32Type f32; ast::type::VectorType vec3(&f32, 3); @@ -492,20 +492,20 @@ TEST_F(BuilderTest, DISABLED_MemberAccessor_Swizzle_Single) { b.push_function(Function{}); ASSERT_TRUE(b.GenerateFunctionVariable(&var)) << b.error(); - EXPECT_EQ(b.GenerateAccessorExpression(&expr), 6u); + EXPECT_EQ(b.GenerateAccessorExpression(&expr), 5u); - EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 -%2 = OpTypeVector %3 3 -%1 = OpTypePointer Function %2 -%5 = OpTypeInt 32 0 -%6 = OpConstant %5 1 -%7 = OpTypePointer Function %3 + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Function %3 +%6 = OpTypeInt 32 0 +%7 = OpConstant %6 1 +%8 = OpTypePointer Function %4 )"); EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), - R"(%1 = OpVariable %1 Function + R"(%1 = OpVariable %2 Function )"); EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%7 = OpAccessChain %7 %1 %6 + R"(%5 = OpAccessChain %8 %1 %7 )"); }