diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 5084d221e0..56e59009a6 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -2143,19 +2143,19 @@ sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) { return result; } -sem::Type* Resolver::TypeOf(const ast::Expression* expr) { - auto* sem = Sem(expr); - return sem ? const_cast(sem->Type()) : nullptr; -} - -std::string Resolver::TypeNameOf(const sem::Type* ty) { +std::string Resolver::TypeNameOf(const sem::Type* ty) const { return RawTypeNameOf(ty->UnwrapRef()); } -std::string Resolver::RawTypeNameOf(const sem::Type* ty) { +std::string Resolver::RawTypeNameOf(const sem::Type* ty) const { return ty->FriendlyName(builder_->Symbols()); } +sem::Type* Resolver::TypeOf(const ast::Expression* expr) const { + auto* sem = Sem(expr); + return sem ? const_cast(sem->Type()) : nullptr; +} + sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) { return Switch( lit, @@ -2782,12 +2782,6 @@ SEM* Resolver::StatementScope(const ast::Statement* ast, return sem; } -std::string Resolver::VectorPretty(uint32_t size, - const sem::Type* element_type) { - sem::Vector vec_type(element_type, size); - return vec_type.FriendlyName(builder_->Symbols()); -} - bool Resolver::Mark(const ast::Node* node) { if (node == nullptr) { TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr"; diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 1e7711fc5b..7531585880 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -154,8 +154,6 @@ class Resolver { /// @returns true on success, false on error bool ResolveInternal(); - bool ValidatePipelineStages(); - /// Creates the nodes and adds them to the sem::Info mappings of the /// ProgramBuilder. void CreateSemanticNodes() const; @@ -240,78 +238,84 @@ class Resolver { // AST and Type validation methods // Each return true on success, false on failure. - bool ValidateAlias(const ast::Alias*); - bool ValidateArray(const sem::Array* arr, const Source& source); + bool ValidatePipelineStages() const; + bool ValidateAlias(const ast::Alias*) const; + bool ValidateArray(const sem::Array* arr, const Source& source) const; bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, uint32_t el_size, uint32_t el_align, - const Source& source); - bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s); - bool ValidateAtomicVariable(const sem::Variable* var); - bool ValidateAssignment(const ast::Statement* a, const sem::Type* rhs_ty); - bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to); - bool ValidateBreakStatement(const sem::Statement* stmt); + const Source& source) const; + bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) const; + bool ValidateAtomicVariable(const sem::Variable* var) const; + bool ValidateAssignment(const ast::Statement* a, + const sem::Type* rhs_ty) const; + bool ValidateBitcast(const ast::BitcastExpression* cast, + const sem::Type* to) const; + bool ValidateBreakStatement(const sem::Statement* stmt) const; bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, const sem::Type* storage_type, ast::PipelineStage stage, - const bool is_input); - bool ValidateContinueStatement(const sem::Statement* stmt); - bool ValidateDiscardStatement(const sem::Statement* stmt); - bool ValidateElseStatement(const sem::ElseStatement* stmt); - bool ValidateEntryPoint(const sem::Function* func, ast::PipelineStage stage); - bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt); - bool ValidateFallthroughStatement(const sem::Statement* stmt); - bool ValidateFunction(const sem::Function* func, ast::PipelineStage stage); - bool ValidateFunctionCall(const sem::Call* call); - bool ValidateGlobalVariable(const sem::Variable* var); - bool ValidateIfStatement(const sem::IfStatement* stmt); + const bool is_input) const; + bool ValidateContinueStatement(const sem::Statement* stmt) const; + bool ValidateDiscardStatement(const sem::Statement* stmt) const; + bool ValidateElseStatement(const sem::ElseStatement* stmt) const; + bool ValidateEntryPoint(const sem::Function* func, + ast::PipelineStage stage) const; + bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt) const; + bool ValidateFallthroughStatement(const sem::Statement* stmt) const; + bool ValidateFunction(const sem::Function* func, + ast::PipelineStage stage) const; + bool ValidateFunctionCall(const sem::Call* call) const; + bool ValidateGlobalVariable(const sem::Variable* var) const; + bool ValidateIfStatement(const sem::IfStatement* stmt) const; bool ValidateIncrementDecrementStatement( - const ast::IncrementDecrementStatement* stmt); + const ast::IncrementDecrementStatement* stmt) const; bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr, - const sem::Type* storage_type); - bool ValidateBuiltinCall(const sem::Call* call); + const sem::Type* storage_type) const; + bool ValidateBuiltinCall(const sem::Call* call) const; bool ValidateLocationAttribute(const ast::LocationAttribute* location, const sem::Type* type, std::unordered_set& locations, ast::PipelineStage stage, const Source& source, - const bool is_input = false); - bool ValidateLoopStatement(const sem::LoopStatement* stmt); - bool ValidateMatrix(const sem::Matrix* ty, const Source& source); + const bool is_input = false) const; + bool ValidateLoopStatement(const sem::LoopStatement* stmt) const; + bool ValidateMatrix(const sem::Matrix* ty, const Source& source) const; bool ValidateFunctionParameter(const ast::Function* func, - const sem::Variable* var); - bool ValidateParameter(const ast::Function* func, const sem::Variable* var); + const sem::Variable* var) const; bool ValidateReturn(const ast::ReturnStatement* ret, const sem::Type* func_type, - const sem::Type* ret_type); - bool ValidateStatements(const ast::StatementList& stmts); - bool ValidateStorageTexture(const ast::StorageTexture* t); - bool ValidateStructure(const sem::Struct* str, ast::PipelineStage stage); + const sem::Type* ret_type) const; + bool ValidateStatements(const ast::StatementList& stmts) const; + bool ValidateStorageTexture(const ast::StorageTexture* t) const; + bool ValidateStructure(const sem::Struct* str, + ast::PipelineStage stage) const; bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor, - const sem::Struct* struct_type); + const sem::Struct* struct_type) const; bool ValidateSwitch(const ast::SwitchStatement* s); - bool ValidateVariable(const sem::Variable* var); + bool ValidateVariable(const sem::Variable* var) const; bool ValidateVariableConstructorOrCast(const ast::Variable* var, ast::StorageClass storage_class, const sem::Type* storage_type, - const sem::Type* rhs_type); - bool ValidateVector(const sem::Vector* ty, const Source& source); + const sem::Type* rhs_type) const; + bool ValidateVector(const sem::Vector* ty, const Source& source) const; bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor, - const sem::Vector* vec_type); + const sem::Vector* vec_type) const; bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor, - const sem::Matrix* matrix_type); + const sem::Matrix* matrix_type) const; bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, - const sem::Type* type); + const sem::Type* type) const; bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor, - const sem::Array* arr_type); - bool ValidateTextureBuiltinFunction(const sem::Call* call); - bool ValidateNoDuplicateAttributes(const ast::AttributeList& attributes); + const sem::Array* arr_type) const; + bool ValidateTextureBuiltinFunction(const sem::Call* call) const; + bool ValidateNoDuplicateAttributes( + const ast::AttributeList& attributes) const; bool ValidateStorageClassLayout(const sem::Type* type, ast::StorageClass sc, Source source, - ValidTypeStorageLayouts& layout); + ValidTypeStorageLayouts& layouts) const; bool ValidateStorageClassLayout(const sem::Variable* var, - ValidTypeStorageLayouts& layout); + ValidTypeStorageLayouts& layouts) const; /// @returns true if the attribute list contains a /// ast::DisableValidationAttribute with the validation mode equal to @@ -325,6 +329,13 @@ class Resolver { bool IsValidationEnabled(const ast::AttributeList& attributes, ast::DisabledValidation validation) const; + /// Returns a human-readable string representation of the vector type name + /// with the given parameters. + /// @param size the vector dimension + /// @param element_type scalar vector sub-element type + /// @return pretty string representation + std::string VectorPretty(uint32_t size, const sem::Type* element_type) const; + /// Resolves the WorkgroupSize for the given function, assigning it to /// current_function_ bool WorkgroupSize(const ast::Function*); @@ -397,15 +408,15 @@ class Resolver { /// @returns the resolved type of the ast::Expression `expr` /// @param expr the expression - sem::Type* TypeOf(const ast::Expression* expr); + sem::Type* TypeOf(const ast::Expression* expr) const; /// @returns the type name of the given semantic type, unwrapping /// references. - std::string TypeNameOf(const sem::Type* ty); + std::string TypeNameOf(const sem::Type* ty) const; /// @returns the type name of the given semantic type, without unwrapping /// references. - std::string RawTypeNameOf(const sem::Type* ty); + std::string RawTypeNameOf(const sem::Type* ty) const; /// @returns the semantic type of the AST literal `lit` /// @param lit the literal @@ -425,13 +436,6 @@ class Resolver { template SEM* StatementScope(const ast::Statement* ast, SEM* sem, F&& callback); - /// Returns a human-readable string representation of the vector type name - /// with the given parameters. - /// @param size the vector dimension - /// @param element_type scalar vector sub-element type - /// @return pretty string representation - std::string VectorPretty(uint32_t size, const sem::Type* element_type); - /// Mark records that the given AST node has been visited, and asserts that /// the given node has not already been seen. Diamonds in the AST are /// illegal. @@ -466,7 +470,7 @@ class Resolver { /// Sem is a helper for obtaining the semantic node for the given AST node. template - auto* Sem(const AST_OR_TYPE* ast) { + auto* Sem(const AST_OR_TYPE* ast) const { using T = sem::Info::GetResultType; auto* sem = builder_->Sem().Get(ast); if (!sem) { @@ -495,7 +499,7 @@ class Resolver { /// @returns the resolved symbol (function, type or variable) for the given /// ast::Identifier or ast::TypeName cast to the given semantic type. template - SEM* ResolvedSymbol(const ast::Node* node) { + SEM* ResolvedSymbol(const ast::Node* node) const { auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node); return resolved ? const_cast(builder_->Sem().Get(resolved)) : nullptr; diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc index 161204903c..04b9925f0a 100644 --- a/src/tint/resolver/resolver_validation.cc +++ b/src/tint/resolver/resolver_validation.cc @@ -149,7 +149,8 @@ void TraverseCallChain(diag::List& diagnostics, } // namespace -bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) { +bool Resolver::ValidateAtomic(const ast::Atomic* a, + const sem::Atomic* s) const { // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // T must be either u32 or i32. if (!s->Type()->IsAnyOf()) { @@ -160,7 +161,7 @@ bool Resolver::ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) { return true; } -bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) { +bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const { switch (t->access) { case ast::Access::kWrite: break; @@ -193,7 +194,7 @@ bool Resolver::ValidateVariableConstructorOrCast( const ast::Variable* var, ast::StorageClass storage_class, const sem::Type* storage_ty, - const sem::Type* rhs_ty) { + const sem::Type* rhs_ty) const { auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS // Value type has to match storage type @@ -228,10 +229,11 @@ bool Resolver::ValidateVariableConstructorOrCast( return true; } -bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty, - ast::StorageClass sc, - Source source, - ValidTypeStorageLayouts& layouts) { +bool Resolver::ValidateStorageClassLayout( + const sem::Type* store_ty, + ast::StorageClass sc, + Source source, + ValidTypeStorageLayouts& layouts) const { // https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints auto is_uniform_struct_or_array = [sc](const sem::Type* ty) { @@ -257,7 +259,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty, }; // Cache result of type + storage class pair. - if (!valid_type_storage_layouts_.emplace(store_ty, sc).second) { + if (!layouts.emplace(store_ty, sc).second) { return true; } @@ -382,8 +384,9 @@ bool Resolver::ValidateStorageClassLayout(const sem::Type* store_ty, return true; } -bool Resolver::ValidateStorageClassLayout(const sem::Variable* var, - ValidTypeStorageLayouts& layouts) { +bool Resolver::ValidateStorageClassLayout( + const sem::Variable* var, + ValidTypeStorageLayouts& layouts) const { if (auto* str = var->Type()->UnwrapRef()->As()) { if (!ValidateStorageClassLayout(str, var->StorageClass(), str->Declaration()->source, layouts)) { @@ -404,7 +407,7 @@ bool Resolver::ValidateStorageClassLayout(const sem::Variable* var, return true; } -bool Resolver::ValidateGlobalVariable(const sem::Variable* var) { +bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const { auto* decl = var->Declaration(); if (!ValidateNoDuplicateAttributes(decl->attributes)) { return false; @@ -510,7 +513,7 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) { // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // Atomic types may only be instantiated by variables in the workgroup storage // class or by storage buffer variables with a read_write access mode. -bool Resolver::ValidateAtomicVariable(const sem::Variable* var) { +bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const { auto sc = var->StorageClass(); auto* decl = var->Declaration(); auto access = var->Access(); @@ -554,7 +557,7 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) { return true; } -bool Resolver::ValidateVariable(const sem::Variable* var) { +bool Resolver::ValidateVariable(const sem::Variable* var) const { auto* decl = var->Declaration(); auto* storage_ty = var->Type()->UnwrapRef(); @@ -630,7 +633,7 @@ bool Resolver::ValidateVariable(const sem::Variable* var) { } bool Resolver::ValidateFunctionParameter(const ast::Function* func, - const sem::Variable* var) { + const sem::Variable* var) const { if (!ValidateVariable(var)) { return false; } @@ -695,7 +698,7 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, const sem::Type* storage_ty, ast::PipelineStage stage, - const bool is_input) { + const bool is_input) const { auto* type = storage_ty->UnwrapRef(); std::stringstream stage_name; stage_name << stage; @@ -813,7 +816,7 @@ bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, bool Resolver::ValidateInterpolateAttribute( const ast::InterpolateAttribute* attr, - const sem::Type* storage_ty) { + const sem::Type* storage_ty) const { auto* type = storage_ty->UnwrapRef(); if (type->is_integer_scalar_or_vector() && @@ -835,7 +838,7 @@ bool Resolver::ValidateInterpolateAttribute( } bool Resolver::ValidateFunction(const sem::Function* func, - ast::PipelineStage stage) { + ast::PipelineStage stage) const { auto* decl = func->Declaration(); auto name = builder_->Symbols().NameFor(decl->symbol); @@ -941,7 +944,7 @@ bool Resolver::ValidateFunction(const sem::Function* func, } bool Resolver::ValidateEntryPoint(const sem::Function* func, - ast::PipelineStage stage) { + ast::PipelineStage stage) const { auto* decl = func->Declaration(); // Use a lambda to validate the entry point attributes for a type. @@ -1217,7 +1220,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func, return true; } -bool Resolver::ValidateStatements(const ast::StatementList& stmts) { +bool Resolver::ValidateStatements(const ast::StatementList& stmts) const { for (auto* stmt : stmts) { if (!Sem(stmt)->IsReachable()) { /// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to @@ -1230,7 +1233,7 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) { } bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, - const sem::Type* to) { + const sem::Type* to) const { auto* from = TypeOf(cast->expr)->UnwrapRef(); if (!from->is_numeric_scalar_or_vector()) { AddError("'" + TypeNameOf(from) + "' cannot be bitcast", @@ -1259,7 +1262,7 @@ bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, return true; } -bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) { +bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const { if (!stmt->FindFirstParent()) { AddError("break statement must be in a loop or switch case", stmt->Declaration()->source); @@ -1326,7 +1329,7 @@ bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) { return true; } -bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) { +bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const { if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) { AddError("continuing blocks must not contain a continue statement", stmt->Declaration()->source); @@ -1346,7 +1349,7 @@ bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) { return true; } -bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) { +bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const { if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { AddError("continuing blocks must not contain a discard statement", stmt->Declaration()->source); @@ -1359,7 +1362,7 @@ bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) { return true; } -bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) { +bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const { if (auto* block = As(stmt->Parent())) { if (auto* c = As(block->Parent())) { if (block->Declaration()->Last() == stmt->Declaration()) { @@ -1382,7 +1385,7 @@ bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) { return false; } -bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) { +bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const { if (auto* cond = stmt->Condition()) { auto* cond_ty = cond->Type()->UnwrapRef(); if (!cond_ty->Is()) { @@ -1395,7 +1398,7 @@ bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) { return true; } -bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) { +bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const { if (stmt->Behaviors().Empty()) { AddError("loop does not exit", stmt->Declaration()->source.Begin()); return false; @@ -1403,7 +1406,8 @@ bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) { return true; } -bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) { +bool Resolver::ValidateForLoopStatement( + const sem::ForLoopStatement* stmt) const { if (stmt->Behaviors().Empty()) { AddError("for-loop does not exit", stmt->Declaration()->source.Begin()); return false; @@ -1419,7 +1423,7 @@ bool Resolver::ValidateForLoopStatement(const sem::ForLoopStatement* stmt) { return true; } -bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) { +bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const { auto* cond_ty = stmt->Condition()->Type()->UnwrapRef(); if (!cond_ty->Is()) { AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty), @@ -1429,7 +1433,7 @@ bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) { return true; } -bool Resolver::ValidateBuiltinCall(const sem::Call* call) { +bool Resolver::ValidateBuiltinCall(const sem::Call* call) const { if (call->Type()->Is()) { bool is_call_statement = false; if (auto* call_stmt = As(call->Stmt()->Declaration())) { @@ -1452,7 +1456,7 @@ bool Resolver::ValidateBuiltinCall(const sem::Call* call) { return true; } -bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) { +bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const { auto* builtin = call->Target()->As(); if (!builtin) { return false; @@ -1524,7 +1528,7 @@ bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) { check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3); } -bool Resolver::ValidateFunctionCall(const sem::Call* call) { +bool Resolver::ValidateFunctionCall(const sem::Call* call) const { auto* decl = call->Declaration(); auto* target = call->Target()->As(); auto sym = decl->target.name->symbol; @@ -1642,7 +1646,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) { bool Resolver::ValidateStructureConstructorOrCast( const ast::CallExpression* ctor, - const sem::Struct* struct_type) { + const sem::Struct* struct_type) const { if (!struct_type->IsConstructible()) { AddError("struct constructor has non-constructible type", ctor->source); return false; @@ -1675,8 +1679,9 @@ bool Resolver::ValidateStructureConstructorOrCast( return true; } -bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor, - const sem::Array* array_type) { +bool Resolver::ValidateArrayConstructorOrCast( + const ast::CallExpression* ctor, + const sem::Array* array_type) const { auto& values = ctor->args; auto* elem_ty = array_type->ElemType(); for (auto* value : values) { @@ -1715,8 +1720,9 @@ bool Resolver::ValidateArrayConstructorOrCast(const ast::CallExpression* ctor, return true; } -bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor, - const sem::Vector* vec_type) { +bool Resolver::ValidateVectorConstructorOrCast( + const ast::CallExpression* ctor, + const sem::Vector* vec_type) const { auto& values = ctor->args; auto* elem_ty = vec_type->type(); size_t value_cardinality_sum = 0; @@ -1776,7 +1782,8 @@ bool Resolver::ValidateVectorConstructorOrCast(const ast::CallExpression* ctor, return true; } -bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) { +bool Resolver::ValidateVector(const sem::Vector* ty, + const Source& source) const { if (!ty->type()->is_scalar()) { AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'", source); @@ -1785,7 +1792,8 @@ bool Resolver::ValidateVector(const sem::Vector* ty, const Source& source) { return true; } -bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) { +bool Resolver::ValidateMatrix(const sem::Matrix* ty, + const Source& source) const { if (!ty->is_float_matrix()) { AddError("matrix element type must be 'f32'", source); return false; @@ -1793,8 +1801,9 @@ bool Resolver::ValidateMatrix(const sem::Matrix* ty, const Source& source) { return true; } -bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor, - const sem::Matrix* matrix_ty) { +bool Resolver::ValidateMatrixConstructorOrCast( + const ast::CallExpression* ctor, + const sem::Matrix* matrix_ty) const { auto& values = ctor->args; // Zero Value expression if (values.empty()) { @@ -1869,7 +1878,7 @@ bool Resolver::ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor, } bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, - const sem::Type* ty) { + const sem::Type* ty) const { if (ctor->args.size() == 0) { return true; } @@ -1904,7 +1913,7 @@ bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, return true; } -bool Resolver::ValidatePipelineStages() { +bool Resolver::ValidatePipelineStages() const { auto check_workgroup_storage = [&](const sem::Function* func, const sem::Function* entry_point) { auto stage = entry_point->Declaration()->PipelineStage(); @@ -1999,7 +2008,8 @@ bool Resolver::ValidatePipelineStages() { return true; } -bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) { +bool Resolver::ValidateArray(const sem::Array* arr, + const Source& source) const { auto* el_ty = arr->ElemType(); if (!IsFixedFootprint(el_ty)) { @@ -2013,7 +2023,7 @@ bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) { bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, uint32_t el_size, uint32_t el_align, - const Source& source) { + const Source& source) const { auto stride = attr->stride; bool is_valid_stride = (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0); @@ -2032,7 +2042,7 @@ bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, return true; } -bool Resolver::ValidateAlias(const ast::Alias* alias) { +bool Resolver::ValidateAlias(const ast::Alias* alias) const { auto name = builder_->Symbols().NameFor(alias->name); if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { AddError("'" + name + "' is a builtin and cannot be redeclared as an alias", @@ -2044,7 +2054,7 @@ bool Resolver::ValidateAlias(const ast::Alias* alias) { } bool Resolver::ValidateStructure(const sem::Struct* str, - ast::PipelineStage stage) { + ast::PipelineStage stage) const { auto name = builder_->Symbols().NameFor(str->Declaration()->name); if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { AddError("'" + name + "' is a builtin and cannot be redeclared as a struct", @@ -2153,7 +2163,7 @@ bool Resolver::ValidateLocationAttribute( std::unordered_set& locations, ast::PipelineStage stage, const Source& source, - const bool is_input) { + const bool is_input) const { std::string inputs_or_output = is_input ? "inputs" : "output"; if (stage == ast::PipelineStage::kCompute) { AddError("attribute is not valid for compute shader " + inputs_or_output, @@ -2185,7 +2195,7 @@ bool Resolver::ValidateLocationAttribute( bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, const sem::Type* func_type, - const sem::Type* ret_type) { + const sem::Type* ret_type) const { if (func_type->UnwrapRef() != ret_type) { AddError( "return statement type must match its function " @@ -2267,7 +2277,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { } bool Resolver::ValidateAssignment(const ast::Statement* a, - const sem::Type* rhs_ty) { + const sem::Type* rhs_ty) const { const ast::Expression* lhs; const ast::Expression* rhs; if (auto* assign = a->As()) { @@ -2349,7 +2359,7 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, } bool Resolver::ValidateIncrementDecrementStatement( - const ast::IncrementDecrementStatement* inc) { + const ast::IncrementDecrementStatement* inc) const { const ast::Expression* lhs = inc->lhs; // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement @@ -2397,7 +2407,7 @@ bool Resolver::ValidateIncrementDecrementStatement( } bool Resolver::ValidateNoDuplicateAttributes( - const ast::AttributeList& attributes) { + const ast::AttributeList& attributes) const { std::unordered_map seen; for (auto* d : attributes) { auto res = seen.emplace(&d->TypeInfo(), d->source); @@ -2427,4 +2437,10 @@ bool Resolver::IsValidationEnabled(const ast::AttributeList& attributes, return !IsValidationDisabled(attributes, validation); } +std::string Resolver::VectorPretty(uint32_t size, + const sem::Type* element_type) const { + sem::Vector vec_type(element_type, size); + return vec_type.FriendlyName(builder_->Symbols()); +} + } // namespace tint::resolver