[tint] Make most Validate methods const.

This CL makes all the validation methods const except for
ValidateSwitch.

Bug: tint:1313
Change-Id: I19ce7beae5ab4591d525dad4dca45f03dd4733b3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/87148
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-04-19 13:53:36 +00:00 committed by Dawn LUCI CQ
parent c069710fb6
commit a1f13f8bad
3 changed files with 135 additions and 121 deletions

View File

@ -2143,19 +2143,19 @@ sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
return result;
}
sem::Type* Resolver::TypeOf(const ast::Expression* expr) {
auto* sem = Sem(expr);
return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
}
std::string Resolver::TypeNameOf(const sem::Type* ty) {
std::string Resolver::TypeNameOf(const sem::Type* ty) const {
return RawTypeNameOf(ty->UnwrapRef());
}
std::string Resolver::RawTypeNameOf(const sem::Type* ty) {
std::string Resolver::RawTypeNameOf(const sem::Type* ty) const {
return ty->FriendlyName(builder_->Symbols());
}
sem::Type* Resolver::TypeOf(const ast::Expression* expr) const {
auto* sem = Sem(expr);
return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
}
sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
return Switch(
lit,
@ -2782,12 +2782,6 @@ SEM* Resolver::StatementScope(const ast::Statement* ast,
return sem;
}
std::string Resolver::VectorPretty(uint32_t size,
const sem::Type* element_type) {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols());
}
bool Resolver::Mark(const ast::Node* node) {
if (node == nullptr) {
TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr";

View File

@ -154,8 +154,6 @@ class Resolver {
/// @returns true on success, false on error
bool ResolveInternal();
bool ValidatePipelineStages();
/// Creates the nodes and adds them to the sem::Info mappings of the
/// ProgramBuilder.
void CreateSemanticNodes() const;
@ -240,78 +238,84 @@ class Resolver {
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidateAlias(const ast::Alias*);
bool ValidateArray(const sem::Array* arr, const Source& source);
bool ValidatePipelineStages() const;
bool ValidateAlias(const ast::Alias*) const;
bool ValidateArray(const sem::Array* arr, const Source& source) const;
bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size,
uint32_t el_align,
const Source& source);
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s);
bool ValidateAtomicVariable(const sem::Variable* var);
bool ValidateAssignment(const ast::Statement* a, const sem::Type* rhs_ty);
bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
bool ValidateBreakStatement(const sem::Statement* stmt);
const Source& source) const;
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) const;
bool ValidateAtomicVariable(const sem::Variable* var) const;
bool ValidateAssignment(const ast::Statement* a,
const sem::Type* rhs_ty) const;
bool ValidateBitcast(const ast::BitcastExpression* cast,
const sem::Type* to) const;
bool ValidateBreakStatement(const sem::Statement* stmt) const;
bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_type,
ast::PipelineStage stage,
const bool is_input);
bool ValidateContinueStatement(const sem::Statement* stmt);
bool ValidateDiscardStatement(const sem::Statement* stmt);
bool ValidateElseStatement(const sem::ElseStatement* stmt);
bool ValidateEntryPoint(const sem::Function* func, ast::PipelineStage stage);
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt);
bool ValidateFallthroughStatement(const sem::Statement* stmt);
bool ValidateFunction(const sem::Function* func, ast::PipelineStage stage);
bool ValidateFunctionCall(const sem::Call* call);
bool ValidateGlobalVariable(const sem::Variable* var);
bool ValidateIfStatement(const sem::IfStatement* stmt);
const bool is_input) const;
bool ValidateContinueStatement(const sem::Statement* stmt) const;
bool ValidateDiscardStatement(const sem::Statement* stmt) const;
bool ValidateElseStatement(const sem::ElseStatement* stmt) const;
bool ValidateEntryPoint(const sem::Function* func,
ast::PipelineStage stage) const;
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt) const;
bool ValidateFallthroughStatement(const sem::Statement* stmt) const;
bool ValidateFunction(const sem::Function* func,
ast::PipelineStage stage) const;
bool ValidateFunctionCall(const sem::Call* call) const;
bool ValidateGlobalVariable(const sem::Variable* var) const;
bool ValidateIfStatement(const sem::IfStatement* stmt) const;
bool ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* stmt);
const ast::IncrementDecrementStatement* stmt) const;
bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr,
const sem::Type* storage_type);
bool ValidateBuiltinCall(const sem::Call* call);
const sem::Type* storage_type) const;
bool ValidateBuiltinCall(const sem::Call* call) const;
bool ValidateLocationAttribute(const ast::LocationAttribute* location,
const sem::Type* type,
std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input = false);
bool ValidateLoopStatement(const sem::LoopStatement* stmt);
bool ValidateMatrix(const sem::Matrix* ty, const Source& source);
const bool is_input = false) const;
bool ValidateLoopStatement(const sem::LoopStatement* stmt) const;
bool ValidateMatrix(const sem::Matrix* ty, const Source& source) const;
bool ValidateFunctionParameter(const ast::Function* func,
const sem::Variable* var);
bool ValidateParameter(const ast::Function* func, const sem::Variable* var);
const sem::Variable* var) const;
bool ValidateReturn(const ast::ReturnStatement* ret,
const sem::Type* func_type,
const sem::Type* ret_type);
bool ValidateStatements(const ast::StatementList& stmts);
bool ValidateStorageTexture(const ast::StorageTexture* t);
bool ValidateStructure(const sem::Struct* str, ast::PipelineStage stage);
const sem::Type* ret_type) const;
bool ValidateStatements(const ast::StatementList& stmts) const;
bool ValidateStorageTexture(const ast::StorageTexture* t) const;
bool ValidateStructure(const sem::Struct* str,
ast::PipelineStage stage) const;
bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
const sem::Struct* struct_type);
const sem::Struct* struct_type) const;
bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const sem::Variable* var);
bool ValidateVariable(const sem::Variable* var) const;
bool ValidateVariableConstructorOrCast(const ast::Variable* var,
ast::StorageClass storage_class,
const sem::Type* storage_type,
const sem::Type* rhs_type);
bool ValidateVector(const sem::Vector* ty, const Source& source);
const sem::Type* rhs_type) const;
bool ValidateVector(const sem::Vector* ty, const Source& source) const;
bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
const sem::Vector* vec_type);
const sem::Vector* vec_type) const;
bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
const sem::Matrix* matrix_type);
const sem::Matrix* matrix_type) const;
bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* type);
const sem::Type* type) const;
bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
const sem::Array* arr_type);
bool ValidateTextureBuiltinFunction(const sem::Call* call);
bool ValidateNoDuplicateAttributes(const ast::AttributeList& attributes);
const sem::Array* arr_type) const;
bool ValidateTextureBuiltinFunction(const sem::Call* call) const;
bool ValidateNoDuplicateAttributes(
const ast::AttributeList& attributes) const;
bool ValidateStorageClassLayout(const sem::Type* type,
ast::StorageClass sc,
Source source,
ValidTypeStorageLayouts& layout);
ValidTypeStorageLayouts& layouts) const;
bool ValidateStorageClassLayout(const sem::Variable* var,
ValidTypeStorageLayouts& layout);
ValidTypeStorageLayouts& layouts) const;
/// @returns true if the attribute list contains a
/// ast::DisableValidationAttribute with the validation mode equal to
@ -325,6 +329,13 @@ class Resolver {
bool IsValidationEnabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const;
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, const sem::Type* element_type) const;
/// Resolves the WorkgroupSize for the given function, assigning it to
/// current_function_
bool WorkgroupSize(const ast::Function*);
@ -397,15 +408,15 @@ class Resolver {
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
sem::Type* TypeOf(const ast::Expression* expr);
sem::Type* TypeOf(const ast::Expression* expr) const;
/// @returns the type name of the given semantic type, unwrapping
/// references.
std::string TypeNameOf(const sem::Type* ty);
std::string TypeNameOf(const sem::Type* ty) const;
/// @returns the type name of the given semantic type, without unwrapping
/// references.
std::string RawTypeNameOf(const sem::Type* ty);
std::string RawTypeNameOf(const sem::Type* ty) const;
/// @returns the semantic type of the AST literal `lit`
/// @param lit the literal
@ -425,13 +436,6 @@ class Resolver {
template <typename SEM, typename F>
SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, const sem::Type* element_type);
/// Mark records that the given AST node has been visited, and asserts that
/// the given node has not already been seen. Diamonds in the AST are
/// illegal.
@ -466,7 +470,7 @@ class Resolver {
/// Sem is a helper for obtaining the semantic node for the given AST node.
template <typename SEM = sem::Info::InferFromAST,
typename AST_OR_TYPE = CastableBase>
auto* Sem(const AST_OR_TYPE* ast) {
auto* Sem(const AST_OR_TYPE* ast) const {
using T = sem::Info::GetResultType<SEM, AST_OR_TYPE>;
auto* sem = builder_->Sem().Get(ast);
if (!sem) {
@ -495,7 +499,7 @@ class Resolver {
/// @returns the resolved symbol (function, type or variable) for the given
/// ast::Identifier or ast::TypeName cast to the given semantic type.
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) {
SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
: nullptr;

View File

@ -149,7 +149,8 @@ void TraverseCallChain(diag::List& diagnostics,
} // namespace
bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) {
bool Resolver::ValidateAtomic(const ast::Atomic* a,
const sem::Atomic* s) const {
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// T must be either u32 or i32.
if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) {
@ -160,7 +161,7 @@ bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) {
return true;
}
bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) {
bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const {
switch (t->access) {
case ast::Access::kWrite:
break;
@ -193,7 +194,7 @@ bool Resolver::ValidateVariableConstructorOrCast(
const ast::Variable* var,
ast::StorageClass storage_class,
const sem::Type* storage_ty,
const sem::Type* rhs_ty) {
const sem::Type* rhs_ty) const {
auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type
@ -228,10 +229,11 @@ bool Resolver::ValidateVariableConstructorOrCast(
return true;
}
bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty,
ast::StorageClass sc,
Source source,
ValidTypeStorageLayouts& layouts) {
bool Resolver::ValidateStorageClassLayout(
const sem::Type* store_ty,
ast::StorageClass sc,
Source source,
ValidTypeStorageLayouts& layouts) const {
// https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints
auto is_uniform_struct_or_array = [sc](const sem::Type* ty) {
@ -257,7 +259,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty,
};
// Cache result of type + storage class pair.
if (!valid_type_storage_layouts_.emplace(store_ty, sc).second) {
if (!layouts.emplace(store_ty, sc).second) {
return true;
}
@ -382,8 +384,9 @@ bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty,
return true;
}
bool Resolver::ValidateStorageClassLayout(const sem::Variable* var,
ValidTypeStorageLayouts& layouts) {
bool Resolver::ValidateStorageClassLayout(
const sem::Variable* var,
ValidTypeStorageLayouts& layouts) const {
if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
if (!ValidateStorageClassLayout(str, var->StorageClass(),
str->Declaration()->source, layouts)) {
@ -404,7 +407,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Variable* var,
return true;
}
bool Resolver::ValidateGlobalVariable(const sem::Variable* var) {
bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const {
auto* decl = var->Declaration();
if (!ValidateNoDuplicateAttributes(decl->attributes)) {
return false;
@ -510,7 +513,7 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) {
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// Atomic types may only be instantiated by variables in the workgroup storage
// class or by storage buffer variables with a read_write access mode.
bool Resolver::ValidateAtomicVariable(const sem::Variable* var) {
bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const {
auto sc = var->StorageClass();
auto* decl = var->Declaration();
auto access = var->Access();
@ -554,7 +557,7 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) {
return true;
}
bool Resolver::ValidateVariable(const sem::Variable* var) {
bool Resolver::ValidateVariable(const sem::Variable* var) const {
auto* decl = var->Declaration();
auto* storage_ty = var->Type()->UnwrapRef();
@ -630,7 +633,7 @@ bool Resolver::ValidateVariable(const sem::Variable* var) {
}
bool Resolver::ValidateFunctionParameter(const ast::Function* func,
const sem::Variable* var) {
const sem::Variable* var) const {
if (!ValidateVariable(var)) {
return false;
}
@ -695,7 +698,7 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func,
bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_ty,
ast::PipelineStage stage,
const bool is_input) {
const bool is_input) const {
auto* type = storage_ty->UnwrapRef();
std::stringstream stage_name;
stage_name << stage;
@ -813,7 +816,7 @@ bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
bool Resolver::ValidateInterpolateAttribute(
const ast::InterpolateAttribute* attr,
const sem::Type* storage_ty) {
const sem::Type* storage_ty) const {
auto* type = storage_ty->UnwrapRef();
if (type->is_integer_scalar_or_vector() &&
@ -835,7 +838,7 @@ bool Resolver::ValidateInterpolateAttribute(
}
bool Resolver::ValidateFunction(const sem::Function* func,
ast::PipelineStage stage) {
ast::PipelineStage stage) const {
auto* decl = func->Declaration();
auto name = builder_->Symbols().NameFor(decl->symbol);
@ -941,7 +944,7 @@ bool Resolver::ValidateFunction(const sem::Function* func,
}
bool Resolver::ValidateEntryPoint(const sem::Function* func,
ast::PipelineStage stage) {
ast::PipelineStage stage) const {
auto* decl = func->Declaration();
// Use a lambda to validate the entry point attributes for a type.
@ -1217,7 +1220,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func,
return true;
}
bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
bool Resolver::ValidateStatements(const ast::StatementList& stmts) const {
for (auto* stmt : stmts) {
if (!Sem(stmt)->IsReachable()) {
/// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to
@ -1230,7 +1233,7 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
}
bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
const sem::Type* to) {
const sem::Type* to) const {
auto* from = TypeOf(cast->expr)->UnwrapRef();
if (!from->is_numeric_scalar_or_vector()) {
AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
@ -1259,7 +1262,7 @@ bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
return true;
}
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const {
if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source);
@ -1326,7 +1329,7 @@ bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
return true;
}
bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const {
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) {
AddError("continuing blocks must not contain a continue statement",
stmt->Declaration()->source);
@ -1346,7 +1349,7 @@ bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
return true;
}
bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const {
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
AddError("continuing blocks must not contain a discard statement",
stmt->Declaration()->source);
@ -1359,7 +1362,7 @@ bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
return true;
}
bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const {
if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
if (auto* c = As<sem::CaseStatement>(block->Parent())) {
if (block->Declaration()->Last() == stmt->Declaration()) {
@ -1382,7 +1385,7 @@ bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
return false;
}
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const {
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
@ -1395,7 +1398,7 @@ bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
return true;
}
bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) {
bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) {
AddError("loop does not exit", stmt->Declaration()->source.Begin());
return false;
@ -1403,7 +1406,8 @@ bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) {
return true;
}
bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
bool Resolver::ValidateForLoopStatement(
const sem::ForLoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) {
AddError("for-loop does not exit", stmt->Declaration()->source.Begin());
return false;
@ -1419,7 +1423,7 @@ bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
return true;
}
bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
@ -1429,7 +1433,7 @@ bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
return true;
}
bool Resolver::ValidateBuiltinCall(const sem::Call* call) {
bool Resolver::ValidateBuiltinCall(const sem::Call* call) const {
if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false;
if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
@ -1452,7 +1456,7 @@ bool Resolver::ValidateBuiltinCall(const sem::Call* call) {
return true;
}
bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) {
bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const {
auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin) {
return false;
@ -1524,7 +1528,7 @@ bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) {
check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3);
}
bool Resolver::ValidateFunctionCall(const sem::Call* call) {
bool Resolver::ValidateFunctionCall(const sem::Call* call) const {
auto* decl = call->Declaration();
auto* target = call->Target()->As<sem::Function>();
auto sym = decl->target.name->symbol;
@ -1642,7 +1646,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) {
bool Resolver::ValidateStructureConstructorOrCast(
const ast::CallExpression* ctor,
const sem::Struct* struct_type) {
const sem::Struct* struct_type) const {
if (!struct_type->IsConstructible()) {
AddError("struct constructor has non-constructible type", ctor->source);
return false;
@ -1675,8 +1679,9 @@ bool Resolver::ValidateStructureConstructorOrCast(
return true;
}
bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
const sem::Array* array_type) {
bool Resolver::ValidateArrayConstructorOrCast(
const ast::CallExpression* ctor,
const sem::Array* array_type) const {
auto& values = ctor->args;
auto* elem_ty = array_type->ElemType();
for (auto* value : values) {
@ -1715,8 +1720,9 @@ bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
return true;
}
bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
const sem::Vector* vec_type) {
bool Resolver::ValidateVectorConstructorOrCast(
const ast::CallExpression* ctor,
const sem::Vector* vec_type) const {
auto& values = ctor->args;
auto* elem_ty = vec_type->type();
size_t value_cardinality_sum = 0;
@ -1776,7 +1782,8 @@ bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
return true;
}
bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
bool Resolver::ValidateVector(const sem::Vector* ty,
const Source& source) const {
if (!ty->type()->is_scalar()) {
AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
source);
@ -1785,7 +1792,8 @@ bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
return true;
}
bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
bool Resolver::ValidateMatrix(const sem::Matrix* ty,
const Source& source) const {
if (!ty->is_float_matrix()) {
AddError("matrix element type must be 'f32'", source);
return false;
@ -1793,8 +1801,9 @@ bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
return true;
}
bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
const sem::Matrix* matrix_ty) {
bool Resolver::ValidateMatrixConstructorOrCast(
const ast::CallExpression* ctor,
const sem::Matrix* matrix_ty) const {
auto& values = ctor->args;
// Zero Value expression
if (values.empty()) {
@ -1869,7 +1878,7 @@ bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
}
bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* ty) {
const sem::Type* ty) const {
if (ctor->args.size() == 0) {
return true;
}
@ -1904,7 +1913,7 @@ bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
return true;
}
bool Resolver::ValidatePipelineStages() {
bool Resolver::ValidatePipelineStages() const {
auto check_workgroup_storage = [&](const sem::Function* func,
const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage();
@ -1999,7 +2008,8 @@ bool Resolver::ValidatePipelineStages() {
return true;
}
bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
bool Resolver::ValidateArray(const sem::Array* arr,
const Source& source) const {
auto* el_ty = arr->ElemType();
if (!IsFixedFootprint(el_ty)) {
@ -2013,7 +2023,7 @@ bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size,
uint32_t el_align,
const Source& source) {
const Source& source) const {
auto stride = attr->stride;
bool is_valid_stride =
(stride >= el_size) && (stride >= el_align) && (stride % el_align == 0);
@ -2032,7 +2042,7 @@ bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
return true;
}
bool Resolver::ValidateAlias(const ast::Alias* alias) {
bool Resolver::ValidateAlias(const ast::Alias* alias) const {
auto name = builder_->Symbols().NameFor(alias->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as an alias",
@ -2044,7 +2054,7 @@ bool Resolver::ValidateAlias(const ast::Alias* alias) {
}
bool Resolver::ValidateStructure(const sem::Struct* str,
ast::PipelineStage stage) {
ast::PipelineStage stage) const {
auto name = builder_->Symbols().NameFor(str->Declaration()->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as a struct",
@ -2153,7 +2163,7 @@ bool Resolver::ValidateLocationAttribute(
std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage,
const Source& source,
const bool is_input) {
const bool is_input) const {
std::string inputs_or_output = is_input ? "inputs" : "output";
if (stage == ast::PipelineStage::kCompute) {
AddError("attribute is not valid for compute shader " + inputs_or_output,
@ -2185,7 +2195,7 @@ bool Resolver::ValidateLocationAttribute(
bool Resolver::ValidateReturn(const ast::ReturnStatement* ret,
const sem::Type* func_type,
const sem::Type* ret_type) {
const sem::Type* ret_type) const {
if (func_type->UnwrapRef() != ret_type) {
AddError(
"return statement type must match its function "
@ -2267,7 +2277,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
}
bool Resolver::ValidateAssignment(const ast::Statement* a,
const sem::Type* rhs_ty) {
const sem::Type* rhs_ty) const {
const ast::Expression* lhs;
const ast::Expression* rhs;
if (auto* assign = a->As<ast::AssignmentStatement>()) {
@ -2349,7 +2359,7 @@ bool Resolver::ValidateAssignment(const ast::Statement* a,
}
bool Resolver::ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* inc) {
const ast::IncrementDecrementStatement* inc) const {
const ast::Expression* lhs = inc->lhs;
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
@ -2397,7 +2407,7 @@ bool Resolver::ValidateIncrementDecrementStatement(
}
bool Resolver::ValidateNoDuplicateAttributes(
const ast::AttributeList& attributes) {
const ast::AttributeList& attributes) const {
std::unordered_map<const TypeInfo*, Source> seen;
for (auto* d : attributes) {
auto res = seen.emplace(&d->TypeInfo(), d->source);
@ -2427,4 +2437,10 @@ bool Resolver::IsValidationEnabled(const ast::AttributeList& attributes,
return !IsValidationDisabled(attributes, validation);
}
std::string Resolver::VectorPretty(uint32_t size,
const sem::Type* element_type) const {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols());
}
} // namespace tint::resolver