diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index fca923203b..5820bf2fc3 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -24,6 +24,7 @@ #include "src/ast/struct_member_offset_decoration.h" #include "src/ast/type/array_type.h" #include "src/ast/type/matrix_type.h" +#include "src/ast/type/pointer_type.h" #include "src/ast/type/struct_type.h" #include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" @@ -282,6 +283,10 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) { if (!GenerateMatrixType(type->AsMatrix(), result)) { return 0; } + } else if (type->IsPointer()) { + if (!GeneratePointerType(type->AsPointer(), result)) { + return 0; + } } else if (type->IsStruct()) { if (!GenerateStructType(type->AsStruct(), result)) { return 0; @@ -339,6 +344,24 @@ bool Builder::GenerateMatrixType(ast::type::MatrixType* mat, return true; } +bool Builder::GeneratePointerType(ast::type::PointerType* ptr, + const Operand& result) { + auto pointee_id = GenerateTypeIfNeeded(ptr->type()); + if (pointee_id == 0) { + return false; + } + + auto stg_class = ConvertStorageClass(ptr->storage_class()); + if (stg_class == SpvStorageClassMax) { + return false; + } + + push_type(spv::Op::OpTypePointer, + {result, Operand::Int(stg_class), Operand::Int(pointee_id)}); + + return true; +} + bool Builder::GenerateStructType(ast::type::StructType* struct_type, const Operand& result) { auto struct_id = result.to_i(); @@ -409,6 +432,34 @@ bool Builder::GenerateVectorType(ast::type::VectorType* vec, return true; } +SpvStorageClass Builder::ConvertStorageClass(ast::StorageClass klass) const { + switch (klass) { + case ast::StorageClass::kInput: + return SpvStorageClassInput; + case ast::StorageClass::kOutput: + return SpvStorageClassOutput; + case ast::StorageClass::kUniform: + return SpvStorageClassUniform; + case ast::StorageClass::kWorkgroup: + return SpvStorageClassWorkgroup; + case ast::StorageClass::kUniformConstant: + return SpvStorageClassUniformConstant; + case ast::StorageClass::kStorageBuffer: + return SpvStorageClassStorageBuffer; + case ast::StorageClass::kImage: + return SpvStorageClassImage; + case ast::StorageClass::kPushConstant: + return SpvStorageClassPushConstant; + case ast::StorageClass::kPrivate: + return SpvStorageClassPrivate; + case ast::StorageClass::kFunction: + return SpvStorageClassFunction; + case ast::StorageClass::kNone: + break; + } + return SpvStorageClassMax; +} + } // namespace spirv } // namespace writer } // namespace tint diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 727fbab185..8aefe564e5 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -20,6 +20,7 @@ #include #include +#include "spirv/unified1/spirv.h" #include "src/ast/literal.h" #include "src/ast/module.h" #include "src/ast/struct_member.h" @@ -122,6 +123,11 @@ class Builder { /// @returns the annotations const std::vector& annots() const { return annotations_; } + /// Converts a storage class to a SPIR-V storage class. + /// @param klass the storage class to convert + /// @returns the SPIR-V storage class or SpvStorageClassMax on error. + SpvStorageClass ConvertStorageClass(ast::StorageClass klass) const; + /// Generates an entry point instruction /// @param ep the entry point /// @returns true if the instruction was generated, false otherwise @@ -155,6 +161,11 @@ class Builder { /// @param result the result operand /// @returns true if the matrix was successfully generated bool GenerateMatrixType(ast::type::MatrixType* mat, const Operand& result); + /// Generates a pointer type declaration + /// @param ptr the pointer type to generate + /// @param result the result operand + /// @returns true if the pointer was successfully generated + bool GeneratePointerType(ast::type::PointerType* ptr, const Operand& result); /// Generates a vector type declaration /// @param struct_type the vector to generate /// @param result the result operand diff --git a/src/writer/spirv/builder_type_test.cc b/src/writer/spirv/builder_type_test.cc index da753f7a72..064ce920ab 100644 --- a/src/writer/spirv/builder_type_test.cc +++ b/src/writer/spirv/builder_type_test.cc @@ -236,6 +236,29 @@ TEST_F(BuilderTest_Type, ReturnsGeneratedMatrix) { ASSERT_FALSE(b.has_error()) << b.error(); } +TEST_F(BuilderTest_Type, GeneratePtr) { + ast::type::I32Type i32; + ast::type::PointerType ptr(&i32, ast::StorageClass::kOutput); + + Builder b; + auto id = b.GenerateTypeIfNeeded(&ptr); + ASSERT_FALSE(b.has_error()) << b.error(); + EXPECT_EQ(1, id); + + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1 +%1 = OpTypePointer Output %2 +)"); +} + +TEST_F(BuilderTest_Type, ReturnsGeneratedPtr) { + ast::type::I32Type i32; + ast::type::PointerType ptr(&i32, ast::StorageClass::kOutput); + + Builder b; + EXPECT_EQ(b.GenerateTypeIfNeeded(&ptr), 1); + EXPECT_EQ(b.GenerateTypeIfNeeded(&ptr), 1); +} + TEST_F(BuilderTest_Type, GenerateStruct_Empty) { auto s = std::make_unique(); ast::type::StructType s_type(std::move(s)); @@ -419,6 +442,39 @@ TEST_F(BuilderTest_Type, ReturnsGeneratedVoid) { ASSERT_FALSE(b.has_error()) << b.error(); } +struct PtrData { + ast::StorageClass ast_class; + SpvStorageClass result; +}; +inline std::ostream& operator<<(std::ostream& out, PtrData data) { + out << data.ast_class; + return out; +} +using PtrDataTest = testing::TestWithParam; +TEST_P(PtrDataTest, ConvertStorageClass) { + auto params = GetParam(); + + Builder b; + EXPECT_EQ(b.ConvertStorageClass(params.ast_class), params.result); +} +INSTANTIATE_TEST_SUITE_P( + BuilderTest_Type, + PtrDataTest, + testing::Values( + PtrData{ast::StorageClass::kNone, SpvStorageClassMax}, + PtrData{ast::StorageClass::kInput, SpvStorageClassInput}, + PtrData{ast::StorageClass::kOutput, SpvStorageClassOutput}, + PtrData{ast::StorageClass::kUniform, SpvStorageClassUniform}, + PtrData{ast::StorageClass::kWorkgroup, SpvStorageClassWorkgroup}, + PtrData{ast::StorageClass::kUniformConstant, + SpvStorageClassUniformConstant}, + PtrData{ast::StorageClass::kStorageBuffer, + SpvStorageClassStorageBuffer}, + PtrData{ast::StorageClass::kImage, SpvStorageClassImage}, + PtrData{ast::StorageClass::kPushConstant, SpvStorageClassPushConstant}, + PtrData{ast::StorageClass::kPrivate, SpvStorageClassPrivate}, + PtrData{ast::StorageClass::kFunction, SpvStorageClassFunction})); + } // namespace } // namespace spirv } // namespace writer