tint/ast: Derive off `ast::Variable`

Add the new classes:
* `ast::Let`
* `ast::Override`
* `ast::Parameter`
* `ast::Var`

Limit the fields to those that are only applicable for their type.

Note: The resolver and validator is a tangled mess for each of the
variable types. This CL tries to keep the functionality exactly the
same. I'll clean this up in another change.

Bug: tint:1582
Change-Id: Iee83324167ffd4d92ae3032b2134677629c90079
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93780
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2022-06-17 12:48:51 +00:00 committed by Dawn LUCI CQ
parent be6fb2a8d2
commit dcdf66ed5b
66 changed files with 1652 additions and 1193 deletions

View File

@ -264,6 +264,8 @@ libtint_source_set("libtint_core_all_src") {
"ast/interpolate_attribute.h",
"ast/invariant_attribute.cc",
"ast/invariant_attribute.h",
"ast/let.cc",
"ast/let.h",
"ast/literal_expression.cc",
"ast/literal_expression.h",
"ast/location_attribute.cc",
@ -280,6 +282,10 @@ libtint_source_set("libtint_core_all_src") {
"ast/multisampled_texture.h",
"ast/node.cc",
"ast/node.h",
"ast/override.cc",
"ast/override.h",
"ast/parameter.cc",
"ast/parameter.h",
"ast/phony_expression.cc",
"ast/phony_expression.h",
"ast/pipeline_stage.cc",
@ -328,6 +334,8 @@ libtint_source_set("libtint_core_all_src") {
"ast/unary_op.h",
"ast/unary_op_expression.cc",
"ast/unary_op_expression.h",
"ast/var.cc",
"ast/var.h",
"ast/variable.cc",
"ast/variable.h",
"ast/variable_decl_statement.cc",

View File

@ -151,6 +151,8 @@ set(TINT_LIB_SRCS
ast/interpolate_attribute.h
ast/invariant_attribute.cc
ast/invariant_attribute.h
ast/let.cc
ast/let.h
ast/literal_expression.cc
ast/literal_expression.h
ast/location_attribute.cc
@ -167,6 +169,10 @@ set(TINT_LIB_SRCS
ast/multisampled_texture.h
ast/node.cc
ast/node.h
ast/override.cc
ast/override.h
ast/parameter.cc
ast/parameter.h
ast/phony_expression.cc
ast/phony_expression.h
ast/pipeline_stage.cc
@ -215,6 +221,8 @@ set(TINT_LIB_SRCS
ast/unary_op_expression.h
ast/unary_op.cc
ast/unary_op.h
ast/var.cc
ast/var.h
ast/variable_decl_statement.cc
ast/variable_decl_statement.h
ast/variable.cc

View File

@ -40,7 +40,7 @@ Function::Function(ProgramID pid,
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
for (auto* param : params) {
TINT_ASSERT(AST, param && param->is_const);
TINT_ASSERT(AST, tint::Is<Parameter>(param));
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, param, program_id);
}
TINT_ASSERT(AST, symbol.IsValid());

View File

@ -26,14 +26,11 @@
#include "src/tint/ast/builtin_attribute.h"
#include "src/tint/ast/group_attribute.h"
#include "src/tint/ast/location_attribute.h"
#include "src/tint/ast/parameter.h"
#include "src/tint/ast/pipeline_stage.h"
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// ParameterList is a list of function parameters
using ParameterList = std::vector<const Variable*>;
/// A Function statement.
class Function final : public Castable<Function, Node> {
public:

View File

@ -122,18 +122,6 @@ TEST_F(FunctionTest, Assert_DifferentProgramID_ReturnAttr) {
"internal compiler error");
}
TEST_F(FunctionTest, Assert_NonConstParam) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
ParameterList params;
params.push_back(b.Var("var", b.ty.i32(), ast::StorageClass::kNone));
b.Func("f", params, b.ty.void_(), {});
},
"internal compiler error");
}
using FunctionListTest = TestHelper;
TEST_F(FunctionListTest, FindSymbol) {

46
src/tint/ast/let.cc Normal file
View File

@ -0,0 +1,46 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/let.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Let);
namespace tint::ast {
Let::Let(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
const Expression* ctor,
AttributeList attrs)
: Base(pid, src, sym, ty, ctor, attrs) {
TINT_ASSERT(AST, ctor != nullptr);
}
Let::Let(Let&&) = default;
Let::~Let() = default;
const Let* Let::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Let>(src, sym, ty, ctor, attrs);
}
} // namespace tint::ast

61
src/tint/ast/let.h Normal file
View File

@ -0,0 +1,61 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_LET_H_
#define SRC_TINT_AST_LET_H_
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// A "let" declaration is a name for a function-scoped runtime typed value.
///
/// Examples:
///
/// ```
/// let twice_depth : i32 = width + width; // Must have initializer
/// ```
/// @see https://www.w3.org/TR/WGSL/#let-decls
class Let final : public Castable<Let, Variable> {
public:
/// Create a 'let' variable
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param constructor the constructor expression
/// @param attributes the variable attributes
Let(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
const Expression* constructor,
AttributeList attributes);
/// Move constructor
Let(Let&&);
/// Destructor
~Let() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Let* Clone(CloneContext* ctx) const override;
};
} // namespace tint::ast
#endif // SRC_TINT_AST_LET_H_

View File

@ -77,6 +77,19 @@ class Module final : public Castable<Module, Node> {
/// @returns the global variables for the module
VariableList& GlobalVariables() { return global_variables_; }
/// @returns the global variable declarations of kind 'T' for the module
template <typename T, typename = traits::EnableIfIsType<T, ast::Variable>>
std::vector<const T*> Globals() const {
std::vector<const T*> out;
out.reserve(global_variables_.size());
for (auto* global : global_variables_) {
if (auto* var = global->As<T>()) {
out.emplace_back(var);
}
}
return out;
}
/// @returns the extension set for the module
const EnableList& Enables() const { return enables_; }

44
src/tint/ast/override.cc Normal file
View File

@ -0,0 +1,44 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/override.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Override);
namespace tint::ast {
Override::Override(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
const Expression* ctor,
AttributeList attrs)
: Base(pid, src, sym, ty, ctor, attrs) {}
Override::Override(Override&&) = default;
Override::~Override() = default;
const Override* Override::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Override>(src, sym, ty, ctor, attrs);
}
} // namespace tint::ast

62
src/tint/ast/override.h Normal file
View File

@ -0,0 +1,62 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_OVERRIDE_H_
#define SRC_TINT_AST_OVERRIDE_H_
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// An "override" declaration - a name for a pipeline-overridable constant.
/// Examples:
///
/// ```
/// override radius : i32 = 2; // Can be overridden by name.
/// @id(5) override width : i32 = 2; // Can be overridden by ID.
/// override scale : f32; // No default - must be overridden.
/// ```
/// @see https://www.w3.org/TR/WGSL/#override-decls
class Override final : public Castable<Override, Variable> {
public:
/// Create an 'override' pipeline-overridable constant.
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param constructor the constructor expression
/// @param attributes the variable attributes
Override(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
const Expression* constructor,
AttributeList attributes);
/// Move constructor
Override(Override&&);
/// Destructor
~Override() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Override* Clone(CloneContext* ctx) const override;
};
} // namespace tint::ast
#endif // SRC_TINT_AST_OVERRIDE_H_

42
src/tint/ast/parameter.cc Normal file
View File

@ -0,0 +1,42 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/parameter.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Parameter);
namespace tint::ast {
Parameter::Parameter(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
AttributeList attrs)
: Base(pid, src, sym, ty, nullptr, attrs) {}
Parameter::Parameter(Parameter&&) = default;
Parameter::~Parameter() = default;
const Parameter* Parameter::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Parameter>(src, sym, ty, attrs);
}
} // namespace tint::ast

66
src/tint/ast/parameter.h Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_PARAMETER_H_
#define SRC_TINT_AST_PARAMETER_H_
#include <vector>
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// A formal parameter to a function - a name for a typed value to be passed into a function.
/// Example:
///
/// ```
/// fn twice(a: i32) -> i32 { // "a:i32" is the formal parameter
/// return a + a;
/// }
/// ```
///
/// @see https://www.w3.org/TR/WGSL/#creation-time-consts
class Parameter final : public Castable<Parameter, Variable> {
public:
/// Create a 'parameter' creation-time value variable.
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param attributes the variable attributes
Parameter(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
AttributeList attributes);
/// Move constructor
Parameter(Parameter&&);
/// Destructor
~Parameter() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Parameter* Clone(CloneContext* ctx) const override;
};
/// A list of parameters
using ParameterList = std::vector<const Parameter*>;
} // namespace tint::ast
#endif // SRC_TINT_AST_PARAMETER_H_

49
src/tint/ast/var.cc Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/var.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Var);
namespace tint::ast {
Var::Var(ProgramID pid,
const Source& src,
const Symbol& sym,
const ast::Type* ty,
StorageClass storage_class,
Access access,
const Expression* ctor,
AttributeList attrs)
: Base(pid, src, sym, ty, ctor, attrs),
declared_storage_class(storage_class),
declared_access(access) {}
Var::Var(Var&&) = default;
Var::~Var() = default;
const Var* Var::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Var>(src, sym, ty, declared_storage_class, declared_access, ctor,
attrs);
}
} // namespace tint::ast

86
src/tint/ast/var.h Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_VAR_H_
#define SRC_TINT_AST_VAR_H_
#include <utility>
#include <vector>
#include "src/tint/ast/variable.h"
namespace tint::ast {
/// A "var" declaration is a name for typed storage.
///
/// Examples:
///
/// ```
/// // Declared outside a function, i.e. at module scope, requires
/// // a storage class.
/// var<workgroup> width : i32; // no initializer
/// var<private> height : i32 = 3; // with initializer
///
/// // A variable declared inside a function doesn't take a storage class,
/// // and maps to SPIR-V Function storage.
/// var computed_depth : i32;
/// var area : i32 = compute_area(width, height);
/// ```
///
/// @see https://www.w3.org/TR/WGSL/#var-decls
class Var final : public Castable<Var, Variable> {
public:
/// Create a 'var' variable
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param type the declared variable type
/// @param declared_storage_class the declared storage class
/// @param declared_access the declared access control
/// @param constructor the constructor expression
/// @param attributes the variable attributes
Var(ProgramID program_id,
const Source& source,
const Symbol& sym,
const ast::Type* type,
StorageClass declared_storage_class,
Access declared_access,
const Expression* constructor,
AttributeList attributes);
/// Move constructor
Var(Var&&);
/// Destructor
~Var() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Var* Clone(CloneContext* ctx) const override;
/// The declared storage class
const StorageClass declared_storage_class;
/// The declared access control
const Access declared_access;
};
/// A list of `var` declarations
using VarList = std::vector<const Var*>;
} // namespace tint::ast
#endif // SRC_TINT_AST_VAR_H_

View File

@ -13,9 +13,8 @@
// limitations under the License.
#include "src/tint/ast/variable.h"
#include "src/tint/program_builder.h"
#include "src/tint/sem/variable.h"
#include "src/tint/ast/binding_attribute.h"
#include "src/tint/ast/group_attribute.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::Variable);
@ -24,24 +23,11 @@ namespace tint::ast {
Variable::Variable(ProgramID pid,
const Source& src,
const Symbol& sym,
StorageClass dsc,
Access da,
const ast::Type* ty,
bool constant,
bool overridable,
const Expression* ctor,
AttributeList attrs)
: Base(pid, src),
symbol(sym),
type(ty),
is_const(constant),
is_overridable(overridable),
constructor(ctor),
attributes(std::move(attrs)),
declared_storage_class(dsc),
declared_access(da) {
: Base(pid, src), symbol(sym), type(ty), constructor(ctor), attributes(std::move(attrs)) {
TINT_ASSERT(AST, symbol.IsValid());
TINT_ASSERT(AST, is_overridable ? is_const : true);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, symbol, program_id);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, constructor, program_id);
}
@ -54,23 +40,12 @@ VariableBindingPoint Variable::BindingPoint() const {
const GroupAttribute* group = nullptr;
const BindingAttribute* binding = nullptr;
for (auto* attr : attributes) {
if (auto* g = attr->As<GroupAttribute>()) {
group = g;
} else if (auto* b = attr->As<BindingAttribute>()) {
binding = b;
}
Switch(
attr, //
[&](const GroupAttribute* a) { group = a; },
[&](const BindingAttribute* a) { binding = a; });
}
return VariableBindingPoint{group, binding};
}
const Variable* Variable::Clone(CloneContext* ctx) const {
auto src = ctx->Clone(source);
auto sym = ctx->Clone(symbol);
auto* ty = ctx->Clone(type);
auto* ctor = ctx->Clone(constructor);
auto attrs = ctx->Clone(attributes);
return ctx->dst->create<Variable>(src, sym, declared_storage_class, declared_access, ty,
is_const, is_overridable, ctor, attrs);
}
} // namespace tint::ast

View File

@ -45,112 +45,38 @@ struct VariableBindingPoint {
inline operator bool() const { return group && binding; }
};
/// A Variable statement.
/// Variable is the base class for Var, Let, Const, Override and Parameter.
///
/// An instance of this class represents one of four constructs in WGSL: "var"
/// declaration, "let" declaration, "override" declaration, or formal parameter
/// to a function.
/// An instance of this class represents one of five constructs in WGSL: "var" declaration, "let"
/// declaration, "override" declaration, "const" declaration, or formal parameter to a function.
///
/// 1. A "var" declaration is a name for typed storage. Examples:
///
/// // Declared outside a function, i.e. at module scope, requires
/// // a storage class.
/// var<workgroup> width : i32; // no initializer
/// var<private> height : i32 = 3; // with initializer
///
/// // A variable declared inside a function doesn't take a storage class,
/// // and maps to SPIR-V Function storage.
/// var computed_depth : i32;
/// var area : i32 = compute_area(width, height);
///
/// 2. A "let" declaration is a name for a typed value. Examples:
///
/// let twice_depth : i32 = width + width; // Must have initializer
///
/// 3. An "override" declaration is a name for a pipeline-overridable constant.
/// Examples:
///
/// override radius : i32 = 2; // Can be overridden by name.
/// @id(5) override width : i32 = 2; // Can be overridden by ID.
/// override scale : f32; // No default - must be overridden.
///
/// 4. A formal parameter to a function is a name for a typed value to
/// be passed into a function. Example:
///
/// fn twice(a: i32) -> i32 { // "a:i32" is the formal parameter
/// return a + a;
/// }
///
/// From the WGSL draft, about "var"::
///
/// A variable is a named reference to storage that can contain a value of a
/// particular type.
///
/// Two types are associated with a variable: its store type (the type of
/// value that may be placed in the referenced storage) and its reference
/// type (the type of the variable itself). If a variable has store type T
/// and storage class S, then its reference type is pointer-to-T-in-S.
///
/// This class uses the term "type" to refer to:
/// the value type of a "let",
/// the value type of an "override",
/// the value type of the formal parameter,
/// or the store type of the "var".
//
/// Setting is_const:
/// - "var" gets false
/// - "let" gets true
/// - "override" gets true
/// - formal parameter gets true
///
/// Setting is_overrideable:
/// - "var" gets false
/// - "let" gets false
/// - "override" gets true
/// - formal parameter gets false
///
/// Setting storage class:
/// - "var" is StorageClass::kNone when using the
/// defaulting syntax for a "var" declared inside a function.
/// - "let" is always StorageClass::kNone.
/// - formal parameter is always StorageClass::kNone.
class Variable final : public Castable<Variable, Node> {
/// @see https://www.w3.org/TR/WGSL/#value-decls
class Variable : public Castable<Variable, Node> {
public:
/// Create a variable
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the variable source
/// @param sym the variable symbol
/// @param declared_storage_class the declared storage class
/// @param declared_access the declared access control
/// @param type the declared variable type
/// @param is_const true if the variable is const
/// @param is_overridable true if the variable is pipeline-overridable
/// @param constructor the constructor expression
/// @param attributes the variable attributes
Variable(ProgramID program_id,
const Source& source,
const Symbol& sym,
StorageClass declared_storage_class,
Access declared_access,
const ast::Type* type,
bool is_const,
bool is_overridable,
const Expression* constructor,
AttributeList attributes);
/// Move constructor
Variable(Variable&&);
/// Destructor
~Variable() override;
/// @returns the binding point information for the variable
/// @returns the binding point information from the variable's attributes.
/// @note binding points should only be applied to Var and Parameter types.
VariableBindingPoint BindingPoint() const;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const Variable* Clone(CloneContext* ctx) const override;
/// The variable symbol
const Symbol symbol;
@ -159,23 +85,11 @@ class Variable final : public Castable<Variable, Node> {
/// var i = 1;
const ast::Type* const type;
/// True if this is a constant, false otherwise
const bool is_const;
/// True if this is a pipeline-overridable constant, false otherwise
const bool is_overridable;
/// The constructor expression or nullptr if none set
const Expression* const constructor;
/// The attributes attached to this variable
const AttributeList attributes;
/// The declared storage class
const StorageClass declared_storage_class;
/// The declared access control
const Access declared_access;
};
/// A list of variables

View File

@ -183,7 +183,7 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
auto name = program_->Symbols().NameFor(decl->symbol);
auto* global = var->As<sem::GlobalVariable>();
if (global && global->IsOverridable()) {
if (global && global->Declaration()->Is<ast::Override>()) {
OverridableConstant overridable_constant;
overridable_constant.name = name;
overridable_constant.numeric_id = global->ConstantId();
@ -219,7 +219,7 @@ std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
std::map<uint32_t, Scalar> result;
for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (!global || !global->IsOverridable()) {
if (!global || !global->Declaration()->Is<ast::Override>()) {
continue;
}
@ -276,7 +276,7 @@ std::map<std::string, uint32_t> Inspector::GetConstantNameToIdMap() {
std::map<std::string, uint32_t> result;
for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsOverridable()) {
if (global && global->Declaration()->Is<ast::Override>()) {
auto name = program_->Symbols().NameFor(var->symbol);
result[name] = global->ConstantId();
}
@ -813,25 +813,24 @@ void Inspector::GenerateSamplerTargets() {
auto* t = c->args[texture_index];
auto* s = c->args[sampler_index];
GetOriginatingResources(std::array<const ast::Expression*, 2>{t, s},
[&](std::array<const sem::GlobalVariable*, 2> globals) {
auto* texture = globals[0];
sem::BindingPoint texture_binding_point = {
texture->Declaration()->BindingPoint().group->value,
texture->Declaration()->BindingPoint().binding->value};
GetOriginatingResources(
std::array<const ast::Expression*, 2>{t, s},
[&](std::array<const sem::GlobalVariable*, 2> globals) {
auto* texture = globals[0]->Declaration()->As<ast::Var>();
sem::BindingPoint texture_binding_point = {texture->BindingPoint().group->value,
texture->BindingPoint().binding->value};
auto* sampler = globals[1];
sem::BindingPoint sampler_binding_point = {
sampler->Declaration()->BindingPoint().group->value,
sampler->Declaration()->BindingPoint().binding->value};
auto* sampler = globals[1]->Declaration()->As<ast::Var>();
sem::BindingPoint sampler_binding_point = {sampler->BindingPoint().group->value,
sampler->BindingPoint().binding->value};
for (auto* entry_point : entry_points) {
const auto& ep_name = program_->Symbols().NameFor(
entry_point->Declaration()->symbol);
(*sampler_targets_)[ep_name].add(
{sampler_binding_point, texture_binding_point});
}
});
for (auto* entry_point : entry_points) {
const auto& ep_name =
program_->Symbols().NameFor(entry_point->Declaration()->symbol);
(*sampler_targets_)[ep_name].add(
{sampler_binding_point, texture_binding_point});
}
});
}
}

View File

@ -53,11 +53,14 @@
#include "src/tint/ast/index_accessor_expression.h"
#include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/invariant_attribute.h"
#include "src/tint/ast/let.h"
#include "src/tint/ast/loop_statement.h"
#include "src/tint/ast/matrix.h"
#include "src/tint/ast/member_accessor_expression.h"
#include "src/tint/ast/module.h"
#include "src/tint/ast/multisampled_texture.h"
#include "src/tint/ast/override.h"
#include "src/tint/ast/parameter.h"
#include "src/tint/ast/phony_expression.h"
#include "src/tint/ast/pointer.h"
#include "src/tint/ast/return_statement.h"
@ -73,6 +76,7 @@
#include "src/tint/ast/type_name.h"
#include "src/tint/ast/u32.h"
#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/var.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/vector.h"
#include "src/tint/ast/void.h"
@ -1328,14 +1332,13 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's
/// value.
/// @returns a `ast::Variable` with the given name, type and additional
/// @returns a `ast::Var` with the given name, type and additional
/// options
template <typename NAME, typename... OPTIONAL>
const ast::Variable* Var(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) {
const ast::Var* Var(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) {
VarOptionals opts(std::forward<OPTIONAL>(optional)...);
return create<ast::Variable>(Sym(std::forward<NAME>(name)), opts.storage, opts.access, type,
false /* is_const */, false /* is_overridable */,
opts.constructor, std::move(opts.attributes));
return create<ast::Var>(Sym(std::forward<NAME>(name)), type, opts.storage, opts.access,
opts.constructor, std::move(opts.attributes));
}
/// @param source the variable source
@ -1349,32 +1352,28 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's
/// value.
/// @returns a `ast::Variable` with the given name, storage and type
/// @returns a `ast::Var` with the given name, storage and type
template <typename NAME, typename... OPTIONAL>
const ast::Variable* Var(const Source& source,
NAME&& name,
const ast::Type* type,
OPTIONAL&&... optional) {
const ast::Var* Var(const Source& source,
NAME&& name,
const ast::Type* type,
OPTIONAL&&... optional) {
VarOptionals opts(std::forward<OPTIONAL>(optional)...);
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)), opts.storage,
opts.access, type, false /* is_const */,
false /* is_overridable */, opts.constructor,
std::move(opts.attributes));
return create<ast::Var>(source, Sym(std::forward<NAME>(name)), type, opts.storage,
opts.access, opts.constructor, std::move(opts.attributes));
}
/// @param name the variable name
/// @param type the variable type
/// @param constructor constructor expression
/// @param attributes optional variable attributes
/// @returns an immutable `ast::Variable` with the given name and type
/// @returns an `ast::Let` with the given name and type
template <typename NAME>
const ast::Variable* Let(NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
return create<ast::Variable>(Sym(std::forward<NAME>(name)), ast::StorageClass::kNone,
ast::Access::kUndefined, type, true /* is_const */,
false /* is_overridable */, constructor, attributes);
const ast::Let* Let(NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
return create<ast::Let>(Sym(std::forward<NAME>(name)), type, constructor, attributes);
}
/// @param source the variable source
@ -1382,46 +1381,39 @@ class ProgramBuilder {
/// @param type the variable type
/// @param constructor constructor expression
/// @param attributes optional variable attributes
/// @returns an immutable `ast::Variable` with the given name and type
/// @returns an `ast::Let` with the given name and type
template <typename NAME>
const ast::Variable* Let(const Source& source,
NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, ast::Access::kUndefined, type,
true /* is_const */, false /* is_overridable */, constructor,
attributes);
const ast::Let* Let(const Source& source,
NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
return create<ast::Let>(source, Sym(std::forward<NAME>(name)), type, constructor,
attributes);
}
/// @param name the parameter name
/// @param type the parameter type
/// @param attributes optional parameter attributes
/// @returns an immutable `ast::Variable` with the given name and type
/// @returns an `ast::Parameter` with the given name and type
template <typename NAME>
const ast::Variable* Param(NAME&& name,
const ast::Type* type,
ast::AttributeList attributes = {}) {
return create<ast::Variable>(Sym(std::forward<NAME>(name)), ast::StorageClass::kNone,
ast::Access::kUndefined, type, true /* is_const */,
false /* is_overridable */, nullptr, attributes);
const ast::Parameter* Param(NAME&& name,
const ast::Type* type,
ast::AttributeList attributes = {}) {
return create<ast::Parameter>(Sym(std::forward<NAME>(name)), type, attributes);
}
/// @param source the parameter source
/// @param name the parameter name
/// @param type the parameter type
/// @param attributes optional parameter attributes
/// @returns an immutable `ast::Variable` with the given name and type
/// @returns an `ast::Parameter` with the given name and type
template <typename NAME>
const ast::Variable* Param(const Source& source,
NAME&& name,
const ast::Type* type,
ast::AttributeList attributes = {}) {
return create<ast::Variable>(source, Sym(std::forward<NAME>(name)),
ast::StorageClass::kNone, ast::Access::kUndefined, type,
true /* is_const */, false /* is_overridable */, nullptr,
attributes);
const ast::Parameter* Param(const Source& source,
NAME&& name,
const ast::Type* type,
ast::AttributeList attributes = {}) {
return create<ast::Parameter>(source, Sym(std::forward<NAME>(name)), type, attributes);
}
/// @param name the variable name
@ -1434,10 +1426,10 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's
/// value.
/// @returns a new `ast::Variable`, which is automatically registered as a
/// global variable with the ast::Module.
/// @returns a new `ast::Var`, which is automatically registered as a global variable with the
/// ast::Module.
template <typename NAME, typename... OPTIONAL, typename = DisableIfSource<NAME>>
const ast::Variable* Global(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) {
const ast::Var* Global(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) {
auto* var = Var(std::forward<NAME>(name), type, std::forward<OPTIONAL>(optional)...);
AST().AddGlobalVariable(var);
return var;
@ -1454,13 +1446,13 @@ class ProgramBuilder {
/// * ast::AttributeList - specifies the variable's attributes
/// Note that repeated arguments of the same type will use the last argument's
/// value.
/// @returns a new `ast::Variable`, which is automatically registered as a
/// global variable with the ast::Module.
/// @returns a new `ast::Var`, which is automatically registered as a global variable with the
/// ast::Module.
template <typename NAME, typename... OPTIONAL>
const ast::Variable* Global(const Source& source,
NAME&& name,
const ast::Type* type,
OPTIONAL&&... optional) {
const ast::Var* Global(const Source& source,
NAME&& name,
const ast::Type* type,
OPTIONAL&&... optional) {
auto* var =
Var(source, std::forward<NAME>(name), type, std::forward<OPTIONAL>(optional)...);
AST().AddGlobalVariable(var);
@ -1471,14 +1463,13 @@ class ProgramBuilder {
/// @param type the variable type
/// @param constructor constructor expression
/// @param attributes optional variable attributes
/// @returns a const `ast::Variable` constructed by calling Var() with the
/// arguments of `args`, which is automatically registered as a global
/// variable with the ast::Module.
/// @returns an `ast::Let` constructed by calling Let() with the arguments of `args`, which is
/// automatically registered as a global variable with the ast::Module.
template <typename NAME>
const ast::Variable* GlobalConst(NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
const ast::Let* GlobalConst(NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
auto* var = Let(std::forward<NAME>(name), type, constructor, std::move(attributes));
AST().AddGlobalVariable(var);
return var;
@ -1489,15 +1480,15 @@ class ProgramBuilder {
/// @param type the variable type
/// @param constructor constructor expression
/// @param attributes optional variable attributes
/// @returns a const `ast::Variable` constructed by calling Var() with the
/// @returns a const `ast::Let` constructed by calling Var() with the
/// arguments of `args`, which is automatically registered as a global
/// variable with the ast::Module.
template <typename NAME>
const ast::Variable* GlobalConst(const Source& source,
NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
const ast::Let* GlobalConst(const Source& source,
NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
auto* var = Let(source, std::forward<NAME>(name), type, constructor, std::move(attributes));
AST().AddGlobalVariable(var);
return var;
@ -1507,17 +1498,15 @@ class ProgramBuilder {
/// @param type the variable type
/// @param constructor optional constructor expression
/// @param attributes optional variable attributes
/// @returns an overridable const `ast::Variable` which is automatically
/// registered as a global variable with the ast::Module.
/// @returns an `ast::Override` which is automatically registered as a global variable with the
/// ast::Module.
template <typename NAME>
const ast::Variable* Override(NAME&& name,
const ast::Override* Override(NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
auto* var =
create<ast::Variable>(source_, Sym(std::forward<NAME>(name)), ast::StorageClass::kNone,
ast::Access::kUndefined, type, true /* is_const */,
true /* is_overridable */, constructor, std::move(attributes));
auto* var = create<ast::Override>(source_, Sym(std::forward<NAME>(name)), type, constructor,
std::move(attributes));
AST().AddGlobalVariable(var);
return var;
}
@ -1527,19 +1516,16 @@ class ProgramBuilder {
/// @param type the variable type
/// @param constructor constructor expression
/// @param attributes optional variable attributes
/// @returns a const `ast::Variable` constructed by calling Var() with the
/// arguments of `args`, which is automatically registered as a global
/// variable with the ast::Module.
/// @returns an `ast::Override` constructed with the arguments of `args`, which is automatically
/// registered as a global variable with the ast::Module.
template <typename NAME>
const ast::Variable* Override(const Source& source,
const ast::Override* Override(const Source& source,
NAME&& name,
const ast::Type* type,
const ast::Expression* constructor,
ast::AttributeList attributes = {}) {
auto* var =
create<ast::Variable>(source, Sym(std::forward<NAME>(name)), ast::StorageClass::kNone,
ast::Access::kUndefined, type, true /* is_const */,
true /* is_overridable */, constructor, std::move(attributes));
auto* var = create<ast::Override>(source, Sym(std::forward<NAME>(name)), type, constructor,
std::move(attributes));
AST().AddGlobalVariable(var);
return var;
}

View File

@ -1253,12 +1253,12 @@ bool FunctionEmitter::EmitEntryPointAsWrapper() {
auto* sample_mask_array_type = store_type->UnwrapRef()->UnwrapAlias()->As<Array>();
TINT_ASSERT(Reader, sample_mask_array_type);
ok = EmitPipelineInput(var_name, store_type, &param_decos, {0},
sample_mask_array_type->type, forced_param_type, &(decl.params),
sample_mask_array_type->type, forced_param_type, &decl.params,
&stmts);
} else {
// The normal path.
ok = EmitPipelineInput(var_name, store_type, &param_decos, {}, store_type,
forced_param_type, &(decl.params), &stmts);
forced_param_type, &decl.params, &stmts);
}
if (!ok) {
return false;
@ -1404,8 +1404,7 @@ bool FunctionEmitter::ParseFunctionDeclaration(FunctionDeclaration* decl) {
auto* type = parser_impl_.ConvertType(param->type_id());
if (type != nullptr) {
auto* ast_param =
parser_impl_.MakeVariable(param->result_id(), ast::StorageClass::kNone, type, true,
false, nullptr, ast::AttributeList{});
parser_impl_.MakeParameter(param->result_id(), type, ast::AttributeList{});
// Parameters are treated as const declarations.
ast_params.emplace_back(ast_param);
// The value is accessible by name.
@ -2468,9 +2467,8 @@ bool FunctionEmitter::EmitFunctionVariables() {
return false;
}
}
auto* var =
parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, var_store_type,
false, false, constructor, ast::AttributeList{});
auto* var = parser_impl_.MakeVar(inst.result_id(), ast::StorageClass::kNone, var_store_type,
constructor, ast::AttributeList{});
auto* var_decl_stmt = create<ast::VariableDeclStatement>(Source{}, var);
AddStatement(var_decl_stmt);
auto* var_type = ty_.Reference(var_store_type, ast::StorageClass::kNone);
@ -3328,8 +3326,8 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
TINT_ASSERT(Reader, def_inst);
auto* storage_type = RemapStorageClass(parser_impl_.ConvertType(def_inst->type_id()), id);
AddStatement(create<ast::VariableDeclStatement>(
Source{}, parser_impl_.MakeVariable(id, ast::StorageClass::kNone, storage_type, false,
false, nullptr, ast::AttributeList{})));
Source{}, parser_impl_.MakeVar(id, ast::StorageClass::kNone, storage_type, nullptr,
ast::AttributeList{})));
auto* type = ty_.Reference(storage_type, ast::StorageClass::kNone);
identifier_types_.emplace(id, type);
}
@ -3396,13 +3394,11 @@ bool FunctionEmitter::EmitConstDefinition(const spvtools::opt::Instruction& inst
}
expr = AddressOfIfNeeded(expr, &inst);
auto* ast_const =
parser_impl_.MakeVariable(inst.result_id(), ast::StorageClass::kNone, expr.type, true,
false, expr.expr, ast::AttributeList{});
if (!ast_const) {
auto* let = parser_impl_.MakeLet(inst.result_id(), expr.type, expr.expr);
if (!let) {
return false;
}
AddStatement(create<ast::VariableDeclStatement>(Source{}, ast_const));
AddStatement(create<ast::VariableDeclStatement>(Source{}, let));
identifier_types_.emplace(inst.result_id(), expr.type);
return success();
}

View File

@ -1371,8 +1371,8 @@ bool ParserImpl::EmitScalarSpecConstants() {
break;
}
}
auto* ast_var = MakeVariable(inst.result_id(), ast::StorageClass::kNone, ast_type, true,
true, ast_expr, std::move(spec_id_decos));
auto* ast_var =
MakeOverride(inst.result_id(), ast_type, ast_expr, std::move(spec_id_decos));
if (ast_var) {
builder_.AST().AddGlobalVariable(ast_var);
scalar_spec_constants_.insert(inst.result_id());
@ -1489,8 +1489,8 @@ bool ParserImpl::EmitModuleScopeVariables() {
// here.)
ast_constructor = MakeConstantExpression(var.GetSingleWordInOperand(1)).expr;
}
auto* ast_var = MakeVariable(var.result_id(), ast_storage_class, ast_store_type, false,
false, ast_constructor, ast::AttributeList{});
auto* ast_var = MakeVar(var.result_id(), ast_storage_class, ast_store_type, ast_constructor,
ast::AttributeList{});
// TODO(dneto): initializers (a.k.a. constructor expression)
if (ast_var) {
builder_.AST().AddGlobalVariable(ast_var);
@ -1521,10 +1521,9 @@ bool ParserImpl::EmitModuleScopeVariables() {
}
}
auto* ast_var =
MakeVariable(builtin_position_.per_vertex_var_id,
enum_converter_.ToStorageClass(builtin_position_.storage_class),
ConvertType(builtin_position_.position_member_type_id), false, false,
ast_constructor, {});
MakeVar(builtin_position_.per_vertex_var_id,
enum_converter_.ToStorageClass(builtin_position_.storage_class),
ConvertType(builtin_position_.position_member_type_id), ast_constructor, {});
builder_.AST().AddGlobalVariable(ast_var);
}
@ -1554,13 +1553,11 @@ const spvtools::opt::analysis::IntConstant* ParserImpl::GetArraySize(uint32_t va
return size->AsIntConstant();
}
ast::Variable* ParserImpl::MakeVariable(uint32_t id,
ast::StorageClass sc,
const Type* storage_type,
bool is_const,
bool is_overridable,
const ast::Expression* constructor,
ast::AttributeList decorations) {
ast::Var* ParserImpl::MakeVar(uint32_t id,
ast::StorageClass sc,
const Type* storage_type,
const ast::Expression* constructor,
ast::AttributeList decorations) {
if (storage_type == nullptr) {
Fail() << "internal error: can't make ast::Variable for null type";
return nullptr;
@ -1588,15 +1585,37 @@ ast::Variable* ParserImpl::MakeVariable(uint32_t id,
return nullptr;
}
std::string name = namer_.Name(id);
auto sym = builder_.Symbols().Register(namer_.Name(id));
return create<ast::Var>(Source{}, sym, storage_type->Build(builder_), sc, access, constructor,
decorations);
}
// Note: we're constructing the variable here with the *storage* type,
// regardless of whether this is a `let`, `override`, or `var` declaration.
// `var` declarations will have a resolved type of ref<storage>, but at the
// AST level all three are declared with the same type.
return create<ast::Variable>(Source{}, builder_.Symbols().Register(name), sc, access,
storage_type->Build(builder_), is_const, is_overridable,
constructor, decorations);
ast::Let* ParserImpl::MakeLet(uint32_t id, const Type* type, const ast::Expression* constructor) {
auto sym = builder_.Symbols().Register(namer_.Name(id));
return create<ast::Let>(Source{}, sym, type->Build(builder_), constructor,
ast::AttributeList{});
}
ast::Override* ParserImpl::MakeOverride(uint32_t id,
const Type* type,
const ast::Expression* constructor,
ast::AttributeList decorations) {
if (!ConvertDecorationsForVariable(id, &type, &decorations, false)) {
return nullptr;
}
auto sym = builder_.Symbols().Register(namer_.Name(id));
return create<ast::Override>(Source{}, sym, type->Build(builder_), constructor, decorations);
}
ast::Parameter* ParserImpl::MakeParameter(uint32_t id,
const Type* type,
ast::AttributeList decorations) {
if (!ConvertDecorationsForVariable(id, &type, &decorations, false)) {
return nullptr;
}
auto sym = builder_.Symbols().Register(namer_.Name(id));
return create<ast::Parameter>(Source{}, sym, type->Build(builder_), decorations);
}
bool ParserImpl::ConvertDecorationsForVariable(uint32_t id,

View File

@ -411,25 +411,47 @@ class ParserImpl : Reader {
/// @returns a list of SPIR-V decorations.
DecorationList GetMemberPipelineDecorations(const Struct& struct_type, int member_index);
/// Creates an AST Variable node for a SPIR-V ID, including any attached
/// decorations, unless it's an ignorable builtin variable.
/// Creates an AST 'var' node for a SPIR-V ID, including any attached decorations, unless it's
/// an ignorable builtin variable.
/// @param id the SPIR-V result ID
/// @param sc the storage class, which cannot be ast::StorageClass::kNone
/// @param storage_type the storage type of the variable
/// @param is_const if true, the variable is const
/// @param is_overridable if true, the variable is pipeline-overridable
/// @param constructor the variable constructor
/// @param decorations the variable decorations
/// @returns a new Variable node, or null in the ignorable variable case and
/// in the error case
ast::Variable* MakeVariable(uint32_t id,
ast::StorageClass sc,
const Type* storage_type,
bool is_const,
bool is_overridable,
ast::Var* MakeVar(uint32_t id,
ast::StorageClass sc,
const Type* storage_type,
const ast::Expression* constructor,
ast::AttributeList decorations);
/// Creates an AST 'let' node for a SPIR-V ID, including any attached decorations,.
/// @param id the SPIR-V result ID
/// @param type the type of the variable
/// @param constructor the variable constructor
/// @returns the AST 'let' node
ast::Let* MakeLet(uint32_t id, const Type* type, const ast::Expression* constructor);
/// Creates an AST 'override' node for a SPIR-V ID, including any attached decorations.
/// @param id the SPIR-V result ID
/// @param type the type of the variable
/// @param constructor the variable constructor
/// @param decorations the variable decorations
/// @returns the AST 'override' node
ast::Override* MakeOverride(uint32_t id,
const Type* type,
const ast::Expression* constructor,
ast::AttributeList decorations);
/// Creates an AST parameter node for a SPIR-V ID, including any attached decorations, unless
/// it's an ignorable builtin variable.
/// @param id the SPIR-V result ID
/// @param type the type of the parameter
/// @param decorations the parameter decorations
/// @returns the AST parameter node
ast::Parameter* MakeParameter(uint32_t id, const Type* type, ast::AttributeList decorations);
/// Returns true if a constant expression can be generated.
/// @param id the SPIR-V ID of the value
/// @returns true if a constant expression can be generated

View File

@ -213,7 +213,11 @@ ParserImpl::FunctionHeader::FunctionHeader(Source src,
ast::ParameterList p,
const ast::Type* ret_ty,
ast::AttributeList ret_attrs)
: source(src), name(n), params(p), return_type(ret_ty), return_type_attributes(ret_attrs) {}
: source(src),
name(n),
params(std::move(p)),
return_type(ret_ty),
return_type_attributes(std::move(ret_attrs)) {}
ParserImpl::FunctionHeader::~FunctionHeader() = default;
@ -542,15 +546,13 @@ Maybe<const ast::Variable*> ParserImpl::global_variable_decl(ast::AttributeList&
constructor = expr.value;
}
return create<ast::Variable>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->storage_class, // storage class
decl->access, // access control
decl->type, // type
false, // is_const
false, // is_overridable
constructor, // constructor
std::move(attrs)); // attributes
return create<ast::Var>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->type, // type
decl->storage_class, // storage class
decl->access, // access control
constructor, // constructor
std::move(attrs)); // attributes
}
// global_constant_decl :
@ -564,7 +566,7 @@ Maybe<const ast::Variable*> ParserImpl::global_constant_decl(ast::AttributeList&
if (match(Token::Type::kLet)) {
use = "'let' declaration";
} else if (match(Token::Type::kOverride)) {
use = "override declaration";
use = "'override' declaration";
is_overridable = true;
} else {
return Failure::kNoMatch;
@ -594,15 +596,18 @@ Maybe<const ast::Variable*> ParserImpl::global_constant_decl(ast::AttributeList&
initializer = std::move(init.value);
}
return create<ast::Variable>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
ast::StorageClass::kNone, // storage class
ast::Access::kUndefined, // access control
decl->type, // type
true, // is_const
is_overridable, // is_overridable
initializer, // constructor
std::move(attrs)); // attributes
if (is_overridable) {
return create<ast::Override>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->type, // type
initializer, // constructor
std::move(attrs)); // attributes
}
return create<ast::Let>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->type, // type
initializer, // constructor
std::move(attrs)); // attributes
}
// variable_decl
@ -1478,7 +1483,7 @@ Expect<ast::ParameterList> ParserImpl::expect_param_list() {
// param
// : attribute_list* variable_ident_decl
Expect<ast::Variable*> ParserImpl::expect_param() {
Expect<ast::Parameter*> ParserImpl::expect_param() {
auto attrs = attribute_list();
auto decl = expect_variable_ident_decl("parameter");
@ -1486,21 +1491,10 @@ Expect<ast::Variable*> ParserImpl::expect_param() {
return Failure::kErrored;
}
auto* var = create<ast::Variable>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
ast::StorageClass::kNone, // storage class
ast::Access::kUndefined, // access control
decl->type, // type
true, // is_const
false, // is_overridable
nullptr, // constructor
std::move(attrs.value)); // attributes
// Formal parameters are treated like a const declaration where the
// initializer value is provided by the call's argument. The key point is
// that it's not updatable after initially set. This is unlike C or GLSL
// which treat formal parameters like local variables that can be updated.
return var;
return create<ast::Parameter>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->type, // type
std::move(attrs.value)); // attributes
}
// pipeline_stage
@ -1794,17 +1788,13 @@ Maybe<const ast::VariableDeclStatement*> ParserImpl::variable_stmt() {
return add_error(peek(), "missing constructor for 'let' declaration");
}
auto* var = create<ast::Variable>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
ast::StorageClass::kNone, // storage class
ast::Access::kUndefined, // access control
decl->type, // type
true, // is_const
false, // is_overridable
constructor.value, // constructor
ast::AttributeList{}); // attributes
auto* let = create<ast::Let>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->type, // type
constructor.value, // constructor
ast::AttributeList{}); // attributes
return create<ast::VariableDeclStatement>(decl->source, var);
return create<ast::VariableDeclStatement>(decl->source, let);
}
auto decl = variable_decl(/*allow_inferred = */ true);
@ -1828,15 +1818,13 @@ Maybe<const ast::VariableDeclStatement*> ParserImpl::variable_stmt() {
constructor = constructor_expr.value;
}
auto* var = create<ast::Variable>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->storage_class, // storage class
decl->access, // access control
decl->type, // type
false, // is_const
false, // is_overridable
constructor, // constructor
ast::AttributeList{}); // attributes
auto* var = create<ast::Var>(decl->source, // source
builder_.Symbols().Register(decl->name), // symbol
decl->type, // type
decl->storage_class, // storage class
decl->access, // access control
constructor, // constructor
ast::AttributeList{}); // attributes
return create<ast::VariableDeclStatement>(var->source, var);
}

View File

@ -462,7 +462,7 @@ class ParserImpl {
Expect<ast::ParameterList> expect_param_list();
/// Parses a `param` grammar element, erroring on parse failure.
/// @returns the parsed variable
Expect<ast::Variable*> expect_param();
Expect<ast::Parameter*> expect_param();
/// Parses a `pipeline_stage` grammar element, erroring if the next token does
/// not match a stage name.
/// @returns the pipeline stage.

View File

@ -57,7 +57,7 @@ TEST_F(ForStmtTest, InitializerStatementDecl) {
ASSERT_TRUE(fl.matched);
ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer));
auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable;
EXPECT_FALSE(var->is_const);
EXPECT_TRUE(var->Is<ast::Var>());
EXPECT_EQ(var->constructor, nullptr);
EXPECT_EQ(fl->condition, nullptr);
EXPECT_EQ(fl->continuing, nullptr);
@ -74,7 +74,7 @@ TEST_F(ForStmtTest, InitializerStatementDeclEqual) {
ASSERT_TRUE(fl.matched);
ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer));
auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable;
EXPECT_FALSE(var->is_const);
EXPECT_TRUE(var->Is<ast::Var>());
EXPECT_NE(var->constructor, nullptr);
EXPECT_EQ(fl->condition, nullptr);
EXPECT_EQ(fl->continuing, nullptr);
@ -90,7 +90,7 @@ TEST_F(ForStmtTest, InitializerStatementConstDecl) {
ASSERT_TRUE(fl.matched);
ASSERT_TRUE(Is<ast::VariableDeclStatement>(fl->initializer));
auto* var = fl->initializer->As<ast::VariableDeclStatement>()->variable;
EXPECT_TRUE(var->is_const);
EXPECT_TRUE(var->Is<ast::Let>());
EXPECT_NE(var->constructor, nullptr);
EXPECT_EQ(fl->condition, nullptr);
EXPECT_EQ(fl->continuing, nullptr);

View File

@ -27,21 +27,20 @@ TEST_F(ParserImplTest, GlobalConstantDecl) {
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* let = e.value->As<ast::Let>();
ASSERT_NE(let, nullptr);
EXPECT_TRUE(e->is_const);
EXPECT_FALSE(e->is_overridable);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(let->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(let->type, nullptr);
EXPECT_TRUE(let->type->Is<ast::F32>());
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 5u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 6u);
EXPECT_EQ(let->source.range.begin.line, 1u);
EXPECT_EQ(let->source.range.begin.column, 5u);
EXPECT_EQ(let->source.range.end.line, 1u);
EXPECT_EQ(let->source.range.end.column, 6u);
ASSERT_NE(e->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>());
ASSERT_NE(let->constructor, nullptr);
EXPECT_TRUE(let->constructor->Is<ast::LiteralExpression>());
}
TEST_F(ParserImplTest, GlobalConstantDecl_Inferred) {
@ -53,20 +52,19 @@ TEST_F(ParserImplTest, GlobalConstantDecl_Inferred) {
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* let = e.value->As<ast::Let>();
ASSERT_NE(let, nullptr);
EXPECT_TRUE(e->is_const);
EXPECT_FALSE(e->is_overridable);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(e->type, nullptr);
EXPECT_EQ(let->symbol, p->builder().Symbols().Get("a"));
EXPECT_EQ(let->type, nullptr);
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 5u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 6u);
EXPECT_EQ(let->source.range.begin.line, 1u);
EXPECT_EQ(let->source.range.begin.column, 5u);
EXPECT_EQ(let->source.range.end.line, 1u);
EXPECT_EQ(let->source.range.end.column, 6u);
ASSERT_NE(e->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>());
ASSERT_NE(let->constructor, nullptr);
EXPECT_TRUE(let->constructor->Is<ast::LiteralExpression>());
}
TEST_F(ParserImplTest, GlobalConstantDecl_InvalidExpression) {
@ -105,23 +103,22 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_WithId) {
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(e->is_const);
EXPECT_TRUE(e->is_overridable);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(override->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(override->type, nullptr);
EXPECT_TRUE(override->type->Is<ast::F32>());
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 17u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 18u);
EXPECT_EQ(override->source.range.begin.line, 1u);
EXPECT_EQ(override->source.range.begin.column, 17u);
EXPECT_EQ(override->source.range.end.line, 1u);
EXPECT_EQ(override->source.range.end.column, 18u);
ASSERT_NE(e->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>());
ASSERT_NE(override->constructor, nullptr);
EXPECT_TRUE(override->constructor->Is<ast::LiteralExpression>());
auto* override_attr = ast::GetAttribute<ast::IdAttribute>(e.value->attributes);
auto* override_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes);
ASSERT_NE(override_attr, nullptr);
EXPECT_EQ(override_attr->value, 7u);
}
@ -136,23 +133,22 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_WithoutId) {
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(e->is_const);
EXPECT_TRUE(e->is_overridable);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(override->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(override->type, nullptr);
EXPECT_TRUE(override->type->Is<ast::F32>());
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 10u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 11u);
EXPECT_EQ(override->source.range.begin.line, 1u);
EXPECT_EQ(override->source.range.begin.column, 10u);
EXPECT_EQ(override->source.range.end.line, 1u);
EXPECT_EQ(override->source.range.end.column, 11u);
ASSERT_NE(e->constructor, nullptr);
EXPECT_TRUE(e->constructor->Is<ast::LiteralExpression>());
ASSERT_NE(override->constructor, nullptr);
EXPECT_TRUE(override->constructor->Is<ast::LiteralExpression>());
auto* id_attr = ast::GetAttribute<ast::IdAttribute>(e.value->attributes);
auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes);
ASSERT_EQ(id_attr, nullptr);
}
@ -165,7 +161,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_MissingId) {
auto e = p->global_constant_decl(attrs.value);
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(p->error(), "1:5: expected signed integer literal for id attribute");
@ -180,7 +177,8 @@ TEST_F(ParserImplTest, GlobalConstantDec_Override_InvalidId) {
auto e = p->global_constant_decl(attrs.value);
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* override = e.value->As<ast::Override>();
ASSERT_NE(override, nullptr);
EXPECT_TRUE(p->has_error());
EXPECT_EQ(p->error(), "1:5: id attribute must be positive");

View File

@ -26,18 +26,19 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithoutConstructor) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kPrivate);
EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kPrivate);
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 14u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 15u);
EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(var->source.range.begin.column, 14u);
EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(var->source.range.end.column, 15u);
ASSERT_EQ(e->constructor, nullptr);
ASSERT_EQ(var->constructor, nullptr);
}
TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) {
@ -49,19 +50,20 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithConstructor) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kPrivate);
EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kPrivate);
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 14u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 15u);
EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(var->source.range.begin.column, 14u);
EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(var->source.range.end.column, 15u);
ASSERT_NE(e->constructor, nullptr);
ASSERT_TRUE(e->constructor->Is<ast::FloatLiteralExpression>());
ASSERT_NE(var->constructor, nullptr);
ASSERT_TRUE(var->constructor->Is<ast::FloatLiteralExpression>());
}
TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute) {
@ -73,21 +75,22 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kUniform);
EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(var->type, nullptr);
EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kUniform);
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 36u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 37u);
EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(var->source.range.begin.column, 36u);
EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(var->source.range.end.column, 37u);
ASSERT_EQ(e->constructor, nullptr);
ASSERT_EQ(var->constructor, nullptr);
auto& attributes = e->attributes;
auto& attributes = var->attributes;
ASSERT_EQ(attributes.size(), 2u);
ASSERT_TRUE(attributes[0]->Is<ast::BindingAttribute>());
ASSERT_TRUE(attributes[1]->Is<ast::GroupAttribute>());
@ -103,21 +106,22 @@ TEST_F(ParserImplTest, GlobalVariableDecl_WithAttribute_MulitpleGroups) {
ASSERT_FALSE(p->has_error()) << p->error();
EXPECT_TRUE(e.matched);
EXPECT_FALSE(e.errored);
ASSERT_NE(e.value, nullptr);
auto* var = e.value->As<ast::Var>();
ASSERT_NE(var, nullptr);
EXPECT_EQ(e->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(e->type, nullptr);
EXPECT_TRUE(e->type->Is<ast::F32>());
EXPECT_EQ(e->declared_storage_class, ast::StorageClass::kUniform);
EXPECT_EQ(var->symbol, p->builder().Symbols().Get("a"));
ASSERT_NE(var->type, nullptr);
EXPECT_TRUE(var->type->Is<ast::F32>());
EXPECT_EQ(var->declared_storage_class, ast::StorageClass::kUniform);
EXPECT_EQ(e->source.range.begin.line, 1u);
EXPECT_EQ(e->source.range.begin.column, 36u);
EXPECT_EQ(e->source.range.end.line, 1u);
EXPECT_EQ(e->source.range.end.column, 37u);
EXPECT_EQ(var->source.range.begin.line, 1u);
EXPECT_EQ(var->source.range.begin.column, 36u);
EXPECT_EQ(var->source.range.end.line, 1u);
EXPECT_EQ(var->source.range.end.column, 37u);
ASSERT_EQ(e->constructor, nullptr);
ASSERT_EQ(var->constructor, nullptr);
auto& attributes = e->attributes;
auto& attributes = var->attributes;
ASSERT_EQ(attributes.size(), 2u);
ASSERT_TRUE(attributes[0]->Is<ast::BindingAttribute>());
ASSERT_TRUE(attributes[1]->Is<ast::GroupAttribute>());

View File

@ -27,7 +27,7 @@ TEST_F(ParserImplTest, ParamList_Single) {
EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e.value[0]->type->Is<ast::I32>());
EXPECT_TRUE(e.value[0]->is_const);
EXPECT_TRUE(e.value[0]->Is<ast::Parameter>());
ASSERT_EQ(e.value[0]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[0]->source.range.begin.column, 1u);
@ -45,7 +45,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
EXPECT_EQ(e.value[0]->symbol, p->builder().Symbols().Get("a"));
EXPECT_TRUE(e.value[0]->type->Is<ast::I32>());
EXPECT_TRUE(e.value[0]->is_const);
EXPECT_TRUE(e.value[0]->Is<ast::Parameter>());
ASSERT_EQ(e.value[0]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[0]->source.range.begin.column, 1u);
@ -54,7 +54,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("b"));
EXPECT_TRUE(e.value[1]->type->Is<ast::F32>());
EXPECT_TRUE(e.value[1]->is_const);
EXPECT_TRUE(e.value[1]->Is<ast::Parameter>());
ASSERT_EQ(e.value[1]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[1]->source.range.begin.column, 10u);
@ -65,7 +65,7 @@ TEST_F(ParserImplTest, ParamList_Multiple) {
ASSERT_TRUE(e.value[2]->type->Is<ast::Vector>());
ASSERT_TRUE(e.value[2]->type->As<ast::Vector>()->type->Is<ast::F32>());
EXPECT_EQ(e.value[2]->type->As<ast::Vector>()->width, 2u);
EXPECT_TRUE(e.value[2]->is_const);
EXPECT_TRUE(e.value[2]->Is<ast::Parameter>());
ASSERT_EQ(e.value[2]->source.range.begin.line, 1u);
ASSERT_EQ(e.value[2]->source.range.begin.column, 18u);
@ -101,7 +101,7 @@ TEST_F(ParserImplTest, ParamList_Attributes) {
ASSERT_TRUE(e.value[0]->type->Is<ast::Vector>());
EXPECT_TRUE(e.value[0]->type->As<ast::Vector>()->type->Is<ast::F32>());
EXPECT_EQ(e.value[0]->type->As<ast::Vector>()->width, 4u);
EXPECT_TRUE(e.value[0]->is_const);
EXPECT_TRUE(e.value[0]->Is<ast::Parameter>());
auto attrs_0 = e.value[0]->attributes;
ASSERT_EQ(attrs_0.size(), 1u);
EXPECT_TRUE(attrs_0[0]->Is<ast::BuiltinAttribute>());
@ -114,7 +114,7 @@ TEST_F(ParserImplTest, ParamList_Attributes) {
EXPECT_EQ(e.value[1]->symbol, p->builder().Symbols().Get("loc1"));
EXPECT_TRUE(e.value[1]->type->Is<ast::F32>());
EXPECT_TRUE(e.value[1]->is_const);
EXPECT_TRUE(e.value[1]->Is<ast::Parameter>());
auto attrs_1 = e.value[1]->attributes;
ASSERT_EQ(attrs_1.size(), 1u);
EXPECT_TRUE(attrs_1[0]->Is<ast::LocationAttribute>());

View File

@ -487,11 +487,13 @@ struct DependencyAnalysis {
/// declaration
std::string KindOf(const ast::Node* node) {
return Switch(
node, //
[&](const ast::Struct*) { return "struct"; },
[&](const ast::Alias*) { return "alias"; },
[&](const ast::Function*) { return "function"; },
[&](const ast::Variable* var) { return var->is_const ? "let" : "var"; },
node, //
[&](const ast::Struct*) { return "struct"; }, //
[&](const ast::Alias*) { return "alias"; }, //
[&](const ast::Function*) { return "function"; }, //
[&](const ast::Let*) { return "let"; }, //
[&](const ast::Var*) { return "var"; }, //
[&](const ast::Override*) { return "override"; }, //
[&](Default) {
UnhandledNode(diagnostics_, node);
return "<error>";

View File

@ -31,7 +31,7 @@ class ResolverPipelineOverridableConstantTest : public ResolverTest {
auto* sem = Sem().Get<sem::GlobalVariable>(var);
ASSERT_NE(sem, nullptr);
EXPECT_EQ(sem->Declaration(), var);
EXPECT_TRUE(sem->IsOverridable());
EXPECT_TRUE(sem->Declaration()->Is<ast::Override>());
EXPECT_EQ(sem->ConstantId(), id);
EXPECT_FALSE(sem->ConstantValue());
}
@ -45,7 +45,7 @@ TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
auto* sem_a = Sem().Get<sem::GlobalVariable>(a);
ASSERT_NE(sem_a, nullptr);
EXPECT_EQ(sem_a->Declaration(), a);
EXPECT_FALSE(sem_a->IsOverridable());
EXPECT_FALSE(sem_a->Declaration()->Is<ast::Override>());
EXPECT_TRUE(sem_a->ConstantValue());
}

View File

@ -303,59 +303,63 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
return s;
}
sem::Variable* Resolver::Variable(const ast::Variable* var,
VariableKind kind,
sem::Variable* Resolver::Variable(const ast::Variable* v,
bool is_global,
uint32_t index /* = 0 */) {
const sem::Type* storage_ty = nullptr;
// If the variable has a declared type, resolve it.
if (auto* ty = var->type) {
if (auto* ty = v->type) {
storage_ty = Type(ty);
if (!storage_ty) {
return nullptr;
}
}
auto* as_var = v->As<ast::Var>();
auto* as_let = v->As<ast::Let>();
auto* as_override = v->As<ast::Override>();
auto* as_param = v->As<ast::Parameter>();
const sem::Expression* rhs = nullptr;
// Does the variable have a constructor?
if (var->constructor) {
rhs = Materialize(Expression(var->constructor), storage_ty);
if (v->constructor) {
rhs = Materialize(Expression(v->constructor), storage_ty);
if (!rhs) {
return nullptr;
}
// If the variable has no declared type, infer it from the RHS
if (!storage_ty) {
if (!var->is_const && kind == VariableKind::kGlobal) {
AddError("module-scope 'var' declaration must specify a type", var->source);
if (as_var && is_global) {
AddError("module-scope 'var' declaration must specify a type", v->source);
return nullptr;
}
storage_ty = rhs->Type()->UnwrapRef(); // Implicit load of RHS
}
} else if (var->is_const && !var->is_overridable && kind != VariableKind::kParameter) {
AddError("'let' declaration must have an initializer", var->source);
} else if (as_let) {
AddError("'let' declaration must have an initializer", v->source);
return nullptr;
} else if (!var->type) {
AddError((kind == VariableKind::kGlobal)
? "module-scope 'var' declaration requires a type or initializer"
: "function-scope 'var' declaration requires a type or initializer",
var->source);
} else if (!v->type) {
AddError((is_global) ? "module-scope 'var' declaration requires a type or initializer"
: "function-scope 'var' declaration requires a type or initializer",
v->source);
return nullptr;
}
if (!storage_ty) {
TINT_ICE(Resolver, diagnostics_) << "failed to determine storage type for variable '" +
builder_->Symbols().NameFor(var->symbol) + "'\n"
<< "Source: " << var->source;
builder_->Symbols().NameFor(v->symbol) + "'\n"
<< "Source: " << v->source;
return nullptr;
}
auto storage_class = var->declared_storage_class;
if (storage_class == ast::StorageClass::kNone && !var->is_const) {
auto storage_class = as_var ? as_var->declared_storage_class : ast::StorageClass::kNone;
if (storage_class == ast::StorageClass::kNone && as_var) {
// No declared storage class. Infer from usage / type.
if (kind == VariableKind::kLocal) {
if (!is_global) {
storage_class = ast::StorageClass::kFunction;
} else if (storage_ty->UnwrapRef()->is_handle()) {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
@ -366,93 +370,83 @@ sem::Variable* Resolver::Variable(const ast::Variable* var,
}
}
if (kind == VariableKind::kLocal && !var->is_const &&
storage_class != ast::StorageClass::kFunction &&
validator_.IsValidationEnabled(var->attributes,
if (!is_global && as_var && storage_class != ast::StorageClass::kFunction &&
validator_.IsValidationEnabled(v->attributes,
ast::DisabledValidation::kIgnoreStorageClass)) {
AddError("function-scope 'var' declaration must use 'function' storage class", var->source);
AddError("function-scope 'var' declaration must use 'function' storage class", v->source);
return nullptr;
}
auto access = var->declared_access;
auto access = as_var ? as_var->declared_access : ast::Access::kUndefined;
if (access == ast::Access::kUndefined) {
access = DefaultAccessForStorageClass(storage_class);
}
auto* var_ty = storage_ty;
if (!var->is_const) {
// Variable declaration. Unlike `let`, `var` has storage.
if (as_var) {
// Variable declaration. Unlike `let` and parameters, `var` has storage.
// Variables are always of a reference type to the declared storage type.
var_ty = builder_->create<sem::Reference>(storage_ty, storage_class, access);
}
if (rhs && !validator_.VariableConstructorOrCast(var, storage_class, storage_ty, rhs->Type())) {
if (rhs && !validator_.VariableConstructorOrCast(v, storage_class, storage_ty, rhs->Type())) {
return nullptr;
}
if (!ApplyStorageClassUsageToType(storage_class, const_cast<sem::Type*>(var_ty), var->source)) {
AddNote(std::string("while instantiating ") +
((kind == VariableKind::kParameter) ? "parameter " : "variable ") +
builder_->Symbols().NameFor(var->symbol),
var->source);
if (!ApplyStorageClassUsageToType(storage_class, const_cast<sem::Type*>(var_ty), v->source)) {
AddNote(std::string("while instantiating ") + ((as_param) ? "parameter " : "variable ") +
builder_->Symbols().NameFor(v->symbol),
v->source);
return nullptr;
}
if (kind == VariableKind::kParameter) {
if (as_param) {
if (auto* ptr = var_ty->As<sem::Pointer>()) {
// For MSL, we push module-scope variables into the entry point as pointer
// parameters, so we also need to handle their store type.
if (!ApplyStorageClassUsageToType(
ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()), var->source)) {
AddNote("while instantiating parameter " + builder_->Symbols().NameFor(var->symbol),
var->source);
ptr->StorageClass(), const_cast<sem::Type*>(ptr->StoreType()), v->source)) {
AddNote("while instantiating parameter " + builder_->Symbols().NameFor(v->symbol),
v->source);
return nullptr;
}
}
auto* param =
builder_->create<sem::Parameter>(as_param, index, var_ty, storage_class, access);
builder_->Sem().Add(as_param, param);
return param;
}
switch (kind) {
case VariableKind::kGlobal: {
sem::BindingPoint binding_point;
if (auto bp = var->BindingPoint()) {
if (is_global) {
sem::BindingPoint binding_point;
if (as_var) {
if (auto bp = as_var->BindingPoint()) {
binding_point = {bp.group->value, bp.binding->value};
}
}
bool has_const_val = rhs && var->is_const && !var->is_overridable;
auto* global = builder_->create<sem::GlobalVariable>(
var, var_ty, storage_class, access,
has_const_val ? rhs->ConstantValue() : sem::Constant{}, binding_point);
bool has_const_val = rhs && as_let && !as_override;
auto* global = builder_->create<sem::GlobalVariable>(
v, var_ty, storage_class, access,
has_const_val ? rhs->ConstantValue() : sem::Constant{}, binding_point);
if (var->is_overridable) {
global->SetIsOverridable();
if (auto* id = ast::GetAttribute<ast::IdAttribute>(var->attributes)) {
global->SetConstantId(static_cast<uint16_t>(id->value));
}
if (as_override) {
if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
global->SetConstantId(static_cast<uint16_t>(id->value));
}
}
global->SetConstructor(rhs);
builder_->Sem().Add(var, global);
return global;
}
case VariableKind::kLocal: {
auto* local = builder_->create<sem::LocalVariable>(
var, var_ty, storage_class, access, current_statement_,
(rhs && var->is_const) ? rhs->ConstantValue() : sem::Constant{});
builder_->Sem().Add(var, local);
local->SetConstructor(rhs);
return local;
}
case VariableKind::kParameter: {
auto* param =
builder_->create<sem::Parameter>(var, index, var_ty, storage_class, access);
builder_->Sem().Add(var, param);
return param;
}
global->SetConstructor(rhs);
builder_->Sem().Add(v, global);
return global;
}
TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled VariableKind " << static_cast<int>(kind);
return nullptr;
auto* local = builder_->create<sem::LocalVariable>(
v, var_ty, storage_class, access, current_statement_,
(rhs && as_let) ? rhs->ConstantValue() : sem::Constant{});
builder_->Sem().Add(v, local);
local->SetConstructor(rhs);
return local;
}
ast::Access Resolver::DefaultAccessForStorageClass(ast::StorageClass storage_class) {
@ -477,13 +471,13 @@ void Resolver::AllocateOverridableConstantIds() {
// TODO(crbug.com/tint/1192): If a transform changes the order or removes an
// unused constant, the allocation may change on the next Resolver pass.
for (auto* decl : builder_->AST().GlobalDeclarations()) {
auto* var = decl->As<ast::Variable>();
if (!var || !var->is_overridable) {
auto* override = decl->As<ast::Override>();
if (!override) {
continue;
}
uint16_t constant_id;
if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(var->attributes)) {
if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes)) {
constant_id = static_cast<uint16_t>(id_attr->value);
} else {
// No ID was specified, so allocate the next available ID.
@ -499,7 +493,7 @@ void Resolver::AllocateOverridableConstantIds() {
next_constant_id = constant_id + 1;
}
auto* sem = sem_.Get<sem::GlobalVariable>(var);
auto* sem = sem_.Get<sem::GlobalVariable>(override);
const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id);
}
}
@ -513,25 +507,21 @@ void Resolver::SetShadows() {
}
}
sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) {
auto* sem = Variable(var, VariableKind::kGlobal);
sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
auto* sem = As<sem::GlobalVariable>(Variable(v, /* is_global */ true));
if (!sem) {
return nullptr;
}
const bool is_var = v->Is<ast::Var>();
auto storage_class = sem->StorageClass();
if (!var->is_const && storage_class == ast::StorageClass::kNone) {
AddError("module-scope 'var' declaration must have a storage class", var->source);
return nullptr;
}
if (var->is_const && storage_class != ast::StorageClass::kNone) {
AddError(var->is_overridable ? "'override' declaration must not have a storage class"
: "'let' declaration must not have a storage class",
var->source);
if (is_var && storage_class == ast::StorageClass::kNone) {
AddError("module-scope 'var' declaration must have a storage class", v->source);
return nullptr;
}
for (auto* attr : var->attributes) {
for (auto* attr : v->attributes) {
Mark(attr);
if (auto* id_attr = attr->As<ast::IdAttribute>()) {
@ -540,7 +530,7 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) {
}
}
if (!validator_.NoDuplicateAttributes(var->attributes)) {
if (!validator_.NoDuplicateAttributes(v->attributes)) {
return nullptr;
}
@ -576,9 +566,8 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
}
}
auto* var =
As<sem::Parameter>(Variable(param, VariableKind::kParameter, parameter_index++));
if (!var) {
auto* p = As<sem::Parameter>(Variable(param, false, parameter_index++));
if (!p) {
return nullptr;
}
@ -589,10 +578,10 @@ sem::Function* Resolver::Function(const ast::Function* decl) {
return nullptr;
}
parameters.emplace_back(var);
parameters.emplace_back(p);
auto* var_ty = const_cast<sem::Type*>(var->Type());
if (auto* str = var_ty->As<sem::Struct>()) {
auto* p_ty = const_cast<sem::Type*>(p->Type());
if (auto* str = p_ty->As<sem::Struct>()) {
switch (decl->PipelineStage()) {
case ast::PipelineStage::kVertex:
str->AddUsage(sem::PipelineStageUsage::kVertexInput);
@ -777,12 +766,12 @@ bool Resolver::WorkgroupSize(const ast::Function* func) {
if (auto* user = args[i]->As<sem::VariableUser>()) {
// We have an variable of a module-scope constant.
auto* decl = user->Variable()->Declaration();
if (!decl->is_const) {
if (!decl->IsAnyOf<ast::Let, ast::Override>()) {
AddError(kErrBadType, values[i]->source);
return false;
}
// Capture the constant if it is pipeline-overridable.
if (decl->is_overridable) {
if (decl->Is<ast::Override>()) {
ws[i].overridable_const = decl;
}
@ -2104,19 +2093,19 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
return nullptr;
}
constexpr const char* kErrInvalidExpr =
"array size identifier must be a literal or a module-scope 'let'";
if (auto* ident = count_expr->As<ast::IdentifierExpression>()) {
// Make sure the identifier is a non-overridable module-scope constant.
auto* var = sem_.ResolvedSymbol<sem::GlobalVariable>(ident);
if (!var || !var->Declaration()->is_const || var->IsOverridable()) {
AddError("array size identifier must be a literal or a module-scope 'let'",
size_source);
// Make sure the identifier is a non-overridable module-scope 'let'.
auto* global = sem_.ResolvedSymbol<sem::GlobalVariable>(ident);
if (!global || !global->Declaration()->Is<ast::Let>()) {
AddError(kErrInvalidExpr, size_source);
return nullptr;
}
count_expr = var->Declaration()->constructor;
count_expr = global->Declaration()->constructor;
} else if (!count_expr->Is<ast::LiteralExpression>()) {
AddError("array size identifier must be a literal or a module-scope 'let'",
size_source);
AddError(kErrInvalidExpr, size_source);
return nullptr;
}
@ -2437,7 +2426,7 @@ sem::Statement* Resolver::VariableDeclStatement(const ast::VariableDeclStatement
return StatementScope(stmt, sem, [&] {
Mark(stmt->variable);
auto* var = Variable(stmt->variable, VariableKind::kLocal);
auto* var = Variable(stmt->variable, /* is_global */ false);
if (!var) {
return false;
}

View File

@ -109,9 +109,6 @@ class Resolver {
const Validator* GetValidatorForTesting() const { return &validator_; }
private:
/// Describes the context in which a variable is declared
enum class VariableKind { kParameter, kLocal, kGlobal };
Validator::ValidTypeStorageLayouts valid_type_storage_layouts_;
/// Structure holding semantic information about a block (i.e. scope), such as
@ -298,9 +295,9 @@ class Resolver {
/// @note this method does not resolve the attributes as these are
/// context-dependent (global, local, parameter)
/// @param var the variable to create or return the `VariableInfo` for
/// @param kind what kind of variable we are declaring
/// @param is_global true if this is module scope, otherwise function scope
/// @param index the index of the parameter, if this variable is a parameter
sem::Variable* Variable(const ast::Variable* var, VariableKind kind, uint32_t index = 0);
sem::Variable* Variable(const ast::Variable* var, bool is_global, uint32_t index = 0);
/// Records the storage class usage for the given type, and any transient
/// dependencies of the type. Validates that the type can be used for the

View File

@ -87,26 +87,6 @@ TEST_F(ResolverTypeValidationTest, GlobalVariableWithStorageClass_Pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTypeValidationTest, GlobalLetWithStorageClass_Fail) {
// let<private> global_var: f32;
AST().AddGlobalVariable(create<ast::Variable>(
Source{{12, 34}}, Symbols().Register("global_let"), ast::StorageClass::kPrivate,
ast::Access::kUndefined, ty.f32(), true, false, Expr(1.23_f), ast::AttributeList{}));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must not have a storage class");
}
TEST_F(ResolverTypeValidationTest, OverrideWithStorageClass_Fail) {
// let<private> global_var: f32;
AST().AddGlobalVariable(create<ast::Variable>(
Source{{12, 34}}, Symbols().Register("global_override"), ast::StorageClass::kPrivate,
ast::Access::kUndefined, ty.f32(), true, true, Expr(1.23_f), ast::AttributeList{}));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'override' declaration must not have a storage class");
}
TEST_F(ResolverTypeValidationTest, GlobalConstNoStorageClass_Pass) {
// let global_var: f32;
GlobalConst(Source{{12, 34}}, "global_var", ty.f32(), Construct(ty.f32()));

View File

@ -949,7 +949,7 @@ class UniformityGraph {
}
current_function_->variables.Set(sem_.Get(decl->variable), node);
if (!decl->variable->is_const) {
if (decl->variable->Is<ast::Var>()) {
current_function_->local_var_decls.insert(
sem_.Get<sem::LocalVariable>(decl->variable));
}
@ -1018,7 +1018,8 @@ class UniformityGraph {
},
[&](const sem::GlobalVariable* global) {
if (global->Declaration()->is_const || global->Access() == ast::Access::kRead) {
if (!global->Declaration()->Is<ast::Var>() ||
global->Access() == ast::Access::kRead) {
node->AddEdge(cf);
} else {
node->AddEdge(current_function_->may_be_non_uniform);

View File

@ -297,7 +297,7 @@ bool Validator::Materialize(const sem::Materialize* m) const {
return true;
}
bool Validator::VariableConstructorOrCast(const ast::Variable* var,
bool Validator::VariableConstructorOrCast(const ast::Variable* v,
ast::StorageClass storage_class,
const sem::Type* storage_ty,
const sem::Type* rhs_ty) const {
@ -305,14 +305,14 @@ bool Validator::VariableConstructorOrCast(const ast::Variable* var,
// Value type has to match storage type
if (storage_ty != value_type) {
std::string decl = var->is_const ? "let" : "var";
std::string decl = v->Is<ast::Let>() ? "let" : "var";
AddError("cannot initialize " + decl + " of type '" + sem_.TypeNameOf(storage_ty) +
"' with value of type '" + sem_.TypeNameOf(rhs_ty) + "'",
var->source);
v->source);
return false;
}
if (!var->is_const) {
if (v->Is<ast::Var>()) {
switch (storage_class) {
case ast::StorageClass::kPrivate:
case ast::StorageClass::kFunction:
@ -325,7 +325,7 @@ bool Validator::VariableConstructorOrCast(const ast::Variable* var,
"' cannot have an initializer. var initializers are only "
"supported for the storage classes "
"'private' and 'function'",
var->source);
v->source);
return false;
}
}
@ -502,21 +502,22 @@ bool Validator::StorageClassLayout(const sem::Variable* var,
}
bool Validator::GlobalVariable(
const sem::Variable* var,
const sem::GlobalVariable* global,
std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const {
auto* decl = var->Declaration();
auto* decl = global->Declaration();
if (!NoDuplicateAttributes(decl->attributes)) {
return false;
}
for (auto* attr : decl->attributes) {
if (decl->is_const) {
if (decl->is_overridable) {
bool ok = Switch(
decl, //
[&](const ast::Override*) {
for (auto* attr : decl->attributes) {
if (auto* id_attr = attr->As<ast::IdAttribute>()) {
uint32_t id = id_attr->value;
auto it = constant_ids.find(id);
if (it != constant_ids.end() && it->second != var) {
if (it != constant_ids.end() && it->second != global) {
AddError("pipeline constant IDs must be unique", attr->source);
AddNote("a pipeline constant with an ID of " + std::to_string(id) +
" was previously declared here:",
@ -533,32 +534,45 @@ bool Validator::GlobalVariable(
AddError("attribute is not valid for 'override' declaration", attr->source);
return false;
}
} else {
AddError("attribute is not valid for module-scope 'let' declaration", attr->source);
}
return true;
},
[&](const ast::Let*) {
if (!decl->attributes.empty()) {
AddError("attribute is not valid for module-scope 'let' declaration",
decl->attributes[0]->source);
return false;
}
} else {
bool is_shader_io_attribute =
attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
ast::InvariantAttribute, ast::LocationAttribute>();
bool has_io_storage_class = var->StorageClass() == ast::StorageClass::kInput ||
var->StorageClass() == ast::StorageClass::kOutput;
if (!(attr->IsAnyOf<ast::BindingAttribute, ast::GroupAttribute,
ast::InternalAttribute>()) &&
(!is_shader_io_attribute || !has_io_storage_class)) {
AddError("attribute is not valid for module-scope 'var'", attr->source);
return false;
return true;
},
[&](const ast::Var*) {
for (auto* attr : decl->attributes) {
bool is_shader_io_attribute =
attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute,
ast::InvariantAttribute, ast::LocationAttribute>();
bool has_io_storage_class = global->StorageClass() == ast::StorageClass::kInput ||
global->StorageClass() == ast::StorageClass::kOutput;
if (!attr->IsAnyOf<ast::BindingAttribute, ast::GroupAttribute,
ast::InternalAttribute>() &&
(!is_shader_io_attribute || !has_io_storage_class)) {
AddError("attribute is not valid for module-scope 'var'", attr->source);
return false;
}
}
}
return true;
});
if (!ok) {
return false;
}
if (var->StorageClass() == ast::StorageClass::kFunction) {
if (global->StorageClass() == ast::StorageClass::kFunction) {
AddError("module-scope 'var' must not use storage class 'function'", decl->source);
return false;
}
auto binding_point = decl->BindingPoint();
switch (var->StorageClass()) {
switch (global->StorageClass()) {
case ast::StorageClass::kUniform:
case ast::StorageClass::kStorage:
case ast::StorageClass::kHandle: {
@ -581,23 +595,23 @@ bool Validator::GlobalVariable(
}
}
// https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
// The access mode always has a default, and except for variables in the
// storage storage class, must not be written.
if (var->StorageClass() != ast::StorageClass::kStorage &&
decl->declared_access != ast::Access::kUndefined) {
AddError("only variables in <storage> storage class may declare an access mode",
decl->source);
return false;
}
if (auto* var = decl->As<ast::Var>()) {
// https://gpuweb.github.io/gpuweb/wgsl/#variable-declaration
// The access mode always has a default, and except for variables in the
// storage storage class, must not be written.
if (global->StorageClass() != ast::StorageClass::kStorage &&
var->declared_access != ast::Access::kUndefined) {
AddError("only variables in <storage> storage class may declare an access mode",
var->source);
return false;
}
if (!decl->is_const) {
if (!AtomicVariable(var, atomic_composite_info)) {
if (!AtomicVariable(global, atomic_composite_info)) {
return false;
}
}
return Variable(var);
return Variable(global);
}
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
@ -641,14 +655,17 @@ bool Validator::AtomicVariable(
return true;
}
bool Validator::Variable(const sem::Variable* var) const {
auto* decl = var->Declaration();
auto* storage_ty = var->Type()->UnwrapRef();
bool Validator::Variable(const sem::Variable* v) const {
auto* decl = v->Declaration();
auto* storage_ty = v->Type()->UnwrapRef();
if (var->Is<sem::GlobalVariable>()) {
auto* as_let = decl->As<ast::Let>();
auto* as_var = decl->As<ast::Var>();
if (v->Is<sem::GlobalVariable>()) {
auto name = symbols_.NameFor(decl->symbol);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
auto* kind = var->Declaration()->is_const ? "let" : "var";
auto* kind = as_let ? "let" : "var";
AddError(
"'" + name + "' is a builtin and cannot be redeclared as a module-scope " + kind,
decl->source);
@ -656,14 +673,13 @@ bool Validator::Variable(const sem::Variable* var) const {
}
}
if (!decl->is_const && !IsStorable(storage_ty)) {
if (as_var && !IsStorable(storage_ty)) {
AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a var",
decl->source);
return false;
}
if (decl->is_const && !var->Is<sem::Parameter>() &&
!(storage_ty->IsConstructible() || storage_ty->Is<sem::Pointer>())) {
if (as_let && !(storage_ty->IsConstructible() || storage_ty->Is<sem::Pointer>())) {
AddError(sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a let",
decl->source);
return false;
@ -688,16 +704,17 @@ bool Validator::Variable(const sem::Variable* var) const {
}
}
if (var->Is<sem::LocalVariable>() && !decl->is_const &&
if (v->Is<sem::LocalVariable>() && as_var &&
IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass)) {
if (!var->Type()->UnwrapRef()->IsConstructible()) {
if (!v->Type()->UnwrapRef()->IsConstructible()) {
AddError("function variable must have a constructible type",
decl->type ? decl->type->source : decl->source);
return false;
}
}
if (storage_ty->is_handle() && decl->declared_storage_class != ast::StorageClass::kNone) {
if (as_var && storage_ty->is_handle() &&
as_var->declared_storage_class != ast::StorageClass::kNone) {
// https://gpuweb.github.io/gpuweb/wgsl/#module-scope-variables
// If the store type is a texture type or a sampler type, then the
// variable declaration must not have a storage class attribute. The
@ -709,9 +726,10 @@ bool Validator::Variable(const sem::Variable* var) const {
}
if (IsValidationEnabled(decl->attributes, ast::DisabledValidation::kIgnoreStorageClass) &&
(decl->declared_storage_class == ast::StorageClass::kInput ||
decl->declared_storage_class == ast::StorageClass::kOutput)) {
AddError("invalid use of input/output storage class", decl->source);
as_var &&
(as_var->declared_storage_class == ast::StorageClass::kInput ||
as_var->declared_storage_class == ast::StorageClass::kOutput)) {
AddError("invalid use of input/output storage class", as_var->source);
return false;
}
return true;
@ -1223,12 +1241,12 @@ bool Validator::EntryPoint(const sem::Function* func, ast::PipelineStage stage)
// Validate there are no resource variable binding collisions
std::unordered_map<sem::BindingPoint, const ast::Variable*> binding_points;
for (auto* var : func->TransitivelyReferencedGlobals()) {
auto* var_decl = var->Declaration();
if (!var_decl->BindingPoint()) {
for (auto* global : func->TransitivelyReferencedGlobals()) {
auto* var_decl = global->Declaration()->As<ast::Var>();
if (!var_decl || !var_decl->BindingPoint()) {
continue;
}
auto bp = var->BindingPoint();
auto bp = global->BindingPoint();
auto res = binding_points.emplace(bp, var_decl);
if (!res.second &&
IsValidationEnabled(decl->attributes,
@ -1663,12 +1681,6 @@ bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_stat
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
if (var->Declaration()->is_const) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::FunctionCall() encountered an address-of "
"expression of a constant identifier expression";
return false;
}
is_valid = true;
}
}
@ -2172,18 +2184,16 @@ bool Validator::Assignment(const ast::Statement* a, const sem::Type* rhs_ty) con
// https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement
auto const* lhs_ty = sem_.TypeOf(lhs);
if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration();
if (var->Is<sem::Parameter>()) {
AddError("cannot assign to function parameter", lhs->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
return false;
}
if (decl->is_const) {
AddError(
decl->is_overridable ? "cannot assign to 'override'" : "cannot assign to 'let'",
lhs->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
if (auto* variable = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* v = variable->Declaration();
const char* err = Switch(
v, //
[&](const ast::Parameter*) { return "cannot assign to function parameter"; },
[&](const ast::Let*) { return "cannot assign to 'let'"; },
[&](const ast::Override*) { return "cannot assign to 'override'"; });
if (err) {
AddError(err, lhs->source);
AddNote("'" + symbols_.NameFor(v->symbol) + "' is declared here:", v->source);
return false;
}
}
@ -2222,17 +2232,16 @@ bool Validator::IncrementDecrementStatement(const ast::IncrementDecrementStateme
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
if (auto* var = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* decl = var->Declaration();
if (var->Is<sem::Parameter>()) {
AddError("cannot modify function parameter", lhs->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
return false;
}
if (decl->is_const) {
AddError(decl->is_overridable ? "cannot modify 'override'" : "cannot modify 'let'",
lhs->source);
AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source);
if (auto* variable = sem_.ResolvedSymbol<sem::Variable>(lhs)) {
auto* v = variable->Declaration();
const char* err = Switch(
v, //
[&](const ast::Parameter*) { return "cannot modify function parameter"; },
[&](const ast::Let*) { return "cannot modify 'let'"; },
[&](const ast::Override*) { return "cannot modify 'override'"; });
if (err) {
AddError(err, lhs->source);
AddNote("'" + symbols_.NameFor(v->symbol) + "' is declared here:", v->source);
return false;
}
}

View File

@ -237,7 +237,7 @@ class Validator {
/// @param atomic_composite_info atomic composite info in the module
/// @returns true on success, false otherwise
bool GlobalVariable(
const sem::Variable* var,
const sem::GlobalVariable* var,
std::unordered_map<uint32_t, const sem::Variable*> constant_ids,
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info) const;
@ -345,12 +345,12 @@ class Validator {
bool Variable(const sem::Variable* var) const;
/// Validates a variable constructor or cast
/// @param var the variable to validate
/// @param v the variable to validate
/// @param storage_class the storage class of the variable
/// @param storage_type the type of the storage
/// @param rhs_type the right hand side of the expression
/// @returns true on succes, false otherwise
bool VariableConstructorOrCast(const ast::Variable* var,
bool VariableConstructorOrCast(const ast::Variable* v,
ast::StorageClass storage_class,
const sem::Type* storage_type,
const sem::Type* rhs_type) const;

View File

@ -24,22 +24,6 @@ namespace {
struct ResolverVarLetValidationTest : public resolver::TestHelper, public testing::Test {};
TEST_F(ResolverVarLetValidationTest, LetNoInitializer) {
// let a : i32;
WrapInFunction(Let(Source{{12, 34}}, "a", ty.i32(), nullptr));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer");
}
TEST_F(ResolverVarLetValidationTest, GlobalLetNoInitializer) {
// let a : i32;
GlobalConst(Source{{12, 34}}, "a", ty.i32(), nullptr);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: 'let' declaration must have an initializer");
}
TEST_F(ResolverVarLetValidationTest, VarNoInitializerNoType) {
// var a;
WrapInFunction(Var(Source{{12, 34}}, "a", nullptr));

View File

@ -44,10 +44,10 @@ std::vector<std::pair<const Variable*, const ast::LocationAttribute*>>
Function::TransitivelyReferencedLocationVariables() const {
std::vector<std::pair<const Variable*, const ast::LocationAttribute*>> ret;
for (auto* var : TransitivelyReferencedGlobals()) {
for (auto* attr : var->Declaration()->attributes) {
for (auto* global : TransitivelyReferencedGlobals()) {
for (auto* attr : global->Declaration()->attributes) {
if (auto* location = attr->As<ast::LocationAttribute>()) {
ret.push_back({var, location});
ret.push_back({global, location});
break;
}
}
@ -58,13 +58,13 @@ Function::TransitivelyReferencedLocationVariables() const {
Function::VariableBindings Function::TransitivelyReferencedUniformVariables() const {
VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) {
if (var->StorageClass() != ast::StorageClass::kUniform) {
for (auto* global : TransitivelyReferencedGlobals()) {
if (global->StorageClass() != ast::StorageClass::kUniform) {
continue;
}
if (auto binding_point = var->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point});
if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({global, binding_point});
}
}
return ret;
@ -73,13 +73,13 @@ Function::VariableBindings Function::TransitivelyReferencedUniformVariables() co
Function::VariableBindings Function::TransitivelyReferencedStorageBufferVariables() const {
VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) {
if (var->StorageClass() != ast::StorageClass::kStorage) {
for (auto* global : TransitivelyReferencedGlobals()) {
if (global->StorageClass() != ast::StorageClass::kStorage) {
continue;
}
if (auto binding_point = var->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point});
if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({global, binding_point});
}
}
return ret;
@ -89,10 +89,10 @@ std::vector<std::pair<const Variable*, const ast::BuiltinAttribute*>>
Function::TransitivelyReferencedBuiltinVariables() const {
std::vector<std::pair<const Variable*, const ast::BuiltinAttribute*>> ret;
for (auto* var : TransitivelyReferencedGlobals()) {
for (auto* attr : var->Declaration()->attributes) {
for (auto* global : TransitivelyReferencedGlobals()) {
for (auto* attr : global->Declaration()->attributes) {
if (auto* builtin = attr->As<ast::BuiltinAttribute>()) {
ret.push_back({var, builtin});
ret.push_back({global, builtin});
break;
}
}
@ -119,11 +119,11 @@ Function::VariableBindings Function::TransitivelyReferencedMultisampledTextureVa
Function::VariableBindings Function::TransitivelyReferencedVariablesOfType(
const tint::TypeInfo* type) const {
VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = var->Type()->UnwrapRef();
for (auto* global : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = global->Type()->UnwrapRef();
if (unwrapped_type->TypeInfo().Is(type)) {
if (auto binding_point = var->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point});
if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({global, binding_point});
}
}
}
@ -143,15 +143,15 @@ Function::VariableBindings Function::TransitivelyReferencedSamplerVariablesImpl(
ast::SamplerKind kind) const {
VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = var->Type()->UnwrapRef();
for (auto* global : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = global->Type()->UnwrapRef();
auto* sampler = unwrapped_type->As<sem::Sampler>();
if (sampler == nullptr || sampler->kind() != kind) {
continue;
}
if (auto binding_point = var->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point});
if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({global, binding_point});
}
}
return ret;
@ -161,8 +161,8 @@ Function::VariableBindings Function::TransitivelyReferencedSampledTextureVariabl
bool multisampled) const {
VariableBindings ret;
for (auto* var : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = var->Type()->UnwrapRef();
for (auto* global : TransitivelyReferencedGlobals()) {
auto* unwrapped_type = global->Type()->UnwrapRef();
auto* texture = unwrapped_type->As<sem::Texture>();
if (texture == nullptr) {
continue;
@ -175,8 +175,8 @@ Function::VariableBindings Function::TransitivelyReferencedSampledTextureVariabl
continue;
}
if (auto binding_point = var->Declaration()->BindingPoint()) {
ret.push_back({var, binding_point});
if (auto binding_point = global->Declaration()->BindingPoint()) {
ret.push_back({global, binding_point});
}
}

View File

@ -27,6 +27,7 @@ class Function;
class IfStatement;
class MemberAccessorExpression;
class Node;
class Override;
class Statement;
class Struct;
class StructMember;
@ -45,6 +46,7 @@ class Function;
class IfStatement;
class MemberAccessorExpression;
class Node;
class GlobalVariable;
class Statement;
class Struct;
class StructMember;
@ -69,6 +71,7 @@ struct TypeMappings {
IfStatement* operator()(ast::IfStatement*);
MemberAccessorExpression* operator()(ast::MemberAccessorExpression*);
Node* operator()(ast::Node*);
GlobalVariable* operator()(ast::Override*);
Statement* operator()(ast::Statement*);
Struct* operator()(ast::Struct*);
StructMember* operator()(ast::StructMember*);

View File

@ -62,7 +62,7 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration,
GlobalVariable::~GlobalVariable() = default;
Parameter::Parameter(const ast::Variable* declaration,
Parameter::Parameter(const ast::Parameter* declaration,
uint32_t index,
const sem::Type* type,
ast::StorageClass storage_class,

View File

@ -154,24 +154,14 @@ class GlobalVariable final : public Castable<GlobalVariable, Variable> {
sem::BindingPoint BindingPoint() const { return binding_point_; }
/// @param id the constant identifier to assign to this variable
void SetConstantId(uint16_t id) {
constant_id_ = id;
is_overridable_ = true;
}
void SetConstantId(uint16_t id) { constant_id_ = id; }
/// @returns the pipeline constant ID associated with the variable
uint16_t ConstantId() const { return constant_id_; }
/// @param is_overridable true if this is a pipeline overridable constant
void SetIsOverridable(bool is_overridable = true) { is_overridable_ = is_overridable; }
/// @returns true if this is pipeline overridable constant
bool IsOverridable() const { return is_overridable_; }
private:
const sem::BindingPoint binding_point_;
bool is_overridable_ = false;
uint16_t constant_id_ = 0;
};
@ -185,7 +175,7 @@ class Parameter final : public Castable<Parameter, Variable> {
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param usage the semantic usage for the parameter
Parameter(const ast::Variable* declaration,
Parameter(const ast::Parameter* declaration,
uint32_t index,
const sem::Type* type,
ast::StorageClass storage_class,

View File

@ -54,8 +54,8 @@ void AddSpirvBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) co
// contains it in the destination program.
std::unordered_map<const sem::Type*, const ast::Struct*> wrapper_structs;
// Process global variables that are buffers.
for (auto* var : ctx.src->AST().GlobalVariables()) {
// Process global 'var' declarations that are buffers.
for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
auto* sem_var = sem.Get<sem::GlobalVariable>(var);
if (var->declared_storage_class != ast::StorageClass::kStorage &&
var->declared_storage_class != ast::StorageClass::kUniform) {

View File

@ -67,8 +67,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
}
auto* func = ctx.src->Sem().Get(func_ast);
std::unordered_map<sem::BindingPoint, int> binding_point_counts;
for (auto* var : func->TransitivelyReferencedGlobals()) {
if (auto binding_point = var->Declaration()->BindingPoint()) {
for (auto* global : func->TransitivelyReferencedGlobals()) {
if (auto binding_point = global->Declaration()->BindingPoint()) {
BindingPoint from{binding_point.group->value, binding_point.binding->value};
auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) {
@ -88,7 +88,7 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
}
}
for (auto* var : ctx.src->AST().GlobalVariables()) {
for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
if (auto binding_point = var->BindingPoint()) {
// The original binding point
BindingPoint from{binding_point.group->value, binding_point.binding->value};
@ -130,10 +130,10 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
}
auto* ty = sem->Type()->UnwrapRef();
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
auto* new_var = ctx.dst->create<ast::Variable>(
ctx.Clone(var->source), ctx.Clone(var->symbol), var->declared_storage_class, ac,
inner_ty, false, false, ctx.Clone(var->constructor),
ctx.Clone(var->attributes));
auto* new_var =
ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
var->declared_storage_class, ac, ctx.Clone(var->constructor),
ctx.Clone(var->attributes));
ctx.Replace(var, new_var);
}

View File

@ -147,16 +147,16 @@ struct CombineSamplers::State {
// Remove all texture and sampler global variables. These will be replaced
// by combined samplers.
for (auto* var : ctx.src->AST().GlobalVariables()) {
auto* type = sem.Get(var->type);
if (type && type->IsAnyOf<sem::Texture, sem::Sampler>() &&
for (auto* global : ctx.src->AST().GlobalVariables()) {
auto* type = sem.Get(global->type);
if (tint::IsAnyOf<sem::Texture, sem::Sampler>(type) &&
!type->Is<sem::StorageTexture>()) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), var);
} else if (auto binding_point = var->BindingPoint()) {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
} else if (auto binding_point = global->BindingPoint()) {
if (binding_point.group->value == 0 && binding_point.binding->value == 0) {
auto* attribute =
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertFront(var->attributes, attribute);
ctx.InsertFront(global->attributes, attribute);
}
}
}
@ -188,9 +188,8 @@ struct CombineSamplers::State {
} else {
// Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler.
const ast::Type* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
const ast::Variable* var =
ctx.dst->Param(ctx.dst->Symbols().New(name), type);
auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.push_back(var);
function_combined_texture_samplers_[func][pair] = var;
}

View File

@ -31,11 +31,11 @@ const ast::VariableDeclStatement* AsTrivialLetDecl(const ast::Statement* stmt) {
if (!var_decl) {
return nullptr;
}
auto* var = var_decl->variable;
if (!var->is_const) {
auto* let = var_decl->variable->As<ast::Let>();
if (!let) {
return nullptr;
}
auto* ctor = var->constructor;
auto* ctor = let->constructor;
if (!IsAnyOf<ast::IdentifierExpression, ast::LiteralExpression>(ctor)) {
return nullptr;
}

View File

@ -155,9 +155,13 @@ struct ModuleScopeVarToEntryPointParam::State {
return workgroup_parameter_symbol;
};
for (auto* var : func_sem->TransitivelyReferencedGlobals()) {
auto sc = var->StorageClass();
auto* ty = var->Type()->UnwrapRef();
for (auto* global : func_sem->TransitivelyReferencedGlobals()) {
auto* var = global->Declaration()->As<ast::Var>();
if (!var) {
continue;
}
auto sc = global->StorageClass();
auto* ty = global->Type()->UnwrapRef();
if (sc == ast::StorageClass::kNone) {
continue;
}
@ -182,12 +186,12 @@ struct ModuleScopeVarToEntryPointParam::State {
bool is_wrapped = false;
if (is_entry_point) {
if (var->Type()->UnwrapRef()->is_handle()) {
if (global->Type()->UnwrapRef()->is_handle()) {
// For a texture or sampler variable, redeclare it as an entry point
// parameter. Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
auto attrs = ctx.Clone(var->Declaration()->attributes);
auto attrs = ctx.Clone(var->attributes);
attrs.push_back(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
ctx.InsertFront(func_ast->params, param);
@ -195,7 +199,7 @@ struct ModuleScopeVarToEntryPointParam::State {
sc == ast::StorageClass::kUniform) {
// Variables into the Storage and Uniform storage classes are
// redeclared as entry point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->attributes);
auto attributes = ctx.Clone(var->attributes);
attributes.push_back(
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
attributes.push_back(
@ -214,22 +218,22 @@ struct ModuleScopeVarToEntryPointParam::State {
is_wrapped = true;
}
param_type = ctx.dst->ty.pointer(param_type, sc,
var->Declaration()->declared_access);
param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access);
auto* param = ctx.dst->Param(new_var_symbol, param_type, attributes);
ctx.InsertFront(func_ast->params, param);
is_pointer = true;
} else if (sc == ast::StorageClass::kWorkgroup && ContainsMatrix(var->Type())) {
} else if (sc == ast::StorageClass::kWorkgroup &&
ContainsMatrix(global->Type())) {
// Due to a bug in the MSL compiler, we use a threadgroup memory
// argument for any workgroup allocation that contains a matrix.
// See crbug.com/tint/938.
// TODO(jrprice): Do this for all other workgroup variables too.
// Create a member in the workgroup parameter struct.
auto member = ctx.Clone(var->Declaration()->symbol);
auto member = ctx.Clone(var->symbol);
workgroup_parameter_members.push_back(
ctx.dst->Member(member, store_type()));
CloneStructTypes(var->Type()->UnwrapRef());
CloneStructTypes(global->Type()->UnwrapRef());
// Create a function-scope variable that is a pointer to the member.
auto* member_ptr = ctx.dst->AddressOf(
@ -246,7 +250,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// this variable.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreStorageClass);
auto* constructor = ctx.Clone(var->Declaration()->constructor);
auto* constructor = ctx.Clone(var->constructor);
auto* local_var =
ctx.dst->Var(new_var_symbol, store_type(), sc, constructor,
ast::AttributeList{disable_validation});
@ -257,9 +261,8 @@ struct ModuleScopeVarToEntryPointParam::State {
// Use a pointer for non-handle types.
auto* param_type = store_type();
ast::AttributeList attributes;
if (!var->Type()->UnwrapRef()->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, sc,
var->Declaration()->declared_access);
if (!global->Type()->UnwrapRef()->is_handle()) {
param_type = ctx.dst->ty.pointer(param_type, sc, var->declared_access);
is_pointer = true;
// Disable validation of the parameter's storage class and of
@ -275,7 +278,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Replace all uses of the module-scope variable.
// For non-entry points, dereference non-handle pointer parameters.
for (auto* user : var->Users()) {
for (auto* user : global->Users()) {
if (user->Stmt()->Function()->Declaration() == func_ast) {
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
if (is_pointer) {
@ -298,7 +301,7 @@ struct ModuleScopeVarToEntryPointParam::State {
}
}
var_to_newvar[var] = {new_var_symbol, is_pointer, is_wrapped};
var_to_newvar[global] = {new_var_symbol, is_pointer, is_wrapped};
}
if (!workgroup_parameter_members.empty()) {

View File

@ -86,8 +86,8 @@ struct MultiplanarExternalTexture::State {
// binding and create two additional bindings (one texture_2d<f32> to
// represent the secondary plane and one uniform buffer for the
// ExternalTextureParams struct).
for (auto* var : ctx.src->AST().GlobalVariables()) {
auto* sem_var = sem.Get(var);
for (auto* global : ctx.src->AST().GlobalVariables()) {
auto* sem_var = sem.Get(global);
if (!sem_var->Type()->UnwrapRef()->Is<sem::ExternalTexture>()) {
continue;
}
@ -95,7 +95,7 @@ struct MultiplanarExternalTexture::State {
// If the attributes are empty, then this must be a texture_external
// passed as a function parameter. These variables are transformed
// elsewhere.
if (var->attributes.empty()) {
if (global->attributes.empty()) {
continue;
}
@ -109,8 +109,8 @@ struct MultiplanarExternalTexture::State {
// provided to this transform. We fetch the new binding points by
// providing the original texture_external binding points into the
// passed map.
BindingPoint bp = {var->BindingPoint().group->value,
var->BindingPoint().binding->value};
BindingPoint bp = {global->BindingPoint().group->value,
global->BindingPoint().binding->value};
BindingsMap::const_iterator it = new_binding_points->bindings_map.find(bp);
if (it == new_binding_points->bindings_map.end()) {
@ -129,7 +129,7 @@ struct MultiplanarExternalTexture::State {
// corresponds with the new destination bindings.
// NewBindingSymbols new_binding_syms;
auto& syms = new_binding_symbols[sem_var];
syms.plane_0 = ctx.Clone(var->symbol);
syms.plane_0 = ctx.Clone(global->symbol);
syms.plane_1 = b.Symbols().New("ext_tex_plane_1");
b.Global(syms.plane_1, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
b.GroupAndBinding(bps.plane_1.group, bps.plane_1.binding));
@ -140,13 +140,13 @@ struct MultiplanarExternalTexture::State {
// Replace the original texture_external binding with a texture_2d<f32>
// binding.
ast::AttributeList cloned_attributes = ctx.Clone(var->attributes);
const ast::Expression* cloned_constructor = ctx.Clone(var->constructor);
ast::AttributeList cloned_attributes = ctx.Clone(global->attributes);
const ast::Expression* cloned_constructor = ctx.Clone(global->constructor);
auto* replacement =
b.Var(syms.plane_0, b.ty.sampled_texture(ast::TextureDimension::k2d, b.ty.f32()),
cloned_constructor, cloned_attributes);
ctx.Replace(var, replacement);
ctx.Replace(global, replacement);
}
// We must update all the texture_external parameters for user declared

View File

@ -133,8 +133,8 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
// plus 1, or group 0 if no resource bound.
group = 0;
for (auto* var : ctx.src->AST().GlobalVariables()) {
if (auto binding_point = var->BindingPoint()) {
for (auto* global : ctx.src->AST().GlobalVariables()) {
if (auto binding_point = global->BindingPoint()) {
if (binding_point.group->value >= group) {
group = binding_point.group->value + 1;
}

View File

@ -109,8 +109,8 @@ struct SimplifyPointers::State {
}
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
auto* var = user->Variable();
if (var->Is<sem::LocalVariable>() && //
var->Declaration()->is_const && //
if (var->Is<sem::LocalVariable>() && //
var->Declaration()->Is<ast::Let>() && //
var->Type()->Is<sem::Pointer>()) {
op.expr = var->Declaration()->constructor;
continue;
@ -161,7 +161,7 @@ struct SimplifyPointers::State {
// permitted.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* let = node->As<ast::VariableDeclStatement>()) {
if (!let->variable->is_const) {
if (!let->variable->Is<ast::Let>()) {
continue; // Not a `let` declaration. Ignore.
}

View File

@ -64,38 +64,43 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
referenced_vars.emplace(var->Declaration());
}
// Clone any module-scope variables, types, and functions that are statically
// referenced by the target entry point.
// Clone any module-scope variables, types, and functions that are statically referenced by the
// target entry point.
for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<ast::TypeDecl>()) {
// TODO(jrprice): Strip unused types.
ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
} else if (auto* var = decl->As<ast::Variable>()) {
if (referenced_vars.count(var)) {
if (var->is_overridable) {
// It is an overridable constant
if (!ast::HasAttribute<ast::IdAttribute>(var->attributes)) {
Switch(
decl, //
[&](const ast::TypeDecl* ty) {
// TODO(jrprice): Strip unused types.
ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
},
[&](const ast::Override* override) {
if (referenced_vars.count(override)) {
if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
// If the constant doesn't already have an @id() attribute, add one
// so that its allocated ID so that it won't be affected by other
// stripped away constants
auto* global = sem.Get(var)->As<sem::GlobalVariable>();
auto* global = sem.Get(override);
const auto* id = ctx.dst->Id(global->ConstantId());
ctx.InsertFront(var->attributes, id);
ctx.InsertFront(override->attributes, id);
}
ctx.dst->AST().AddGlobalVariable(ctx.Clone(override));
}
ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
}
} else if (auto* func = decl->As<ast::Function>()) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
ctx.dst->AST().AddFunction(ctx.Clone(func));
}
} else if (auto* ext = decl->As<ast::Enable>()) {
ctx.dst->AST().AddEnable(ctx.Clone(ext));
} else {
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;
return;
}
},
[&](const ast::Variable* v) { // var, let
if (referenced_vars.count(v)) {
ctx.dst->AST().AddGlobalVariable(ctx.Clone(v));
}
},
[&](const ast::Function* func) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
ctx.dst->AST().AddFunction(ctx.Clone(func));
}
},
[&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); },
[&](Default) {
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;
});
}
// Clone the entry point.

View File

@ -44,28 +44,42 @@ struct Unshadow::State {
// Maps a variable to its new name.
std::unordered_map<const sem::Variable*, Symbol> renamed_to;
auto rename = [&](const sem::Variable* var) -> const ast::Variable* {
auto* decl = var->Declaration();
auto rename = [&](const sem::Variable* v) -> const ast::Variable* {
auto* decl = v->Declaration();
auto name = ctx.src->Symbols().NameFor(decl->symbol);
auto symbol = ctx.dst->Symbols().New(name);
renamed_to.emplace(var, symbol);
renamed_to.emplace(v, symbol);
auto source = ctx.Clone(decl->source);
auto* type = ctx.Clone(decl->type);
auto* constructor = ctx.Clone(decl->constructor);
auto attributes = ctx.Clone(decl->attributes);
return ctx.dst->create<ast::Variable>(source, symbol, decl->declared_storage_class,
decl->declared_access, type, decl->is_const,
decl->is_overridable, constructor, attributes);
return Switch(
decl, //
[&](const ast::Var* var) {
return ctx.dst->Var(source, symbol, type, var->declared_storage_class,
var->declared_access, constructor, attributes);
},
[&](const ast::Let*) {
return ctx.dst->Let(source, symbol, type, constructor, attributes);
},
[&](const ast::Parameter*) {
return ctx.dst->Param(source, symbol, type, attributes);
},
[&](Default) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected variable type: " << decl->TypeInfo().name;
return nullptr;
});
};
ctx.ReplaceAll([&](const ast::Variable* var) -> const ast::Variable* {
if (auto* local = sem.Get<sem::LocalVariable>(var)) {
ctx.ReplaceAll([&](const ast::Variable* v) -> const ast::Variable* {
if (auto* local = sem.Get<sem::LocalVariable>(v)) {
if (local->Shadows()) {
return rename(local);
}
}
if (auto* param = sem.Get<sem::Parameter>(var)) {
if (auto* param = sem.Get<sem::Parameter>(v)) {
if (param->Shadows()) {
return rename(param);
}

View File

@ -189,18 +189,18 @@ class HoistToDeclBefore::State {
/// before `before_expr`.
/// @param before_expr expression to insert `expr` before
/// @param expr expression to hoist
/// @param as_const hoist to `let` if true, otherwise to `var`
/// @param as_let hoist to `let` if true, otherwise to `var`
/// @param decl_name optional name to use for the variable/constant name
/// @return true on success
bool Add(const sem::Expression* before_expr,
const ast::Expression* expr,
bool as_const,
bool as_let,
const char* decl_name) {
auto name = b.Symbols().New(decl_name);
// Construct the let/var that holds the hoisted expr
auto* v = as_const ? b.Let(name, nullptr, ctx.Clone(expr))
: b.Var(name, nullptr, ctx.Clone(expr));
auto* v = as_let ? static_cast<const ast::Variable*>(b.Let(name, nullptr, ctx.Clone(expr)))
: static_cast<const ast::Variable*>(b.Var(name, nullptr, ctx.Clone(expr)));
auto* decl = b.Decl(v);
if (!InsertBefore(before_expr->Stmt(), decl)) {
@ -330,9 +330,9 @@ HoistToDeclBefore::~HoistToDeclBefore() {}
bool HoistToDeclBefore::Add(const sem::Expression* before_expr,
const ast::Expression* expr,
bool as_const,
bool as_let,
const char* decl_name) {
return state_->Add(before_expr, expr, as_const, decl_name);
return state_->Add(before_expr, expr, as_let, decl_name);
}
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,

View File

@ -695,7 +695,7 @@ struct State {
/// vertex_index and instance_index builtins if present.
/// @param func the entry point function
/// @param param the parameter to process
void ProcessNonStructParameter(const ast::Function* func, const ast::Variable* param) {
void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) {
if (auto* location = ast::GetAttribute<ast::LocationAttribute>(param->attributes)) {
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol);
@ -733,7 +733,7 @@ struct State {
/// @param param the parameter to process
/// @param struct_ty the structure type
void ProcessStructParameter(const ast::Function* func,
const ast::Variable* param,
const ast::Parameter* param,
const ast::Struct* struct_ty) {
auto param_sym = ctx.Clone(param->symbol);

View File

@ -416,8 +416,8 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
bool ZeroInitWorkgroupMemory::ShouldRun(const Program* program, const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = decl->As<ast::Variable>()) {
for (auto* global : program->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (var->declared_storage_class == ast::StorageClass::kWorkgroup) {
return true;
}

View File

@ -1904,42 +1904,47 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
}
bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
if (global->is_const) {
return EmitProgramConstVariable(global);
}
auto* sem = builder_.Sem().Get(global);
switch (sem->StorageClass()) {
case ast::StorageClass::kUniform:
return EmitUniformVariable(sem);
case ast::StorageClass::kStorage:
return EmitStorageVariable(sem);
case ast::StorageClass::kHandle:
return EmitHandleVariable(sem);
case ast::StorageClass::kPrivate:
return EmitPrivateVariable(sem);
case ast::StorageClass::kWorkgroup:
return EmitWorkgroupVariable(sem);
case ast::StorageClass::kInput:
case ast::StorageClass::kOutput:
return EmitIOVariable(sem);
default:
break;
}
TINT_ICE(Writer, diagnostics_) << "unhandled storage class " << sem->StorageClass();
return false;
return Switch(
global, //
[&](const ast::Var* var) {
auto* sem = builder_.Sem().Get(global);
switch (sem->StorageClass()) {
case ast::StorageClass::kUniform:
return EmitUniformVariable(var, sem);
case ast::StorageClass::kStorage:
return EmitStorageVariable(var, sem);
case ast::StorageClass::kHandle:
return EmitHandleVariable(var, sem);
case ast::StorageClass::kPrivate:
return EmitPrivateVariable(sem);
case ast::StorageClass::kWorkgroup:
return EmitWorkgroupVariable(sem);
case ast::StorageClass::kInput:
case ast::StorageClass::kOutput:
return EmitIOVariable(sem);
default:
TINT_ICE(Writer, diagnostics_)
<< "unhandled storage class " << sem->StorageClass();
return false;
}
},
[&](const ast::Let* let) { return EmitProgramConstVariable(let); },
[&](const ast::Override* override) { return EmitOverride(override); },
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unhandled global variable type " << global->TypeInfo().name;
return false;
});
}
bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
auto* type = var->Type()->UnwrapRef();
bool GeneratorImpl::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) {
auto* type = sem->Type()->UnwrapRef();
auto* str = type->As<sem::Struct>();
if (!str) {
TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type";
return false;
}
ast::VariableBindingPoint bp = decl->BindingPoint();
ast::VariableBindingPoint bp = var->BindingPoint();
{
auto out = line();
out << "layout(binding = " << bp.binding->value;
@ -1949,36 +1954,34 @@ bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
out << ") uniform " << UniqueIdentifier(StructName(str)) << " {";
}
EmitStructMembers(current_buffer_, str, /* emit_offsets */ true);
auto name = builder_.Symbols().NameFor(decl->symbol);
auto name = builder_.Symbols().NameFor(var->symbol);
line() << "} " << name << ";";
line();
return true;
}
bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
auto* type = var->Type()->UnwrapRef();
bool GeneratorImpl::EmitStorageVariable(const ast::Var* var, const sem::Variable* sem) {
auto* type = sem->Type()->UnwrapRef();
auto* str = type->As<sem::Struct>();
if (!str) {
TINT_ICE(Writer, builder_.Diagnostics()) << "storage variable must be of struct type";
return false;
}
ast::VariableBindingPoint bp = decl->BindingPoint();
ast::VariableBindingPoint bp = var->BindingPoint();
line() << "layout(binding = " << bp.binding->value << ", std430) buffer "
<< UniqueIdentifier(StructName(str)) << " {";
EmitStructMembers(current_buffer_, str, /* emit_offsets */ true);
auto name = builder_.Symbols().NameFor(decl->symbol);
auto name = builder_.Symbols().NameFor(var->symbol);
line() << "} " << name << ";";
return true;
}
bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
bool GeneratorImpl::EmitHandleVariable(const ast::Var* var, const sem::Variable* sem) {
auto out = line();
auto name = builder_.Symbols().NameFor(decl->symbol);
auto* type = var->Type()->UnwrapRef();
auto name = builder_.Symbols().NameFor(var->symbol);
auto* type = sem->Type()->UnwrapRef();
if (type->Is<sem::Sampler>()) {
// GLSL ignores Sampler variables.
return true;
@ -1986,7 +1989,7 @@ bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
if (auto* storage = type->As<sem::StorageTexture>()) {
out << "layout(" << convert_texel_format_to_glsl(storage->texel_format()) << ") ";
}
if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), name)) {
return false;
}
@ -2138,7 +2141,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (wgsize[i].overridable_const) {
auto* global = builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
if (!global->IsOverridable()) {
if (!global->Declaration()->Is<ast::Override>()) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant";
}
@ -2652,7 +2655,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
return EmitSwitch(s);
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
return EmitVariable(v->variable);
return Switch(
v->variable, //
[&](const ast::Var* var) { return EmitVar(var); },
[&](const ast::Let* let) { return EmitLet(let); },
[&](Default) { //
TINT_ICE(Writer, diagnostics_)
<< "unknown variable type: " << v->variable->TypeInfo().name;
return false;
});
}
diagnostics_.add_error(diag::System::Writer,
@ -2934,18 +2945,11 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression*
return true;
}
bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
bool GeneratorImpl::EmitVar(const ast::Var* var) {
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type()->UnwrapRef();
// TODO(dsinclair): Handle variable attributes
if (!var->attributes.empty()) {
diagnostics_.add_error(diag::System::Writer, "Variable attributes are not handled yet");
return false;
}
auto out = line();
// TODO(senorblanco): handle const
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
return false;
@ -2967,58 +2971,74 @@ bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
for (auto* d : var->attributes) {
if (!d->Is<ast::IdAttribute>()) {
diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid");
return false;
}
}
if (!var->is_const) {
diagnostics_.add_error(diag::System::Writer, "Expected a const value");
bool GeneratorImpl::EmitLet(const ast::Let* let) {
auto* sem = builder_.Sem().Get(let);
auto* type = sem->Type()->UnwrapRef();
auto out = line();
// TODO(senorblanco): handle const
if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(let->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type();
auto out = line();
out << "const ";
if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(var->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, var->constructor)) {
return false;
}
out << ";";
return true;
}
bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto* sem = builder_.Sem().Get(override);
auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>();
if (global && global->IsOverridable()) {
auto const_id = global->ConstantId();
auto const_id = global->ConstantId();
line() << "#ifndef " << kSpecConstantPrefix << const_id;
line() << "#ifndef " << kSpecConstantPrefix << const_id;
if (var->constructor != nullptr) {
auto out = line();
out << "#define " << kSpecConstantPrefix << const_id << " ";
if (!EmitExpression(out, var->constructor)) {
return false;
}
} else {
line() << "#error spec constant required for constant id " << const_id;
}
line() << "#endif";
{
auto out = line();
out << "const ";
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
return false;
}
out << " = " << kSpecConstantPrefix << const_id << ";";
if (override->constructor != nullptr) {
auto out = line();
out << "#define " << kSpecConstantPrefix << const_id << " ";
if (!EmitExpression(out, override->constructor)) {
return false;
}
} else {
line() << "#error spec constant required for constant id " << const_id;
}
line() << "#endif";
{
auto out = line();
out << "const ";
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(override->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, var->constructor)) {
return false;
}
out << ";";
out << " = " << kSpecConstantPrefix << const_id << ";";
}
return true;

View File

@ -293,19 +293,22 @@ class GeneratorImpl : public TextGenerator {
bool EmitGlobalVariable(const ast::Variable* global);
/// Handles emitting a global variable with the uniform storage class
/// @param var the global variable
/// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success
bool EmitUniformVariable(const sem::Variable* var);
bool EmitUniformVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the storage storage class
/// @param var the global variable
/// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success
bool EmitStorageVariable(const sem::Variable* var);
bool EmitStorageVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the handle storage class
/// @param var the global variable
/// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success
bool EmitHandleVariable(const sem::Variable* var);
bool EmitHandleVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the private storage class
/// @param var the global variable
@ -437,14 +440,22 @@ class GeneratorImpl : public TextGenerator {
/// @param type the type to emit the value for
/// @returns true if the zero value was successfully emitted.
bool EmitZeroValue(std::ostream& out, const sem::Type* type);
/// Handles generating a variable
/// Handles generating a 'var' declaration
/// @param var the variable to generate
/// @returns true if the variable was emitted
bool EmitVariable(const ast::Variable* var);
/// Handles generating a program scope constant variable
/// @param var the variable to emit
bool EmitVar(const ast::Var* var);
/// Handles generating a function-scope 'let' declaration
/// @param let the variable to generate
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* var);
bool EmitLet(const ast::Let* let);
/// Handles generating a module-scope 'let' declaration
/// @param let the 'let' to emit
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* let);
/// Handles generating a module-scope 'override' declaration
/// @param override the 'override' to emit
/// @returns true if the variable was emitted
bool EmitOverride(const ast::Override* override);
/// Handles generating a builtin method name
/// @param builtin the semantic info for the builtin
/// @returns the name or "" if not valid

View File

@ -40,7 +40,7 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#define WGSL_SPEC_CONSTANT_23 3.0f
#endif
@ -56,7 +56,7 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#error spec constant required for constant id 23
#endif
@ -73,8 +73,8 @@ TEST_F(GlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(a)) << gen.error();
ASSERT_TRUE(gen.EmitProgramConstVariable(b)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(a)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(b)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0
#define WGSL_SPEC_CONSTANT_0 3.0f
#endif

View File

@ -365,7 +365,7 @@ bool GeneratorImpl::EmitDynamicVectorAssignment(const ast::AssignmentStatement*
out << "vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;";
break;
default:
TINT_UNREACHABLE(Writer, builder_.Diagnostics())
TINT_UNREACHABLE(Writer, diagnostics_)
<< "invalid vector size " << vec->Width();
break;
}
@ -524,7 +524,7 @@ bool GeneratorImpl::EmitDynamicMatrixScalarAssignment(const ast::AssignmentState
<< vec_name << ";";
break;
default:
TINT_UNREACHABLE(Writer, builder_.Diagnostics())
TINT_UNREACHABLE(Writer, diagnostics_)
<< "invalid vector size " << vec->Width();
break;
}
@ -2861,41 +2861,46 @@ bool GeneratorImpl::EmitFunctionBodyWithDiscard(const ast::Function* func) {
}
bool GeneratorImpl::EmitGlobalVariable(const ast::Variable* global) {
if (global->is_const) {
return EmitProgramConstVariable(global);
}
auto* sem = builder_.Sem().Get(global);
switch (sem->StorageClass()) {
case ast::StorageClass::kUniform:
return EmitUniformVariable(sem);
case ast::StorageClass::kStorage:
return EmitStorageVariable(sem);
case ast::StorageClass::kHandle:
return EmitHandleVariable(sem);
case ast::StorageClass::kPrivate:
return EmitPrivateVariable(sem);
case ast::StorageClass::kWorkgroup:
return EmitWorkgroupVariable(sem);
default:
break;
}
TINT_ICE(Writer, diagnostics_) << "unhandled storage class " << sem->StorageClass();
return false;
return Switch(
global, //
[&](const ast::Var* var) {
auto* sem = builder_.Sem().Get(global);
switch (sem->StorageClass()) {
case ast::StorageClass::kUniform:
return EmitUniformVariable(var, sem);
case ast::StorageClass::kStorage:
return EmitStorageVariable(var, sem);
case ast::StorageClass::kHandle:
return EmitHandleVariable(var, sem);
case ast::StorageClass::kPrivate:
return EmitPrivateVariable(sem);
case ast::StorageClass::kWorkgroup:
return EmitWorkgroupVariable(sem);
default:
TINT_ICE(Writer, diagnostics_)
<< "unhandled storage class " << sem->StorageClass();
return false;
}
},
[&](const ast::Let* let) { return EmitProgramConstVariable(let); },
[&](const ast::Override* override) { return EmitOverride(override); },
[&](Default) {
TINT_ICE(Writer, diagnostics_)
<< "unhandled global variable type " << global->TypeInfo().name;
return false;
});
}
bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
auto binding_point = decl->BindingPoint();
auto* type = var->Type()->UnwrapRef();
auto name = builder_.Symbols().NameFor(decl->symbol);
bool GeneratorImpl::EmitUniformVariable(const ast::Var* var, const sem::Variable* sem) {
auto binding_point = var->BindingPoint();
auto* type = sem->Type()->UnwrapRef();
auto name = builder_.Symbols().NameFor(var->symbol);
line() << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point) << " {";
{
ScopedIndent si(this);
auto out = line();
if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, var->Access(), name)) {
if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, sem->Access(), name)) {
return false;
}
out << ";";
@ -2906,29 +2911,27 @@ bool GeneratorImpl::EmitUniformVariable(const sem::Variable* var) {
return true;
}
bool GeneratorImpl::EmitStorageVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
auto* type = var->Type()->UnwrapRef();
bool GeneratorImpl::EmitStorageVariable(const ast::Var* var, const sem::Variable* sem) {
auto* type = sem->Type()->UnwrapRef();
auto out = line();
if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, var->Access(),
builder_.Symbols().NameFor(decl->symbol))) {
if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
return false;
}
out << RegisterAndSpace(var->Access() == ast::Access::kRead ? 't' : 'u', decl->BindingPoint())
out << RegisterAndSpace(sem->Access() == ast::Access::kRead ? 't' : 'u', var->BindingPoint())
<< ";";
return true;
}
bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
auto* unwrapped_type = var->Type()->UnwrapRef();
bool GeneratorImpl::EmitHandleVariable(const ast::Var* var, const sem::Variable* sem) {
auto* unwrapped_type = sem->Type()->UnwrapRef();
auto out = line();
auto name = builder_.Symbols().NameFor(decl->symbol);
auto* type = var->Type()->UnwrapRef();
if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
auto name = builder_.Symbols().NameFor(var->symbol);
auto* type = sem->Type()->UnwrapRef();
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(), name)) {
return false;
}
@ -2944,7 +2947,7 @@ bool GeneratorImpl::EmitHandleVariable(const sem::Variable* var) {
}
if (register_space) {
auto bp = decl->BindingPoint();
auto bp = var->BindingPoint();
out << " : register(" << register_space << bp.binding->value << ", space" << bp.group->value
<< ")";
}
@ -3078,8 +3081,8 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (wgsize[i].overridable_const) {
auto* global =
builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
if (!global->IsOverridable()) {
TINT_ICE(Writer, builder_.Diagnostics())
if (!global->Declaration()->Is<ast::Override>()) {
TINT_ICE(Writer, diagnostics_)
<< "expected a pipeline-overridable constant";
}
out << kSpecConstantPrefix << global->ConstantId();
@ -3611,7 +3614,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
return EmitSwitch(s);
},
[&](const ast::VariableDeclStatement* v) { //
return EmitVariable(v->variable);
return Switch(
v->variable, //
[&](const ast::Var* var) { return EmitVar(var); },
[&](const ast::Let* let) { return EmitLet(let); },
[&](Default) { //
TINT_ICE(Writer, diagnostics_)
<< "unknown variable type: " << v->variable->TypeInfo().name;
return false;
});
},
[&](Default) { //
diagnostics_.add_error(diag::System::Writer,
@ -4018,20 +4029,11 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression*
return true;
}
bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
bool GeneratorImpl::EmitVar(const ast::Var* var) {
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type()->UnwrapRef();
// TODO(dsinclair): Handle variable attributes
if (!var->attributes.empty()) {
diagnostics_.add_error(diag::System::Writer, "Variable attributes are not handled yet");
return false;
}
auto out = line();
if (var->is_const) {
out << "const ";
}
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
return false;
@ -4053,60 +4055,71 @@ bool GeneratorImpl::EmitVariable(const ast::Variable* var) {
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
for (auto* d : var->attributes) {
if (!d->Is<ast::IdAttribute>()) {
diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid");
return false;
}
}
if (!var->is_const) {
diagnostics_.add_error(diag::System::Writer, "Expected a const value");
bool GeneratorImpl::EmitLet(const ast::Let* let) {
auto* sem = builder_.Sem().Get(let);
auto* type = sem->Type()->UnwrapRef();
auto out = line();
out << "const ";
if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(let->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
auto* sem = builder_.Sem().Get(var);
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Let* let) {
auto* sem = builder_.Sem().Get(let);
auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>();
if (global && global->IsOverridable()) {
auto const_id = global->ConstantId();
auto out = line();
out << "static const ";
if (!EmitTypeAndName(out, type, ast::StorageClass::kNone, ast::Access::kUndefined,
builder_.Symbols().NameFor(let->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
line() << "#ifndef " << kSpecConstantPrefix << const_id;
return true;
}
if (var->constructor != nullptr) {
auto out = line();
out << "#define " << kSpecConstantPrefix << const_id << " ";
if (!EmitExpression(out, var->constructor)) {
return false;
}
} else {
line() << "#error spec constant required for constant id " << const_id;
}
line() << "#endif";
{
auto out = line();
out << "static const ";
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
return false;
}
out << " = " << kSpecConstantPrefix << const_id << ";";
bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto* sem = builder_.Sem().Get(override);
auto* type = sem->Type();
auto const_id = sem->ConstantId();
line() << "#ifndef " << kSpecConstantPrefix << const_id;
if (override->constructor != nullptr) {
auto out = line();
out << "#define " << kSpecConstantPrefix << const_id << " ";
if (!EmitExpression(out, override->constructor)) {
return false;
}
} else {
line() << "#error spec constant required for constant id " << const_id;
}
line() << "#endif";
{
auto out = line();
out << "static const ";
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol))) {
builder_.Symbols().NameFor(override->symbol))) {
return false;
}
out << " = ";
if (!EmitExpression(out, var->constructor)) {
return false;
}
out << ";";
out << " = " << kSpecConstantPrefix << const_id << ";";
}
return true;
}

View File

@ -303,19 +303,22 @@ class GeneratorImpl : public TextGenerator {
bool EmitGlobalVariable(const ast::Variable* global);
/// Handles emitting a global variable with the uniform storage class
/// @param var the global variable
/// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success
bool EmitUniformVariable(const sem::Variable* var);
bool EmitUniformVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the storage storage class
/// @param var the global variable
/// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success
bool EmitStorageVariable(const sem::Variable* var);
bool EmitStorageVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the handle storage class
/// @param var the global variable
/// @param var the AST node for the 'var'
/// @param sem the semantic node for the 'var'
/// @returns true on success
bool EmitHandleVariable(const sem::Variable* var);
bool EmitHandleVariable(const ast::Var* var, const sem::Variable* sem);
/// Handles emitting a global variable with the private storage class
/// @param var the global variable
@ -437,14 +440,22 @@ class GeneratorImpl : public TextGenerator {
/// @param type the type to emit the value for
/// @returns true if the zero value was successfully emitted.
bool EmitZeroValue(std::ostream& out, const sem::Type* type);
/// Handles generating a variable
/// Handles generating a 'var' declaration
/// @param var the variable to generate
/// @returns true if the variable was emitted
bool EmitVariable(const ast::Variable* var);
/// Handles generating a program scope constant variable
/// @param var the variable to emit
bool EmitVar(const ast::Var* var);
/// Handles generating a function-scope 'let' declaration
/// @param let the variable to generate
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* var);
bool EmitLet(const ast::Let* let);
/// Handles generating a module-scope 'let' declaration
/// @param let the 'let' to emit
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Let* let);
/// Handles generating a module-scope 'override' declaration
/// @param override the 'override' to emit
/// @returns true if the variable was emitted
bool EmitOverride(const ast::Override* override);
/// Emits call to a helper vector assignment function for the input assignment
/// statement and vector type. This is used to work around FXC issues where
/// assignments to vectors with dynamic indices cause compilation failures.

View File

@ -40,7 +40,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#define WGSL_SPEC_CONSTANT_23 3.0f
#endif
@ -56,7 +56,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoConstructor) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_23
#error spec constant required for constant id 23
#endif
@ -73,8 +73,8 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_SpecConstant_NoId) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(a)) << gen.error();
ASSERT_TRUE(gen.EmitProgramConstVariable(b)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(a)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(b)) << gen.error();
EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_0
#define WGSL_SPEC_CONSTANT_0 3.0f
#endif

View File

@ -253,16 +253,13 @@ bool GeneratorImpl::Generate() {
[&](const ast::Alias*) {
return true; // folded away by the writer
},
[&](const ast::Variable* var) {
if (var->is_const) {
TINT_DEFER(line());
return EmitProgramConstVariable(var);
}
// These are pushed into the entry point by sanitizer transforms.
TINT_ICE(Writer, diagnostics_)
<< "module-scope variables should have been handled by the MSL "
"sanitizer";
return false;
[&](const ast::Let* let) {
TINT_DEFER(line());
return EmitProgramConstVariable(let);
},
[&](const ast::Override* override) {
TINT_DEFER(line());
return EmitOverride(override);
},
[&](const ast::Function* func) {
TINT_DEFER(line());
@ -1866,8 +1863,8 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
// Returns the binding index of a variable, requiring that the group
// attribute have a value of zero.
const uint32_t kInvalidBindingIndex = std::numeric_limits<uint32_t>::max();
auto get_binding_index = [&](const ast::Variable* var) -> uint32_t {
auto bp = var->BindingPoint();
auto get_binding_index = [&](const ast::Parameter* param) -> uint32_t {
auto bp = param->BindingPoint();
if (bp.group == nullptr || bp.binding == nullptr) {
TINT_ICE(Writer, diagnostics_)
<< "missing binding attributes for entry point parameter";
@ -1890,15 +1887,15 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
// Emit entry point parameters.
bool first = true;
for (auto* var : func->params) {
for (auto* param : func->params) {
if (!first) {
out << ", ";
}
first = false;
auto* type = program_->Sem().Get(var)->Type()->UnwrapRef();
auto* type = program_->Sem().Get(param)->Type()->UnwrapRef();
auto param_name = program_->Symbols().NameFor(var->symbol);
auto param_name = program_->Symbols().NameFor(param->symbol);
if (!EmitType(out, type, param_name)) {
return false;
}
@ -1910,26 +1907,26 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
if (type->Is<sem::Struct>()) {
out << " [[stage_in]]";
} else if (type->is_handle()) {
uint32_t binding = get_binding_index(var);
uint32_t binding = get_binding_index(param);
if (binding == kInvalidBindingIndex) {
return false;
}
if (var->type->Is<ast::Sampler>()) {
if (param->type->Is<ast::Sampler>()) {
out << " [[sampler(" << binding << ")]]";
} else if (var->type->Is<ast::Texture>()) {
} else if (param->type->Is<ast::Texture>()) {
out << " [[texture(" << binding << ")]]";
} else {
TINT_ICE(Writer, diagnostics_) << "invalid handle type entry point parameter";
return false;
}
} else if (auto* ptr = var->type->As<ast::Pointer>()) {
} else if (auto* ptr = param->type->As<ast::Pointer>()) {
auto sc = ptr->storage_class;
if (sc == ast::StorageClass::kWorkgroup) {
auto& allocations = workgroup_allocations_[func_name];
out << " [[threadgroup(" << allocations.size() << ")]]";
allocations.push_back(program_->Sem().Get(ptr->type)->Size());
} else if (sc == ast::StorageClass::kStorage || sc == ast::StorageClass::kUniform) {
uint32_t binding = get_binding_index(var);
uint32_t binding = get_binding_index(param);
if (binding == kInvalidBindingIndex) {
return false;
}
@ -1940,7 +1937,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
return false;
}
} else {
auto& attrs = var->attributes;
auto& attrs = param->attributes;
bool builtin_found = false;
for (auto* attr : attrs) {
auto* builtin = attr->As<ast::BuiltinAttribute>();
@ -2340,8 +2337,15 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
return EmitSwitch(s);
},
[&](const ast::VariableDeclStatement* v) { //
auto* var = program_->Sem().Get(v->variable);
return EmitVariable(var);
return Switch(
v->variable, //
[&](const ast::Var* var) { return EmitVar(var); },
[&](const ast::Let* let) { return EmitLet(let); },
[&](Default) { //
TINT_ICE(Writer, diagnostics_)
<< "unknown statement type: " << stmt->TypeInfo().name;
return false;
});
},
[&](Default) {
diagnostics_.add_error(diag::System::Writer,
@ -2918,19 +2922,13 @@ bool GeneratorImpl::EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression*
return true;
}
bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
auto* decl = var->Declaration();
for (auto* attr : decl->attributes) {
if (!attr->Is<ast::InternalAttribute>()) {
TINT_ICE(Writer, diagnostics_) << "unexpected variable attribute";
return false;
}
}
bool GeneratorImpl::EmitVar(const ast::Var* var) {
auto* sem = program_->Sem().Get(var);
auto* type = sem->Type()->UnwrapRef();
auto out = line();
switch (var->StorageClass()) {
switch (sem->StorageClass()) {
case ast::StorageClass::kFunction:
case ast::StorageClass::kHandle:
case ast::StorageClass::kNone:
@ -2946,12 +2944,7 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
return false;
}
auto* type = var->Type()->UnwrapRef();
std::string name = program_->Symbols().NameFor(decl->symbol);
if (decl->is_const) {
name = "const " + name;
}
std::string name = program_->Symbols().NameFor(var->symbol);
if (!EmitType(out, type, name)) {
return false;
}
@ -2960,14 +2953,14 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
out << " " << name;
}
if (decl->constructor != nullptr) {
if (var->constructor != nullptr) {
out << " = ";
if (!EmitExpression(out, decl->constructor)) {
if (!EmitExpression(out, var->constructor)) {
return false;
}
} else if (var->StorageClass() == ast::StorageClass::kPrivate ||
var->StorageClass() == ast::StorageClass::kFunction ||
var->StorageClass() == ast::StorageClass::kNone) {
} else if (sem->StorageClass() == ast::StorageClass::kPrivate ||
sem->StorageClass() == ast::StorageClass::kFunction ||
sem->StorageClass() == ast::StorageClass::kNone) {
out << " = ";
if (!EmitZeroValue(out, type)) {
return false;
@ -2978,34 +2971,63 @@ bool GeneratorImpl::EmitVariable(const sem::Variable* var) {
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
for (auto* d : var->attributes) {
if (!d->Is<ast::IdAttribute>()) {
diagnostics_.add_error(diag::System::Writer, "Decorated const values not valid");
bool GeneratorImpl::EmitLet(const ast::Let* let) {
auto* sem = program_->Sem().Get(let);
auto* type = sem->Type();
auto out = line();
switch (sem->StorageClass()) {
case ast::StorageClass::kFunction:
case ast::StorageClass::kHandle:
case ast::StorageClass::kNone:
break;
case ast::StorageClass::kPrivate:
out << "thread ";
break;
case ast::StorageClass::kWorkgroup:
out << "threadgroup ";
break;
default:
TINT_ICE(Writer, diagnostics_) << "unhandled variable storage class";
return false;
}
}
if (!var->is_const) {
diagnostics_.add_error(diag::System::Writer, "Expected a const value");
std::string name = "const " + program_->Symbols().NameFor(let->symbol);
if (!EmitType(out, type, name)) {
return false;
}
// Variable name is output as part of the type for arrays and pointers.
if (!type->Is<sem::Array>() && !type->Is<sem::Pointer>()) {
out << " " << name;
}
out << " = ";
if (!EmitExpression(out, let->constructor)) {
return false;
}
out << ";";
return true;
}
bool GeneratorImpl::EmitProgramConstVariable(const ast::Let* let) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(let);
auto* type = global->Type();
auto out = line();
out << "constant ";
auto* type = program_->Sem().Get(var)->Type()->UnwrapRef();
if (!EmitType(out, type, program_->Symbols().NameFor(var->symbol))) {
if (!EmitType(out, type, program_->Symbols().NameFor(let->symbol))) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << program_->Symbols().NameFor(var->symbol);
out << " " << program_->Symbols().NameFor(let->symbol);
}
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsOverridable()) {
out << " [[function_constant(" << global->ConstantId() << ")]]";
} else if (var->constructor != nullptr) {
if (let->constructor != nullptr) {
out << " = ";
if (!EmitExpression(out, var->constructor)) {
if (!EmitExpression(out, let->constructor)) {
return false;
}
}
@ -3014,6 +3036,24 @@ bool GeneratorImpl::EmitProgramConstVariable(const ast::Variable* var) {
return true;
}
bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(override);
auto* type = global->Type();
auto out = line();
out << "constant ";
if (!EmitType(out, type, program_->Symbols().NameFor(override->symbol))) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << program_->Symbols().NameFor(override->symbol);
}
out << " [[function_constant(" << global->ConstantId() << ")]];";
return true;
}
GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::Type* ty) {
return Switch(
ty,

View File

@ -348,14 +348,22 @@ class GeneratorImpl : public TextGenerator {
/// @param expr the expression to emit
/// @returns true if the expression was emitted
bool EmitUnaryOp(std::ostream& out, const ast::UnaryOpExpression* expr);
/// Handles generating a variable
/// Handles generating a 'var' declaration
/// @param var the variable to generate
/// @returns true if the variable was emitted
bool EmitVariable(const sem::Variable* var);
/// Handles generating a program scope constant variable
/// @param var the variable to emit
bool EmitVar(const ast::Var* var);
/// Handles generating a function-scope 'let' declaration
/// @param let the variable to generate
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Variable* var);
bool EmitLet(const ast::Let* let);
/// Handles generating a module-scope 'let' declaration
/// @param let the 'let' to emit
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(const ast::Let* let);
/// Handles generating a module-scope 'override' declaration
/// @param override the 'override' to emit
/// @returns true if the variable was emitted
bool EmitOverride(const ast::Override* override);
/// Emits the zero value for the given type
/// @param out the output of the expression stream
/// @param type the type to emit the value for

View File

@ -39,7 +39,7 @@ TEST_F(MslGeneratorImplTest, Emit_SpecConstant) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(var)) << gen.error();
EXPECT_EQ(gen.result(), "constant float pos [[function_constant(23)]];\n");
}
@ -52,8 +52,8 @@ TEST_F(MslGeneratorImplTest, Emit_SpecConstant_NoId) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitProgramConstVariable(var_a)) << gen.error();
ASSERT_TRUE(gen.EmitProgramConstVariable(var_b)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(var_a)) << gen.error();
ASSERT_TRUE(gen.EmitOverride(var_b)) << gen.error();
EXPECT_EQ(gen.result(), R"(constant float a [[function_constant(0)]];
constant float b [[function_constant(1)]];
)");

View File

@ -533,7 +533,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
// Make the constant specializable.
auto* sem_const =
builder_.Sem().Get<sem::GlobalVariable>(wgsize[i].overridable_const);
if (!sem_const->IsOverridable()) {
if (!sem_const->Declaration()->Is<ast::Override>()) {
TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant";
}
@ -692,19 +692,19 @@ uint32_t Builder::GenerateFunctionTypeIfNeeded(const sem::Function* func) {
});
}
bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
bool Builder::GenerateFunctionVariable(const ast::Variable* v) {
uint32_t init_id = 0;
if (var->constructor) {
init_id = GenerateExpressionWithLoadIfNeeded(var->constructor);
if (v->constructor) {
init_id = GenerateExpressionWithLoadIfNeeded(v->constructor);
if (init_id == 0) {
return false;
}
}
auto* sem = builder_.Sem().Get(var);
auto* sem = builder_.Sem().Get(v);
if (var->is_const) {
if (!var->constructor) {
if (auto* let = v->As<ast::Let>()) {
if (!let->constructor) {
error_ = "missing constructor for constant";
return false;
}
@ -721,8 +721,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
return false;
}
push_debug(spv::Op::OpName,
{Operand(var_id), Operand(builder_.Symbols().NameFor(var->symbol))});
push_debug(spv::Op::OpName, {Operand(var_id), Operand(builder_.Symbols().NameFor(v->symbol))});
// TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad.
@ -733,7 +732,7 @@ bool Builder::GenerateFunctionVariable(const ast::Variable* var) {
push_function_var(
{Operand(type_id), result, U32Operand(ConvertStorageClass(sc)), Operand(null_id)});
if (var->constructor) {
if (v->constructor) {
if (!GenerateStore(var_id, init_id)) {
return false;
}
@ -748,66 +747,61 @@ bool Builder::GenerateStore(uint32_t to, uint32_t from) {
return push_function_inst(spv::Op::OpStore, {Operand(to), Operand(from)});
}
bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
auto* sem = builder_.Sem().Get(var);
bool Builder::GenerateGlobalVariable(const ast::Variable* v) {
auto* sem = builder_.Sem().Get(v);
auto* type = sem->Type()->UnwrapRef();
uint32_t init_id = 0;
if (var->constructor) {
if (!var->is_overridable) {
auto* ctor = builder_.Sem().Get(var->constructor);
if (auto constant = ctor->ConstantValue()) {
if (auto* ctor = v->constructor) {
if (!v->Is<ast::Override>()) {
auto* ctor_sem = builder_.Sem().Get(ctor);
if (auto constant = ctor_sem->ConstantValue()) {
init_id = GenerateConstantIfNeeded(std::move(constant));
}
}
if (init_id == 0) {
init_id = GenerateConstructorExpression(var, var->constructor);
init_id = GenerateConstructorExpression(v, v->constructor);
}
if (init_id == 0) {
return false;
}
}
if (var->is_const) {
if (!var->constructor) {
// Constants must have an initializer unless they are overridable.
if (!var->is_overridable) {
error_ = "missing constructor for constant";
return false;
}
// SPIR-V requires specialization constants to have initializers.
init_id = Switch(
type, //
[&](const sem::F32*) {
ast::FloatLiteralExpression l(ProgramID{}, Source{}, 0,
ast::FloatLiteralExpression::Suffix::kF);
return GenerateLiteralIfNeeded(var, &l);
},
[&](const sem::U32*) {
ast::IntLiteralExpression l(ProgramID{}, Source{}, 0,
ast::IntLiteralExpression::Suffix::kU);
return GenerateLiteralIfNeeded(var, &l);
},
[&](const sem::I32*) {
ast::IntLiteralExpression l(ProgramID{}, Source{}, 0,
ast::IntLiteralExpression::Suffix::kI);
return GenerateLiteralIfNeeded(var, &l);
},
[&](const sem::Bool*) {
ast::BoolLiteralExpression l(ProgramID{}, Source{}, false);
return GenerateLiteralIfNeeded(var, &l);
},
[&](Default) {
error_ = "invalid type for pipeline constant ID, must be scalar";
return 0;
});
if (init_id == 0) {
if (auto* override = v->As<ast::Override>(); override && !override->constructor) {
// SPIR-V requires specialization constants to have initializers.
init_id = Switch(
type, //
[&](const sem::F32*) {
ast::FloatLiteralExpression l(ProgramID{}, Source{}, 0,
ast::FloatLiteralExpression::Suffix::kF);
return GenerateLiteralIfNeeded(override, &l);
},
[&](const sem::U32*) {
ast::IntLiteralExpression l(ProgramID{}, Source{}, 0,
ast::IntLiteralExpression::Suffix::kU);
return GenerateLiteralIfNeeded(override, &l);
},
[&](const sem::I32*) {
ast::IntLiteralExpression l(ProgramID{}, Source{}, 0,
ast::IntLiteralExpression::Suffix::kI);
return GenerateLiteralIfNeeded(override, &l);
},
[&](const sem::Bool*) {
ast::BoolLiteralExpression l(ProgramID{}, Source{}, false);
return GenerateLiteralIfNeeded(override, &l);
},
[&](Default) {
error_ = "invalid type for pipeline constant ID, must be scalar";
return 0;
}
});
if (init_id == 0) {
return 0;
}
}
if (v->IsAnyOf<ast::Let, ast::Override>()) {
push_debug(spv::Op::OpName,
{Operand(init_id), Operand(builder_.Symbols().NameFor(var->symbol))});
{Operand(init_id), Operand(builder_.Symbols().NameFor(v->symbol))});
RegisterVariable(sem, init_id);
return true;
@ -824,12 +818,11 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
return false;
}
push_debug(spv::Op::OpName,
{Operand(var_id), Operand(builder_.Symbols().NameFor(var->symbol))});
push_debug(spv::Op::OpName, {Operand(var_id), Operand(builder_.Symbols().NameFor(v->symbol))});
OperandList ops = {Operand(type_id), result, U32Operand(ConvertStorageClass(sc))};
if (var->constructor) {
if (v->constructor) {
ops.push_back(Operand(init_id));
} else {
auto* st = type->As<sem::StorageTexture>();
@ -871,7 +864,7 @@ bool Builder::GenerateGlobalVariable(const ast::Variable* var) {
push_type(spv::Op::OpVariable, std::move(ops));
for (auto* attr : var->attributes) {
for (auto* attr : v->attributes) {
bool ok = Switch(
attr,
[&](const ast::BuiltinAttribute* builtin) {
@ -1332,7 +1325,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call,
// Generate the zero initializer if there are no values provided.
if (args.empty()) {
if (global_var && global_var->IsOverridable()) {
if (global_var && global_var->Declaration()->Is<ast::Override>()) {
auto constant_id = global_var->ConstantId();
if (result_type->Is<sem::I32>()) {
return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id));
@ -1637,7 +1630,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
ScalarConstant constant;
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
if (global && global->IsOverridable()) {
if (global && global->Declaration()->Is<ast::Override>()) {
constant.is_spec_op = true;
constant.constant_id = global->ConstantId();
}

View File

@ -635,46 +635,60 @@ bool GeneratorImpl::EmitStructType(const ast::Struct* str) {
return true;
}
bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* var) {
if (!var->attributes.empty()) {
if (!EmitAttributes(out, var->attributes)) {
bool GeneratorImpl::EmitVariable(std::ostream& out, const ast::Variable* v) {
if (!v->attributes.empty()) {
if (!EmitAttributes(out, v->attributes)) {
return false;
}
out << " ";
}
if (var->is_overridable) {
out << "override";
} else if (var->is_const) {
out << "let";
} else {
out << "var";
auto sc = var->declared_storage_class;
auto ac = var->declared_access;
if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) {
out << "<" << sc;
if (ac != ast::Access::kUndefined) {
out << ", ";
if (!EmitAccess(out, ac)) {
return false;
bool ok = Switch(
v, //
[&](const ast::Let* ) {
out << "let";
return true;
},
[&](const ast::Override* ) {
out << "override";
return true;
},
[&](const ast::Var* var) {
out << "var";
auto sc = var->declared_storage_class;
auto ac = var->declared_access;
if (sc != ast::StorageClass::kNone || ac != ast::Access::kUndefined) {
out << "<" << sc;
if (ac != ast::Access::kUndefined) {
out << ", ";
if (!EmitAccess(out, ac)) {
return false;
}
}
out << ">";
}
out << ">";
}
return true;
},
[&](Default) {
TINT_ICE(Writer, diagnostics_) << "unhandled variable type " << v->TypeInfo().name;
return false;
});
if (!ok) {
return false;
}
out << " " << program_->Symbols().NameFor(var->symbol);
out << " " << program_->Symbols().NameFor(v->symbol);
if (auto* ty = var->type) {
if (auto* ty = v->type) {
out << " : ";
if (!EmitType(out, ty)) {
return false;
}
}
if (var->constructor != nullptr) {
if (v->constructor != nullptr) {
out << " = ";
if (!EmitExpression(out, var->constructor)) {
if (!EmitExpression(out, v->constructor)) {
return false;
}
}

View File

@ -1,4 +1,5 @@
static const uint width = 128u;
Texture2D tex : register(t0, space0);
RWByteAddressBuffer result : register(u1, space0);

View File

@ -45,6 +45,7 @@ static const uint ColPerThread = 4u;
static const uint TileAOuter = 64u;
static const uint TileBOuter = 64u;
static const uint TileInner = 64u;
groupshared float mm_Asub[64][64];
groupshared float mm_Bsub[64][64];