From 795bf4c716374b99a428b35ce5a50fac15e7c6de Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Thu, 5 Nov 2020 14:52:32 +0000 Subject: [PATCH] Fixup emitting duplicate globals in HLSL. This CL fixes the issue with duplicate globals being emitted in HLSL if used in multiple entry points. Tests are added for the other backends to verify the issue does not exist there. Bug: tint:297 Change-Id: I16d7504e8458fd375c6e1896758fe180ad963871 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/31880 Commit-Queue: Ryan Harrison Reviewed-by: Ryan Harrison --- src/writer/hlsl/generator_impl.cc | 22 ++- src/writer/hlsl/generator_impl.h | 6 +- ...tor_impl_function_entry_point_data_test.cc | 36 +++- .../hlsl/generator_impl_function_test.cc | 116 +++++++++++++ .../msl/generator_impl_function_test.cc | 121 ++++++++++++++ src/writer/spirv/builder_function_test.cc | 158 ++++++++++++++++++ src/writer/wgsl/generator_impl.cc | 22 ++- .../wgsl/generator_impl_function_test.cc | 133 +++++++++++++++ 8 files changed, 602 insertions(+), 12 deletions(-) diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 151ec79462..763bb66b68 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -135,13 +135,14 @@ bool GeneratorImpl::Generate(std::ostream& out) { } } + std::unordered_set emitted_globals; // Make sure all entry point data is emitted before the entry point functions for (const auto& func : module_->functions()) { if (!func->IsEntryPoint()) { continue; } - if (!EmitEntryPointData(out, func.get())) { + if (!EmitEntryPointData(out, func.get(), emitted_globals)) { return false; } } @@ -1136,7 +1137,10 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, return true; } -bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) { +bool GeneratorImpl::EmitEntryPointData( + std::ostream& out, + ast::Function* func, + std::unordered_set& emitted_globals) { std::vector> in_variables; std::vector> outvariables; for (auto data : func->referenced_location_variables()) { @@ -1174,6 +1178,13 @@ bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) { } // auto* set = data.second.set; + // If the global has already been emitted we skip it, it's been emitted by + // a previous entry point. + if (emitted_globals.count(var->name()) != 0) { + continue; + } + emitted_globals.insert(var->name()); + auto* type = var->type()->UnwrapIfNeeded(); if (type->IsStruct()) { auto* strct = type->AsStruct(); @@ -1210,6 +1221,13 @@ bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) { auto* var = data.first; auto* binding = data.second.binding; + // If the global has already been emitted we skip it, it's been emitted by + // a previous entry point. + if (emitted_globals.count(var->name()) != 0) { + continue; + } + emitted_globals.insert(var->name()); + if (!var->type()->IsAccessControl()) { error_ = "access control type required for storage buffer"; return false; diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index d106b66822..1bb9a32cd3 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -17,6 +17,7 @@ #include #include +#include #include "src/ast/intrinsic.h" #include "src/ast/literal.h" @@ -193,8 +194,11 @@ class GeneratorImpl { /// Handles emitting information for an entry point /// @param out the output stream /// @param func the entry point + /// @param emitted_globals the set of globals emitted over all entry points /// @returns true if the entry point data was emitted - bool EmitEntryPointData(std::ostream& out, ast::Function* func); + bool EmitEntryPointData(std::ostream& out, + ast::Function* func, + std::unordered_set& emitted_globals); /// Handles emitting the entry point function /// @param out the output stream /// @param func the entry point diff --git a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc index 8a81489bfa..e089fca6a0 100644 --- a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc +++ b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include "src/ast/assignment_statement.h" #include "src/ast/decorated_variable.h" @@ -88,8 +89,11 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, mod()->AddFunction(std::move(func)); + std::unordered_set globals; + ASSERT_TRUE(td().Determine()) << td().error(); - ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr, globals)) + << gen().error(); EXPECT_EQ(result(), R"(struct vtx_main_in { float foo : TEXCOORD0; int bar : TEXCOORD1; @@ -147,8 +151,11 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, mod()->AddFunction(std::move(func)); + std::unordered_set globals; + ASSERT_TRUE(td().Determine()) << td().error(); - ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr, globals)) + << gen().error(); EXPECT_EQ(result(), R"(struct vtx_main_out { float foo : TEXCOORD0; int bar : TEXCOORD1; @@ -205,8 +212,11 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, mod()->AddFunction(std::move(func)); + std::unordered_set globals; + ASSERT_TRUE(td().Determine()) << td().error(); - ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr, globals)) + << gen().error(); EXPECT_EQ(result(), R"(struct main_in { float foo : TEXCOORD0; int bar : TEXCOORD1; @@ -263,8 +273,11 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, mod()->AddFunction(std::move(func)); + std::unordered_set globals; + ASSERT_TRUE(td().Determine()) << td().error(); - ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr, globals)) + << gen().error(); EXPECT_EQ(result(), R"(struct main_out { float foo : SV_Target0; int bar : SV_Target1; @@ -318,8 +331,11 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, mod()->AddFunction(std::move(func)); + std::unordered_set globals; + ASSERT_TRUE(td().Determine()) << td().error(); - ASSERT_FALSE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + ASSERT_FALSE(gen().EmitEntryPointData(out(), func_ptr, globals)) + << gen().error(); EXPECT_EQ(gen().error(), R"(invalid location variable for pipeline stage)"); } @@ -368,8 +384,11 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, mod()->AddFunction(std::move(func)); + std::unordered_set globals; + ASSERT_TRUE(td().Determine()) << td().error(); - ASSERT_FALSE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + ASSERT_FALSE(gen().EmitEntryPointData(out(), func_ptr, globals)) + << gen().error(); EXPECT_EQ(gen().error(), R"(invalid location variable for pipeline stage)"); } @@ -429,8 +448,11 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, mod()->AddFunction(std::move(func)); + std::unordered_set globals; + ASSERT_TRUE(td().Determine()) << td().error(); - ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr)) << gen().error(); + ASSERT_TRUE(gen().EmitEntryPointData(out(), func_ptr, globals)) + << gen().error(); EXPECT_EQ(result(), R"(struct main_in { vector coord : SV_Position; }; diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 5edfe2b4ef..5f02e3a013 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -31,6 +31,7 @@ #include "src/ast/sint_literal.h" #include "src/ast/stage_decoration.h" #include "src/ast/struct.h" +#include "src/ast/struct_block_decoration.h" #include "src/ast/struct_member_offset_decoration.h" #include "src/ast/type/access_control_type.h" #include "src/ast/type/array_type.h" @@ -1133,6 +1134,121 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { )"); } +// https://crbug.com/tint/297 +TEST_F(HlslGeneratorImplTest_Function, + Emit_Multiple_EntryPoint_With_Same_ModuleVar) { + // [[block]] struct Data { + // [[offset(0)]] d : f32; + // }; + // [[binding(0), set(0)]] var data : Data; + // + // [[stage(compute)]] + // fn a() -> void { + // return; + // } + // + // [[stage(compute)]] + // fn b() -> void { + // return; + // } + + ast::type::VoidType void_type; + ast::type::F32Type f32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back( + std::make_unique(0, Source{})); + members.push_back( + std::make_unique("d", &f32, std::move(a_deco))); + + ast::StructDecorationList s_decos; + s_decos.push_back(std::make_unique(Source{})); + + auto str = + std::make_unique(std::move(s_decos), std::move(members)); + + ast::type::StructType s("Data", std::move(str)); + ast::type::AccessControlType ac(ast::AccessControl::kReadWrite, &s); + + auto data_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &ac)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0, Source{})); + decos.push_back(std::make_unique(0, Source{})); + data_var->set_decorations(std::move(decos)); + + mod()->AddConstructedType(&s); + td().RegisterVariableForTesting(data_var.get()); + mod()->AddGlobalVariable(std::move(data_var)); + + { + ast::VariableList params; + auto func = + std::make_unique("a", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + } + + { + ast::VariableList params; + auto func = + std::make_unique("b", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + } + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"(struct Data { + float d; +}; + +RWByteAddressBuffer data : register(u0); + +[numthreads(1, 1, 1)] +void a() { + float v = asfloat(data.Load(0)); + return; +} + +[numthreads(1, 1, 1)] +void b() { + float v = asfloat(data.Load(0)); + return; +} + +)"); +} + } // namespace } // namespace hlsl } // namespace writer diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 5cbe0d2cbb..a5d871b8d5 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -32,6 +32,7 @@ #include "src/ast/sint_literal.h" #include "src/ast/stage_decoration.h" #include "src/ast/struct.h" +#include "src/ast/struct_block_decoration.h" #include "src/ast/struct_member.h" #include "src/ast/struct_member_decoration.h" #include "src/ast/struct_member_offset_decoration.h" @@ -1168,6 +1169,126 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) { )"); } +// https://crbug.com/tint/297 +TEST_F(MslGeneratorImplTest, + Emit_Function_Multiple_EntryPoint_With_Same_ModuleVar) { + // [[block]] struct Data { + // [[offset(0)]] d : f32; + // }; + // [[binding(0), set(0)]] var data : Data; + // + // [[stage(compute)]] + // fn a() -> void { + // return; + // } + // + // [[stage(compute)]] + // fn b() -> void { + // return; + // } + + ast::type::VoidType void_type; + ast::type::F32Type f32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back( + std::make_unique(0, Source{})); + members.push_back( + std::make_unique("d", &f32, std::move(a_deco))); + + ast::StructDecorationList s_decos; + s_decos.push_back(std::make_unique(Source{})); + + auto str = + std::make_unique(std::move(s_decos), std::move(members)); + + ast::type::StructType s("Data", std::move(str)); + ast::type::AccessControlType ac(ast::AccessControl::kReadWrite, &s); + + auto data_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &ac)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0, Source{})); + decos.push_back(std::make_unique(0, Source{})); + data_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + mod.AddConstructedType(&s); + + td.RegisterVariableForTesting(data_var.get()); + mod.AddGlobalVariable(std::move(data_var)); + + { + ast::VariableList params; + auto func = + std::make_unique("a", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + } + + { + ast::VariableList params; + auto func = + std::make_unique("b", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + } + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct Data { + float d; +}; + +kernel void a(device Data& data [[buffer(0)]]) { + float v = data.d; + return; +} + +kernel void b(device Data& data [[buffer(0)]]) { + float v = data.d; + return; +} + +)"); +} + } // namespace } // namespace msl } // namespace writer diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index 266b7b75fd..ba4ad5922c 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -17,13 +17,22 @@ #include "gtest/gtest.h" #include "spirv/unified1/spirv.h" #include "spirv/unified1/spirv.hpp11" +#include "src/ast/decorated_variable.h" #include "src/ast/function.h" #include "src/ast/identifier_expression.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/return_statement.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/struct.h" +#include "src/ast/struct_block_decoration.h" +#include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/struct_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" +#include "src/ast/variable_decl_statement.h" #include "src/context.h" #include "src/type_determiner.h" #include "src/writer/spirv/builder.h" @@ -150,6 +159,155 @@ TEST_F(BuilderTest, FunctionType_DeDuplicate) { )"); } +// https://crbug.com/tint/297 +TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) { + // [[block]] struct Data { + // [[offset(0)]] d : f32; + // }; + // [[binding(0), set(0)]] var data : Data; + // + // [[stage(compute)]] + // fn a() -> void { + // return; + // } + // + // [[stage(compute)]] + // fn b() -> void { + // return; + // } + + ast::type::VoidType void_type; + ast::type::F32Type f32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back( + std::make_unique(0, Source{})); + members.push_back( + std::make_unique("d", &f32, std::move(a_deco))); + + ast::StructDecorationList s_decos; + s_decos.push_back(std::make_unique(Source{})); + + auto str = + std::make_unique(std::move(s_decos), std::move(members)); + + ast::type::StructType s("Data", std::move(str)); + ast::type::AccessControlType ac(ast::AccessControl::kReadWrite, &s); + + auto data_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &ac)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0, Source{})); + decos.push_back(std::make_unique(0, Source{})); + data_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + mod.AddConstructedType(&s); + + td.RegisterVariableForTesting(data_var.get()); + mod.AddGlobalVariable(std::move(data_var)); + + { + ast::VariableList params; + auto func = + std::make_unique("a", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + } + + { + ast::VariableList params; + auto func = + std::make_unique("b", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + } + + ASSERT_TRUE(td.Determine()) << td.error(); + + Builder b(&mod); + ASSERT_TRUE(b.Build()); + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %7 "a" +OpEntryPoint GLCompute %17 "b" +OpExecutionMode %7 LocalSize 1 1 1 +OpExecutionMode %17 LocalSize 1 1 1 +OpName %3 "Data" +OpMemberName %3 0 "d" +OpName %1 "data" +OpName %7 "a" +OpName %13 "v" +OpName %17 "b" +OpName %20 "v" +OpDecorate %3 Block +OpMemberDecorate %3 0 Offset 0 +OpDecorate %1 Binding 0 +OpDecorate %1 DescriptorSet 0 +%4 = OpTypeFloat 32 +%3 = OpTypeStruct %4 +%2 = OpTypePointer StorageBuffer %3 +%1 = OpVariable %2 StorageBuffer +%6 = OpTypeVoid +%5 = OpTypeFunction %6 +%9 = OpTypeInt 32 0 +%10 = OpConstant %9 0 +%11 = OpTypePointer StorageBuffer %4 +%14 = OpTypePointer Function %4 +%15 = OpConstantNull %4 +%7 = OpFunction %6 None %5 +%8 = OpLabel +%13 = OpVariable %14 Function %15 +%12 = OpAccessChain %11 %1 %10 +%16 = OpLoad %4 %12 +OpStore %13 %16 +OpReturn +OpFunctionEnd +%17 = OpFunction %6 None %5 +%18 = OpLabel +%20 = OpVariable %14 Function %15 +%19 = OpAccessChain %11 %1 %10 +%21 = OpLoad %4 %19 +OpStore %20 %21 +OpReturn +OpFunctionEnd +)"); +} + } // namespace } // namespace spirv } // namespace writer diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index eee6edb8ca..8266900358 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -51,6 +51,7 @@ #include "src/ast/struct_member.h" #include "src/ast/struct_member_offset_decoration.h" #include "src/ast/switch_statement.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/array_type.h" #include "src/ast/type/depth_texture_type.h" #include "src/ast/type/matrix_type.h" @@ -401,7 +402,24 @@ bool GeneratorImpl::EmitImageFormat(const ast::type::ImageFormat fmt) { } bool GeneratorImpl::EmitType(ast::type::Type* type) { - if (type->IsAlias()) { + if (type->IsAccessControl()) { + auto* ac = type->AsAccessControl(); + // TODO(dsinclair): Access control isn't supported in WGSL yet, so this + // is disabled for now. + // + // out_ << "[[access("; + // if (ac->IsReadOnly()) { + // out_ << "read"; + // } else if (ac->IsWriteOnly()) { + // out_ << "write"; + // } else { + // out_ << "read_write"; + // } + // out_ << ")]]" << std::endl; + if (!EmitType(ac->type())) { + return false; + } + } else if (type->IsAlias()) { out_ << type->AsAlias()->name(); } else if (type->IsArray()) { auto* ary = type->AsArray(); @@ -544,7 +562,7 @@ bool GeneratorImpl::EmitType(ast::type::Type* type) { } else if (type->IsVoid()) { out_ << "void"; } else { - error_ = "unknown type in EmitType"; + error_ = "unknown type in EmitType: " + type->type_name(); return false; } diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index 92f4b626ff..f7ba33b010 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -13,16 +13,25 @@ // limitations under the License. #include "gtest/gtest.h" +#include "src/ast/decorated_variable.h" #include "src/ast/discard_statement.h" #include "src/ast/function.h" +#include "src/ast/member_accessor_expression.h" +#include "src/ast/module.h" #include "src/ast/pipeline_stage.h" #include "src/ast/return_statement.h" #include "src/ast/stage_decoration.h" +#include "src/ast/struct_block_decoration.h" +#include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" +#include "src/ast/variable_decl_statement.h" #include "src/ast/workgroup_decoration.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/wgsl/generator_impl.h" namespace tint { @@ -152,6 +161,130 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) { )"); } +// https://crbug.com/tint/297 +TEST_F(WgslGeneratorImplTest, + Emit_Function_Multiple_EntryPoint_With_Same_ModuleVar) { + // [[block]] struct Data { + // [[offset(0)]] d : f32; + // }; + // [[binding(0), set(0)]] var data : Data; + // + // [[stage(compute)]] + // fn a() -> void { + // return; + // } + // + // [[stage(compute)]] + // fn b() -> void { + // return; + // } + + ast::type::VoidType void_type; + ast::type::F32Type f32; + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back( + std::make_unique(0, Source{})); + members.push_back( + std::make_unique("d", &f32, std::move(a_deco))); + + ast::StructDecorationList s_decos; + s_decos.push_back(std::make_unique(Source{})); + + auto str = + std::make_unique(std::move(s_decos), std::move(members)); + + ast::type::StructType s("Data", std::move(str)); + ast::type::AccessControlType ac(ast::AccessControl::kReadWrite, &s); + + auto data_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &ac)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0, Source{})); + decos.push_back(std::make_unique(0, Source{})); + data_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + mod.AddConstructedType(&s); + + td.RegisterVariableForTesting(data_var.get()); + mod.AddGlobalVariable(std::move(data_var)); + + { + ast::VariableList params; + auto func = + std::make_unique("a", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + } + + { + ast::VariableList params; + auto func = + std::make_unique("b", std::move(params), &void_type); + func->add_decoration(std::make_unique( + ast::PipelineStage::kCompute, Source{})); + + auto var = std::make_unique( + "v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("data"), + std::make_unique("d"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + } + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"([[block]] +struct Data { + [[offset(0)]] + d : f32; +}; + +[[binding(0), set(0)]] var data : Data; + +[[stage(compute)]] +fn a() -> void { + var v : f32 = data.d; + return; +} + +[[stage(compute)]] +fn b() -> void { + var v : f32 = data.d; + return; +} + +)"); +} + } // namespace } // namespace wgsl } // namespace writer