From dcdf66ed5b286b8b67bfb633b86f298f8910f344 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 17 Jun 2022 12:48:51 +0000 Subject: [PATCH] tint/ast: Derive off `ast::Variable` Add the new classes: * `ast::Let` * `ast::Override` * `ast::Parameter` * `ast::Var` Limit the fields to those that are only applicable for their type. Note: The resolver and validator is a tangled mess for each of the variable types. This CL tries to keep the functionality exactly the same. I'll clean this up in another change. Bug: tint:1582 Change-Id: Iee83324167ffd4d92ae3032b2134677629c90079 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93780 Reviewed-by: Dan Sinclair Commit-Queue: Ben Clayton Kokoro: Kokoro --- src/tint/BUILD.gn | 8 + src/tint/CMakeLists.txt | 8 + src/tint/ast/function.cc | 2 +- src/tint/ast/function.h | 5 +- src/tint/ast/function_test.cc | 12 -- src/tint/ast/let.cc | 46 ++++ src/tint/ast/let.h | 61 ++++++ src/tint/ast/module.h | 13 ++ src/tint/ast/override.cc | 44 ++++ src/tint/ast/override.h | 62 ++++++ src/tint/ast/parameter.cc | 42 ++++ src/tint/ast/parameter.h | 66 ++++++ src/tint/ast/var.cc | 49 +++++ src/tint/ast/var.h | 86 ++++++++ src/tint/ast/variable.cc | 39 +--- src/tint/ast/variable.h | 106 +-------- src/tint/inspector/inspector.cc | 39 ++-- src/tint/program_builder.h | 156 ++++++-------- src/tint/reader/spirv/function.cc | 24 +-- src/tint/reader/spirv/parser_impl.cc | 65 ++++-- src/tint/reader/spirv/parser_impl.h | 40 +++- src/tint/reader/wgsl/parser_impl.cc | 98 ++++----- src/tint/reader/wgsl/parser_impl.h | 2 +- .../reader/wgsl/parser_impl_for_stmt_test.cc | 6 +- .../parser_impl_global_constant_decl_test.cc | 100 +++++---- .../parser_impl_global_variable_decl_test.cc | 86 ++++---- .../wgsl/parser_impl_param_list_test.cc | 12 +- src/tint/resolver/dependency_graph.cc | 12 +- .../pipeline_overridable_constant_test.cc | 4 +- src/tint/resolver/resolver.cc | 197 ++++++++--------- src/tint/resolver/resolver.h | 7 +- src/tint/resolver/type_validation_test.cc | 20 -- src/tint/resolver/uniformity.cc | 5 +- src/tint/resolver/validator.cc | 181 ++++++++-------- src/tint/resolver/validator.h | 6 +- src/tint/resolver/var_let_validation_test.cc | 16 -- src/tint/sem/function.cc | 52 ++--- src/tint/sem/type_mappings.h | 3 + src/tint/sem/variable.cc | 2 +- src/tint/sem/variable.h | 14 +- .../transform/add_spirv_block_attribute.cc | 4 +- src/tint/transform/binding_remapper.cc | 14 +- src/tint/transform/combine_samplers.cc | 17 +- .../transform/fold_trivial_single_use_lets.cc | 6 +- .../module_scope_var_to_entry_point_param.cc | 37 ++-- .../transform/multiplanar_external_texture.cc | 18 +- .../transform/num_workgroups_from_uniform.cc | 4 +- src/tint/transform/simplify_pointers.cc | 6 +- src/tint/transform/single_entry_point.cc | 55 ++--- src/tint/transform/unshadow.cc | 32 ++- .../transform/utils/hoist_to_decl_before.cc | 12 +- src/tint/transform/vertex_pulling.cc | 4 +- .../transform/zero_init_workgroup_memory.cc | 4 +- src/tint/writer/glsl/generator_impl.cc | 194 +++++++++-------- src/tint/writer/glsl/generator_impl.h | 33 ++- .../generator_impl_module_constant_test.cc | 8 +- src/tint/writer/hlsl/generator_impl.cc | 203 ++++++++++-------- src/tint/writer/hlsl/generator_impl.h | 33 ++- .../generator_impl_module_constant_test.cc | 8 +- src/tint/writer/msl/generator_impl.cc | 158 +++++++++----- src/tint/writer/msl/generator_impl.h | 18 +- .../generator_impl_module_constant_test.cc | 6 +- src/tint/writer/spirv/builder.cc | 113 +++++----- src/tint/writer/wgsl/generator_impl.cc | 60 ++++-- test/tint/bug/tint/827.wgsl.expected.hlsl | 1 + test/tint/bug/tint/914.wgsl.expected.hlsl | 1 + 66 files changed, 1652 insertions(+), 1193 deletions(-) create mode 100644 src/tint/ast/let.cc create mode 100644 src/tint/ast/let.h create mode 100644 src/tint/ast/override.cc create mode 100644 src/tint/ast/override.h create mode 100644 src/tint/ast/parameter.cc create mode 100644 src/tint/ast/parameter.h create mode 100644 src/tint/ast/var.cc create mode 100644 src/tint/ast/var.h diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 648cf17c77..e058f5b817 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -264,6 +264,8 @@ libtint_source_set("libtint_core_all_src") { "ast/interpolate_attribute.h", "ast/invariant_attribute.cc", "ast/invariant_attribute.h", + "ast/let.cc", + "ast/let.h", "ast/literal_expression.cc", "ast/literal_expression.h", "ast/location_attribute.cc", @@ -280,6 +282,10 @@ libtint_source_set("libtint_core_all_src") { "ast/multisampled_texture.h", "ast/node.cc", "ast/node.h", + "ast/override.cc", + "ast/override.h", + "ast/parameter.cc", + "ast/parameter.h", "ast/phony_expression.cc", "ast/phony_expression.h", "ast/pipeline_stage.cc", @@ -328,6 +334,8 @@ libtint_source_set("libtint_core_all_src") { "ast/unary_op.h", "ast/unary_op_expression.cc", "ast/unary_op_expression.h", + "ast/var.cc", + "ast/var.h", "ast/variable.cc", "ast/variable.h", "ast/variable_decl_statement.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index a3767b56e2..f7fb14a8fc 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -151,6 +151,8 @@ set(TINT_LIB_SRCS ast/interpolate_attribute.h ast/invariant_attribute.cc ast/invariant_attribute.h + ast/let.cc + ast/let.h ast/literal_expression.cc ast/literal_expression.h ast/location_attribute.cc @@ -167,6 +169,10 @@ set(TINT_LIB_SRCS ast/multisampled_texture.h ast/node.cc ast/node.h + ast/override.cc + ast/override.h + ast/parameter.cc + ast/parameter.h ast/phony_expression.cc ast/phony_expression.h ast/pipeline_stage.cc @@ -215,6 +221,8 @@ set(TINT_LIB_SRCS ast/unary_op_expression.h ast/unary_op.cc ast/unary_op.h + ast/var.cc + ast/var.h ast/variable_decl_statement.cc ast/variable_decl_statement.h ast/variable.cc diff --git a/src/tint/ast/function.cc b/src/tint/ast/function.cc index 2fb64f1b95..84d80d798a 100644 --- a/src/tint/ast/function.cc +++ b/src/tint/ast/function.cc @@ -40,7 +40,7 @@ Function::Function(ProgramID pid, TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id); for (auto* param : params) { - TINT_ASSERT(AST, param && param->is_const); + TINT_ASSERT(AST, tint::Is(param)); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, param, program_id); } TINT_ASSERT(AST, symbol.IsValid()); diff --git a/src/tint/ast/function.h b/src/tint/ast/function.h index 585f1c80f1..d849486ee8 100644 --- a/src/tint/ast/function.h +++ b/src/tint/ast/function.h @@ -26,14 +26,11 @@ #include "src/tint/ast/builtin_attribute.h" #include "src/tint/ast/group_attribute.h" #include "src/tint/ast/location_attribute.h" +#include "src/tint/ast/parameter.h" #include "src/tint/ast/pipeline_stage.h" -#include "src/tint/ast/variable.h" namespace tint::ast { -/// ParameterList is a list of function parameters -using ParameterList = std::vector; - /// A Function statement. class Function final : public Castable { public: diff --git a/src/tint/ast/function_test.cc b/src/tint/ast/function_test.cc index 4fb6b9121e..f1de639b4e 100644 --- a/src/tint/ast/function_test.cc +++ b/src/tint/ast/function_test.cc @@ -122,18 +122,6 @@ TEST_F(FunctionTest, Assert_DifferentProgramID_ReturnAttr) { "internal compiler error"); } -TEST_F(FunctionTest, Assert_NonConstParam) { - EXPECT_FATAL_FAILURE( - { - ProgramBuilder b; - ParameterList params; - params.push_back(b.Var("var", b.ty.i32(), ast::StorageClass::kNone)); - - b.Func("f", params, b.ty.void_(), {}); - }, - "internal compiler error"); -} - using FunctionListTest = TestHelper; TEST_F(FunctionListTest, FindSymbol) { diff --git a/src/tint/ast/let.cc b/src/tint/ast/let.cc new file mode 100644 index 0000000000..e771c4d2ba --- /dev/null +++ b/src/tint/ast/let.cc @@ -0,0 +1,46 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/let.h" + +#include "src/tint/program_builder.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ast::Let); + +namespace tint::ast { + +Let::Let(ProgramID pid, + const Source& src, + const Symbol& sym, + const ast::Type* ty, + const Expression* ctor, + AttributeList attrs) + : Base(pid, src, sym, ty, ctor, attrs) { + TINT_ASSERT(AST, ctor != nullptr); +} + +Let::Let(Let&&) = default; + +Let::~Let() = default; + +const Let* Let::Clone(CloneContext* ctx) const { + auto src = ctx->Clone(source); + auto sym = ctx->Clone(symbol); + auto* ty = ctx->Clone(type); + auto* ctor = ctx->Clone(constructor); + auto attrs = ctx->Clone(attributes); + return ctx->dst->create(src, sym, ty, ctor, attrs); +} + +} // namespace tint::ast diff --git a/src/tint/ast/let.h b/src/tint/ast/let.h new file mode 100644 index 0000000000..2b0a6aaf87 --- /dev/null +++ b/src/tint/ast/let.h @@ -0,0 +1,61 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_AST_LET_H_ +#define SRC_TINT_AST_LET_H_ + +#include "src/tint/ast/variable.h" + +namespace tint::ast { + +/// A "let" declaration is a name for a function-scoped runtime typed value. +/// +/// Examples: +/// +/// ``` +/// let twice_depth : i32 = width + width; // Must have initializer +/// ``` +/// @see https://www.w3.org/TR/WGSL/#let-decls +class Let final : public Castable { + public: + /// Create a 'let' variable + /// @param program_id the identifier of the program that owns this node + /// @param source the variable source + /// @param sym the variable symbol + /// @param type the declared variable type + /// @param constructor the constructor expression + /// @param attributes the variable attributes + Let(ProgramID program_id, + const Source& source, + const Symbol& sym, + const ast::Type* type, + const Expression* constructor, + AttributeList attributes); + + /// Move constructor + Let(Let&&); + + /// Destructor + ~Let() override; + + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @param ctx the clone context + /// @return the newly cloned node + const Let* Clone(CloneContext* ctx) const override; +}; + +} // namespace tint::ast + +#endif // SRC_TINT_AST_LET_H_ diff --git a/src/tint/ast/module.h b/src/tint/ast/module.h index 45b1ec6ebe..17cf353e8f 100644 --- a/src/tint/ast/module.h +++ b/src/tint/ast/module.h @@ -77,6 +77,19 @@ class Module final : public Castable { /// @returns the global variables for the module VariableList& GlobalVariables() { return global_variables_; } + /// @returns the global variable declarations of kind 'T' for the module + template > + std::vector Globals() const { + std::vector out; + out.reserve(global_variables_.size()); + for (auto* global : global_variables_) { + if (auto* var = global->As()) { + out.emplace_back(var); + } + } + return out; + } + /// @returns the extension set for the module const EnableList& Enables() const { return enables_; } diff --git a/src/tint/ast/override.cc b/src/tint/ast/override.cc new file mode 100644 index 0000000000..f494bc83b7 --- /dev/null +++ b/src/tint/ast/override.cc @@ -0,0 +1,44 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/override.h" + +#include "src/tint/program_builder.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ast::Override); + +namespace tint::ast { + +Override::Override(ProgramID pid, + const Source& src, + const Symbol& sym, + const ast::Type* ty, + const Expression* ctor, + AttributeList attrs) + : Base(pid, src, sym, ty, ctor, attrs) {} + +Override::Override(Override&&) = default; + +Override::~Override() = default; + +const Override* Override::Clone(CloneContext* ctx) const { + auto src = ctx->Clone(source); + auto sym = ctx->Clone(symbol); + auto* ty = ctx->Clone(type); + auto* ctor = ctx->Clone(constructor); + auto attrs = ctx->Clone(attributes); + return ctx->dst->create(src, sym, ty, ctor, attrs); +} + +} // namespace tint::ast diff --git a/src/tint/ast/override.h b/src/tint/ast/override.h new file mode 100644 index 0000000000..168b9b213c --- /dev/null +++ b/src/tint/ast/override.h @@ -0,0 +1,62 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_AST_OVERRIDE_H_ +#define SRC_TINT_AST_OVERRIDE_H_ + +#include "src/tint/ast/variable.h" + +namespace tint::ast { + +/// An "override" declaration - a name for a pipeline-overridable constant. +/// Examples: +/// +/// ``` +/// override radius : i32 = 2; // Can be overridden by name. +/// @id(5) override width : i32 = 2; // Can be overridden by ID. +/// override scale : f32; // No default - must be overridden. +/// ``` +/// @see https://www.w3.org/TR/WGSL/#override-decls +class Override final : public Castable { + public: + /// Create an 'override' pipeline-overridable constant. + /// @param program_id the identifier of the program that owns this node + /// @param source the variable source + /// @param sym the variable symbol + /// @param type the declared variable type + /// @param constructor the constructor expression + /// @param attributes the variable attributes + Override(ProgramID program_id, + const Source& source, + const Symbol& sym, + const ast::Type* type, + const Expression* constructor, + AttributeList attributes); + + /// Move constructor + Override(Override&&); + + /// Destructor + ~Override() override; + + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @param ctx the clone context + /// @return the newly cloned node + const Override* Clone(CloneContext* ctx) const override; +}; + +} // namespace tint::ast + +#endif // SRC_TINT_AST_OVERRIDE_H_ diff --git a/src/tint/ast/parameter.cc b/src/tint/ast/parameter.cc new file mode 100644 index 0000000000..b7ea3b1152 --- /dev/null +++ b/src/tint/ast/parameter.cc @@ -0,0 +1,42 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/parameter.h" + +#include "src/tint/program_builder.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ast::Parameter); + +namespace tint::ast { + +Parameter::Parameter(ProgramID pid, + const Source& src, + const Symbol& sym, + const ast::Type* ty, + AttributeList attrs) + : Base(pid, src, sym, ty, nullptr, attrs) {} + +Parameter::Parameter(Parameter&&) = default; + +Parameter::~Parameter() = default; + +const Parameter* Parameter::Clone(CloneContext* ctx) const { + auto src = ctx->Clone(source); + auto sym = ctx->Clone(symbol); + auto* ty = ctx->Clone(type); + auto attrs = ctx->Clone(attributes); + return ctx->dst->create(src, sym, ty, attrs); +} + +} // namespace tint::ast diff --git a/src/tint/ast/parameter.h b/src/tint/ast/parameter.h new file mode 100644 index 0000000000..d3ecf8dfa1 --- /dev/null +++ b/src/tint/ast/parameter.h @@ -0,0 +1,66 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_AST_PARAMETER_H_ +#define SRC_TINT_AST_PARAMETER_H_ + +#include + +#include "src/tint/ast/variable.h" + +namespace tint::ast { + +/// A formal parameter to a function - a name for a typed value to be passed into a function. +/// Example: +/// +/// ``` +/// fn twice(a: i32) -> i32 { // "a:i32" is the formal parameter +/// return a + a; +/// } +/// ``` +/// +/// @see https://www.w3.org/TR/WGSL/#creation-time-consts +class Parameter final : public Castable { + public: + /// Create a 'parameter' creation-time value variable. + /// @param program_id the identifier of the program that owns this node + /// @param source the variable source + /// @param sym the variable symbol + /// @param type the declared variable type + /// @param attributes the variable attributes + Parameter(ProgramID program_id, + const Source& source, + const Symbol& sym, + const ast::Type* type, + AttributeList attributes); + + /// Move constructor + Parameter(Parameter&&); + + /// Destructor + ~Parameter() override; + + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @param ctx the clone context + /// @return the newly cloned node + const Parameter* Clone(CloneContext* ctx) const override; +}; + +/// A list of parameters +using ParameterList = std::vector; + +} // namespace tint::ast + +#endif // SRC_TINT_AST_PARAMETER_H_ diff --git a/src/tint/ast/var.cc b/src/tint/ast/var.cc new file mode 100644 index 0000000000..622aa03022 --- /dev/null +++ b/src/tint/ast/var.cc @@ -0,0 +1,49 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/ast/var.h" + +#include "src/tint/program_builder.h" + +TINT_INSTANTIATE_TYPEINFO(tint::ast::Var); + +namespace tint::ast { + +Var::Var(ProgramID pid, + const Source& src, + const Symbol& sym, + const ast::Type* ty, + StorageClass storage_class, + Access access, + const Expression* ctor, + AttributeList attrs) + : Base(pid, src, sym, ty, ctor, attrs), + declared_storage_class(storage_class), + declared_access(access) {} + +Var::Var(Var&&) = default; + +Var::~Var() = default; + +const Var* Var::Clone(CloneContext* ctx) const { + auto src = ctx->Clone(source); + auto sym = ctx->Clone(symbol); + auto* ty = ctx->Clone(type); + auto* ctor = ctx->Clone(constructor); + auto attrs = ctx->Clone(attributes); + return ctx->dst->create(src, sym, ty, declared_storage_class, declared_access, ctor, + attrs); +} + +} // namespace tint::ast diff --git a/src/tint/ast/var.h b/src/tint/ast/var.h new file mode 100644 index 0000000000..fd0358097e --- /dev/null +++ b/src/tint/ast/var.h @@ -0,0 +1,86 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_AST_VAR_H_ +#define SRC_TINT_AST_VAR_H_ + +#include +#include + +#include "src/tint/ast/variable.h" + +namespace tint::ast { + +/// A "var" declaration is a name for typed storage. +/// +/// Examples: +/// +/// ``` +/// // Declared outside a function, i.e. at module scope, requires +/// // a storage class. +/// var width : i32; // no initializer +/// var height : i32 = 3; // with initializer +/// +/// // A variable declared inside a function doesn't take a storage class, +/// // and maps to SPIR-V Function storage. +/// var computed_depth : i32; +/// var area : i32 = compute_area(width, height); +/// ``` +/// +/// @see https://www.w3.org/TR/WGSL/#var-decls +class Var final : public Castable { + public: + /// Create a 'var' variable + /// @param program_id the identifier of the program that owns this node + /// @param source the variable source + /// @param sym the variable symbol + /// @param type the declared variable type + /// @param declared_storage_class the declared storage class + /// @param declared_access the declared access control + /// @param constructor the constructor expression + /// @param attributes the variable attributes + Var(ProgramID program_id, + const Source& source, + const Symbol& sym, + const ast::Type* type, + StorageClass declared_storage_class, + Access declared_access, + const Expression* constructor, + AttributeList attributes); + + /// Move constructor + Var(Var&&); + + /// Destructor + ~Var() override; + + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @param ctx the clone context + /// @return the newly cloned node + const Var* Clone(CloneContext* ctx) const override; + + /// The declared storage class + const StorageClass declared_storage_class; + + /// The declared access control + const Access declared_access; +}; + +/// A list of `var` declarations +using VarList = std::vector; + +} // namespace tint::ast + +#endif // SRC_TINT_AST_VAR_H_ diff --git a/src/tint/ast/variable.cc b/src/tint/ast/variable.cc index 26991f252a..a9ae829ce7 100644 --- a/src/tint/ast/variable.cc +++ b/src/tint/ast/variable.cc @@ -13,9 +13,8 @@ // limitations under the License. #include "src/tint/ast/variable.h" - -#include "src/tint/program_builder.h" -#include "src/tint/sem/variable.h" +#include "src/tint/ast/binding_attribute.h" +#include "src/tint/ast/group_attribute.h" TINT_INSTANTIATE_TYPEINFO(tint::ast::Variable); @@ -24,24 +23,11 @@ namespace tint::ast { Variable::Variable(ProgramID pid, const Source& src, const Symbol& sym, - StorageClass dsc, - Access da, const ast::Type* ty, - bool constant, - bool overridable, const Expression* ctor, AttributeList attrs) - : Base(pid, src), - symbol(sym), - type(ty), - is_const(constant), - is_overridable(overridable), - constructor(ctor), - attributes(std::move(attrs)), - declared_storage_class(dsc), - declared_access(da) { + : Base(pid, src), symbol(sym), type(ty), constructor(ctor), attributes(std::move(attrs)) { TINT_ASSERT(AST, symbol.IsValid()); - TINT_ASSERT(AST, is_overridable ? is_const : true); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, constructor, program_id); } @@ -54,23 +40,12 @@ VariableBindingPoint Variable::BindingPoint() const { const GroupAttribute* group = nullptr; const BindingAttribute* binding = nullptr; for (auto* attr : attributes) { - if (auto* g = attr->As()) { - group = g; - } else if (auto* b = attr->As()) { - binding = b; - } + Switch( + attr, // + [&](const GroupAttribute* a) { group = a; }, + [&](const BindingAttribute* a) { binding = a; }); } return VariableBindingPoint{group, binding}; } -const Variable* Variable::Clone(CloneContext* ctx) const { - auto src = ctx->Clone(source); - auto sym = ctx->Clone(symbol); - auto* ty = ctx->Clone(type); - auto* ctor = ctx->Clone(constructor); - auto attrs = ctx->Clone(attributes); - return ctx->dst->create(src, sym, declared_storage_class, declared_access, ty, - is_const, is_overridable, ctor, attrs); -} - } // namespace tint::ast diff --git a/src/tint/ast/variable.h b/src/tint/ast/variable.h index 58022558f5..bc25753171 100644 --- a/src/tint/ast/variable.h +++ b/src/tint/ast/variable.h @@ -45,112 +45,38 @@ struct VariableBindingPoint { inline operator bool() const { return group && binding; } }; -/// A Variable statement. +/// Variable is the base class for Var, Let, Const, Override and Parameter. /// -/// An instance of this class represents one of four constructs in WGSL: "var" -/// declaration, "let" declaration, "override" declaration, or formal parameter -/// to a function. +/// An instance of this class represents one of five constructs in WGSL: "var" declaration, "let" +/// declaration, "override" declaration, "const" declaration, or formal parameter to a function. /// -/// 1. A "var" declaration is a name for typed storage. Examples: -/// -/// // Declared outside a function, i.e. at module scope, requires -/// // a storage class. -/// var width : i32; // no initializer -/// var height : i32 = 3; // with initializer -/// -/// // A variable declared inside a function doesn't take a storage class, -/// // and maps to SPIR-V Function storage. -/// var computed_depth : i32; -/// var area : i32 = compute_area(width, height); -/// -/// 2. A "let" declaration is a name for a typed value. Examples: -/// -/// let twice_depth : i32 = width + width; // Must have initializer -/// -/// 3. An "override" declaration is a name for a pipeline-overridable constant. -/// Examples: -/// -/// override radius : i32 = 2; // Can be overridden by name. -/// @id(5) override width : i32 = 2; // Can be overridden by ID. -/// override scale : f32; // No default - must be overridden. -/// -/// 4. A formal parameter to a function is a name for a typed value to -/// be passed into a function. Example: -/// -/// fn twice(a: i32) -> i32 { // "a:i32" is the formal parameter -/// return a + a; -/// } -/// -/// From the WGSL draft, about "var":: -/// -/// A variable is a named reference to storage that can contain a value of a -/// particular type. -/// -/// Two types are associated with a variable: its store type (the type of -/// value that may be placed in the referenced storage) and its reference -/// type (the type of the variable itself). If a variable has store type T -/// and storage class S, then its reference type is pointer-to-T-in-S. -/// -/// This class uses the term "type" to refer to: -/// the value type of a "let", -/// the value type of an "override", -/// the value type of the formal parameter, -/// or the store type of the "var". -// -/// Setting is_const: -/// - "var" gets false -/// - "let" gets true -/// - "override" gets true -/// - formal parameter gets true -/// -/// Setting is_overrideable: -/// - "var" gets false -/// - "let" gets false -/// - "override" gets true -/// - formal parameter gets false -/// -/// Setting storage class: -/// - "var" is StorageClass::kNone when using the -/// defaulting syntax for a "var" declared inside a function. -/// - "let" is always StorageClass::kNone. -/// - formal parameter is always StorageClass::kNone. -class Variable final : public Castable { +/// @see https://www.w3.org/TR/WGSL/#value-decls +class Variable : public Castable { public: - /// Create a variable + /// Constructor /// @param program_id the identifier of the program that owns this node /// @param source the variable source /// @param sym the variable symbol - /// @param declared_storage_class the declared storage class - /// @param declared_access the declared access control /// @param type the declared variable type - /// @param is_const true if the variable is const - /// @param is_overridable true if the variable is pipeline-overridable /// @param constructor the constructor expression /// @param attributes the variable attributes Variable(ProgramID program_id, const Source& source, const Symbol& sym, - StorageClass declared_storage_class, - Access declared_access, const ast::Type* type, - bool is_const, - bool is_overridable, const Expression* constructor, AttributeList attributes); + /// Move constructor Variable(Variable&&); + /// Destructor ~Variable() override; - /// @returns the binding point information for the variable + /// @returns the binding point information from the variable's attributes. + /// @note binding points should only be applied to Var and Parameter types. VariableBindingPoint BindingPoint() const; - /// Clones this node and all transitive child nodes using the `CloneContext` - /// `ctx`. - /// @param ctx the clone context - /// @return the newly cloned node - const Variable* Clone(CloneContext* ctx) const override; - /// The variable symbol const Symbol symbol; @@ -159,23 +85,11 @@ class Variable final : public Castable { /// var i = 1; const ast::Type* const type; - /// True if this is a constant, false otherwise - const bool is_const; - - /// True if this is a pipeline-overridable constant, false otherwise - const bool is_overridable; - /// The constructor expression or nullptr if none set const Expression* const constructor; /// The attributes attached to this variable const AttributeList attributes; - - /// The declared storage class - const StorageClass declared_storage_class; - - /// The declared access control - const Access declared_access; }; /// A list of variables diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 569a1042e6..df208f6278 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc @@ -183,7 +183,7 @@ std::vector Inspector::GetEntryPoints() { auto name = program_->Symbols().NameFor(decl->symbol); auto* global = var->As(); - if (global && global->IsOverridable()) { + if (global && global->Declaration()->Is()) { OverridableConstant overridable_constant; overridable_constant.name = name; overridable_constant.numeric_id = global->ConstantId(); @@ -219,7 +219,7 @@ std::map Inspector::GetConstantIDs() { std::map result; for (auto* var : program_->AST().GlobalVariables()) { auto* global = program_->Sem().Get(var); - if (!global || !global->IsOverridable()) { + if (!global || !global->Declaration()->Is()) { continue; } @@ -276,7 +276,7 @@ std::map Inspector::GetConstantNameToIdMap() { std::map result; for (auto* var : program_->AST().GlobalVariables()) { auto* global = program_->Sem().Get(var); - if (global && global->IsOverridable()) { + if (global && global->Declaration()->Is()) { auto name = program_->Symbols().NameFor(var->symbol); result[name] = global->ConstantId(); } @@ -813,25 +813,24 @@ void Inspector::GenerateSamplerTargets() { auto* t = c->args[texture_index]; auto* s = c->args[sampler_index]; - GetOriginatingResources(std::array{t, s}, - [&](std::array globals) { - auto* texture = globals[0]; - sem::BindingPoint texture_binding_point = { - texture->Declaration()->BindingPoint().group->value, - texture->Declaration()->BindingPoint().binding->value}; + GetOriginatingResources( + std::array{t, s}, + [&](std::array globals) { + auto* texture = globals[0]->Declaration()->As(); + sem::BindingPoint texture_binding_point = {texture->BindingPoint().group->value, + texture->BindingPoint().binding->value}; - auto* sampler = globals[1]; - sem::BindingPoint sampler_binding_point = { - sampler->Declaration()->BindingPoint().group->value, - sampler->Declaration()->BindingPoint().binding->value}; + auto* sampler = globals[1]->Declaration()->As(); + sem::BindingPoint sampler_binding_point = {sampler->BindingPoint().group->value, + sampler->BindingPoint().binding->value}; - for (auto* entry_point : entry_points) { - const auto& ep_name = program_->Symbols().NameFor( - entry_point->Declaration()->symbol); - (*sampler_targets_)[ep_name].add( - {sampler_binding_point, texture_binding_point}); - } - }); + for (auto* entry_point : entry_points) { + const auto& ep_name = + program_->Symbols().NameFor(entry_point->Declaration()->symbol); + (*sampler_targets_)[ep_name].add( + {sampler_binding_point, texture_binding_point}); + } + }); } } diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 5f17d1defe..54928671a4 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -53,11 +53,14 @@ #include "src/tint/ast/index_accessor_expression.h" #include "src/tint/ast/interpolate_attribute.h" #include "src/tint/ast/invariant_attribute.h" +#include "src/tint/ast/let.h" #include "src/tint/ast/loop_statement.h" #include "src/tint/ast/matrix.h" #include "src/tint/ast/member_accessor_expression.h" #include "src/tint/ast/module.h" #include "src/tint/ast/multisampled_texture.h" +#include "src/tint/ast/override.h" +#include "src/tint/ast/parameter.h" #include "src/tint/ast/phony_expression.h" #include "src/tint/ast/pointer.h" #include "src/tint/ast/return_statement.h" @@ -73,6 +76,7 @@ #include "src/tint/ast/type_name.h" #include "src/tint/ast/u32.h" #include "src/tint/ast/unary_op_expression.h" +#include "src/tint/ast/var.h" #include "src/tint/ast/variable_decl_statement.h" #include "src/tint/ast/vector.h" #include "src/tint/ast/void.h" @@ -1328,14 +1332,13 @@ class ProgramBuilder { /// * ast::AttributeList - specifies the variable's attributes /// Note that repeated arguments of the same type will use the last argument's /// value. - /// @returns a `ast::Variable` with the given name, type and additional + /// @returns a `ast::Var` with the given name, type and additional /// options template - const ast::Variable* Var(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) { + const ast::Var* Var(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) { VarOptionals opts(std::forward(optional)...); - return create(Sym(std::forward(name)), opts.storage, opts.access, type, - false /* is_const */, false /* is_overridable */, - opts.constructor, std::move(opts.attributes)); + return create(Sym(std::forward(name)), type, opts.storage, opts.access, + opts.constructor, std::move(opts.attributes)); } /// @param source the variable source @@ -1349,32 +1352,28 @@ class ProgramBuilder { /// * ast::AttributeList - specifies the variable's attributes /// Note that repeated arguments of the same type will use the last argument's /// value. - /// @returns a `ast::Variable` with the given name, storage and type + /// @returns a `ast::Var` with the given name, storage and type template - const ast::Variable* Var(const Source& source, - NAME&& name, - const ast::Type* type, - OPTIONAL&&... optional) { + const ast::Var* Var(const Source& source, + NAME&& name, + const ast::Type* type, + OPTIONAL&&... optional) { VarOptionals opts(std::forward(optional)...); - return create(source, Sym(std::forward(name)), opts.storage, - opts.access, type, false /* is_const */, - false /* is_overridable */, opts.constructor, - std::move(opts.attributes)); + return create(source, Sym(std::forward(name)), type, opts.storage, + opts.access, opts.constructor, std::move(opts.attributes)); } /// @param name the variable name /// @param type the variable type /// @param constructor constructor expression /// @param attributes optional variable attributes - /// @returns an immutable `ast::Variable` with the given name and type + /// @returns an `ast::Let` with the given name and type template - const ast::Variable* Let(NAME&& name, - const ast::Type* type, - const ast::Expression* constructor, - ast::AttributeList attributes = {}) { - return create(Sym(std::forward(name)), ast::StorageClass::kNone, - ast::Access::kUndefined, type, true /* is_const */, - false /* is_overridable */, constructor, attributes); + const ast::Let* Let(NAME&& name, + const ast::Type* type, + const ast::Expression* constructor, + ast::AttributeList attributes = {}) { + return create(Sym(std::forward(name)), type, constructor, attributes); } /// @param source the variable source @@ -1382,46 +1381,39 @@ class ProgramBuilder { /// @param type the variable type /// @param constructor constructor expression /// @param attributes optional variable attributes - /// @returns an immutable `ast::Variable` with the given name and type + /// @returns an `ast::Let` with the given name and type template - const ast::Variable* Let(const Source& source, - NAME&& name, - const ast::Type* type, - const ast::Expression* constructor, - ast::AttributeList attributes = {}) { - return create(source, Sym(std::forward(name)), - ast::StorageClass::kNone, ast::Access::kUndefined, type, - true /* is_const */, false /* is_overridable */, constructor, - attributes); + const ast::Let* Let(const Source& source, + NAME&& name, + const ast::Type* type, + const ast::Expression* constructor, + ast::AttributeList attributes = {}) { + return create(source, Sym(std::forward(name)), type, constructor, + attributes); } /// @param name the parameter name /// @param type the parameter type /// @param attributes optional parameter attributes - /// @returns an immutable `ast::Variable` with the given name and type + /// @returns an `ast::Parameter` with the given name and type template - const ast::Variable* Param(NAME&& name, - const ast::Type* type, - ast::AttributeList attributes = {}) { - return create(Sym(std::forward(name)), ast::StorageClass::kNone, - ast::Access::kUndefined, type, true /* is_const */, - false /* is_overridable */, nullptr, attributes); + const ast::Parameter* Param(NAME&& name, + const ast::Type* type, + ast::AttributeList attributes = {}) { + return create(Sym(std::forward(name)), type, attributes); } /// @param source the parameter source /// @param name the parameter name /// @param type the parameter type /// @param attributes optional parameter attributes - /// @returns an immutable `ast::Variable` with the given name and type + /// @returns an `ast::Parameter` with the given name and type template - const ast::Variable* Param(const Source& source, - NAME&& name, - const ast::Type* type, - ast::AttributeList attributes = {}) { - return create(source, Sym(std::forward(name)), - ast::StorageClass::kNone, ast::Access::kUndefined, type, - true /* is_const */, false /* is_overridable */, nullptr, - attributes); + const ast::Parameter* Param(const Source& source, + NAME&& name, + const ast::Type* type, + ast::AttributeList attributes = {}) { + return create(source, Sym(std::forward(name)), type, attributes); } /// @param name the variable name @@ -1434,10 +1426,10 @@ class ProgramBuilder { /// * ast::AttributeList - specifies the variable's attributes /// Note that repeated arguments of the same type will use the last argument's /// value. - /// @returns a new `ast::Variable`, which is automatically registered as a - /// global variable with the ast::Module. + /// @returns a new `ast::Var`, which is automatically registered as a global variable with the + /// ast::Module. template > - const ast::Variable* Global(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) { + const ast::Var* Global(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) { auto* var = Var(std::forward(name), type, std::forward(optional)...); AST().AddGlobalVariable(var); return var; @@ -1454,13 +1446,13 @@ class ProgramBuilder { /// * ast::AttributeList - specifies the variable's attributes /// Note that repeated arguments of the same type will use the last argument's /// value. - /// @returns a new `ast::Variable`, which is automatically registered as a - /// global variable with the ast::Module. + /// @returns a new `ast::Var`, which is automatically registered as a global variable with the + /// ast::Module. template - const ast::Variable* Global(const Source& source, - NAME&& name, - const ast::Type* type, - OPTIONAL&&... optional) { + const ast::Var* Global(const Source& source, + NAME&& name, + const ast::Type* type, + OPTIONAL&&... optional) { auto* var = Var(source, std::forward(name), type, std::forward(optional)...); AST().AddGlobalVariable(var); @@ -1471,14 +1463,13 @@ class ProgramBuilder { /// @param type the variable type /// @param constructor constructor expression /// @param attributes optional variable attributes - /// @returns a const `ast::Variable` constructed by calling Var() with the - /// arguments of `args`, which is automatically registered as a global - /// variable with the ast::Module. + /// @returns an `ast::Let` constructed by calling Let() with the arguments of `args`, which is + /// automatically registered as a global variable with the ast::Module. template - const ast::Variable* GlobalConst(NAME&& name, - const ast::Type* type, - const ast::Expression* constructor, - ast::AttributeList attributes = {}) { + const ast::Let* GlobalConst(NAME&& name, + const ast::Type* type, + const ast::Expression* constructor, + ast::AttributeList attributes = {}) { auto* var = Let(std::forward(name), type, constructor, std::move(attributes)); AST().AddGlobalVariable(var); return var; @@ -1489,15 +1480,15 @@ class ProgramBuilder { /// @param type the variable type /// @param constructor constructor expression /// @param attributes optional variable attributes - /// @returns a const `ast::Variable` constructed by calling Var() with the + /// @returns a const `ast::Let` constructed by calling Var() with the /// arguments of `args`, which is automatically registered as a global /// variable with the ast::Module. template - const ast::Variable* GlobalConst(const Source& source, - NAME&& name, - const ast::Type* type, - const ast::Expression* constructor, - ast::AttributeList attributes = {}) { + const ast::Let* GlobalConst(const Source& source, + NAME&& name, + const ast::Type* type, + const ast::Expression* constructor, + ast::AttributeList attributes = {}) { auto* var = Let(source, std::forward(name), type, constructor, std::move(attributes)); AST().AddGlobalVariable(var); return var; @@ -1507,17 +1498,15 @@ class ProgramBuilder { /// @param type the variable type /// @param constructor optional constructor expression /// @param attributes optional variable attributes - /// @returns an overridable const `ast::Variable` which is automatically - /// registered as a global variable with the ast::Module. + /// @returns an `ast::Override` which is automatically registered as a global variable with the + /// ast::Module. template - const ast::Variable* Override(NAME&& name, + const ast::Override* Override(NAME&& name, const ast::Type* type, const ast::Expression* constructor, ast::AttributeList attributes = {}) { - auto* var = - create(source_, Sym(std::forward(name)), ast::StorageClass::kNone, - ast::Access::kUndefined, type, true /* is_const */, - true /* is_overridable */, constructor, std::move(attributes)); + auto* var = create(source_, Sym(std::forward(name)), type, constructor, + std::move(attributes)); AST().AddGlobalVariable(var); return var; } @@ -1527,19 +1516,16 @@ class ProgramBuilder { /// @param type the variable type /// @param constructor constructor expression /// @param attributes optional variable attributes - /// @returns a const `ast::Variable` constructed by calling Var() with the - /// arguments of `args`, which is automatically registered as a global - /// variable with the ast::Module. + /// @returns an `ast::Override` constructed with the arguments of `args`, which is automatically + /// registered as a global variable with the ast::Module. template - const ast::Variable* Override(const Source& source, + const ast::Override* Override(const Source& source, NAME&& name, const ast::Type* type, const ast::Expression* constructor, ast::AttributeList attributes = {}) { - auto* var = - create(source, Sym(std::forward(name)), ast::StorageClass::kNone, - ast::Access::kUndefined, type, true /* is_const */, - true /* is_overridable */, constructor, std::move(attributes)); + auto* var = create(source, Sym(std::forward(name)), type, constructor, + std::move(attributes)); AST().AddGlobalVariable(var); return var; } diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc index 72cc93b219..2a5169b620 100644 --- a/src/tint/reader/spirv/function.cc +++ b/src/tint/reader/spirv/function.cc @@ -1253,12 +1253,12 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() { auto* sample_mask_array_type = store_type->UnwrapRef()->UnwrapAlias()->As(); TINT_ASSERT(Reader, sample_mask_array_type); ok = EmitPipelineInput(var_name, store_type, ¶m_decos, {0}, - sample_mask_array_type->type, forced_param_type, &(decl.params), + sample_mask_array_type->type, forced_param_type, &decl.params, &stmts); } else { // The normal path. ok = EmitPipelineInput(var_name, store_type, ¶m_decos, {}, store_type, - forced_param_type, &(decl.params), &stmts); + forced_param_type, &decl.params, &stmts); } if (!ok) { return false; @@ -1404,8 +1404,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) { auto* type = parser_impl_.ConvertType(param->type_id()); if (type != nullptr) { auto* ast_param = - parser_impl_.MakeVariable(param->result_id(), ast::StorageClass::kNone, type, true, - false, nullptr, ast::AttributeList{}); + parser_impl_.MakeParameter(param->result_id(), type, ast::AttributeList{}); // Parameters are treated as const declarations. ast_params.emplace_back(ast_param); // The value is accessible by name. @@ -2468,9 +2467,8 @@ bool FunctionEmitter::EmitFunctionVariables() { return false; } } - auto* var = - parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, var_store_type, - false, false, constructor, ast::AttributeList{}); + auto* var = parser_impl_.MakeVar(inst.result_id(), ast::StorageClass::kNone, var_store_type, + constructor, ast::AttributeList{}); auto* var_decl_stmt = create(Source{}, var); AddStatement(var_decl_stmt); auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone); @@ -3328,8 +3326,8 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info, TINT_ASSERT(Reader, def_inst); auto* storage_type = RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id); AddStatement(create( - Source{}, parser_impl_.MakeVariable(id, ast::StorageClass::kNone, storage_type, false, - false, nullptr, ast::AttributeList{}))); + Source{}, parser_impl_.MakeVar(id, ast::StorageClass::kNone, storage_type, nullptr, + ast::AttributeList{}))); auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone); identifier_types_.emplace(id, type); } @@ -3396,13 +3394,11 @@ bool FunctionEmitter::EmitConstDefinition(const spvtools::opt::Instruction& inst } expr = AddressOfIfNeeded(expr, &inst); - auto* ast_const = - parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, expr.type, true, - false, expr.expr, ast::AttributeList{}); - if (!ast_const) { + auto* let = parser_impl_.MakeLet(inst.result_id(), expr.type, expr.expr); + if (!let) { return false; } - AddStatement(create(Source{}, ast_const)); + AddStatement(create(Source{}, let)); identifier_types_.emplace(inst.result_id(), expr.type); return success(); } diff --git a/src/tint/reader/spirv/parser_impl.cc b/src/tint/reader/spirv/parser_impl.cc index bdc52e695e..1ebf0b9770 100644 --- a/src/tint/reader/spirv/parser_impl.cc +++ b/src/tint/reader/spirv/parser_impl.cc @@ -1371,8 +1371,8 @@ bool ParserImpl::EmitScalarSpecConstants() { break; } } - auto* ast_var = MakeVariable(inst.result_id(), ast::StorageClass::kNone, ast_type, true, - true, ast_expr, std::move(spec_id_decos)); + auto* ast_var = + MakeOverride(inst.result_id(), ast_type, ast_expr, std::move(spec_id_decos)); if (ast_var) { builder_.AST().AddGlobalVariable(ast_var); scalar_spec_constants_.insert(inst.result_id()); @@ -1489,8 +1489,8 @@ bool ParserImpl::EmitModuleScopeVariables() { // here.) ast_constructor = MakeConstantExpression(var.GetSingleWordInOperand(1)).expr; } - auto* ast_var = MakeVariable(var.result_id(), ast_storage_class, ast_store_type, false, - false, ast_constructor, ast::AttributeList{}); + auto* ast_var = MakeVar(var.result_id(), ast_storage_class, ast_store_type, ast_constructor, + ast::AttributeList{}); // TODO(dneto): initializers (a.k.a. constructor expression) if (ast_var) { builder_.AST().AddGlobalVariable(ast_var); @@ -1521,10 +1521,9 @@ bool ParserImpl::EmitModuleScopeVariables() { } } auto* ast_var = - MakeVariable(builtin_position_.per_vertex_var_id, - enum_converter_.ToStorageClass(builtin_position_.storage_class), - ConvertType(builtin_position_.position_member_type_id), false, false, - ast_constructor, {}); + MakeVar(builtin_position_.per_vertex_var_id, + enum_converter_.ToStorageClass(builtin_position_.storage_class), + ConvertType(builtin_position_.position_member_type_id), ast_constructor, {}); builder_.AST().AddGlobalVariable(ast_var); } @@ -1554,13 +1553,11 @@ const spvtools::opt::analysis::IntConstant* ParserImpl::GetArraySize(uint32_t va return size->AsIntConstant(); } -ast::Variable* ParserImpl::MakeVariable(uint32_t id, - ast::StorageClass sc, - const Type* storage_type, - bool is_const, - bool is_overridable, - const ast::Expression* constructor, - ast::AttributeList decorations) { +ast::Var* ParserImpl::MakeVar(uint32_t id, + ast::StorageClass sc, + const Type* storage_type, + const ast::Expression* constructor, + ast::AttributeList decorations) { if (storage_type == nullptr) { Fail() << "internal error: can't make ast::Variable for null type"; return nullptr; @@ -1588,15 +1585,37 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id, return nullptr; } - std::string name = namer_.Name(id); + auto sym = builder_.Symbols().Register(namer_.Name(id)); + return create(Source{}, sym, storage_type->Build(builder_), sc, access, constructor, + decorations); +} - // Note: we're constructing the variable here with the *storage* type, - // regardless of whether this is a `let`, `override`, or `var` declaration. - // `var` declarations will have a resolved type of ref, but at the - // AST level all three are declared with the same type. - return create(Source{}, builder_.Symbols().Register(name), sc, access, - storage_type->Build(builder_), is_const, is_overridable, - constructor, decorations); +ast::Let* ParserImpl::MakeLet(uint32_t id, const Type* type, const ast::Expression* constructor) { + auto sym = builder_.Symbols().Register(namer_.Name(id)); + return create(Source{}, sym, type->Build(builder_), constructor, + ast::AttributeList{}); +} + +ast::Override* ParserImpl::MakeOverride(uint32_t id, + const Type* type, + const ast::Expression* constructor, + ast::AttributeList decorations) { + if (!ConvertDecorationsForVariable(id, &type, &decorations, false)) { + return nullptr; + } + auto sym = builder_.Symbols().Register(namer_.Name(id)); + return create(Source{}, sym, type->Build(builder_), constructor, decorations); +} + +ast::Parameter* ParserImpl::MakeParameter(uint32_t id, + const Type* type, + ast::AttributeList decorations) { + if (!ConvertDecorationsForVariable(id, &type, &decorations, false)) { + return nullptr; + } + + auto sym = builder_.Symbols().Register(namer_.Name(id)); + return create(Source{}, sym, type->Build(builder_), decorations); } bool ParserImpl::ConvertDecorationsForVariable(uint32_t id, diff --git a/src/tint/reader/spirv/parser_impl.h b/src/tint/reader/spirv/parser_impl.h index b91f1924a5..6addf6c412 100644 --- a/src/tint/reader/spirv/parser_impl.h +++ b/src/tint/reader/spirv/parser_impl.h @@ -411,25 +411,47 @@ class ParserImpl : Reader { /// @returns a list of SPIR-V decorations. DecorationList GetMemberPipelineDecorations(const Struct& struct_type, int member_index); - /// Creates an AST Variable node for a SPIR-V ID, including any attached - /// decorations, unless it's an ignorable builtin variable. + /// Creates an AST 'var' node for a SPIR-V ID, including any attached decorations, unless it's + /// an ignorable builtin variable. /// @param id the SPIR-V result ID /// @param sc the storage class, which cannot be ast::StorageClass::kNone /// @param storage_type the storage type of the variable - /// @param is_const if true, the variable is const - /// @param is_overridable if true, the variable is pipeline-overridable /// @param constructor the variable constructor /// @param decorations the variable decorations /// @returns a new Variable node, or null in the ignorable variable case and /// in the error case - ast::Variable* MakeVariable(uint32_t id, - ast::StorageClass sc, - const Type* storage_type, - bool is_const, - bool is_overridable, + ast::Var* MakeVar(uint32_t id, + ast::StorageClass sc, + const Type* storage_type, + const ast::Expression* constructor, + ast::AttributeList decorations); + + /// Creates an AST 'let' node for a SPIR-V ID, including any attached decorations,. + /// @param id the SPIR-V result ID + /// @param type the type of the variable + /// @param constructor the variable constructor + /// @returns the AST 'let' node + ast::Let* MakeLet(uint32_t id, const Type* type, const ast::Expression* constructor); + + /// Creates an AST 'override' node for a SPIR-V ID, including any attached decorations. + /// @param id the SPIR-V result ID + /// @param type the type of the variable + /// @param constructor the variable constructor + /// @param decorations the variable decorations + /// @returns the AST 'override' node + ast::Override* MakeOverride(uint32_t id, + const Type* type, const ast::Expression* constructor, ast::AttributeList decorations); + /// Creates an AST parameter node for a SPIR-V ID, including any attached decorations, unless + /// it's an ignorable builtin variable. + /// @param id the SPIR-V result ID + /// @param type the type of the parameter + /// @param decorations the parameter decorations + /// @returns the AST parameter node + ast::Parameter* MakeParameter(uint32_t id, const Type* type, ast::AttributeList decorations); + /// Returns true if a constant expression can be generated. /// @param id the SPIR-V ID of the value /// @returns true if a constant expression can be generated diff --git a/src/tint/reader/wgsl/parser_impl.cc b/src/tint/reader/wgsl/parser_impl.cc index 0c853bd0ad..854b276bb2 100644 --- a/src/tint/reader/wgsl/parser_impl.cc +++ b/src/tint/reader/wgsl/parser_impl.cc @@ -213,7 +213,11 @@ ParserImpl::FunctionHeader::FunctionHeader(Source src, ast::ParameterList p, const ast::Type* ret_ty, ast::AttributeList ret_attrs) - : source(src), name(n), params(p), return_type(ret_ty), return_type_attributes(ret_attrs) {} + : source(src), + name(n), + params(std::move(p)), + return_type(ret_ty), + return_type_attributes(std::move(ret_attrs)) {} ParserImpl::FunctionHeader::~FunctionHeader() = default; @@ -542,15 +546,13 @@ Maybe ParserImpl::global_variable_decl(ast::AttributeList& constructor = expr.value; } - return create(decl->source, // source - builder_.Symbols().Register(decl->name), // symbol - decl->storage_class, // storage class - decl->access, // access control - decl->type, // type - false, // is_const - false, // is_overridable - constructor, // constructor - std::move(attrs)); // attributes + return create(decl->source, // source + builder_.Symbols().Register(decl->name), // symbol + decl->type, // type + decl->storage_class, // storage class + decl->access, // access control + constructor, // constructor + std::move(attrs)); // attributes } // global_constant_decl : @@ -564,7 +566,7 @@ Maybe ParserImpl::global_constant_decl(ast::AttributeList& if (match(Token::Type::kLet)) { use = "'let' declaration"; } else if (match(Token::Type::kOverride)) { - use = "override declaration"; + use = "'override' declaration"; is_overridable = true; } else { return Failure::kNoMatch; @@ -594,15 +596,18 @@ Maybe ParserImpl::global_constant_decl(ast::AttributeList& initializer = std::move(init.value); } - return create(decl->source, // source - builder_.Symbols().Register(decl->name), // symbol - ast::StorageClass::kNone, // storage class - ast::Access::kUndefined, // access control - decl->type, // type - true, // is_const - is_overridable, // is_overridable - initializer, // constructor - std::move(attrs)); // attributes + if (is_overridable) { + return create(decl->source, // source + builder_.Symbols().Register(decl->name), // symbol + decl->type, // type + initializer, // constructor + std::move(attrs)); // attributes + } + return create(decl->source, // source + builder_.Symbols().Register(decl->name), // symbol + decl->type, // type + initializer, // constructor + std::move(attrs)); // attributes } // variable_decl @@ -1478,7 +1483,7 @@ Expect ParserImpl::expect_param_list() { // param // : attribute_list* variable_ident_decl -Expect ParserImpl::expect_param() { +Expect ParserImpl::expect_param() { auto attrs = attribute_list(); auto decl = expect_variable_ident_decl("parameter"); @@ -1486,21 +1491,10 @@ Expect ParserImpl::expect_param() { return Failure::kErrored; } - auto* var = create(decl->source, // source - builder_.Symbols().Register(decl->name), // symbol - ast::StorageClass::kNone, // storage class - ast::Access::kUndefined, // access control - decl->type, // type - true, // is_const - false, // is_overridable - nullptr, // constructor - std::move(attrs.value)); // attributes - // Formal parameters are treated like a const declaration where the - // initializer value is provided by the call's argument. The key point is - // that it's not updatable after initially set. This is unlike C or GLSL - // which treat formal parameters like local variables that can be updated. - - return var; + return create(decl->source, // source + builder_.Symbols().Register(decl->name), // symbol + decl->type, // type + std::move(attrs.value)); // attributes } // pipeline_stage @@ -1794,17 +1788,13 @@ Maybe ParserImpl::variable_stmt() { return add_error(peek(), "missing constructor for 'let' declaration"); } - auto* var = create(decl->source, // source - builder_.Symbols().Register(decl->name), // symbol - ast::StorageClass::kNone, // storage class - ast::Access::kUndefined, // access control - decl->type, // type - true, // is_const - false, // is_overridable - constructor.value, // constructor - ast::AttributeList{}); // attributes + auto* let = create(decl->source, // source + builder_.Symbols().Register(decl->name), // symbol + decl->type, // type + constructor.value, // constructor + ast::AttributeList{}); // attributes - return create(decl->source, var); + return create(decl->source, let); } auto decl = variable_decl(/*allow_inferred = */ true); @@ -1828,15 +1818,13 @@ Maybe ParserImpl::variable_stmt() { constructor = constructor_expr.value; } - auto* var = create(decl->source, // source - builder_.Symbols().Register(decl->name), // symbol - decl->storage_class, // storage class - decl->access, // access control - decl->type, // type - false, // is_const - false, // is_overridable - constructor, // constructor - ast::AttributeList{}); // attributes + auto* var = create(decl->source, // source + builder_.Symbols().Register(decl->name), // symbol + decl->type, // type + decl->storage_class, // storage class + decl->access, // access control + constructor, // constructor + ast::AttributeList{}); // attributes return create(var->source, var); } diff --git a/src/tint/reader/wgsl/parser_impl.h b/src/tint/reader/wgsl/parser_impl.h index 3f9b09017a..c6e696bc91 100644 --- a/src/tint/reader/wgsl/parser_impl.h +++ b/src/tint/reader/wgsl/parser_impl.h @@ -462,7 +462,7 @@ class ParserImpl { Expect expect_param_list(); /// Parses a `param` grammar element, erroring on parse failure. /// @returns the parsed variable - Expect expect_param(); + Expect expect_param(); /// Parses a `pipeline_stage` grammar element, erroring if the next token does /// not match a stage name. /// @returns the pipeline stage. diff --git a/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc b/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc index 3ae7a32f7c..4f6beb032b 100644 --- a/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc +++ b/src/tint/reader/wgsl/parser_impl_for_stmt_test.cc @@ -57,7 +57,7 @@ TEST_F(ForStmtTest, InitializerStatementDecl) { ASSERT_TRUE(fl.matched); ASSERT_TRUE(Is(fl->initializer)); auto* var = fl->initializer->As()->variable; - EXPECT_FALSE(var->is_const); + EXPECT_TRUE(var->Is()); EXPECT_EQ(var->constructor, nullptr); EXPECT_EQ(fl->condition, nullptr); EXPECT_EQ(fl->continuing, nullptr); @@ -74,7 +74,7 @@ TEST_F(ForStmtTest, InitializerStatementDeclEqual) { ASSERT_TRUE(fl.matched); ASSERT_TRUE(Is(fl->initializer)); auto* var = fl->initializer->As()->variable; - EXPECT_FALSE(var->is_const); + EXPECT_TRUE(var->Is()); EXPECT_NE(var->constructor, nullptr); EXPECT_EQ(fl->condition, nullptr); EXPECT_EQ(fl->continuing, nullptr); @@ -90,7 +90,7 @@ TEST_F(ForStmtTest, InitializerStatementConstDecl) { ASSERT_TRUE(fl.matched); ASSERT_TRUE(Is(fl->initializer)); auto* var = fl->initializer->As()->variable; - EXPECT_TRUE(var->is_const); + EXPECT_TRUE(var->Is()); EXPECT_NE(var->constructor, nullptr); EXPECT_EQ(fl->condition, nullptr); EXPECT_EQ(fl->continuing, nullptr); diff --git a/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc b/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc index 7a5bb452d7..f5428fc13a 100644 --- a/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc +++ b/src/tint/reader/wgsl/parser_impl_global_constant_decl_test.cc @@ -27,21 +27,20 @@ TEST_F(ParserImplTest, GlobalConstantDecl) { EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* let = e.value->As(); + ASSERT_NE(let, nullptr); - EXPECT_TRUE(e->is_const); - EXPECT_FALSE(e->is_overridable); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - ASSERT_NE(e->type, nullptr); - EXPECT_TRUE(e->type->Is()); + EXPECT_EQ(let->symbol, p->builder().Symbols().Get("a")); + ASSERT_NE(let->type, nullptr); + EXPECT_TRUE(let->type->Is()); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 5u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 6u); + EXPECT_EQ(let->source.range.begin.line, 1u); + EXPECT_EQ(let->source.range.begin.column, 5u); + EXPECT_EQ(let->source.range.end.line, 1u); + EXPECT_EQ(let->source.range.end.column, 6u); - ASSERT_NE(e->constructor, nullptr); - EXPECT_TRUE(e->constructor->Is()); + ASSERT_NE(let->constructor, nullptr); + EXPECT_TRUE(let->constructor->Is()); } TEST_F(ParserImplTest, GlobalConstantDecl_Inferred) { @@ -53,20 +52,19 @@ TEST_F(ParserImplTest, GlobalConstantDecl_Inferred) { EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* let = e.value->As(); + ASSERT_NE(let, nullptr); - EXPECT_TRUE(e->is_const); - EXPECT_FALSE(e->is_overridable); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - EXPECT_EQ(e->type, nullptr); + EXPECT_EQ(let->symbol, p->builder().Symbols().Get("a")); + EXPECT_EQ(let->type, nullptr); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 5u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 6u); + EXPECT_EQ(let->source.range.begin.line, 1u); + EXPECT_EQ(let->source.range.begin.column, 5u); + EXPECT_EQ(let->source.range.end.line, 1u); + EXPECT_EQ(let->source.range.end.column, 6u); - ASSERT_NE(e->constructor, nullptr); - EXPECT_TRUE(e->constructor->Is()); + ASSERT_NE(let->constructor, nullptr); + EXPECT_TRUE(let->constructor->Is()); } TEST_F(ParserImplTest, GlobalConstantDecl_InvalidExpression) { @@ -105,23 +103,22 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_WithId) { EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* override = e.value->As(); + ASSERT_NE(override, nullptr); - EXPECT_TRUE(e->is_const); - EXPECT_TRUE(e->is_overridable); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - ASSERT_NE(e->type, nullptr); - EXPECT_TRUE(e->type->Is()); + EXPECT_EQ(override->symbol, p->builder().Symbols().Get("a")); + ASSERT_NE(override->type, nullptr); + EXPECT_TRUE(override->type->Is()); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 17u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 18u); + EXPECT_EQ(override->source.range.begin.line, 1u); + EXPECT_EQ(override->source.range.begin.column, 17u); + EXPECT_EQ(override->source.range.end.line, 1u); + EXPECT_EQ(override->source.range.end.column, 18u); - ASSERT_NE(e->constructor, nullptr); - EXPECT_TRUE(e->constructor->Is()); + ASSERT_NE(override->constructor, nullptr); + EXPECT_TRUE(override->constructor->Is()); - auto* override_attr = ast::GetAttribute(e.value->attributes); + auto* override_attr = ast::GetAttribute(override->attributes); ASSERT_NE(override_attr, nullptr); EXPECT_EQ(override_attr->value, 7u); } @@ -136,23 +133,22 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_WithoutId) { EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* override = e.value->As(); + ASSERT_NE(override, nullptr); - EXPECT_TRUE(e->is_const); - EXPECT_TRUE(e->is_overridable); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - ASSERT_NE(e->type, nullptr); - EXPECT_TRUE(e->type->Is()); + EXPECT_EQ(override->symbol, p->builder().Symbols().Get("a")); + ASSERT_NE(override->type, nullptr); + EXPECT_TRUE(override->type->Is()); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 10u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 11u); + EXPECT_EQ(override->source.range.begin.line, 1u); + EXPECT_EQ(override->source.range.begin.column, 10u); + EXPECT_EQ(override->source.range.end.line, 1u); + EXPECT_EQ(override->source.range.end.column, 11u); - ASSERT_NE(e->constructor, nullptr); - EXPECT_TRUE(e->constructor->Is()); + ASSERT_NE(override->constructor, nullptr); + EXPECT_TRUE(override->constructor->Is()); - auto* id_attr = ast::GetAttribute(e.value->attributes); + auto* id_attr = ast::GetAttribute(override->attributes); ASSERT_EQ(id_attr, nullptr); } @@ -165,7 +161,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_MissingId) { auto e = p->global_constant_decl(attrs.value); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* override = e.value->As(); + ASSERT_NE(override, nullptr); EXPECT_TRUE(p->has_error()); EXPECT_EQ(p->error(), "1:5: expected signed integer literal for id attribute"); @@ -180,7 +177,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_InvalidId) { auto e = p->global_constant_decl(attrs.value); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* override = e.value->As(); + ASSERT_NE(override, nullptr); EXPECT_TRUE(p->has_error()); EXPECT_EQ(p->error(), "1:5: id attribute must be positive"); diff --git a/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc b/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc index 57f90b9573..11371e6997 100644 --- a/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc +++ b/src/tint/reader/wgsl/parser_impl_global_variable_decl_test.cc @@ -26,18 +26,19 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithoutConstructor) { ASSERT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* var = e.value->As(); + ASSERT_NE(var, nullptr); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - EXPECT_TRUE(e->type->Is()); - EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kPrivate); + EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a")); + EXPECT_TRUE(var->type->Is()); + EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kPrivate); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 14u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 15u); + EXPECT_EQ(var->source.range.begin.line, 1u); + EXPECT_EQ(var->source.range.begin.column, 14u); + EXPECT_EQ(var->source.range.end.line, 1u); + EXPECT_EQ(var->source.range.end.column, 15u); - ASSERT_EQ(e->constructor, nullptr); + ASSERT_EQ(var->constructor, nullptr); } TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) { @@ -49,19 +50,20 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) { ASSERT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* var = e.value->As(); + ASSERT_NE(var, nullptr); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - EXPECT_TRUE(e->type->Is()); - EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kPrivate); + EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a")); + EXPECT_TRUE(var->type->Is()); + EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kPrivate); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 14u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 15u); + EXPECT_EQ(var->source.range.begin.line, 1u); + EXPECT_EQ(var->source.range.begin.column, 14u); + EXPECT_EQ(var->source.range.end.line, 1u); + EXPECT_EQ(var->source.range.end.column, 15u); - ASSERT_NE(e->constructor, nullptr); - ASSERT_TRUE(e->constructor->Is()); + ASSERT_NE(var->constructor, nullptr); + ASSERT_TRUE(var->constructor->Is()); } TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute) { @@ -73,21 +75,22 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute) { ASSERT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* var = e.value->As(); + ASSERT_NE(var, nullptr); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - ASSERT_NE(e->type, nullptr); - EXPECT_TRUE(e->type->Is()); - EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kUniform); + EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a")); + ASSERT_NE(var->type, nullptr); + EXPECT_TRUE(var->type->Is()); + EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kUniform); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 36u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 37u); + EXPECT_EQ(var->source.range.begin.line, 1u); + EXPECT_EQ(var->source.range.begin.column, 36u); + EXPECT_EQ(var->source.range.end.line, 1u); + EXPECT_EQ(var->source.range.end.column, 37u); - ASSERT_EQ(e->constructor, nullptr); + ASSERT_EQ(var->constructor, nullptr); - auto& attributes = e->attributes; + auto& attributes = var->attributes; ASSERT_EQ(attributes.size(), 2u); ASSERT_TRUE(attributes[0]->Is()); ASSERT_TRUE(attributes[1]->Is()); @@ -103,21 +106,22 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute_MulitpleGroups) { ASSERT_FALSE(p->has_error()) << p->error(); EXPECT_TRUE(e.matched); EXPECT_FALSE(e.errored); - ASSERT_NE(e.value, nullptr); + auto* var = e.value->As(); + ASSERT_NE(var, nullptr); - EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); - ASSERT_NE(e->type, nullptr); - EXPECT_TRUE(e->type->Is()); - EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kUniform); + EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a")); + ASSERT_NE(var->type, nullptr); + EXPECT_TRUE(var->type->Is()); + EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kUniform); - EXPECT_EQ(e->source.range.begin.line, 1u); - EXPECT_EQ(e->source.range.begin.column, 36u); - EXPECT_EQ(e->source.range.end.line, 1u); - EXPECT_EQ(e->source.range.end.column, 37u); + EXPECT_EQ(var->source.range.begin.line, 1u); + EXPECT_EQ(var->source.range.begin.column, 36u); + EXPECT_EQ(var->source.range.end.line, 1u); + EXPECT_EQ(var->source.range.end.column, 37u); - ASSERT_EQ(e->constructor, nullptr); + ASSERT_EQ(var->constructor, nullptr); - auto& attributes = e->attributes; + auto& attributes = var->attributes; ASSERT_EQ(attributes.size(), 2u); ASSERT_TRUE(attributes[0]->Is()); ASSERT_TRUE(attributes[1]->Is()); diff --git a/src/tint/reader/wgsl/parser_impl_param_list_test.cc b/src/tint/reader/wgsl/parser_impl_param_list_test.cc index 2d79a99da1..99cd1f2d8e 100644 --- a/src/tint/reader/wgsl/parser_impl_param_list_test.cc +++ b/src/tint/reader/wgsl/parser_impl_param_list_test.cc @@ -27,7 +27,7 @@ TEST_F(ParserImplTest, ParamList_Single) { EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a")); EXPECT_TRUE(e.value[0]->type->Is()); - EXPECT_TRUE(e.value[0]->is_const); + EXPECT_TRUE(e.value[0]->Is()); ASSERT_EQ(e.value[0]->source.range.begin.line, 1u); ASSERT_EQ(e.value[0]->source.range.begin.column, 1u); @@ -45,7 +45,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) { EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a")); EXPECT_TRUE(e.value[0]->type->Is()); - EXPECT_TRUE(e.value[0]->is_const); + EXPECT_TRUE(e.value[0]->Is()); ASSERT_EQ(e.value[0]->source.range.begin.line, 1u); ASSERT_EQ(e.value[0]->source.range.begin.column, 1u); @@ -54,7 +54,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) { EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("b")); EXPECT_TRUE(e.value[1]->type->Is()); - EXPECT_TRUE(e.value[1]->is_const); + EXPECT_TRUE(e.value[1]->Is()); ASSERT_EQ(e.value[1]->source.range.begin.line, 1u); ASSERT_EQ(e.value[1]->source.range.begin.column, 10u); @@ -65,7 +65,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) { ASSERT_TRUE(e.value[2]->type->Is()); ASSERT_TRUE(e.value[2]->type->As()->type->Is()); EXPECT_EQ(e.value[2]->type->As()->width, 2u); - EXPECT_TRUE(e.value[2]->is_const); + EXPECT_TRUE(e.value[2]->Is()); ASSERT_EQ(e.value[2]->source.range.begin.line, 1u); ASSERT_EQ(e.value[2]->source.range.begin.column, 18u); @@ -101,7 +101,7 @@ TEST_F(ParserImplTest, ParamList_Attributes) { ASSERT_TRUE(e.value[0]->type->Is()); EXPECT_TRUE(e.value[0]->type->As()->type->Is()); EXPECT_EQ(e.value[0]->type->As()->width, 4u); - EXPECT_TRUE(e.value[0]->is_const); + EXPECT_TRUE(e.value[0]->Is()); auto attrs_0 = e.value[0]->attributes; ASSERT_EQ(attrs_0.size(), 1u); EXPECT_TRUE(attrs_0[0]->Is()); @@ -114,7 +114,7 @@ TEST_F(ParserImplTest, ParamList_Attributes) { EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("loc1")); EXPECT_TRUE(e.value[1]->type->Is()); - EXPECT_TRUE(e.value[1]->is_const); + EXPECT_TRUE(e.value[1]->Is()); auto attrs_1 = e.value[1]->attributes; ASSERT_EQ(attrs_1.size(), 1u); EXPECT_TRUE(attrs_1[0]->Is()); diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index cf6114950b..6b3b779752 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -487,11 +487,13 @@ struct DependencyAnalysis { /// declaration std::string KindOf(const ast::Node* node) { return Switch( - node, // - [&](const ast::Struct*) { return "struct"; }, - [&](const ast::Alias*) { return "alias"; }, - [&](const ast::Function*) { return "function"; }, - [&](const ast::Variable* var) { return var->is_const ? "let" : "var"; }, + node, // + [&](const ast::Struct*) { return "struct"; }, // + [&](const ast::Alias*) { return "alias"; }, // + [&](const ast::Function*) { return "function"; }, // + [&](const ast::Let*) { return "let"; }, // + [&](const ast::Var*) { return "var"; }, // + [&](const ast::Override*) { return "override"; }, // [&](Default) { UnhandledNode(diagnostics_, node); return ""; diff --git a/src/tint/resolver/pipeline_overridable_constant_test.cc b/src/tint/resolver/pipeline_overridable_constant_test.cc index 035936c3f3..583cc16881 100644 --- a/src/tint/resolver/pipeline_overridable_constant_test.cc +++ b/src/tint/resolver/pipeline_overridable_constant_test.cc @@ -31,7 +31,7 @@ class ResolverPipelineOverridableConstantTest : public ResolverTest { auto* sem = Sem().Get(var); ASSERT_NE(sem, nullptr); EXPECT_EQ(sem->Declaration(), var); - EXPECT_TRUE(sem->IsOverridable()); + EXPECT_TRUE(sem->Declaration()->Is()); EXPECT_EQ(sem->ConstantId(), id); EXPECT_FALSE(sem->ConstantValue()); } @@ -45,7 +45,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) { auto* sem_a = Sem().Get(a); ASSERT_NE(sem_a, nullptr); EXPECT_EQ(sem_a->Declaration(), a); - EXPECT_FALSE(sem_a->IsOverridable()); + EXPECT_FALSE(sem_a->Declaration()->Is()); EXPECT_TRUE(sem_a->ConstantValue()); } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 3d58930f3c..3438a136f7 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -303,59 +303,63 @@ sem::Type* Resolver::Type(const ast::Type* ty) { return s; } -sem::Variable* Resolver::Variable(const ast::Variable* var, - VariableKind kind, +sem::Variable* Resolver::Variable(const ast::Variable* v, + bool is_global, uint32_t index /* = 0 */) { const sem::Type* storage_ty = nullptr; // If the variable has a declared type, resolve it. - if (auto* ty = var->type) { + if (auto* ty = v->type) { storage_ty = Type(ty); if (!storage_ty) { return nullptr; } } + auto* as_var = v->As(); + auto* as_let = v->As(); + auto* as_override = v->As(); + auto* as_param = v->As(); + const sem::Expression* rhs = nullptr; // Does the variable have a constructor? - if (var->constructor) { - rhs = Materialize(Expression(var->constructor), storage_ty); + if (v->constructor) { + rhs = Materialize(Expression(v->constructor), storage_ty); if (!rhs) { return nullptr; } // If the variable has no declared type, infer it from the RHS if (!storage_ty) { - if (!var->is_const && kind == VariableKind::kGlobal) { - AddError("module-scope 'var' declaration must specify a type", var->source); + if (as_var && is_global) { + AddError("module-scope 'var' declaration must specify a type", v->source); return nullptr; } storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS } - } else if (var->is_const && !var->is_overridable && kind != VariableKind::kParameter) { - AddError("'let' declaration must have an initializer", var->source); + } else if (as_let) { + AddError("'let' declaration must have an initializer", v->source); return nullptr; - } else if (!var->type) { - AddError((kind == VariableKind::kGlobal) - ? "module-scope 'var' declaration requires a type or initializer" - : "function-scope 'var' declaration requires a type or initializer", - var->source); + } else if (!v->type) { + AddError((is_global) ? "module-scope 'var' declaration requires a type or initializer" + : "function-scope 'var' declaration requires a type or initializer", + v->source); return nullptr; } if (!storage_ty) { TINT_ICE(Resolver, diagnostics_) << "failed to determine storage type for variable '" + - builder_->Symbols().NameFor(var->symbol) + "'\n" - << "Source: " << var->source; + builder_->Symbols().NameFor(v->symbol) + "'\n" + << "Source: " << v->source; return nullptr; } - auto storage_class = var->declared_storage_class; - if (storage_class == ast::StorageClass::kNone && !var->is_const) { + auto storage_class = as_var ? as_var->declared_storage_class : ast::StorageClass::kNone; + if (storage_class == ast::StorageClass::kNone && as_var) { // No declared storage class. Infer from usage / type. - if (kind == VariableKind::kLocal) { + if (!is_global) { storage_class = ast::StorageClass::kFunction; } else if (storage_ty->UnwrapRef()->is_handle()) { // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables @@ -366,93 +370,83 @@ sem::Variable* Resolver::Variable(const ast::Variable* var, } } - if (kind == VariableKind::kLocal && !var->is_const && - storage_class != ast::StorageClass::kFunction && - validator_.IsValidationEnabled(var->attributes, + if (!is_global && as_var && storage_class != ast::StorageClass::kFunction && + validator_.IsValidationEnabled(v->attributes, ast::DisabledValidation::kIgnoreStorageClass)) { - AddError("function-scope 'var' declaration must use 'function' storage class", var->source); + AddError("function-scope 'var' declaration must use 'function' storage class", v->source); return nullptr; } - auto access = var->declared_access; + auto access = as_var ? as_var->declared_access : ast::Access::kUndefined; if (access == ast::Access::kUndefined) { access = DefaultAccessForStorageClass(storage_class); } auto* var_ty = storage_ty; - if (!var->is_const) { - // Variable declaration. Unlike `let`, `var` has storage. + if (as_var) { + // Variable declaration. Unlike `let` and parameters, `var` has storage. // Variables are always of a reference type to the declared storage type. var_ty = builder_->create(storage_ty, storage_class, access); } - if (rhs && !validator_.VariableConstructorOrCast(var, storage_class, storage_ty, rhs->Type())) { + if (rhs && !validator_.VariableConstructorOrCast(v, storage_class, storage_ty, rhs->Type())) { return nullptr; } - if (!ApplyStorageClassUsageToType(storage_class, const_cast(var_ty), var->source)) { - AddNote(std::string("while instantiating ") + - ((kind == VariableKind::kParameter) ? "parameter " : "variable ") + - builder_->Symbols().NameFor(var->symbol), - var->source); + if (!ApplyStorageClassUsageToType(storage_class, const_cast(var_ty), v->source)) { + AddNote(std::string("while instantiating ") + ((as_param) ? "parameter " : "variable ") + + builder_->Symbols().NameFor(v->symbol), + v->source); return nullptr; } - if (kind == VariableKind::kParameter) { + if (as_param) { if (auto* ptr = var_ty->As()) { // For MSL, we push module-scope variables into the entry point as pointer // parameters, so we also need to handle their store type. if (!ApplyStorageClassUsageToType( - ptr->StorageClass(), const_cast(ptr->StoreType()), var->source)) { - AddNote("while instantiating parameter " + builder_->Symbols().NameFor(var->symbol), - var->source); + ptr->StorageClass(), const_cast(ptr->StoreType()), v->source)) { + AddNote("while instantiating parameter " + builder_->Symbols().NameFor(v->symbol), + v->source); return nullptr; } } + auto* param = + builder_->create(as_param, index, var_ty, storage_class, access); + builder_->Sem().Add(as_param, param); + return param; } - switch (kind) { - case VariableKind::kGlobal: { - sem::BindingPoint binding_point; - if (auto bp = var->BindingPoint()) { + if (is_global) { + sem::BindingPoint binding_point; + if (as_var) { + if (auto bp = as_var->BindingPoint()) { binding_point = {bp.group->value, bp.binding->value}; } + } - bool has_const_val = rhs && var->is_const && !var->is_overridable; - auto* global = builder_->create( - var, var_ty, storage_class, access, - has_const_val ? rhs->ConstantValue() : sem::Constant{}, binding_point); + bool has_const_val = rhs && as_let && !as_override; + auto* global = builder_->create( + v, var_ty, storage_class, access, + has_const_val ? rhs->ConstantValue() : sem::Constant{}, binding_point); - if (var->is_overridable) { - global->SetIsOverridable(); - if (auto* id = ast::GetAttribute(var->attributes)) { - global->SetConstantId(static_cast(id->value)); - } + if (as_override) { + if (auto* id = ast::GetAttribute(v->attributes)) { + global->SetConstantId(static_cast(id->value)); } + } - global->SetConstructor(rhs); - - builder_->Sem().Add(var, global); - return global; - } - case VariableKind::kLocal: { - auto* local = builder_->create( - var, var_ty, storage_class, access, current_statement_, - (rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{}); - builder_->Sem().Add(var, local); - local->SetConstructor(rhs); - return local; - } - case VariableKind::kParameter: { - auto* param = - builder_->create(var, index, var_ty, storage_class, access); - builder_->Sem().Add(var, param); - return param; - } + global->SetConstructor(rhs); + builder_->Sem().Add(v, global); + return global; } - TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled VariableKind " << static_cast(kind); - return nullptr; + auto* local = builder_->create( + v, var_ty, storage_class, access, current_statement_, + (rhs && as_let) ? rhs->ConstantValue() : sem::Constant{}); + builder_->Sem().Add(v, local); + local->SetConstructor(rhs); + return local; } ast::Access Resolver::DefaultAccessForStorageClass(ast::StorageClass storage_class) { @@ -477,13 +471,13 @@ void Resolver::AllocateOverridableConstantIds() { // TODO(crbug.com/tint/1192): If a transform changes the order or removes an // unused constant, the allocation may change on the next Resolver pass. for (auto* decl : builder_->AST().GlobalDeclarations()) { - auto* var = decl->As(); - if (!var || !var->is_overridable) { + auto* override = decl->As(); + if (!override) { continue; } uint16_t constant_id; - if (auto* id_attr = ast::GetAttribute(var->attributes)) { + if (auto* id_attr = ast::GetAttribute(override->attributes)) { constant_id = static_cast(id_attr->value); } else { // No ID was specified, so allocate the next available ID. @@ -499,7 +493,7 @@ void Resolver::AllocateOverridableConstantIds() { next_constant_id = constant_id + 1; } - auto* sem = sem_.Get(var); + auto* sem = sem_.Get(override); const_cast(sem)->SetConstantId(constant_id); } } @@ -513,25 +507,21 @@ void Resolver::SetShadows() { } } -sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) { - auto* sem = Variable(var, VariableKind::kGlobal); +sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { + auto* sem = As(Variable(v, /* is_global */ true)); if (!sem) { return nullptr; } + const bool is_var = v->Is(); + auto storage_class = sem->StorageClass(); - if (!var->is_const && storage_class == ast::StorageClass::kNone) { - AddError("module-scope 'var' declaration must have a storage class", var->source); - return nullptr; - } - if (var->is_const && storage_class != ast::StorageClass::kNone) { - AddError(var->is_overridable ? "'override' declaration must not have a storage class" - : "'let' declaration must not have a storage class", - var->source); + if (is_var && storage_class == ast::StorageClass::kNone) { + AddError("module-scope 'var' declaration must have a storage class", v->source); return nullptr; } - for (auto* attr : var->attributes) { + for (auto* attr : v->attributes) { Mark(attr); if (auto* id_attr = attr->As()) { @@ -540,7 +530,7 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) { } } - if (!validator_.NoDuplicateAttributes(var->attributes)) { + if (!validator_.NoDuplicateAttributes(v->attributes)) { return nullptr; } @@ -576,9 +566,8 @@ sem::Function* Resolver::Function(const ast::Function* decl) { } } - auto* var = - As(Variable(param, VariableKind::kParameter, parameter_index++)); - if (!var) { + auto* p = As(Variable(param, false, parameter_index++)); + if (!p) { return nullptr; } @@ -589,10 +578,10 @@ sem::Function* Resolver::Function(const ast::Function* decl) { return nullptr; } - parameters.emplace_back(var); + parameters.emplace_back(p); - auto* var_ty = const_cast(var->Type()); - if (auto* str = var_ty->As()) { + auto* p_ty = const_cast(p->Type()); + if (auto* str = p_ty->As()) { switch (decl->PipelineStage()) { case ast::PipelineStage::kVertex: str->AddUsage(sem::PipelineStageUsage::kVertexInput); @@ -777,12 +766,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { if (auto* user = args[i]->As()) { // We have an variable of a module-scope constant. auto* decl = user->Variable()->Declaration(); - if (!decl->is_const) { + if (!decl->IsAnyOf()) { AddError(kErrBadType, values[i]->source); return false; } // Capture the constant if it is pipeline-overridable. - if (decl->is_overridable) { + if (decl->Is()) { ws[i].overridable_const = decl; } @@ -2104,19 +2093,19 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return nullptr; } + constexpr const char* kErrInvalidExpr = + "array size identifier must be a literal or a module-scope 'let'"; + if (auto* ident = count_expr->As()) { - // Make sure the identifier is a non-overridable module-scope constant. - auto* var = sem_.ResolvedSymbol(ident); - if (!var || !var->Declaration()->is_const || var->IsOverridable()) { - AddError("array size identifier must be a literal or a module-scope 'let'", - size_source); + // Make sure the identifier is a non-overridable module-scope 'let'. + auto* global = sem_.ResolvedSymbol(ident); + if (!global || !global->Declaration()->Is()) { + AddError(kErrInvalidExpr, size_source); return nullptr; } - - count_expr = var->Declaration()->constructor; + count_expr = global->Declaration()->constructor; } else if (!count_expr->Is()) { - AddError("array size identifier must be a literal or a module-scope 'let'", - size_source); + AddError(kErrInvalidExpr, size_source); return nullptr; } @@ -2437,7 +2426,7 @@ sem::Statement* Resolver::VariableDeclStatement(const ast::VariableDeclStatement return StatementScope(stmt, sem, [&] { Mark(stmt->variable); - auto* var = Variable(stmt->variable, VariableKind::kLocal); + auto* var = Variable(stmt->variable, /* is_global */ false); if (!var) { return false; } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index e619322312..5f24169394 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -109,9 +109,6 @@ class Resolver { const Validator* GetValidatorForTesting() const { return &validator_; } private: - /// Describes the context in which a variable is declared - enum class VariableKind { kParameter, kLocal, kGlobal }; - Validator::ValidTypeStorageLayouts valid_type_storage_layouts_; /// Structure holding semantic information about a block (i.e. scope), such as @@ -298,9 +295,9 @@ class Resolver { /// @note this method does not resolve the attributes as these are /// context-dependent (global, local, parameter) /// @param var the variable to create or return the `VariableInfo` for - /// @param kind what kind of variable we are declaring + /// @param is_global true if this is module scope, otherwise function scope /// @param index the index of the parameter, if this variable is a parameter - sem::Variable* Variable(const ast::Variable* var, VariableKind kind, uint32_t index = 0); + sem::Variable* Variable(const ast::Variable* var, bool is_global, uint32_t index = 0); /// Records the storage class usage for the given type, and any transient /// dependencies of the type. Validates that the type can be used for the diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc index 27ee04d606..8f7ef1a961 100644 --- a/src/tint/resolver/type_validation_test.cc +++ b/src/tint/resolver/type_validation_test.cc @@ -87,26 +87,6 @@ TEST_F(ResolverTypeValidationTest, GlobalVariableWithStorageClass_Pass) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } -TEST_F(ResolverTypeValidationTest, GlobalLetWithStorageClass_Fail) { - // let global_var: f32; - AST().AddGlobalVariable(create( - Source{{12, 34}}, Symbols().Register("global_let"), ast::StorageClass::kPrivate, - ast::Access::kUndefined, ty.f32(), true, false, Expr(1.23_f), ast::AttributeList{})); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must not have a storage class"); -} - -TEST_F(ResolverTypeValidationTest, OverrideWithStorageClass_Fail) { - // let global_var: f32; - AST().AddGlobalVariable(create( - Source{{12, 34}}, Symbols().Register("global_override"), ast::StorageClass::kPrivate, - ast::Access::kUndefined, ty.f32(), true, true, Expr(1.23_f), ast::AttributeList{})); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: 'override' declaration must not have a storage class"); -} - TEST_F(ResolverTypeValidationTest, GlobalConstNoStorageClass_Pass) { // let global_var: f32; GlobalConst(Source{{12, 34}}, "global_var", ty.f32(), Construct(ty.f32())); diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index 18fa425c15..15ca72f5bf 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -949,7 +949,7 @@ class UniformityGraph { } current_function_->variables.Set(sem_.Get(decl->variable), node); - if (!decl->variable->is_const) { + if (decl->variable->Is()) { current_function_->local_var_decls.insert( sem_.Get(decl->variable)); } @@ -1018,7 +1018,8 @@ class UniformityGraph { }, [&](const sem::GlobalVariable* global) { - if (global->Declaration()->is_const || global->Access() == ast::Access::kRead) { + if (!global->Declaration()->Is() || + global->Access() == ast::Access::kRead) { node->AddEdge(cf); } else { node->AddEdge(current_function_->may_be_non_uniform); diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index 39eb31a196..7562b797b8 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -297,7 +297,7 @@ bool Validator::Materialize(const sem::Materialize* m) const { return true; } -bool Validator::VariableConstructorOrCast(const ast::Variable* var, +bool Validator::VariableConstructorOrCast(const ast::Variable* v, ast::StorageClass storage_class, const sem::Type* storage_ty, const sem::Type* rhs_ty) const { @@ -305,14 +305,14 @@ bool Validator::VariableConstructorOrCast(const ast::Variable* var, // Value type has to match storage type if (storage_ty != value_type) { - std::string decl = var->is_const ? "let" : "var"; + std::string decl = v->Is() ? "let" : "var"; AddError("cannot initialize " + decl + " of type '" + sem_.TypeNameOf(storage_ty) + "' with value of type '" + sem_.TypeNameOf(rhs_ty) + "'", - var->source); + v->source); return false; } - if (!var->is_const) { + if (v->Is()) { switch (storage_class) { case ast::StorageClass::kPrivate: case ast::StorageClass::kFunction: @@ -325,7 +325,7 @@ bool Validator::VariableConstructorOrCast(const ast::Variable* var, "' cannot have an initializer. var initializers are only " "supported for the storage classes " "'private' and 'function'", - var->source); + v->source); return false; } } @@ -502,21 +502,22 @@ bool Validator::StorageClassLayout(const sem::Variable* var, } bool Validator::GlobalVariable( - const sem::Variable* var, + const sem::GlobalVariable* global, std::unordered_map constant_ids, std::unordered_map atomic_composite_info) const { - auto* decl = var->Declaration(); + auto* decl = global->Declaration(); if (!NoDuplicateAttributes(decl->attributes)) { return false; } - for (auto* attr : decl->attributes) { - if (decl->is_const) { - if (decl->is_overridable) { + bool ok = Switch( + decl, // + [&](const ast::Override*) { + for (auto* attr : decl->attributes) { if (auto* id_attr = attr->As()) { uint32_t id = id_attr->value; auto it = constant_ids.find(id); - if (it != constant_ids.end() && it->second != var) { + if (it != constant_ids.end() && it->second != global) { AddError("pipeline constant IDs must be unique", attr->source); AddNote("a pipeline constant with an ID of " + std::to_string(id) + " was previously declared here:", @@ -533,32 +534,45 @@ bool Validator::GlobalVariable( AddError("attribute is not valid for 'override' declaration", attr->source); return false; } - } else { - AddError("attribute is not valid for module-scope 'let' declaration", attr->source); + } + return true; + }, + [&](const ast::Let*) { + if (!decl->attributes.empty()) { + AddError("attribute is not valid for module-scope 'let' declaration", + decl->attributes[0]->source); return false; } - } else { - bool is_shader_io_attribute = - attr->IsAnyOf(); - bool has_io_storage_class = var->StorageClass() == ast::StorageClass::kInput || - var->StorageClass() == ast::StorageClass::kOutput; - if (!(attr->IsAnyOf()) && - (!is_shader_io_attribute || !has_io_storage_class)) { - AddError("attribute is not valid for module-scope 'var'", attr->source); - return false; + return true; + }, + [&](const ast::Var*) { + for (auto* attr : decl->attributes) { + bool is_shader_io_attribute = + attr->IsAnyOf(); + bool has_io_storage_class = global->StorageClass() == ast::StorageClass::kInput || + global->StorageClass() == ast::StorageClass::kOutput; + if (!attr->IsAnyOf() && + (!is_shader_io_attribute || !has_io_storage_class)) { + AddError("attribute is not valid for module-scope 'var'", attr->source); + return false; + } } - } + return true; + }); + + if (!ok) { + return false; } - if (var->StorageClass() == ast::StorageClass::kFunction) { + if (global->StorageClass() == ast::StorageClass::kFunction) { AddError("module-scope 'var' must not use storage class 'function'", decl->source); return false; } auto binding_point = decl->BindingPoint(); - switch (var->StorageClass()) { + switch (global->StorageClass()) { case ast::StorageClass::kUniform: case ast::StorageClass::kStorage: case ast::StorageClass::kHandle: { @@ -581,23 +595,23 @@ bool Validator::GlobalVariable( } } - // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration - // The access mode always has a default, and except for variables in the - // storage storage class, must not be written. - if (var->StorageClass() != ast::StorageClass::kStorage && - decl->declared_access != ast::Access::kUndefined) { - AddError("only variables in storage class may declare an access mode", - decl->source); - return false; - } + if (auto* var = decl->As()) { + // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration + // The access mode always has a default, and except for variables in the + // storage storage class, must not be written. + if (global->StorageClass() != ast::StorageClass::kStorage && + var->declared_access != ast::Access::kUndefined) { + AddError("only variables in storage class may declare an access mode", + var->source); + return false; + } - if (!decl->is_const) { - if (!AtomicVariable(var, atomic_composite_info)) { + if (!AtomicVariable(global, atomic_composite_info)) { return false; } } - return Variable(var); + return Variable(global); } // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types @@ -641,14 +655,17 @@ bool Validator::AtomicVariable( return true; } -bool Validator::Variable(const sem::Variable* var) const { - auto* decl = var->Declaration(); - auto* storage_ty = var->Type()->UnwrapRef(); +bool Validator::Variable(const sem::Variable* v) const { + auto* decl = v->Declaration(); + auto* storage_ty = v->Type()->UnwrapRef(); - if (var->Is()) { + auto* as_let = decl->As(); + auto* as_var = decl->As(); + + if (v->Is()) { auto name = symbols_.NameFor(decl->symbol); if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { - auto* kind = var->Declaration()->is_const ? "let" : "var"; + auto* kind = as_let ? "let" : "var"; AddError( "'" + name + "' is a builtin and cannot be redeclared as a module-scope " + kind, decl->source); @@ -656,14 +673,13 @@ bool Validator::Variable(const sem::Variable* var) const { } } - if (!decl->is_const && !IsStorable(storage_ty)) { + if (as_var && !IsStorable(storage_ty)) { AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a var", decl->source); return false; } - if (decl->is_const && !var->Is() && - !(storage_ty->IsConstructible() || storage_ty->Is())) { + if (as_let && !(storage_ty->IsConstructible() || storage_ty->Is())) { AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a let", decl->source); return false; @@ -688,16 +704,17 @@ bool Validator::Variable(const sem::Variable* var) const { } } - if (var->Is() && !decl->is_const && + if (v->Is() && as_var && IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass)) { - if (!var->Type()->UnwrapRef()->IsConstructible()) { + if (!v->Type()->UnwrapRef()->IsConstructible()) { AddError("function variable must have a constructible type", decl->type ? decl->type->source : decl->source); return false; } } - if (storage_ty->is_handle() && decl->declared_storage_class != ast::StorageClass::kNone) { + if (as_var && storage_ty->is_handle() && + as_var->declared_storage_class != ast::StorageClass::kNone) { // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables // If the store type is a texture type or a sampler type, then the // variable declaration must not have a storage class attribute. The @@ -709,9 +726,10 @@ bool Validator::Variable(const sem::Variable* var) const { } if (IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass) && - (decl->declared_storage_class == ast::StorageClass::kInput || - decl->declared_storage_class == ast::StorageClass::kOutput)) { - AddError("invalid use of input/output storage class", decl->source); + as_var && + (as_var->declared_storage_class == ast::StorageClass::kInput || + as_var->declared_storage_class == ast::StorageClass::kOutput)) { + AddError("invalid use of input/output storage class", as_var->source); return false; } return true; @@ -1223,12 +1241,12 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage) // Validate there are no resource variable binding collisions std::unordered_map binding_points; - for (auto* var : func->TransitivelyReferencedGlobals()) { - auto* var_decl = var->Declaration(); - if (!var_decl->BindingPoint()) { + for (auto* global : func->TransitivelyReferencedGlobals()) { + auto* var_decl = global->Declaration()->As(); + if (!var_decl || !var_decl->BindingPoint()) { continue; } - auto bp = var->BindingPoint(); + auto bp = global->BindingPoint(); auto res = binding_points.emplace(bp, var_decl); if (!res.second && IsValidationEnabled(decl->attributes, @@ -1663,12 +1681,6 @@ bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_stat TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; return false; } - if (var->Declaration()->is_const) { - TINT_ICE(Resolver, diagnostics_) - << "Resolver::FunctionCall() encountered an address-of " - "expression of a constant identifier expression"; - return false; - } is_valid = true; } } @@ -2172,18 +2184,16 @@ bool Validator::Assignment(const ast::Statement* a, const sem::Type* rhs_ty) con // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement auto const* lhs_ty = sem_.TypeOf(lhs); - if (auto* var = sem_.ResolvedSymbol(lhs)) { - auto* decl = var->Declaration(); - if (var->Is()) { - AddError("cannot assign to function parameter", lhs->source); - AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); - return false; - } - if (decl->is_const) { - AddError( - decl->is_overridable ? "cannot assign to 'override'" : "cannot assign to 'let'", - lhs->source); - AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); + if (auto* variable = sem_.ResolvedSymbol(lhs)) { + auto* v = variable->Declaration(); + const char* err = Switch( + v, // + [&](const ast::Parameter*) { return "cannot assign to function parameter"; }, + [&](const ast::Let*) { return "cannot assign to 'let'"; }, + [&](const ast::Override*) { return "cannot assign to 'override'"; }); + if (err) { + AddError(err, lhs->source); + AddNote("'" + symbols_.NameFor(v->symbol) + "' is declared here:", v->source); return false; } } @@ -2222,17 +2232,16 @@ bool Validator::IncrementDecrementStatement(const ast::IncrementDecrementStateme // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement - if (auto* var = sem_.ResolvedSymbol(lhs)) { - auto* decl = var->Declaration(); - if (var->Is()) { - AddError("cannot modify function parameter", lhs->source); - AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); - return false; - } - if (decl->is_const) { - AddError(decl->is_overridable ? "cannot modify 'override'" : "cannot modify 'let'", - lhs->source); - AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); + if (auto* variable = sem_.ResolvedSymbol(lhs)) { + auto* v = variable->Declaration(); + const char* err = Switch( + v, // + [&](const ast::Parameter*) { return "cannot modify function parameter"; }, + [&](const ast::Let*) { return "cannot modify 'let'"; }, + [&](const ast::Override*) { return "cannot modify 'override'"; }); + if (err) { + AddError(err, lhs->source); + AddNote("'" + symbols_.NameFor(v->symbol) + "' is declared here:", v->source); return false; } } diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index e3912ba7d9..223f3cf684 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h @@ -237,7 +237,7 @@ class Validator { /// @param atomic_composite_info atomic composite info in the module /// @returns true on success, false otherwise bool GlobalVariable( - const sem::Variable* var, + const sem::GlobalVariable* var, std::unordered_map constant_ids, std::unordered_map atomic_composite_info) const; @@ -345,12 +345,12 @@ class Validator { bool Variable(const sem::Variable* var) const; /// Validates a variable constructor or cast - /// @param var the variable to validate + /// @param v the variable to validate /// @param storage_class the storage class of the variable /// @param storage_type the type of the storage /// @param rhs_type the right hand side of the expression /// @returns true on succes, false otherwise - bool VariableConstructorOrCast(const ast::Variable* var, + bool VariableConstructorOrCast(const ast::Variable* v, ast::StorageClass storage_class, const sem::Type* storage_type, const sem::Type* rhs_type) const; diff --git a/src/tint/resolver/var_let_validation_test.cc b/src/tint/resolver/var_let_validation_test.cc index eea5e72f59..67ecd53715 100644 --- a/src/tint/resolver/var_let_validation_test.cc +++ b/src/tint/resolver/var_let_validation_test.cc @@ -24,22 +24,6 @@ namespace { struct ResolverVarLetValidationTest : public resolver::TestHelper, public testing::Test {}; -TEST_F(ResolverVarLetValidationTest, LetNoInitializer) { - // let a : i32; - WrapInFunction(Let(Source{{12, 34}}, "a", ty.i32(), nullptr)); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer"); -} - -TEST_F(ResolverVarLetValidationTest, GlobalLetNoInitializer) { - // let a : i32; - GlobalConst(Source{{12, 34}}, "a", ty.i32(), nullptr); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer"); -} - TEST_F(ResolverVarLetValidationTest, VarNoInitializerNoType) { // var a; WrapInFunction(Var(Source{{12, 34}}, "a", nullptr)); diff --git a/src/tint/sem/function.cc b/src/tint/sem/function.cc index 790933b94a..cd34eb737d 100644 --- a/src/tint/sem/function.cc +++ b/src/tint/sem/function.cc @@ -44,10 +44,10 @@ std::vector> Function::TransitivelyReferencedLocationVariables() const { std::vector> ret; - for (auto* var : TransitivelyReferencedGlobals()) { - for (auto* attr : var->Declaration()->attributes) { + for (auto* global : TransitivelyReferencedGlobals()) { + for (auto* attr : global->Declaration()->attributes) { if (auto* location = attr->As()) { - ret.push_back({var, location}); + ret.push_back({global, location}); break; } } @@ -58,13 +58,13 @@ Function::TransitivelyReferencedLocationVariables() const { Function::VariableBindings Function::TransitivelyReferencedUniformVariables() const { VariableBindings ret; - for (auto* var : TransitivelyReferencedGlobals()) { - if (var->StorageClass() != ast::StorageClass::kUniform) { + for (auto* global : TransitivelyReferencedGlobals()) { + if (global->StorageClass() != ast::StorageClass::kUniform) { continue; } - if (auto binding_point = var->Declaration()->BindingPoint()) { - ret.push_back({var, binding_point}); + if (auto binding_point = global->Declaration()->BindingPoint()) { + ret.push_back({global, binding_point}); } } return ret; @@ -73,13 +73,13 @@ Function::VariableBindings Function::TransitivelyReferencedUniformVariables() co Function::VariableBindings Function::TransitivelyReferencedStorageBufferVariables() const { VariableBindings ret; - for (auto* var : TransitivelyReferencedGlobals()) { - if (var->StorageClass() != ast::StorageClass::kStorage) { + for (auto* global : TransitivelyReferencedGlobals()) { + if (global->StorageClass() != ast::StorageClass::kStorage) { continue; } - if (auto binding_point = var->Declaration()->BindingPoint()) { - ret.push_back({var, binding_point}); + if (auto binding_point = global->Declaration()->BindingPoint()) { + ret.push_back({global, binding_point}); } } return ret; @@ -89,10 +89,10 @@ std::vector> Function::TransitivelyReferencedBuiltinVariables() const { std::vector> ret; - for (auto* var : TransitivelyReferencedGlobals()) { - for (auto* attr : var->Declaration()->attributes) { + for (auto* global : TransitivelyReferencedGlobals()) { + for (auto* attr : global->Declaration()->attributes) { if (auto* builtin = attr->As()) { - ret.push_back({var, builtin}); + ret.push_back({global, builtin}); break; } } @@ -119,11 +119,11 @@ Function::VariableBindings Function::TransitivelyReferencedMultisampledTextureVa Function::VariableBindings Function::TransitivelyReferencedVariablesOfType( const tint::TypeInfo* type) const { VariableBindings ret; - for (auto* var : TransitivelyReferencedGlobals()) { - auto* unwrapped_type = var->Type()->UnwrapRef(); + for (auto* global : TransitivelyReferencedGlobals()) { + auto* unwrapped_type = global->Type()->UnwrapRef(); if (unwrapped_type->TypeInfo().Is(type)) { - if (auto binding_point = var->Declaration()->BindingPoint()) { - ret.push_back({var, binding_point}); + if (auto binding_point = global->Declaration()->BindingPoint()) { + ret.push_back({global, binding_point}); } } } @@ -143,15 +143,15 @@ Function::VariableBindings Function::TransitivelyReferencedSamplerVariablesImpl( ast::SamplerKind kind) const { VariableBindings ret; - for (auto* var : TransitivelyReferencedGlobals()) { - auto* unwrapped_type = var->Type()->UnwrapRef(); + for (auto* global : TransitivelyReferencedGlobals()) { + auto* unwrapped_type = global->Type()->UnwrapRef(); auto* sampler = unwrapped_type->As(); if (sampler == nullptr || sampler->kind() != kind) { continue; } - if (auto binding_point = var->Declaration()->BindingPoint()) { - ret.push_back({var, binding_point}); + if (auto binding_point = global->Declaration()->BindingPoint()) { + ret.push_back({global, binding_point}); } } return ret; @@ -161,8 +161,8 @@ Function::VariableBindings Function::TransitivelyReferencedSampledTextureVariabl bool multisampled) const { VariableBindings ret; - for (auto* var : TransitivelyReferencedGlobals()) { - auto* unwrapped_type = var->Type()->UnwrapRef(); + for (auto* global : TransitivelyReferencedGlobals()) { + auto* unwrapped_type = global->Type()->UnwrapRef(); auto* texture = unwrapped_type->As(); if (texture == nullptr) { continue; @@ -175,8 +175,8 @@ Function::VariableBindings Function::TransitivelyReferencedSampledTextureVariabl continue; } - if (auto binding_point = var->Declaration()->BindingPoint()) { - ret.push_back({var, binding_point}); + if (auto binding_point = global->Declaration()->BindingPoint()) { + ret.push_back({global, binding_point}); } } diff --git a/src/tint/sem/type_mappings.h b/src/tint/sem/type_mappings.h index 2b082a8b57..9041bdb73e 100644 --- a/src/tint/sem/type_mappings.h +++ b/src/tint/sem/type_mappings.h @@ -27,6 +27,7 @@ class Function; class IfStatement; class MemberAccessorExpression; class Node; +class Override; class Statement; class Struct; class StructMember; @@ -45,6 +46,7 @@ class Function; class IfStatement; class MemberAccessorExpression; class Node; +class GlobalVariable; class Statement; class Struct; class StructMember; @@ -69,6 +71,7 @@ struct TypeMappings { IfStatement* operator()(ast::IfStatement*); MemberAccessorExpression* operator()(ast::MemberAccessorExpression*); Node* operator()(ast::Node*); + GlobalVariable* operator()(ast::Override*); Statement* operator()(ast::Statement*); Struct* operator()(ast::Struct*); StructMember* operator()(ast::StructMember*); diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc index 0ada5aeeed..de5cfa44b6 100644 --- a/src/tint/sem/variable.cc +++ b/src/tint/sem/variable.cc @@ -62,7 +62,7 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration, GlobalVariable::~GlobalVariable() = default; -Parameter::Parameter(const ast::Variable* declaration, +Parameter::Parameter(const ast::Parameter* declaration, uint32_t index, const sem::Type* type, ast::StorageClass storage_class, diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h index 7026ca7506..b9952339d1 100644 --- a/src/tint/sem/variable.h +++ b/src/tint/sem/variable.h @@ -154,24 +154,14 @@ class GlobalVariable final : public Castable { sem::BindingPoint BindingPoint() const { return binding_point_; } /// @param id the constant identifier to assign to this variable - void SetConstantId(uint16_t id) { - constant_id_ = id; - is_overridable_ = true; - } + void SetConstantId(uint16_t id) { constant_id_ = id; } /// @returns the pipeline constant ID associated with the variable uint16_t ConstantId() const { return constant_id_; } - /// @param is_overridable true if this is a pipeline overridable constant - void SetIsOverridable(bool is_overridable = true) { is_overridable_ = is_overridable; } - - /// @returns true if this is pipeline overridable constant - bool IsOverridable() const { return is_overridable_; } - private: const sem::BindingPoint binding_point_; - bool is_overridable_ = false; uint16_t constant_id_ = 0; }; @@ -185,7 +175,7 @@ class Parameter final : public Castable { /// @param storage_class the variable storage class /// @param access the variable access control type /// @param usage the semantic usage for the parameter - Parameter(const ast::Variable* declaration, + Parameter(const ast::Parameter* declaration, uint32_t index, const sem::Type* type, ast::StorageClass storage_class, diff --git a/src/tint/transform/add_spirv_block_attribute.cc b/src/tint/transform/add_spirv_block_attribute.cc index 38e0de66d5..a62f9c7bfc 100644 --- a/src/tint/transform/add_spirv_block_attribute.cc +++ b/src/tint/transform/add_spirv_block_attribute.cc @@ -54,8 +54,8 @@ void AddSpirvBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) co // contains it in the destination program. std::unordered_map wrapper_structs; - // Process global variables that are buffers. - for (auto* var : ctx.src->AST().GlobalVariables()) { + // Process global 'var' declarations that are buffers. + for (auto* var : ctx.src->AST().Globals()) { auto* sem_var = sem.Get(var); if (var->declared_storage_class != ast::StorageClass::kStorage && var->declared_storage_class != ast::StorageClass::kUniform) { diff --git a/src/tint/transform/binding_remapper.cc b/src/tint/transform/binding_remapper.cc index e3b7afdf89..3d1783f30b 100644 --- a/src/tint/transform/binding_remapper.cc +++ b/src/tint/transform/binding_remapper.cc @@ -67,8 +67,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co } auto* func = ctx.src->Sem().Get(func_ast); std::unordered_map binding_point_counts; - for (auto* var : func->TransitivelyReferencedGlobals()) { - if (auto binding_point = var->Declaration()->BindingPoint()) { + for (auto* global : func->TransitivelyReferencedGlobals()) { + if (auto binding_point = global->Declaration()->BindingPoint()) { BindingPoint from{binding_point.group->value, binding_point.binding->value}; auto bp_it = remappings->binding_points.find(from); if (bp_it != remappings->binding_points.end()) { @@ -88,7 +88,7 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co } } - for (auto* var : ctx.src->AST().GlobalVariables()) { + for (auto* var : ctx.src->AST().Globals()) { if (auto binding_point = var->BindingPoint()) { // The original binding point BindingPoint from{binding_point.group->value, binding_point.binding->value}; @@ -130,10 +130,10 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co } auto* ty = sem->Type()->UnwrapRef(); const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty); - auto* new_var = ctx.dst->create( - ctx.Clone(var->source), ctx.Clone(var->symbol), var->declared_storage_class, ac, - inner_ty, false, false, ctx.Clone(var->constructor), - ctx.Clone(var->attributes)); + auto* new_var = + ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty, + var->declared_storage_class, ac, ctx.Clone(var->constructor), + ctx.Clone(var->attributes)); ctx.Replace(var, new_var); } diff --git a/src/tint/transform/combine_samplers.cc b/src/tint/transform/combine_samplers.cc index 8cd3a6d49a..cd8395c782 100644 --- a/src/tint/transform/combine_samplers.cc +++ b/src/tint/transform/combine_samplers.cc @@ -147,16 +147,16 @@ struct CombineSamplers::State { // Remove all texture and sampler global variables. These will be replaced // by combined samplers. - for (auto* var : ctx.src->AST().GlobalVariables()) { - auto* type = sem.Get(var->type); - if (type && type->IsAnyOf() && + for (auto* global : ctx.src->AST().GlobalVariables()) { + auto* type = sem.Get(global->type); + if (tint::IsAnyOf(type) && !type->Is()) { - ctx.Remove(ctx.src->AST().GlobalDeclarations(), var); - } else if (auto binding_point = var->BindingPoint()) { + ctx.Remove(ctx.src->AST().GlobalDeclarations(), global); + } else if (auto binding_point = global->BindingPoint()) { if (binding_point.group->value == 0 && binding_point.binding->value == 0) { auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision); - ctx.InsertFront(var->attributes, attribute); + ctx.InsertFront(global->attributes, attribute); } } } @@ -188,9 +188,8 @@ struct CombineSamplers::State { } else { // Either texture or sampler (or both) is a function parameter; // add a new function parameter to represent the combined sampler. - const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var); - const ast::Variable* var = - ctx.dst->Param(ctx.dst->Symbols().New(name), type); + auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var); + auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type); params.push_back(var); function_combined_texture_samplers_[func][pair] = var; } diff --git a/src/tint/transform/fold_trivial_single_use_lets.cc b/src/tint/transform/fold_trivial_single_use_lets.cc index 5bcdaa4900..4dbc8c28f9 100644 --- a/src/tint/transform/fold_trivial_single_use_lets.cc +++ b/src/tint/transform/fold_trivial_single_use_lets.cc @@ -31,11 +31,11 @@ const ast::VariableDeclStatement* AsTrivialLetDecl(const ast::Statement* stmt) { if (!var_decl) { return nullptr; } - auto* var = var_decl->variable; - if (!var->is_const) { + auto* let = var_decl->variable->As(); + if (!let) { return nullptr; } - auto* ctor = var->constructor; + auto* ctor = let->constructor; if (!IsAnyOf(ctor)) { return nullptr; } diff --git a/src/tint/transform/module_scope_var_to_entry_point_param.cc b/src/tint/transform/module_scope_var_to_entry_point_param.cc index 22bcd5c1bc..cc89fde782 100644 --- a/src/tint/transform/module_scope_var_to_entry_point_param.cc +++ b/src/tint/transform/module_scope_var_to_entry_point_param.cc @@ -155,9 +155,13 @@ struct ModuleScopeVarToEntryPointParam::State { return workgroup_parameter_symbol; }; - for (auto* var : func_sem->TransitivelyReferencedGlobals()) { - auto sc = var->StorageClass(); - auto* ty = var->Type()->UnwrapRef(); + for (auto* global : func_sem->TransitivelyReferencedGlobals()) { + auto* var = global->Declaration()->As(); + if (!var) { + continue; + } + auto sc = global->StorageClass(); + auto* ty = global->Type()->UnwrapRef(); if (sc == ast::StorageClass::kNone) { continue; } @@ -182,12 +186,12 @@ struct ModuleScopeVarToEntryPointParam::State { bool is_wrapped = false; if (is_entry_point) { - if (var->Type()->UnwrapRef()->is_handle()) { + if (global->Type()->UnwrapRef()->is_handle()) { // For a texture or sampler variable, redeclare it as an entry point // parameter. Disable entry point parameter validation. auto* disable_validation = ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter); - auto attrs = ctx.Clone(var->Declaration()->attributes); + auto attrs = ctx.Clone(var->attributes); attrs.push_back(disable_validation); auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs); ctx.InsertFront(func_ast->params, param); @@ -195,7 +199,7 @@ struct ModuleScopeVarToEntryPointParam::State { sc == ast::StorageClass::kUniform) { // Variables into the Storage and Uniform storage classes are // redeclared as entry point parameters with a pointer type. - auto attributes = ctx.Clone(var->Declaration()->attributes); + auto attributes = ctx.Clone(var->attributes); attributes.push_back( ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter)); attributes.push_back( @@ -214,22 +218,22 @@ struct ModuleScopeVarToEntryPointParam::State { is_wrapped = true; } - param_type = ctx.dst->ty.pointer(param_type, sc, - var->Declaration()->declared_access); + param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access); auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes); ctx.InsertFront(func_ast->params, param); is_pointer = true; - } else if (sc == ast::StorageClass::kWorkgroup && ContainsMatrix(var->Type())) { + } else if (sc == ast::StorageClass::kWorkgroup && + ContainsMatrix(global->Type())) { // Due to a bug in the MSL compiler, we use a threadgroup memory // argument for any workgroup allocation that contains a matrix. // See crbug.com/tint/938. // TODO(jrprice): Do this for all other workgroup variables too. // Create a member in the workgroup parameter struct. - auto member = ctx.Clone(var->Declaration()->symbol); + auto member = ctx.Clone(var->symbol); workgroup_parameter_members.push_back( ctx.dst->Member(member, store_type())); - CloneStructTypes(var->Type()->UnwrapRef()); + CloneStructTypes(global->Type()->UnwrapRef()); // Create a function-scope variable that is a pointer to the member. auto* member_ptr = ctx.dst->AddressOf( @@ -246,7 +250,7 @@ struct ModuleScopeVarToEntryPointParam::State { // this variable. auto* disable_validation = ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass); - auto* constructor = ctx.Clone(var->Declaration()->constructor); + auto* constructor = ctx.Clone(var->constructor); auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, constructor, ast::AttributeList{disable_validation}); @@ -257,9 +261,8 @@ struct ModuleScopeVarToEntryPointParam::State { // Use a pointer for non-handle types. auto* param_type = store_type(); ast::AttributeList attributes; - if (!var->Type()->UnwrapRef()->is_handle()) { - param_type = ctx.dst->ty.pointer(param_type, sc, - var->Declaration()->declared_access); + if (!global->Type()->UnwrapRef()->is_handle()) { + param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access); is_pointer = true; // Disable validation of the parameter's storage class and of @@ -275,7 +278,7 @@ struct ModuleScopeVarToEntryPointParam::State { // Replace all uses of the module-scope variable. // For non-entry points, dereference non-handle pointer parameters. - for (auto* user : var->Users()) { + for (auto* user : global->Users()) { if (user->Stmt()->Function()->Declaration() == func_ast) { const ast::Expression* expr = ctx.dst->Expr(new_var_symbol); if (is_pointer) { @@ -298,7 +301,7 @@ struct ModuleScopeVarToEntryPointParam::State { } } - var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped}; + var_to_newvar[global] = {new_var_symbol, is_pointer, is_wrapped}; } if (!workgroup_parameter_members.empty()) { diff --git a/src/tint/transform/multiplanar_external_texture.cc b/src/tint/transform/multiplanar_external_texture.cc index 3e3974cbf5..8b85573af2 100644 --- a/src/tint/transform/multiplanar_external_texture.cc +++ b/src/tint/transform/multiplanar_external_texture.cc @@ -86,8 +86,8 @@ struct MultiplanarExternalTexture::State { // binding and create two additional bindings (one texture_2d to // represent the secondary plane and one uniform buffer for the // ExternalTextureParams struct). - for (auto* var : ctx.src->AST().GlobalVariables()) { - auto* sem_var = sem.Get(var); + for (auto* global : ctx.src->AST().GlobalVariables()) { + auto* sem_var = sem.Get(global); if (!sem_var->Type()->UnwrapRef()->Is()) { continue; } @@ -95,7 +95,7 @@ struct MultiplanarExternalTexture::State { // If the attributes are empty, then this must be a texture_external // passed as a function parameter. These variables are transformed // elsewhere. - if (var->attributes.empty()) { + if (global->attributes.empty()) { continue; } @@ -109,8 +109,8 @@ struct MultiplanarExternalTexture::State { // provided to this transform. We fetch the new binding points by // providing the original texture_external binding points into the // passed map. - BindingPoint bp = {var->BindingPoint().group->value, - var->BindingPoint().binding->value}; + BindingPoint bp = {global->BindingPoint().group->value, + global->BindingPoint().binding->value}; BindingsMap::const_iterator it = new_binding_points->bindings_map.find(bp); if (it == new_binding_points->bindings_map.end()) { @@ -129,7 +129,7 @@ struct MultiplanarExternalTexture::State { // corresponds with the new destination bindings. // NewBindingSymbols new_binding_syms; auto& syms = new_binding_symbols[sem_var]; - syms.plane_0 = ctx.Clone(var->symbol); + syms.plane_0 = ctx.Clone(global->symbol); syms.plane_1 = b.Symbols().New("ext_tex_plane_1"); b.Global(syms.plane_1, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()), b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding)); @@ -140,13 +140,13 @@ struct MultiplanarExternalTexture::State { // Replace the original texture_external binding with a texture_2d // binding. - ast::AttributeList cloned_attributes = ctx.Clone(var->attributes); - const ast::Expression* cloned_constructor = ctx.Clone(var->constructor); + ast::AttributeList cloned_attributes = ctx.Clone(global->attributes); + const ast::Expression* cloned_constructor = ctx.Clone(global->constructor); auto* replacement = b.Var(syms.plane_0, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()), cloned_constructor, cloned_attributes); - ctx.Replace(var, replacement); + ctx.Replace(global, replacement); } // We must update all the texture_external parameters for user declared diff --git a/src/tint/transform/num_workgroups_from_uniform.cc b/src/tint/transform/num_workgroups_from_uniform.cc index 0bb1518544..bba1580c62 100644 --- a/src/tint/transform/num_workgroups_from_uniform.cc +++ b/src/tint/transform/num_workgroups_from_uniform.cc @@ -133,8 +133,8 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat // plus 1, or group 0 if no resource bound. group = 0; - for (auto* var : ctx.src->AST().GlobalVariables()) { - if (auto binding_point = var->BindingPoint()) { + for (auto* global : ctx.src->AST().GlobalVariables()) { + if (auto binding_point = global->BindingPoint()) { if (binding_point.group->value >= group) { group = binding_point.group->value + 1; } diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc index b8f82bc49a..4190312f51 100644 --- a/src/tint/transform/simplify_pointers.cc +++ b/src/tint/transform/simplify_pointers.cc @@ -109,8 +109,8 @@ struct SimplifyPointers::State { } if (auto* user = ctx.src->Sem().Get(op.expr)) { auto* var = user->Variable(); - if (var->Is() && // - var->Declaration()->is_const && // + if (var->Is() && // + var->Declaration()->Is() && // var->Type()->Is()) { op.expr = var->Declaration()->constructor; continue; @@ -161,7 +161,7 @@ struct SimplifyPointers::State { // permitted. for (auto* node : ctx.src->ASTNodes().Objects()) { if (auto* let = node->As()) { - if (!let->variable->is_const) { + if (!let->variable->Is()) { continue; // Not a `let` declaration. Ignore. } diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc index 82324c7745..ab7bed3da1 100644 --- a/src/tint/transform/single_entry_point.cc +++ b/src/tint/transform/single_entry_point.cc @@ -64,38 +64,43 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c referenced_vars.emplace(var->Declaration()); } - // Clone any module-scope variables, types, and functions that are statically - // referenced by the target entry point. + // Clone any module-scope variables, types, and functions that are statically referenced by the + // target entry point. for (auto* decl : ctx.src->AST().GlobalDeclarations()) { - if (auto* ty = decl->As()) { - // TODO(jrprice): Strip unused types. - ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); - } else if (auto* var = decl->As()) { - if (referenced_vars.count(var)) { - if (var->is_overridable) { - // It is an overridable constant - if (!ast::HasAttribute(var->attributes)) { + Switch( + decl, // + [&](const ast::TypeDecl* ty) { + // TODO(jrprice): Strip unused types. + ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); + }, + [&](const ast::Override* override) { + if (referenced_vars.count(override)) { + if (!ast::HasAttribute(override->attributes)) { // If the constant doesn't already have an @id() attribute, add one // so that its allocated ID so that it won't be affected by other // stripped away constants - auto* global = sem.Get(var)->As(); + auto* global = sem.Get(override); const auto* id = ctx.dst->Id(global->ConstantId()); - ctx.InsertFront(var->attributes, id); + ctx.InsertFront(override->attributes, id); } + ctx.dst->AST().AddGlobalVariable(ctx.Clone(override)); } - ctx.dst->AST().AddGlobalVariable(ctx.Clone(var)); - } - } else if (auto* func = decl->As()) { - if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) { - ctx.dst->AST().AddFunction(ctx.Clone(func)); - } - } else if (auto* ext = decl->As()) { - ctx.dst->AST().AddEnable(ctx.Clone(ext)); - } else { - TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) - << "unhandled global declaration: " << decl->TypeInfo().name; - return; - } + }, + [&](const ast::Variable* v) { // var, let + if (referenced_vars.count(v)) { + ctx.dst->AST().AddGlobalVariable(ctx.Clone(v)); + } + }, + [&](const ast::Function* func) { + if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) { + ctx.dst->AST().AddFunction(ctx.Clone(func)); + } + }, + [&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); }, + [&](Default) { + TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) + << "unhandled global declaration: " << decl->TypeInfo().name; + }); } // Clone the entry point. diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc index dcf90daa8b..6f0292b172 100644 --- a/src/tint/transform/unshadow.cc +++ b/src/tint/transform/unshadow.cc @@ -44,28 +44,42 @@ struct Unshadow::State { // Maps a variable to its new name. std::unordered_map renamed_to; - auto rename = [&](const sem::Variable* var) -> const ast::Variable* { - auto* decl = var->Declaration(); + auto rename = [&](const sem::Variable* v) -> const ast::Variable* { + auto* decl = v->Declaration(); auto name = ctx.src->Symbols().NameFor(decl->symbol); auto symbol = ctx.dst->Symbols().New(name); - renamed_to.emplace(var, symbol); + renamed_to.emplace(v, symbol); auto source = ctx.Clone(decl->source); auto* type = ctx.Clone(decl->type); auto* constructor = ctx.Clone(decl->constructor); auto attributes = ctx.Clone(decl->attributes); - return ctx.dst->create(source, symbol, decl->declared_storage_class, - decl->declared_access, type, decl->is_const, - decl->is_overridable, constructor, attributes); + return Switch( + decl, // + [&](const ast::Var* var) { + return ctx.dst->Var(source, symbol, type, var->declared_storage_class, + var->declared_access, constructor, attributes); + }, + [&](const ast::Let*) { + return ctx.dst->Let(source, symbol, type, constructor, attributes); + }, + [&](const ast::Parameter*) { + return ctx.dst->Param(source, symbol, type, attributes); + }, + [&](Default) { + TINT_ICE(Transform, ctx.dst->Diagnostics()) + << "unexpected variable type: " << decl->TypeInfo().name; + return nullptr; + }); }; - ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* { - if (auto* local = sem.Get(var)) { + ctx.ReplaceAll([&](const ast::Variable* v) -> const ast::Variable* { + if (auto* local = sem.Get(v)) { if (local->Shadows()) { return rename(local); } } - if (auto* param = sem.Get(var)) { + if (auto* param = sem.Get(v)) { if (param->Shadows()) { return rename(param); } diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc index f5bf24bfac..41f2287654 100644 --- a/src/tint/transform/utils/hoist_to_decl_before.cc +++ b/src/tint/transform/utils/hoist_to_decl_before.cc @@ -189,18 +189,18 @@ class HoistToDeclBefore::State { /// before `before_expr`. /// @param before_expr expression to insert `expr` before /// @param expr expression to hoist - /// @param as_const hoist to `let` if true, otherwise to `var` + /// @param as_let hoist to `let` if true, otherwise to `var` /// @param decl_name optional name to use for the variable/constant name /// @return true on success bool Add(const sem::Expression* before_expr, const ast::Expression* expr, - bool as_const, + bool as_let, const char* decl_name) { auto name = b.Symbols().New(decl_name); // Construct the let/var that holds the hoisted expr - auto* v = as_const ? b.Let(name, nullptr, ctx.Clone(expr)) - : b.Var(name, nullptr, ctx.Clone(expr)); + auto* v = as_let ? static_cast(b.Let(name, nullptr, ctx.Clone(expr))) + : static_cast(b.Var(name, nullptr, ctx.Clone(expr))); auto* decl = b.Decl(v); if (!InsertBefore(before_expr->Stmt(), decl)) { @@ -330,9 +330,9 @@ HoistToDeclBefore::~HoistToDeclBefore() {} bool HoistToDeclBefore::Add(const sem::Expression* before_expr, const ast::Expression* expr, - bool as_const, + bool as_let, const char* decl_name) { - return state_->Add(before_expr, expr, as_const, decl_name); + return state_->Add(before_expr, expr, as_let, decl_name); } bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc index 31e4e7c9ce..a1627e2ab3 100644 --- a/src/tint/transform/vertex_pulling.cc +++ b/src/tint/transform/vertex_pulling.cc @@ -695,7 +695,7 @@ struct State { /// vertex_index and instance_index builtins if present. /// @param func the entry point function /// @param param the parameter to process - void ProcessNonStructParameter(const ast::Function* func, const ast::Variable* param) { + void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) { if (auto* location = ast::GetAttribute(param->attributes)) { // Create a function-scope variable to replace the parameter. auto func_var_sym = ctx.Clone(param->symbol); @@ -733,7 +733,7 @@ struct State { /// @param param the parameter to process /// @param struct_ty the structure type void ProcessStructParameter(const ast::Function* func, - const ast::Variable* param, + const ast::Parameter* param, const ast::Struct* struct_ty) { auto param_sym = ctx.Clone(param->symbol); diff --git a/src/tint/transform/zero_init_workgroup_memory.cc b/src/tint/transform/zero_init_workgroup_memory.cc index 6e84310992..f56dc617fd 100644 --- a/src/tint/transform/zero_init_workgroup_memory.cc +++ b/src/tint/transform/zero_init_workgroup_memory.cc @@ -416,8 +416,8 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const { - for (auto* decl : program->AST().GlobalDeclarations()) { - if (auto* var = decl->As()) { + for (auto* global : program->AST().GlobalVariables()) { + if (auto* var = global->As()) { if (var->declared_storage_class == ast::StorageClass::kWorkgroup) { return true; } diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index 441b0d41f2..66ff64a992 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -1904,42 +1904,47 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) { } bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) { - if (global->is_const) { - return EmitProgramConstVariable(global); - } - - auto* sem = builder_.Sem().Get(global); - switch (sem->StorageClass()) { - case ast::StorageClass::kUniform: - return EmitUniformVariable(sem); - case ast::StorageClass::kStorage: - return EmitStorageVariable(sem); - case ast::StorageClass::kHandle: - return EmitHandleVariable(sem); - case ast::StorageClass::kPrivate: - return EmitPrivateVariable(sem); - case ast::StorageClass::kWorkgroup: - return EmitWorkgroupVariable(sem); - case ast::StorageClass::kInput: - case ast::StorageClass::kOutput: - return EmitIOVariable(sem); - default: - break; - } - - TINT_ICE(Writer, diagnostics_) << "unhandled storage class " << sem->StorageClass(); - return false; + return Switch( + global, // + [&](const ast::Var* var) { + auto* sem = builder_.Sem().Get(global); + switch (sem->StorageClass()) { + case ast::StorageClass::kUniform: + return EmitUniformVariable(var, sem); + case ast::StorageClass::kStorage: + return EmitStorageVariable(var, sem); + case ast::StorageClass::kHandle: + return EmitHandleVariable(var, sem); + case ast::StorageClass::kPrivate: + return EmitPrivateVariable(sem); + case ast::StorageClass::kWorkgroup: + return EmitWorkgroupVariable(sem); + case ast::StorageClass::kInput: + case ast::StorageClass::kOutput: + return EmitIOVariable(sem); + default: + TINT_ICE(Writer, diagnostics_) + << "unhandled storage class " << sem->StorageClass(); + return false; + } + }, + [&](const ast::Let* let) { return EmitProgramConstVariable(let); }, + [&](const ast::Override* override) { return EmitOverride(override); }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) + << "unhandled global variable type " << global->TypeInfo().name; + return false; + }); } -bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) { - auto* decl = var->Declaration(); - auto* type = var->Type()->UnwrapRef(); +bool GeneratorImpl::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) { + auto* type = sem->Type()->UnwrapRef(); auto* str = type->As(); if (!str) { TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type"; return false; } - ast::VariableBindingPoint bp = decl->BindingPoint(); + ast::VariableBindingPoint bp = var->BindingPoint(); { auto out = line(); out << "layout(binding = " << bp.binding->value; @@ -1949,36 +1954,34 @@ bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) { out << ") uniform " << UniqueIdentifier(StructName(str)) << " {"; } EmitStructMembers(current_buffer_, str, /* emit_offsets */ true); - auto name = builder_.Symbols().NameFor(decl->symbol); + auto name = builder_.Symbols().NameFor(var->symbol); line() << "} " << name << ";"; line(); return true; } -bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) { - auto* decl = var->Declaration(); - auto* type = var->Type()->UnwrapRef(); +bool GeneratorImpl::EmitStorageVariable(const ast::Var* var, const sem::Variable* sem) { + auto* type = sem->Type()->UnwrapRef(); auto* str = type->As(); if (!str) { TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type"; return false; } - ast::VariableBindingPoint bp = decl->BindingPoint(); + ast::VariableBindingPoint bp = var->BindingPoint(); line() << "layout(binding = " << bp.binding->value << ", std430) buffer " << UniqueIdentifier(StructName(str)) << " {"; EmitStructMembers(current_buffer_, str, /* emit_offsets */ true); - auto name = builder_.Symbols().NameFor(decl->symbol); + auto name = builder_.Symbols().NameFor(var->symbol); line() << "} " << name << ";"; return true; } -bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) { - auto* decl = var->Declaration(); +bool GeneratorImpl::EmitHandleVariable(const ast::Var* var, const sem::Variable* sem) { auto out = line(); - auto name = builder_.Symbols().NameFor(decl->symbol); - auto* type = var->Type()->UnwrapRef(); + auto name = builder_.Symbols().NameFor(var->symbol); + auto* type = sem->Type()->UnwrapRef(); if (type->Is()) { // GLSL ignores Sampler variables. return true; @@ -1986,7 +1989,7 @@ bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) { if (auto* storage = type->As()) { out << "layout(" << convert_texel_format_to_glsl(storage->texel_format()) << ") "; } - if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) { + if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), name)) { return false; } @@ -2138,7 +2141,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { if (wgsize[i].overridable_const) { auto* global = builder_.Sem().Get(wgsize[i].overridable_const); - if (!global->IsOverridable()) { + if (!global->Declaration()->Is()) { TINT_ICE(Writer, builder_.Diagnostics()) << "expected a pipeline-overridable constant"; } @@ -2652,7 +2655,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { return EmitSwitch(s); } if (auto* v = stmt->As()) { - return EmitVariable(v->variable); + return Switch( + v->variable, // + [&](const ast::Var* var) { return EmitVar(var); }, + [&](const ast::Let* let) { return EmitLet(let); }, + [&](Default) { // + TINT_ICE(Writer, diagnostics_) + << "unknown variable type: " << v->variable->TypeInfo().name; + return false; + }); } diagnostics_.add_error(diag::System::Writer, @@ -2934,18 +2945,11 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression* return true; } -bool GeneratorImpl::EmitVariable(const ast::Variable* var) { +bool GeneratorImpl::EmitVar(const ast::Var* var) { auto* sem = builder_.Sem().Get(var); auto* type = sem->Type()->UnwrapRef(); - // TODO(dsinclair): Handle variable attributes - if (!var->attributes.empty()) { - diagnostics_.add_error(diag::System::Writer, "Variable attributes are not handled yet"); - return false; - } - auto out = line(); - // TODO(senorblanco): handle const if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), builder_.Symbols().NameFor(var->symbol))) { return false; @@ -2967,58 +2971,74 @@ bool GeneratorImpl::EmitVariable(const ast::Variable* var) { return true; } -bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { - for (auto* d : var->attributes) { - if (!d->Is()) { - diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid"); - return false; - } - } - if (!var->is_const) { - diagnostics_.add_error(diag::System::Writer, "Expected a const value"); +bool GeneratorImpl::EmitLet(const ast::Let* let) { + auto* sem = builder_.Sem().Get(let); + auto* type = sem->Type()->UnwrapRef(); + + auto out = line(); + // TODO(senorblanco): handle const + if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, + builder_.Symbols().NameFor(let->symbol))) { return false; } + out << " = "; + + if (!EmitExpression(out, let->constructor)) { + return false; + } + + out << ";"; + + return true; +} + +bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { auto* sem = builder_.Sem().Get(var); auto* type = sem->Type(); + auto out = line(); + out << "const "; + if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, + builder_.Symbols().NameFor(var->symbol))) { + return false; + } + out << " = "; + if (!EmitExpression(out, var->constructor)) { + return false; + } + out << ";"; + + return true; +} + +bool GeneratorImpl::EmitOverride(const ast::Override* override) { + auto* sem = builder_.Sem().Get(override); + auto* type = sem->Type(); + auto* global = sem->As(); - if (global && global->IsOverridable()) { - auto const_id = global->ConstantId(); + auto const_id = global->ConstantId(); - line() << "#ifndef " << kSpecConstantPrefix << const_id; + line() << "#ifndef " << kSpecConstantPrefix << const_id; - if (var->constructor != nullptr) { - auto out = line(); - out << "#define " << kSpecConstantPrefix << const_id << " "; - if (!EmitExpression(out, var->constructor)) { - return false; - } - } else { - line() << "#error spec constant required for constant id " << const_id; - } - line() << "#endif"; - { - auto out = line(); - out << "const "; - if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), - builder_.Symbols().NameFor(var->symbol))) { - return false; - } - out << " = " << kSpecConstantPrefix << const_id << ";"; + if (override->constructor != nullptr) { + auto out = line(); + out << "#define " << kSpecConstantPrefix << const_id << " "; + if (!EmitExpression(out, override->constructor)) { + return false; } } else { + line() << "#error spec constant required for constant id " << const_id; + } + line() << "#endif"; + { auto out = line(); out << "const "; - if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), - builder_.Symbols().NameFor(var->symbol))) { + if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, + builder_.Symbols().NameFor(override->symbol))) { return false; } - out << " = "; - if (!EmitExpression(out, var->constructor)) { - return false; - } - out << ";"; + out << " = " << kSpecConstantPrefix << const_id << ";"; } return true; diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h index ff1611c1b4..5b07fd39d2 100644 --- a/src/tint/writer/glsl/generator_impl.h +++ b/src/tint/writer/glsl/generator_impl.h @@ -293,19 +293,22 @@ class GeneratorImpl : public TextGenerator { bool EmitGlobalVariable(const ast::Variable* global); /// Handles emitting a global variable with the uniform storage class - /// @param var the global variable + /// @param var the AST node for the 'var' + /// @param sem the semantic node for the 'var' /// @returns true on success - bool EmitUniformVariable(const sem::Variable* var); + bool EmitUniformVariable(const ast::Var* var, const sem::Variable* sem); /// Handles emitting a global variable with the storage storage class - /// @param var the global variable + /// @param var the AST node for the 'var' + /// @param sem the semantic node for the 'var' /// @returns true on success - bool EmitStorageVariable(const sem::Variable* var); + bool EmitStorageVariable(const ast::Var* var, const sem::Variable* sem); /// Handles emitting a global variable with the handle storage class - /// @param var the global variable + /// @param var the AST node for the 'var' + /// @param sem the semantic node for the 'var' /// @returns true on success - bool EmitHandleVariable(const sem::Variable* var); + bool EmitHandleVariable(const ast::Var* var, const sem::Variable* sem); /// Handles emitting a global variable with the private storage class /// @param var the global variable @@ -437,14 +440,22 @@ class GeneratorImpl : public TextGenerator { /// @param type the type to emit the value for /// @returns true if the zero value was successfully emitted. bool EmitZeroValue(std::ostream& out, const sem::Type* type); - /// Handles generating a variable + /// Handles generating a 'var' declaration /// @param var the variable to generate /// @returns true if the variable was emitted - bool EmitVariable(const ast::Variable* var); - /// Handles generating a program scope constant variable - /// @param var the variable to emit + bool EmitVar(const ast::Var* var); + /// Handles generating a function-scope 'let' declaration + /// @param let the variable to generate /// @returns true if the variable was emitted - bool EmitProgramConstVariable(const ast::Variable* var); + bool EmitLet(const ast::Let* let); + /// Handles generating a module-scope 'let' declaration + /// @param let the 'let' to emit + /// @returns true if the variable was emitted + bool EmitProgramConstVariable(const ast::Variable* let); + /// Handles generating a module-scope 'override' declaration + /// @param override the 'override' to emit + /// @returns true if the variable was emitted + bool EmitOverride(const ast::Override* override); /// Handles generating a builtin method name /// @param builtin the semantic info for the builtin /// @returns the name or "" if not valid diff --git a/src/tint/writer/glsl/generator_impl_module_constant_test.cc b/src/tint/writer/glsl/generator_impl_module_constant_test.cc index 223122afb8..a97d4cd77f 100644 --- a/src/tint/writer/glsl/generator_impl_module_constant_test.cc +++ b/src/tint/writer/glsl/generator_impl_module_constant_test.cc @@ -40,7 +40,7 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(var)) << gen.error(); EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 #define WGSL_SPEC_CONSTANT_23 3.0f #endif @@ -56,7 +56,7 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(var)) << gen.error(); EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 #error spec constant required for constant id 23 #endif @@ -73,8 +73,8 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(a)) << gen.error(); - ASSERT_TRUE(gen.EmitProgramConstVariable(b)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(a)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(b)) << gen.error(); EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0 #define WGSL_SPEC_CONSTANT_0 3.0f #endif diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 0906940325..3dfb74f0b0 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -365,7 +365,7 @@ bool GeneratorImpl::EmitDynamicVectorAssignment(const ast::AssignmentStatement* out << "vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;"; break; default: - TINT_UNREACHABLE(Writer, builder_.Diagnostics()) + TINT_UNREACHABLE(Writer, diagnostics_) << "invalid vector size " << vec->Width(); break; } @@ -524,7 +524,7 @@ bool GeneratorImpl::EmitDynamicMatrixScalarAssignment(const ast::AssignmentState << vec_name << ";"; break; default: - TINT_UNREACHABLE(Writer, builder_.Diagnostics()) + TINT_UNREACHABLE(Writer, diagnostics_) << "invalid vector size " << vec->Width(); break; } @@ -2861,41 +2861,46 @@ bool GeneratorImpl::EmitFunctionBodyWithDiscard(const ast::Function* func) { } bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) { - if (global->is_const) { - return EmitProgramConstVariable(global); - } - - auto* sem = builder_.Sem().Get(global); - switch (sem->StorageClass()) { - case ast::StorageClass::kUniform: - return EmitUniformVariable(sem); - case ast::StorageClass::kStorage: - return EmitStorageVariable(sem); - case ast::StorageClass::kHandle: - return EmitHandleVariable(sem); - case ast::StorageClass::kPrivate: - return EmitPrivateVariable(sem); - case ast::StorageClass::kWorkgroup: - return EmitWorkgroupVariable(sem); - default: - break; - } - - TINT_ICE(Writer, diagnostics_) << "unhandled storage class " << sem->StorageClass(); - return false; + return Switch( + global, // + [&](const ast::Var* var) { + auto* sem = builder_.Sem().Get(global); + switch (sem->StorageClass()) { + case ast::StorageClass::kUniform: + return EmitUniformVariable(var, sem); + case ast::StorageClass::kStorage: + return EmitStorageVariable(var, sem); + case ast::StorageClass::kHandle: + return EmitHandleVariable(var, sem); + case ast::StorageClass::kPrivate: + return EmitPrivateVariable(sem); + case ast::StorageClass::kWorkgroup: + return EmitWorkgroupVariable(sem); + default: + TINT_ICE(Writer, diagnostics_) + << "unhandled storage class " << sem->StorageClass(); + return false; + } + }, + [&](const ast::Let* let) { return EmitProgramConstVariable(let); }, + [&](const ast::Override* override) { return EmitOverride(override); }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) + << "unhandled global variable type " << global->TypeInfo().name; + return false; + }); } -bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) { - auto* decl = var->Declaration(); - auto binding_point = decl->BindingPoint(); - auto* type = var->Type()->UnwrapRef(); - auto name = builder_.Symbols().NameFor(decl->symbol); +bool GeneratorImpl::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) { + auto binding_point = var->BindingPoint(); + auto* type = sem->Type()->UnwrapRef(); + auto name = builder_.Symbols().NameFor(var->symbol); line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point) << " {"; { ScopedIndent si(this); auto out = line(); - if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, var->Access(), name)) { + if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, sem->Access(), name)) { return false; } out << ";"; @@ -2906,29 +2911,27 @@ bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) { return true; } -bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) { - auto* decl = var->Declaration(); - auto* type = var->Type()->UnwrapRef(); +bool GeneratorImpl::EmitStorageVariable(const ast::Var* var, const sem::Variable* sem) { + auto* type = sem->Type()->UnwrapRef(); auto out = line(); - if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, var->Access(), - builder_.Symbols().NameFor(decl->symbol))) { + if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, sem->Access(), + builder_.Symbols().NameFor(var->symbol))) { return false; } - out << RegisterAndSpace(var->Access() == ast::Access::kRead ? 't' : 'u', decl->BindingPoint()) + out << RegisterAndSpace(sem->Access() == ast::Access::kRead ? 't' : 'u', var->BindingPoint()) << ";"; return true; } -bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) { - auto* decl = var->Declaration(); - auto* unwrapped_type = var->Type()->UnwrapRef(); +bool GeneratorImpl::EmitHandleVariable(const ast::Var* var, const sem::Variable* sem) { + auto* unwrapped_type = sem->Type()->UnwrapRef(); auto out = line(); - auto name = builder_.Symbols().NameFor(decl->symbol); - auto* type = var->Type()->UnwrapRef(); - if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) { + auto name = builder_.Symbols().NameFor(var->symbol); + auto* type = sem->Type()->UnwrapRef(); + if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), name)) { return false; } @@ -2944,7 +2947,7 @@ bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) { } if (register_space) { - auto bp = decl->BindingPoint(); + auto bp = var->BindingPoint(); out << " : register(" << register_space << bp.binding->value << ", space" << bp.group->value << ")"; } @@ -3078,8 +3081,8 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { if (wgsize[i].overridable_const) { auto* global = builder_.Sem().Get(wgsize[i].overridable_const); - if (!global->IsOverridable()) { - TINT_ICE(Writer, builder_.Diagnostics()) + if (!global->Declaration()->Is()) { + TINT_ICE(Writer, diagnostics_) << "expected a pipeline-overridable constant"; } out << kSpecConstantPrefix << global->ConstantId(); @@ -3611,7 +3614,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { return EmitSwitch(s); }, [&](const ast::VariableDeclStatement* v) { // - return EmitVariable(v->variable); + return Switch( + v->variable, // + [&](const ast::Var* var) { return EmitVar(var); }, + [&](const ast::Let* let) { return EmitLet(let); }, + [&](Default) { // + TINT_ICE(Writer, diagnostics_) + << "unknown variable type: " << v->variable->TypeInfo().name; + return false; + }); }, [&](Default) { // diagnostics_.add_error(diag::System::Writer, @@ -4018,20 +4029,11 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression* return true; } -bool GeneratorImpl::EmitVariable(const ast::Variable* var) { +bool GeneratorImpl::EmitVar(const ast::Var* var) { auto* sem = builder_.Sem().Get(var); auto* type = sem->Type()->UnwrapRef(); - // TODO(dsinclair): Handle variable attributes - if (!var->attributes.empty()) { - diagnostics_.add_error(diag::System::Writer, "Variable attributes are not handled yet"); - return false; - } - auto out = line(); - if (var->is_const) { - out << "const "; - } if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), builder_.Symbols().NameFor(var->symbol))) { return false; @@ -4053,60 +4055,71 @@ bool GeneratorImpl::EmitVariable(const ast::Variable* var) { return true; } -bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { - for (auto* d : var->attributes) { - if (!d->Is()) { - diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid"); - return false; - } - } - if (!var->is_const) { - diagnostics_.add_error(diag::System::Writer, "Expected a const value"); +bool GeneratorImpl::EmitLet(const ast::Let* let) { + auto* sem = builder_.Sem().Get(let); + auto* type = sem->Type()->UnwrapRef(); + + auto out = line(); + out << "const "; + if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, + builder_.Symbols().NameFor(let->symbol))) { return false; } + out << " = "; + if (!EmitExpression(out, let->constructor)) { + return false; + } + out << ";"; - auto* sem = builder_.Sem().Get(var); + return true; +} + +bool GeneratorImpl::EmitProgramConstVariable(const ast::Let* let) { + auto* sem = builder_.Sem().Get(let); auto* type = sem->Type(); - auto* global = sem->As(); - if (global && global->IsOverridable()) { - auto const_id = global->ConstantId(); + auto out = line(); + out << "static const "; + if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined, + builder_.Symbols().NameFor(let->symbol))) { + return false; + } + out << " = "; + if (!EmitExpression(out, let->constructor)) { + return false; + } + out << ";"; - line() << "#ifndef " << kSpecConstantPrefix << const_id; + return true; +} - if (var->constructor != nullptr) { - auto out = line(); - out << "#define " << kSpecConstantPrefix << const_id << " "; - if (!EmitExpression(out, var->constructor)) { - return false; - } - } else { - line() << "#error spec constant required for constant id " << const_id; - } - line() << "#endif"; - { - auto out = line(); - out << "static const "; - if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), - builder_.Symbols().NameFor(var->symbol))) { - return false; - } - out << " = " << kSpecConstantPrefix << const_id << ";"; +bool GeneratorImpl::EmitOverride(const ast::Override* override) { + auto* sem = builder_.Sem().Get(override); + auto* type = sem->Type(); + + auto const_id = sem->ConstantId(); + + line() << "#ifndef " << kSpecConstantPrefix << const_id; + + if (override->constructor != nullptr) { + auto out = line(); + out << "#define " << kSpecConstantPrefix << const_id << " "; + if (!EmitExpression(out, override->constructor)) { + return false; } } else { + line() << "#error spec constant required for constant id " << const_id; + } + line() << "#endif"; + { auto out = line(); out << "static const "; if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), - builder_.Symbols().NameFor(var->symbol))) { + builder_.Symbols().NameFor(override->symbol))) { return false; } - out << " = "; - if (!EmitExpression(out, var->constructor)) { - return false; - } - out << ";"; + out << " = " << kSpecConstantPrefix << const_id << ";"; } - return true; } diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h index 02051423df..485b779c40 100644 --- a/src/tint/writer/hlsl/generator_impl.h +++ b/src/tint/writer/hlsl/generator_impl.h @@ -303,19 +303,22 @@ class GeneratorImpl : public TextGenerator { bool EmitGlobalVariable(const ast::Variable* global); /// Handles emitting a global variable with the uniform storage class - /// @param var the global variable + /// @param var the AST node for the 'var' + /// @param sem the semantic node for the 'var' /// @returns true on success - bool EmitUniformVariable(const sem::Variable* var); + bool EmitUniformVariable(const ast::Var* var, const sem::Variable* sem); /// Handles emitting a global variable with the storage storage class - /// @param var the global variable + /// @param var the AST node for the 'var' + /// @param sem the semantic node for the 'var' /// @returns true on success - bool EmitStorageVariable(const sem::Variable* var); + bool EmitStorageVariable(const ast::Var* var, const sem::Variable* sem); /// Handles emitting a global variable with the handle storage class - /// @param var the global variable + /// @param var the AST node for the 'var' + /// @param sem the semantic node for the 'var' /// @returns true on success - bool EmitHandleVariable(const sem::Variable* var); + bool EmitHandleVariable(const ast::Var* var, const sem::Variable* sem); /// Handles emitting a global variable with the private storage class /// @param var the global variable @@ -437,14 +440,22 @@ class GeneratorImpl : public TextGenerator { /// @param type the type to emit the value for /// @returns true if the zero value was successfully emitted. bool EmitZeroValue(std::ostream& out, const sem::Type* type); - /// Handles generating a variable + /// Handles generating a 'var' declaration /// @param var the variable to generate /// @returns true if the variable was emitted - bool EmitVariable(const ast::Variable* var); - /// Handles generating a program scope constant variable - /// @param var the variable to emit + bool EmitVar(const ast::Var* var); + /// Handles generating a function-scope 'let' declaration + /// @param let the variable to generate /// @returns true if the variable was emitted - bool EmitProgramConstVariable(const ast::Variable* var); + bool EmitLet(const ast::Let* let); + /// Handles generating a module-scope 'let' declaration + /// @param let the 'let' to emit + /// @returns true if the variable was emitted + bool EmitProgramConstVariable(const ast::Let* let); + /// Handles generating a module-scope 'override' declaration + /// @param override the 'override' to emit + /// @returns true if the variable was emitted + bool EmitOverride(const ast::Override* override); /// Emits call to a helper vector assignment function for the input assignment /// statement and vector type. This is used to work around FXC issues where /// assignments to vectors with dynamic indices cause compilation failures. diff --git a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc index 2c96744623..19be5e6fc5 100644 --- a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc +++ b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc @@ -40,7 +40,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(var)) << gen.error(); EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 #define WGSL_SPEC_CONSTANT_23 3.0f #endif @@ -56,7 +56,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(var)) << gen.error(); EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 #error spec constant required for constant id 23 #endif @@ -73,8 +73,8 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(a)) << gen.error(); - ASSERT_TRUE(gen.EmitProgramConstVariable(b)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(a)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(b)) << gen.error(); EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0 #define WGSL_SPEC_CONSTANT_0 3.0f #endif diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index e5ba3a8e1a..1181c02f12 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -253,16 +253,13 @@ bool GeneratorImpl::Generate() { [&](const ast::Alias*) { return true; // folded away by the writer }, - [&](const ast::Variable* var) { - if (var->is_const) { - TINT_DEFER(line()); - return EmitProgramConstVariable(var); - } - // These are pushed into the entry point by sanitizer transforms. - TINT_ICE(Writer, diagnostics_) - << "module-scope variables should have been handled by the MSL " - "sanitizer"; - return false; + [&](const ast::Let* let) { + TINT_DEFER(line()); + return EmitProgramConstVariable(let); + }, + [&](const ast::Override* override) { + TINT_DEFER(line()); + return EmitOverride(override); }, [&](const ast::Function* func) { TINT_DEFER(line()); @@ -1866,8 +1863,8 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { // Returns the binding index of a variable, requiring that the group // attribute have a value of zero. const uint32_t kInvalidBindingIndex = std::numeric_limits::max(); - auto get_binding_index = [&](const ast::Variable* var) -> uint32_t { - auto bp = var->BindingPoint(); + auto get_binding_index = [&](const ast::Parameter* param) -> uint32_t { + auto bp = param->BindingPoint(); if (bp.group == nullptr || bp.binding == nullptr) { TINT_ICE(Writer, diagnostics_) << "missing binding attributes for entry point parameter"; @@ -1890,15 +1887,15 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { // Emit entry point parameters. bool first = true; - for (auto* var : func->params) { + for (auto* param : func->params) { if (!first) { out << ", "; } first = false; - auto* type = program_->Sem().Get(var)->Type()->UnwrapRef(); + auto* type = program_->Sem().Get(param)->Type()->UnwrapRef(); - auto param_name = program_->Symbols().NameFor(var->symbol); + auto param_name = program_->Symbols().NameFor(param->symbol); if (!EmitType(out, type, param_name)) { return false; } @@ -1910,26 +1907,26 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { if (type->Is()) { out << " [[stage_in]]"; } else if (type->is_handle()) { - uint32_t binding = get_binding_index(var); + uint32_t binding = get_binding_index(param); if (binding == kInvalidBindingIndex) { return false; } - if (var->type->Is()) { + if (param->type->Is()) { out << " [[sampler(" << binding << ")]]"; - } else if (var->type->Is()) { + } else if (param->type->Is()) { out << " [[texture(" << binding << ")]]"; } else { TINT_ICE(Writer, diagnostics_) << "invalid handle type entry point parameter"; return false; } - } else if (auto* ptr = var->type->As()) { + } else if (auto* ptr = param->type->As()) { auto sc = ptr->storage_class; if (sc == ast::StorageClass::kWorkgroup) { auto& allocations = workgroup_allocations_[func_name]; out << " [[threadgroup(" << allocations.size() << ")]]"; allocations.push_back(program_->Sem().Get(ptr->type)->Size()); } else if (sc == ast::StorageClass::kStorage || sc == ast::StorageClass::kUniform) { - uint32_t binding = get_binding_index(var); + uint32_t binding = get_binding_index(param); if (binding == kInvalidBindingIndex) { return false; } @@ -1940,7 +1937,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { return false; } } else { - auto& attrs = var->attributes; + auto& attrs = param->attributes; bool builtin_found = false; for (auto* attr : attrs) { auto* builtin = attr->As(); @@ -2340,8 +2337,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) { return EmitSwitch(s); }, [&](const ast::VariableDeclStatement* v) { // - auto* var = program_->Sem().Get(v->variable); - return EmitVariable(var); + return Switch( + v->variable, // + [&](const ast::Var* var) { return EmitVar(var); }, + [&](const ast::Let* let) { return EmitLet(let); }, + [&](Default) { // + TINT_ICE(Writer, diagnostics_) + << "unknown statement type: " << stmt->TypeInfo().name; + return false; + }); }, [&](Default) { diagnostics_.add_error(diag::System::Writer, @@ -2918,19 +2922,13 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression* return true; } -bool GeneratorImpl::EmitVariable(const sem::Variable* var) { - auto* decl = var->Declaration(); - - for (auto* attr : decl->attributes) { - if (!attr->Is()) { - TINT_ICE(Writer, diagnostics_) << "unexpected variable attribute"; - return false; - } - } +bool GeneratorImpl::EmitVar(const ast::Var* var) { + auto* sem = program_->Sem().Get(var); + auto* type = sem->Type()->UnwrapRef(); auto out = line(); - switch (var->StorageClass()) { + switch (sem->StorageClass()) { case ast::StorageClass::kFunction: case ast::StorageClass::kHandle: case ast::StorageClass::kNone: @@ -2946,12 +2944,7 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) { return false; } - auto* type = var->Type()->UnwrapRef(); - - std::string name = program_->Symbols().NameFor(decl->symbol); - if (decl->is_const) { - name = "const " + name; - } + std::string name = program_->Symbols().NameFor(var->symbol); if (!EmitType(out, type, name)) { return false; } @@ -2960,14 +2953,14 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) { out << " " << name; } - if (decl->constructor != nullptr) { + if (var->constructor != nullptr) { out << " = "; - if (!EmitExpression(out, decl->constructor)) { + if (!EmitExpression(out, var->constructor)) { return false; } - } else if (var->StorageClass() == ast::StorageClass::kPrivate || - var->StorageClass() == ast::StorageClass::kFunction || - var->StorageClass() == ast::StorageClass::kNone) { + } else if (sem->StorageClass() == ast::StorageClass::kPrivate || + sem->StorageClass() == ast::StorageClass::kFunction || + sem->StorageClass() == ast::StorageClass::kNone) { out << " = "; if (!EmitZeroValue(out, type)) { return false; @@ -2978,34 +2971,63 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) { return true; } -bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { - for (auto* d : var->attributes) { - if (!d->Is()) { - diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid"); +bool GeneratorImpl::EmitLet(const ast::Let* let) { + auto* sem = program_->Sem().Get(let); + auto* type = sem->Type(); + + auto out = line(); + + switch (sem->StorageClass()) { + case ast::StorageClass::kFunction: + case ast::StorageClass::kHandle: + case ast::StorageClass::kNone: + break; + case ast::StorageClass::kPrivate: + out << "thread "; + break; + case ast::StorageClass::kWorkgroup: + out << "threadgroup "; + break; + default: + TINT_ICE(Writer, diagnostics_) << "unhandled variable storage class"; return false; - } } - if (!var->is_const) { - diagnostics_.add_error(diag::System::Writer, "Expected a const value"); + + std::string name = "const " + program_->Symbols().NameFor(let->symbol); + if (!EmitType(out, type, name)) { return false; } + // Variable name is output as part of the type for arrays and pointers. + if (!type->Is() && !type->Is()) { + out << " " << name; + } + + out << " = "; + if (!EmitExpression(out, let->constructor)) { + return false; + } + out << ";"; + + return true; +} + +bool GeneratorImpl::EmitProgramConstVariable(const ast::Let* let) { + auto* global = program_->Sem().Get(let); + auto* type = global->Type(); + auto out = line(); out << "constant "; - auto* type = program_->Sem().Get(var)->Type()->UnwrapRef(); - if (!EmitType(out, type, program_->Symbols().NameFor(var->symbol))) { + if (!EmitType(out, type, program_->Symbols().NameFor(let->symbol))) { return false; } if (!type->Is()) { - out << " " << program_->Symbols().NameFor(var->symbol); + out << " " << program_->Symbols().NameFor(let->symbol); } - auto* global = program_->Sem().Get(var); - if (global && global->IsOverridable()) { - out << " [[function_constant(" << global->ConstantId() << ")]]"; - } else if (var->constructor != nullptr) { + if (let->constructor != nullptr) { out << " = "; - if (!EmitExpression(out, var->constructor)) { + if (!EmitExpression(out, let->constructor)) { return false; } } @@ -3014,6 +3036,24 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { return true; } +bool GeneratorImpl::EmitOverride(const ast::Override* override) { + auto* global = program_->Sem().Get(override); + auto* type = global->Type(); + + auto out = line(); + out << "constant "; + if (!EmitType(out, type, program_->Symbols().NameFor(override->symbol))) { + return false; + } + if (!type->Is()) { + out << " " << program_->Symbols().NameFor(override->symbol); + } + + out << " [[function_constant(" << global->ConstantId() << ")]];"; + + return true; +} + GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::Type* ty) { return Switch( ty, diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h index a05f3b1c77..59af5c3bd6 100644 --- a/src/tint/writer/msl/generator_impl.h +++ b/src/tint/writer/msl/generator_impl.h @@ -348,14 +348,22 @@ class GeneratorImpl : public TextGenerator { /// @param expr the expression to emit /// @returns true if the expression was emitted bool EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression* expr); - /// Handles generating a variable + /// Handles generating a 'var' declaration /// @param var the variable to generate /// @returns true if the variable was emitted - bool EmitVariable(const sem::Variable* var); - /// Handles generating a program scope constant variable - /// @param var the variable to emit + bool EmitVar(const ast::Var* var); + /// Handles generating a function-scope 'let' declaration + /// @param let the variable to generate /// @returns true if the variable was emitted - bool EmitProgramConstVariable(const ast::Variable* var); + bool EmitLet(const ast::Let* let); + /// Handles generating a module-scope 'let' declaration + /// @param let the 'let' to emit + /// @returns true if the variable was emitted + bool EmitProgramConstVariable(const ast::Let* let); + /// Handles generating a module-scope 'override' declaration + /// @param override the 'override' to emit + /// @returns true if the variable was emitted + bool EmitOverride(const ast::Override* override); /// Emits the zero value for the given type /// @param out the output of the expression stream /// @param type the type to emit the value for diff --git a/src/tint/writer/msl/generator_impl_module_constant_test.cc b/src/tint/writer/msl/generator_impl_module_constant_test.cc index 2b70da48d3..f23e51aa77 100644 --- a/src/tint/writer/msl/generator_impl_module_constant_test.cc +++ b/src/tint/writer/msl/generator_impl_module_constant_test.cc @@ -39,7 +39,7 @@ TEST_F(MslGeneratorImplTest, Emit_SpecConstant) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(var)) << gen.error(); EXPECT_EQ(gen.result(), "constant float pos [[function_constant(23)]];\n"); } @@ -52,8 +52,8 @@ TEST_F(MslGeneratorImplTest, Emit_SpecConstant_NoId) { GeneratorImpl& gen = Build(); - ASSERT_TRUE(gen.EmitProgramConstVariable(var_a)) << gen.error(); - ASSERT_TRUE(gen.EmitProgramConstVariable(var_b)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(var_a)) << gen.error(); + ASSERT_TRUE(gen.EmitOverride(var_b)) << gen.error(); EXPECT_EQ(gen.result(), R"(constant float a [[function_constant(0)]]; constant float b [[function_constant(1)]]; )"); diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index 33061eab2d..4368e3c006 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -533,7 +533,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) { // Make the constant specializable. auto* sem_const = builder_.Sem().Get(wgsize[i].overridable_const); - if (!sem_const->IsOverridable()) { + if (!sem_const->Declaration()->Is()) { TINT_ICE(Writer, builder_.Diagnostics()) << "expected a pipeline-overridable constant"; } @@ -692,19 +692,19 @@ uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) { }); } -bool Builder::GenerateFunctionVariable(const ast::Variable* var) { +bool Builder::GenerateFunctionVariable(const ast::Variable* v) { uint32_t init_id = 0; - if (var->constructor) { - init_id = GenerateExpressionWithLoadIfNeeded(var->constructor); + if (v->constructor) { + init_id = GenerateExpressionWithLoadIfNeeded(v->constructor); if (init_id == 0) { return false; } } - auto* sem = builder_.Sem().Get(var); + auto* sem = builder_.Sem().Get(v); - if (var->is_const) { - if (!var->constructor) { + if (auto* let = v->As()) { + if (!let->constructor) { error_ = "missing constructor for constant"; return false; } @@ -721,8 +721,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) { return false; } - push_debug(spv::Op::OpName, - {Operand(var_id), Operand(builder_.Symbols().NameFor(var->symbol))}); + push_debug(spv::Op::OpName, {Operand(var_id), Operand(builder_.Symbols().NameFor(v->symbol))}); // TODO(dsinclair) We could detect if the constructor is fully const and emit // an initializer value for the variable instead of doing the OpLoad. @@ -733,7 +732,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) { push_function_var( {Operand(type_id), result, U32Operand(ConvertStorageClass(sc)), Operand(null_id)}); - if (var->constructor) { + if (v->constructor) { if (!GenerateStore(var_id, init_id)) { return false; } @@ -748,66 +747,61 @@ bool Builder::GenerateStore(uint32_t to, uint32_t from) { return push_function_inst(spv::Op::OpStore, {Operand(to), Operand(from)}); } -bool Builder::GenerateGlobalVariable(const ast::Variable* var) { - auto* sem = builder_.Sem().Get(var); +bool Builder::GenerateGlobalVariable(const ast::Variable* v) { + auto* sem = builder_.Sem().Get(v); auto* type = sem->Type()->UnwrapRef(); uint32_t init_id = 0; - if (var->constructor) { - if (!var->is_overridable) { - auto* ctor = builder_.Sem().Get(var->constructor); - if (auto constant = ctor->ConstantValue()) { + if (auto* ctor = v->constructor) { + if (!v->Is()) { + auto* ctor_sem = builder_.Sem().Get(ctor); + if (auto constant = ctor_sem->ConstantValue()) { init_id = GenerateConstantIfNeeded(std::move(constant)); } } if (init_id == 0) { - init_id = GenerateConstructorExpression(var, var->constructor); + init_id = GenerateConstructorExpression(v, v->constructor); } if (init_id == 0) { return false; } } - if (var->is_const) { - if (!var->constructor) { - // Constants must have an initializer unless they are overridable. - if (!var->is_overridable) { - error_ = "missing constructor for constant"; - return false; - } - - // SPIR-V requires specialization constants to have initializers. - init_id = Switch( - type, // - [&](const sem::F32*) { - ast::FloatLiteralExpression l(ProgramID{}, Source{}, 0, - ast::FloatLiteralExpression::Suffix::kF); - return GenerateLiteralIfNeeded(var, &l); - }, - [&](const sem::U32*) { - ast::IntLiteralExpression l(ProgramID{}, Source{}, 0, - ast::IntLiteralExpression::Suffix::kU); - return GenerateLiteralIfNeeded(var, &l); - }, - [&](const sem::I32*) { - ast::IntLiteralExpression l(ProgramID{}, Source{}, 0, - ast::IntLiteralExpression::Suffix::kI); - return GenerateLiteralIfNeeded(var, &l); - }, - [&](const sem::Bool*) { - ast::BoolLiteralExpression l(ProgramID{}, Source{}, false); - return GenerateLiteralIfNeeded(var, &l); - }, - [&](Default) { - error_ = "invalid type for pipeline constant ID, must be scalar"; - return 0; - }); - if (init_id == 0) { + if (auto* override = v->As(); override && !override->constructor) { + // SPIR-V requires specialization constants to have initializers. + init_id = Switch( + type, // + [&](const sem::F32*) { + ast::FloatLiteralExpression l(ProgramID{}, Source{}, 0, + ast::FloatLiteralExpression::Suffix::kF); + return GenerateLiteralIfNeeded(override, &l); + }, + [&](const sem::U32*) { + ast::IntLiteralExpression l(ProgramID{}, Source{}, 0, + ast::IntLiteralExpression::Suffix::kU); + return GenerateLiteralIfNeeded(override, &l); + }, + [&](const sem::I32*) { + ast::IntLiteralExpression l(ProgramID{}, Source{}, 0, + ast::IntLiteralExpression::Suffix::kI); + return GenerateLiteralIfNeeded(override, &l); + }, + [&](const sem::Bool*) { + ast::BoolLiteralExpression l(ProgramID{}, Source{}, false); + return GenerateLiteralIfNeeded(override, &l); + }, + [&](Default) { + error_ = "invalid type for pipeline constant ID, must be scalar"; return 0; - } + }); + if (init_id == 0) { + return 0; } + } + + if (v->IsAnyOf()) { push_debug(spv::Op::OpName, - {Operand(init_id), Operand(builder_.Symbols().NameFor(var->symbol))}); + {Operand(init_id), Operand(builder_.Symbols().NameFor(v->symbol))}); RegisterVariable(sem, init_id); return true; @@ -824,12 +818,11 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) { return false; } - push_debug(spv::Op::OpName, - {Operand(var_id), Operand(builder_.Symbols().NameFor(var->symbol))}); + push_debug(spv::Op::OpName, {Operand(var_id), Operand(builder_.Symbols().NameFor(v->symbol))}); OperandList ops = {Operand(type_id), result, U32Operand(ConvertStorageClass(sc))}; - if (var->constructor) { + if (v->constructor) { ops.push_back(Operand(init_id)); } else { auto* st = type->As(); @@ -871,7 +864,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) { push_type(spv::Op::OpVariable, std::move(ops)); - for (auto* attr : var->attributes) { + for (auto* attr : v->attributes) { bool ok = Switch( attr, [&](const ast::BuiltinAttribute* builtin) { @@ -1332,7 +1325,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call, // Generate the zero initializer if there are no values provided. if (args.empty()) { - if (global_var && global_var->IsOverridable()) { + if (global_var && global_var->Declaration()->Is()) { auto constant_id = global_var->ConstantId(); if (result_type->Is()) { return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id)); @@ -1637,7 +1630,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var, ScalarConstant constant; auto* global = builder_.Sem().Get(var); - if (global && global->IsOverridable()) { + if (global && global->Declaration()->Is()) { constant.is_spec_op = true; constant.constant_id = global->ConstantId(); } diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc index 35119ac3a4..7df4fdf33c 100644 --- a/src/tint/writer/wgsl/generator_impl.cc +++ b/src/tint/writer/wgsl/generator_impl.cc @@ -635,46 +635,60 @@ bool GeneratorImpl::EmitStructType(const ast::Struct* str) { return true; } -bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* var) { - if (!var->attributes.empty()) { - if (!EmitAttributes(out, var->attributes)) { +bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* v) { + if (!v->attributes.empty()) { + if (!EmitAttributes(out, v->attributes)) { return false; } out << " "; } - if (var->is_overridable) { - out << "override"; - } else if (var->is_const) { - out << "let"; - } else { - out << "var"; - auto sc = var->declared_storage_class; - auto ac = var->declared_access; - if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) { - out << "<" << sc; - if (ac != ast::Access::kUndefined) { - out << ", "; - if (!EmitAccess(out, ac)) { - return false; + bool ok = Switch( + v, // + [&](const ast::Let* ) { + out << "let"; + return true; + }, + [&](const ast::Override* ) { + out << "override"; + return true; + }, + [&](const ast::Var* var) { + out << "var"; + auto sc = var->declared_storage_class; + auto ac = var->declared_access; + if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) { + out << "<" << sc; + if (ac != ast::Access::kUndefined) { + out << ", "; + if (!EmitAccess(out, ac)) { + return false; + } } + out << ">"; } - out << ">"; - } + return true; + }, + [&](Default) { + TINT_ICE(Writer, diagnostics_) << "unhandled variable type " << v->TypeInfo().name; + return false; + }); + if (!ok) { + return false; } - out << " " << program_->Symbols().NameFor(var->symbol); + out << " " << program_->Symbols().NameFor(v->symbol); - if (auto* ty = var->type) { + if (auto* ty = v->type) { out << " : "; if (!EmitType(out, ty)) { return false; } } - if (var->constructor != nullptr) { + if (v->constructor != nullptr) { out << " = "; - if (!EmitExpression(out, var->constructor)) { + if (!EmitExpression(out, v->constructor)) { return false; } } diff --git a/test/tint/bug/tint/827.wgsl.expected.hlsl b/test/tint/bug/tint/827.wgsl.expected.hlsl index 89de07bccb..936812e743 100644 --- a/test/tint/bug/tint/827.wgsl.expected.hlsl +++ b/test/tint/bug/tint/827.wgsl.expected.hlsl @@ -1,4 +1,5 @@ static const uint width = 128u; + Texture2D tex : register(t0, space0); RWByteAddressBuffer result : register(u1, space0); diff --git a/test/tint/bug/tint/914.wgsl.expected.hlsl b/test/tint/bug/tint/914.wgsl.expected.hlsl index 0baa6c69b6..5652d5cec1 100644 --- a/test/tint/bug/tint/914.wgsl.expected.hlsl +++ b/test/tint/bug/tint/914.wgsl.expected.hlsl @@ -45,6 +45,7 @@ static const uint ColPerThread = 4u; static const uint TileAOuter = 64u; static const uint TileBOuter = 64u; static const uint TileInner = 64u; + groupshared float mm_Asub[64][64]; groupshared float mm_Bsub[64][64];