diff --git a/src/ast/type/array_type.cc b/src/ast/type/array_type.cc index c9eb6b6d61..15ad533a9e 100644 --- a/src/ast/type/array_type.cc +++ b/src/ast/type/array_type.cc @@ -20,7 +20,7 @@ namespace type { ArrayType::ArrayType(Type* subtype) : subtype_(subtype) {} -ArrayType::ArrayType(Type* subtype, size_t size) +ArrayType::ArrayType(Type* subtype, uint32_t size) : subtype_(subtype), size_(size) {} ArrayType::~ArrayType() = default; diff --git a/src/ast/type/array_type.h b/src/ast/type/array_type.h index 28d25e0cd6..22962dd148 100644 --- a/src/ast/type/array_type.h +++ b/src/ast/type/array_type.h @@ -34,7 +34,7 @@ class ArrayType : public Type { /// Constructor /// @param subtype the type of the array elements /// @param size the number of elements in the array - ArrayType(Type* subtype, size_t size); + ArrayType(Type* subtype, uint32_t size); /// Move constructor ArrayType(ArrayType&&) = default; ~ArrayType() override; @@ -48,9 +48,9 @@ class ArrayType : public Type { /// @returns the array type Type* type() const { return subtype_; } /// @returns the array size. Size is 0 for a runtime array - size_t size() const { return size_; } + uint32_t size() const { return size_; } - /// @returns the name for th type + /// @returns the name for the type std::string type_name() const override { assert(subtype_); @@ -63,7 +63,7 @@ class ArrayType : public Type { private: Type* subtype_ = nullptr; - size_t size_ = 0; + uint32_t size_ = 0; }; } // namespace type diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index bca5b92b4c..fca923203b 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -22,8 +22,10 @@ #include "src/ast/struct.h" #include "src/ast/struct_member.h" #include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/type/array_type.h" #include "src/ast/type/matrix_type.h" #include "src/ast/type/struct_type.h" +#include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/uint_literal.h" @@ -266,7 +268,11 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) { auto result = result_op(); auto id = result.to_i(); - if (type->IsBool()) { + if (type->IsArray()) { + if (!GenerateArrayType(type->AsArray(), result)) { + return 0; + } + } else if (type->IsBool()) { push_type(spv::Op::OpTypeBool, {result}); } else if (type->IsF32()) { push_type(spv::Op::OpTypeFloat, {result, Operand::Int(32)}); @@ -297,6 +303,29 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) { return id; } +bool Builder::GenerateArrayType(ast::type::ArrayType* ary, + const Operand& result) { + auto elem_type = GenerateTypeIfNeeded(ary->type()); + if (elem_type == 0) { + return false; + } + + if (ary->IsRuntimeArray()) { + push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)}); + } else { + ast::type::U32Type u32; + ast::IntLiteral ary_size(&u32, ary->size()); + auto len_id = GenerateLiteralIfNeeded(&ary_size); + if (len_id == 0) { + return false; + } + + push_type(spv::Op::OpTypeArray, + {result, Operand::Int(elem_type), Operand::Int(len_id)}); + } + return true; +} + bool Builder::GenerateMatrixType(ast::type::MatrixType* mat, const Operand& result) { ast::type::VectorType col_type(mat->type(), mat->rows()); diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 3045fe2793..727fbab185 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -145,6 +145,11 @@ class Builder { /// @param type the type to create /// @returns the ID to use for the given type. Returns 0 on unknown type. uint32_t GenerateTypeIfNeeded(ast::type::Type* type); + /// Generates an array type declaration + /// @param ary the array to generate + /// @param result the result operand + /// @returns true if the array was successfully generated + bool GenerateArrayType(ast::type::ArrayType* ary, const Operand& result); /// Generates a matrix type declaration /// @param mat the matrix to generate /// @param result the result operand diff --git a/src/writer/spirv/builder_type_test.cc b/src/writer/spirv/builder_type_test.cc index 10f960d894..da753f7a72 100644 --- a/src/writer/spirv/builder_type_test.cc +++ b/src/writer/spirv/builder_type_test.cc @@ -69,6 +69,66 @@ TEST_F(BuilderTest_Type, ReturnsGeneratedAlias) { ASSERT_FALSE(b.has_error()) << b.error(); } +TEST_F(BuilderTest_Type, GenerateRuntimeArray) { + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32); + + Builder b; + auto id = b.GenerateTypeIfNeeded(&ary); + ASSERT_FALSE(b.has_error()) << b.error(); + EXPECT_EQ(1, id); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%1 = OpTypeRuntimeArray %2 +)"); +} + +TEST_F(BuilderTest_Type, ReturnsGeneratedRuntimeArray) { + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32); + + Builder b; + EXPECT_EQ(b.GenerateTypeIfNeeded(&ary), 1); + EXPECT_EQ(b.GenerateTypeIfNeeded(&ary), 1); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%1 = OpTypeRuntimeArray %2 +)"); +} + +TEST_F(BuilderTest_Type, GenerateArray) { + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32, 4); + + Builder b; + auto id = b.GenerateTypeIfNeeded(&ary); + ASSERT_FALSE(b.has_error()) << b.error(); + EXPECT_EQ(1, id); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 4 +%1 = OpTypeArray %2 %4 +)"); +} + +TEST_F(BuilderTest_Type, ReturnsGeneratedArray) { + ast::type::I32Type i32; + ast::type::ArrayType ary(&i32, 4); + + Builder b; + EXPECT_EQ(b.GenerateTypeIfNeeded(&ary), 1); + EXPECT_EQ(b.GenerateTypeIfNeeded(&ary), 1); + ASSERT_FALSE(b.has_error()) << b.error(); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%3 = OpTypeInt 32 0 +%4 = OpConstant %3 4 +%1 = OpTypeArray %2 %4 +)"); +} + TEST_F(BuilderTest_Type, GenerateBool) { ast::type::BoolType bool_type;