Make all ast and sem pointers const

And remove a whole load of const_cast hackery.

Semantic nodes may contain internally mutable fields (although only ever modified during resolving), so these are always passed by `const` pointer.

While all AST nodes are internally immutable, we have decided that pointers to AST nodes should also be marked `const`, for consistency.

There's still a collection of const_cast calls in the Resolver. These will be fixed up in a later change.

Bug: tint:745
Change-Id: I046309b8e586772605fc0fe6b2d27f28806d40ef
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/66606
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton
2021-10-19 18:38:54 +00:00
committed by Tint LUCI CQ
parent 7d0fc07b20
commit 8648120bbe
261 changed files with 2441 additions and 2258 deletions

View File

@@ -224,7 +224,8 @@ class ResolverIntrinsicTest_TextureOperation
/// @param dim dimensionality of the texture being sampled
/// @param scalar the scalar type
/// @returns a pointer to a type appropriate for the coord param
ast::Type* GetCoordsType(ast::TextureDimension dim, ast::Type* scalar) {
const ast::Type* GetCoordsType(ast::TextureDimension dim,
const ast::Type* scalar) {
switch (dim) {
case ast::TextureDimension::k1d:
return scalar;
@@ -257,7 +258,7 @@ class ResolverIntrinsicTest_TextureOperation
call_params->push_back(Expr(name));
}
ast::Type* subtype(Texture type) {
const ast::Type* subtype(Texture type) {
if (type == Texture::kF32) {
return ty.f32();
}

View File

@@ -304,8 +304,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
if (auto* t = ty->As<ast::Vector>()) {
if (auto* el = Type(t->type)) {
if (auto* vector = builder_->create<sem::Vector>(
const_cast<sem::Type*>(el), t->width)) {
if (auto* vector = builder_->create<sem::Vector>(el, t->width)) {
if (ValidateVector(vector, t->source)) {
return vector;
}
@@ -315,8 +314,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
if (auto* t = ty->As<ast::Matrix>()) {
if (auto* el = Type(t->type)) {
if (auto* column_type = builder_->create<sem::Vector>(
const_cast<sem::Type*>(el), t->rows)) {
if (auto* column_type = builder_->create<sem::Vector>(el, t->rows)) {
if (auto* matrix =
builder_->create<sem::Matrix>(column_type, t->columns)) {
if (ValidateMatrix(matrix, t->source)) {
@@ -332,7 +330,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
if (auto* t = ty->As<ast::Atomic>()) {
if (auto* el = Type(t->type)) {
auto* a = builder_->create<sem::Atomic>(const_cast<sem::Type*>(el));
auto* a = builder_->create<sem::Atomic>(el);
if (!ValidateAtomic(t, a)) {
return nullptr;
}
@@ -346,8 +344,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
if (access == ast::kUndefined) {
access = DefaultAccessForStorageClass(t->storage_class);
}
return builder_->create<sem::Pointer>(const_cast<sem::Type*>(el),
t->storage_class, access);
return builder_->create<sem::Pointer>(el, t->storage_class, access);
}
return nullptr;
}
@@ -356,15 +353,13 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
}
if (auto* t = ty->As<ast::SampledTexture>()) {
if (auto* el = Type(t->type)) {
return builder_->create<sem::SampledTexture>(
t->dim, const_cast<sem::Type*>(el));
return builder_->create<sem::SampledTexture>(t->dim, el);
}
return nullptr;
}
if (auto* t = ty->As<ast::MultisampledTexture>()) {
if (auto* el = Type(t->type)) {
return builder_->create<sem::MultisampledTexture>(
t->dim, const_cast<sem::Type*>(el));
return builder_->create<sem::MultisampledTexture>(t->dim, el);
}
return nullptr;
}
@@ -379,8 +374,8 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
if (!ValidateStorageTexture(t)) {
return nullptr;
}
return builder_->create<sem::StorageTexture>(
t->dim, t->format, t->access, const_cast<sem::Type*>(el));
return builder_->create<sem::StorageTexture>(t->dim, t->format,
t->access, el);
}
return nullptr;
}
@@ -447,7 +442,7 @@ bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) {
return true;
}
Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var,
VariableKind kind,
uint32_t index /* = 0 */) {
if (variable_to_info_.count(var)) {
@@ -651,7 +646,7 @@ bool Resolver::ValidateVariableConstructor(const ast::Variable* var,
return true;
}
bool Resolver::GlobalVariable(ast::Variable* var) {
bool Resolver::GlobalVariable(const ast::Variable* var) {
if (!ValidateNoDuplicateDefinition(var->symbol, var->source,
/* check_global_scope_only */ true)) {
return false;
@@ -1515,118 +1510,115 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
};
// Inner lambda that is applied to a type and all of its members.
auto validate_entry_point_decorations_inner =
[&](const ast::DecorationList& decos, sem::Type* ty, Source source,
ParamOrRetType param_or_ret, bool is_struct_member) {
// Scan decorations for pipeline IO attributes.
// Check for overlap with attributes that have been seen previously.
ast::Decoration* pipeline_io_attribute = nullptr;
ast::InvariantDecoration* invariant_attribute = nullptr;
for (auto* deco : decos) {
auto is_invalid_compute_shader_decoration = false;
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", deco->source);
AddNote(
"previously consumed " + deco_to_str(pipeline_io_attribute),
auto validate_entry_point_decorations_inner = [&](const ast::DecorationList&
decos,
sem::Type* ty,
Source source,
ParamOrRetType param_or_ret,
bool is_struct_member) {
// Scan decorations for pipeline IO attributes.
// Check for overlap with attributes that have been seen previously.
const ast::Decoration* pipeline_io_attribute = nullptr;
const ast::InvariantDecoration* invariant_attribute = nullptr;
for (auto* deco : decos) {
auto is_invalid_compute_shader_decoration = false;
if (auto* builtin = deco->As<ast::BuiltinDecoration>()) {
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", deco->source);
AddNote("previously consumed " + deco_to_str(pipeline_io_attribute),
pipeline_io_attribute->source);
return false;
}
pipeline_io_attribute = deco;
return false;
}
pipeline_io_attribute = deco;
if (builtins.count(builtin->builtin)) {
AddError(
deco_to_str(builtin) +
" attribute appears multiple times as pipeline " +
(param_or_ret == ParamOrRetType::kParameter ? "input"
: "output"),
func->source);
return false;
}
if (!ValidateBuiltinDecoration(builtin, ty,
/* is_input */ param_or_ret ==
ParamOrRetType::kParameter)) {
return false;
}
builtins.emplace(builtin->builtin);
} else if (auto* location = deco->As<ast::LocationDecoration>()) {
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", deco->source);
AddNote(
"previously consumed " + deco_to_str(pipeline_io_attribute),
pipeline_io_attribute->source);
return false;
}
pipeline_io_attribute = deco;
bool is_input = param_or_ret == ParamOrRetType::kParameter;
if (!ValidateLocationDecoration(location, ty, locations, source,
is_input)) {
return false;
}
} else if (auto* interpolate =
deco->As<ast::InterpolateDecoration>()) {
if (func->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_decoration = true;
} else if (!ValidateInterpolateDecoration(interpolate, ty)) {
return false;
}
} else if (auto* invariant = deco->As<ast::InvariantDecoration>()) {
if (func->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_decoration = true;
}
invariant_attribute = invariant;
}
if (is_invalid_compute_shader_decoration) {
std::string input_or_output =
param_or_ret == ParamOrRetType::kParameter ? "inputs"
: "output";
AddError(
"decoration is not valid for compute shader " + input_or_output,
deco->source);
return false;
}
if (builtins.count(builtin->builtin)) {
AddError(deco_to_str(builtin) +
" attribute appears multiple times as pipeline " +
(param_or_ret == ParamOrRetType::kParameter ? "input"
: "output"),
func->source);
return false;
}
if (IsValidationEnabled(
decos, ast::DisabledValidation::kEntryPointParameter)) {
if (is_struct_member && ty->Is<sem::Struct>()) {
AddError("nested structures cannot be used for entry point IO",
source);
return false;
}
if (!ValidateBuiltinDecoration(
builtin, ty,
/* is_input */ param_or_ret == ParamOrRetType::kParameter)) {
return false;
}
builtins.emplace(builtin->builtin);
} else if (auto* location = deco->As<ast::LocationDecoration>()) {
if (pipeline_io_attribute) {
AddError("multiple entry point IO attributes", deco->source);
AddNote("previously consumed " + deco_to_str(pipeline_io_attribute),
pipeline_io_attribute->source);
return false;
}
pipeline_io_attribute = deco;
if (!ty->Is<sem::Struct>() && !pipeline_io_attribute) {
std::string err = "missing entry point IO attribute";
if (!is_struct_member) {
err += (param_or_ret == ParamOrRetType::kParameter
? " on parameter"
: " on return type");
}
AddError(err, source);
return false;
}
bool is_input = param_or_ret == ParamOrRetType::kParameter;
if (!ValidateLocationDecoration(location, ty, locations, source,
is_input)) {
return false;
}
} else if (auto* interpolate = deco->As<ast::InterpolateDecoration>()) {
if (func->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_decoration = true;
} else if (!ValidateInterpolateDecoration(interpolate, ty)) {
return false;
}
} else if (auto* invariant = deco->As<ast::InvariantDecoration>()) {
if (func->PipelineStage() == ast::PipelineStage::kCompute) {
is_invalid_compute_shader_decoration = true;
}
invariant_attribute = invariant;
}
if (is_invalid_compute_shader_decoration) {
std::string input_or_output =
param_or_ret == ParamOrRetType::kParameter ? "inputs" : "output";
AddError(
"decoration is not valid for compute shader " + input_or_output,
deco->source);
return false;
}
}
if (invariant_attribute) {
bool has_position = false;
if (pipeline_io_attribute) {
if (auto* builtin =
pipeline_io_attribute->As<ast::BuiltinDecoration>()) {
has_position = (builtin->builtin == ast::Builtin::kPosition);
}
}
if (!has_position) {
AddError(
"invariant attribute must only be applied to a position "
"builtin",
invariant_attribute->source);
return false;
}
if (IsValidationEnabled(decos,
ast::DisabledValidation::kEntryPointParameter)) {
if (is_struct_member && ty->Is<sem::Struct>()) {
AddError("nested structures cannot be used for entry point IO", source);
return false;
}
if (!ty->Is<sem::Struct>() && !pipeline_io_attribute) {
std::string err = "missing entry point IO attribute";
if (!is_struct_member) {
err +=
(param_or_ret == ParamOrRetType::kParameter ? " on parameter"
: " on return type");
}
AddError(err, source);
return false;
}
if (invariant_attribute) {
bool has_position = false;
if (pipeline_io_attribute) {
if (auto* builtin =
pipeline_io_attribute->As<ast::BuiltinDecoration>()) {
has_position = (builtin->builtin == ast::Builtin::kPosition);
}
}
return true;
};
if (!has_position) {
AddError(
"invariant attribute must only be applied to a position "
"builtin",
invariant_attribute->source);
return false;
}
}
}
return true;
};
// Outer lambda for validating the entry point decorations for a type.
auto validate_entry_point_decorations = [&](const ast::DecorationList& decos,
@@ -1742,7 +1734,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
return true;
}
bool Resolver::Function(ast::Function* func) {
bool Resolver::Function(const ast::Function* func) {
auto* info = function_infos_.Create<FunctionInfo>(func);
if (func->IsEntryPoint()) {
@@ -2019,7 +2011,7 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
return true;
}
bool Resolver::Statement(ast::Statement* stmt) {
bool Resolver::Statement(const ast::Statement* stmt) {
if (stmt->Is<ast::CaseStatement>()) {
AddError("case statement can only be used inside a switch statement",
stmt->source);
@@ -2129,7 +2121,7 @@ bool Resolver::Statement(ast::Statement* stmt) {
return false;
}
bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
bool Resolver::CaseStatement(const ast::CaseStatement* stmt) {
auto* sem = builder_->create<sem::SwitchCaseBlockStatement>(
stmt->body, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
@@ -2141,7 +2133,7 @@ bool Resolver::CaseStatement(ast::CaseStatement* stmt) {
return Scope(sem, [&] { return Statements(stmt->body->statements); });
}
bool Resolver::IfStatement(ast::IfStatement* stmt) {
bool Resolver::IfStatement(const ast::IfStatement* stmt) {
auto* sem =
builder_->create<sem::IfStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
@@ -2177,7 +2169,7 @@ bool Resolver::IfStatement(ast::IfStatement* stmt) {
});
}
bool Resolver::ElseStatement(ast::ElseStatement* stmt) {
bool Resolver::ElseStatement(const ast::ElseStatement* stmt) {
auto* sem =
builder_->create<sem::ElseStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
@@ -2205,14 +2197,14 @@ bool Resolver::ElseStatement(ast::ElseStatement* stmt) {
});
}
bool Resolver::BlockStatement(ast::BlockStatement* stmt) {
bool Resolver::BlockStatement(const ast::BlockStatement* stmt) {
auto* sem = builder_->create<sem::BlockStatement>(
stmt->As<ast::BlockStatement>(), current_compound_statement_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] { return Statements(stmt->statements); });
}
bool Resolver::LoopStatement(ast::LoopStatement* stmt) {
bool Resolver::LoopStatement(const ast::LoopStatement* stmt) {
auto* sem =
builder_->create<sem::LoopStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
@@ -2245,7 +2237,7 @@ bool Resolver::LoopStatement(ast::LoopStatement* stmt) {
});
}
bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) {
bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
auto* sem = builder_->create<sem::ForLoopStatement>(
stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
@@ -2287,12 +2279,12 @@ bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) {
});
}
bool Resolver::TraverseExpressions(ast::Expression* root,
std::vector<ast::Expression*>& out) {
std::vector<ast::Expression*> to_visit;
bool Resolver::TraverseExpressions(const ast::Expression* root,
std::vector<const ast::Expression*>& out) {
std::vector<const ast::Expression*> to_visit;
to_visit.emplace_back(root);
auto add = [&](ast::Expression* e) {
auto add = [&](const ast::Expression* e) {
Mark(e);
to_visit.emplace_back(e);
};
@@ -2336,8 +2328,8 @@ bool Resolver::TraverseExpressions(ast::Expression* root,
return true;
}
bool Resolver::Expression(ast::Expression* root) {
std::vector<ast::Expression*> sorted;
bool Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
if (!TraverseExpressions(root, sorted)) {
return false;
}
@@ -2373,7 +2365,7 @@ bool Resolver::Expression(ast::Expression* root) {
return true;
}
bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
bool Resolver::ArrayAccessor(const ast::ArrayAccessorExpression* expr) {
auto* idx = expr->index;
auto* res = TypeOf(expr->array);
auto* parent_type = res->UnwrapRef();
@@ -2420,7 +2412,7 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
return true;
}
bool Resolver::Bitcast(ast::BitcastExpression* expr) {
bool Resolver::Bitcast(const ast::BitcastExpression* expr) {
auto* ty = Type(expr->type);
if (!ty) {
return false;
@@ -2433,7 +2425,7 @@ bool Resolver::Bitcast(ast::BitcastExpression* expr) {
return true;
}
bool Resolver::Call(ast::CallExpression* call) {
bool Resolver::Call(const ast::CallExpression* call) {
Mark(call->func);
auto* ident = call->func;
auto name = builder_->Symbols().NameFor(ident->symbol);
@@ -2452,7 +2444,7 @@ bool Resolver::Call(ast::CallExpression* call) {
return ValidateCall(call);
}
bool Resolver::ValidateCall(ast::CallExpression* call) {
bool Resolver::ValidateCall(const ast::CallExpression* call) {
if (TypeOf(call)->Is<sem::Void>()) {
bool is_call_statement = false;
if (current_statement_) {
@@ -2483,7 +2475,7 @@ bool Resolver::ValidateCall(ast::CallExpression* call) {
return true;
}
bool Resolver::ValidateCallStatement(ast::CallStatement* stmt) {
bool Resolver::ValidateCallStatement(const ast::CallStatement* stmt) {
const sem::Type* return_type = TypeOf(stmt->expr);
if (!return_type->Is<sem::Void>()) {
// https://gpuweb.github.io/gpuweb/wgsl/#function-call-statement
@@ -2501,7 +2493,7 @@ bool Resolver::ValidateCallStatement(ast::CallStatement* stmt) {
return true;
}
bool Resolver::IntrinsicCall(ast::CallExpression* call,
bool Resolver::IntrinsicCall(const ast::CallExpression* call,
sem::IntrinsicType intrinsic_type) {
std::vector<const sem::Type*> arg_tys;
arg_tys.reserve(call->args.size());
@@ -2710,7 +2702,7 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
return true;
}
bool Resolver::Constructor(ast::ConstructorExpression* expr) {
bool Resolver::Constructor(const ast::ConstructorExpression* expr) {
if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
auto* type = Type(type_ctor->type);
if (!type) {
@@ -2997,7 +2989,7 @@ bool Resolver::ValidateScalarConstructor(
return true;
}
bool Resolver::Identifier(ast::IdentifierExpression* expr) {
bool Resolver::Identifier(const ast::IdentifierExpression* expr) {
auto symbol = expr->symbol;
VariableInfo* var;
if (variable_stack_.get(symbol, &var)) {
@@ -3061,11 +3053,11 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) {
return false;
}
bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
bool Resolver::MemberAccessor(const ast::MemberAccessorExpression* expr) {
auto* structure = TypeOf(expr->structure);
auto* storage_type = structure->UnwrapRef();
sem::Type* ret = nullptr;
const sem::Type* ret = nullptr;
std::vector<uint32_t> swizzle;
if (auto* str = storage_type->As<sem::Struct>()) {
@@ -3181,7 +3173,7 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
return true;
}
bool Resolver::Binary(ast::BinaryExpression* expr) {
bool Resolver::Binary(const ast::BinaryExpression* expr) {
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
@@ -3189,8 +3181,8 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
using Matrix = sem::Matrix;
using Vector = sem::Vector;
auto* lhs_type = const_cast<sem::Type*>(TypeOf(expr->lhs)->UnwrapRef());
auto* rhs_type = const_cast<sem::Type*>(TypeOf(expr->rhs)->UnwrapRef());
auto* lhs_type = TypeOf(expr->lhs)->UnwrapRef();
auto* rhs_type = TypeOf(expr->rhs)->UnwrapRef();
auto* lhs_vec = lhs_type->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
@@ -3386,7 +3378,7 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
return false;
}
bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
bool Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
auto* expr_type = TypeOf(unary->expr);
if (!expr_type) {
return false;
@@ -3466,7 +3458,7 @@ bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
}
bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
ast::Variable* var = stmt->variable;
const ast::Variable* var = stmt->variable;
Mark(var);
if (!ValidateNoDuplicateDefinition(var->symbol, var->source)) {
@@ -3819,8 +3811,8 @@ void Resolver::CreateSemanticNodes() const {
}
auto* sem_func = builder_->create<sem::Function>(
info->declaration, const_cast<sem::Type*>(info->return_type),
parameters, remap_vars(info->referenced_module_vars),
info->declaration, info->return_type, parameters,
remap_vars(info->referenced_module_vars),
remap_vars(info->local_referenced_module_vars), info->return_statements,
info->callsites, ancestor_entry_points[func->symbol],
info->workgroup_size);
@@ -3845,8 +3837,7 @@ void Resolver::CreateSemanticNodes() const {
continue;
}
sem.Add(expr, builder_->create<sem::Expression>(
const_cast<ast::Expression*>(expr), info.type,
info.statement, info.constant_value));
expr, info.type, info.statement, info.constant_value));
}
}
@@ -4057,7 +4048,7 @@ bool Resolver::ValidateStructure(const sem::Struct* str) {
}
auto has_position = false;
ast::InvariantDecoration* invariant_attribute = nullptr;
const ast::InvariantDecoration* invariant_attribute = nullptr;
for (auto* deco : member->Declaration()->decorations) {
if (!deco->IsAnyOf<ast::BuiltinDecoration, //
ast::InternalDecoration, //
@@ -4187,7 +4178,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
// validation.
uint64_t struct_size = 0;
uint64_t struct_align = 1;
std::unordered_map<Symbol, ast::StructMember*> member_map;
std::unordered_map<Symbol, const ast::StructMember*> member_map;
for (auto* member : str->members) {
Mark(member);
@@ -4275,8 +4266,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
}
auto* sem_member = builder_->create<sem::StructMember>(
member, member->symbol, const_cast<sem::Type*>(type),
static_cast<uint32_t>(sem_members.size()),
member, member->symbol, type, static_cast<uint32_t>(sem_members.size()),
static_cast<uint32_t>(offset), static_cast<uint32_t>(align),
static_cast<uint32_t>(size));
builder_->Sem().Add(member, sem_member);
@@ -4360,7 +4350,7 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
return true;
}
bool Resolver::Return(ast::ReturnStatement* ret) {
bool Resolver::Return(const ast::ReturnStatement* ret) {
current_function_->return_statements.push_back(ret);
if (auto* value = ret->value) {
@@ -4443,7 +4433,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
return true;
}
bool Resolver::SwitchStatement(ast::SwitchStatement* stmt) {
bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
auto* sem =
builder_->create<sem::SwitchStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
@@ -4465,7 +4455,7 @@ bool Resolver::SwitchStatement(ast::SwitchStatement* stmt) {
});
}
bool Resolver::Assignment(ast::AssignmentStatement* a) {
bool Resolver::Assignment(const ast::AssignmentStatement* a) {
Mark(a->lhs);
Mark(a->rhs);
@@ -4688,7 +4678,8 @@ Resolver::VariableInfo::VariableInfo(const ast::Variable* decl,
Resolver::VariableInfo::~VariableInfo() = default;
Resolver::FunctionInfo::FunctionInfo(ast::Function* decl) : declaration(decl) {}
Resolver::FunctionInfo::FunctionInfo(const ast::Function* decl)
: declaration(decl) {}
Resolver::FunctionInfo::~FunctionInfo() = default;
} // namespace resolver

View File

@@ -112,7 +112,7 @@ class Resolver {
std::string const type_name;
ast::StorageClass storage_class;
ast::Access const access;
std::vector<ast::IdentifierExpression*> users;
std::vector<const ast::IdentifierExpression*> users;
sem::BindingPoint binding_point;
VariableKind kind;
uint32_t index = 0; // Parameter index, if kind == kParameter
@@ -130,10 +130,10 @@ class Resolver {
/// Structure holding semantic information about a function.
/// Used to build the sem::Function nodes at the end of resolving.
struct FunctionInfo {
explicit FunctionInfo(ast::Function* decl);
explicit FunctionInfo(const ast::Function* decl);
~FunctionInfo();
ast::Function* const declaration;
const ast::Function* const declaration;
std::vector<VariableInfo*> parameters;
UniqueVector<VariableInfo*> referenced_module_vars;
UniqueVector<VariableInfo*> local_referenced_module_vars;
@@ -192,7 +192,7 @@ class Resolver {
}
ast::BlockStatement const* const block;
Type const type;
const Type type;
BlockInfo* const parent;
std::vector<const ast::Variable*> decls;
@@ -235,31 +235,31 @@ class Resolver {
// AST and Type traversal methods
// Each return true on success, false on failure.
bool ArrayAccessor(ast::ArrayAccessorExpression*);
bool Assignment(ast::AssignmentStatement* a);
bool Binary(ast::BinaryExpression*);
bool Bitcast(ast::BitcastExpression*);
bool BlockStatement(ast::BlockStatement*);
bool Call(ast::CallExpression*);
bool CaseStatement(ast::CaseStatement*);
bool Constructor(ast::ConstructorExpression*);
bool ElseStatement(ast::ElseStatement*);
bool Expression(ast::Expression*);
bool ForLoopStatement(ast::ForLoopStatement*);
bool Function(ast::Function*);
bool ArrayAccessor(const ast::ArrayAccessorExpression*);
bool Assignment(const ast::AssignmentStatement* a);
bool Binary(const ast::BinaryExpression*);
bool Bitcast(const ast::BitcastExpression*);
bool BlockStatement(const ast::BlockStatement*);
bool Call(const ast::CallExpression*);
bool CaseStatement(const ast::CaseStatement*);
bool Constructor(const ast::ConstructorExpression*);
bool ElseStatement(const ast::ElseStatement*);
bool Expression(const ast::Expression*);
bool ForLoopStatement(const ast::ForLoopStatement*);
bool Function(const ast::Function*);
bool FunctionCall(const ast::CallExpression* call);
bool GlobalVariable(ast::Variable* var);
bool Identifier(ast::IdentifierExpression*);
bool IfStatement(ast::IfStatement*);
bool IntrinsicCall(ast::CallExpression*, sem::IntrinsicType);
bool LoopStatement(ast::LoopStatement*);
bool MemberAccessor(ast::MemberAccessorExpression*);
bool Parameter(ast::Variable* param);
bool Return(ast::ReturnStatement* ret);
bool Statement(ast::Statement*);
bool GlobalVariable(const ast::Variable* var);
bool Identifier(const ast::IdentifierExpression*);
bool IfStatement(const ast::IfStatement*);
bool IntrinsicCall(const ast::CallExpression*, sem::IntrinsicType);
bool LoopStatement(const ast::LoopStatement*);
bool MemberAccessor(const ast::MemberAccessorExpression*);
bool Parameter(const ast::Variable* param);
bool Return(const ast::ReturnStatement* ret);
bool Statement(const ast::Statement*);
bool Statements(const ast::StatementList&);
bool SwitchStatement(ast::SwitchStatement* s);
bool UnaryOp(ast::UnaryOpExpression*);
bool SwitchStatement(const ast::SwitchStatement* s);
bool UnaryOp(const ast::UnaryOpExpression*);
bool VariableDeclStatement(const ast::VariableDeclStatement*);
/// Performs a depth-first traversal of the expression nodes from `root`,
@@ -268,8 +268,8 @@ class Resolver {
/// @param out the ordered list of visited expression nodes, starting with the
/// root node, and ending with leaf nodes
/// @return true on success, false on error
bool TraverseExpressions(ast::Expression* root,
std::vector<ast::Expression*>& out);
bool TraverseExpressions(const ast::Expression* root,
std::vector<const ast::Expression*>& out);
// AST and Type validation methods
// Each return true on success, false on failure.
@@ -284,8 +284,8 @@ class Resolver {
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
bool ValidateCall(ast::CallExpression* call);
bool ValidateCallStatement(ast::CallStatement* stmt);
bool ValidateCall(const ast::CallExpression* call);
bool ValidateCallStatement(const ast::CallStatement* stmt);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunctionCall(const ast::CallExpression* call,
@@ -371,7 +371,7 @@ class Resolver {
/// @param var the variable to create or return the `VariableInfo` for
/// @param kind what kind of variable we are declaring
/// @param index the index of the parameter, if this variable is a parameter
VariableInfo* Variable(ast::Variable* var,
VariableInfo* Variable(const ast::Variable* var,
VariableKind kind,
uint32_t index = 0);

View File

@@ -1641,9 +1641,9 @@ TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
uint32_t mat_rows = std::get<2>(GetParam());
uint32_t mat_cols = std::get<3>(GetParam());
ast::Type* lhs_type;
ast::Type* rhs_type;
sem::Type* result_type;
const ast::Type* lhs_type = nullptr;
const ast::Type* rhs_type = nullptr;
const sem::Type* result_type = nullptr;
bool is_valid_expr;
if (vec_by_mat) {

View File

@@ -45,16 +45,16 @@ class TestHelper : public ProgramBuilder {
/// @param expr the ast::Expression
/// @return the ast::Statement of the ast::Expression, or nullptr if the
/// expression is not owned by a statement.
const ast::Statement* StmtOf(ast::Expression* expr) {
const ast::Statement* StmtOf(const ast::Expression* expr) {
auto* sem_stmt = Sem().Get(expr)->Stmt();
return sem_stmt ? sem_stmt->Declaration() : nullptr;
}
/// Returns the BlockStatement that holds the given statement.
/// @param stmt the ast::Statment
/// @param stmt the ast::Statement
/// @return the ast::BlockStatement that holds the ast::Statement, or nullptr
/// if the statement is not owned by a BlockStatement.
const ast::BlockStatement* BlockOf(ast::Statement* stmt) {
const ast::BlockStatement* BlockOf(const ast::Statement* stmt) {
auto* sem_stmt = Sem().Get(stmt);
return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr;
}
@@ -63,7 +63,7 @@ class TestHelper : public ProgramBuilder {
/// @param expr the ast::Expression
/// @return the ast::Statement of the ast::Expression, or nullptr if the
/// expression is not indirectly owned by a BlockStatement.
const ast::BlockStatement* BlockOf(ast::Expression* expr) {
const ast::BlockStatement* BlockOf(const ast::Expression* expr) {
auto* sem_stmt = Sem().Get(expr)->Stmt();
return sem_stmt ? sem_stmt->Block()->Declaration() : nullptr;
}
@@ -72,7 +72,7 @@ class TestHelper : public ProgramBuilder {
/// @param expr the identifier expression
/// @return the resolved sem::Variable of the identifier, or nullptr if
/// the expression did not resolve to a variable.
const sem::Variable* VarOf(ast::Expression* expr) {
const sem::Variable* VarOf(const ast::Expression* expr) {
auto* sem_ident = Sem().Get(expr);
auto* var_user = sem_ident ? sem_ident->As<sem::VariableUser>() : nullptr;
return var_user ? var_user->Variable() : nullptr;
@@ -82,8 +82,8 @@ class TestHelper : public ProgramBuilder {
/// @param var the variable to check
/// @param expected_users the expected users of the variable
/// @return true if all users are as expected
bool CheckVarUsers(ast::Variable* var,
std::vector<ast::Expression*>&& expected_users) {
bool CheckVarUsers(const ast::Variable* var,
std::vector<const ast::Expression*>&& expected_users) {
auto& var_users = Sem().Get(var)->Users();
if (var_users.size() != expected_users.size()) {
return false;
@@ -171,10 +171,10 @@ using alias2 = alias<TO, 2>;
template <typename TO>
using alias3 = alias<TO, 3>;
using ast_type_func_ptr = ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = ast::Expression* (*)(ProgramBuilder& b,
int elem_value);
using sem_type_func_ptr = sem::Type* (*)(ProgramBuilder& b);
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
int elem_value);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
template <typename T>
struct DataType {};
@@ -187,16 +187,16 @@ struct DataType<bool> {
/// @param b the ProgramBuilder
/// @return a new AST bool type
static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); }
static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.bool_(); }
/// @param b the ProgramBuilder
/// @return the semantic bool type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Bool>();
}
/// @param b the ProgramBuilder
/// @param elem_value the b
/// @return a new AST expression of the bool type
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(elem_value == 0);
}
};
@@ -209,16 +209,16 @@ struct DataType<i32> {
/// @param b the ProgramBuilder
/// @return a new AST i32 type
static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); }
static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.i32(); }
/// @param b the ProgramBuilder
/// @return the semantic i32 type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::I32>();
}
/// @param b the ProgramBuilder
/// @param elem_value the value i32 will be initialized with
/// @return a new AST i32 literal value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(static_cast<i32>(elem_value));
}
};
@@ -231,16 +231,16 @@ struct DataType<u32> {
/// @param b the ProgramBuilder
/// @return a new AST u32 type
static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); }
static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.u32(); }
/// @param b the ProgramBuilder
/// @return the semantic u32 type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::U32>();
}
/// @param b the ProgramBuilder
/// @param elem_value the value u32 will be initialized with
/// @return a new AST u32 literal value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(static_cast<u32>(elem_value));
}
};
@@ -253,16 +253,16 @@ struct DataType<f32> {
/// @param b the ProgramBuilder
/// @return a new AST f32 type
static inline ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); }
static inline const ast::Type* AST(ProgramBuilder& b) { return b.ty.f32(); }
/// @param b the ProgramBuilder
/// @return the semantic f32 type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::F32>();
}
/// @param b the ProgramBuilder
/// @param elem_value the value f32 will be initialized with
/// @return a new AST f32 literal value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Expr(static_cast<f32>(elem_value));
}
};
@@ -275,19 +275,19 @@ struct DataType<vec<N, T>> {
/// @param b the ProgramBuilder
/// @return a new AST vector type
static inline ast::Type* AST(ProgramBuilder& b) {
static inline const ast::Type* AST(ProgramBuilder& b) {
return b.ty.vec(DataType<T>::AST(b), N);
}
/// @param b the ProgramBuilder
/// @return the semantic vector type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Vector>(DataType<T>::Sem(b), N);
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the vector will be initialized
/// with
/// @return a new AST vector value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
@@ -312,12 +312,12 @@ struct DataType<mat<N, M, T>> {
/// @param b the ProgramBuilder
/// @return a new AST matrix type
static inline ast::Type* AST(ProgramBuilder& b) {
static inline const ast::Type* AST(ProgramBuilder& b) {
return b.ty.mat(DataType<T>::AST(b), N, M);
}
/// @param b the ProgramBuilder
/// @return the semantic matrix type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
auto* column_type = b.create<sem::Vector>(DataType<T>::Sem(b), M);
return b.create<sem::Matrix>(column_type, N);
}
@@ -325,7 +325,7 @@ struct DataType<mat<N, M, T>> {
/// @param elem_value the value each element in the matrix will be initialized
/// with
/// @return a new AST matrix value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}
@@ -350,7 +350,7 @@ struct DataType<alias<T, ID>> {
/// @param b the ProgramBuilder
/// @return a new AST alias type
static inline ast::Type* AST(ProgramBuilder& b) {
static inline const ast::Type* AST(ProgramBuilder& b) {
auto name = b.Symbols().Register("alias_" + std::to_string(ID));
if (!b.AST().LookupType(name)) {
auto* type = DataType<T>::AST(b);
@@ -360,7 +360,7 @@ struct DataType<alias<T, ID>> {
}
/// @param b the ProgramBuilder
/// @return the semantic aliased type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
return DataType<T>::Sem(b);
}
@@ -368,7 +368,7 @@ struct DataType<alias<T, ID>> {
/// @param elem_value the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<!IS_COMPOSITE, ast::Expression*> Expr(
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(
ProgramBuilder& b,
int elem_value) {
// Cast
@@ -379,7 +379,7 @@ struct DataType<alias<T, ID>> {
/// @param elem_value the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<IS_COMPOSITE, ast::Expression*> Expr(
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(
ProgramBuilder& b,
int elem_value) {
// Construct
@@ -395,19 +395,19 @@ struct DataType<array<N, T>> {
/// @param b the ProgramBuilder
/// @return a new AST array type
static inline ast::Type* AST(ProgramBuilder& b) {
static inline const ast::Type* AST(ProgramBuilder& b) {
return b.ty.array(DataType<T>::AST(b), N);
}
/// @param b the ProgramBuilder
/// @return the semantic array type
static inline sem::Type* Sem(ProgramBuilder& b) {
static inline const sem::Type* Sem(ProgramBuilder& b) {
return b.create<sem::Array>(DataType<T>::Sem(b), N);
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the array will be initialized
/// with
/// @return a new AST array value expression
static inline ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
static inline const ast::Expression* Expr(ProgramBuilder& b, int elem_value) {
return b.Construct(AST(b), ExprArgs(b, elem_value));
}