tint/ast: Derive off `ast::Variable`

Add the new classes:
* `ast::Let`
* `ast::Override`
* `ast::Parameter`
* `ast::Var`

Limit the fields to those that are only applicable for their type.

Note: The resolver and validator is a tangled mess for each of the
variable types. This CL tries to keep the functionality exactly the
same. I'll clean this up in another change.

Bug: tint:1582
Change-Id: Iee83324167ffd4d92ae3032b2134677629c90079
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93780
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2022-06-17 12:48:51 +00:00 committed by Dawn LUCI CQ
parent be6fb2a8d2
commit dcdf66ed5b
66 changed files with 1652 additions and 1193 deletions

View File

@ -264,6 +264,8 @@ libtint_source_set("libtint_core_all_src") {
"ast/interpolate_attribute.h", "ast/interpolate_attribute.h",
"ast/invariant_attribute.cc", "ast/invariant_attribute.cc",
"ast/invariant_attribute.h", "ast/invariant_attribute.h",
"ast/let.cc",
"ast/let.h",
"ast/literal_expression.cc", "ast/literal_expression.cc",
"ast/literal_expression.h", "ast/literal_expression.h",
"ast/location_attribute.cc", "ast/location_attribute.cc",
@ -280,6 +282,10 @@ libtint_source_set("libtint_core_all_src") {
"ast/multisampled_texture.h", "ast/multisampled_texture.h",
"ast/node.cc", "ast/node.cc",
"ast/node.h", "ast/node.h",
"ast/override.cc",
"ast/override.h",
"ast/parameter.cc",
"ast/parameter.h",
"ast/phony_expression.cc", "ast/phony_expression.cc",
"ast/phony_expression.h", "ast/phony_expression.h",
"ast/pipeline_stage.cc", "ast/pipeline_stage.cc",
@ -328,6 +334,8 @@ libtint_source_set("libtint_core_all_src") {
"ast/unary_op.h", "ast/unary_op.h",
"ast/unary_op_expression.cc", "ast/unary_op_expression.cc",
"ast/unary_op_expression.h", "ast/unary_op_expression.h",
"ast/var.cc",
"ast/var.h",
"ast/variable.cc", "ast/variable.cc",
"ast/variable.h", "ast/variable.h",
"ast/variable_decl_statement.cc", "ast/variable_decl_statement.cc",

View File

@ -151,6 +151,8 @@ set(TINT_LIB_SRCS
ast/interpolate_attribute.h ast/interpolate_attribute.h
ast/invariant_attribute.cc ast/invariant_attribute.cc
ast/invariant_attribute.h ast/invariant_attribute.h
ast/let.cc
ast/let.h
ast/literal_expression.cc ast/literal_expression.cc
ast/literal_expression.h ast/literal_expression.h
ast/location_attribute.cc ast/location_attribute.cc
@ -167,6 +169,10 @@ set(TINT_LIB_SRCS
ast/multisampled_texture.h ast/multisampled_texture.h
ast/node.cc ast/node.cc
ast/node.h ast/node.h
ast/override.cc
ast/override.h
ast/parameter.cc
ast/parameter.h
ast/phony_expression.cc ast/phony_expression.cc
ast/phony_expression.h ast/phony_expression.h
ast/pipeline_stage.cc ast/pipeline_stage.cc
@ -215,6 +221,8 @@ set(TINT_LIB_SRCS
ast/unary_op_expression.h ast/unary_op_expression.h
ast/unary_op.cc ast/unary_op.cc
ast/unary_op.h ast/unary_op.h
ast/var.cc
ast/var.h
ast/variable_decl_statement.cc ast/variable_decl_statement.cc
ast/variable_decl_statement.h ast/variable_decl_statement.h
ast/variable.cc ast/variable.cc

View File

@ -40,7 +40,7 @@ Function::Function(ProgramID pid,
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
for (auto* param : params) { for (auto* param : params) {
TINT_ASSERT(AST, param && param->is_const); TINT_ASSERT(AST, tint::Is<Parameter>(param));
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, param, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, param, program_id);
} }
TINT_ASSERT(AST, symbol.IsValid()); TINT_ASSERT(AST, symbol.IsValid());

View File

@ -26,14 +26,11 @@
#include "src/tint/ast/builtin_attribute.h" #include "src/tint/ast/builtin_attribute.h"
#include "src/tint/ast/group_attribute.h" #include "src/tint/ast/group_attribute.h"
#include "src/tint/ast/location_attribute.h" #include "src/tint/ast/location_attribute.h"
#include "src/tint/ast/parameter.h"
#include "src/tint/ast/pipeline_stage.h" #include "src/tint/ast/pipeline_stage.h"
#include "src/tint/ast/variable.h"
namespace tint::ast { namespace tint::ast {
/// ParameterList is a list of function parameters
using ParameterList = std::vector<const Variable*>;
/// A Function statement. /// A Function statement.
class Function final : public Castable<Function, Node> { class Function final : public Castable<Function, Node> {
public: public:

View File

@ -122,18 +122,6 @@ TEST_F(FunctionTest, Assert_DifferentProgramID_ReturnAttr) {
"internal compiler error"); "internal compiler error");
} }
TEST_F(FunctionTest, Assert_NonConstParam) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
ParameterList params;
params.push_back(b.Var("var", b.ty.i32(), ast::StorageClass::kNone));
b.Func("f", params, b.ty.void_(), {});
},
"internal compiler error");
}
using FunctionListTest = TestHelper; using FunctionListTest = TestHelper;
TEST_F(FunctionListTest, FindSymbol) { TEST_F(FunctionListTest, FindSymbol) {

46
src/tint/ast/let.cc Normal file
View File

@ -0,0 +1,46 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/let.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Let);
namespace tint::ast {
Let::Let(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
const Expression* ctor,
AttributeList attrs)
: Base(pid, src, sym, ty, ctor, attrs) {
TINT_ASSERT(AST, ctor != nullptr);
}
Let::Let(Let&&) = default;
Let::~Let() = default;
const Let* Let::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Let>(src, sym, ty, ctor, attrs);
}
} // namespace tint::ast

61
src/tint/ast/let.h Normal file
View File

@ -0,0 +1,61 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_LET_H_
#define SRC_TINT_AST_LET_H_
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// A "let" declaration is a name for a function-scoped runtime typed value.
///
/// Examples:
///
/// ```
/// let twice_depth : i32 = width + width; // Must have initializer
/// ```
/// @see https://www.w3.org/TR/WGSL/#let-decls
class Let final : public Castable<Let, Variable> {
public:
/// Create a 'let' variable
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param constructor the constructor expression
/// @param attributes the variable attributes
Let(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
const Expression* constructor,
AttributeList attributes);
/// Move constructor
Let(Let&&);
/// Destructor
~Let() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Let* Clone(CloneContext* ctx) const override;
};
} // namespace tint::ast
#endif // SRC_TINT_AST_LET_H_

View File

@ -77,6 +77,19 @@ class Module final : public Castable<Module, Node> {
/// @returns the global variables for the module /// @returns the global variables for the module
VariableList& GlobalVariables() { return global_variables_; } VariableList& GlobalVariables() { return global_variables_; }
/// @returns the global variable declarations of kind 'T' for the module
template <typename T, typename = traits::EnableIfIsType<T, ast::Variable>>
std::vector<const T*> Globals() const {
std::vector<const T*> out;
out.reserve(global_variables_.size());
for (auto* global : global_variables_) {
if (auto* var = global->As<T>()) {
out.emplace_back(var);
}
}
return out;
}
/// @returns the extension set for the module /// @returns the extension set for the module
const EnableList& Enables() const { return enables_; } const EnableList& Enables() const { return enables_; }

44
src/tint/ast/override.cc Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/override.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Override);
namespace tint::ast {
Override::Override(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
const Expression* ctor,
AttributeList attrs)
: Base(pid, src, sym, ty, ctor, attrs) {}
Override::Override(Override&&) = default;
Override::~Override() = default;
const Override* Override::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Override>(src, sym, ty, ctor, attrs);
}
} // namespace tint::ast

62
src/tint/ast/override.h Normal file
View File

@ -0,0 +1,62 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_OVERRIDE_H_
#define SRC_TINT_AST_OVERRIDE_H_
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// An "override" declaration - a name for a pipeline-overridable constant.
/// Examples:
///
/// ```
/// override radius : i32 = 2; // Can be overridden by name.
/// @id(5) override width : i32 = 2; // Can be overridden by ID.
/// override scale : f32; // No default - must be overridden.
/// ```
/// @see https://www.w3.org/TR/WGSL/#override-decls
class Override final : public Castable<Override, Variable> {
public:
/// Create an 'override' pipeline-overridable constant.
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param constructor the constructor expression
/// @param attributes the variable attributes
Override(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
const Expression* constructor,
AttributeList attributes);
/// Move constructor
Override(Override&&);
/// Destructor
~Override() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Override* Clone(CloneContext* ctx) const override;
};
} // namespace tint::ast
#endif // SRC_TINT_AST_OVERRIDE_H_

42
src/tint/ast/parameter.cc Normal file
View File

@ -0,0 +1,42 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/parameter.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Parameter);
namespace tint::ast {
Parameter::Parameter(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
AttributeList attrs)
: Base(pid, src, sym, ty, nullptr, attrs) {}
Parameter::Parameter(Parameter&&) = default;
Parameter::~Parameter() = default;
const Parameter* Parameter::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Parameter>(src, sym, ty, attrs);
}
} // namespace tint::ast

66
src/tint/ast/parameter.h Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_PARAMETER_H_
#define SRC_TINT_AST_PARAMETER_H_
#include <vector>
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// A formal parameter to a function - a name for a typed value to be passed into a function.
/// Example:
///
/// ```
/// fn twice(a: i32) -> i32 { // "a:i32" is the formal parameter
/// return a + a;
/// }
/// ```
///
/// @see https://www.w3.org/TR/WGSL/#creation-time-consts
class Parameter final : public Castable<Parameter, Variable> {
public:
/// Create a 'parameter' creation-time value variable.
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param attributes the variable attributes
Parameter(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
AttributeList attributes);
/// Move constructor
Parameter(Parameter&&);
/// Destructor
~Parameter() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Parameter* Clone(CloneContext* ctx) const override;
};
/// A list of parameters
using ParameterList = std::vector<const Parameter*>;
} // namespace tint::ast
#endif // SRC_TINT_AST_PARAMETER_H_

49
src/tint/ast/var.cc Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/var.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Var);
namespace tint::ast {
Var::Var(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
StorageClass storage_class,
Access access,
const Expression* ctor,
AttributeList attrs)
: Base(pid, src, sym, ty, ctor, attrs),
declared_storage_class(storage_class),
declared_access(access) {}
Var::Var(Var&&) = default;
Var::~Var() = default;
const Var* Var::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Var>(src, sym, ty, declared_storage_class, declared_access, ctor,
attrs);
}
} // namespace tint::ast

86
src/tint/ast/var.h Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_VAR_H_
#define SRC_TINT_AST_VAR_H_
#include <utility>
#include <vector>
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// A "var" declaration is a name for typed storage.
///
/// Examples:
///
/// ```
/// // Declared outside a function, i.e. at module scope, requires
/// // a storage class.
/// var<workgroup> width : i32; // no initializer
/// var<private> height : i32 = 3; // with initializer
///
/// // A variable declared inside a function doesn't take a storage class,
/// // and maps to SPIR-V Function storage.
/// var computed_depth : i32;
/// var area : i32 = compute_area(width, height);
/// ```
///
/// @see https://www.w3.org/TR/WGSL/#var-decls
class Var final : public Castable<Var, Variable> {
public:
/// Create a 'var' variable
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param declared_storage_class the declared storage class
/// @param declared_access the declared access control
/// @param constructor the constructor expression
/// @param attributes the variable attributes
Var(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
StorageClass declared_storage_class,
Access declared_access,
const Expression* constructor,
AttributeList attributes);
/// Move constructor
Var(Var&&);
/// Destructor
~Var() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Var* Clone(CloneContext* ctx) const override;
/// The declared storage class
const StorageClass declared_storage_class;
/// The declared access control
const Access declared_access;
};
/// A list of `var` declarations
using VarList = std::vector<const Var*>;
} // namespace tint::ast
#endif // SRC_TINT_AST_VAR_H_

View File

@ -13,9 +13,8 @@
// limitations under the License. // limitations under the License.
#include "src/tint/ast/variable.h" #include "src/tint/ast/variable.h"
#include "src/tint/ast/binding_attribute.h"
#include "src/tint/program_builder.h" #include "src/tint/ast/group_attribute.h"
#include "src/tint/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Variable); TINT_INSTANTIATE_TYPEINFO(tint::ast::Variable);
@ -24,24 +23,11 @@ namespace tint::ast {
Variable::Variable(ProgramID pid, Variable::Variable(ProgramID pid,
const Source& src, const Source& src,
const Symbol& sym, const Symbol& sym,
StorageClass dsc,
Access da,
const ast::Type* ty, const ast::Type* ty,
bool constant,
bool overridable,
const Expression* ctor, const Expression* ctor,
AttributeList attrs) AttributeList attrs)
: Base(pid, src), : Base(pid, src), symbol(sym), type(ty), constructor(ctor), attributes(std::move(attrs)) {
symbol(sym),
type(ty),
is_const(constant),
is_overridable(overridable),
constructor(ctor),
attributes(std::move(attrs)),
declared_storage_class(dsc),
declared_access(da) {
TINT_ASSERT(AST, symbol.IsValid()); TINT_ASSERT(AST, symbol.IsValid());
TINT_ASSERT(AST, is_overridable ? is_const : true);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, constructor, program_id); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, constructor, program_id);
} }
@ -54,23 +40,12 @@ VariableBindingPoint Variable::BindingPoint() const {
const GroupAttribute* group = nullptr; const GroupAttribute* group = nullptr;
const BindingAttribute* binding = nullptr; const BindingAttribute* binding = nullptr;
for (auto* attr : attributes) { for (auto* attr : attributes) {
if (auto* g = attr->As<GroupAttribute>()) { Switch(
group = g; attr, //
} else if (auto* b = attr->As<BindingAttribute>()) { [&](const GroupAttribute* a) { group = a; },
binding = b; [&](const BindingAttribute* a) { binding = a; });
}
} }
return VariableBindingPoint{group, binding}; return VariableBindingPoint{group, binding};
} }
const Variable* Variable::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Variable>(src, sym, declared_storage_class, declared_access, ty,
is_const, is_overridable, ctor, attrs);
}
} // namespace tint::ast } // namespace tint::ast

View File

@ -45,112 +45,38 @@ struct VariableBindingPoint {
inline operator bool() const { return group && binding; } inline operator bool() const { return group && binding; }
}; };
/// A Variable statement. /// Variable is the base class for Var, Let, Const, Override and Parameter.
/// ///
/// An instance of this class represents one of four constructs in WGSL: "var" /// An instance of this class represents one of five constructs in WGSL: "var" declaration, "let"
/// declaration, "let" declaration, "override" declaration, or formal parameter /// declaration, "override" declaration, "const" declaration, or formal parameter to a function.
/// to a function.
/// ///
/// 1. A "var" declaration is a name for typed storage. Examples: /// @see https://www.w3.org/TR/WGSL/#value-decls
/// class Variable : public Castable<Variable, Node> {
/// // Declared outside a function, i.e. at module scope, requires
/// // a storage class.
/// var<workgroup> width : i32; // no initializer
/// var<private> height : i32 = 3; // with initializer
///
/// // A variable declared inside a function doesn't take a storage class,
/// // and maps to SPIR-V Function storage.
/// var computed_depth : i32;
/// var area : i32 = compute_area(width, height);
///
/// 2. A "let" declaration is a name for a typed value. Examples:
///
/// let twice_depth : i32 = width + width; // Must have initializer
///
/// 3. An "override" declaration is a name for a pipeline-overridable constant.
/// Examples:
///
/// override radius : i32 = 2; // Can be overridden by name.
/// @id(5) override width : i32 = 2; // Can be overridden by ID.
/// override scale : f32; // No default - must be overridden.
///
/// 4. A formal parameter to a function is a name for a typed value to
/// be passed into a function. Example:
///
/// fn twice(a: i32) -> i32 { // "a:i32" is the formal parameter
/// return a + a;
/// }
///
/// From the WGSL draft, about "var"::
///
/// A variable is a named reference to storage that can contain a value of a
/// particular type.
///
/// Two types are associated with a variable: its store type (the type of
/// value that may be placed in the referenced storage) and its reference
/// type (the type of the variable itself). If a variable has store type T
/// and storage class S, then its reference type is pointer-to-T-in-S.
///
/// This class uses the term "type" to refer to:
/// the value type of a "let",
/// the value type of an "override",
/// the value type of the formal parameter,
/// or the store type of the "var".
//
/// Setting is_const:
/// - "var" gets false
/// - "let" gets true
/// - "override" gets true
/// - formal parameter gets true
///
/// Setting is_overrideable:
/// - "var" gets false
/// - "let" gets false
/// - "override" gets true
/// - formal parameter gets false
///
/// Setting storage class:
/// - "var" is StorageClass::kNone when using the
/// defaulting syntax for a "var" declared inside a function.
/// - "let" is always StorageClass::kNone.
/// - formal parameter is always StorageClass::kNone.
class Variable final : public Castable<Variable, Node> {
public: public:
/// Create a variable /// Constructor
/// @param program_id the identifier of the program that owns this node /// @param program_id the identifier of the program that owns this node
/// @param source the variable source /// @param source the variable source
/// @param sym the variable symbol /// @param sym the variable symbol
/// @param declared_storage_class the declared storage class
/// @param declared_access the declared access control
/// @param type the declared variable type /// @param type the declared variable type
/// @param is_const true if the variable is const
/// @param is_overridable true if the variable is pipeline-overridable
/// @param constructor the constructor expression /// @param constructor the constructor expression
/// @param attributes the variable attributes /// @param attributes the variable attributes
Variable(ProgramID program_id, Variable(ProgramID program_id,
const Source& source, const Source& source,
const Symbol& sym, const Symbol& sym,
StorageClass declared_storage_class,
Access declared_access,
const ast::Type* type, const ast::Type* type,
bool is_const,
bool is_overridable,
const Expression* constructor, const Expression* constructor,
AttributeList attributes); AttributeList attributes);
/// Move constructor /// Move constructor
Variable(Variable&&); Variable(Variable&&);
/// Destructor
~Variable() override; ~Variable() override;
/// @returns the binding point information for the variable /// @returns the binding point information from the variable's attributes.
/// @note binding points should only be applied to Var and Parameter types.
VariableBindingPoint BindingPoint() const; VariableBindingPoint BindingPoint() const;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Variable* Clone(CloneContext* ctx) const override;
/// The variable symbol /// The variable symbol
const Symbol symbol; const Symbol symbol;
@ -159,23 +85,11 @@ class Variable final : public Castable<Variable, Node> {
/// var i = 1; /// var i = 1;
const ast::Type* const type; const ast::Type* const type;
/// True if this is a constant, false otherwise
const bool is_const;
/// True if this is a pipeline-overridable constant, false otherwise
const bool is_overridable;
/// The constructor expression or nullptr if none set /// The constructor expression or nullptr if none set
const Expression* const constructor; const Expression* const constructor;
/// The attributes attached to this variable /// The attributes attached to this variable
const AttributeList attributes; const AttributeList attributes;
/// The declared storage class
const StorageClass declared_storage_class;
/// The declared access control
const Access declared_access;
}; };
/// A list of variables /// A list of variables

View File

@ -183,7 +183,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
auto name = program_->Symbols().NameFor(decl->symbol); auto name = program_->Symbols().NameFor(decl->symbol);
auto* global = var->As<sem::GlobalVariable>(); auto* global = var->As<sem::GlobalVariable>();
if (global && global->IsOverridable()) { if (global && global->Declaration()->Is<ast::Override>()) {
OverridableConstant overridable_constant; OverridableConstant overridable_constant;
overridable_constant.name = name; overridable_constant.name = name;
overridable_constant.numeric_id = global->ConstantId(); overridable_constant.numeric_id = global->ConstantId();
@ -219,7 +219,7 @@ std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
std::map<uint32_t, Scalar> result; std::map<uint32_t, Scalar> result;
for (auto* var : program_->AST().GlobalVariables()) { for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (!global || !global->IsOverridable()) { if (!global || !global->Declaration()->Is<ast::Override>()) {
continue; continue;
} }
@ -276,7 +276,7 @@ std::map<std::string, uint32_t> Inspector::GetConstantNameToIdMap() {
std::map<std::string, uint32_t> result; std::map<std::string, uint32_t> result;
for (auto* var : program_->AST().GlobalVariables()) { for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsOverridable()) { if (global && global->Declaration()->Is<ast::Override>()) {
auto name = program_->Symbols().NameFor(var->symbol); auto name = program_->Symbols().NameFor(var->symbol);
result[name] = global->ConstantId(); result[name] = global->ConstantId();
} }
@ -813,25 +813,24 @@ void Inspector::GenerateSamplerTargets() {
auto* t = c->args[texture_index]; auto* t = c->args[texture_index];
auto* s = c->args[sampler_index]; auto* s = c->args[sampler_index];
GetOriginatingResources(std::array<const ast::Expression*, 2>{t, s}, GetOriginatingResources(
[&](std::array<const sem::GlobalVariable*, 2> globals) { std::array<const ast::Expression*, 2>{t, s},
auto* texture = globals[0]; [&](std::array<const sem::GlobalVariable*, 2> globals) {
sem::BindingPoint texture_binding_point = { auto* texture = globals[0]->Declaration()->As<ast::Var>();
texture->Declaration()->BindingPoint().group->value, sem::BindingPoint texture_binding_point = {texture->BindingPoint().group->value,
texture->Declaration()->BindingPoint().binding->value}; texture->BindingPoint().binding->value};
auto* sampler = globals[1]; auto* sampler = globals[1]->Declaration()->As<ast::Var>();
sem::BindingPoint sampler_binding_point = { sem::BindingPoint sampler_binding_point = {sampler->BindingPoint().group->value,
sampler->Declaration()->BindingPoint().group->value, sampler->BindingPoint().binding->value};
sampler->Declaration()->BindingPoint().binding->value};
for (auto* entry_point : entry_points) { for (auto* entry_point : entry_points) {
const auto& ep_name = program_->Symbols().NameFor( const auto& ep_name =
entry_point->Declaration()->symbol); program_->Symbols().NameFor(entry_point->Declaration()->symbol);
(*sampler_targets_)[ep_name].add( (*sampler_targets_)[ep_name].add(
{sampler_binding_point, texture_binding_point}); {sampler_binding_point, texture_binding_point});
} }
}); });
} }
} }

View File

@ -53,11 +53,14 @@
#include "src/tint/ast/index_accessor_expression.h" #include "src/tint/ast/index_accessor_expression.h"
#include "src/tint/ast/interpolate_attribute.h" #include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/invariant_attribute.h" #include "src/tint/ast/invariant_attribute.h"
#include "src/tint/ast/let.h"
#include "src/tint/ast/loop_statement.h" #include "src/tint/ast/loop_statement.h"
#include "src/tint/ast/matrix.h" #include "src/tint/ast/matrix.h"
#include "src/tint/ast/member_accessor_expression.h" #include "src/tint/ast/member_accessor_expression.h"
#include "src/tint/ast/module.h" #include "src/tint/ast/module.h"
#include "src/tint/ast/multisampled_texture.h" #include "src/tint/ast/multisampled_texture.h"
#include "src/tint/ast/override.h"
#include "src/tint/ast/parameter.h"
#include "src/tint/ast/phony_expression.h" #include "src/tint/ast/phony_expression.h"
#include "src/tint/ast/pointer.h" #include "src/tint/ast/pointer.h"
#include "src/tint/ast/return_statement.h" #include "src/tint/ast/return_statement.h"
@ -73,6 +76,7 @@
#include "src/tint/ast/type_name.h" #include "src/tint/ast/type_name.h"
#include "src/tint/ast/u32.h" #include "src/tint/ast/u32.h"
#include "src/tint/ast/unary_op_expression.h" #include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/var.h"
#include "src/tint/ast/variable_decl_statement.h" #include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/vector.h" #include "src/tint/ast/vector.h"
#include "src/tint/ast/void.h" #include "src/tint/ast/void.h"
@ -1328,14 +1332,13 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes /// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's /// Note that repeated arguments of the same type will use the last argument's
/// value. /// value.
/// @returns a `ast::Variable` with the given name, type and additional /// @returns a `ast::Var` with the given name, type and additional
/// options /// options
template <typename NAME, typename... OPTIONAL> template <typename NAME, typename... OPTIONAL>
const ast::Variable* Var(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) { const ast::Var* Var(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) {
VarOptionals opts(std::forward<OPTIONAL>(optional)...); VarOptionals opts(std::forward<OPTIONAL>(optional)...);
return create<ast::Variable>(Sym(std::forward<NAME>(name)), opts.storage, opts.access, type, return create<ast::Var>(Sym(std::forward<NAME>(name)), type, opts.storage, opts.access,
false /* is_const */, false /* is_overridable */, opts.constructor, std::move(opts.attributes));
opts.constructor, std::move(opts.attributes));
} }
/// @param source the variable source /// @param source the variable source
@ -1349,32 +1352,28 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes /// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's /// Note that repeated arguments of the same type will use the last argument's
/// value. /// value.
/// @returns a `ast::Variable` with the given name, storage and type /// @returns a `ast::Var` with the given name, storage and type
template <typename NAME, typename... OPTIONAL> template <typename NAME, typename... OPTIONAL>
const ast::Variable* Var(const Source& source, const ast::Var* Var(const Source& source,
NAME&& name, NAME&& name,
const ast::Type* type, const ast::Type* type,
OPTIONAL&&... optional) { OPTIONAL&&... optional) {
VarOptionals opts(std::forward<OPTIONAL>(optional)...); VarOptionals opts(std::forward<OPTIONAL>(optional)...);
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)), opts.storage, return create<ast::Var>(source, Sym(std::forward<NAME>(name)), type, opts.storage,
opts.access, type, false /* is_const */, opts.access, opts.constructor, std::move(opts.attributes));
false /* is_overridable */, opts.constructor,
std::move(opts.attributes));
} }
/// @param name the variable name /// @param name the variable name
/// @param type the variable type /// @param type the variable type
/// @param constructor constructor expression /// @param constructor constructor expression
/// @param attributes optional variable attributes /// @param attributes optional variable attributes
/// @returns an immutable `ast::Variable` with the given name and type /// @returns an `ast::Let` with the given name and type
template <typename NAME> template <typename NAME>
const ast::Variable* Let(NAME&& name, const ast::Let* Let(NAME&& name,
const ast::Type* type, const ast::Type* type,
const ast::Expression* constructor, const ast::Expression* constructor,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
return create<ast::Variable>(Sym(std::forward<NAME>(name)), ast::StorageClass::kNone, return create<ast::Let>(Sym(std::forward<NAME>(name)), type, constructor, attributes);
ast::Access::kUndefined, type, true /* is_const */,
false /* is_overridable */, constructor, attributes);
} }
/// @param source the variable source /// @param source the variable source
@ -1382,46 +1381,39 @@ class ProgramBuilder {
/// @param type the variable type /// @param type the variable type
/// @param constructor constructor expression /// @param constructor constructor expression
/// @param attributes optional variable attributes /// @param attributes optional variable attributes
/// @returns an immutable `ast::Variable` with the given name and type /// @returns an `ast::Let` with the given name and type
template <typename NAME> template <typename NAME>
const ast::Variable* Let(const Source& source, const ast::Let* Let(const Source& source,
NAME&& name, NAME&& name,
const ast::Type* type, const ast::Type* type,
const ast::Expression* constructor, const ast::Expression* constructor,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)), return create<ast::Let>(source, Sym(std::forward<NAME>(name)), type, constructor,
ast::StorageClass::kNone, ast::Access::kUndefined, type, attributes);
true /* is_const */, false /* is_overridable */, constructor,
attributes);
} }
/// @param name the parameter name /// @param name the parameter name
/// @param type the parameter type /// @param type the parameter type
/// @param attributes optional parameter attributes /// @param attributes optional parameter attributes
/// @returns an immutable `ast::Variable` with the given name and type /// @returns an `ast::Parameter` with the given name and type
template <typename NAME> template <typename NAME>
const ast::Variable* Param(NAME&& name, const ast::Parameter* Param(NAME&& name,
const ast::Type* type, const ast::Type* type,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
return create<ast::Variable>(Sym(std::forward<NAME>(name)), ast::StorageClass::kNone, return create<ast::Parameter>(Sym(std::forward<NAME>(name)), type, attributes);
ast::Access::kUndefined, type, true /* is_const */,
false /* is_overridable */, nullptr, attributes);
} }
/// @param source the parameter source /// @param source the parameter source
/// @param name the parameter name /// @param name the parameter name
/// @param type the parameter type /// @param type the parameter type
/// @param attributes optional parameter attributes /// @param attributes optional parameter attributes
/// @returns an immutable `ast::Variable` with the given name and type /// @returns an `ast::Parameter` with the given name and type
template <typename NAME> template <typename NAME>
const ast::Variable* Param(const Source& source, const ast::Parameter* Param(const Source& source,
NAME&& name, NAME&& name,
const ast::Type* type, const ast::Type* type,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)), return create<ast::Parameter>(source, Sym(std::forward<NAME>(name)), type, attributes);
ast::StorageClass::kNone, ast::Access::kUndefined, type,
true /* is_const */, false /* is_overridable */, nullptr,
attributes);
} }
/// @param name the variable name /// @param name the variable name
@ -1434,10 +1426,10 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes /// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's /// Note that repeated arguments of the same type will use the last argument's
/// value. /// value.
/// @returns a new `ast::Variable`, which is automatically registered as a /// @returns a new `ast::Var`, which is automatically registered as a global variable with the
/// global variable with the ast::Module. /// ast::Module.
template <typename NAME, typename... OPTIONAL, typename = DisableIfSource<NAME>> template <typename NAME, typename... OPTIONAL, typename = DisableIfSource<NAME>>
const ast::Variable* Global(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) { const ast::Var* Global(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) {
auto* var = Var(std::forward<NAME>(name), type, std::forward<OPTIONAL>(optional)...); auto* var = Var(std::forward<NAME>(name), type, std::forward<OPTIONAL>(optional)...);
AST().AddGlobalVariable(var); AST().AddGlobalVariable(var);
return var; return var;
@ -1454,13 +1446,13 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes /// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's /// Note that repeated arguments of the same type will use the last argument's
/// value. /// value.
/// @returns a new `ast::Variable`, which is automatically registered as a /// @returns a new `ast::Var`, which is automatically registered as a global variable with the
/// global variable with the ast::Module. /// ast::Module.
template <typename NAME, typename... OPTIONAL> template <typename NAME, typename... OPTIONAL>
const ast::Variable* Global(const Source& source, const ast::Var* Global(const Source& source,
NAME&& name, NAME&& name,
const ast::Type* type, const ast::Type* type,
OPTIONAL&&... optional) { OPTIONAL&&... optional) {
auto* var = auto* var =
Var(source, std::forward<NAME>(name), type, std::forward<OPTIONAL>(optional)...); Var(source, std::forward<NAME>(name), type, std::forward<OPTIONAL>(optional)...);
AST().AddGlobalVariable(var); AST().AddGlobalVariable(var);
@ -1471,14 +1463,13 @@ class ProgramBuilder {
/// @param type the variable type /// @param type the variable type
/// @param constructor constructor expression /// @param constructor constructor expression
/// @param attributes optional variable attributes /// @param attributes optional variable attributes
/// @returns a const `ast::Variable` constructed by calling Var() with the /// @returns an `ast::Let` constructed by calling Let() with the arguments of `args`, which is
/// arguments of `args`, which is automatically registered as a global /// automatically registered as a global variable with the ast::Module.
/// variable with the ast::Module.
template <typename NAME> template <typename NAME>
const ast::Variable* GlobalConst(NAME&& name, const ast::Let* GlobalConst(NAME&& name,
const ast::Type* type, const ast::Type* type,
const ast::Expression* constructor, const ast::Expression* constructor,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
auto* var = Let(std::forward<NAME>(name), type, constructor, std::move(attributes)); auto* var = Let(std::forward<NAME>(name), type, constructor, std::move(attributes));
AST().AddGlobalVariable(var); AST().AddGlobalVariable(var);
return var; return var;
@ -1489,15 +1480,15 @@ class ProgramBuilder {
/// @param type the variable type /// @param type the variable type
/// @param constructor constructor expression /// @param constructor constructor expression
/// @param attributes optional variable attributes /// @param attributes optional variable attributes
/// @returns a const `ast::Variable` constructed by calling Var() with the /// @returns a const `ast::Let` constructed by calling Var() with the
/// arguments of `args`, which is automatically registered as a global /// arguments of `args`, which is automatically registered as a global
/// variable with the ast::Module. /// variable with the ast::Module.
template <typename NAME> template <typename NAME>
const ast::Variable* GlobalConst(const Source& source, const ast::Let* GlobalConst(const Source& source,
NAME&& name, NAME&& name,
const ast::Type* type, const ast::Type* type,
const ast::Expression* constructor, const ast::Expression* constructor,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
auto* var = Let(source, std::forward<NAME>(name), type, constructor, std::move(attributes)); auto* var = Let(source, std::forward<NAME>(name), type, constructor, std::move(attributes));
AST().AddGlobalVariable(var); AST().AddGlobalVariable(var);
return var; return var;
@ -1507,17 +1498,15 @@ class ProgramBuilder {
/// @param type the variable type /// @param type the variable type
/// @param constructor optional constructor expression /// @param constructor optional constructor expression
/// @param attributes optional variable attributes /// @param attributes optional variable attributes
/// @returns an overridable const `ast::Variable` which is automatically /// @returns an `ast::Override` which is automatically registered as a global variable with the
/// registered as a global variable with the ast::Module. /// ast::Module.
template <typename NAME> template <typename NAME>
const ast::Variable* Override(NAME&& name, const ast::Override* Override(NAME&& name,
const ast::Type* type, const ast::Type* type,
const ast::Expression* constructor, const ast::Expression* constructor,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
auto* var = auto* var = create<ast::Override>(source_, Sym(std::forward<NAME>(name)), type, constructor,
create<ast::Variable>(source_, Sym(std::forward<NAME>(name)), ast::StorageClass::kNone, std::move(attributes));
ast::Access::kUndefined, type, true /* is_const */,
true /* is_overridable */, constructor, std::move(attributes));
AST().AddGlobalVariable(var); AST().AddGlobalVariable(var);
return var; return var;
} }
@ -1527,19 +1516,16 @@ class ProgramBuilder {
/// @param type the variable type /// @param type the variable type
/// @param constructor constructor expression /// @param constructor constructor expression
/// @param attributes optional variable attributes /// @param attributes optional variable attributes
/// @returns a const `ast::Variable` constructed by calling Var() with the /// @returns an `ast::Override` constructed with the arguments of `args`, which is automatically
/// arguments of `args`, which is automatically registered as a global /// registered as a global variable with the ast::Module.
/// variable with the ast::Module.
template <typename NAME> template <typename NAME>
const ast::Variable* Override(const Source& source, const ast::Override* Override(const Source& source,
NAME&& name, NAME&& name,
const ast::Type* type, const ast::Type* type,
const ast::Expression* constructor, const ast::Expression* constructor,
ast::AttributeList attributes = {}) { ast::AttributeList attributes = {}) {
auto* var = auto* var = create<ast::Override>(source, Sym(std::forward<NAME>(name)), type, constructor,
create<ast::Variable>(source, Sym(std::forward<NAME>(name)), ast::StorageClass::kNone, std::move(attributes));
ast::Access::kUndefined, type, true /* is_const */,
true /* is_overridable */, constructor, std::move(attributes));
AST().AddGlobalVariable(var); AST().AddGlobalVariable(var);
return var; return var;
} }

View File

@ -1253,12 +1253,12 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() {
auto* sample_mask_array_type = store_type->UnwrapRef()->UnwrapAlias()->As<Array>(); auto* sample_mask_array_type = store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(Reader, sample_mask_array_type); TINT_ASSERT(Reader, sample_mask_array_type);
ok = EmitPipelineInput(var_name, store_type, &param_decos, {0}, ok = EmitPipelineInput(var_name, store_type, &param_decos, {0},
sample_mask_array_type->type, forced_param_type, &(decl.params), sample_mask_array_type->type, forced_param_type, &decl.params,
&stmts); &stmts);
} else { } else {
// The normal path. // The normal path.
ok = EmitPipelineInput(var_name, store_type, &param_decos, {}, store_type, ok = EmitPipelineInput(var_name, store_type, &param_decos, {}, store_type,
forced_param_type, &(decl.params), &stmts); forced_param_type, &decl.params, &stmts);
} }
if (!ok) { if (!ok) {
return false; return false;
@ -1404,8 +1404,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
auto* type = parser_impl_.ConvertType(param->type_id()); auto* type = parser_impl_.ConvertType(param->type_id());
if (type != nullptr) { if (type != nullptr) {
auto* ast_param = auto* ast_param =
parser_impl_.MakeVariable(param->result_id(), ast::StorageClass::kNone, type, true, parser_impl_.MakeParameter(param->result_id(), type, ast::AttributeList{});
false, nullptr, ast::AttributeList{});
// Parameters are treated as const declarations. // Parameters are treated as const declarations.
ast_params.emplace_back(ast_param); ast_params.emplace_back(ast_param);
// The value is accessible by name. // The value is accessible by name.
@ -2468,9 +2467,8 @@ bool FunctionEmitter::EmitFunctionVariables() {
return false; return false;
} }
} }
auto* var = auto* var = parser_impl_.MakeVar(inst.result_id(), ast::StorageClass::kNone, var_store_type,
parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, var_store_type, constructor, ast::AttributeList{});
false, false, constructor, ast::AttributeList{});
auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var); auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
AddStatement(var_decl_stmt); AddStatement(var_decl_stmt);
auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone); auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone);
@ -3328,8 +3326,8 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
TINT_ASSERT(Reader, def_inst); TINT_ASSERT(Reader, def_inst);
auto* storage_type = RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id); auto* storage_type = RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
AddStatement(create<ast::VariableDeclStatement>( AddStatement(create<ast::VariableDeclStatement>(
Source{}, parser_impl_.MakeVariable(id, ast::StorageClass::kNone, storage_type, false, Source{}, parser_impl_.MakeVar(id, ast::StorageClass::kNone, storage_type, nullptr,
false, nullptr, ast::AttributeList{}))); ast::AttributeList{})));
auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone); auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone);
identifier_types_.emplace(id, type); identifier_types_.emplace(id, type);
} }
@ -3396,13 +3394,11 @@ bool FunctionEmitter::EmitConstDefinition(const spvtools::opt::Instruction& inst
} }
expr = AddressOfIfNeeded(expr, &inst); expr = AddressOfIfNeeded(expr, &inst);
auto* ast_const = auto* let = parser_impl_.MakeLet(inst.result_id(), expr.type, expr.expr);
parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, expr.type, true, if (!let) {
false, expr.expr, ast::AttributeList{});
if (!ast_const) {
return false; return false;
} }
AddStatement(create<ast::VariableDeclStatement>(Source{}, ast_const)); AddStatement(create<ast::VariableDeclStatement>(Source{}, let));
identifier_types_.emplace(inst.result_id(), expr.type); identifier_types_.emplace(inst.result_id(), expr.type);
return success(); return success();
} }

View File

@ -1371,8 +1371,8 @@ bool ParserImpl::EmitScalarSpecConstants() {
break; break;
} }
} }
auto* ast_var = MakeVariable(inst.result_id(), ast::StorageClass::kNone, ast_type, true, auto* ast_var =
true, ast_expr, std::move(spec_id_decos)); MakeOverride(inst.result_id(), ast_type, ast_expr, std::move(spec_id_decos));
if (ast_var) { if (ast_var) {
builder_.AST().AddGlobalVariable(ast_var); builder_.AST().AddGlobalVariable(ast_var);
scalar_spec_constants_.insert(inst.result_id()); scalar_spec_constants_.insert(inst.result_id());
@ -1489,8 +1489,8 @@ bool ParserImpl::EmitModuleScopeVariables() {
// here.) // here.)
ast_constructor = MakeConstantExpression(var.GetSingleWordInOperand(1)).expr; ast_constructor = MakeConstantExpression(var.GetSingleWordInOperand(1)).expr;
} }
auto* ast_var = MakeVariable(var.result_id(), ast_storage_class, ast_store_type, false, auto* ast_var = MakeVar(var.result_id(), ast_storage_class, ast_store_type, ast_constructor,
false, ast_constructor, ast::AttributeList{}); ast::AttributeList{});
// TODO(dneto): initializers (a.k.a. constructor expression) // TODO(dneto): initializers (a.k.a. constructor expression)
if (ast_var) { if (ast_var) {
builder_.AST().AddGlobalVariable(ast_var); builder_.AST().AddGlobalVariable(ast_var);
@ -1521,10 +1521,9 @@ bool ParserImpl::EmitModuleScopeVariables() {
} }
} }
auto* ast_var = auto* ast_var =
MakeVariable(builtin_position_.per_vertex_var_id, MakeVar(builtin_position_.per_vertex_var_id,
enum_converter_.ToStorageClass(builtin_position_.storage_class), enum_converter_.ToStorageClass(builtin_position_.storage_class),
ConvertType(builtin_position_.position_member_type_id), false, false, ConvertType(builtin_position_.position_member_type_id), ast_constructor, {});
ast_constructor, {});
builder_.AST().AddGlobalVariable(ast_var); builder_.AST().AddGlobalVariable(ast_var);
} }
@ -1554,13 +1553,11 @@ const spvtools::opt::analysis::IntConstant* ParserImpl::GetArraySize(uint32_t va
return size->AsIntConstant(); return size->AsIntConstant();
} }
ast::Variable* ParserImpl::MakeVariable(uint32_t id, ast::Var* ParserImpl::MakeVar(uint32_t id,
ast::StorageClass sc, ast::StorageClass sc,
const Type* storage_type, const Type* storage_type,
bool is_const, const ast::Expression* constructor,
bool is_overridable, ast::AttributeList decorations) {
const ast::Expression* constructor,
ast::AttributeList decorations) {
if (storage_type == nullptr) { if (storage_type == nullptr) {
Fail() << "internal error: can't make ast::Variable for null type"; Fail() << "internal error: can't make ast::Variable for null type";
return nullptr; return nullptr;
@ -1588,15 +1585,37 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id,
return nullptr; return nullptr;
} }
std::string name = namer_.Name(id); auto sym = builder_.Symbols().Register(namer_.Name(id));
return create<ast::Var>(Source{}, sym, storage_type->Build(builder_), sc, access, constructor,
decorations);
}
// Note: we're constructing the variable here with the *storage* type, ast::Let* ParserImpl::MakeLet(uint32_t id, const Type* type, const ast::Expression* constructor) {
// regardless of whether this is a `let`, `override`, or `var` declaration. auto sym = builder_.Symbols().Register(namer_.Name(id));
// `var` declarations will have a resolved type of ref<storage>, but at the return create<ast::Let>(Source{}, sym, type->Build(builder_), constructor,
// AST level all three are declared with the same type. ast::AttributeList{});
return create<ast::Variable>(Source{}, builder_.Symbols().Register(name), sc, access, }
storage_type->Build(builder_), is_const, is_overridable,
constructor, decorations); ast::Override* ParserImpl::MakeOverride(uint32_t id,
const Type* type,
const ast::Expression* constructor,
ast::AttributeList decorations) {
if (!ConvertDecorationsForVariable(id, &type, &decorations, false)) {
return nullptr;
}
auto sym = builder_.Symbols().Register(namer_.Name(id));
return create<ast::Override>(Source{}, sym, type->Build(builder_), constructor, decorations);
}
ast::Parameter* ParserImpl::MakeParameter(uint32_t id,
const Type* type,
ast::AttributeList decorations) {
if (!ConvertDecorationsForVariable(id, &type, &decorations, false)) {
return nullptr;
}
auto sym = builder_.Symbols().Register(namer_.Name(id));
return create<ast::Parameter>(Source{}, sym, type->Build(builder_), decorations);
} }
bool ParserImpl::ConvertDecorationsForVariable(uint32_t id, bool ParserImpl::ConvertDecorationsForVariable(uint32_t id,

View File

@ -411,25 +411,47 @@ class ParserImpl : Reader {
/// @returns a list of SPIR-V decorations. /// @returns a list of SPIR-V decorations.
DecorationList GetMemberPipelineDecorations(const Struct& struct_type, int member_index); DecorationList GetMemberPipelineDecorations(const Struct& struct_type, int member_index);
/// Creates an AST Variable node for a SPIR-V ID, including any attached /// Creates an AST 'var' node for a SPIR-V ID, including any attached decorations, unless it's
/// decorations, unless it's an ignorable builtin variable. /// an ignorable builtin variable.
/// @param id the SPIR-V result ID /// @param id the SPIR-V result ID
/// @param sc the storage class, which cannot be ast::StorageClass::kNone /// @param sc the storage class, which cannot be ast::StorageClass::kNone
/// @param storage_type the storage type of the variable /// @param storage_type the storage type of the variable
/// @param is_const if true, the variable is const
/// @param is_overridable if true, the variable is pipeline-overridable
/// @param constructor the variable constructor /// @param constructor the variable constructor
/// @param decorations the variable decorations /// @param decorations the variable decorations
/// @returns a new Variable node, or null in the ignorable variable case and /// @returns a new Variable node, or null in the ignorable variable case and
/// in the error case /// in the error case
ast::Variable* MakeVariable(uint32_t id, ast::Var* MakeVar(uint32_t id,
ast::StorageClass sc, ast::StorageClass sc,
const Type* storage_type, const Type* storage_type,
bool is_const, const ast::Expression* constructor,
bool is_overridable, ast::AttributeList decorations);
/// Creates an AST 'let' node for a SPIR-V ID, including any attached decorations,.
/// @param id the SPIR-V result ID
/// @param type the type of the variable
/// @param constructor the variable constructor
/// @returns the AST 'let' node
ast::Let* MakeLet(uint32_t id, const Type* type, const ast::Expression* constructor);
/// Creates an AST 'override' node for a SPIR-V ID, including any attached decorations.
/// @param id the SPIR-V result ID
/// @param type the type of the variable
/// @param constructor the variable constructor
/// @param decorations the variable decorations
/// @returns the AST 'override' node
ast::Override* MakeOverride(uint32_t id,
const Type* type,
const ast::Expression* constructor, const ast::Expression* constructor,
ast::AttributeList decorations); ast::AttributeList decorations);
/// Creates an AST parameter node for a SPIR-V ID, including any attached decorations, unless
/// it's an ignorable builtin variable.
/// @param id the SPIR-V result ID
/// @param type the type of the parameter
/// @param decorations the parameter decorations
/// @returns the AST parameter node
ast::Parameter* MakeParameter(uint32_t id, const Type* type, ast::AttributeList decorations);
/// Returns true if a constant expression can be generated. /// Returns true if a constant expression can be generated.
/// @param id the SPIR-V ID of the value /// @param id the SPIR-V ID of the value
/// @returns true if a constant expression can be generated /// @returns true if a constant expression can be generated

View File

@ -213,7 +213,11 @@ ParserImpl::FunctionHeader::FunctionHeader(Source src,
ast::ParameterList p, ast::ParameterList p,
const ast::Type* ret_ty, const ast::Type* ret_ty,
ast::AttributeList ret_attrs) ast::AttributeList ret_attrs)
: source(src), name(n), params(p), return_type(ret_ty), return_type_attributes(ret_attrs) {} : source(src),
name(n),
params(std::move(p)),
return_type(ret_ty),
return_type_attributes(std::move(ret_attrs)) {}
ParserImpl::FunctionHeader::~FunctionHeader() = default; ParserImpl::FunctionHeader::~FunctionHeader() = default;
@ -542,15 +546,13 @@ Maybe<const ast::Variable*> ParserImpl::global_variable_decl(ast::AttributeList&
constructor = expr.value; constructor = expr.value;
} }
return create<ast::Variable>(decl->source, // source return create<ast::Var>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol builder_.Symbols().Register(decl->name), // symbol
decl->storage_class, // storage class decl->type, // type
decl->access, // access control decl->storage_class, // storage class
decl->type, // type decl->access, // access control
false, // is_const constructor, // constructor
false, // is_overridable std::move(attrs)); // attributes
constructor, // constructor
std::move(attrs)); // attributes
} }
// global_constant_decl : // global_constant_decl :
@ -564,7 +566,7 @@ Maybe<const ast::Variable*> ParserImpl::global_constant_decl(ast::AttributeList&
if (match(Token::Type::kLet)) { if (match(Token::Type::kLet)) {
use = "'let' declaration"; use = "'let' declaration";
} else if (match(Token::Type::kOverride)) { } else if (match(Token::Type::kOverride)) {
use = "override declaration"; use = "'override' declaration";
is_overridable = true; is_overridable = true;
} else { } else {
return Failure::kNoMatch; return Failure::kNoMatch;
@ -594,15 +596,18 @@ Maybe<const ast::Variable*> ParserImpl::global_constant_decl(ast::AttributeList&
initializer = std::move(init.value); initializer = std::move(init.value);
} }
return create<ast::Variable>(decl->source, // source if (is_overridable) {
builder_.Symbols().Register(decl->name), // symbol return create<ast::Override>(decl->source, // source
ast::StorageClass::kNone, // storage class builder_.Symbols().Register(decl->name), // symbol
ast::Access::kUndefined, // access control decl->type, // type
decl->type, // type initializer, // constructor
true, // is_const std::move(attrs)); // attributes
is_overridable, // is_overridable }
initializer, // constructor return create<ast::Let>(decl->source, // source
std::move(attrs)); // attributes builder_.Symbols().Register(decl->name), // symbol
decl->type, // type
initializer, // constructor
std::move(attrs)); // attributes
} }
// variable_decl // variable_decl
@ -1478,7 +1483,7 @@ Expect<ast::ParameterList> ParserImpl::expect_param_list() {
// param // param
// : attribute_list* variable_ident_decl // : attribute_list* variable_ident_decl
Expect<ast::Variable*> ParserImpl::expect_param() { Expect<ast::Parameter*> ParserImpl::expect_param() {
auto attrs = attribute_list(); auto attrs = attribute_list();
auto decl = expect_variable_ident_decl("parameter"); auto decl = expect_variable_ident_decl("parameter");
@ -1486,21 +1491,10 @@ Expect<ast::Variable*> ParserImpl::expect_param() {
return Failure::kErrored; return Failure::kErrored;
} }
auto* var = create<ast::Variable>(decl->source, // source return create<ast::Parameter>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol builder_.Symbols().Register(decl->name), // symbol
ast::StorageClass::kNone, // storage class decl->type, // type
ast::Access::kUndefined, // access control std::move(attrs.value)); // attributes
decl->type, // type
true, // is_const
false, // is_overridable
nullptr, // constructor
std::move(attrs.value)); // attributes
// Formal parameters are treated like a const declaration where the
// initializer value is provided by the call's argument. The key point is
// that it's not updatable after initially set. This is unlike C or GLSL
// which treat formal parameters like local variables that can be updated.
return var;
} }
// pipeline_stage // pipeline_stage
@ -1794,17 +1788,13 @@ Maybe<const ast::VariableDeclStatement*> ParserImpl::variable_stmt() {
return add_error(peek(), "missing constructor for 'let' declaration"); return add_error(peek(), "missing constructor for 'let' declaration");
} }
auto* var = create<ast::Variable>(decl->source, // source auto* let = create<ast::Let>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol builder_.Symbols().Register(decl->name), // symbol
ast::StorageClass::kNone, // storage class decl->type, // type
ast::Access::kUndefined, // access control constructor.value, // constructor
decl->type, // type ast::AttributeList{}); // attributes
true, // is_const
false, // is_overridable
constructor.value, // constructor
ast::AttributeList{}); // attributes
return create<ast::VariableDeclStatement>(decl->source, var); return create<ast::VariableDeclStatement>(decl->source, let);
} }
auto decl = variable_decl(/*allow_inferred = */ true); auto decl = variable_decl(/*allow_inferred = */ true);
@ -1828,15 +1818,13 @@ Maybe<const ast::VariableDeclStatement*> ParserImpl::variable_stmt() {
constructor = constructor_expr.value; constructor = constructor_expr.value;
} }
auto* var = create<ast::Variable>(decl->source, // source auto* var = create<ast::Var>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol builder_.Symbols().Register(decl->name), // symbol
decl->storage_class, // storage class decl->type, // type
decl->access, // access control decl->storage_class, // storage class
decl->type, // type decl->access, // access control
false, // is_const constructor, // constructor
false, // is_overridable ast::AttributeList{}); // attributes
constructor, // constructor
ast::AttributeList{}); // attributes
return create<ast::VariableDeclStatement>(var->source, var); return create<ast::VariableDeclStatement>(var->source, var);
} }

View File

@ -462,7 +462,7 @@ class ParserImpl {
Expect<ast::ParameterList> expect_param_list(); Expect<ast::ParameterList> expect_param_list();
/// Parses a `param` grammar element, erroring on parse failure. /// Parses a `param` grammar element, erroring on parse failure.
/// @returns the parsed variable /// @returns the parsed variable
Expect<ast::Variable*> expect_param(); Expect<ast::Parameter*> expect_param();
/// Parses a `pipeline_stage` grammar element, erroring if the next token does /// Parses a `pipeline_stage` grammar element, erroring if the next token does
/// not match a stage name. /// not match a stage name.
/// @returns the pipeline stage. /// @returns the pipeline stage.

View File

@ -57,7 +57,7 @@ TEST_F(ForStmtTest, InitializerStatementDecl) {
ASSERT_TRUE(fl.matched); ASSERT_TRUE(fl.matched);
ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer)); ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer));
auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable; auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable;
EXPECT_FALSE(var->is_const); EXPECT_TRUE(var->Is<ast::Var>());
EXPECT_EQ(var->constructor, nullptr); EXPECT_EQ(var->constructor, nullptr);
EXPECT_EQ(fl->condition, nullptr); EXPECT_EQ(fl->condition, nullptr);
EXPECT_EQ(fl->continuing, nullptr); EXPECT_EQ(fl->continuing, nullptr);
@ -74,7 +74,7 @@ TEST_F(ForStmtTest, InitializerStatementDeclEqual) {
ASSERT_TRUE(fl.matched); ASSERT_TRUE(fl.matched);
ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer)); ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer));
auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable; auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable;
EXPECT_FALSE(var->is_const); EXPECT_TRUE(var->Is<ast::Var>());
EXPECT_NE(var->constructor, nullptr); EXPECT_NE(var->constructor, nullptr);
EXPECT_EQ(fl->condition, nullptr); EXPECT_EQ(fl->condition, nullptr);
EXPECT_EQ(fl->continuing, nullptr); EXPECT_EQ(fl->continuing, nullptr);
@ -90,7 +90,7 @@ TEST_F(ForStmtTest, InitializerStatementConstDecl) {
ASSERT_TRUE(fl.matched); ASSERT_TRUE(fl.matched);
ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer)); ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer));
auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable; auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable;
EXPECT_TRUE(var->is_const); EXPECT_TRUE(var->Is<ast::Let>());
EXPECT_NE(var->constructor, nullptr); EXPECT_NE(var->constructor, nullptr);
EXPECT_EQ(fl->condition, nullptr); EXPECT_EQ(fl->condition, nullptr);
EXPECT_EQ(fl->continuing, nullptr); EXPECT_EQ(fl->continuing, nullptr);

View File

@ -27,21 +27,20 @@ TEST_F(ParserImplTest, GlobalConstantDecl) {
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* let = e.value->As<ast::Let>();
ASSERT_NE(let, nullptr);
EXPECT_TRUE(e->is_const); EXPECT_EQ(let->symbol, p->builder().Symbols().Get("a"));
EXPECT_FALSE(e->is_overridable); ASSERT_NE(let->type, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); EXPECT_TRUE(let->type->Is<ast::F32>());
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(let->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 5u); EXPECT_EQ(let->source.range.begin.column, 5u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(let->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 6u); EXPECT_EQ(let->source.range.end.column, 6u);
ASSERT_NE(e->constructor, nullptr); ASSERT_NE(let->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>()); EXPECT_TRUE(let->constructor->Is<ast::LiteralExpression>());
} }
TEST_F(ParserImplTest, GlobalConstantDecl_Inferred) { TEST_F(ParserImplTest, GlobalConstantDecl_Inferred) {
@ -53,20 +52,19 @@ TEST_F(ParserImplTest, GlobalConstantDecl_Inferred) {
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* let = e.value->As<ast::Let>();
ASSERT_NE(let, nullptr);
EXPECT_TRUE(e->is_const); EXPECT_EQ(let->symbol, p->builder().Symbols().Get("a"));
EXPECT_FALSE(e->is_overridable); EXPECT_EQ(let->type, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(e->type, nullptr);
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(let->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 5u); EXPECT_EQ(let->source.range.begin.column, 5u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(let->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 6u); EXPECT_EQ(let->source.range.end.column, 6u);
ASSERT_NE(e->constructor, nullptr); ASSERT_NE(let->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>()); EXPECT_TRUE(let->constructor->Is<ast::LiteralExpression>());
} }
TEST_F(ParserImplTest, GlobalConstantDecl_InvalidExpression) { TEST_F(ParserImplTest, GlobalConstantDecl_InvalidExpression) {
@ -105,23 +103,22 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_WithId) {
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(e->is_const); EXPECT_EQ(override->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->is_overridable); ASSERT_NE(override->type, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); EXPECT_TRUE(override->type->Is<ast::F32>());
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(override->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 17u); EXPECT_EQ(override->source.range.begin.column, 17u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(override->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 18u); EXPECT_EQ(override->source.range.end.column, 18u);
ASSERT_NE(e->constructor, nullptr); ASSERT_NE(override->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>()); EXPECT_TRUE(override->constructor->Is<ast::LiteralExpression>());
auto* override_attr = ast::GetAttribute<ast::IdAttribute>(e.value->attributes); auto* override_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes);
ASSERT_NE(override_attr, nullptr); ASSERT_NE(override_attr, nullptr);
EXPECT_EQ(override_attr->value, 7u); EXPECT_EQ(override_attr->value, 7u);
} }
@ -136,23 +133,22 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_WithoutId) {
EXPECT_FALSE(p->has_error()) << p->error(); EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(e->is_const); EXPECT_EQ(override->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->is_overridable); ASSERT_NE(override->type, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); EXPECT_TRUE(override->type->Is<ast::F32>());
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(override->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 10u); EXPECT_EQ(override->source.range.begin.column, 10u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(override->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 11u); EXPECT_EQ(override->source.range.end.column, 11u);
ASSERT_NE(e->constructor, nullptr); ASSERT_NE(override->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>()); EXPECT_TRUE(override->constructor->Is<ast::LiteralExpression>());
auto* id_attr = ast::GetAttribute<ast::IdAttribute>(e.value->attributes); auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes);
ASSERT_EQ(id_attr, nullptr); ASSERT_EQ(id_attr, nullptr);
} }
@ -165,7 +161,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_MissingId) {
auto e = p->global_constant_decl(attrs.value); auto e = p->global_constant_decl(attrs.value);
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(p->has_error()); EXPECT_TRUE(p->has_error());
EXPECT_EQ(p->error(), "1:5: expected signed integer literal for id attribute"); EXPECT_EQ(p->error(), "1:5: expected signed integer literal for id attribute");
@ -180,7 +177,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_InvalidId) {
auto e = p->global_constant_decl(attrs.value); auto e = p->global_constant_decl(attrs.value);
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(p->has_error()); EXPECT_TRUE(p->has_error());
EXPECT_EQ(p->error(), "1:5: id attribute must be positive"); EXPECT_EQ(p->error(), "1:5: id attribute must be positive");

View File

@ -26,18 +26,19 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithoutConstructor) {
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->type->Is<ast::F32>()); EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kPrivate); EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kPrivate);
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 14u); EXPECT_EQ(var->source.range.begin.column, 14u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 15u); EXPECT_EQ(var->source.range.end.column, 15u);
ASSERT_EQ(e->constructor, nullptr); ASSERT_EQ(var->constructor, nullptr);
} }
TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) { TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) {
@ -49,19 +50,20 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) {
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->type->Is<ast::F32>()); EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kPrivate); EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kPrivate);
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 14u); EXPECT_EQ(var->source.range.begin.column, 14u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 15u); EXPECT_EQ(var->source.range.end.column, 15u);
ASSERT_NE(e->constructor, nullptr); ASSERT_NE(var->constructor, nullptr);
ASSERT_TRUE(e->constructor->Is<ast::FloatLiteralExpression>()); ASSERT_TRUE(var->constructor->Is<ast::FloatLiteralExpression>());
} }
TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute) { TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute) {
@ -73,21 +75,22 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute) {
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->type, nullptr); ASSERT_NE(var->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>()); EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kUniform); EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kUniform);
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 36u); EXPECT_EQ(var->source.range.begin.column, 36u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 37u); EXPECT_EQ(var->source.range.end.column, 37u);
ASSERT_EQ(e->constructor, nullptr); ASSERT_EQ(var->constructor, nullptr);
auto& attributes = e->attributes; auto& attributes = var->attributes;
ASSERT_EQ(attributes.size(), 2u); ASSERT_EQ(attributes.size(), 2u);
ASSERT_TRUE(attributes[0]->Is<ast::BindingAttribute>()); ASSERT_TRUE(attributes[0]->Is<ast::BindingAttribute>());
ASSERT_TRUE(attributes[1]->Is<ast::GroupAttribute>()); ASSERT_TRUE(attributes[1]->Is<ast::GroupAttribute>());
@ -103,21 +106,22 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute_MulitpleGroups) {
ASSERT_FALSE(p->has_error()) << p->error(); ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched); EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored); EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr); auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->type, nullptr); ASSERT_NE(var->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>()); EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kUniform); EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kUniform);
EXPECT_EQ(e->source.range.begin.line, 1u); EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 36u); EXPECT_EQ(var->source.range.begin.column, 36u);
EXPECT_EQ(e->source.range.end.line, 1u); EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 37u); EXPECT_EQ(var->source.range.end.column, 37u);
ASSERT_EQ(e->constructor, nullptr); ASSERT_EQ(var->constructor, nullptr);
auto& attributes = e->attributes; auto& attributes = var->attributes;
ASSERT_EQ(attributes.size(), 2u); ASSERT_EQ(attributes.size(), 2u);
ASSERT_TRUE(attributes[0]->Is<ast::BindingAttribute>()); ASSERT_TRUE(attributes[0]->Is<ast::BindingAttribute>());
ASSERT_TRUE(attributes[1]->Is<ast::GroupAttribute>()); ASSERT_TRUE(attributes[1]->Is<ast::GroupAttribute>());

View File

@ -27,7 +27,7 @@ TEST_F(ParserImplTest, ParamList_Single) {
EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e.value[0]->type->Is<ast::I32>()); EXPECT_TRUE(e.value[0]->type->Is<ast::I32>());
EXPECT_TRUE(e.value[0]->is_const); EXPECT_TRUE(e.value[0]->Is<ast::Parameter>());
ASSERT_EQ(e.value[0]->source.range.begin.line, 1u); ASSERT_EQ(e.value[0]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[0]->source.range.begin.column, 1u); ASSERT_EQ(e.value[0]->source.range.begin.column, 1u);
@ -45,7 +45,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a")); EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e.value[0]->type->Is<ast::I32>()); EXPECT_TRUE(e.value[0]->type->Is<ast::I32>());
EXPECT_TRUE(e.value[0]->is_const); EXPECT_TRUE(e.value[0]->Is<ast::Parameter>());
ASSERT_EQ(e.value[0]->source.range.begin.line, 1u); ASSERT_EQ(e.value[0]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[0]->source.range.begin.column, 1u); ASSERT_EQ(e.value[0]->source.range.begin.column, 1u);
@ -54,7 +54,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("b")); EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("b"));
EXPECT_TRUE(e.value[1]->type->Is<ast::F32>()); EXPECT_TRUE(e.value[1]->type->Is<ast::F32>());
EXPECT_TRUE(e.value[1]->is_const); EXPECT_TRUE(e.value[1]->Is<ast::Parameter>());
ASSERT_EQ(e.value[1]->source.range.begin.line, 1u); ASSERT_EQ(e.value[1]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[1]->source.range.begin.column, 10u); ASSERT_EQ(e.value[1]->source.range.begin.column, 10u);
@ -65,7 +65,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
ASSERT_TRUE(e.value[2]->type->Is<ast::Vector>()); ASSERT_TRUE(e.value[2]->type->Is<ast::Vector>());
ASSERT_TRUE(e.value[2]->type->As<ast::Vector>()->type->Is<ast::F32>()); ASSERT_TRUE(e.value[2]->type->As<ast::Vector>()->type->Is<ast::F32>());
EXPECT_EQ(e.value[2]->type->As<ast::Vector>()->width, 2u); EXPECT_EQ(e.value[2]->type->As<ast::Vector>()->width, 2u);
EXPECT_TRUE(e.value[2]->is_const); EXPECT_TRUE(e.value[2]->Is<ast::Parameter>());
ASSERT_EQ(e.value[2]->source.range.begin.line, 1u); ASSERT_EQ(e.value[2]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[2]->source.range.begin.column, 18u); ASSERT_EQ(e.value[2]->source.range.begin.column, 18u);
@ -101,7 +101,7 @@ TEST_F(ParserImplTest, ParamList_Attributes) {
ASSERT_TRUE(e.value[0]->type->Is<ast::Vector>()); ASSERT_TRUE(e.value[0]->type->Is<ast::Vector>());
EXPECT_TRUE(e.value[0]->type->As<ast::Vector>()->type->Is<ast::F32>()); EXPECT_TRUE(e.value[0]->type->As<ast::Vector>()->type->Is<ast::F32>());
EXPECT_EQ(e.value[0]->type->As<ast::Vector>()->width, 4u); EXPECT_EQ(e.value[0]->type->As<ast::Vector>()->width, 4u);
EXPECT_TRUE(e.value[0]->is_const); EXPECT_TRUE(e.value[0]->Is<ast::Parameter>());
auto attrs_0 = e.value[0]->attributes; auto attrs_0 = e.value[0]->attributes;
ASSERT_EQ(attrs_0.size(), 1u); ASSERT_EQ(attrs_0.size(), 1u);
EXPECT_TRUE(attrs_0[0]->Is<ast::BuiltinAttribute>()); EXPECT_TRUE(attrs_0[0]->Is<ast::BuiltinAttribute>());
@ -114,7 +114,7 @@ TEST_F(ParserImplTest, ParamList_Attributes) {
EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("loc1")); EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("loc1"));
EXPECT_TRUE(e.value[1]->type->Is<ast::F32>()); EXPECT_TRUE(e.value[1]->type->Is<ast::F32>());
EXPECT_TRUE(e.value[1]->is_const); EXPECT_TRUE(e.value[1]->Is<ast::Parameter>());
auto attrs_1 = e.value[1]->attributes; auto attrs_1 = e.value[1]->attributes;
ASSERT_EQ(attrs_1.size(), 1u); ASSERT_EQ(attrs_1.size(), 1u);
EXPECT_TRUE(attrs_1[0]->Is<ast::LocationAttribute>()); EXPECT_TRUE(attrs_1[0]->Is<ast::LocationAttribute>());

View File

@ -487,11 +487,13 @@ struct DependencyAnalysis {
/// declaration /// declaration
std::string KindOf(const ast::Node* node) { std::string KindOf(const ast::Node* node) {
return Switch( return Switch(
node, // node, //
[&](const ast::Struct*) { return "struct"; }, [&](const ast::Struct*) { return "struct"; }, //
[&](const ast::Alias*) { return "alias"; }, [&](const ast::Alias*) { return "alias"; }, //
[&](const ast::Function*) { return "function"; }, [&](const ast::Function*) { return "function"; }, //
[&](const ast::Variable* var) { return var->is_const ? "let" : "var"; }, [&](const ast::Let*) { return "let"; }, //
[&](const ast::Var*) { return "var"; }, //
[&](const ast::Override*) { return "override"; }, //
[&](Default) { [&](Default) {
UnhandledNode(diagnostics_, node); UnhandledNode(diagnostics_, node);
return "<error>"; return "<error>";

View File

@ -31,7 +31,7 @@ class ResolverPipelineOverridableConstantTest : public ResolverTest {
auto* sem = Sem().Get<sem::GlobalVariable>(var); auto* sem = Sem().Get<sem::GlobalVariable>(var);
ASSERT_NE(sem, nullptr); ASSERT_NE(sem, nullptr);
EXPECT_EQ(sem->Declaration(), var); EXPECT_EQ(sem->Declaration(), var);
EXPECT_TRUE(sem->IsOverridable()); EXPECT_TRUE(sem->Declaration()->Is<ast::Override>());
EXPECT_EQ(sem->ConstantId(), id); EXPECT_EQ(sem->ConstantId(), id);
EXPECT_FALSE(sem->ConstantValue()); EXPECT_FALSE(sem->ConstantValue());
} }
@ -45,7 +45,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
auto* sem_a = Sem().Get<sem::GlobalVariable>(a); auto* sem_a = Sem().Get<sem::GlobalVariable>(a);
ASSERT_NE(sem_a, nullptr); ASSERT_NE(sem_a, nullptr);
EXPECT_EQ(sem_a->Declaration(), a); EXPECT_EQ(sem_a->Declaration(), a);
EXPECT_FALSE(sem_a->IsOverridable()); EXPECT_FALSE(sem_a->Declaration()->Is<ast::Override>());
EXPECT_TRUE(sem_a->ConstantValue()); EXPECT_TRUE(sem_a->ConstantValue());
} }

View File

@ -303,59 +303,63 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
return s; return s;
} }
sem::Variable* Resolver::Variable(const ast::Variable* var, sem::Variable* Resolver::Variable(const ast::Variable* v,
VariableKind kind, bool is_global,
uint32_t index /* = 0 */) { uint32_t index /* = 0 */) {
const sem::Type* storage_ty = nullptr; const sem::Type* storage_ty = nullptr;
// If the variable has a declared type, resolve it. // If the variable has a declared type, resolve it.
if (auto* ty = var->type) { if (auto* ty = v->type) {
storage_ty = Type(ty); storage_ty = Type(ty);
if (!storage_ty) { if (!storage_ty) {
return nullptr; return nullptr;
} }
} }
auto* as_var = v->As<ast::Var>();
auto* as_let = v->As<ast::Let>();
auto* as_override = v->As<ast::Override>();
auto* as_param = v->As<ast::Parameter>();
const sem::Expression* rhs = nullptr; const sem::Expression* rhs = nullptr;
// Does the variable have a constructor? // Does the variable have a constructor?
if (var->constructor) { if (v->constructor) {
rhs = Materialize(Expression(var->constructor), storage_ty); rhs = Materialize(Expression(v->constructor), storage_ty);
if (!rhs) { if (!rhs) {
return nullptr; return nullptr;
} }
// If the variable has no declared type, infer it from the RHS // If the variable has no declared type, infer it from the RHS
if (!storage_ty) { if (!storage_ty) {
if (!var->is_const && kind == VariableKind::kGlobal) { if (as_var && is_global) {
AddError("module-scope 'var' declaration must specify a type", var->source); AddError("module-scope 'var' declaration must specify a type", v->source);
return nullptr; return nullptr;
} }
storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS
} }
} else if (var->is_const && !var->is_overridable && kind != VariableKind::kParameter) { } else if (as_let) {
AddError("'let' declaration must have an initializer", var->source); AddError("'let' declaration must have an initializer", v->source);
return nullptr; return nullptr;
} else if (!var->type) { } else if (!v->type) {
AddError((kind == VariableKind::kGlobal) AddError((is_global) ? "module-scope 'var' declaration requires a type or initializer"
? "module-scope 'var' declaration requires a type or initializer" : "function-scope 'var' declaration requires a type or initializer",
: "function-scope 'var' declaration requires a type or initializer", v->source);
var->source);
return nullptr; return nullptr;
} }
if (!storage_ty) { if (!storage_ty) {
TINT_ICE(Resolver, diagnostics_) << "failed to determine storage type for variable '" + TINT_ICE(Resolver, diagnostics_) << "failed to determine storage type for variable '" +
builder_->Symbols().NameFor(var->symbol) + "'\n" builder_->Symbols().NameFor(v->symbol) + "'\n"
<< "Source: " << var->source; << "Source: " << v->source;
return nullptr; return nullptr;
} }
auto storage_class = var->declared_storage_class; auto storage_class = as_var ? as_var->declared_storage_class : ast::StorageClass::kNone;
if (storage_class == ast::StorageClass::kNone && !var->is_const) { if (storage_class == ast::StorageClass::kNone && as_var) {
// No declared storage class. Infer from usage / type. // No declared storage class. Infer from usage / type.
if (kind == VariableKind::kLocal) { if (!is_global) {
storage_class = ast::StorageClass::kFunction; storage_class = ast::StorageClass::kFunction;
} else if (storage_ty->UnwrapRef()->is_handle()) { } else if (storage_ty->UnwrapRef()->is_handle()) {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
@ -366,93 +370,83 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
} }
} }
if (kind == VariableKind::kLocal && !var->is_const && if (!is_global && as_var && storage_class != ast::StorageClass::kFunction &&
storage_class != ast::StorageClass::kFunction && validator_.IsValidationEnabled(v->attributes,
validator_.IsValidationEnabled(var->attributes,
ast::DisabledValidation::kIgnoreStorageClass)) { ast::DisabledValidation::kIgnoreStorageClass)) {
AddError("function-scope 'var' declaration must use 'function' storage class", var->source); AddError("function-scope 'var' declaration must use 'function' storage class", v->source);
return nullptr; return nullptr;
} }
auto access = var->declared_access; auto access = as_var ? as_var->declared_access : ast::Access::kUndefined;
if (access == ast::Access::kUndefined) { if (access == ast::Access::kUndefined) {
access = DefaultAccessForStorageClass(storage_class); access = DefaultAccessForStorageClass(storage_class);
} }
auto* var_ty = storage_ty; auto* var_ty = storage_ty;
if (!var->is_const) { if (as_var) {
// Variable declaration. Unlike `let`, `var` has storage. // Variable declaration. Unlike `let` and parameters, `var` has storage.
// Variables are always of a reference type to the declared storage type. // Variables are always of a reference type to the declared storage type.
var_ty = builder_->create<sem::Reference>(storage_ty, storage_class, access); var_ty = builder_->create<sem::Reference>(storage_ty, storage_class, access);
} }
if (rhs && !validator_.VariableConstructorOrCast(var, storage_class, storage_ty, rhs->Type())) { if (rhs && !validator_.VariableConstructorOrCast(v, storage_class, storage_ty, rhs->Type())) {
return nullptr; return nullptr;
} }
if (!ApplyStorageClassUsageToType(storage_class, const_cast<sem::Type*>(var_ty), var->source)) { if (!ApplyStorageClassUsageToType(storage_class, const_cast<sem::Type*>(var_ty), v->source)) {
AddNote(std::string("while instantiating ") + AddNote(std::string("while instantiating ") + ((as_param) ? "parameter " : "variable ") +
((kind == VariableKind::kParameter) ? "parameter " : "variable ") + builder_->Symbols().NameFor(v->symbol),
builder_->Symbols().NameFor(var->symbol), v->source);
var->source);
return nullptr; return nullptr;
} }
if (kind == VariableKind::kParameter) { if (as_param) {
if (auto* ptr = var_ty->As<sem::Pointer>()) { if (auto* ptr = var_ty->As<sem::Pointer>()) {
// For MSL, we push module-scope variables into the entry point as pointer // For MSL, we push module-scope variables into the entry point as pointer
// parameters, so we also need to handle their store type. // parameters, so we also need to handle their store type.
if (!ApplyStorageClassUsageToType( if (!ApplyStorageClassUsageToType(
ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()), var->source)) { ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()), v->source)) {
AddNote("while instantiating parameter " + builder_->Symbols().NameFor(var->symbol), AddNote("while instantiating parameter " + builder_->Symbols().NameFor(v->symbol),
var->source); v->source);
return nullptr; return nullptr;
} }
} }
auto* param =
builder_->create<sem::Parameter>(as_param, index, var_ty, storage_class, access);
builder_->Sem().Add(as_param, param);
return param;
} }
switch (kind) { if (is_global) {
case VariableKind::kGlobal: { sem::BindingPoint binding_point;
sem::BindingPoint binding_point; if (as_var) {
if (auto bp = var->BindingPoint()) { if (auto bp = as_var->BindingPoint()) {
binding_point = {bp.group->value, bp.binding->value}; binding_point = {bp.group->value, bp.binding->value};
} }
}
bool has_const_val = rhs && var->is_const && !var->is_overridable; bool has_const_val = rhs && as_let && !as_override;
auto* global = builder_->create<sem::GlobalVariable>( auto* global = builder_->create<sem::GlobalVariable>(
var, var_ty, storage_class, access, v, var_ty, storage_class, access,
has_const_val ? rhs->ConstantValue() : sem::Constant{}, binding_point); has_const_val ? rhs->ConstantValue() : sem::Constant{}, binding_point);
if (var->is_overridable) { if (as_override) {
global->SetIsOverridable(); if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
if (auto* id = ast::GetAttribute<ast::IdAttribute>(var->attributes)) { global->SetConstantId(static_cast<uint16_t>(id->value));
global->SetConstantId(static_cast<uint16_t>(id->value));
}
} }
}
global->SetConstructor(rhs); global->SetConstructor(rhs);
builder_->Sem().Add(v, global);
builder_->Sem().Add(var, global); return global;
return global;
}
case VariableKind::kLocal: {
auto* local = builder_->create<sem::LocalVariable>(
var, var_ty, storage_class, access, current_statement_,
(rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{});
builder_->Sem().Add(var, local);
local->SetConstructor(rhs);
return local;
}
case VariableKind::kParameter: {
auto* param =
builder_->create<sem::Parameter>(var, index, var_ty, storage_class, access);
builder_->Sem().Add(var, param);
return param;
}
} }
TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled VariableKind " << static_cast<int>(kind); auto* local = builder_->create<sem::LocalVariable>(
return nullptr; v, var_ty, storage_class, access, current_statement_,
(rhs && as_let) ? rhs->ConstantValue() : sem::Constant{});
builder_->Sem().Add(v, local);
local->SetConstructor(rhs);
return local;
} }
ast::Access Resolver::DefaultAccessForStorageClass(ast::StorageClass storage_class) { ast::Access Resolver::DefaultAccessForStorageClass(ast::StorageClass storage_class) {
@ -477,13 +471,13 @@ void Resolver::AllocateOverridableConstantIds() {
// TODO(crbug.com/tint/1192): If a transform changes the order or removes an // TODO(crbug.com/tint/1192): If a transform changes the order or removes an
// unused constant, the allocation may change on the next Resolver pass. // unused constant, the allocation may change on the next Resolver pass.
for (auto* decl : builder_->AST().GlobalDeclarations()) { for (auto* decl : builder_->AST().GlobalDeclarations()) {
auto* var = decl->As<ast::Variable>(); auto* override = decl->As<ast::Override>();
if (!var || !var->is_overridable) { if (!override) {
continue; continue;
} }
uint16_t constant_id; uint16_t constant_id;
if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(var->attributes)) { if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes)) {
constant_id = static_cast<uint16_t>(id_attr->value); constant_id = static_cast<uint16_t>(id_attr->value);
} else { } else {
// No ID was specified, so allocate the next available ID. // No ID was specified, so allocate the next available ID.
@ -499,7 +493,7 @@ void Resolver::AllocateOverridableConstantIds() {
next_constant_id = constant_id + 1; next_constant_id = constant_id + 1;
} }
auto* sem = sem_.Get<sem::GlobalVariable>(var); auto* sem = sem_.Get<sem::GlobalVariable>(override);
const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id); const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id);
} }
} }
@ -513,25 +507,21 @@ void Resolver::SetShadows() {
} }
} }
sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) { sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
auto* sem = Variable(var, VariableKind::kGlobal); auto* sem = As<sem::GlobalVariable>(Variable(v, /* is_global */ true));
if (!sem) { if (!sem) {
return nullptr; return nullptr;
} }
const bool is_var = v->Is<ast::Var>();
auto storage_class = sem->StorageClass(); auto storage_class = sem->StorageClass();
if (!var->is_const && storage_class == ast::StorageClass::kNone) { if (is_var && storage_class == ast::StorageClass::kNone) {
AddError("module-scope 'var' declaration must have a storage class", var->source); AddError("module-scope 'var' declaration must have a storage class", v->source);
return nullptr;
}
if (var->is_const && storage_class != ast::StorageClass::kNone) {
AddError(var->is_overridable ? "'override' declaration must not have a storage class"
: "'let' declaration must not have a storage class",
var->source);
return nullptr; return nullptr;
} }
for (auto* attr : var->attributes) { for (auto* attr : v->attributes) {
Mark(attr); Mark(attr);
if (auto* id_attr = attr->As<ast::IdAttribute>()) { if (auto* id_attr = attr->As<ast::IdAttribute>()) {
@ -540,7 +530,7 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) {
} }
} }
if (!validator_.NoDuplicateAttributes(var->attributes)) { if (!validator_.NoDuplicateAttributes(v->attributes)) {
return nullptr; return nullptr;
} }
@ -576,9 +566,8 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
} }
} }
auto* var = auto* p = As<sem::Parameter>(Variable(param, false, parameter_index++));
As<sem::Parameter>(Variable(param, VariableKind::kParameter, parameter_index++)); if (!p) {
if (!var) {
return nullptr; return nullptr;
} }
@ -589,10 +578,10 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
return nullptr; return nullptr;
} }
parameters.emplace_back(var); parameters.emplace_back(p);
auto* var_ty = const_cast<sem::Type*>(var->Type()); auto* p_ty = const_cast<sem::Type*>(p->Type());
if (auto* str = var_ty->As<sem::Struct>()) { if (auto* str = p_ty->As<sem::Struct>()) {
switch (decl->PipelineStage()) { switch (decl->PipelineStage()) {
case ast::PipelineStage::kVertex: case ast::PipelineStage::kVertex:
str->AddUsage(sem::PipelineStageUsage::kVertexInput); str->AddUsage(sem::PipelineStageUsage::kVertexInput);
@ -777,12 +766,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
if (auto* user = args[i]->As<sem::VariableUser>()) { if (auto* user = args[i]->As<sem::VariableUser>()) {
// We have an variable of a module-scope constant. // We have an variable of a module-scope constant.
auto* decl = user->Variable()->Declaration(); auto* decl = user->Variable()->Declaration();
if (!decl->is_const) { if (!decl->IsAnyOf<ast::Let, ast::Override>()) {
AddError(kErrBadType, values[i]->source); AddError(kErrBadType, values[i]->source);
return false; return false;
} }
// Capture the constant if it is pipeline-overridable. // Capture the constant if it is pipeline-overridable.
if (decl->is_overridable) { if (decl->Is<ast::Override>()) {
ws[i].overridable_const = decl; ws[i].overridable_const = decl;
} }
@ -2104,19 +2093,19 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr; return nullptr;
} }
constexpr const char* kErrInvalidExpr =
"array size identifier must be a literal or a module-scope 'let'";
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) { if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant. // Make sure the identifier is a non-overridable module-scope 'let'.
auto* var = sem_.ResolvedSymbol<sem::GlobalVariable>(ident); auto* global = sem_.ResolvedSymbol<sem::GlobalVariable>(ident);
if (!var || !var->Declaration()->is_const || var->IsOverridable()) { if (!global || !global->Declaration()->Is<ast::Let>()) {
AddError("array size identifier must be a literal or a module-scope 'let'", AddError(kErrInvalidExpr, size_source);
size_source);
return nullptr; return nullptr;
} }
count_expr = global->Declaration()->constructor;
count_expr = var->Declaration()->constructor;
} else if (!count_expr->Is<ast::LiteralExpression>()) { } else if (!count_expr->Is<ast::LiteralExpression>()) {
AddError("array size identifier must be a literal or a module-scope 'let'", AddError(kErrInvalidExpr, size_source);
size_source);
return nullptr; return nullptr;
} }
@ -2437,7 +2426,7 @@ sem::Statement* Resolver::VariableDeclStatement(const ast::VariableDeclStatement
return StatementScope(stmt, sem, [&] { return StatementScope(stmt, sem, [&] {
Mark(stmt->variable); Mark(stmt->variable);
auto* var = Variable(stmt->variable, VariableKind::kLocal); auto* var = Variable(stmt->variable, /* is_global */ false);
if (!var) { if (!var) {
return false; return false;
} }

View File

@ -109,9 +109,6 @@ class Resolver {
const Validator* GetValidatorForTesting() const { return &validator_; } const Validator* GetValidatorForTesting() const { return &validator_; }
private: private:
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
Validator::ValidTypeStorageLayouts valid_type_storage_layouts_; Validator::ValidTypeStorageLayouts valid_type_storage_layouts_;
/// Structure holding semantic information about a block (i.e. scope), such as /// Structure holding semantic information about a block (i.e. scope), such as
@ -298,9 +295,9 @@ class Resolver {
/// @note this method does not resolve the attributes as these are /// @note this method does not resolve the attributes as these are
/// context-dependent (global, local, parameter) /// context-dependent (global, local, parameter)
/// @param var the variable to create or return the `VariableInfo` for /// @param var the variable to create or return the `VariableInfo` for
/// @param kind what kind of variable we are declaring /// @param is_global true if this is module scope, otherwise function scope
/// @param index the index of the parameter, if this variable is a parameter /// @param index the index of the parameter, if this variable is a parameter
sem::Variable* Variable(const ast::Variable* var, VariableKind kind, uint32_t index = 0); sem::Variable* Variable(const ast::Variable* var, bool is_global, uint32_t index = 0);
/// Records the storage class usage for the given type, and any transient /// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the /// dependencies of the type. Validates that the type can be used for the

View File

@ -87,26 +87,6 @@ TEST_F(ResolverTypeValidationTest, GlobalVariableWithStorageClass_Pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverTypeValidationTest, GlobalLetWithStorageClass_Fail) {
// let<private> global_var: f32;
AST().AddGlobalVariable(create<ast::Variable>(
Source{{12, 34}}, Symbols().Register("global_let"), ast::StorageClass::kPrivate,
ast::Access::kUndefined, ty.f32(), true, false, Expr(1.23_f), ast::AttributeList{}));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must not have a storage class");
}
TEST_F(ResolverTypeValidationTest, OverrideWithStorageClass_Fail) {
// let<private> global_var: f32;
AST().AddGlobalVariable(create<ast::Variable>(
Source{{12, 34}}, Symbols().Register("global_override"), ast::StorageClass::kPrivate,
ast::Access::kUndefined, ty.f32(), true, true, Expr(1.23_f), ast::AttributeList{}));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'override' declaration must not have a storage class");
}
TEST_F(ResolverTypeValidationTest, GlobalConstNoStorageClass_Pass) { TEST_F(ResolverTypeValidationTest, GlobalConstNoStorageClass_Pass) {
// let global_var: f32; // let global_var: f32;
GlobalConst(Source{{12, 34}}, "global_var", ty.f32(), Construct(ty.f32())); GlobalConst(Source{{12, 34}}, "global_var", ty.f32(), Construct(ty.f32()));

View File

@ -949,7 +949,7 @@ class UniformityGraph {
} }
current_function_->variables.Set(sem_.Get(decl->variable), node); current_function_->variables.Set(sem_.Get(decl->variable), node);
if (!decl->variable->is_const) { if (decl->variable->Is<ast::Var>()) {
current_function_->local_var_decls.insert( current_function_->local_var_decls.insert(
sem_.Get<sem::LocalVariable>(decl->variable)); sem_.Get<sem::LocalVariable>(decl->variable));
} }
@ -1018,7 +1018,8 @@ class UniformityGraph {
}, },
[&](const sem::GlobalVariable* global) { [&](const sem::GlobalVariable* global) {
if (global->Declaration()->is_const || global->Access() == ast::Access::kRead) { if (!global->Declaration()->Is<ast::Var>() ||
global->Access() == ast::Access::kRead) {
node->AddEdge(cf); node->AddEdge(cf);
} else { } else {
node->AddEdge(current_function_->may_be_non_uniform); node->AddEdge(current_function_->may_be_non_uniform);

View File

@ -297,7 +297,7 @@ bool Validator::Materialize(const sem::Materialize* m) const {
return true; return true;
} }
bool Validator::VariableConstructorOrCast(const ast::Variable* var, bool Validator::VariableConstructorOrCast(const ast::Variable* v,
ast::StorageClass storage_class, ast::StorageClass storage_class,
const sem::Type* storage_ty, const sem::Type* storage_ty,
const sem::Type* rhs_ty) const { const sem::Type* rhs_ty) const {
@ -305,14 +305,14 @@ bool Validator::VariableConstructorOrCast(const ast::Variable* var,
// Value type has to match storage type // Value type has to match storage type
if (storage_ty != value_type) { if (storage_ty != value_type) {
std::string decl = var->is_const ? "let" : "var"; std::string decl = v->Is<ast::Let>() ? "let" : "var";
AddError("cannot initialize " + decl + " of type '" + sem_.TypeNameOf(storage_ty) + AddError("cannot initialize " + decl + " of type '" + sem_.TypeNameOf(storage_ty) +
"' with value of type '" + sem_.TypeNameOf(rhs_ty) + "'", "' with value of type '" + sem_.TypeNameOf(rhs_ty) + "'",
var->source); v->source);
return false; return false;
} }
if (!var->is_const) { if (v->Is<ast::Var>()) {
switch (storage_class) { switch (storage_class) {
case ast::StorageClass::kPrivate: case ast::StorageClass::kPrivate:
case ast::StorageClass::kFunction: case ast::StorageClass::kFunction:
@ -325,7 +325,7 @@ bool Validator::VariableConstructorOrCast(const ast::Variable* var,
"' cannot have an initializer. var initializers are only " "' cannot have an initializer. var initializers are only "
"supported for the storage classes " "supported for the storage classes "
"'private' and 'function'", "'private' and 'function'",
var->source); v->source);
return false; return false;
} }
} }
@ -502,21 +502,22 @@ bool Validator::StorageClassLayout(const sem::Variable* var,
} }
bool Validator::GlobalVariable( bool Validator::GlobalVariable(
const sem::Variable* var, const sem::GlobalVariable* global,
std::unordered_map<uint32_t, const sem::Variable*> constant_ids, std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const { std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const {
auto* decl = var->Declaration(); auto* decl = global->Declaration();
if (!NoDuplicateAttributes(decl->attributes)) { if (!NoDuplicateAttributes(decl->attributes)) {
return false; return false;
} }
for (auto* attr : decl->attributes) { bool ok = Switch(
if (decl->is_const) { decl, //
if (decl->is_overridable) { [&](const ast::Override*) {
for (auto* attr : decl->attributes) {
if (auto* id_attr = attr->As<ast::IdAttribute>()) { if (auto* id_attr = attr->As<ast::IdAttribute>()) {
uint32_t id = id_attr->value; uint32_t id = id_attr->value;
auto it = constant_ids.find(id); auto it = constant_ids.find(id);
if (it != constant_ids.end() && it->second != var) { if (it != constant_ids.end() && it->second != global) {
AddError("pipeline constant IDs must be unique", attr->source); AddError("pipeline constant IDs must be unique", attr->source);
AddNote("a pipeline constant with an ID of " + std::to_string(id) + AddNote("a pipeline constant with an ID of " + std::to_string(id) +
" was previously declared here:", " was previously declared here:",
@ -533,32 +534,45 @@ bool Validator::GlobalVariable(
AddError("attribute is not valid for 'override' declaration", attr->source); AddError("attribute is not valid for 'override' declaration", attr->source);
return false; return false;
} }
} else { }
AddError("attribute is not valid for module-scope 'let' declaration", attr->source); return true;
},
[&](const ast::Let*) {
if (!decl->attributes.empty()) {
AddError("attribute is not valid for module-scope 'let' declaration",
decl->attributes[0]->source);
return false; return false;
} }
} else { return true;
bool is_shader_io_attribute = },
attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute, [&](const ast::Var*) {
ast::InvariantAttribute, ast::LocationAttribute>(); for (auto* attr : decl->attributes) {
bool has_io_storage_class = var->StorageClass() == ast::StorageClass::kInput || bool is_shader_io_attribute =
var->StorageClass() == ast::StorageClass::kOutput; attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
if (!(attr->IsAnyOf<ast::BindingAttribute, ast::GroupAttribute, ast::InvariantAttribute, ast::LocationAttribute>();
ast::InternalAttribute>()) && bool has_io_storage_class = global->StorageClass() == ast::StorageClass::kInput ||
(!is_shader_io_attribute || !has_io_storage_class)) { global->StorageClass() == ast::StorageClass::kOutput;
AddError("attribute is not valid for module-scope 'var'", attr->source); if (!attr->IsAnyOf<ast::BindingAttribute, ast::GroupAttribute,
return false; ast::InternalAttribute>() &&
(!is_shader_io_attribute || !has_io_storage_class)) {
AddError("attribute is not valid for module-scope 'var'", attr->source);
return false;
}
} }
} return true;
});
if (!ok) {
return false;
} }
if (var->StorageClass() == ast::StorageClass::kFunction) { if (global->StorageClass() == ast::StorageClass::kFunction) {
AddError("module-scope 'var' must not use storage class 'function'", decl->source); AddError("module-scope 'var' must not use storage class 'function'", decl->source);
return false; return false;
} }
auto binding_point = decl->BindingPoint(); auto binding_point = decl->BindingPoint();
switch (var->StorageClass()) { switch (global->StorageClass()) {
case ast::StorageClass::kUniform: case ast::StorageClass::kUniform:
case ast::StorageClass::kStorage: case ast::StorageClass::kStorage:
case ast::StorageClass::kHandle: { case ast::StorageClass::kHandle: {
@ -581,23 +595,23 @@ bool Validator::GlobalVariable(
} }
} }
// https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration if (auto* var = decl->As<ast::Var>()) {
// The access mode always has a default, and except for variables in the // https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
// storage storage class, must not be written. // The access mode always has a default, and except for variables in the
if (var->StorageClass() != ast::StorageClass::kStorage && // storage storage class, must not be written.
decl->declared_access != ast::Access::kUndefined) { if (global->StorageClass() != ast::StorageClass::kStorage &&
AddError("only variables in <storage> storage class may declare an access mode", var->declared_access != ast::Access::kUndefined) {
decl->source); AddError("only variables in <storage> storage class may declare an access mode",
return false; var->source);
} return false;
}
if (!decl->is_const) { if (!AtomicVariable(global, atomic_composite_info)) {
if (!AtomicVariable(var, atomic_composite_info)) {
return false; return false;
} }
} }
return Variable(var); return Variable(global);
} }
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
@ -641,14 +655,17 @@ bool Validator::AtomicVariable(
return true; return true;
} }
bool Validator::Variable(const sem::Variable* var) const { bool Validator::Variable(const sem::Variable* v) const {
auto* decl = var->Declaration(); auto* decl = v->Declaration();
auto* storage_ty = var->Type()->UnwrapRef(); auto* storage_ty = v->Type()->UnwrapRef();
if (var->Is<sem::GlobalVariable>()) { auto* as_let = decl->As<ast::Let>();
auto* as_var = decl->As<ast::Var>();
if (v->Is<sem::GlobalVariable>()) {
auto name = symbols_.NameFor(decl->symbol); auto name = symbols_.NameFor(decl->symbol);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
auto* kind = var->Declaration()->is_const ? "let" : "var"; auto* kind = as_let ? "let" : "var";
AddError( AddError(
"'" + name + "' is a builtin and cannot be redeclared as a module-scope " + kind, "'" + name + "' is a builtin and cannot be redeclared as a module-scope " + kind,
decl->source); decl->source);
@ -656,14 +673,13 @@ bool Validator::Variable(const sem::Variable* var) const {
} }
} }
if (!decl->is_const && !IsStorable(storage_ty)) { if (as_var && !IsStorable(storage_ty)) {
AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a var", AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a var",
decl->source); decl->source);
return false; return false;
} }
if (decl->is_const && !var->Is<sem::Parameter>() && if (as_let && !(storage_ty->IsConstructible() || storage_ty->Is<sem::Pointer>())) {
!(storage_ty->IsConstructible() || storage_ty->Is<sem::Pointer>())) {
AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a let", AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a let",
decl->source); decl->source);
return false; return false;
@ -688,16 +704,17 @@ bool Validator::Variable(const sem::Variable* var) const {
} }
} }
if (var->Is<sem::LocalVariable>() && !decl->is_const && if (v->Is<sem::LocalVariable>() && as_var &&
IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass)) { IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass)) {
if (!var->Type()->UnwrapRef()->IsConstructible()) { if (!v->Type()->UnwrapRef()->IsConstructible()) {
AddError("function variable must have a constructible type", AddError("function variable must have a constructible type",
decl->type ? decl->type->source : decl->source); decl->type ? decl->type->source : decl->source);
return false; return false;
} }
} }
if (storage_ty->is_handle() && decl->declared_storage_class != ast::StorageClass::kNone) { if (as_var && storage_ty->is_handle() &&
as_var->declared_storage_class != ast::StorageClass::kNone) {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables // https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// If the store type is a texture type or a sampler type, then the // If the store type is a texture type or a sampler type, then the
// variable declaration must not have a storage class attribute. The // variable declaration must not have a storage class attribute. The
@ -709,9 +726,10 @@ bool Validator::Variable(const sem::Variable* var) const {
} }
if (IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass) && if (IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass) &&
(decl->declared_storage_class == ast::StorageClass::kInput || as_var &&
decl->declared_storage_class == ast::StorageClass::kOutput)) { (as_var->declared_storage_class == ast::StorageClass::kInput ||
AddError("invalid use of input/output storage class", decl->source); as_var->declared_storage_class == ast::StorageClass::kOutput)) {
AddError("invalid use of input/output storage class", as_var->source);
return false; return false;
} }
return true; return true;
@ -1223,12 +1241,12 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
// Validate there are no resource variable binding collisions // Validate there are no resource variable binding collisions
std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points; std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
for (auto* var : func->TransitivelyReferencedGlobals()) { for (auto* global : func->TransitivelyReferencedGlobals()) {
auto* var_decl = var->Declaration(); auto* var_decl = global->Declaration()->As<ast::Var>();
if (!var_decl->BindingPoint()) { if (!var_decl || !var_decl->BindingPoint()) {
continue; continue;
} }
auto bp = var->BindingPoint(); auto bp = global->BindingPoint();
auto res = binding_points.emplace(bp, var_decl); auto res = binding_points.emplace(bp, var_decl);
if (!res.second && if (!res.second &&
IsValidationEnabled(decl->attributes, IsValidationEnabled(decl->attributes,
@ -1663,12 +1681,6 @@ bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_stat
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false; return false;
} }
if (var->Declaration()->is_const) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::FunctionCall() encountered an address-of "
"expression of a constant identifier expression";
return false;
}
is_valid = true; is_valid = true;
} }
} }
@ -2172,18 +2184,16 @@ bool Validator::Assignment(const ast::Statement* a, const sem::Type* rhs_ty) con
// https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
auto const* lhs_ty = sem_.TypeOf(lhs); auto const* lhs_ty = sem_.TypeOf(lhs);
if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) { if (auto* variable = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration(); auto* v = variable->Declaration();
if (var->Is<sem::Parameter>()) { const char* err = Switch(
AddError("cannot assign to function parameter", lhs->source); v, //
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); [&](const ast::Parameter*) { return "cannot assign to function parameter"; },
return false; [&](const ast::Let*) { return "cannot assign to 'let'"; },
} [&](const ast::Override*) { return "cannot assign to 'override'"; });
if (decl->is_const) { if (err) {
AddError( AddError(err, lhs->source);
decl->is_overridable ? "cannot assign to 'override'" : "cannot assign to 'let'", AddNote("'" + symbols_.NameFor(v->symbol) + "' is declared here:", v->source);
lhs->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
return false; return false;
} }
} }
@ -2222,17 +2232,16 @@ bool Validator::IncrementDecrementStatement(const ast::IncrementDecrementStateme
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) { if (auto* variable = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration(); auto* v = variable->Declaration();
if (var->Is<sem::Parameter>()) { const char* err = Switch(
AddError("cannot modify function parameter", lhs->source); v, //
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); [&](const ast::Parameter*) { return "cannot modify function parameter"; },
return false; [&](const ast::Let*) { return "cannot modify 'let'"; },
} [&](const ast::Override*) { return "cannot modify 'override'"; });
if (decl->is_const) { if (err) {
AddError(decl->is_overridable ? "cannot modify 'override'" : "cannot modify 'let'", AddError(err, lhs->source);
lhs->source); AddNote("'" + symbols_.NameFor(v->symbol) + "' is declared here:", v->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
return false; return false;
} }
} }

View File

@ -237,7 +237,7 @@ class Validator {
/// @param atomic_composite_info atomic composite info in the module /// @param atomic_composite_info atomic composite info in the module
/// @returns true on success, false otherwise /// @returns true on success, false otherwise
bool GlobalVariable( bool GlobalVariable(
const sem::Variable* var, const sem::GlobalVariable* var,
std::unordered_map<uint32_t, const sem::Variable*> constant_ids, std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const; std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const;
@ -345,12 +345,12 @@ class Validator {
bool Variable(const sem::Variable* var) const; bool Variable(const sem::Variable* var) const;
/// Validates a variable constructor or cast /// Validates a variable constructor or cast
/// @param var the variable to validate /// @param v the variable to validate
/// @param storage_class the storage class of the variable /// @param storage_class the storage class of the variable
/// @param storage_type the type of the storage /// @param storage_type the type of the storage
/// @param rhs_type the right hand side of the expression /// @param rhs_type the right hand side of the expression
/// @returns true on succes, false otherwise /// @returns true on succes, false otherwise
bool VariableConstructorOrCast(const ast::Variable* var, bool VariableConstructorOrCast(const ast::Variable* v,
ast::StorageClass storage_class, ast::StorageClass storage_class,
const sem::Type* storage_type, const sem::Type* storage_type,
const sem::Type* rhs_type) const; const sem::Type* rhs_type) const;

View File

@ -24,22 +24,6 @@ namespace {
struct ResolverVarLetValidationTest : public resolver::TestHelper, public testing::Test {}; struct ResolverVarLetValidationTest : public resolver::TestHelper, public testing::Test {};
TEST_F(ResolverVarLetValidationTest, LetNoInitializer) {
// let a : i32;
WrapInFunction(Let(Source{{12, 34}}, "a", ty.i32(), nullptr));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer");
}
TEST_F(ResolverVarLetValidationTest, GlobalLetNoInitializer) {
// let a : i32;
GlobalConst(Source{{12, 34}}, "a", ty.i32(), nullptr);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer");
}
TEST_F(ResolverVarLetValidationTest, VarNoInitializerNoType) { TEST_F(ResolverVarLetValidationTest, VarNoInitializerNoType) {
// var a; // var a;
WrapInFunction(Var(Source{{12, 34}}, "a", nullptr)); WrapInFunction(Var(Source{{12, 34}}, "a", nullptr));

View File

@ -44,10 +44,10 @@ std::vector<std::pair<const Variable*, const ast::LocationAttribute*>>
Function::TransitivelyReferencedLocationVariables() const { Function::TransitivelyReferencedLocationVariables() const {
std::vector<std::pair<const Variable*, const ast::LocationAttribute*>> ret; std::vector<std::pair<const Variable*, const ast::LocationAttribute*>> ret;
for (auto* var : TransitivelyReferencedGlobals()) { for (auto* global : TransitivelyReferencedGlobals()) {
for (auto* attr : var->Declaration()->attributes) { for (auto* attr : global->Declaration()->attributes) {
if (auto* location = attr->As<ast::LocationAttribute>()) { if (auto* location = attr->As<ast::LocationAttribute>()) {
ret.push_back({var, location}); ret.push_back({global, location});
break; break;
} }
} }
@ -58,13 +58,13 @@ Function::TransitivelyReferencedLocationVariables() const {
Function::VariableBindings Function::TransitivelyReferencedUniformVariables() const { Function::VariableBindings Function::TransitivelyReferencedUniformVariables() const {
VariableBindings ret; VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) { for (auto* global : TransitivelyReferencedGlobals()) {
if (var->StorageClass() != ast::StorageClass::kUniform) { if (global->StorageClass() != ast::StorageClass::kUniform) {
continue; continue;
} }
if (auto binding_point = var->Declaration()->BindingPoint()) { if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point}); ret.push_back({global, binding_point});
} }
} }
return ret; return ret;
@ -73,13 +73,13 @@ Function::VariableBindings Function::TransitivelyReferencedUniformVariables() co
Function::VariableBindings Function::TransitivelyReferencedStorageBufferVariables() const { Function::VariableBindings Function::TransitivelyReferencedStorageBufferVariables() const {
VariableBindings ret; VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) { for (auto* global : TransitivelyReferencedGlobals()) {
if (var->StorageClass() != ast::StorageClass::kStorage) { if (global->StorageClass() != ast::StorageClass::kStorage) {
continue; continue;
} }
if (auto binding_point = var->Declaration()->BindingPoint()) { if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point}); ret.push_back({global, binding_point});
} }
} }
return ret; return ret;
@ -89,10 +89,10 @@ std::vector<std::pair<const Variable*, const ast::BuiltinAttribute*>>
Function::TransitivelyReferencedBuiltinVariables() const { Function::TransitivelyReferencedBuiltinVariables() const {
std::vector<std::pair<const Variable*, const ast::BuiltinAttribute*>> ret; std::vector<std::pair<const Variable*, const ast::BuiltinAttribute*>> ret;
for (auto* var : TransitivelyReferencedGlobals()) { for (auto* global : TransitivelyReferencedGlobals()) {
for (auto* attr : var->Declaration()->attributes) { for (auto* attr : global->Declaration()->attributes) {
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) { if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
ret.push_back({var, builtin}); ret.push_back({global, builtin});
break; break;
} }
} }
@ -119,11 +119,11 @@ Function::VariableBindings Function::TransitivelyReferencedMultisampledTextureVa
Function::VariableBindings Function::TransitivelyReferencedVariablesOfType( Function::VariableBindings Function::TransitivelyReferencedVariablesOfType(
const tint::TypeInfo* type) const { const tint::TypeInfo* type) const {
VariableBindings ret; VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) { for (auto* global : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = var->Type()->UnwrapRef(); auto* unwrapped_type = global->Type()->UnwrapRef();
if (unwrapped_type->TypeInfo().Is(type)) { if (unwrapped_type->TypeInfo().Is(type)) {
if (auto binding_point = var->Declaration()->BindingPoint()) { if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point}); ret.push_back({global, binding_point});
} }
} }
} }
@ -143,15 +143,15 @@ Function::VariableBindings Function::TransitivelyReferencedSamplerVariablesImpl(
ast::SamplerKind kind) const { ast::SamplerKind kind) const {
VariableBindings ret; VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) { for (auto* global : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = var->Type()->UnwrapRef(); auto* unwrapped_type = global->Type()->UnwrapRef();
auto* sampler = unwrapped_type->As<sem::Sampler>(); auto* sampler = unwrapped_type->As<sem::Sampler>();
if (sampler == nullptr || sampler->kind() != kind) { if (sampler == nullptr || sampler->kind() != kind) {
continue; continue;
} }
if (auto binding_point = var->Declaration()->BindingPoint()) { if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point}); ret.push_back({global, binding_point});
} }
} }
return ret; return ret;
@ -161,8 +161,8 @@ Function::VariableBindings Function::TransitivelyReferencedSampledTextureVariabl
bool multisampled) const { bool multisampled) const {
VariableBindings ret; VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) { for (auto* global : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = var->Type()->UnwrapRef(); auto* unwrapped_type = global->Type()->UnwrapRef();
auto* texture = unwrapped_type->As<sem::Texture>(); auto* texture = unwrapped_type->As<sem::Texture>();
if (texture == nullptr) { if (texture == nullptr) {
continue; continue;
@ -175,8 +175,8 @@ Function::VariableBindings Function::TransitivelyReferencedSampledTextureVariabl
continue; continue;
} }
if (auto binding_point = var->Declaration()->BindingPoint()) { if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point}); ret.push_back({global, binding_point});
} }
} }

View File

@ -27,6 +27,7 @@ class Function;
class IfStatement; class IfStatement;
class MemberAccessorExpression; class MemberAccessorExpression;
class Node; class Node;
class Override;
class Statement; class Statement;
class Struct; class Struct;
class StructMember; class StructMember;
@ -45,6 +46,7 @@ class Function;
class IfStatement; class IfStatement;
class MemberAccessorExpression; class MemberAccessorExpression;
class Node; class Node;
class GlobalVariable;
class Statement; class Statement;
class Struct; class Struct;
class StructMember; class StructMember;
@ -69,6 +71,7 @@ struct TypeMappings {
IfStatement* operator()(ast::IfStatement*); IfStatement* operator()(ast::IfStatement*);
MemberAccessorExpression* operator()(ast::MemberAccessorExpression*); MemberAccessorExpression* operator()(ast::MemberAccessorExpression*);
Node* operator()(ast::Node*); Node* operator()(ast::Node*);
GlobalVariable* operator()(ast::Override*);
Statement* operator()(ast::Statement*); Statement* operator()(ast::Statement*);
Struct* operator()(ast::Struct*); Struct* operator()(ast::Struct*);
StructMember* operator()(ast::StructMember*); StructMember* operator()(ast::StructMember*);

View File

@ -62,7 +62,7 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration,
GlobalVariable::~GlobalVariable() = default; GlobalVariable::~GlobalVariable() = default;
Parameter::Parameter(const ast::Variable* declaration, Parameter::Parameter(const ast::Parameter* declaration,
uint32_t index, uint32_t index,
const sem::Type* type, const sem::Type* type,
ast::StorageClass storage_class, ast::StorageClass storage_class,

View File

@ -154,24 +154,14 @@ class GlobalVariable final : public Castable<GlobalVariable, Variable> {
sem::BindingPoint BindingPoint() const { return binding_point_; } sem::BindingPoint BindingPoint() const { return binding_point_; }
/// @param id the constant identifier to assign to this variable /// @param id the constant identifier to assign to this variable
void SetConstantId(uint16_t id) { void SetConstantId(uint16_t id) { constant_id_ = id; }
constant_id_ = id;
is_overridable_ = true;
}
/// @returns the pipeline constant ID associated with the variable /// @returns the pipeline constant ID associated with the variable
uint16_t ConstantId() const { return constant_id_; } uint16_t ConstantId() const { return constant_id_; }
/// @param is_overridable true if this is a pipeline overridable constant
void SetIsOverridable(bool is_overridable = true) { is_overridable_ = is_overridable; }
/// @returns true if this is pipeline overridable constant
bool IsOverridable() const { return is_overridable_; }
private: private:
const sem::BindingPoint binding_point_; const sem::BindingPoint binding_point_;
bool is_overridable_ = false;
uint16_t constant_id_ = 0; uint16_t constant_id_ = 0;
}; };
@ -185,7 +175,7 @@ class Parameter final : public Castable<Parameter, Variable> {
/// @param storage_class the variable storage class /// @param storage_class the variable storage class
/// @param access the variable access control type /// @param access the variable access control type
/// @param usage the semantic usage for the parameter /// @param usage the semantic usage for the parameter
Parameter(const ast::Variable* declaration, Parameter(const ast::Parameter* declaration,
uint32_t index, uint32_t index,
const sem::Type* type, const sem::Type* type,
ast::StorageClass storage_class, ast::StorageClass storage_class,

View File

@ -54,8 +54,8 @@ void AddSpirvBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) co
// contains it in the destination program. // contains it in the destination program.
std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs; std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
// Process global variables that are buffers. // Process global 'var' declarations that are buffers.
for (auto* var : ctx.src->AST().GlobalVariables()) { for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
auto* sem_var = sem.Get<sem::GlobalVariable>(var); auto* sem_var = sem.Get<sem::GlobalVariable>(var);
if (var->declared_storage_class != ast::StorageClass::kStorage && if (var->declared_storage_class != ast::StorageClass::kStorage &&
var->declared_storage_class != ast::StorageClass::kUniform) { var->declared_storage_class != ast::StorageClass::kUniform) {

View File

@ -67,8 +67,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
} }
auto* func = ctx.src->Sem().Get(func_ast); auto* func = ctx.src->Sem().Get(func_ast);
std::unordered_map<sem::BindingPoint, int> binding_point_counts; std::unordered_map<sem::BindingPoint, int> binding_point_counts;
for (auto* var : func->TransitivelyReferencedGlobals()) { for (auto* global : func->TransitivelyReferencedGlobals()) {
if (auto binding_point = var->Declaration()->BindingPoint()) { if (auto binding_point = global->Declaration()->BindingPoint()) {
BindingPoint from{binding_point.group->value, binding_point.binding->value}; BindingPoint from{binding_point.group->value, binding_point.binding->value};
auto bp_it = remappings->binding_points.find(from); auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) { if (bp_it != remappings->binding_points.end()) {
@ -88,7 +88,7 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
} }
} }
for (auto* var : ctx.src->AST().GlobalVariables()) { for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
if (auto binding_point = var->BindingPoint()) { if (auto binding_point = var->BindingPoint()) {
// The original binding point // The original binding point
BindingPoint from{binding_point.group->value, binding_point.binding->value}; BindingPoint from{binding_point.group->value, binding_point.binding->value};
@ -130,10 +130,10 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
} }
auto* ty = sem->Type()->UnwrapRef(); auto* ty = sem->Type()->UnwrapRef();
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty); const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
auto* new_var = ctx.dst->create<ast::Variable>( auto* new_var =
ctx.Clone(var->source), ctx.Clone(var->symbol), var->declared_storage_class, ac, ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
inner_ty, false, false, ctx.Clone(var->constructor), var->declared_storage_class, ac, ctx.Clone(var->constructor),
ctx.Clone(var->attributes)); ctx.Clone(var->attributes));
ctx.Replace(var, new_var); ctx.Replace(var, new_var);
} }

View File

@ -147,16 +147,16 @@ struct CombineSamplers::State {
// Remove all texture and sampler global variables. These will be replaced // Remove all texture and sampler global variables. These will be replaced
// by combined samplers. // by combined samplers.
for (auto* var : ctx.src->AST().GlobalVariables()) { for (auto* global : ctx.src->AST().GlobalVariables()) {
auto* type = sem.Get(var->type); auto* type = sem.Get(global->type);
if (type && type->IsAnyOf<sem::Texture, sem::Sampler>() && if (tint::IsAnyOf<sem::Texture, sem::Sampler>(type) &&
!type->Is<sem::StorageTexture>()) { !type->Is<sem::StorageTexture>()) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var); ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
} else if (auto binding_point = var->BindingPoint()) { } else if (auto binding_point = global->BindingPoint()) {
if (binding_point.group->value == 0 && binding_point.binding->value == 0) { if (binding_point.group->value == 0 && binding_point.binding->value == 0) {
auto* attribute = auto* attribute =
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision); ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertFront(var->attributes, attribute); ctx.InsertFront(global->attributes, attribute);
} }
} }
} }
@ -188,9 +188,8 @@ struct CombineSamplers::State {
} else { } else {
// Either texture or sampler (or both) is a function parameter; // Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler. // add a new function parameter to represent the combined sampler.
const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var); auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
const ast::Variable* var = auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.push_back(var); params.push_back(var);
function_combined_texture_samplers_[func][pair] = var; function_combined_texture_samplers_[func][pair] = var;
} }

View File

@ -31,11 +31,11 @@ const ast::VariableDeclStatement* AsTrivialLetDecl(const ast::Statement* stmt) {
if (!var_decl) { if (!var_decl) {
return nullptr; return nullptr;
} }
auto* var = var_decl->variable; auto* let = var_decl->variable->As<ast::Let>();
if (!var->is_const) { if (!let) {
return nullptr; return nullptr;
} }
auto* ctor = var->constructor; auto* ctor = let->constructor;
if (!IsAnyOf<ast::IdentifierExpression, ast::LiteralExpression>(ctor)) { if (!IsAnyOf<ast::IdentifierExpression, ast::LiteralExpression>(ctor)) {
return nullptr; return nullptr;
} }

View File

@ -155,9 +155,13 @@ struct ModuleScopeVarToEntryPointParam::State {
return workgroup_parameter_symbol; return workgroup_parameter_symbol;
}; };
for (auto* var : func_sem->TransitivelyReferencedGlobals()) { for (auto* global : func_sem->TransitivelyReferencedGlobals()) {
auto sc = var->StorageClass(); auto* var = global->Declaration()->As<ast::Var>();
auto* ty = var->Type()->UnwrapRef(); if (!var) {
continue;
}
auto sc = global->StorageClass();
auto* ty = global->Type()->UnwrapRef();
if (sc == ast::StorageClass::kNone) { if (sc == ast::StorageClass::kNone) {
continue; continue;
} }
@ -182,12 +186,12 @@ struct ModuleScopeVarToEntryPointParam::State {
bool is_wrapped = false; bool is_wrapped = false;
if (is_entry_point) { if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) { if (global->Type()->UnwrapRef()->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point // For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation. // parameter. Disable entry point parameter validation.
auto* disable_validation = auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter); ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto attrs = ctx.Clone(var->Declaration()->attributes); auto attrs = ctx.Clone(var->attributes);
attrs.push_back(disable_validation); attrs.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs); auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
ctx.InsertFront(func_ast->params, param); ctx.InsertFront(func_ast->params, param);
@ -195,7 +199,7 @@ struct ModuleScopeVarToEntryPointParam::State {
sc == ast::StorageClass::kUniform) { sc == ast::StorageClass::kUniform) {
// Variables into the Storage and Uniform storage classes are // Variables into the Storage and Uniform storage classes are
// redeclared as entry point parameters with a pointer type. // redeclared as entry point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->attributes); auto attributes = ctx.Clone(var->attributes);
attributes.push_back( attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter)); ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
attributes.push_back( attributes.push_back(
@ -214,22 +218,22 @@ struct ModuleScopeVarToEntryPointParam::State {
is_wrapped = true; is_wrapped = true;
} }
param_type = ctx.dst->ty.pointer(param_type, sc, param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access);
var->Declaration()->declared_access);
auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes); auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func_ast->params, param); ctx.InsertFront(func_ast->params, param);
is_pointer = true; is_pointer = true;
} else if (sc == ast::StorageClass::kWorkgroup && ContainsMatrix(var->Type())) { } else if (sc == ast::StorageClass::kWorkgroup &&
ContainsMatrix(global->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory // Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix. // argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938. // See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too. // TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct. // Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->Declaration()->symbol); auto member = ctx.Clone(var->symbol);
workgroup_parameter_members.push_back( workgroup_parameter_members.push_back(
ctx.dst->Member(member, store_type())); ctx.dst->Member(member, store_type()));
CloneStructTypes(var->Type()->UnwrapRef()); CloneStructTypes(global->Type()->UnwrapRef());
// Create a function-scope variable that is a pointer to the member. // Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf( auto* member_ptr = ctx.dst->AddressOf(
@ -246,7 +250,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// this variable. // this variable.
auto* disable_validation = auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass); ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor); auto* constructor = ctx.Clone(var->constructor);
auto* local_var = auto* local_var =
ctx.dst->Var(new_var_symbol, store_type(), sc, constructor, ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
ast::AttributeList{disable_validation}); ast::AttributeList{disable_validation});
@ -257,9 +261,8 @@ struct ModuleScopeVarToEntryPointParam::State {
// Use a pointer for non-handle types. // Use a pointer for non-handle types.
auto* param_type = store_type(); auto* param_type = store_type();
ast::AttributeList attributes; ast::AttributeList attributes;
if (!var->Type()->UnwrapRef()->is_handle()) { if (!global->Type()->UnwrapRef()->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, sc, param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access);
var->Declaration()->declared_access);
is_pointer = true; is_pointer = true;
// Disable validation of the parameter's storage class and of // Disable validation of the parameter's storage class and of
@ -275,7 +278,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Replace all uses of the module-scope variable. // Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters. // For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) { for (auto* user : global->Users()) {
if (user->Stmt()->Function()->Declaration() == func_ast) { if (user->Stmt()->Function()->Declaration() == func_ast) {
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol); const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) { if (is_pointer) {
@ -298,7 +301,7 @@ struct ModuleScopeVarToEntryPointParam::State {
} }
} }
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped}; var_to_newvar[global] = {new_var_symbol, is_pointer, is_wrapped};
} }
if (!workgroup_parameter_members.empty()) { if (!workgroup_parameter_members.empty()) {

View File

@ -86,8 +86,8 @@ struct MultiplanarExternalTexture::State {
// binding and create two additional bindings (one texture_2d<f32> to // binding and create two additional bindings (one texture_2d<f32> to
// represent the secondary plane and one uniform buffer for the // represent the secondary plane and one uniform buffer for the
// ExternalTextureParams struct). // ExternalTextureParams struct).
for (auto* var : ctx.src->AST().GlobalVariables()) { for (auto* global : ctx.src->AST().GlobalVariables()) {
auto* sem_var = sem.Get(var); auto* sem_var = sem.Get(global);
if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) { if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
continue; continue;
} }
@ -95,7 +95,7 @@ struct MultiplanarExternalTexture::State {
// If the attributes are empty, then this must be a texture_external // If the attributes are empty, then this must be a texture_external
// passed as a function parameter. These variables are transformed // passed as a function parameter. These variables are transformed
// elsewhere. // elsewhere.
if (var->attributes.empty()) { if (global->attributes.empty()) {
continue; continue;
} }
@ -109,8 +109,8 @@ struct MultiplanarExternalTexture::State {
// provided to this transform. We fetch the new binding points by // provided to this transform. We fetch the new binding points by
// providing the original texture_external binding points into the // providing the original texture_external binding points into the
// passed map. // passed map.
BindingPoint bp = {var->BindingPoint().group->value, BindingPoint bp = {global->BindingPoint().group->value,
var->BindingPoint().binding->value}; global->BindingPoint().binding->value};
BindingsMap::const_iterator it = new_binding_points->bindings_map.find(bp); BindingsMap::const_iterator it = new_binding_points->bindings_map.find(bp);
if (it == new_binding_points->bindings_map.end()) { if (it == new_binding_points->bindings_map.end()) {
@ -129,7 +129,7 @@ struct MultiplanarExternalTexture::State {
// corresponds with the new destination bindings. // corresponds with the new destination bindings.
// NewBindingSymbols new_binding_syms; // NewBindingSymbols new_binding_syms;
auto& syms = new_binding_symbols[sem_var]; auto& syms = new_binding_symbols[sem_var];
syms.plane_0 = ctx.Clone(var->symbol); syms.plane_0 = ctx.Clone(global->symbol);
syms.plane_1 = b.Symbols().New("ext_tex_plane_1"); syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
b.Global(syms.plane_1, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()), b.Global(syms.plane_1, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding)); b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding));
@ -140,13 +140,13 @@ struct MultiplanarExternalTexture::State {
// Replace the original texture_external binding with a texture_2d<f32> // Replace the original texture_external binding with a texture_2d<f32>
// binding. // binding.
ast::AttributeList cloned_attributes = ctx.Clone(var->attributes); ast::AttributeList cloned_attributes = ctx.Clone(global->attributes);
const ast::Expression* cloned_constructor = ctx.Clone(var->constructor); const ast::Expression* cloned_constructor = ctx.Clone(global->constructor);
auto* replacement = auto* replacement =
b.Var(syms.plane_0, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()), b.Var(syms.plane_0, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
cloned_constructor, cloned_attributes); cloned_constructor, cloned_attributes);
ctx.Replace(var, replacement); ctx.Replace(global, replacement);
} }
// We must update all the texture_external parameters for user declared // We must update all the texture_external parameters for user declared

View File

@ -133,8 +133,8 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
// plus 1, or group 0 if no resource bound. // plus 1, or group 0 if no resource bound.
group = 0; group = 0;
for (auto* var : ctx.src->AST().GlobalVariables()) { for (auto* global : ctx.src->AST().GlobalVariables()) {
if (auto binding_point = var->BindingPoint()) { if (auto binding_point = global->BindingPoint()) {
if (binding_point.group->value >= group) { if (binding_point.group->value >= group) {
group = binding_point.group->value + 1; group = binding_point.group->value + 1;
} }

View File

@ -109,8 +109,8 @@ struct SimplifyPointers::State {
} }
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) { if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
auto* var = user->Variable(); auto* var = user->Variable();
if (var->Is<sem::LocalVariable>() && // if (var->Is<sem::LocalVariable>() && //
var->Declaration()->is_const && // var->Declaration()->Is<ast::Let>() && //
var->Type()->Is<sem::Pointer>()) { var->Type()->Is<sem::Pointer>()) {
op.expr = var->Declaration()->constructor; op.expr = var->Declaration()->constructor;
continue; continue;
@ -161,7 +161,7 @@ struct SimplifyPointers::State {
// permitted. // permitted.
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* let = node->As<ast::VariableDeclStatement>()) { if (auto* let = node->As<ast::VariableDeclStatement>()) {
if (!let->variable->is_const) { if (!let->variable->Is<ast::Let>()) {
continue; // Not a `let` declaration. Ignore. continue; // Not a `let` declaration. Ignore.
} }

View File

@ -64,38 +64,43 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
referenced_vars.emplace(var->Declaration()); referenced_vars.emplace(var->Declaration());
} }
// Clone any module-scope variables, types, and functions that are statically // Clone any module-scope variables, types, and functions that are statically referenced by the
// referenced by the target entry point. // target entry point.
for (auto* decl : ctx.src->AST().GlobalDeclarations()) { for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<ast::TypeDecl>()) { Switch(
// TODO(jrprice): Strip unused types. decl, //
ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); [&](const ast::TypeDecl* ty) {
} else if (auto* var = decl->As<ast::Variable>()) { // TODO(jrprice): Strip unused types.
if (referenced_vars.count(var)) { ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
if (var->is_overridable) { },
// It is an overridable constant [&](const ast::Override* override) {
if (!ast::HasAttribute<ast::IdAttribute>(var->attributes)) { if (referenced_vars.count(override)) {
if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
// If the constant doesn't already have an @id() attribute, add one // If the constant doesn't already have an @id() attribute, add one
// so that its allocated ID so that it won't be affected by other // so that its allocated ID so that it won't be affected by other
// stripped away constants // stripped away constants
auto* global = sem.Get(var)->As<sem::GlobalVariable>(); auto* global = sem.Get(override);
const auto* id = ctx.dst->Id(global->ConstantId()); const auto* id = ctx.dst->Id(global->ConstantId());
ctx.InsertFront(var->attributes, id); ctx.InsertFront(override->attributes, id);
} }
ctx.dst->AST().AddGlobalVariable(ctx.Clone(override));
} }
ctx.dst->AST().AddGlobalVariable(ctx.Clone(var)); },
} [&](const ast::Variable* v) { // var, let
} else if (auto* func = decl->As<ast::Function>()) { if (referenced_vars.count(v)) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) { ctx.dst->AST().AddGlobalVariable(ctx.Clone(v));
ctx.dst->AST().AddFunction(ctx.Clone(func)); }
} },
} else if (auto* ext = decl->As<ast::Enable>()) { [&](const ast::Function* func) {
ctx.dst->AST().AddEnable(ctx.Clone(ext)); if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
} else { ctx.dst->AST().AddFunction(ctx.Clone(func));
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) }
<< "unhandled global declaration: " << decl->TypeInfo().name; },
return; [&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); },
} [&](Default) {
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;
});
} }
// Clone the entry point. // Clone the entry point.

View File

@ -44,28 +44,42 @@ struct Unshadow::State {
// Maps a variable to its new name. // Maps a variable to its new name.
std::unordered_map<const sem::Variable*, Symbol> renamed_to; std::unordered_map<const sem::Variable*, Symbol> renamed_to;
auto rename = [&](const sem::Variable* var) -> const ast::Variable* { auto rename = [&](const sem::Variable* v) -> const ast::Variable* {
auto* decl = var->Declaration(); auto* decl = v->Declaration();
auto name = ctx.src->Symbols().NameFor(decl->symbol); auto name = ctx.src->Symbols().NameFor(decl->symbol);
auto symbol = ctx.dst->Symbols().New(name); auto symbol = ctx.dst->Symbols().New(name);
renamed_to.emplace(var, symbol); renamed_to.emplace(v, symbol);
auto source = ctx.Clone(decl->source); auto source = ctx.Clone(decl->source);
auto* type = ctx.Clone(decl->type); auto* type = ctx.Clone(decl->type);
auto* constructor = ctx.Clone(decl->constructor); auto* constructor = ctx.Clone(decl->constructor);
auto attributes = ctx.Clone(decl->attributes); auto attributes = ctx.Clone(decl->attributes);
return ctx.dst->create<ast::Variable>(source, symbol, decl->declared_storage_class, return Switch(
decl->declared_access, type, decl->is_const, decl, //
decl->is_overridable, constructor, attributes); [&](const ast::Var* var) {
return ctx.dst->Var(source, symbol, type, var->declared_storage_class,
var->declared_access, constructor, attributes);
},
[&](const ast::Let*) {
return ctx.dst->Let(source, symbol, type, constructor, attributes);
},
[&](const ast::Parameter*) {
return ctx.dst->Param(source, symbol, type, attributes);
},
[&](Default) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected variable type: " << decl->TypeInfo().name;
return nullptr;
});
}; };
ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* { ctx.ReplaceAll([&](const ast::Variable* v) -> const ast::Variable* {
if (auto* local = sem.Get<sem::LocalVariable>(var)) { if (auto* local = sem.Get<sem::LocalVariable>(v)) {
if (local->Shadows()) { if (local->Shadows()) {
return rename(local); return rename(local);
} }
} }
if (auto* param = sem.Get<sem::Parameter>(var)) { if (auto* param = sem.Get<sem::Parameter>(v)) {
if (param->Shadows()) { if (param->Shadows()) {
return rename(param); return rename(param);
} }

View File

@ -189,18 +189,18 @@ class HoistToDeclBefore::State {
/// before `before_expr`. /// before `before_expr`.
/// @param before_expr expression to insert `expr` before /// @param before_expr expression to insert `expr` before
/// @param expr expression to hoist /// @param expr expression to hoist
/// @param as_const hoist to `let` if true, otherwise to `var` /// @param as_let hoist to `let` if true, otherwise to `var`
/// @param decl_name optional name to use for the variable/constant name /// @param decl_name optional name to use for the variable/constant name
/// @return true on success /// @return true on success
bool Add(const sem::Expression* before_expr, bool Add(const sem::Expression* before_expr,
const ast::Expression* expr, const ast::Expression* expr,
bool as_const, bool as_let,
const char* decl_name) { const char* decl_name) {
auto name = b.Symbols().New(decl_name); auto name = b.Symbols().New(decl_name);
// Construct the let/var that holds the hoisted expr // Construct the let/var that holds the hoisted expr
auto* v = as_const ? b.Let(name, nullptr, ctx.Clone(expr)) auto* v = as_let ? static_cast<const ast::Variable*>(b.Let(name, nullptr, ctx.Clone(expr)))
: b.Var(name, nullptr, ctx.Clone(expr)); : static_cast<const ast::Variable*>(b.Var(name, nullptr, ctx.Clone(expr)));
auto* decl = b.Decl(v); auto* decl = b.Decl(v);
if (!InsertBefore(before_expr->Stmt(), decl)) { if (!InsertBefore(before_expr->Stmt(), decl)) {
@ -330,9 +330,9 @@ HoistToDeclBefore::~HoistToDeclBefore() {}
bool HoistToDeclBefore::Add(const sem::Expression* before_expr, bool HoistToDeclBefore::Add(const sem::Expression* before_expr,
const ast::Expression* expr, const ast::Expression* expr,
bool as_const, bool as_let,
const char* decl_name) { const char* decl_name) {
return state_->Add(before_expr, expr, as_const, decl_name); return state_->Add(before_expr, expr, as_let, decl_name);
} }
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,

View File

@ -695,7 +695,7 @@ struct State {
/// vertex_index and instance_index builtins if present. /// vertex_index and instance_index builtins if present.
/// @param func the entry point function /// @param func the entry point function
/// @param param the parameter to process /// @param param the parameter to process
void ProcessNonStructParameter(const ast::Function* func, const ast::Variable* param) { void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) {
if (auto* location = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) { if (auto* location = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) {
// Create a function-scope variable to replace the parameter. // Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol); auto func_var_sym = ctx.Clone(param->symbol);
@ -733,7 +733,7 @@ struct State {
/// @param param the parameter to process /// @param param the parameter to process
/// @param struct_ty the structure type /// @param struct_ty the structure type
void ProcessStructParameter(const ast::Function* func, void ProcessStructParameter(const ast::Function* func,
const ast::Variable* param, const ast::Parameter* param,
const ast::Struct* struct_ty) { const ast::Struct* struct_ty) {
auto param_sym = ctx.Clone(param->symbol); auto param_sym = ctx.Clone(param->symbol);

View File

@ -416,8 +416,8 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const { bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) { for (auto* global : program->AST().GlobalVariables()) {
if (auto* var = decl->As<ast::Variable>()) { if (auto* var = global->As<ast::Var>()) {
if (var->declared_storage_class == ast::StorageClass::kWorkgroup) { if (var->declared_storage_class == ast::StorageClass::kWorkgroup) {
return true; return true;
} }

View File

@ -1904,42 +1904,47 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
} }
bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) { bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
if (global->is_const) { return Switch(
return EmitProgramConstVariable(global); global, //
} [&](const ast::Var* var) {
auto* sem = builder_.Sem().Get(global);
auto* sem = builder_.Sem().Get(global); switch (sem->StorageClass()) {
switch (sem->StorageClass()) { case ast::StorageClass::kUniform:
case ast::StorageClass::kUniform: return EmitUniformVariable(var, sem);
return EmitUniformVariable(sem); case ast::StorageClass::kStorage:
case ast::StorageClass::kStorage: return EmitStorageVariable(var, sem);
return EmitStorageVariable(sem); case ast::StorageClass::kHandle:
case ast::StorageClass::kHandle: return EmitHandleVariable(var, sem);
return EmitHandleVariable(sem); case ast::StorageClass::kPrivate:
case ast::StorageClass::kPrivate: return EmitPrivateVariable(sem);
return EmitPrivateVariable(sem); case ast::StorageClass::kWorkgroup:
case ast::StorageClass::kWorkgroup: return EmitWorkgroupVariable(sem);
return EmitWorkgroupVariable(sem); case ast::StorageClass::kInput:
case ast::StorageClass::kInput: case ast::StorageClass::kOutput:
case ast::StorageClass::kOutput: return EmitIOVariable(sem);
return EmitIOVariable(sem); default:
default: TINT_ICE(Writer, diagnostics_)
break; << "unhandled storage class " << sem->StorageClass();
} return false;
}
TINT_ICE(Writer, diagnostics_) << "unhandled storage class " << sem->StorageClass(); },
return false; [&](const ast::Let* let) { return EmitProgramConstVariable(let); },
[&](const ast::Override* override) { return EmitOverride(override); },
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unhandled global variable type " << global->TypeInfo().name;
return false;
});
} }
bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) { bool GeneratorImpl::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) {
auto* decl = var->Declaration(); auto* type = sem->Type()->UnwrapRef();
auto* type = var->Type()->UnwrapRef();
auto* str = type->As<sem::Struct>(); auto* str = type->As<sem::Struct>();
if (!str) { if (!str) {
TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type"; TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type";
return false; return false;
} }
ast::VariableBindingPoint bp = decl->BindingPoint(); ast::VariableBindingPoint bp = var->BindingPoint();
{ {
auto out = line(); auto out = line();
out << "layout(binding = " << bp.binding->value; out << "layout(binding = " << bp.binding->value;
@ -1949,36 +1954,34 @@ bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
out << ") uniform " << UniqueIdentifier(StructName(str)) << " {"; out << ") uniform " << UniqueIdentifier(StructName(str)) << " {";
} }
EmitStructMembers(current_buffer_, str, /* emit_offsets */ true); EmitStructMembers(current_buffer_, str, /* emit_offsets */ true);
auto name = builder_.Symbols().NameFor(decl->symbol); auto name = builder_.Symbols().NameFor(var->symbol);
line() << "} " << name << ";"; line() << "} " << name << ";";
line(); line();
return true; return true;
} }
bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) { bool GeneratorImpl::EmitStorageVariable(const ast::Var* var, const sem::Variable* sem) {
auto* decl = var->Declaration(); auto* type = sem->Type()->UnwrapRef();
auto* type = var->Type()->UnwrapRef();
auto* str = type->As<sem::Struct>(); auto* str = type->As<sem::Struct>();
if (!str) { if (!str) {
TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type"; TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type";
return false; return false;
} }
ast::VariableBindingPoint bp = decl->BindingPoint(); ast::VariableBindingPoint bp = var->BindingPoint();
line() << "layout(binding = " << bp.binding->value << ", std430) buffer " line() << "layout(binding = " << bp.binding->value << ", std430) buffer "
<< UniqueIdentifier(StructName(str)) << " {"; << UniqueIdentifier(StructName(str)) << " {";
EmitStructMembers(current_buffer_, str, /* emit_offsets */ true); EmitStructMembers(current_buffer_, str, /* emit_offsets */ true);
auto name = builder_.Symbols().NameFor(decl->symbol); auto name = builder_.Symbols().NameFor(var->symbol);
line() << "} " << name << ";"; line() << "} " << name << ";";
return true; return true;
} }
bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) { bool GeneratorImpl::EmitHandleVariable(const ast::Var* var, const sem::Variable* sem) {
auto* decl = var->Declaration();
auto out = line(); auto out = line();
auto name = builder_.Symbols().NameFor(decl->symbol); auto name = builder_.Symbols().NameFor(var->symbol);
auto* type = var->Type()->UnwrapRef(); auto* type = sem->Type()->UnwrapRef();
if (type->Is<sem::Sampler>()) { if (type->Is<sem::Sampler>()) {
// GLSL ignores Sampler variables. // GLSL ignores Sampler variables.
return true; return true;
@ -1986,7 +1989,7 @@ bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
if (auto* storage = type->As<sem::StorageTexture>()) { if (auto* storage = type->As<sem::StorageTexture>()) {
out << "layout(" << convert_texel_format_to_glsl(storage->texel_format()) << ") "; out << "layout(" << convert_texel_format_to_glsl(storage->texel_format()) << ") ";
} }
if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) { if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), name)) {
return false; return false;
} }
@ -2138,7 +2141,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (wgsize[i].overridable_const) { if (wgsize[i].overridable_const) {
auto* global = builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const); auto* global = builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
if (!global->IsOverridable()) { if (!global->Declaration()->Is<ast::Override>()) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
@ -2652,7 +2655,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
return EmitSwitch(s); return EmitSwitch(s);
} }
if (auto* v = stmt->As<ast::VariableDeclStatement>()) { if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(v->variable); return Switch(
v->variable, //
[&](const ast::Var* var) { return EmitVar(var); },
[&](const ast::Let* let) { return EmitLet(let); },
[&](Default) { //
TINT_ICE(Writer, diagnostics_)
<< "unknown variable type: " << v->variable->TypeInfo().name;
return false;
});
} }
diagnostics_.add_error(diag::System::Writer, diagnostics_.add_error(diag::System::Writer,
@ -2934,18 +2945,11 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression*
return true; return true;
} }
bool GeneratorImpl::EmitVariable(const ast::Variable* var) { bool GeneratorImpl::EmitVar(const ast::Var* var) {
auto* sem = builder_.Sem().Get(var); auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type()->UnwrapRef(); auto* type = sem->Type()->UnwrapRef();
// TODO(dsinclair): Handle variable attributes
if (!var->attributes.empty()) {
diagnostics_.add_error(diag::System::Writer, "Variable attributes are not handled yet");
return false;
}
auto out = line(); auto out = line();
// TODO(senorblanco): handle const
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) { builder_.Symbols().NameFor(var->symbol))) {
return false; return false;
@ -2967,58 +2971,74 @@ bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
return true; return true;
} }
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { bool GeneratorImpl::EmitLet(const ast::Let* let) {
for (auto* d : var->attributes) { auto* sem = builder_.Sem().Get(let);
if (!d->Is<ast::IdAttribute>()) { auto* type = sem->Type()->UnwrapRef();
diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid");
return false; auto out = line();
} // TODO(senorblanco): handle const
} if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
if (!var->is_const) { builder_.Symbols().NameFor(let->symbol))) {
diagnostics_.add_error(diag::System::Writer, "Expected a const value");
return false; return false;
} }
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
auto* sem = builder_.Sem().Get(var); auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type(); auto* type = sem->Type();
auto out = line();
out << "const ";
if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(var->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, var->constructor)) {
return false;
}
out << ";";
return true;
}
bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto* sem = builder_.Sem().Get(override);
auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>(); auto* global = sem->As<sem::GlobalVariable>();
if (global && global->IsOverridable()) { auto const_id = global->ConstantId();
auto const_id = global->ConstantId();
line() << "#ifndef " << kSpecConstantPrefix << const_id; line() << "#ifndef " << kSpecConstantPrefix << const_id;
if (var->constructor != nullptr) { if (override->constructor != nullptr) {
auto out = line(); auto out = line();
out << "#define " << kSpecConstantPrefix << const_id << " "; out << "#define " << kSpecConstantPrefix << const_id << " ";
if (!EmitExpression(out, var->constructor)) { if (!EmitExpression(out, override->constructor)) {
return false; return false;
}
} else {
line() << "#error spec constant required for constant id " << const_id;
}
line() << "#endif";
{
auto out = line();
out << "const ";
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
return false;
}
out << " = " << kSpecConstantPrefix << const_id << ";";
} }
} else { } else {
line() << "#error spec constant required for constant id " << const_id;
}
line() << "#endif";
{
auto out = line(); auto out = line();
out << "const "; out << "const ";
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(var->symbol))) { builder_.Symbols().NameFor(override->symbol))) {
return false; return false;
} }
out << " = "; out << " = " << kSpecConstantPrefix << const_id << ";";
if (!EmitExpression(out, var->constructor)) {
return false;
}
out << ";";
} }
return true; return true;

View File

@ -293,19 +293,22 @@ class GeneratorImpl : public TextGenerator {
bool EmitGlobalVariable(const ast::Variable* global); bool EmitGlobalVariable(const ast::Variable* global);
/// Handles emitting a global variable with the uniform storage class /// Handles emitting a global variable with the uniform storage class
/// @param var the global variable /// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success /// @returns true on success
bool EmitUniformVariable(const sem::Variable* var); bool EmitUniformVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the storage storage class /// Handles emitting a global variable with the storage storage class
/// @param var the global variable /// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success /// @returns true on success
bool EmitStorageVariable(const sem::Variable* var); bool EmitStorageVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the handle storage class /// Handles emitting a global variable with the handle storage class
/// @param var the global variable /// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success /// @returns true on success
bool EmitHandleVariable(const sem::Variable* var); bool EmitHandleVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the private storage class /// Handles emitting a global variable with the private storage class
/// @param var the global variable /// @param var the global variable
@ -437,14 +440,22 @@ class GeneratorImpl : public TextGenerator {
/// @param type the type to emit the value for /// @param type the type to emit the value for
/// @returns true if the zero value was successfully emitted. /// @returns true if the zero value was successfully emitted.
bool EmitZeroValue(std::ostream& out, const sem::Type* type); bool EmitZeroValue(std::ostream& out, const sem::Type* type);
/// Handles generating a variable /// Handles generating a 'var' declaration
/// @param var the variable to generate /// @param var the variable to generate
/// @returns true if the variable was emitted /// @returns true if the variable was emitted
bool EmitVariable(const ast::Variable* var); bool EmitVar(const ast::Var* var);
/// Handles generating a program scope constant variable /// Handles generating a function-scope 'let' declaration
/// @param var the variable to emit /// @param let the variable to generate
/// @returns true if the variable was emitted /// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* var); bool EmitLet(const ast::Let* let);
/// Handles generating a module-scope 'let' declaration
/// @param let the 'let' to emit
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* let);
/// Handles generating a module-scope 'override' declaration
/// @param override the 'override' to emit
/// @returns true if the variable was emitted
bool EmitOverride(const ast::Override* override);
/// Handles generating a builtin method name /// Handles generating a builtin method name
/// @param builtin the semantic info for the builtin /// @param builtin the semantic info for the builtin
/// @returns the name or "" if not valid /// @returns the name or "" if not valid

View File

@ -40,7 +40,7 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#define WGSL_SPEC_CONSTANT_23 3.0f #define WGSL_SPEC_CONSTANT_23 3.0f
#endif #endif
@ -56,7 +56,7 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#error spec constant required for constant id 23 #error spec constant required for constant id 23
#endif #endif
@ -73,8 +73,8 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(a)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(a)) << gen.error();
ASSERT_TRUE(gen.EmitProgramConstVariable(b)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(b)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0 EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0
#define WGSL_SPEC_CONSTANT_0 3.0f #define WGSL_SPEC_CONSTANT_0 3.0f
#endif #endif

View File

@ -365,7 +365,7 @@ bool GeneratorImpl::EmitDynamicVectorAssignment(const ast::AssignmentStatement*
out << "vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;"; out << "vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;";
break; break;
default: default:
TINT_UNREACHABLE(Writer, builder_.Diagnostics()) TINT_UNREACHABLE(Writer, diagnostics_)
<< "invalid vector size " << vec->Width(); << "invalid vector size " << vec->Width();
break; break;
} }
@ -524,7 +524,7 @@ bool GeneratorImpl::EmitDynamicMatrixScalarAssignment(const ast::AssignmentState
<< vec_name << ";"; << vec_name << ";";
break; break;
default: default:
TINT_UNREACHABLE(Writer, builder_.Diagnostics()) TINT_UNREACHABLE(Writer, diagnostics_)
<< "invalid vector size " << vec->Width(); << "invalid vector size " << vec->Width();
break; break;
} }
@ -2861,41 +2861,46 @@ bool GeneratorImpl::EmitFunctionBodyWithDiscard(const ast::Function* func) {
} }
bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) { bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
if (global->is_const) { return Switch(
return EmitProgramConstVariable(global); global, //
} [&](const ast::Var* var) {
auto* sem = builder_.Sem().Get(global);
auto* sem = builder_.Sem().Get(global); switch (sem->StorageClass()) {
switch (sem->StorageClass()) { case ast::StorageClass::kUniform:
case ast::StorageClass::kUniform: return EmitUniformVariable(var, sem);
return EmitUniformVariable(sem); case ast::StorageClass::kStorage:
case ast::StorageClass::kStorage: return EmitStorageVariable(var, sem);
return EmitStorageVariable(sem); case ast::StorageClass::kHandle:
case ast::StorageClass::kHandle: return EmitHandleVariable(var, sem);
return EmitHandleVariable(sem); case ast::StorageClass::kPrivate:
case ast::StorageClass::kPrivate: return EmitPrivateVariable(sem);
return EmitPrivateVariable(sem); case ast::StorageClass::kWorkgroup:
case ast::StorageClass::kWorkgroup: return EmitWorkgroupVariable(sem);
return EmitWorkgroupVariable(sem); default:
default: TINT_ICE(Writer, diagnostics_)
break; << "unhandled storage class " << sem->StorageClass();
} return false;
}
TINT_ICE(Writer, diagnostics_) << "unhandled storage class " << sem->StorageClass(); },
return false; [&](const ast::Let* let) { return EmitProgramConstVariable(let); },
[&](const ast::Override* override) { return EmitOverride(override); },
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unhandled global variable type " << global->TypeInfo().name;
return false;
});
} }
bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) { bool GeneratorImpl::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) {
auto* decl = var->Declaration(); auto binding_point = var->BindingPoint();
auto binding_point = decl->BindingPoint(); auto* type = sem->Type()->UnwrapRef();
auto* type = var->Type()->UnwrapRef(); auto name = builder_.Symbols().NameFor(var->symbol);
auto name = builder_.Symbols().NameFor(decl->symbol);
line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point) << " {"; line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point) << " {";
{ {
ScopedIndent si(this); ScopedIndent si(this);
auto out = line(); auto out = line();
if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, var->Access(), name)) { if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, sem->Access(), name)) {
return false; return false;
} }
out << ";"; out << ";";
@ -2906,29 +2911,27 @@ bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
return true; return true;
} }
bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) { bool GeneratorImpl::EmitStorageVariable(const ast::Var* var, const sem::Variable* sem) {
auto* decl = var->Declaration(); auto* type = sem->Type()->UnwrapRef();
auto* type = var->Type()->UnwrapRef();
auto out = line(); auto out = line();
if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, var->Access(), if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, sem->Access(),
builder_.Symbols().NameFor(decl->symbol))) { builder_.Symbols().NameFor(var->symbol))) {
return false; return false;
} }
out << RegisterAndSpace(var->Access() == ast::Access::kRead ? 't' : 'u', decl->BindingPoint()) out << RegisterAndSpace(sem->Access() == ast::Access::kRead ? 't' : 'u', var->BindingPoint())
<< ";"; << ";";
return true; return true;
} }
bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) { bool GeneratorImpl::EmitHandleVariable(const ast::Var* var, const sem::Variable* sem) {
auto* decl = var->Declaration(); auto* unwrapped_type = sem->Type()->UnwrapRef();
auto* unwrapped_type = var->Type()->UnwrapRef();
auto out = line(); auto out = line();
auto name = builder_.Symbols().NameFor(decl->symbol); auto name = builder_.Symbols().NameFor(var->symbol);
auto* type = var->Type()->UnwrapRef(); auto* type = sem->Type()->UnwrapRef();
if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) { if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), name)) {
return false; return false;
} }
@ -2944,7 +2947,7 @@ bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
} }
if (register_space) { if (register_space) {
auto bp = decl->BindingPoint(); auto bp = var->BindingPoint();
out << " : register(" << register_space << bp.binding->value << ", space" << bp.group->value out << " : register(" << register_space << bp.binding->value << ", space" << bp.group->value
<< ")"; << ")";
} }
@ -3078,8 +3081,8 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (wgsize[i].overridable_const) { if (wgsize[i].overridable_const) {
auto* global = auto* global =
builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const); builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
if (!global->IsOverridable()) { if (!global->Declaration()->Is<ast::Override>()) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, diagnostics_)
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
out << kSpecConstantPrefix << global->ConstantId(); out << kSpecConstantPrefix << global->ConstantId();
@ -3611,7 +3614,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
return EmitSwitch(s); return EmitSwitch(s);
}, },
[&](const ast::VariableDeclStatement* v) { // [&](const ast::VariableDeclStatement* v) { //
return EmitVariable(v->variable); return Switch(
v->variable, //
[&](const ast::Var* var) { return EmitVar(var); },
[&](const ast::Let* let) { return EmitLet(let); },
[&](Default) { //
TINT_ICE(Writer, diagnostics_)
<< "unknown variable type: " << v->variable->TypeInfo().name;
return false;
});
}, },
[&](Default) { // [&](Default) { //
diagnostics_.add_error(diag::System::Writer, diagnostics_.add_error(diag::System::Writer,
@ -4018,20 +4029,11 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression*
return true; return true;
} }
bool GeneratorImpl::EmitVariable(const ast::Variable* var) { bool GeneratorImpl::EmitVar(const ast::Var* var) {
auto* sem = builder_.Sem().Get(var); auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type()->UnwrapRef(); auto* type = sem->Type()->UnwrapRef();
// TODO(dsinclair): Handle variable attributes
if (!var->attributes.empty()) {
diagnostics_.add_error(diag::System::Writer, "Variable attributes are not handled yet");
return false;
}
auto out = line(); auto out = line();
if (var->is_const) {
out << "const ";
}
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) { builder_.Symbols().NameFor(var->symbol))) {
return false; return false;
@ -4053,60 +4055,71 @@ bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
return true; return true;
} }
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { bool GeneratorImpl::EmitLet(const ast::Let* let) {
for (auto* d : var->attributes) { auto* sem = builder_.Sem().Get(let);
if (!d->Is<ast::IdAttribute>()) { auto* type = sem->Type()->UnwrapRef();
diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid");
return false; auto out = line();
} out << "const ";
} if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
if (!var->is_const) { builder_.Symbols().NameFor(let->symbol))) {
diagnostics_.add_error(diag::System::Writer, "Expected a const value");
return false; return false;
} }
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
auto* sem = builder_.Sem().Get(var); return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Let* let) {
auto* sem = builder_.Sem().Get(let);
auto* type = sem->Type(); auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>(); auto out = line();
if (global && global->IsOverridable()) { out << "static const ";
auto const_id = global->ConstantId(); if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(let->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
line() << "#ifndef " << kSpecConstantPrefix << const_id; return true;
}
if (var->constructor != nullptr) { bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto out = line(); auto* sem = builder_.Sem().Get(override);
out << "#define " << kSpecConstantPrefix << const_id << " "; auto* type = sem->Type();
if (!EmitExpression(out, var->constructor)) {
return false; auto const_id = sem->ConstantId();
}
} else { line() << "#ifndef " << kSpecConstantPrefix << const_id;
line() << "#error spec constant required for constant id " << const_id;
} if (override->constructor != nullptr) {
line() << "#endif"; auto out = line();
{ out << "#define " << kSpecConstantPrefix << const_id << " ";
auto out = line(); if (!EmitExpression(out, override->constructor)) {
out << "static const "; return false;
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
return false;
}
out << " = " << kSpecConstantPrefix << const_id << ";";
} }
} else { } else {
line() << "#error spec constant required for constant id " << const_id;
}
line() << "#endif";
{
auto out = line(); auto out = line();
out << "static const "; out << "static const ";
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) { builder_.Symbols().NameFor(override->symbol))) {
return false; return false;
} }
out << " = "; out << " = " << kSpecConstantPrefix << const_id << ";";
if (!EmitExpression(out, var->constructor)) {
return false;
}
out << ";";
} }
return true; return true;
} }

View File

@ -303,19 +303,22 @@ class GeneratorImpl : public TextGenerator {
bool EmitGlobalVariable(const ast::Variable* global); bool EmitGlobalVariable(const ast::Variable* global);
/// Handles emitting a global variable with the uniform storage class /// Handles emitting a global variable with the uniform storage class
/// @param var the global variable /// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success /// @returns true on success
bool EmitUniformVariable(const sem::Variable* var); bool EmitUniformVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the storage storage class /// Handles emitting a global variable with the storage storage class
/// @param var the global variable /// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success /// @returns true on success
bool EmitStorageVariable(const sem::Variable* var); bool EmitStorageVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the handle storage class /// Handles emitting a global variable with the handle storage class
/// @param var the global variable /// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success /// @returns true on success
bool EmitHandleVariable(const sem::Variable* var); bool EmitHandleVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the private storage class /// Handles emitting a global variable with the private storage class
/// @param var the global variable /// @param var the global variable
@ -437,14 +440,22 @@ class GeneratorImpl : public TextGenerator {
/// @param type the type to emit the value for /// @param type the type to emit the value for
/// @returns true if the zero value was successfully emitted. /// @returns true if the zero value was successfully emitted.
bool EmitZeroValue(std::ostream& out, const sem::Type* type); bool EmitZeroValue(std::ostream& out, const sem::Type* type);
/// Handles generating a variable /// Handles generating a 'var' declaration
/// @param var the variable to generate /// @param var the variable to generate
/// @returns true if the variable was emitted /// @returns true if the variable was emitted
bool EmitVariable(const ast::Variable* var); bool EmitVar(const ast::Var* var);
/// Handles generating a program scope constant variable /// Handles generating a function-scope 'let' declaration
/// @param var the variable to emit /// @param let the variable to generate
/// @returns true if the variable was emitted /// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* var); bool EmitLet(const ast::Let* let);
/// Handles generating a module-scope 'let' declaration
/// @param let the 'let' to emit
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Let* let);
/// Handles generating a module-scope 'override' declaration
/// @param override the 'override' to emit
/// @returns true if the variable was emitted
bool EmitOverride(const ast::Override* override);
/// Emits call to a helper vector assignment function for the input assignment /// Emits call to a helper vector assignment function for the input assignment
/// statement and vector type. This is used to work around FXC issues where /// statement and vector type. This is used to work around FXC issues where
/// assignments to vectors with dynamic indices cause compilation failures. /// assignments to vectors with dynamic indices cause compilation failures.

View File

@ -40,7 +40,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#define WGSL_SPEC_CONSTANT_23 3.0f #define WGSL_SPEC_CONSTANT_23 3.0f
#endif #endif
@ -56,7 +56,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23 EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#error spec constant required for constant id 23 #error spec constant required for constant id 23
#endif #endif
@ -73,8 +73,8 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(a)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(a)) << gen.error();
ASSERT_TRUE(gen.EmitProgramConstVariable(b)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(b)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0 EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0
#define WGSL_SPEC_CONSTANT_0 3.0f #define WGSL_SPEC_CONSTANT_0 3.0f
#endif #endif

View File

@ -253,16 +253,13 @@ bool GeneratorImpl::Generate() {
[&](const ast::Alias*) { [&](const ast::Alias*) {
return true; // folded away by the writer return true; // folded away by the writer
}, },
[&](const ast::Variable* var) { [&](const ast::Let* let) {
if (var->is_const) { TINT_DEFER(line());
TINT_DEFER(line()); return EmitProgramConstVariable(let);
return EmitProgramConstVariable(var); },
} [&](const ast::Override* override) {
// These are pushed into the entry point by sanitizer transforms. TINT_DEFER(line());
TINT_ICE(Writer, diagnostics_) return EmitOverride(override);
<< "module-scope variables should have been handled by the MSL "
"sanitizer";
return false;
}, },
[&](const ast::Function* func) { [&](const ast::Function* func) {
TINT_DEFER(line()); TINT_DEFER(line());
@ -1866,8 +1863,8 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
// Returns the binding index of a variable, requiring that the group // Returns the binding index of a variable, requiring that the group
// attribute have a value of zero. // attribute have a value of zero.
const uint32_t kInvalidBindingIndex = std::numeric_limits<uint32_t>::max(); const uint32_t kInvalidBindingIndex = std::numeric_limits<uint32_t>::max();
auto get_binding_index = [&](const ast::Variable* var) -> uint32_t { auto get_binding_index = [&](const ast::Parameter* param) -> uint32_t {
auto bp = var->BindingPoint(); auto bp = param->BindingPoint();
if (bp.group == nullptr || bp.binding == nullptr) { if (bp.group == nullptr || bp.binding == nullptr) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "missing binding attributes for entry point parameter"; << "missing binding attributes for entry point parameter";
@ -1890,15 +1887,15 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
// Emit entry point parameters. // Emit entry point parameters.
bool first = true; bool first = true;
for (auto* var : func->params) { for (auto* param : func->params) {
if (!first) { if (!first) {
out << ", "; out << ", ";
} }
first = false; first = false;
auto* type = program_->Sem().Get(var)->Type()->UnwrapRef(); auto* type = program_->Sem().Get(param)->Type()->UnwrapRef();
auto param_name = program_->Symbols().NameFor(var->symbol); auto param_name = program_->Symbols().NameFor(param->symbol);
if (!EmitType(out, type, param_name)) { if (!EmitType(out, type, param_name)) {
return false; return false;
} }
@ -1910,26 +1907,26 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (type->Is<sem::Struct>()) { if (type->Is<sem::Struct>()) {
out << " [[stage_in]]"; out << " [[stage_in]]";
} else if (type->is_handle()) { } else if (type->is_handle()) {
uint32_t binding = get_binding_index(var); uint32_t binding = get_binding_index(param);
if (binding == kInvalidBindingIndex) { if (binding == kInvalidBindingIndex) {
return false; return false;
} }
if (var->type->Is<ast::Sampler>()) { if (param->type->Is<ast::Sampler>()) {
out << " [[sampler(" << binding << ")]]"; out << " [[sampler(" << binding << ")]]";
} else if (var->type->Is<ast::Texture>()) { } else if (param->type->Is<ast::Texture>()) {
out << " [[texture(" << binding << ")]]"; out << " [[texture(" << binding << ")]]";
} else { } else {
TINT_ICE(Writer, diagnostics_) << "invalid handle type entry point parameter"; TINT_ICE(Writer, diagnostics_) << "invalid handle type entry point parameter";
return false; return false;
} }
} else if (auto* ptr = var->type->As<ast::Pointer>()) { } else if (auto* ptr = param->type->As<ast::Pointer>()) {
auto sc = ptr->storage_class; auto sc = ptr->storage_class;
if (sc == ast::StorageClass::kWorkgroup) { if (sc == ast::StorageClass::kWorkgroup) {
auto& allocations = workgroup_allocations_[func_name]; auto& allocations = workgroup_allocations_[func_name];
out << " [[threadgroup(" << allocations.size() << ")]]"; out << " [[threadgroup(" << allocations.size() << ")]]";
allocations.push_back(program_->Sem().Get(ptr->type)->Size()); allocations.push_back(program_->Sem().Get(ptr->type)->Size());
} else if (sc == ast::StorageClass::kStorage || sc == ast::StorageClass::kUniform) { } else if (sc == ast::StorageClass::kStorage || sc == ast::StorageClass::kUniform) {
uint32_t binding = get_binding_index(var); uint32_t binding = get_binding_index(param);
if (binding == kInvalidBindingIndex) { if (binding == kInvalidBindingIndex) {
return false; return false;
} }
@ -1940,7 +1937,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
return false; return false;
} }
} else { } else {
auto& attrs = var->attributes; auto& attrs = param->attributes;
bool builtin_found = false; bool builtin_found = false;
for (auto* attr : attrs) { for (auto* attr : attrs) {
auto* builtin = attr->As<ast::BuiltinAttribute>(); auto* builtin = attr->As<ast::BuiltinAttribute>();
@ -2340,8 +2337,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
return EmitSwitch(s); return EmitSwitch(s);
}, },
[&](const ast::VariableDeclStatement* v) { // [&](const ast::VariableDeclStatement* v) { //
auto* var = program_->Sem().Get(v->variable); return Switch(
return EmitVariable(var); v->variable, //
[&](const ast::Var* var) { return EmitVar(var); },
[&](const ast::Let* let) { return EmitLet(let); },
[&](Default) { //
TINT_ICE(Writer, diagnostics_)
<< "unknown statement type: " << stmt->TypeInfo().name;
return false;
});
}, },
[&](Default) { [&](Default) {
diagnostics_.add_error(diag::System::Writer, diagnostics_.add_error(diag::System::Writer,
@ -2918,19 +2922,13 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression*
return true; return true;
} }
bool GeneratorImpl::EmitVariable(const sem::Variable* var) { bool GeneratorImpl::EmitVar(const ast::Var* var) {
auto* decl = var->Declaration(); auto* sem = program_->Sem().Get(var);
auto* type = sem->Type()->UnwrapRef();
for (auto* attr : decl->attributes) {
if (!attr->Is<ast::InternalAttribute>()) {
TINT_ICE(Writer, diagnostics_) << "unexpected variable attribute";
return false;
}
}
auto out = line(); auto out = line();
switch (var->StorageClass()) { switch (sem->StorageClass()) {
case ast::StorageClass::kFunction: case ast::StorageClass::kFunction:
case ast::StorageClass::kHandle: case ast::StorageClass::kHandle:
case ast::StorageClass::kNone: case ast::StorageClass::kNone:
@ -2946,12 +2944,7 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
return false; return false;
} }
auto* type = var->Type()->UnwrapRef(); std::string name = program_->Symbols().NameFor(var->symbol);
std::string name = program_->Symbols().NameFor(decl->symbol);
if (decl->is_const) {
name = "const " + name;
}
if (!EmitType(out, type, name)) { if (!EmitType(out, type, name)) {
return false; return false;
} }
@ -2960,14 +2953,14 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
out << " " << name; out << " " << name;
} }
if (decl->constructor != nullptr) { if (var->constructor != nullptr) {
out << " = "; out << " = ";
if (!EmitExpression(out, decl->constructor)) { if (!EmitExpression(out, var->constructor)) {
return false; return false;
} }
} else if (var->StorageClass() == ast::StorageClass::kPrivate || } else if (sem->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kFunction || sem->StorageClass() == ast::StorageClass::kFunction ||
var->StorageClass() == ast::StorageClass::kNone) { sem->StorageClass() == ast::StorageClass::kNone) {
out << " = "; out << " = ";
if (!EmitZeroValue(out, type)) { if (!EmitZeroValue(out, type)) {
return false; return false;
@ -2978,34 +2971,63 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
return true; return true;
} }
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) { bool GeneratorImpl::EmitLet(const ast::Let* let) {
for (auto* d : var->attributes) { auto* sem = program_->Sem().Get(let);
if (!d->Is<ast::IdAttribute>()) { auto* type = sem->Type();
diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid");
auto out = line();
switch (sem->StorageClass()) {
case ast::StorageClass::kFunction:
case ast::StorageClass::kHandle:
case ast::StorageClass::kNone:
break;
case ast::StorageClass::kPrivate:
out << "thread ";
break;
case ast::StorageClass::kWorkgroup:
out << "threadgroup ";
break;
default:
TINT_ICE(Writer, diagnostics_) << "unhandled variable storage class";
return false; return false;
}
} }
if (!var->is_const) {
diagnostics_.add_error(diag::System::Writer, "Expected a const value"); std::string name = "const " + program_->Symbols().NameFor(let->symbol);
if (!EmitType(out, type, name)) {
return false; return false;
} }
// Variable name is output as part of the type for arrays and pointers.
if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
out << " " << name;
}
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Let* let) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(let);
auto* type = global->Type();
auto out = line(); auto out = line();
out << "constant "; out << "constant ";
auto* type = program_->Sem().Get(var)->Type()->UnwrapRef(); if (!EmitType(out, type, program_->Symbols().NameFor(let->symbol))) {
if (!EmitType(out, type, program_->Symbols().NameFor(var->symbol))) {
return false; return false;
} }
if (!type->Is<sem::Array>()) { if (!type->Is<sem::Array>()) {
out << " " << program_->Symbols().NameFor(var->symbol); out << " " << program_->Symbols().NameFor(let->symbol);
} }
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); if (let->constructor != nullptr) {
if (global && global->IsOverridable()) {
out << " [[function_constant(" << global->ConstantId() << ")]]";
} else if (var->constructor != nullptr) {
out << " = "; out << " = ";
if (!EmitExpression(out, var->constructor)) { if (!EmitExpression(out, let->constructor)) {
return false; return false;
} }
} }
@ -3014,6 +3036,24 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
return true; return true;
} }
bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(override);
auto* type = global->Type();
auto out = line();
out << "constant ";
if (!EmitType(out, type, program_->Symbols().NameFor(override->symbol))) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << program_->Symbols().NameFor(override->symbol);
}
out << " [[function_constant(" << global->ConstantId() << ")]];";
return true;
}
GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::Type* ty) { GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::Type* ty) {
return Switch( return Switch(
ty, ty,

View File

@ -348,14 +348,22 @@ class GeneratorImpl : public TextGenerator {
/// @param expr the expression to emit /// @param expr the expression to emit
/// @returns true if the expression was emitted /// @returns true if the expression was emitted
bool EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression* expr); bool EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression* expr);
/// Handles generating a variable /// Handles generating a 'var' declaration
/// @param var the variable to generate /// @param var the variable to generate
/// @returns true if the variable was emitted /// @returns true if the variable was emitted
bool EmitVariable(const sem::Variable* var); bool EmitVar(const ast::Var* var);
/// Handles generating a program scope constant variable /// Handles generating a function-scope 'let' declaration
/// @param var the variable to emit /// @param let the variable to generate
/// @returns true if the variable was emitted /// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* var); bool EmitLet(const ast::Let* let);
/// Handles generating a module-scope 'let' declaration
/// @param let the 'let' to emit
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Let* let);
/// Handles generating a module-scope 'override' declaration
/// @param override the 'override' to emit
/// @returns true if the variable was emitted
bool EmitOverride(const ast::Override* override);
/// Emits the zero value for the given type /// Emits the zero value for the given type
/// @param out the output of the expression stream /// @param out the output of the expression stream
/// @param type the type to emit the value for /// @param type the type to emit the value for

View File

@ -39,7 +39,7 @@ TEST_F(MslGeneratorImplTest, Emit_SpecConstant) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), "constant float pos [[function_constant(23)]];\n"); EXPECT_EQ(gen.result(), "constant float pos [[function_constant(23)]];\n");
} }
@ -52,8 +52,8 @@ TEST_F(MslGeneratorImplTest, Emit_SpecConstant_NoId) {
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var_a)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(var_a)) << gen.error();
ASSERT_TRUE(gen.EmitProgramConstVariable(var_b)) << gen.error(); ASSERT_TRUE(gen.EmitOverride(var_b)) << gen.error();
EXPECT_EQ(gen.result(), R"(constant float a [[function_constant(0)]]; EXPECT_EQ(gen.result(), R"(constant float a [[function_constant(0)]];
constant float b [[function_constant(1)]]; constant float b [[function_constant(1)]];
)"); )");

View File

@ -533,7 +533,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
// Make the constant specializable. // Make the constant specializable.
auto* sem_const = auto* sem_const =
builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const); builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
if (!sem_const->IsOverridable()) { if (!sem_const->Declaration()->Is<ast::Override>()) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
@ -692,19 +692,19 @@ uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
}); });
} }
bool Builder::GenerateFunctionVariable(const ast::Variable* var) { bool Builder::GenerateFunctionVariable(const ast::Variable* v) {
uint32_t init_id = 0; uint32_t init_id = 0;
if (var->constructor) { if (v->constructor) {
init_id = GenerateExpressionWithLoadIfNeeded(var->constructor); init_id = GenerateExpressionWithLoadIfNeeded(v->constructor);
if (init_id == 0) { if (init_id == 0) {
return false; return false;
} }
} }
auto* sem = builder_.Sem().Get(var); auto* sem = builder_.Sem().Get(v);
if (var->is_const) { if (auto* let = v->As<ast::Let>()) {
if (!var->constructor) { if (!let->constructor) {
error_ = "missing constructor for constant"; error_ = "missing constructor for constant";
return false; return false;
} }
@ -721,8 +721,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
return false; return false;
} }
push_debug(spv::Op::OpName, push_debug(spv::Op::OpName, {Operand(var_id), Operand(builder_.Symbols().NameFor(v->symbol))});
{Operand(var_id), Operand(builder_.Symbols().NameFor(var->symbol))});
// TODO(dsinclair) We could detect if the constructor is fully const and emit // TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad. // an initializer value for the variable instead of doing the OpLoad.
@ -733,7 +732,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
push_function_var( push_function_var(
{Operand(type_id), result, U32Operand(ConvertStorageClass(sc)), Operand(null_id)}); {Operand(type_id), result, U32Operand(ConvertStorageClass(sc)), Operand(null_id)});
if (var->constructor) { if (v->constructor) {
if (!GenerateStore(var_id, init_id)) { if (!GenerateStore(var_id, init_id)) {
return false; return false;
} }
@ -748,66 +747,61 @@ bool Builder::GenerateStore(uint32_t to, uint32_t from) {
return push_function_inst(spv::Op::OpStore, {Operand(to), Operand(from)}); return push_function_inst(spv::Op::OpStore, {Operand(to), Operand(from)});
} }
bool Builder::GenerateGlobalVariable(const ast::Variable* var) { bool Builder::GenerateGlobalVariable(const ast::Variable* v) {
auto* sem = builder_.Sem().Get(var); auto* sem = builder_.Sem().Get(v);
auto* type = sem->Type()->UnwrapRef(); auto* type = sem->Type()->UnwrapRef();
uint32_t init_id = 0; uint32_t init_id = 0;
if (var->constructor) { if (auto* ctor = v->constructor) {
if (!var->is_overridable) { if (!v->Is<ast::Override>()) {
auto* ctor = builder_.Sem().Get(var->constructor); auto* ctor_sem = builder_.Sem().Get(ctor);
if (auto constant = ctor->ConstantValue()) { if (auto constant = ctor_sem->ConstantValue()) {
init_id = GenerateConstantIfNeeded(std::move(constant)); init_id = GenerateConstantIfNeeded(std::move(constant));
} }
} }
if (init_id == 0) { if (init_id == 0) {
init_id = GenerateConstructorExpression(var, var->constructor); init_id = GenerateConstructorExpression(v, v->constructor);
} }
if (init_id == 0) { if (init_id == 0) {
return false; return false;
} }
} }
if (var->is_const) { if (auto* override = v->As<ast::Override>(); override && !override->constructor) {
if (!var->constructor) { // SPIR-V requires specialization constants to have initializers.
// Constants must have an initializer unless they are overridable. init_id = Switch(
if (!var->is_overridable) { type, //
error_ = "missing constructor for constant"; [&](const sem::F32*) {
return false; ast::FloatLiteralExpression l(ProgramID{}, Source{}, 0,
} ast::FloatLiteralExpression::Suffix::kF);
return GenerateLiteralIfNeeded(override, &l);
// SPIR-V requires specialization constants to have initializers. },
init_id = Switch( [&](const sem::U32*) {
type, // ast::IntLiteralExpression l(ProgramID{}, Source{}, 0,
[&](const sem::F32*) { ast::IntLiteralExpression::Suffix::kU);
ast::FloatLiteralExpression l(ProgramID{}, Source{}, 0, return GenerateLiteralIfNeeded(override, &l);
ast::FloatLiteralExpression::Suffix::kF); },
return GenerateLiteralIfNeeded(var, &l); [&](const sem::I32*) {
}, ast::IntLiteralExpression l(ProgramID{}, Source{}, 0,
[&](const sem::U32*) { ast::IntLiteralExpression::Suffix::kI);
ast::IntLiteralExpression l(ProgramID{}, Source{}, 0, return GenerateLiteralIfNeeded(override, &l);
ast::IntLiteralExpression::Suffix::kU); },
return GenerateLiteralIfNeeded(var, &l); [&](const sem::Bool*) {
}, ast::BoolLiteralExpression l(ProgramID{}, Source{}, false);
[&](const sem::I32*) { return GenerateLiteralIfNeeded(override, &l);
ast::IntLiteralExpression l(ProgramID{}, Source{}, 0, },
ast::IntLiteralExpression::Suffix::kI); [&](Default) {
return GenerateLiteralIfNeeded(var, &l); error_ = "invalid type for pipeline constant ID, must be scalar";
},
[&](const sem::Bool*) {
ast::BoolLiteralExpression l(ProgramID{}, Source{}, false);
return GenerateLiteralIfNeeded(var, &l);
},
[&](Default) {
error_ = "invalid type for pipeline constant ID, must be scalar";
return 0;
});
if (init_id == 0) {
return 0; return 0;
} });
if (init_id == 0) {
return 0;
} }
}
if (v->IsAnyOf<ast::Let, ast::Override>()) {
push_debug(spv::Op::OpName, push_debug(spv::Op::OpName,
{Operand(init_id), Operand(builder_.Symbols().NameFor(var->symbol))}); {Operand(init_id), Operand(builder_.Symbols().NameFor(v->symbol))});
RegisterVariable(sem, init_id); RegisterVariable(sem, init_id);
return true; return true;
@ -824,12 +818,11 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
return false; return false;
} }
push_debug(spv::Op::OpName, push_debug(spv::Op::OpName, {Operand(var_id), Operand(builder_.Symbols().NameFor(v->symbol))});
{Operand(var_id), Operand(builder_.Symbols().NameFor(var->symbol))});
OperandList ops = {Operand(type_id), result, U32Operand(ConvertStorageClass(sc))}; OperandList ops = {Operand(type_id), result, U32Operand(ConvertStorageClass(sc))};
if (var->constructor) { if (v->constructor) {
ops.push_back(Operand(init_id)); ops.push_back(Operand(init_id));
} else { } else {
auto* st = type->As<sem::StorageTexture>(); auto* st = type->As<sem::StorageTexture>();
@ -871,7 +864,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
push_type(spv::Op::OpVariable, std::move(ops)); push_type(spv::Op::OpVariable, std::move(ops));
for (auto* attr : var->attributes) { for (auto* attr : v->attributes) {
bool ok = Switch( bool ok = Switch(
attr, attr,
[&](const ast::BuiltinAttribute* builtin) { [&](const ast::BuiltinAttribute* builtin) {
@ -1332,7 +1325,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call,
// Generate the zero initializer if there are no values provided. // Generate the zero initializer if there are no values provided.
if (args.empty()) { if (args.empty()) {
if (global_var && global_var->IsOverridable()) { if (global_var && global_var->Declaration()->Is<ast::Override>()) {
auto constant_id = global_var->ConstantId(); auto constant_id = global_var->ConstantId();
if (result_type->Is<sem::I32>()) { if (result_type->Is<sem::I32>()) {
return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id)); return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id));
@ -1637,7 +1630,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
ScalarConstant constant; ScalarConstant constant;
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var); auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsOverridable()) { if (global && global->Declaration()->Is<ast::Override>()) {
constant.is_spec_op = true; constant.is_spec_op = true;
constant.constant_id = global->ConstantId(); constant.constant_id = global->ConstantId();
} }

View File

@ -635,46 +635,60 @@ bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
return true; return true;
} }
bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* var) { bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* v) {
if (!var->attributes.empty()) { if (!v->attributes.empty()) {
if (!EmitAttributes(out, var->attributes)) { if (!EmitAttributes(out, v->attributes)) {
return false; return false;
} }
out << " "; out << " ";
} }
if (var->is_overridable) { bool ok = Switch(
out << "override"; v, //
} else if (var->is_const) { [&](const ast::Let* ) {
out << "let"; out << "let";
} else { return true;
out << "var"; },
auto sc = var->declared_storage_class; [&](const ast::Override* ) {
auto ac = var->declared_access; out << "override";
if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) { return true;
out << "<" << sc; },
if (ac != ast::Access::kUndefined) { [&](const ast::Var* var) {
out << ", "; out << "var";
if (!EmitAccess(out, ac)) { auto sc = var->declared_storage_class;
return false; auto ac = var->declared_access;
if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) {
out << "<" << sc;
if (ac != ast::Access::kUndefined) {
out << ", ";
if (!EmitAccess(out, ac)) {
return false;
}
} }
out << ">";
} }
out << ">"; return true;
} },
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled variable type " << v->TypeInfo().name;
return false;
});
if (!ok) {
return false;
} }
out << " " << program_->Symbols().NameFor(var->symbol); out << " " << program_->Symbols().NameFor(v->symbol);
if (auto* ty = var->type) { if (auto* ty = v->type) {
out << " : "; out << " : ";
if (!EmitType(out, ty)) { if (!EmitType(out, ty)) {
return false; return false;
} }
} }
if (var->constructor != nullptr) { if (v->constructor != nullptr) {
out << " = "; out << " = ";
if (!EmitExpression(out, var->constructor)) { if (!EmitExpression(out, v->constructor)) {
return false; return false;
} }
} }

View File

@ -1,4 +1,5 @@
static const uint width = 128u; static const uint width = 128u;
Texture2D tex : register(t0, space0); Texture2D tex : register(t0, space0);
RWByteAddressBuffer result : register(u1, space0); RWByteAddressBuffer result : register(u1, space0);

View File

@ -45,6 +45,7 @@ static const uint ColPerThread = 4u;
static const uint TileAOuter = 64u; static const uint TileAOuter = 64u;
static const uint TileBOuter = 64u; static const uint TileBOuter = 64u;
static const uint TileInner = 64u; static const uint TileInner = 64u;
groupshared float mm_Asub[64][64]; groupshared float mm_Asub[64][64];
groupshared float mm_Bsub[64][64]; groupshared float mm_Bsub[64][64];