diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index c771175229..7c3ed46390 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -35,6 +35,7 @@ #include "src/ast/sint_literal.h" #include "src/ast/struct.h" #include "src/ast/switch_statement.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/alias_type.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" @@ -1206,7 +1207,16 @@ bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) { auto* var = data.first; auto* binding = data.second.binding; - out << "RWByteAddressBuffer " << var->name() << " : register(u" + if (!var->type()->IsAccessControl()) { + error_ = "access control type required for storage buffer"; + return false; + } + auto* ac = var->type()->AsAccessControl(); + + if (ac->IsReadWrite()) { + out << "RW"; + } + out << "ByteAddressBuffer " << var->name() << " : register(u" << binding->value() << ");" << std::endl; emitted_storagebuffer = true; } diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index a5cae3e840..4e84f6ff74 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -32,6 +32,7 @@ #include "src/ast/stage_decoration.h" #include "src/ast/struct.h" #include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" @@ -362,7 +363,7 @@ void frag_main() { } TEST_F(HlslGeneratorImplTest_Function, - Emit_FunctionDecoration_EntryPoint_With_StorageBuffer_Read) { + Emit_FunctionDecoration_EntryPoint_With_RW_StorageBuffer_Read) { ast::type::VoidType void_type; ast::type::F32Type f32; ast::type::I32Type i32; @@ -382,10 +383,11 @@ TEST_F(HlslGeneratorImplTest_Function, str->set_members(std::move(members)); ast::type::StructType s("Data", std::move(str)); + ast::type::AccessControlType ac(ast::type::AccessControl::kReadWrite, &s); auto coord_var = std::make_unique(std::make_unique( - "coord", ast::StorageClass::kStorageBuffer, &s)); + "coord", ast::StorageClass::kStorageBuffer, &ac)); ast::VariableDecorationList decos; decos.push_back(std::make_unique(0)); @@ -426,6 +428,72 @@ void frag_main() { )"); } +TEST_F(HlslGeneratorImplTest_Function, + Emit_FunctionDecoration_EntryPoint_With_RO_StorageBuffer_Read) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("b", &f32, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s("Data", std::move(str)); + ast::type::AccessControlType ac(ast::type::AccessControl::kReadOnly, &s); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kStorageBuffer, &ac)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + td().RegisterVariableForTesting(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("coord"), + std::make_unique("b"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(ByteAddressBuffer coord : register(u0); + +void frag_main() { + float v = asfloat(coord.Load(4)); + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_FunctionDecoration_EntryPoint_With_StorageBuffer_Store) { ast::type::VoidType void_type; @@ -447,10 +515,11 @@ TEST_F(HlslGeneratorImplTest_Function, str->set_members(std::move(members)); ast::type::StructType s("Data", std::move(str)); + ast::type::AccessControlType ac(ast::type::AccessControl::kReadWrite, &s); auto coord_var = std::make_unique(std::make_unique( - "coord", ast::StorageClass::kStorageBuffer, &s)); + "coord", ast::StorageClass::kStorageBuffer, &ac)); ast::VariableDecorationList decos; decos.push_back(std::make_unique(0)); @@ -830,10 +899,10 @@ TEST_F(HlslGeneratorImplTest_Function, ast::type::VoidType void_type; ast::type::F32Type f32; ast::type::VectorType vec4(&f32, 4); - + ast::type::AccessControlType ac(ast::type::AccessControl::kReadWrite, &vec4); auto coord_var = std::make_unique(std::make_unique( - "coord", ast::StorageClass::kStorageBuffer, &vec4)); + "coord", ast::StorageClass::kStorageBuffer, &ac)); ast::VariableDecorationList decos; decos.push_back(std::make_unique(0));