[wgsl-writer] Emit globals in declared order

Instead of emitting all global variables and then functions, emit
global declarations in the order they were added to the AST.

This fixes issues where the reording might generate an invalid WGSL
program from a valid input (e.g. when declaring a global variable with
the same name as a variable inside a function that precedes it).

This also unifies the implementation of Generate() and
GenerateEntryPoint(), to avoid implementing the same logic twice.

Change-Id: I60a4e5ed4a054562cdcc3d028f8d577434a6d713
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/41303
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
James Price 2021-02-10 15:34:37 +00:00 committed by Commit Bot service account
parent 0c7f97626f
commit dfd1714174
11 changed files with 405 additions and 109 deletions

View File

@ -657,8 +657,10 @@ if(${TINT_BUILD_TESTS})
writer/wgsl/generator_impl_constructor_test.cc writer/wgsl/generator_impl_constructor_test.cc
writer/wgsl/generator_impl_continue_test.cc writer/wgsl/generator_impl_continue_test.cc
writer/wgsl/generator_impl_discard_test.cc writer/wgsl/generator_impl_discard_test.cc
writer/wgsl/generator_impl_entry_point_test.cc
writer/wgsl/generator_impl_fallthrough_test.cc writer/wgsl/generator_impl_fallthrough_test.cc
writer/wgsl/generator_impl_function_test.cc writer/wgsl/generator_impl_function_test.cc
writer/wgsl/generator_impl_global_decl_test.cc
writer/wgsl/generator_impl_identifier_test.cc writer/wgsl/generator_impl_identifier_test.cc
writer/wgsl/generator_impl_if_test.cc writer/wgsl/generator_impl_if_test.cc
writer/wgsl/generator_impl_loop_test.cc writer/wgsl/generator_impl_loop_test.cc

View File

@ -25,6 +25,7 @@ using BoundArrayAccessorsTest = TransformTest;
TEST_F(BoundArrayAccessorsTest, Ptrs_Clamp) { TEST_F(BoundArrayAccessorsTest, Ptrs_Clamp) {
auto* src = R"( auto* src = R"(
var a : array<f32, 3>; var a : array<f32, 3>;
const c : u32 = 1u; const c : u32 = 1u;
fn f() -> void { fn f() -> void {
@ -34,6 +35,7 @@ fn f() -> void {
auto* expect = R"( auto* expect = R"(
var a : array<f32, 3>; var a : array<f32, 3>;
const c : u32 = 1u; const c : u32 = 1u;
fn f() -> void { fn f() -> void {
@ -49,7 +51,9 @@ fn f() -> void {
TEST_F(BoundArrayAccessorsTest, Array_Idx_Nested_Scalar) { TEST_F(BoundArrayAccessorsTest, Array_Idx_Nested_Scalar) {
auto* src = R"( auto* src = R"(
var a : array<f32, 3>; var a : array<f32, 3>;
var b : array<f32, 5>; var b : array<f32, 5>;
var i : u32; var i : u32;
fn f() -> void { fn f() -> void {
@ -59,7 +63,9 @@ fn f() -> void {
auto* expect = R"( auto* expect = R"(
var a : array<f32, 3>; var a : array<f32, 3>;
var b : array<f32, 5>; var b : array<f32, 5>;
var i : u32; var i : u32;
fn f() -> void { fn f() -> void {
@ -97,6 +103,7 @@ fn f() -> void {
TEST_F(BoundArrayAccessorsTest, Array_Idx_Expr) { TEST_F(BoundArrayAccessorsTest, Array_Idx_Expr) {
auto* src = R"( auto* src = R"(
var a : array<f32, 3>; var a : array<f32, 3>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {
@ -106,6 +113,7 @@ fn f() -> void {
auto* expect = R"( auto* expect = R"(
var a : array<f32, 3>; var a : array<f32, 3>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {
@ -187,6 +195,7 @@ fn f() -> void {
TEST_F(BoundArrayAccessorsTest, Vector_Idx_Expr) { TEST_F(BoundArrayAccessorsTest, Vector_Idx_Expr) {
auto* src = R"( auto* src = R"(
var a : vec3<f32>; var a : vec3<f32>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {
@ -196,6 +205,7 @@ fn f() -> void {
auto* expect = R"( auto* expect = R"(
var a : vec3<f32>; var a : vec3<f32>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {
@ -277,6 +287,7 @@ fn f() -> void {
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Column) { TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Column) {
auto* src = R"( auto* src = R"(
var a : mat3x2<f32>; var a : mat3x2<f32>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {
@ -286,6 +297,7 @@ fn f() -> void {
auto* expect = R"( auto* expect = R"(
var a : mat3x2<f32>; var a : mat3x2<f32>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {
@ -301,6 +313,7 @@ fn f() -> void {
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Row) { TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Row) {
auto* src = R"( auto* src = R"(
var a : mat3x2<f32>; var a : mat3x2<f32>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {
@ -310,6 +323,7 @@ fn f() -> void {
auto* expect = R"( auto* expect = R"(
var a : mat3x2<f32>; var a : mat3x2<f32>;
var c : u32; var c : u32;
fn f() -> void { fn f() -> void {

View File

@ -76,15 +76,16 @@ fn entry() -> void {
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
[[block]] [[block]]
struct TintFirstIndexOffsetData { struct TintFirstIndexOffsetData {
[[offset(0)]] [[offset(0)]]
tint_first_vertex_index : u32; tint_first_vertex_index : u32;
}; };
[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
fn test() -> u32 { fn test() -> u32 {
const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index); const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index);
return vert_idx; return vert_idx;
@ -116,15 +117,16 @@ fn entry() -> void {
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(instance_index)]] var<in> tint_first_index_offset_inst_idx : u32;
[[binding(1), group(7)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
[[block]] [[block]]
struct TintFirstIndexOffsetData { struct TintFirstIndexOffsetData {
[[offset(0)]] [[offset(0)]]
tint_first_instance_index : u32; tint_first_instance_index : u32;
}; };
[[builtin(instance_index)]] var<in> tint_first_index_offset_inst_idx : u32;
[[binding(1), group(7)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
fn test() -> u32 { fn test() -> u32 {
const inst_idx : u32 = (tint_first_index_offset_inst_idx + tint_first_index_data.tint_first_instance_index); const inst_idx : u32 = (tint_first_index_offset_inst_idx + tint_first_index_data.tint_first_instance_index);
return inst_idx; return inst_idx;
@ -157,6 +159,12 @@ fn entry() -> void {
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(instance_index)]] var<in> tint_first_index_offset_instance_idx : u32;
[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
[[block]] [[block]]
struct TintFirstIndexOffsetData { struct TintFirstIndexOffsetData {
[[offset(0)]] [[offset(0)]]
@ -165,10 +173,6 @@ struct TintFirstIndexOffsetData {
tint_first_instance_index : u32; tint_first_instance_index : u32;
}; };
[[builtin(instance_index)]] var<in> tint_first_index_offset_instance_idx : u32;
[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
fn test() -> u32 { fn test() -> u32 {
const instance_idx : u32 = (tint_first_index_offset_instance_idx + tint_first_index_data.tint_first_instance_index); const instance_idx : u32 = (tint_first_index_offset_instance_idx + tint_first_index_data.tint_first_instance_index);
const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index); const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index);
@ -205,15 +209,16 @@ fn entry() -> void {
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
[[block]] [[block]]
struct TintFirstIndexOffsetData { struct TintFirstIndexOffsetData {
[[offset(0)]] [[offset(0)]]
tint_first_vertex_index : u32; tint_first_vertex_index : u32;
}; };
[[builtin(vertex_index)]] var<in> tint_first_index_offset_vert_idx : u32;
[[binding(1), group(2)]] var<uniform> tint_first_index_data : TintFirstIndexOffsetData;
fn func1() -> u32 { fn func1() -> u32 {
const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index); const vert_idx : u32 = (tint_first_index_offset_vert_idx + tint_first_index_data.tint_first_vertex_index);
return vert_idx; return vert_idx;

View File

@ -127,14 +127,16 @@ fn main() -> void {}
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[block]] [[block]]
struct TintVertexData { struct TintVertexData {
[[offset(0)]] [[offset(0)]]
_tint_vertex_data : [[stride(4)]] array<u32>; _tint_vertex_data : [[stride(4)]] array<u32>;
}; };
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
var<private> var_a : f32; var<private> var_a : f32;
[[stage(vertex)]] [[stage(vertex)]]
@ -166,14 +168,16 @@ fn main() -> void {}
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(instance_index)]] var<in> _tint_pulling_instance_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[block]] [[block]]
struct TintVertexData { struct TintVertexData {
[[offset(0)]] [[offset(0)]]
_tint_vertex_data : [[stride(4)]] array<u32>; _tint_vertex_data : [[stride(4)]] array<u32>;
}; };
[[builtin(instance_index)]] var<in> _tint_pulling_instance_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
var<private> var_a : f32; var<private> var_a : f32;
[[stage(vertex)]] [[stage(vertex)]]
@ -205,14 +209,16 @@ fn main() -> void {}
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[block]] [[block]]
struct TintVertexData { struct TintVertexData {
[[offset(0)]] [[offset(0)]]
_tint_vertex_data : [[stride(4)]] array<u32>; _tint_vertex_data : [[stride(4)]] array<u32>;
}; };
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
var<private> var_a : f32; var<private> var_a : f32;
[[stage(vertex)]] [[stage(vertex)]]
@ -249,17 +255,22 @@ fn main() -> void {}
)"; )";
auto* expect = R"( auto* expect = R"(
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
[[block]] [[block]]
struct TintVertexData { struct TintVertexData {
[[offset(0)]] [[offset(0)]]
_tint_vertex_data : [[stride(4)]] array<u32>; _tint_vertex_data : [[stride(4)]] array<u32>;
}; };
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
var<private> var_a : f32; var<private> var_a : f32;
var<private> var_b : f32; var<private> var_b : f32;
[[builtin(vertex_index)]] var<in> custom_vertex_index : i32; [[builtin(vertex_index)]] var<in> custom_vertex_index : i32;
[[builtin(instance_index)]] var<in> custom_instance_index : i32; [[builtin(instance_index)]] var<in> custom_instance_index : i32;
[[stage(vertex)]] [[stage(vertex)]]
@ -295,15 +306,18 @@ fn main() -> void {}
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[block]] [[block]]
struct TintVertexData { struct TintVertexData {
[[offset(0)]] [[offset(0)]]
_tint_vertex_data : [[stride(4)]] array<u32>; _tint_vertex_data : [[stride(4)]] array<u32>;
}; };
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
var<private> var_a : f32; var<private> var_a : f32;
var<private> var_b : array<f32, 4>; var<private> var_b : array<f32, 4>;
[[stage(vertex)]] [[stage(vertex)]]
@ -341,18 +355,24 @@ fn main() -> void {}
)"; )";
auto* expect = R"( auto* expect = R"(
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
[[binding(2), group(4)]] var<storage> _tint_pulling_vertex_buffer_2 : TintVertexData;
[[block]] [[block]]
struct TintVertexData { struct TintVertexData {
[[offset(0)]] [[offset(0)]]
_tint_vertex_data : [[stride(4)]] array<u32>; _tint_vertex_data : [[stride(4)]] array<u32>;
}; };
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
[[binding(1), group(4)]] var<storage> _tint_pulling_vertex_buffer_1 : TintVertexData;
[[binding(2), group(4)]] var<storage> _tint_pulling_vertex_buffer_2 : TintVertexData;
var<private> var_a : array<f32, 2>; var<private> var_a : array<f32, 2>;
var<private> var_b : array<f32, 3>; var<private> var_b : array<f32, 3>;
var<private> var_c : array<f32, 4>; var<private> var_c : array<f32, 4>;
[[stage(vertex)]] [[stage(vertex)]]

View File

@ -26,7 +26,7 @@ Generator::Generator(const Program* program)
Generator::~Generator() = default; Generator::~Generator() = default;
bool Generator::Generate() { bool Generator::Generate() {
auto ret = impl_->Generate(); auto ret = impl_->Generate(nullptr);
if (!ret) { if (!ret) {
error_ = impl_->error(); error_ = impl_->error();
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020 The Tint Authors. // Copyright 2021 The Tint Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -88,30 +88,44 @@ GeneratorImpl::GeneratorImpl(const Program* program)
GeneratorImpl::~GeneratorImpl() = default; GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() { bool GeneratorImpl::Generate(const ast::Function* entry) {
for (auto* const ty : program_->AST().ConstructedTypes()) { // Generate global declarations in the order they appear in the module.
for (auto* decl : program_->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<type::Type>()) {
if (!EmitConstructedType(ty)) { if (!EmitConstructedType(ty)) {
return false; return false;
} }
} } else if (auto* func = decl->As<ast::Function>()) {
if (!program_->AST().ConstructedTypes().empty()) if (entry && func != entry) {
out_ << std::endl; // Skip functions that are not reachable by the target entry point.
auto* sem = program_->Sem().Get(func);
for (auto* var : program_->AST().GlobalVariables()) { if (!sem->HasAncestorEntryPoint(entry->symbol())) {
if (!EmitVariable(var)) { continue;
return false;
} }
} }
if (!program_->AST().GlobalVariables().empty()) {
out_ << std::endl;
}
for (auto* func : program_->AST().Functions()) {
if (!EmitFunction(func)) { if (!EmitFunction(func)) {
return false; return false;
} }
} else if (auto* var = decl->As<ast::Variable>()) {
if (entry && !var->is_const()) {
// Skip variables that are not referenced by the target entry point.
auto& refs = program_->Sem().Get(entry)->ReferencedModuleVariables();
if (std::find(refs.begin(), refs.end(), program_->Sem().Get(var)) ==
refs.end()) {
continue;
}
}
if (!EmitVariable(var)) {
return false;
}
} else {
assert(false /* unreachable */);
}
if (decl != program_->AST().GlobalDeclarations().back()) {
out_ << std::endl; out_ << std::endl;
} }
}
return true; return true;
} }
@ -124,58 +138,7 @@ bool GeneratorImpl::GenerateEntryPoint(ast::PipelineStage stage,
error_ = "Unable to find requested entry point: " + name; error_ = "Unable to find requested entry point: " + name;
return false; return false;
} }
return Generate(func);
// TODO(dsinclair): We always emit constructed types even if they aren't
// strictly needed
for (auto* const ty : program_->AST().ConstructedTypes()) {
if (!EmitConstructedType(ty)) {
return false;
}
}
if (!program_->AST().ConstructedTypes().empty()) {
out_ << std::endl;
}
// TODO(dsinclair): This should be smarter and only emit needed const
// variables
for (auto* var : program_->AST().GlobalVariables()) {
if (!var->is_const()) {
continue;
}
if (!EmitVariable(var)) {
return false;
}
}
bool found_func_variable = false;
for (auto* var : program_->Sem().Get(func)->ReferencedModuleVariables()) {
if (!EmitVariable(var->Declaration())) {
return false;
}
found_func_variable = true;
}
if (found_func_variable) {
out_ << std::endl;
}
for (auto* f : program_->AST().Functions()) {
auto* f_sem = program_->Sem().Get(f);
if (!f_sem->HasAncestorEntryPoint(program_->Symbols().Get(name))) {
continue;
}
if (!EmitFunction(f)) {
return false;
}
out_ << std::endl;
}
if (!EmitFunction(func)) {
return false;
}
out_ << std::endl;
return true;
} }
bool GeneratorImpl::EmitConstructedType(const type::Type* ty) { bool GeneratorImpl::EmitConstructedType(const type::Type* ty) {

View File

@ -57,9 +57,10 @@ class GeneratorImpl : public TextGenerator {
explicit GeneratorImpl(const Program* program); explicit GeneratorImpl(const Program* program);
~GeneratorImpl(); ~GeneratorImpl();
/// Generates the result data /// Generates the result data, optionally restricted to a single entry point
/// @param entry entry point to target, or nullptr
/// @returns true on successful generation; false otherwise /// @returns true on successful generation; false otherwise
bool Generate(); bool Generate(const ast::Function* entry);
/// Generates a single entry point /// Generates a single entry point
/// @param stage the pipeline stage /// @param stage the pipeline stage

View File

@ -0,0 +1,171 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/type/f32_type.h"
#include "src/writer/wgsl/generator_impl.h"
#include "src/writer/wgsl/test_helper.h"
namespace tint {
namespace writer {
namespace wgsl {
namespace {
using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_UnusedFunction) {
Func("func_unused", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::FunctionDecorationList{});
Func("func_used", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::FunctionDecorationList{});
auto* call_func = Call("func_used");
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::CallStatement>(call_func),
},
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(
gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main"))
<< gen.error();
EXPECT_EQ(gen.result(), R"( fn func_used() -> void {
}
[[stage(compute)]]
fn main() -> void {
func_used();
}
)");
}
TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_UnusedVariable) {
auto* global_unused =
Global("global_unused", ast::StorageClass::kInput, ty.f32());
create<ast::VariableDeclStatement>(global_unused);
auto* global_used =
Global("global_used", ast::StorageClass::kInput, ty.f32());
create<ast::VariableDeclStatement>(global_used);
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::AssignmentStatement>(Expr("global_used"), Expr(1.f)),
},
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(
gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main"))
<< gen.error();
EXPECT_EQ(gen.result(), R"( var<in> global_used : f32;
[[stage(compute)]]
fn main() -> void {
global_used = 1.0;
}
)");
}
TEST_F(WgslGeneratorImplTest, Emit_EntryPoint_GlobalsInterleaved) {
auto* global0 = Global("a0", ast::StorageClass::kInput, ty.f32());
create<ast::VariableDeclStatement>(global0);
auto* str0 = create<ast::Struct>(ast::StructMemberList{Member("a", ty.i32())},
ast::StructDecorationList{});
auto* s0 = ty.struct_("S0", str0);
AST().AddConstructedType(s0);
Func("func", ast::VariableList{}, ty.f32(),
ast::StatementList{
create<ast::ReturnStatement>(Expr("a0")),
},
ast::FunctionDecorationList{});
auto* global1 = Global("a1", ast::StorageClass::kOutput, ty.f32());
create<ast::VariableDeclStatement>(global1);
auto* str1 = create<ast::Struct>(ast::StructMemberList{Member("a", ty.i32())},
ast::StructDecorationList{});
auto* s1 = ty.struct_("S1", str1);
AST().AddConstructedType(s1);
auto* call_func = Call("func");
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(
Var("s0", ast::StorageClass::kFunction, s0)),
create<ast::VariableDeclStatement>(
Var("s1", ast::StorageClass::kFunction, s1)),
create<ast::AssignmentStatement>(Expr("a1"), Expr(call_func)),
},
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(
gen.GenerateEntryPoint(tint::ast::PipelineStage::kCompute, "main"))
<< gen.error();
EXPECT_EQ(gen.result(), R"( var<in> a0 : f32;
struct S0 {
a : i32;
};
fn func() -> f32 {
return a0;
}
var<out> a1 : f32;
struct S1 {
a : i32;
};
[[stage(compute)]]
fn main() -> void {
var s0 : S0;
var s1 : S1;
a1 = func();
}
)");
}
} // namespace
} // namespace wgsl
} // namespace writer
} // namespace tint

View File

@ -183,6 +183,7 @@ TEST_F(WgslGeneratorImplTest,
auto* s = ty.struct_("Data", str); auto* s = ty.struct_("Data", str);
type::AccessControl ac(ast::AccessControl::kReadWrite, s); type::AccessControl ac(ast::AccessControl::kReadWrite, s);
AST().AddConstructedType(s);
Global("data", ast::StorageClass::kStorage, &ac, nullptr, Global("data", ast::StorageClass::kStorage, &ac, nullptr,
ast::VariableDecorationList{ ast::VariableDecorationList{
@ -190,8 +191,6 @@ TEST_F(WgslGeneratorImplTest,
create<ast::GroupDecoration>(0), create<ast::GroupDecoration>(0),
}); });
AST().AddConstructedType(s);
{ {
auto* var = auto* var =
Var("v", ast::StorageClass::kFunction, ty.f32(), Var("v", ast::StorageClass::kFunction, ty.f32(),
@ -226,7 +225,7 @@ TEST_F(WgslGeneratorImplTest,
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error(); ASSERT_TRUE(gen.Generate(nullptr)) << gen.error();
EXPECT_EQ(gen.result(), R"([[block]] EXPECT_EQ(gen.result(), R"([[block]]
struct Data { struct Data {
[[offset(0)]] [[offset(0)]]
@ -247,7 +246,6 @@ fn b() -> void {
var v : f32 = data.d; var v : f32 = data.d;
return; return;
} }
)"); )");
} }

View File

@ -0,0 +1,123 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/stage_decoration.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/type/f32_type.h"
#include "src/writer/wgsl/generator_impl.h"
#include "src/writer/wgsl/test_helper.h"
namespace tint {
namespace writer {
namespace wgsl {
namespace {
using WgslGeneratorImplTest = TestHelper;
TEST_F(WgslGeneratorImplTest, Emit_GlobalDeclAfterFunction) {
auto* func_var = Var("a", ast::StorageClass::kFunction, ty.f32(), nullptr,
ast::VariableDecorationList{});
WrapInFunction(create<ast::VariableDeclStatement>(func_var));
auto* global_var = Global("a", ast::StorageClass::kInput, ty.f32());
create<ast::VariableDeclStatement>(global_var);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.Generate(nullptr)) << gen.error();
EXPECT_EQ(gen.result(), R"( fn test_function() -> void {
var a : f32;
}
var<in> a : f32;
)");
}
TEST_F(WgslGeneratorImplTest, Emit_GlobalsInterleaved) {
auto* global0 = Global("a0", ast::StorageClass::kInput, ty.f32());
create<ast::VariableDeclStatement>(global0);
auto* str0 = create<ast::Struct>(ast::StructMemberList{Member("a", ty.i32())},
ast::StructDecorationList{});
auto* s0 = ty.struct_("S0", str0);
AST().AddConstructedType(s0);
Func("func", ast::VariableList{}, ty.f32(),
ast::StatementList{
create<ast::ReturnStatement>(Expr("a0")),
},
ast::FunctionDecorationList{});
auto* global1 = Global("a1", ast::StorageClass::kOutput, ty.f32());
create<ast::VariableDeclStatement>(global1);
auto* str1 = create<ast::Struct>(ast::StructMemberList{Member("a", ty.i32())},
ast::StructDecorationList{});
auto* s1 = ty.struct_("S1", str1);
AST().AddConstructedType(s1);
auto* call_func = Call("func");
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(
Var("s0", ast::StorageClass::kFunction, s0)),
create<ast::VariableDeclStatement>(
Var("s1", ast::StorageClass::kFunction, s1)),
create<ast::AssignmentStatement>(Expr("a1"), Expr(call_func)),
},
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute),
});
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.Generate(nullptr)) << gen.error();
EXPECT_EQ(gen.result(), R"( var<in> a0 : f32;
struct S0 {
a : i32;
};
fn func() -> f32 {
return a0;
}
var<out> a1 : f32;
struct S1 {
a : i32;
};
[[stage(compute)]]
fn main() -> void {
var s0 : S0;
var s1 : S1;
a1 = func();
}
)");
}
} // namespace
} // namespace wgsl
} // namespace writer
} // namespace tint

View File

@ -35,10 +35,9 @@ TEST_F(WgslGeneratorImplTest, Generate) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error(); ASSERT_TRUE(gen.Generate(nullptr)) << gen.error();
EXPECT_EQ(gen.result(), R"(fn my_func() -> void { EXPECT_EQ(gen.result(), R"(fn my_func() -> void {
} }
)"); )");
} }