Rework Resolver so that we construct semantic types in a single pass.

The semantic nodes cannot be fully immutable, as they contain cyclic
references. Remove Resolver::CreateSemanticNodes(), and instead
construct and mutate the semantic nodes in the single traversal pass.

Give up on trying to maintain the 'authored' type names (aliased names).
These are a nightmare to maintain, and provided limited use.

Significantly simplfies the Resolver, and allows us to generate more
semantic to semantic references, reducing sem -> ast -> sem hops.

Note: This change introduces constant value propagation across constant
variables. This is unlocked by the earlier construction of the
sem::Variable.

Change-Id: I592092fdc47fe24d30e512952511c9ab7c16d7a1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/68406
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton
2021-11-05 16:51:38 +00:00
committed by Tint LUCI CQ
parent 2423df3e04
commit a9156ff091
43 changed files with 1448 additions and 1550 deletions

View File

@@ -26,8 +26,9 @@ namespace tint {
namespace sem {
BlockStatement::BlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
BlockStatement::~BlockStatement() = default;
@@ -39,15 +40,19 @@ void BlockStatement::AddDecl(const ast::Variable* var) {
decls_.push_back(var);
}
FunctionBlockStatement::FunctionBlockStatement(const ast::Function* function)
: Base(function->body, nullptr), function_(function) {}
FunctionBlockStatement::FunctionBlockStatement(const sem::Function* function)
: Base(function->Declaration()->body, nullptr, function) {
TINT_ASSERT(Semantic, function);
}
FunctionBlockStatement::~FunctionBlockStatement() = default;
LoopBlockStatement::LoopBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
LoopBlockStatement::~LoopBlockStatement() = default;

View File

@@ -40,8 +40,10 @@ class BlockStatement : public Castable<BlockStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
BlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~BlockStatement() override;
@@ -67,16 +69,10 @@ class FunctionBlockStatement
public:
/// Constructor
/// @param function the owning function
explicit FunctionBlockStatement(const ast::Function* function);
explicit FunctionBlockStatement(const sem::Function* function);
/// Destructor
~FunctionBlockStatement() override;
/// @returns the function owning this block
const ast::Function* Function() const { return function_; }
private:
ast::Function const* const function_;
};
/// Holds semantic information about a loop body block or for-loop body block
@@ -85,8 +81,10 @@ class LoopBlockStatement : public Castable<LoopBlockStatement, BlockStatement> {
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
LoopBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~LoopBlockStatement() override;

View File

@@ -14,16 +14,21 @@
#include "src/sem/call.h"
#include <utility>
#include <vector>
TINT_INSTANTIATE_TYPEINFO(tint::sem::Call);
namespace tint {
namespace sem {
Call::Call(const ast::Expression* declaration,
Call::Call(const ast::CallExpression* declaration,
const CallTarget* target,
std::vector<const sem::Expression*> arguments,
Statement* statement)
: Base(declaration, target->ReturnType(), statement, Constant{}),
target_(target) {}
target_(target),
arguments_(std::move(arguments)) {}
Call::~Call() = default;

View File

@@ -15,6 +15,8 @@
#ifndef SRC_SEM_CALL_H_
#define SRC_SEM_CALL_H_
#include <vector>
#include "src/sem/expression.h"
#include "src/sem/intrinsic.h"
@@ -28,9 +30,11 @@ class Call : public Castable<Call, Expression> {
/// Constructor
/// @param declaration the AST node
/// @param target the call target
/// @param arguments the call arguments
/// @param statement the statement that owns this expression
Call(const ast::Expression* declaration,
Call(const ast::CallExpression* declaration,
const CallTarget* target,
std::vector<const sem::Expression*> arguments,
Statement* statement);
/// Destructor
@@ -39,8 +43,19 @@ class Call : public Castable<Call, Expression> {
/// @return the target of the call
const CallTarget* Target() const { return target_; }
/// @return the call arguments
const std::vector<const sem::Expression*>& Arguments() const {
return arguments_;
}
/// @returns the AST node
const ast::CallExpression* Declaration() const {
return static_cast<const ast::CallExpression*>(declaration_);
}
private:
CallTarget const* const target_;
std::vector<const sem::Expression*> arguments_;
};
} // namespace sem

View File

@@ -59,5 +59,7 @@ Constant::Constant(const Constant&) = default;
Constant::~Constant() = default;
Constant& Constant::operator=(const Constant& rhs) = default;
} // namespace sem
} // namespace tint

View File

@@ -77,6 +77,11 @@ class Constant {
/// Destructor
~Constant();
/// Copy assignment
/// @param other the Constant to copy
/// @returns this Constant
Constant& operator=(const Constant& other);
/// @returns true if the Constant has been initialized
bool IsValid() const { return type_ != nullptr; }

View File

@@ -53,8 +53,11 @@ class Expression : public Castable<Expression, Node> {
/// @returns the AST node
const ast::Expression* Declaration() const { return declaration_; }
private:
protected:
/// The AST expression node for this semantic expression
const ast::Expression* const declaration_;
private:
const sem::Type* const type_;
const Statement* const statement_;
const Constant constant_;

View File

@@ -22,8 +22,9 @@ namespace tint {
namespace sem {
ForLoopStatement::ForLoopStatement(const ast::ForLoopStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
ForLoopStatement::~ForLoopStatement() = default;

View File

@@ -32,8 +32,10 @@ class ForLoopStatement : public Castable<ForLoopStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this for-loop statement
/// @param parent the owning statement
/// @param function the owning function
ForLoopStatement(const ast::ForLoopStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~ForLoopStatement() override;

View File

@@ -28,23 +28,13 @@ TINT_INSTANTIATE_TYPEINFO(tint::sem::Function);
namespace tint {
namespace sem {
Function::Function(
const ast::Function* declaration,
Type* return_type,
std::vector<Parameter*> parameters,
std::vector<const GlobalVariable*> transitively_referenced_globals,
std::vector<const GlobalVariable*> directly_referenced_globals,
std::vector<const ast::CallExpression*> callsites,
std::vector<Symbol> ancestor_entry_points,
sem::WorkgroupSize workgroup_size)
Function::Function(const ast::Function* declaration,
Type* return_type,
std::vector<Parameter*> parameters,
sem::WorkgroupSize workgroup_size)
: Base(return_type, utils::ToConstPtrVec(parameters)),
declaration_(declaration),
workgroup_size_(std::move(workgroup_size)),
directly_referenced_globals_(std::move(directly_referenced_globals)),
transitively_referenced_globals_(
std::move(transitively_referenced_globals)),
callsites_(callsites),
ancestor_entry_points_(std::move(ancestor_entry_points)) {
workgroup_size_(std::move(workgroup_size)) {
for (auto* parameter : parameters) {
parameter->SetOwner(this);
}
@@ -150,8 +140,8 @@ Function::VariableBindings Function::TransitivelyReferencedVariablesOfType(
}
bool Function::HasAncestorEntryPoint(Symbol symbol) const {
for (const auto& point : ancestor_entry_points_) {
if (point == symbol) {
for (const auto* point : ancestor_entry_points_) {
if (point->Declaration()->symbol == symbol) {
return true;
}
}

View File

@@ -62,18 +62,10 @@ class Function : public Castable<Function, CallTarget> {
/// @param declaration the ast::Function
/// @param return_type the return type of the function
/// @param parameters the parameters to the function
/// @param transitively_referenced_globals the referenced module variables
/// @param directly_referenced_globals the locally referenced module
/// @param callsites the callsites of the function
/// @param ancestor_entry_points the ancestor entry points
/// @param workgroup_size the workgroup size
Function(const ast::Function* declaration,
Type* return_type,
std::vector<Parameter*> parameters,
std::vector<const GlobalVariable*> transitively_referenced_globals,
std::vector<const GlobalVariable*> directly_referenced_globals,
std::vector<const ast::CallExpression*> callsites,
std::vector<Symbol> ancestor_entry_points,
sem::WorkgroupSize workgroup_size);
/// Destructor
@@ -85,22 +77,98 @@ class Function : public Castable<Function, CallTarget> {
/// @returns the workgroup size {x, y, z} for the function.
const sem::WorkgroupSize& WorkgroupSize() const { return workgroup_size_; }
/// @returns all directly referenced global variables
const utils::UniqueVector<const GlobalVariable*>& DirectlyReferencedGlobals()
const {
return directly_referenced_globals_;
}
/// Records that this function directly references the given global variable.
/// Note: Implicitly adds this global to the transtively-called globals.
/// @param global the module-scope variable
void AddDirectlyReferencedGlobal(const sem::GlobalVariable* global) {
directly_referenced_globals_.add(global);
transitively_referenced_globals_.add(global);
}
/// @returns all transitively referenced global variables
const utils::UniqueVector<const GlobalVariable*>&
TransitivelyReferencedGlobals() const {
return transitively_referenced_globals_;
}
/// @returns the list of callsites of this function
std::vector<const ast::CallExpression*> CallSites() const {
return callsites_;
/// Records that this function transitively references the given global
/// variable.
/// @param global the module-scoped variable
void AddTransitivelyReferencedGlobal(const sem::GlobalVariable* global) {
transitively_referenced_globals_.add(global);
}
/// @returns the names of the ancestor entry points
const std::vector<Symbol>& AncestorEntryPoints() const {
/// @returns the list of functions that this function transitively calls.
const utils::UniqueVector<const Function*>& TransitivelyCalledFunctions()
const {
return transitively_called_functions_;
}
/// Records that this function transitively calls `function`.
/// @param function the function this function transitively calls
void AddTransitivelyCalledFunction(const Function* function) {
transitively_called_functions_.add(function);
}
/// @returns the list of intrinsics that this function directly calls.
const utils::UniqueVector<const Intrinsic*>& DirectlyCalledIntrinsics()
const {
return directly_called_intrinsics_;
}
/// Records that this function transitively calls `intrinsic`.
/// @param intrinsic the intrinsic this function directly calls
void AddDirectlyCalledIntrinsic(const Intrinsic* intrinsic) {
directly_called_intrinsics_.add(intrinsic);
}
/// @returns the list of direct calls to functions / intrinsics made by this
/// function
std::vector<const Call*> DirectCallStatements() const {
return direct_calls_;
}
/// Adds a record of the direct function / intrinsic calls made by this
/// function
/// @param call the call
void AddDirectCall(const Call* call) { direct_calls_.emplace_back(call); }
/// @param target the target of a call
/// @returns the Call to the given CallTarget, or nullptr the target was not
/// called by this function.
const Call* FindDirectCallTo(const CallTarget* target) const {
for (auto* call : direct_calls_) {
if (call->Target() == target) {
return call;
}
}
return nullptr;
}
/// @returns the list of callsites of this function
std::vector<const Call*> CallSites() const { return callsites_; }
/// Adds a record of a callsite to this function
/// @param call the callsite
void AddCallSite(const Call* call) { callsites_.emplace_back(call); }
/// @returns the ancestor entry points
const std::vector<const Function*>& AncestorEntryPoints() const {
return ancestor_entry_points_;
}
/// Adds a record that the given entry point transitively calls this function
/// @param entry_point the entry point that transtively calls this function
void AddAncestorEntryPoint(const sem::Function* entry_point) {
ancestor_entry_points_.emplace_back(entry_point);
}
/// Retrieves any referenced location variables
/// @returns the <variable, decoration> pair.
std::vector<std::pair<const Variable*, const ast::LocationDecoration*>>
@@ -174,8 +242,9 @@ class Function : public Castable<Function, CallTarget> {
utils::UniqueVector<const GlobalVariable*> transitively_referenced_globals_;
utils::UniqueVector<const Function*> transitively_called_functions_;
utils::UniqueVector<const Intrinsic*> directly_called_intrinsics_;
std::vector<const ast::CallExpression*> callsites_;
std::vector<Symbol> ancestor_entry_points_;
std::vector<const Call*> direct_calls_;
std::vector<const Call*> callsites_;
std::vector<const Function*> ancestor_entry_points_;
};
} // namespace sem

View File

@@ -23,14 +23,16 @@ namespace tint {
namespace sem {
IfStatement::IfStatement(const ast::IfStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
IfStatement::~IfStatement() = default;
ElseStatement::ElseStatement(const ast::ElseStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
ElseStatement::~ElseStatement() = default;

View File

@@ -34,7 +34,10 @@ class IfStatement : public Castable<IfStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this if statement
/// @param parent the owning statement
IfStatement(const ast::IfStatement* declaration, CompoundStatement* parent);
/// @param function the owning function
IfStatement(const ast::IfStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~IfStatement() override;
@@ -46,8 +49,10 @@ class ElseStatement : public Castable<ElseStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this else statement
/// @param parent the owning statement
/// @param function the owning function
ElseStatement(const ast::ElseStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~ElseStatement() override;

View File

@@ -27,10 +27,18 @@ namespace sem {
/// Info holds all the resolved semantic information for a Program.
class Info {
public:
/// Placeholder type used by Get() to provide a default value for EXPLICIT_SEM
using InferFromAST = std::nullptr_t;
public:
/// Resolves to the return type of the Get() method given the desired sementic
/// type and AST type.
template <typename SEM, typename AST_OR_TYPE>
using GetResultType =
std::conditional_t<std::is_same<SEM, InferFromAST>::value,
SemanticNodeTypeFor<AST_OR_TYPE>,
SEM>;
/// Constructor
Info();
@@ -50,10 +58,7 @@ class Info {
/// @returns a pointer to the semantic node if found, otherwise nullptr
template <typename SEM = InferFromAST,
typename AST_OR_TYPE = CastableBase,
typename RESULT =
std::conditional_t<std::is_same<SEM, InferFromAST>::value,
SemanticNodeTypeFor<AST_OR_TYPE>,
SEM>>
typename RESULT = GetResultType<SEM, AST_OR_TYPE>>
const RESULT* Get(const AST_OR_TYPE* node) const {
auto it = map.find(node);
if (it == map.end()) {

View File

@@ -23,16 +23,22 @@ namespace tint {
namespace sem {
LoopStatement::LoopStatement(const ast::LoopStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
LoopStatement::~LoopStatement() = default;
LoopContinuingBlockStatement::LoopContinuingBlockStatement(
const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
LoopContinuingBlockStatement::~LoopContinuingBlockStatement() = default;

View File

@@ -33,8 +33,10 @@ class LoopStatement : public Castable<LoopStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this loop statement
/// @param parent the owning statement
/// @param function the owning function
LoopStatement(const ast::LoopStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~LoopStatement() override;
@@ -47,8 +49,10 @@ class LoopContinuingBlockStatement
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
LoopContinuingBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~LoopContinuingBlockStatement() override;

View File

@@ -27,23 +27,18 @@ namespace tint {
namespace sem {
Statement::Statement(const ast::Statement* declaration,
const CompoundStatement* parent)
: declaration_(declaration), parent_(parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: declaration_(declaration), parent_(parent), function_(function) {}
const BlockStatement* Statement::Block() const {
return FindFirstParent<BlockStatement>();
}
const ast::Function* Statement::Function() const {
if (auto* fbs = FindFirstParent<FunctionBlockStatement>()) {
return fbs->Function();
}
return nullptr;
}
CompoundStatement::CompoundStatement(const ast::Statement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
CompoundStatement::~CompoundStatement() = default;

View File

@@ -33,6 +33,7 @@ namespace sem {
/// Forward declaration
class CompoundStatement;
class Function;
namespace detail {
/// FindFirstParentReturn is a traits helper for determining the return type for
@@ -64,7 +65,10 @@ class Statement : public Castable<Statement, Node> {
/// Constructor
/// @param declaration the AST node for this statement
/// @param parent the owning statement
Statement(const ast::Statement* declaration, const CompoundStatement* parent);
/// @param function the owning function
Statement(const ast::Statement* declaration,
const CompoundStatement* parent,
const sem::Function* function);
/// @return the AST node for this statement
const ast::Statement* Declaration() const { return declaration_; }
@@ -90,11 +94,12 @@ class Statement : public Castable<Statement, Node> {
const BlockStatement* Block() const;
/// @returns the function that owns this statement
const ast::Function* Function() const;
const sem::Function* Function() const { return function_; }
private:
ast::Statement const* const declaration_;
CompoundStatement const* const parent_;
const ast::Statement* const declaration_;
const CompoundStatement* const parent_;
const sem::Function* const function_;
};
/// CompoundStatement is the base class of statements that can hold other
@@ -103,9 +108,11 @@ class CompoundStatement : public Castable<Statement, Statement> {
public:
/// Constructor
/// @param declaration the AST node for this statement
/// @param parent the owning statement
/// @param statement the owning statement
/// @param function the owning function
CompoundStatement(const ast::Statement* declaration,
const CompoundStatement* parent);
const CompoundStatement* statement,
const sem::Function* function);
/// Destructor
~CompoundStatement() override;

View File

@@ -23,16 +23,22 @@ namespace tint {
namespace sem {
SwitchStatement::SwitchStatement(const ast::SwitchStatement* declaration,
CompoundStatement* parent)
: Base(declaration, parent) {}
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
SwitchStatement::~SwitchStatement() = default;
SwitchCaseBlockStatement::SwitchCaseBlockStatement(
const ast::BlockStatement* declaration,
const CompoundStatement* parent)
: Base(declaration, parent) {
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {
TINT_ASSERT(Semantic, parent);
TINT_ASSERT(Semantic, function);
}
SwitchCaseBlockStatement::~SwitchCaseBlockStatement() = default;

View File

@@ -33,8 +33,10 @@ class SwitchStatement : public Castable<SwitchStatement, CompoundStatement> {
/// Constructor
/// @param declaration the AST node for this switch statement
/// @param parent the owning statement
/// @param function the owning function
SwitchStatement(const ast::SwitchStatement* declaration,
CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~SwitchStatement() override;
@@ -47,8 +49,10 @@ class SwitchCaseBlockStatement
/// Constructor
/// @param declaration the AST node for this block statement
/// @param parent the owning statement
/// @param function the owning function
SwitchCaseBlockStatement(const ast::BlockStatement* declaration,
const CompoundStatement* parent);
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~SwitchCaseBlockStatement() override;

View File

@@ -31,19 +31,26 @@ namespace sem {
Variable::Variable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access)
ast::Access access,
Constant constant_value)
: declaration_(declaration),
type_(type),
storage_class_(storage_class),
access_(access) {}
access_(access),
constant_value_(constant_value) {}
Variable::~Variable() = default;
LocalVariable::LocalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access)
: Base(declaration, type, storage_class, access) {}
ast::Access access,
Constant constant_value)
: Base(declaration,
type,
storage_class,
access,
std::move(constant_value)) {}
LocalVariable::~LocalVariable() = default;
@@ -51,20 +58,12 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
Constant constant_value,
sem::BindingPoint binding_point)
: Base(declaration, type, storage_class, access),
: Base(declaration, type, storage_class, access, std::move(constant_value)),
binding_point_(binding_point),
is_pipeline_constant_(false) {}
GlobalVariable::GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
uint16_t constant_id)
: Base(declaration,
type,
ast::StorageClass::kNone,
ast::Access::kReadWrite),
is_pipeline_constant_(true),
constant_id_(constant_id) {}
GlobalVariable::~GlobalVariable() = default;
@@ -74,18 +73,16 @@ Parameter::Parameter(const ast::Variable* declaration,
ast::StorageClass storage_class,
ast::Access access,
const ParameterUsage usage /* = ParameterUsage::kNone */)
: Base(declaration, type, storage_class, access),
: Base(declaration, type, storage_class, access, Constant{}),
index_(index),
usage_(usage) {}
Parameter::~Parameter() = default;
VariableUser::VariableUser(const ast::IdentifierExpression* declaration,
const sem::Type* type,
Statement* statement,
sem::Variable* variable,
Constant constant_value)
: Base(declaration, type, statement, std::move(constant_value)),
sem::Variable* variable)
: Base(declaration, variable->Type(), statement, variable->ConstantValue()),
variable_(variable) {}
} // namespace sem

View File

@@ -47,10 +47,12 @@ class Variable : public Castable<Variable, Node> {
/// @param type the variable type
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param constant_value the constant value for the variable. May be invalid
Variable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access);
ast::Access access,
Constant constant_value);
/// Destructor
~Variable() override;
@@ -67,6 +69,9 @@ class Variable : public Castable<Variable, Node> {
/// @returns the access control for the variable
ast::Access Access() const { return access_; }
/// @return the constant value of this expression
const Constant& ConstantValue() const { return constant_value_; }
/// @returns the expressions that use the variable
const std::vector<const VariableUser*>& Users() const { return users_; }
@@ -78,6 +83,7 @@ class Variable : public Castable<Variable, Node> {
const sem::Type* const type_;
ast::StorageClass const storage_class_;
ast::Access const access_;
const Constant constant_value_;
std::vector<const VariableUser*> users_;
};
@@ -89,10 +95,12 @@ class LocalVariable : public Castable<LocalVariable, Variable> {
/// @param type the variable type
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param constant_value the constant value for the variable. May be invalid
LocalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access);
ast::Access access,
Constant constant_value);
/// Destructor
~LocalVariable() override;
@@ -101,26 +109,20 @@ class LocalVariable : public Castable<LocalVariable, Variable> {
/// GlobalVariable is a module-scope variable
class GlobalVariable : public Castable<GlobalVariable, Variable> {
public:
/// Constructor for non-overridable constants
/// Constructor
/// @param declaration the AST declaration node
/// @param type the variable type
/// @param storage_class the variable storage class
/// @param access the variable access control type
/// @param constant_value the constant value for the variable. May be invalid
/// @param binding_point the optional resource binding point of the variable
GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
Constant constant_value,
sem::BindingPoint binding_point = {});
/// Constructor for overridable pipeline constants
/// @param declaration the AST declaration node
/// @param type the variable type
/// @param constant_id the pipeline constant ID
GlobalVariable(const ast::Variable* declaration,
const sem::Type* type,
uint16_t constant_id);
/// Destructor
~GlobalVariable() override;
@@ -130,13 +132,20 @@ class GlobalVariable : public Castable<GlobalVariable, Variable> {
/// @returns the pipeline constant ID associated with the variable
uint16_t ConstantId() const { return constant_id_; }
/// @param id the constant identifier to assign to this variable
void SetConstantId(uint16_t id) {
constant_id_ = id;
is_pipeline_constant_ = true;
}
/// @returns true if this variable is an overridable pipeline constant
bool IsPipelineConstant() const { return is_pipeline_constant_; }
private:
sem::BindingPoint binding_point_;
const bool is_pipeline_constant_;
uint16_t const constant_id_ = 0;
const sem::BindingPoint binding_point_;
bool is_pipeline_constant_ = false;
uint16_t constant_id_ = 0;
};
/// Parameter is a function parameter
@@ -186,15 +195,11 @@ class VariableUser : public Castable<VariableUser, Expression> {
public:
/// Constructor
/// @param declaration the AST identifier node
/// @param type the resolved type of the expression
/// @param statement the statement that owns this expression
/// @param variable the semantic variable
/// @param constant_value the constant value for the variable. May be invalid
VariableUser(const ast::IdentifierExpression* declaration,
const sem::Type* type,
Statement* statement,
sem::Variable* variable,
Constant constant_value);
sem::Variable* variable);
/// @returns the variable that this expression refers to
const sem::Variable* Variable() const { return variable_; }