diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 876e851ba5..74585131a2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -657,8 +657,10 @@ if(${TINT_BUILD_TESTS}) writer/wgsl/generator_impl_constructor_test.cc writer/wgsl/generator_impl_continue_test.cc writer/wgsl/generator_impl_discard_test.cc + writer/wgsl/generator_impl_entry_point_test.cc writer/wgsl/generator_impl_fallthrough_test.cc writer/wgsl/generator_impl_function_test.cc + writer/wgsl/generator_impl_global_decl_test.cc writer/wgsl/generator_impl_identifier_test.cc writer/wgsl/generator_impl_if_test.cc writer/wgsl/generator_impl_loop_test.cc diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc index 14ce0627b1..783bb98645 100644 --- a/src/transform/bound_array_accessors_test.cc +++ b/src/transform/bound_array_accessors_test.cc @@ -25,6 +25,7 @@ using BoundArrayAccessorsTest = TransformTest; TEST_F(BoundArrayAccessorsTest, Ptrs_Clamp) { auto* src = R"( var a : array; + const c : u32 = 1u; fn f() -> void { @@ -34,6 +35,7 @@ fn f() -> void { auto* expect = R"( var a : array; + const c : u32 = 1u; fn f() -> void { @@ -49,7 +51,9 @@ fn f() -> void { TEST_F(BoundArrayAccessorsTest, Array_Idx_Nested_Scalar) { auto* src = R"( var a : array; + var b : array; + var i : u32; fn f() -> void { @@ -59,7 +63,9 @@ fn f() -> void { auto* expect = R"( var a : array; + var b : array; + var i : u32; fn f() -> void { @@ -97,6 +103,7 @@ fn f() -> void { TEST_F(BoundArrayAccessorsTest, Array_Idx_Expr) { auto* src = R"( var a : array; + var c : u32; fn f() -> void { @@ -106,6 +113,7 @@ fn f() -> void { auto* expect = R"( var a : array; + var c : u32; fn f() -> void { @@ -187,6 +195,7 @@ fn f() -> void { TEST_F(BoundArrayAccessorsTest, Vector_Idx_Expr) { auto* src = R"( var a : vec3; + var c : u32; fn f() -> void { @@ -196,6 +205,7 @@ fn f() -> void { auto* expect = R"( var a : vec3; + var c : u32; fn f() -> void { @@ -277,6 +287,7 @@ fn f() -> void { TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Column) { auto* src = R"( var a : mat3x2; + var c : u32; fn f() -> void { @@ -286,6 +297,7 @@ fn f() -> void { auto* expect = R"( var a : mat3x2; + var c : u32; fn f() -> void { @@ -301,6 +313,7 @@ fn f() -> void { TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Row) { auto* src = R"( var a : mat3x2; + var c : u32; fn f() -> void { @@ -310,6 +323,7 @@ fn f() -> void { auto* expect = R"( var a : mat3x2; + var c : u32; fn f() -> void { diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc index 0ef6417b8c..977a7166a7 100644 --- a/src/transform/first_index_offset_test.cc +++ b/src/transform/first_index_offset_test.cc @@ -76,15 +76,16 @@ fn entry() -> void { )"; auto* expect = R"( +[[builtin(vertex_index)]] var tint_first_index_offset_vert_idx : u32; + +[[binding(1), group(2)]] var tint_first_index_data : TintFirstIndexOffsetData; + [[block]] struct TintFirstIndexOffsetData { [[offset(0)]] tint_first_vertex_index : u32; }; -[[builtin(vertex_index)]] var tint_first_index_offset_vert_idx : u32; -[[binding(1), group(2)]] var tint_first_index_data : TintFirstIndexOffsetData; - fn test() -> u32 { const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index); return vert_idx; @@ -116,15 +117,16 @@ fn entry() -> void { )"; auto* expect = R"( +[[builtin(instance_index)]] var tint_first_index_offset_inst_idx : u32; + +[[binding(1), group(7)]] var tint_first_index_data : TintFirstIndexOffsetData; + [[block]] struct TintFirstIndexOffsetData { [[offset(0)]] tint_first_instance_index : u32; }; -[[builtin(instance_index)]] var tint_first_index_offset_inst_idx : u32; -[[binding(1), group(7)]] var tint_first_index_data : TintFirstIndexOffsetData; - fn test() -> u32 { const inst_idx : u32 = (tint_first_index_offset_inst_idx + tint_first_index_data.tint_first_instance_index); return inst_idx; @@ -157,6 +159,12 @@ fn entry() -> void { )"; auto* expect = R"( +[[builtin(instance_index)]] var tint_first_index_offset_instance_idx : u32; + +[[builtin(vertex_index)]] var tint_first_index_offset_vert_idx : u32; + +[[binding(1), group(2)]] var tint_first_index_data : TintFirstIndexOffsetData; + [[block]] struct TintFirstIndexOffsetData { [[offset(0)]] @@ -165,10 +173,6 @@ struct TintFirstIndexOffsetData { tint_first_instance_index : u32; }; -[[builtin(instance_index)]] var tint_first_index_offset_instance_idx : u32; -[[builtin(vertex_index)]] var tint_first_index_offset_vert_idx : u32; -[[binding(1), group(2)]] var tint_first_index_data : TintFirstIndexOffsetData; - fn test() -> u32 { const instance_idx : u32 = (tint_first_index_offset_instance_idx + tint_first_index_data.tint_first_instance_index); const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index); @@ -205,15 +209,16 @@ fn entry() -> void { )"; auto* expect = R"( +[[builtin(vertex_index)]] var tint_first_index_offset_vert_idx : u32; + +[[binding(1), group(2)]] var tint_first_index_data : TintFirstIndexOffsetData; + [[block]] struct TintFirstIndexOffsetData { [[offset(0)]] tint_first_vertex_index : u32; }; -[[builtin(vertex_index)]] var tint_first_index_offset_vert_idx : u32; -[[binding(1), group(2)]] var tint_first_index_data : TintFirstIndexOffsetData; - fn func1() -> u32 { const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index); return vert_idx; diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index edc8abcf82..c81a00e931 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -127,14 +127,16 @@ fn main() -> void {} )"; auto* expect = R"( +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; + +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + [[block]] struct TintVertexData { [[offset(0)]] _tint_vertex_data : [[stride(4)]] array; }; -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; var var_a : f32; [[stage(vertex)]] @@ -166,14 +168,16 @@ fn main() -> void {} )"; auto* expect = R"( +[[builtin(instance_index)]] var _tint_pulling_instance_index : i32; + +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + [[block]] struct TintVertexData { [[offset(0)]] _tint_vertex_data : [[stride(4)]] array; }; -[[builtin(instance_index)]] var _tint_pulling_instance_index : i32; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; var var_a : f32; [[stage(vertex)]] @@ -205,14 +209,16 @@ fn main() -> void {} )"; auto* expect = R"( +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; + +[[binding(0), group(5)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + [[block]] struct TintVertexData { [[offset(0)]] _tint_vertex_data : [[stride(4)]] array; }; -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; -[[binding(0), group(5)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; var var_a : f32; [[stage(vertex)]] @@ -249,17 +255,22 @@ fn main() -> void {} )"; auto* expect = R"( +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + +[[binding(1), group(4)]] var _tint_pulling_vertex_buffer_1 : TintVertexData; + [[block]] struct TintVertexData { [[offset(0)]] _tint_vertex_data : [[stride(4)]] array; }; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; -[[binding(1), group(4)]] var _tint_pulling_vertex_buffer_1 : TintVertexData; var var_a : f32; + var var_b : f32; + [[builtin(vertex_index)]] var custom_vertex_index : i32; + [[builtin(instance_index)]] var custom_instance_index : i32; [[stage(vertex)]] @@ -295,15 +306,18 @@ fn main() -> void {} )"; auto* expect = R"( +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; + +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + [[block]] struct TintVertexData { [[offset(0)]] _tint_vertex_data : [[stride(4)]] array; }; -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; var var_a : f32; + var var_b : array; [[stage(vertex)]] @@ -341,18 +355,24 @@ fn main() -> void {} )"; auto* expect = R"( +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; + +[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; + +[[binding(1), group(4)]] var _tint_pulling_vertex_buffer_1 : TintVertexData; + +[[binding(2), group(4)]] var _tint_pulling_vertex_buffer_2 : TintVertexData; + [[block]] struct TintVertexData { [[offset(0)]] _tint_vertex_data : [[stride(4)]] array; }; -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; -[[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; -[[binding(1), group(4)]] var _tint_pulling_vertex_buffer_1 : TintVertexData; -[[binding(2), group(4)]] var _tint_pulling_vertex_buffer_2 : TintVertexData; var var_a : array; + var var_b : array; + var var_c : array; [[stage(vertex)]] diff --git a/src/writer/wgsl/generator.cc b/src/writer/wgsl/generator.cc index 5fc91cc994..90ef18e370 100644 --- a/src/writer/wgsl/generator.cc +++ b/src/writer/wgsl/generator.cc @@ -26,7 +26,7 @@ Generator::Generator(const Program* program) Generator::~Generator() = default; bool Generator::Generate() { - auto ret = impl_->Generate(); + auto ret = impl_->Generate(nullptr); if (!ret) { error_ = impl_->error(); } diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 4e08026cbc..828af6ad98 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -1,4 +1,4 @@ -// Copyright 2020 The Tint Authors. +// Copyright 2021 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -88,30 +88,44 @@ GeneratorImpl::GeneratorImpl(const Program* program) GeneratorImpl::~GeneratorImpl() = default; -bool GeneratorImpl::Generate() { - for (auto* const ty : program_->AST().ConstructedTypes()) { - if (!EmitConstructedType(ty)) { - return false; +bool GeneratorImpl::Generate(const ast::Function* entry) { + // Generate global declarations in the order they appear in the module. + for (auto* decl : program_->AST().GlobalDeclarations()) { + if (auto* ty = decl->As()) { + if (!EmitConstructedType(ty)) { + return false; + } + } else if (auto* func = decl->As()) { + if (entry && func != entry) { + // Skip functions that are not reachable by the target entry point. + auto* sem = program_->Sem().Get(func); + if (!sem->HasAncestorEntryPoint(entry->symbol())) { + continue; + } + } + if (!EmitFunction(func)) { + return false; + } + } else if (auto* var = decl->As()) { + if (entry && !var->is_const()) { + // Skip variables that are not referenced by the target entry point. + auto& refs = program_->Sem().Get(entry)->ReferencedModuleVariables(); + if (std::find(refs.begin(), refs.end(), program_->Sem().Get(var)) == + refs.end()) { + continue; + } + } + if (!EmitVariable(var)) { + return false; + } + } else { + assert(false /* unreachable */); } - } - if (!program_->AST().ConstructedTypes().empty()) - out_ << std::endl; - for (auto* var : program_->AST().GlobalVariables()) { - if (!EmitVariable(var)) { - return false; + if (decl != program_->AST().GlobalDeclarations().back()) { + out_ << std::endl; } } - if (!program_->AST().GlobalVariables().empty()) { - out_ << std::endl; - } - - for (auto* func : program_->AST().Functions()) { - if (!EmitFunction(func)) { - return false; - } - out_ << std::endl; - } return true; } @@ -124,58 +138,7 @@ bool GeneratorImpl::GenerateEntryPoint(ast::PipelineStage stage, error_ = "Unable to find requested entry point: " + name; return false; } - - // TODO(dsinclair): We always emit constructed types even if they aren't - // strictly needed - for (auto* const ty : program_->AST().ConstructedTypes()) { - if (!EmitConstructedType(ty)) { - return false; - } - } - if (!program_->AST().ConstructedTypes().empty()) { - out_ << std::endl; - } - - // TODO(dsinclair): This should be smarter and only emit needed const - // variables - for (auto* var : program_->AST().GlobalVariables()) { - if (!var->is_const()) { - continue; - } - if (!EmitVariable(var)) { - return false; - } - } - - bool found_func_variable = false; - for (auto* var : program_->Sem().Get(func)->ReferencedModuleVariables()) { - if (!EmitVariable(var->Declaration())) { - return false; - } - found_func_variable = true; - } - if (found_func_variable) { - out_ << std::endl; - } - - for (auto* f : program_->AST().Functions()) { - auto* f_sem = program_->Sem().Get(f); - if (!f_sem->HasAncestorEntryPoint(program_->Symbols().Get(name))) { - continue; - } - - if (!EmitFunction(f)) { - return false; - } - out_ << std::endl; - } - - if (!EmitFunction(func)) { - return false; - } - out_ << std::endl; - - return true; + return Generate(func); } bool GeneratorImpl::EmitConstructedType(const type::Type* ty) { diff --git a/src/writer/wgsl/generator_impl.h b/src/writer/wgsl/generator_impl.h index 34ad398261..9befd510d9 100644 --- a/src/writer/wgsl/generator_impl.h +++ b/src/writer/wgsl/generator_impl.h @@ -57,9 +57,10 @@ class GeneratorImpl : public TextGenerator { explicit GeneratorImpl(const Program* program); ~GeneratorImpl(); - /// Generates the result data + /// Generates the result data, optionally restricted to a single entry point + /// @param entry entry point to target, or nullptr /// @returns true on successful generation; false otherwise - bool Generate(); + bool Generate(const ast::Function* entry); /// Generates a single entry point /// @param stage the pipeline stage diff --git a/src/writer/wgsl/generator_impl_entry_point_test.cc b/src/writer/wgsl/generator_impl_entry_point_test.cc new file mode 100644 index 0000000000..057629d361 --- /dev/null +++ b/src/writer/wgsl/generator_impl_entry_point_test.cc @@ -0,0 +1,171 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "src/ast/assignment_statement.h" +#include "src/ast/call_statement.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/variable.h" +#include "src/ast/variable_decl_statement.h" +#include "src/type/f32_type.h" +#include "src/writer/wgsl/generator_impl.h" +#include "src/writer/wgsl/test_helper.h" + +namespace tint { +namespace writer { +namespace wgsl { +namespace { + +using WgslGeneratorImplTest = TestHelper; + +TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_UnusedFunction) { + Func("func_unused", ast::VariableList{}, ty.void_(), ast::StatementList{}, + ast::FunctionDecorationList{}); + + Func("func_used", ast::VariableList{}, ty.void_(), ast::StatementList{}, + ast::FunctionDecorationList{}); + + auto* call_func = Call("func_used"); + + Func("main", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create(call_func), + }, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute), + }); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE( + gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main")) + << gen.error(); + EXPECT_EQ(gen.result(), R"( fn func_used() -> void { + } + + [[stage(compute)]] + fn main() -> void { + func_used(); + } +)"); +} + +TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_UnusedVariable) { + auto* global_unused = + Global("global_unused", ast::StorageClass::kInput, ty.f32()); + create(global_unused); + + auto* global_used = + Global("global_used", ast::StorageClass::kInput, ty.f32()); + create(global_used); + + Func("main", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create(Expr("global_used"), Expr(1.f)), + }, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute), + }); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE( + gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main")) + << gen.error(); + EXPECT_EQ(gen.result(), R"( var global_used : f32; + + [[stage(compute)]] + fn main() -> void { + global_used = 1.0; + } +)"); +} + +TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_GlobalsInterleaved) { + auto* global0 = Global("a0", ast::StorageClass::kInput, ty.f32()); + create(global0); + + auto* str0 = create(ast::StructMemberList{Member("a", ty.i32())}, + ast::StructDecorationList{}); + auto* s0 = ty.struct_("S0", str0); + AST().AddConstructedType(s0); + + Func("func", ast::VariableList{}, ty.f32(), + ast::StatementList{ + create(Expr("a0")), + }, + ast::FunctionDecorationList{}); + + auto* global1 = Global("a1", ast::StorageClass::kOutput, ty.f32()); + create(global1); + + auto* str1 = create(ast::StructMemberList{Member("a", ty.i32())}, + ast::StructDecorationList{}); + auto* s1 = ty.struct_("S1", str1); + AST().AddConstructedType(s1); + + auto* call_func = Call("func"); + + Func("main", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create( + Var("s0", ast::StorageClass::kFunction, s0)), + create( + Var("s1", ast::StorageClass::kFunction, s1)), + create(Expr("a1"), Expr(call_func)), + }, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute), + }); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE( + gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main")) + << gen.error(); + EXPECT_EQ(gen.result(), R"( var a0 : f32; + + struct S0 { + a : i32; + }; + + fn func() -> f32 { + return a0; + } + + var a1 : f32; + + struct S1 { + a : i32; + }; + + [[stage(compute)]] + fn main() -> void { + var s0 : S0; + var s1 : S1; + a1 = func(); + } +)"); +} + +} // namespace +} // namespace wgsl +} // namespace writer +} // namespace tint diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index 4fd3f02e5e..9414cde645 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -183,6 +183,7 @@ TEST_F(WgslGeneratorImplTest, auto* s = ty.struct_("Data", str); type::AccessControl ac(ast::AccessControl::kReadWrite, s); + AST().AddConstructedType(s); Global("data", ast::StorageClass::kStorage, &ac, nullptr, ast::VariableDecorationList{ @@ -190,8 +191,6 @@ TEST_F(WgslGeneratorImplTest, create(0), }); - AST().AddConstructedType(s); - { auto* var = Var("v", ast::StorageClass::kFunction, ty.f32(), @@ -226,7 +225,7 @@ TEST_F(WgslGeneratorImplTest, GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.Generate()) << gen.error(); + ASSERT_TRUE(gen.Generate(nullptr)) << gen.error(); EXPECT_EQ(gen.result(), R"([[block]] struct Data { [[offset(0)]] @@ -247,7 +246,6 @@ fn b() -> void { var v : f32 = data.d; return; } - )"); } diff --git a/src/writer/wgsl/generator_impl_global_decl_test.cc b/src/writer/wgsl/generator_impl_global_decl_test.cc new file mode 100644 index 0000000000..464ee70a94 --- /dev/null +++ b/src/writer/wgsl/generator_impl_global_decl_test.cc @@ -0,0 +1,123 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "src/ast/assignment_statement.h" +#include "src/ast/call_statement.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/variable.h" +#include "src/ast/variable_decl_statement.h" +#include "src/type/f32_type.h" +#include "src/writer/wgsl/generator_impl.h" +#include "src/writer/wgsl/test_helper.h" + +namespace tint { +namespace writer { +namespace wgsl { +namespace { + +using WgslGeneratorImplTest = TestHelper; + +TEST_F(WgslGeneratorImplTest, Emit_GlobalDeclAfterFunction) { + auto* func_var = Var("a", ast::StorageClass::kFunction, ty.f32(), nullptr, + ast::VariableDecorationList{}); + WrapInFunction(create(func_var)); + + auto* global_var = Global("a", ast::StorageClass::kInput, ty.f32()); + create(global_var); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.Generate(nullptr)) << gen.error(); + EXPECT_EQ(gen.result(), R"( fn test_function() -> void { + var a : f32; + } + + var a : f32; +)"); +} + +TEST_F(WgslGeneratorImplTest, Emit_GlobalsInterleaved) { + auto* global0 = Global("a0", ast::StorageClass::kInput, ty.f32()); + create(global0); + + auto* str0 = create(ast::StructMemberList{Member("a", ty.i32())}, + ast::StructDecorationList{}); + auto* s0 = ty.struct_("S0", str0); + AST().AddConstructedType(s0); + + Func("func", ast::VariableList{}, ty.f32(), + ast::StatementList{ + create(Expr("a0")), + }, + ast::FunctionDecorationList{}); + + auto* global1 = Global("a1", ast::StorageClass::kOutput, ty.f32()); + create(global1); + + auto* str1 = create(ast::StructMemberList{Member("a", ty.i32())}, + ast::StructDecorationList{}); + auto* s1 = ty.struct_("S1", str1); + AST().AddConstructedType(s1); + + auto* call_func = Call("func"); + + Func("main", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create( + Var("s0", ast::StorageClass::kFunction, s0)), + create( + Var("s1", ast::StorageClass::kFunction, s1)), + create(Expr("a1"), Expr(call_func)), + }, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute), + }); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.Generate(nullptr)) << gen.error(); + EXPECT_EQ(gen.result(), R"( var a0 : f32; + + struct S0 { + a : i32; + }; + + fn func() -> f32 { + return a0; + } + + var a1 : f32; + + struct S1 { + a : i32; + }; + + [[stage(compute)]] + fn main() -> void { + var s0 : S0; + var s1 : S1; + a1 = func(); + } +)"); +} + +} // namespace +} // namespace wgsl +} // namespace writer +} // namespace tint diff --git a/src/writer/wgsl/generator_impl_test.cc b/src/writer/wgsl/generator_impl_test.cc index 5619268b93..3d684c9606 100644 --- a/src/writer/wgsl/generator_impl_test.cc +++ b/src/writer/wgsl/generator_impl_test.cc @@ -35,10 +35,9 @@ TEST_F(WgslGeneratorImplTest, Generate) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.Generate()) << gen.error(); + ASSERT_TRUE(gen.Generate(nullptr)) << gen.error(); EXPECT_EQ(gen.result(), R"(fn my_func() -> void { } - )"); }