[tint] Make most Validate methods const.

This CL makes all the validation methods const except for
ValidateSwitch.

Bug: tint:1313
Change-Id: I19ce7beae5ab4591d525dad4dca45f03dd4733b3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/87148
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-04-19 13:53:36 +00:00 committed by Dawn LUCI CQ
parent c069710fb6
commit a1f13f8bad
3 changed files with 135 additions and 121 deletions

View File

@ -2143,19 +2143,19 @@ sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) {
return result; return result;
} }
sem::Type* Resolver::TypeOf(const ast::Expression* expr) { std::string Resolver::TypeNameOf(const sem::Type* ty) const {
auto* sem = Sem(expr);
return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
}
std::string Resolver::TypeNameOf(const sem::Type* ty) {
return RawTypeNameOf(ty->UnwrapRef()); return RawTypeNameOf(ty->UnwrapRef());
} }
std::string Resolver::RawTypeNameOf(const sem::Type* ty) { std::string Resolver::RawTypeNameOf(const sem::Type* ty) const {
return ty->FriendlyName(builder_->Symbols()); return ty->FriendlyName(builder_->Symbols());
} }
sem::Type* Resolver::TypeOf(const ast::Expression* expr) const {
auto* sem = Sem(expr);
return sem ? const_cast<sem::Type*>(sem->Type()) : nullptr;
}
sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) { sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) {
return Switch( return Switch(
lit, lit,
@ -2782,12 +2782,6 @@ SEM* Resolver::StatementScope(const ast::Statement* ast,
return sem; return sem;
} }
std::string Resolver::VectorPretty(uint32_t size,
const sem::Type* element_type) {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols());
}
bool Resolver::Mark(const ast::Node* node) { bool Resolver::Mark(const ast::Node* node) {
if (node == nullptr) { if (node == nullptr) {
TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr"; TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr";

View File

@ -154,8 +154,6 @@ class Resolver {
/// @returns true on success, false on error /// @returns true on success, false on error
bool ResolveInternal(); bool ResolveInternal();
bool ValidatePipelineStages();
/// Creates the nodes and adds them to the sem::Info mappings of the /// Creates the nodes and adds them to the sem::Info mappings of the
/// ProgramBuilder. /// ProgramBuilder.
void CreateSemanticNodes() const; void CreateSemanticNodes() const;
@ -240,78 +238,84 @@ class Resolver {
// AST and Type validation methods // AST and Type validation methods
// Each return true on success, false on failure. // Each return true on success, false on failure.
bool ValidateAlias(const ast::Alias*); bool ValidatePipelineStages() const;
bool ValidateArray(const sem::Array* arr, const Source& source); bool ValidateAlias(const ast::Alias*) const;
bool ValidateArray(const sem::Array* arr, const Source& source) const;
bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size, uint32_t el_size,
uint32_t el_align, uint32_t el_align,
const Source& source); const Source& source) const;
bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s); bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) const;
bool ValidateAtomicVariable(const sem::Variable* var); bool ValidateAtomicVariable(const sem::Variable* var) const;
bool ValidateAssignment(const ast::Statement* a, const sem::Type* rhs_ty); bool ValidateAssignment(const ast::Statement* a,
bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to); const sem::Type* rhs_ty) const;
bool ValidateBreakStatement(const sem::Statement* stmt); bool ValidateBitcast(const ast::BitcastExpression* cast,
const sem::Type* to) const;
bool ValidateBreakStatement(const sem::Statement* stmt) const;
bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_type, const sem::Type* storage_type,
ast::PipelineStage stage, ast::PipelineStage stage,
const bool is_input); const bool is_input) const;
bool ValidateContinueStatement(const sem::Statement* stmt); bool ValidateContinueStatement(const sem::Statement* stmt) const;
bool ValidateDiscardStatement(const sem::Statement* stmt); bool ValidateDiscardStatement(const sem::Statement* stmt) const;
bool ValidateElseStatement(const sem::ElseStatement* stmt); bool ValidateElseStatement(const sem::ElseStatement* stmt) const;
bool ValidateEntryPoint(const sem::Function* func, ast::PipelineStage stage); bool ValidateEntryPoint(const sem::Function* func,
bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt); ast::PipelineStage stage) const;
bool ValidateFallthroughStatement(const sem::Statement* stmt); bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt) const;
bool ValidateFunction(const sem::Function* func, ast::PipelineStage stage); bool ValidateFallthroughStatement(const sem::Statement* stmt) const;
bool ValidateFunctionCall(const sem::Call* call); bool ValidateFunction(const sem::Function* func,
bool ValidateGlobalVariable(const sem::Variable* var); ast::PipelineStage stage) const;
bool ValidateIfStatement(const sem::IfStatement* stmt); bool ValidateFunctionCall(const sem::Call* call) const;
bool ValidateGlobalVariable(const sem::Variable* var) const;
bool ValidateIfStatement(const sem::IfStatement* stmt) const;
bool ValidateIncrementDecrementStatement( bool ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* stmt); const ast::IncrementDecrementStatement* stmt) const;
bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr, bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr,
const sem::Type* storage_type); const sem::Type* storage_type) const;
bool ValidateBuiltinCall(const sem::Call* call); bool ValidateBuiltinCall(const sem::Call* call) const;
bool ValidateLocationAttribute(const ast::LocationAttribute* location, bool ValidateLocationAttribute(const ast::LocationAttribute* location,
const sem::Type* type, const sem::Type* type,
std::unordered_set<uint32_t>& locations, std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage, ast::PipelineStage stage,
const Source& source, const Source& source,
const bool is_input = false); const bool is_input = false) const;
bool ValidateLoopStatement(const sem::LoopStatement* stmt); bool ValidateLoopStatement(const sem::LoopStatement* stmt) const;
bool ValidateMatrix(const sem::Matrix* ty, const Source& source); bool ValidateMatrix(const sem::Matrix* ty, const Source& source) const;
bool ValidateFunctionParameter(const ast::Function* func, bool ValidateFunctionParameter(const ast::Function* func,
const sem::Variable* var); const sem::Variable* var) const;
bool ValidateParameter(const ast::Function* func, const sem::Variable* var);
bool ValidateReturn(const ast::ReturnStatement* ret, bool ValidateReturn(const ast::ReturnStatement* ret,
const sem::Type* func_type, const sem::Type* func_type,
const sem::Type* ret_type); const sem::Type* ret_type) const;
bool ValidateStatements(const ast::StatementList& stmts); bool ValidateStatements(const ast::StatementList& stmts) const;
bool ValidateStorageTexture(const ast::StorageTexture* t); bool ValidateStorageTexture(const ast::StorageTexture* t) const;
bool ValidateStructure(const sem::Struct* str, ast::PipelineStage stage); bool ValidateStructure(const sem::Struct* str,
ast::PipelineStage stage) const;
bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor, bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor,
const sem::Struct* struct_type); const sem::Struct* struct_type) const;
bool ValidateSwitch(const ast::SwitchStatement* s); bool ValidateSwitch(const ast::SwitchStatement* s);
bool ValidateVariable(const sem::Variable* var); bool ValidateVariable(const sem::Variable* var) const;
bool ValidateVariableConstructorOrCast(const ast::Variable* var, bool ValidateVariableConstructorOrCast(const ast::Variable* var,
ast::StorageClass storage_class, ast::StorageClass storage_class,
const sem::Type* storage_type, const sem::Type* storage_type,
const sem::Type* rhs_type); const sem::Type* rhs_type) const;
bool ValidateVector(const sem::Vector* ty, const Source& source); bool ValidateVector(const sem::Vector* ty, const Source& source) const;
bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor, bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
const sem::Vector* vec_type); const sem::Vector* vec_type) const;
bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor, bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
const sem::Matrix* matrix_type); const sem::Matrix* matrix_type) const;
bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* type); const sem::Type* type) const;
bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor, bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
const sem::Array* arr_type); const sem::Array* arr_type) const;
bool ValidateTextureBuiltinFunction(const sem::Call* call); bool ValidateTextureBuiltinFunction(const sem::Call* call) const;
bool ValidateNoDuplicateAttributes(const ast::AttributeList& attributes); bool ValidateNoDuplicateAttributes(
const ast::AttributeList& attributes) const;
bool ValidateStorageClassLayout(const sem::Type* type, bool ValidateStorageClassLayout(const sem::Type* type,
ast::StorageClass sc, ast::StorageClass sc,
Source source, Source source,
ValidTypeStorageLayouts& layout); ValidTypeStorageLayouts& layouts) const;
bool ValidateStorageClassLayout(const sem::Variable* var, bool ValidateStorageClassLayout(const sem::Variable* var,
ValidTypeStorageLayouts& layout); ValidTypeStorageLayouts& layouts) const;
/// @returns true if the attribute list contains a /// @returns true if the attribute list contains a
/// ast::DisableValidationAttribute with the validation mode equal to /// ast::DisableValidationAttribute with the validation mode equal to
@ -325,6 +329,13 @@ class Resolver {
bool IsValidationEnabled(const ast::AttributeList& attributes, bool IsValidationEnabled(const ast::AttributeList& attributes,
ast::DisabledValidation validation) const; ast::DisabledValidation validation) const;
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, const sem::Type* element_type) const;
/// Resolves the WorkgroupSize for the given function, assigning it to /// Resolves the WorkgroupSize for the given function, assigning it to
/// current_function_ /// current_function_
bool WorkgroupSize(const ast::Function*); bool WorkgroupSize(const ast::Function*);
@ -397,15 +408,15 @@ class Resolver {
/// @returns the resolved type of the ast::Expression `expr` /// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression /// @param expr the expression
sem::Type* TypeOf(const ast::Expression* expr); sem::Type* TypeOf(const ast::Expression* expr) const;
/// @returns the type name of the given semantic type, unwrapping /// @returns the type name of the given semantic type, unwrapping
/// references. /// references.
std::string TypeNameOf(const sem::Type* ty); std::string TypeNameOf(const sem::Type* ty) const;
/// @returns the type name of the given semantic type, without unwrapping /// @returns the type name of the given semantic type, without unwrapping
/// references. /// references.
std::string RawTypeNameOf(const sem::Type* ty); std::string RawTypeNameOf(const sem::Type* ty) const;
/// @returns the semantic type of the AST literal `lit` /// @returns the semantic type of the AST literal `lit`
/// @param lit the literal /// @param lit the literal
@ -425,13 +436,6 @@ class Resolver {
template <typename SEM, typename F> template <typename SEM, typename F>
SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback); SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback);
/// Returns a human-readable string representation of the vector type name
/// with the given parameters.
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, const sem::Type* element_type);
/// Mark records that the given AST node has been visited, and asserts that /// Mark records that the given AST node has been visited, and asserts that
/// the given node has not already been seen. Diamonds in the AST are /// the given node has not already been seen. Diamonds in the AST are
/// illegal. /// illegal.
@ -466,7 +470,7 @@ class Resolver {
/// Sem is a helper for obtaining the semantic node for the given AST node. /// Sem is a helper for obtaining the semantic node for the given AST node.
template <typename SEM = sem::Info::InferFromAST, template <typename SEM = sem::Info::InferFromAST,
typename AST_OR_TYPE = CastableBase> typename AST_OR_TYPE = CastableBase>
auto* Sem(const AST_OR_TYPE* ast) { auto* Sem(const AST_OR_TYPE* ast) const {
using T = sem::Info::GetResultType<SEM, AST_OR_TYPE>; using T = sem::Info::GetResultType<SEM, AST_OR_TYPE>;
auto* sem = builder_->Sem().Get(ast); auto* sem = builder_->Sem().Get(ast);
if (!sem) { if (!sem) {
@ -495,7 +499,7 @@ class Resolver {
/// @returns the resolved symbol (function, type or variable) for the given /// @returns the resolved symbol (function, type or variable) for the given
/// ast::Identifier or ast::TypeName cast to the given semantic type. /// ast::Identifier or ast::TypeName cast to the given semantic type.
template <typename SEM = sem::Node> template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) { SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node); auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved)) return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(resolved))
: nullptr; : nullptr;

View File

@ -149,7 +149,8 @@ void TraverseCallChain(diag::List& diagnostics,
} // namespace } // namespace
bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) { bool Resolver::ValidateAtomic(const ast::Atomic* a,
const sem::Atomic* s) const {
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// T must be either u32 or i32. // T must be either u32 or i32.
if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) { if (!s->Type()->IsAnyOf<sem::U32, sem::I32>()) {
@ -160,7 +161,7 @@ bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) {
return true; return true;
} }
bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) { bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const {
switch (t->access) { switch (t->access) {
case ast::Access::kWrite: case ast::Access::kWrite:
break; break;
@ -193,7 +194,7 @@ bool Resolver::ValidateVariableConstructorOrCast(
const ast::Variable* var, const ast::Variable* var,
ast::StorageClass storage_class, ast::StorageClass storage_class,
const sem::Type* storage_ty, const sem::Type* storage_ty,
const sem::Type* rhs_ty) { const sem::Type* rhs_ty) const {
auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS
// Value type has to match storage type // Value type has to match storage type
@ -228,10 +229,11 @@ bool Resolver::ValidateVariableConstructorOrCast(
return true; return true;
} }
bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty, bool Resolver::ValidateStorageClassLayout(
ast::StorageClass sc, const sem::Type* store_ty,
Source source, ast::StorageClass sc,
ValidTypeStorageLayouts& layouts) { Source source,
ValidTypeStorageLayouts& layouts) const {
// https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints // https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints
auto is_uniform_struct_or_array = [sc](const sem::Type* ty) { auto is_uniform_struct_or_array = [sc](const sem::Type* ty) {
@ -257,7 +259,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty,
}; };
// Cache result of type + storage class pair. // Cache result of type + storage class pair.
if (!valid_type_storage_layouts_.emplace(store_ty, sc).second) { if (!layouts.emplace(store_ty, sc).second) {
return true; return true;
} }
@ -382,8 +384,9 @@ bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty,
return true; return true;
} }
bool Resolver::ValidateStorageClassLayout(const sem::Variable* var, bool Resolver::ValidateStorageClassLayout(
ValidTypeStorageLayouts& layouts) { const sem::Variable* var,
ValidTypeStorageLayouts& layouts) const {
if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) { if (auto* str = var->Type()->UnwrapRef()->As<sem::Struct>()) {
if (!ValidateStorageClassLayout(str, var->StorageClass(), if (!ValidateStorageClassLayout(str, var->StorageClass(),
str->Declaration()->source, layouts)) { str->Declaration()->source, layouts)) {
@ -404,7 +407,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Variable* var,
return true; return true;
} }
bool Resolver::ValidateGlobalVariable(const sem::Variable* var) { bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
if (!ValidateNoDuplicateAttributes(decl->attributes)) { if (!ValidateNoDuplicateAttributes(decl->attributes)) {
return false; return false;
@ -510,7 +513,7 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) {
// https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types
// Atomic types may only be instantiated by variables in the workgroup storage // Atomic types may only be instantiated by variables in the workgroup storage
// class or by storage buffer variables with a read_write access mode. // class or by storage buffer variables with a read_write access mode.
bool Resolver::ValidateAtomicVariable(const sem::Variable* var) { bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const {
auto sc = var->StorageClass(); auto sc = var->StorageClass();
auto* decl = var->Declaration(); auto* decl = var->Declaration();
auto access = var->Access(); auto access = var->Access();
@ -554,7 +557,7 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) {
return true; return true;
} }
bool Resolver::ValidateVariable(const sem::Variable* var) { bool Resolver::ValidateVariable(const sem::Variable* var) const {
auto* decl = var->Declaration(); auto* decl = var->Declaration();
auto* storage_ty = var->Type()->UnwrapRef(); auto* storage_ty = var->Type()->UnwrapRef();
@ -630,7 +633,7 @@ bool Resolver::ValidateVariable(const sem::Variable* var) {
} }
bool Resolver::ValidateFunctionParameter(const ast::Function* func, bool Resolver::ValidateFunctionParameter(const ast::Function* func,
const sem::Variable* var) { const sem::Variable* var) const {
if (!ValidateVariable(var)) { if (!ValidateVariable(var)) {
return false; return false;
} }
@ -695,7 +698,7 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func,
bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
const sem::Type* storage_ty, const sem::Type* storage_ty,
ast::PipelineStage stage, ast::PipelineStage stage,
const bool is_input) { const bool is_input) const {
auto* type = storage_ty->UnwrapRef(); auto* type = storage_ty->UnwrapRef();
std::stringstream stage_name; std::stringstream stage_name;
stage_name << stage; stage_name << stage;
@ -813,7 +816,7 @@ bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr,
bool Resolver::ValidateInterpolateAttribute( bool Resolver::ValidateInterpolateAttribute(
const ast::InterpolateAttribute* attr, const ast::InterpolateAttribute* attr,
const sem::Type* storage_ty) { const sem::Type* storage_ty) const {
auto* type = storage_ty->UnwrapRef(); auto* type = storage_ty->UnwrapRef();
if (type->is_integer_scalar_or_vector() && if (type->is_integer_scalar_or_vector() &&
@ -835,7 +838,7 @@ bool Resolver::ValidateInterpolateAttribute(
} }
bool Resolver::ValidateFunction(const sem::Function* func, bool Resolver::ValidateFunction(const sem::Function* func,
ast::PipelineStage stage) { ast::PipelineStage stage) const {
auto* decl = func->Declaration(); auto* decl = func->Declaration();
auto name = builder_->Symbols().NameFor(decl->symbol); auto name = builder_->Symbols().NameFor(decl->symbol);
@ -941,7 +944,7 @@ bool Resolver::ValidateFunction(const sem::Function* func,
} }
bool Resolver::ValidateEntryPoint(const sem::Function* func, bool Resolver::ValidateEntryPoint(const sem::Function* func,
ast::PipelineStage stage) { ast::PipelineStage stage) const {
auto* decl = func->Declaration(); auto* decl = func->Declaration();
// Use a lambda to validate the entry point attributes for a type. // Use a lambda to validate the entry point attributes for a type.
@ -1217,7 +1220,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func,
return true; return true;
} }
bool Resolver::ValidateStatements(const ast::StatementList& stmts) { bool Resolver::ValidateStatements(const ast::StatementList& stmts) const {
for (auto* stmt : stmts) { for (auto* stmt : stmts) {
if (!Sem(stmt)->IsReachable()) { if (!Sem(stmt)->IsReachable()) {
/// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to /// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to
@ -1230,7 +1233,7 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) {
} }
bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
const sem::Type* to) { const sem::Type* to) const {
auto* from = TypeOf(cast->expr)->UnwrapRef(); auto* from = TypeOf(cast->expr)->UnwrapRef();
if (!from->is_numeric_scalar_or_vector()) { if (!from->is_numeric_scalar_or_vector()) {
AddError("'" + TypeNameOf(from) + "' cannot be bitcast", AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
@ -1259,7 +1262,7 @@ bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
return true; return true;
} }
bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) { bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const {
if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) { if (!stmt->FindFirstParent<sem::LoopBlockStatement, sem::CaseStatement>()) {
AddError("break statement must be in a loop or switch case", AddError("break statement must be in a loop or switch case",
stmt->Declaration()->source); stmt->Declaration()->source);
@ -1326,7 +1329,7 @@ bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) {
return true; return true;
} }
bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) { bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const {
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) { if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) {
AddError("continuing blocks must not contain a continue statement", AddError("continuing blocks must not contain a continue statement",
stmt->Declaration()->source); stmt->Declaration()->source);
@ -1346,7 +1349,7 @@ bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) {
return true; return true;
} }
bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) { bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const {
if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) {
AddError("continuing blocks must not contain a discard statement", AddError("continuing blocks must not contain a discard statement",
stmt->Declaration()->source); stmt->Declaration()->source);
@ -1359,7 +1362,7 @@ bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) {
return true; return true;
} }
bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) { bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const {
if (auto* block = As<sem::BlockStatement>(stmt->Parent())) { if (auto* block = As<sem::BlockStatement>(stmt->Parent())) {
if (auto* c = As<sem::CaseStatement>(block->Parent())) { if (auto* c = As<sem::CaseStatement>(block->Parent())) {
if (block->Declaration()->Last() == stmt->Declaration()) { if (block->Declaration()->Last() == stmt->Declaration()) {
@ -1382,7 +1385,7 @@ bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) {
return false; return false;
} }
bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) { bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const {
if (auto* cond = stmt->Condition()) { if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef(); auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) { if (!cond_ty->Is<sem::Bool>()) {
@ -1395,7 +1398,7 @@ bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) {
return true; return true;
} }
bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) { bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) { if (stmt->Behaviors().Empty()) {
AddError("loop does not exit", stmt->Declaration()->source.Begin()); AddError("loop does not exit", stmt->Declaration()->source.Begin());
return false; return false;
@ -1403,7 +1406,8 @@ bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) {
return true; return true;
} }
bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) { bool Resolver::ValidateForLoopStatement(
const sem::ForLoopStatement* stmt) const {
if (stmt->Behaviors().Empty()) { if (stmt->Behaviors().Empty()) {
AddError("for-loop does not exit", stmt->Declaration()->source.Begin()); AddError("for-loop does not exit", stmt->Declaration()->source.Begin());
return false; return false;
@ -1419,7 +1423,7 @@ bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) {
return true; return true;
} }
bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) { bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef(); auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) { if (!cond_ty->Is<sem::Bool>()) {
AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty), AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty),
@ -1429,7 +1433,7 @@ bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) {
return true; return true;
} }
bool Resolver::ValidateBuiltinCall(const sem::Call* call) { bool Resolver::ValidateBuiltinCall(const sem::Call* call) const {
if (call->Type()->Is<sem::Void>()) { if (call->Type()->Is<sem::Void>()) {
bool is_call_statement = false; bool is_call_statement = false;
if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) { if (auto* call_stmt = As<ast::CallStatement>(call->Stmt()->Declaration())) {
@ -1452,7 +1456,7 @@ bool Resolver::ValidateBuiltinCall(const sem::Call* call) {
return true; return true;
} }
bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) { bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const {
auto* builtin = call->Target()->As<sem::Builtin>(); auto* builtin = call->Target()->As<sem::Builtin>();
if (!builtin) { if (!builtin) {
return false; return false;
@ -1524,7 +1528,7 @@ bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) {
check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3); check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3);
} }
bool Resolver::ValidateFunctionCall(const sem::Call* call) { bool Resolver::ValidateFunctionCall(const sem::Call* call) const {
auto* decl = call->Declaration(); auto* decl = call->Declaration();
auto* target = call->Target()->As<sem::Function>(); auto* target = call->Target()->As<sem::Function>();
auto sym = decl->target.name->symbol; auto sym = decl->target.name->symbol;
@ -1642,7 +1646,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) {
bool Resolver::ValidateStructureConstructorOrCast( bool Resolver::ValidateStructureConstructorOrCast(
const ast::CallExpression* ctor, const ast::CallExpression* ctor,
const sem::Struct* struct_type) { const sem::Struct* struct_type) const {
if (!struct_type->IsConstructible()) { if (!struct_type->IsConstructible()) {
AddError("struct constructor has non-constructible type", ctor->source); AddError("struct constructor has non-constructible type", ctor->source);
return false; return false;
@ -1675,8 +1679,9 @@ bool Resolver::ValidateStructureConstructorOrCast(
return true; return true;
} }
bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor, bool Resolver::ValidateArrayConstructorOrCast(
const sem::Array* array_type) { const ast::CallExpression* ctor,
const sem::Array* array_type) const {
auto& values = ctor->args; auto& values = ctor->args;
auto* elem_ty = array_type->ElemType(); auto* elem_ty = array_type->ElemType();
for (auto* value : values) { for (auto* value : values) {
@ -1715,8 +1720,9 @@ bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor,
return true; return true;
} }
bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor, bool Resolver::ValidateVectorConstructorOrCast(
const sem::Vector* vec_type) { const ast::CallExpression* ctor,
const sem::Vector* vec_type) const {
auto& values = ctor->args; auto& values = ctor->args;
auto* elem_ty = vec_type->type(); auto* elem_ty = vec_type->type();
size_t value_cardinality_sum = 0; size_t value_cardinality_sum = 0;
@ -1776,7 +1782,8 @@ bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor,
return true; return true;
} }
bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) { bool Resolver::ValidateVector(const sem::Vector* ty,
const Source& source) const {
if (!ty->type()->is_scalar()) { if (!ty->type()->is_scalar()) {
AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'", AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'",
source); source);
@ -1785,7 +1792,8 @@ bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) {
return true; return true;
} }
bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) { bool Resolver::ValidateMatrix(const sem::Matrix* ty,
const Source& source) const {
if (!ty->is_float_matrix()) { if (!ty->is_float_matrix()) {
AddError("matrix element type must be 'f32'", source); AddError("matrix element type must be 'f32'", source);
return false; return false;
@ -1793,8 +1801,9 @@ bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) {
return true; return true;
} }
bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor, bool Resolver::ValidateMatrixConstructorOrCast(
const sem::Matrix* matrix_ty) { const ast::CallExpression* ctor,
const sem::Matrix* matrix_ty) const {
auto& values = ctor->args; auto& values = ctor->args;
// Zero Value expression // Zero Value expression
if (values.empty()) { if (values.empty()) {
@ -1869,7 +1878,7 @@ bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor,
} }
bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
const sem::Type* ty) { const sem::Type* ty) const {
if (ctor->args.size() == 0) { if (ctor->args.size() == 0) {
return true; return true;
} }
@ -1904,7 +1913,7 @@ bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor,
return true; return true;
} }
bool Resolver::ValidatePipelineStages() { bool Resolver::ValidatePipelineStages() const {
auto check_workgroup_storage = [&](const sem::Function* func, auto check_workgroup_storage = [&](const sem::Function* func,
const sem::Function* entry_point) { const sem::Function* entry_point) {
auto stage = entry_point->Declaration()->PipelineStage(); auto stage = entry_point->Declaration()->PipelineStage();
@ -1999,7 +2008,8 @@ bool Resolver::ValidatePipelineStages() {
return true; return true;
} }
bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) { bool Resolver::ValidateArray(const sem::Array* arr,
const Source& source) const {
auto* el_ty = arr->ElemType(); auto* el_ty = arr->ElemType();
if (!IsFixedFootprint(el_ty)) { if (!IsFixedFootprint(el_ty)) {
@ -2013,7 +2023,7 @@ bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
uint32_t el_size, uint32_t el_size,
uint32_t el_align, uint32_t el_align,
const Source& source) { const Source& source) const {
auto stride = attr->stride; auto stride = attr->stride;
bool is_valid_stride = bool is_valid_stride =
(stride >= el_size) && (stride >= el_align) && (stride % el_align == 0); (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0);
@ -2032,7 +2042,7 @@ bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr,
return true; return true;
} }
bool Resolver::ValidateAlias(const ast::Alias* alias) { bool Resolver::ValidateAlias(const ast::Alias* alias) const {
auto name = builder_->Symbols().NameFor(alias->name); auto name = builder_->Symbols().NameFor(alias->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as an alias", AddError("'" + name + "' is a builtin and cannot be redeclared as an alias",
@ -2044,7 +2054,7 @@ bool Resolver::ValidateAlias(const ast::Alias* alias) {
} }
bool Resolver::ValidateStructure(const sem::Struct* str, bool Resolver::ValidateStructure(const sem::Struct* str,
ast::PipelineStage stage) { ast::PipelineStage stage) const {
auto name = builder_->Symbols().NameFor(str->Declaration()->name); auto name = builder_->Symbols().NameFor(str->Declaration()->name);
if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) {
AddError("'" + name + "' is a builtin and cannot be redeclared as a struct", AddError("'" + name + "' is a builtin and cannot be redeclared as a struct",
@ -2153,7 +2163,7 @@ bool Resolver::ValidateLocationAttribute(
std::unordered_set<uint32_t>& locations, std::unordered_set<uint32_t>& locations,
ast::PipelineStage stage, ast::PipelineStage stage,
const Source& source, const Source& source,
const bool is_input) { const bool is_input) const {
std::string inputs_or_output = is_input ? "inputs" : "output"; std::string inputs_or_output = is_input ? "inputs" : "output";
if (stage == ast::PipelineStage::kCompute) { if (stage == ast::PipelineStage::kCompute) {
AddError("attribute is not valid for compute shader " + inputs_or_output, AddError("attribute is not valid for compute shader " + inputs_or_output,
@ -2185,7 +2195,7 @@ bool Resolver::ValidateLocationAttribute(
bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, bool Resolver::ValidateReturn(const ast::ReturnStatement* ret,
const sem::Type* func_type, const sem::Type* func_type,
const sem::Type* ret_type) { const sem::Type* ret_type) const {
if (func_type->UnwrapRef() != ret_type) { if (func_type->UnwrapRef() != ret_type) {
AddError( AddError(
"return statement type must match its function " "return statement type must match its function "
@ -2267,7 +2277,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) {
} }
bool Resolver::ValidateAssignment(const ast::Statement* a, bool Resolver::ValidateAssignment(const ast::Statement* a,
const sem::Type* rhs_ty) { const sem::Type* rhs_ty) const {
const ast::Expression* lhs; const ast::Expression* lhs;
const ast::Expression* rhs; const ast::Expression* rhs;
if (auto* assign = a->As<ast::AssignmentStatement>()) { if (auto* assign = a->As<ast::AssignmentStatement>()) {
@ -2349,7 +2359,7 @@ bool Resolver::ValidateAssignment(const ast::Statement* a,
} }
bool Resolver::ValidateIncrementDecrementStatement( bool Resolver::ValidateIncrementDecrementStatement(
const ast::IncrementDecrementStatement* inc) { const ast::IncrementDecrementStatement* inc) const {
const ast::Expression* lhs = inc->lhs; const ast::Expression* lhs = inc->lhs;
// https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement
@ -2397,7 +2407,7 @@ bool Resolver::ValidateIncrementDecrementStatement(
} }
bool Resolver::ValidateNoDuplicateAttributes( bool Resolver::ValidateNoDuplicateAttributes(
const ast::AttributeList& attributes) { const ast::AttributeList& attributes) const {
std::unordered_map<const TypeInfo*, Source> seen; std::unordered_map<const TypeInfo*, Source> seen;
for (auto* d : attributes) { for (auto* d : attributes) {
auto res = seen.emplace(&d->TypeInfo(), d->source); auto res = seen.emplace(&d->TypeInfo(), d->source);
@ -2427,4 +2437,10 @@ bool Resolver::IsValidationEnabled(const ast::AttributeList& attributes,
return !IsValidationDisabled(attributes, validation); return !IsValidationDisabled(attributes, validation);
} }
std::string Resolver::VectorPretty(uint32_t size,
const sem::Type* element_type) const {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols());
}
} // namespace tint::resolver } // namespace tint::resolver