diff --git a/src/ast/type/access_control_type.h b/src/ast/type/access_control_type.h index 9bd04c50cb..cf802b6cfa 100644 --- a/src/ast/type/access_control_type.h +++ b/src/ast/type/access_control_type.h @@ -44,6 +44,8 @@ class AccessControlType : public Type { /// @returns true if the access control is read/write bool IsReadWrite() const { return access_ == AccessControl::kReadWrite; } + /// @returns the access control value + AccessControl access_control() const { return access_; } /// @returns the subtype type Type* type() const { return subtype_; } diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 6f4184c056..1d7ad02224 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -49,6 +49,7 @@ #include "src/ast/struct_member.h" #include "src/ast/struct_member_offset_decoration.h" #include "src/ast/switch_statement.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/array_type.h" #include "src/ast/type/depth_texture_type.h" #include "src/ast/type/f32_type.h" @@ -695,11 +696,13 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { push_debug(spv::Op::OpName, {Operand::Int(var_id), Operand::String(var->name())}); + auto* type = var->type()->UnwrapAll(); + OperandList ops = {Operand::Int(type_id), result, Operand::Int(ConvertStorageClass(sc))}; if (var->has_constructor()) { ops.push_back(Operand::Int(init_id)); - } else if (!var->type()->IsTexture() && !var->type()->IsSampler()) { + } else if (!type->IsTexture() && !type->IsSampler()) { // Certain cases require us to generate a constructor value. // // 1- ConstantId's must be attached to the OpConstant, if we have a @@ -707,7 +710,6 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { // one // 2- If we don't have a constructor and we're an Output or Private variable // then WGSL requires an initializer. - auto* type = var->type()->UnwrapPtrIfNeeded(); if (var->IsDecorated() && var->AsDecorated()->HasConstantIdDecoration()) { if (type->IsF32()) { ast::FloatLiteral l(type, 0.0f); @@ -2251,6 +2253,7 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) { return 0; } + // The alias is a wrapper around the subtype, so emit the subtype if (type->IsAlias()) { return GenerateTypeIfNeeded(type->AsAlias()->type()); } @@ -2263,7 +2266,18 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) { auto result = result_op(); auto id = result.to_i(); - if (type->IsArray()) { + if (type->IsAccessControl()) { + auto* ac = type->AsAccessControl(); + auto* subtype = ac->type()->UnwrapIfNeeded(); + if (!subtype->IsStruct()) { + error_ = "Access control attached to non-struct type."; + return 0; + } + if (!GenerateStructType(subtype->AsStruct(), ac->access_control(), + result)) { + return 0; + } + } else if (type->IsArray()) { if (!GenerateArrayType(type->AsArray(), result)) { return 0; } @@ -2282,7 +2296,8 @@ uint32_t Builder::GenerateTypeIfNeeded(ast::type::Type* type) { return 0; } } else if (type->IsStruct()) { - if (!GenerateStructType(type->AsStruct(), result)) { + if (!GenerateStructType(type->AsStruct(), ast::AccessControl::kReadWrite, + result)) { return 0; } } else if (type->IsU32()) { @@ -2449,6 +2464,7 @@ bool Builder::GeneratePointerType(ast::type::PointerType* ptr, } bool Builder::GenerateStructType(ast::type::StructType* struct_type, + ast::AccessControl access_control, const Operand& result) { auto struct_id = result.to_i(); auto* impl = struct_type->impl(); @@ -2473,6 +2489,19 @@ bool Builder::GenerateStructType(ast::type::StructType* struct_type, return false; } + // We're attaching the access control to the members of the struct instead + // of to the variable. The reason we do this is that WGSL models the access + // as part of the type. If we attach to the variable, it's no longer part + // of the type in the SPIR-V backend, but part of the variable. This differs + // from the modeling and other backends. Attaching to the struct members + // means the access control stays part of the type where it logically makes + // the most sense. + if (access_control == ast::AccessControl::kReadOnly) { + push_annot(spv::Op::OpMemberDecorate, + {Operand::Int(struct_id), Operand::Int(i), + Operand::Int(SpvDecorationNonWritable)}); + } + ops.push_back(Operand::Int(mem_id)); } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 29b5711df8..d0bc9b2f84 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -26,6 +26,7 @@ #include "src/ast/literal.h" #include "src/ast/module.h" #include "src/ast/struct_member.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/storage_texture_type.h" #include "src/ast/type_constructor_expression.h" #include "src/scope_stack.h" @@ -238,6 +239,13 @@ class Builder { /// @param func the function to generate for /// @returns the ID to use for the function type. Returns 0 on failure. uint32_t GenerateFunctionTypeIfNeeded(ast::Function* func); + /// Generates access control annotations if needed + /// @param type the type to generate for + /// @param struct_id the struct id + /// @param member_idx the member index + void GenerateMemberAccessControlIfNeeded(ast::type::Type* type, + uint32_t struct_id, + uint32_t member_idx); /// Generates a function variable /// @param var the variable /// @returns true if the variable was generated @@ -409,9 +417,11 @@ class Builder { bool GeneratePointerType(ast::type::PointerType* ptr, const Operand& result); /// Generates a vector type declaration /// @param struct_type the vector to generate + /// @param access_control the access controls to assign to the struct /// @param result the result operand /// @returns true if the vector was successfully generated bool GenerateStructType(ast::type::StructType* struct_type, + ast::AccessControl access_control, const Operand& result); /// Generates a struct member /// @param struct_id the id of the parent structure diff --git a/src/writer/spirv/builder_function_variable_test.cc b/src/writer/spirv/builder_function_variable_test.cc index e462ce1601..0f39a2e282 100644 --- a/src/writer/spirv/builder_function_variable_test.cc +++ b/src/writer/spirv/builder_function_variable_test.cc @@ -25,7 +25,10 @@ #include "src/ast/scalar_constructor_expression.h" #include "src/ast/set_decoration.h" #include "src/ast/storage_class.h" +#include "src/ast/struct.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/f32_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/variable.h" diff --git a/src/writer/spirv/builder_global_variable_test.cc b/src/writer/spirv/builder_global_variable_test.cc index 14106b9b0e..62e9b16059 100644 --- a/src/writer/spirv/builder_global_variable_test.cc +++ b/src/writer/spirv/builder_global_variable_test.cc @@ -27,9 +27,12 @@ #include "src/ast/scalar_constructor_expression.h" #include "src/ast/set_decoration.h" #include "src/ast/storage_class.h" +#include "src/ast/struct.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/bool_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" @@ -520,6 +523,168 @@ INSTANTIATE_TEST_SUITE_P( BuiltinData{ast::Builtin::kGlobalInvocationId, SpvBuiltInGlobalInvocationId})); +TEST_F(BuilderTest, GlobalVar_DeclReadOnly) { + // struct A { + // a : i32; + // }; + // var b : [[access(read)]] A + + ast::type::I32Type i32; + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("a", &i32, std::move(decos))); + members.push_back( + std::make_unique("b", &i32, std::move(decos))); + + ast::type::StructType A("A", + std::make_unique(std::move(members))); + ast::type::AccessControlType ac{ast::AccessControl::kReadOnly, &A}; + + ast::Variable var("b", ast::StorageClass::kStorageBuffer, &ac); + + ast::Module mod; + Builder b(&mod); + EXPECT_TRUE(b.GenerateGlobalVariable(&var)) << b.error(); + + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpMemberDecorate %3 0 NonWritable +OpMemberDecorate %3 1 NonWritable +)"); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "A" +OpMemberName %3 0 "a" +OpMemberName %3 1 "b" +OpName %1 "b" +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeStruct %4 %4 +%2 = OpTypePointer StorageBuffer %3 +%1 = OpVariable %2 StorageBuffer +)"); +} + +TEST_F(BuilderTest, GlobalVar_TypeAliasDeclReadOnly) { + // struct A { + // a : i32; + // }; + // type B = A; + // var b : [[access(read)]] B + + ast::type::I32Type i32; + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("a", &i32, std::move(decos))); + + ast::type::StructType A("A", + std::make_unique(std::move(members))); + ast::type::AliasType B("B", &A); + ast::type::AccessControlType ac{ast::AccessControl::kReadOnly, &B}; + + ast::Variable var("b", ast::StorageClass::kStorageBuffer, &ac); + + ast::Module mod; + Builder b(&mod); + EXPECT_TRUE(b.GenerateGlobalVariable(&var)) << b.error(); + + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpMemberDecorate %3 0 NonWritable +)"); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "A" +OpMemberName %3 0 "a" +OpName %1 "b" +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeStruct %4 +%2 = OpTypePointer StorageBuffer %3 +%1 = OpVariable %2 StorageBuffer +)"); +} + +TEST_F(BuilderTest, GlobalVar_TypeAliasAssignReadOnly) { + // struct A { + // a : i32; + // }; + // type B = [[access(read)]] A; + // var b : B + + ast::type::I32Type i32; + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("a", &i32, std::move(decos))); + + ast::type::StructType A("A", + std::make_unique(std::move(members))); + ast::type::AccessControlType ac{ast::AccessControl::kReadOnly, &A}; + ast::type::AliasType B("B", &ac); + + ast::Variable var("b", ast::StorageClass::kStorageBuffer, &B); + + ast::Module mod; + Builder b(&mod); + EXPECT_TRUE(b.GenerateGlobalVariable(&var)) << b.error(); + + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpMemberDecorate %3 0 NonWritable +)"); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "A" +OpMemberName %3 0 "a" +OpName %1 "b" +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeStruct %4 +%2 = OpTypePointer StorageBuffer %3 +%1 = OpVariable %2 StorageBuffer +)"); +} + +TEST_F(BuilderTest, GlobalVar_TwoVarDeclReadOnly) { + // struct A { + // a : i32; + // }; + // var b : [[access(read)]] A + // var c : [[access(read_write)]] A + + ast::type::I32Type i32; + + ast::StructMemberDecorationList decos; + ast::StructMemberList members; + members.push_back( + std::make_unique("a", &i32, std::move(decos))); + + ast::type::StructType A("A", + std::make_unique(std::move(members))); + ast::type::AccessControlType read{ast::AccessControl::kReadOnly, &A}; + ast::type::AccessControlType rw{ast::AccessControl::kReadWrite, &A}; + + ast::Variable var_b("b", ast::StorageClass::kStorageBuffer, &read); + ast::Variable var_c("c", ast::StorageClass::kStorageBuffer, &rw); + + ast::Module mod; + Builder b(&mod); + EXPECT_TRUE(b.GenerateGlobalVariable(&var_b)) << b.error(); + EXPECT_TRUE(b.GenerateGlobalVariable(&var_c)) << b.error(); + + EXPECT_EQ(DumpInstructions(b.annots()), R"(OpMemberDecorate %3 0 NonWritable +)"); + EXPECT_EQ(DumpInstructions(b.debug()), R"(OpName %3 "A" +OpMemberName %3 0 "a" +OpName %1 "b" +OpName %7 "A" +OpMemberName %7 0 "a" +OpName %5 "c" +)"); + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeInt 32 1 +%3 = OpTypeStruct %4 +%2 = OpTypePointer StorageBuffer %3 +%1 = OpVariable %2 StorageBuffer +%7 = OpTypeStruct %4 +%6 = OpTypePointer StorageBuffer %7 +%5 = OpVariable %6 StorageBuffer +)"); +} + } // namespace } // namespace spirv } // namespace writer diff --git a/src/writer/spirv/builder_type_test.cc b/src/writer/spirv/builder_type_test.cc index 92d67192ef..fe8710e7a4 100644 --- a/src/writer/spirv/builder_type_test.cc +++ b/src/writer/spirv/builder_type_test.cc @@ -20,6 +20,7 @@ #include "src/ast/struct.h" #include "src/ast/struct_member.h" #include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/alias_type.h" #include "src/ast/type/array_type.h" #include "src/ast/type/bool_type.h"