diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 8011e7446c..802e34a948 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -680,6 +680,9 @@ uint32_t Builder::GenerateConstructorExpression( } constructor_is_const = false; } + } + + for (const auto& e : init->values()) { uint32_t id = 0; if (constructor_is_const) { id = GenerateConstructorExpression(e->AsConstructor(), is_global_init); @@ -691,8 +694,32 @@ uint32_t Builder::GenerateConstructorExpression( return 0; } - out << "_" << id; - ops.push_back(Operand::Int(id)); + auto* result_type = e->result_type()->UnwrapPtrIfNeeded(); + + // If we're putting a vector into the constructed composite we need to + // extract each of the values and insert them individually + if (result_type->IsVector()) { + auto* vec = result_type->AsVector(); + auto result_type_id = GenerateTypeIfNeeded(vec->type()); + if (result_type_id == 0) { + return 0; + } + + for (uint32_t i = 0; i < vec->size(); ++i) { + auto extract = result_op(); + auto extract_id = extract.to_i(); + + push_function_inst(spv::Op::OpCompositeExtract, + {Operand::Int(result_type_id), extract, + Operand::Int(id), Operand::Int(i)}); + + out << "_" << extract_id; + ops.push_back(Operand::Int(extract_id)); + } + } else { + out << "_" << id; + ops.push_back(Operand::Int(id)); + } } auto str = out.str(); diff --git a/src/writer/spirv/builder_constructor_expression_test.cc b/src/writer/spirv/builder_constructor_expression_test.cc index efa3875a6d..b7af4c88f7 100644 --- a/src/writer/spirv/builder_constructor_expression_test.cc +++ b/src/writer/spirv/builder_constructor_expression_test.cc @@ -65,7 +65,11 @@ TEST_F(BuilderTest, Constructor_Type) { ast::TypeConstructorExpression t(&vec, std::move(vals)); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + Builder b(&mod); EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u); ASSERT_FALSE(b.has_error()) << b.error(); @@ -120,6 +124,54 @@ TEST_F(BuilderTest, Constructor_Type_NonConstructorParam) { )"); } +TEST_F(BuilderTest, Constructor_Type_NonConstVector) { + ast::type::F32Type f32; + ast::type::VectorType vec2(&f32, 2); + ast::type::VectorType vec4(&f32, 4); + + auto var = std::make_unique( + "ident", ast::StorageClass::kFunction, &vec2); + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + vals.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + vals.push_back(std::make_unique("ident")); + + ast::TypeConstructorExpression t(&vec4, std::move(vals)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateFunctionVariable(var.get())) << b.error(); + + EXPECT_EQ(b.GenerateConstructorExpression(&t, false), 10u); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 2 +%2 = OpTypePointer Function %3 +%5 = OpTypeVector %4 4 +%6 = OpConstant %4 1 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].variables()), + R"(%1 = OpVariable %2 Function +)"); + + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%7 = OpLoad %3 %1 +%8 = OpCompositeExtract %4 %7 0 +%9 = OpCompositeExtract %4 %7 1 +%10 = OpCompositeConstruct %5 %6 %6 %8 %9 +)"); +} + TEST_F(BuilderTest, Constructor_Type_Dedups) { ast::type::F32Type f32; ast::type::VectorType vec(&f32, 3); @@ -134,7 +186,11 @@ TEST_F(BuilderTest, Constructor_Type_Dedups) { ast::TypeConstructorExpression t(&vec, std::move(vals)); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&t)) << td.error(); + Builder b(&mod); EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u); EXPECT_EQ(b.GenerateConstructorExpression(&t, true), 5u); diff --git a/src/writer/spirv/builder_function_variable_test.cc b/src/writer/spirv/builder_function_variable_test.cc index 38e98bddaa..33ce0f133c 100644 --- a/src/writer/spirv/builder_function_variable_test.cc +++ b/src/writer/spirv/builder_function_variable_test.cc @@ -30,6 +30,8 @@ #include "src/ast/type_constructor_expression.h" #include "src/ast/variable.h" #include "src/ast/variable_decoration.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -74,10 +76,16 @@ TEST_F(BuilderTest, FunctionVar_WithConstantConstructor) { auto init = std::make_unique(&vec, std::move(vals)); + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error(); + ast::Variable v("var", ast::StorageClass::kOutput, &f32); v.set_constructor(std::move(init)); - ast::Module mod; + td.RegisterVariableForTesting(&v); + Builder b(&mod); b.push_function(Function{}); EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error(); @@ -161,11 +169,17 @@ TEST_F(BuilderTest, FunctionVar_Const) { auto init = std::make_unique(&vec, std::move(vals)); + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error(); + ast::Variable v("var", ast::StorageClass::kOutput, &f32); v.set_constructor(std::move(init)); v.set_is_const(true); - ast::Module mod; + td.RegisterVariableForTesting(&v); + Builder b(&mod); EXPECT_TRUE(b.GenerateFunctionVariable(&v)) << b.error(); ASSERT_FALSE(b.has_error()) << b.error(); diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc index aad5670cea..8763d87151 100644 --- a/src/writer/spirv/builder_global_variable_test.cc +++ b/src/writer/spirv/builder_global_variable_test.cc @@ -29,6 +29,8 @@ #include "src/ast/type_constructor_expression.h" #include "src/ast/variable.h" #include "src/ast/variable_decoration.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -84,10 +86,15 @@ TEST_F(BuilderTest, GlobalVar_WithConstructor) { auto init = std::make_unique(&vec, std::move(vals)); + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error(); + ast::Variable v("var", ast::StorageClass::kOutput, &f32); v.set_constructor(std::move(init)); + td.RegisterVariableForTesting(&v); - ast::Module mod; Builder b(&mod); EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); ASSERT_FALSE(b.has_error()) << b.error(); @@ -119,11 +126,16 @@ TEST_F(BuilderTest, GlobalVar_Const) { auto init = std::make_unique(&vec, std::move(vals)); + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error(); + ast::Variable v("var", ast::StorageClass::kOutput, &f32); v.set_constructor(std::move(init)); v.set_is_const(true); + td.RegisterVariableForTesting(&v); - ast::Module mod; Builder b(&mod); EXPECT_TRUE(b.GenerateGlobalVariable(&v)) << b.error(); ASSERT_FALSE(b.has_error()) << b.error(); diff --git a/src/writer/spirv/builder_ident_expression_test.cc b/src/writer/spirv/builder_ident_expression_test.cc index e69bb73cab..5fb39b59d7 100644 --- a/src/writer/spirv/builder_ident_expression_test.cc +++ b/src/writer/spirv/builder_ident_expression_test.cc @@ -52,13 +52,15 @@ TEST_F(BuilderTest, IdentifierExpression_GlobalConst) { auto init = std::make_unique(&vec, std::move(vals)); + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error(); + ast::Variable v("var", ast::StorageClass::kOutput, &f32); v.set_constructor(std::move(init)); v.set_is_const(true); - Context ctx; - ast::Module mod; - TypeDeterminer td(&ctx, &mod); td.RegisterVariableForTesting(&v); Builder b(&mod); @@ -117,13 +119,14 @@ TEST_F(BuilderTest, IdentifierExpression_FunctionConst) { auto init = std::make_unique(&vec, std::move(vals)); - ast::Variable v("var", ast::StorageClass::kOutput, &f32); - v.set_constructor(std::move(init)); - v.set_is_const(true); - Context ctx; ast::Module mod; TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(init.get())) << td.error(); + + ast::Variable v("var", ast::StorageClass::kOutput, &f32); + v.set_constructor(std::move(init)); + v.set_is_const(true); td.RegisterVariableForTesting(&v); Builder b(&mod); diff --git a/src/writer/spirv/builder_return_test.cc b/src/writer/spirv/builder_return_test.cc index e6ab54c848..48933a4a1c 100644 --- a/src/writer/spirv/builder_return_test.cc +++ b/src/writer/spirv/builder_return_test.cc @@ -21,6 +21,8 @@ #include "src/ast/type/f32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/spirv/builder.h" #include "src/writer/spirv/spv_dump.h" @@ -61,7 +63,11 @@ TEST_F(BuilderTest, Return_WithValue) { ast::ReturnStatement ret(std::move(val)); + Context ctx; ast::Module mod; + TypeDeterminer td(&ctx, &mod); + EXPECT_TRUE(td.DetermineResultType(&ret)) << td.error(); + Builder b(&mod); b.push_function(Function{}); EXPECT_TRUE(b.GenerateReturnStatement(&ret));