From f05575bb21584fda1e8e174470348c8d1679cf80 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Thu, 21 Apr 2022 13:40:16 +0000 Subject: [PATCH] [tint] Move validation code into a Validator class. This CL moves the Validate methods from the Resolver into a specific Validator class used by the Resolver. Bug: tint:1313 Change-Id: Ida21a0cc65f2679739c8499de7065ff8b58c4efc Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/87150 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Dan Sinclair --- src/tint/BUILD.gn | 3 +- src/tint/CMakeLists.txt | 3 +- src/tint/resolver/resolver.cc | 207 +++----- src/tint/resolver/resolver.h | 139 +----- .../resolver/resolver_is_storeable_test.cc | 79 +++ src/tint/resolver/sem_helper.cc | 3 +- src/tint/resolver/sem_helper.h | 16 +- .../{resolver_validation.cc => validator.cc} | 449 ++++++++++------- src/tint/resolver/validator.h | 457 ++++++++++++++++++ .../resolver/validator_is_storeable_test.cc | 86 ++++ src/tint/resolver/validator_test_helper.cc | 27 ++ src/tint/resolver/validator_test_helper.h | 46 ++ 12 files changed, 1066 insertions(+), 449 deletions(-) create mode 100644 src/tint/resolver/resolver_is_storeable_test.cc rename src/tint/resolver/{resolver_validation.cc => validator.cc} (85%) create mode 100644 src/tint/resolver/validator.h create mode 100644 src/tint/resolver/validator_is_storeable_test.cc create mode 100644 src/tint/resolver/validator_test_helper.cc create mode 100644 src/tint/resolver/validator_test_helper.h diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 980b16d0ea..d7b2be3766 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -376,9 +376,10 @@ libtint_source_set("libtint_core_all_src") { "resolver/resolver.cc", "resolver/resolver.h", "resolver/resolver_constants.cc", - "resolver/resolver_validation.cc", "resolver/sem_helper.cc", "resolver/sem_helper.h", + "resolver/validator.cc", + "resolver/validator.h", "scope_stack.h", "sem/array.h", "sem/atomic_type.h", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index ae3f2240ff..e2bb332e4c 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -256,10 +256,11 @@ set(TINT_LIB_SRCS resolver/dependency_graph.h resolver/resolver.cc resolver/resolver_constants.cc - resolver/resolver_validation.cc resolver/resolver.h resolver/sem_helper.cc resolver/sem_helper.h + resolver/validator.cc + resolver/validator.h scope_stack.h sem/array.cc sem/array.h diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 4f7c08d2df..5069e5cf5e 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -85,7 +85,8 @@ Resolver::Resolver(ProgramBuilder* builder) : builder_(builder), diagnostics_(builder->Diagnostics()), builtin_table_(BuiltinTable::Create(*builder)), - sem_(builder) {} + sem_(builder, dependencies_), + validator_(builder, sem_) {} Resolver::~Resolver() = default; @@ -138,7 +139,7 @@ bool Resolver::ResolveInternal() { SetShadows(); - if (!ValidatePipelineStages()) { + if (!validator_.PipelineStages(entry_points_)) { return false; } @@ -172,7 +173,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) { } if (auto* el = Type(t->type)) { if (auto* vector = builder_->create(el, t->width)) { - if (ValidateVector(vector, t->source)) { + if (validator_.Vector(vector, t->source)) { return vector; } } @@ -188,7 +189,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) { if (auto* column_type = builder_->create(el, t->rows)) { if (auto* matrix = builder_->create(column_type, t->columns)) { - if (ValidateMatrix(matrix, t->source)) { + if (validator_.Matrix(matrix, t->source)) { return matrix; } } @@ -200,7 +201,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) { [&](const ast::Atomic* t) -> sem::Atomic* { if (auto* el = Type(t->type)) { auto* a = builder_->create(el); - if (!ValidateAtomic(t, a)) { + if (!validator_.Atomic(t, a)) { return nullptr; } return a; @@ -240,7 +241,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) { }, [&](const ast::StorageTexture* t) -> sem::StorageTexture* { if (auto* el = Type(t->type)) { - if (!ValidateStorageTexture(t)) { + if (!validator_.StorageTexture(t)) { return nullptr; } return builder_->create(t->dim, t->format, @@ -252,7 +253,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) { return builder_->create(); }, [&](Default) { - auto* resolved = ResolvedSymbol(ty); + auto* resolved = sem_.ResolvedSymbol(ty); return Switch( resolved, // [&](sem::Type* type) { return type; }, @@ -366,8 +367,8 @@ sem::Variable* Resolver::Variable(const ast::Variable* var, if (kind == VariableKind::kLocal && !var->is_const && storage_class != ast::StorageClass::kFunction && - IsValidationEnabled(var->attributes, - ast::DisabledValidation::kIgnoreStorageClass)) { + validator_.IsValidationEnabled( + var->attributes, ast::DisabledValidation::kIgnoreStorageClass)) { AddError("function variable has a non-function storage class", var->source); return nullptr; } @@ -385,8 +386,8 @@ sem::Variable* Resolver::Variable(const ast::Variable* var, builder_->create(storage_ty, storage_class, access); } - if (rhs && !ValidateVariableConstructorOrCast(var, storage_class, storage_ty, - rhs->Type())) { + if (rhs && !validator_.VariableConstructorOrCast(var, storage_class, + storage_ty, rhs->Type())) { return nullptr; } @@ -547,17 +548,17 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* var) { } } - if (!ValidateNoDuplicateAttributes(var->attributes)) { + if (!validator_.NoDuplicateAttributes(var->attributes)) { return nullptr; } - if (!ValidateGlobalVariable(sem)) { + if (!validator_.GlobalVariable(sem, constant_ids_, atomic_composite_info_)) { return nullptr; } // TODO(bclayton): Call this at the end of resolve on all uniform and storage // referenced structs - if (!ValidateStorageClassLayout(sem, valid_type_storage_layouts_)) { + if (!validator_.StorageClassLayout(sem, valid_type_storage_layouts_)) { return nullptr; } @@ -592,7 +593,7 @@ sem::Function* Resolver::Function(const ast::Function* decl) { for (auto* attr : param->attributes) { Mark(attr); } - if (!ValidateNoDuplicateAttributes(param->attributes)) { + if (!validator_.NoDuplicateAttributes(param->attributes)) { return nullptr; } @@ -691,21 +692,21 @@ sem::Function* Resolver::Function(const ast::Function* decl) { for (auto* attr : decl->attributes) { Mark(attr); } - if (!ValidateNoDuplicateAttributes(decl->attributes)) { + if (!validator_.NoDuplicateAttributes(decl->attributes)) { return nullptr; } for (auto* attr : decl->return_type_attributes) { Mark(attr); } - if (!ValidateNoDuplicateAttributes(decl->return_type_attributes)) { + if (!validator_.NoDuplicateAttributes(decl->return_type_attributes)) { return nullptr; } auto stage = current_function_ ? current_function_->Declaration()->PipelineStage() : ast::PipelineStage::kNone; - if (!ValidateFunction(func, stage)) { + if (!validator_.Function(func, stage)) { return nullptr; } @@ -809,7 +810,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { << "could not resolve constant workgroup_size constant value"; continue; } - // Validate and set the default value for this dimension. + // validator_.Validate and set the default value for this dimension. if (is_i32 ? value.Elements()[0].i32 < 1 : value.Elements()[0].u32 < 1) { AddError("workgroup_size argument must be at least 1", values[i]->source); return false; @@ -843,7 +844,7 @@ bool Resolver::Statements(const ast::StatementList& stmts) { current_statement_->Behaviors() = behaviors; - if (!ValidateStatements(stmts)) { + if (!validator_.Statements(stmts)) { return false; } @@ -958,7 +959,7 @@ sem::IfStatement* Resolver::IfStatement(const ast::IfStatement* stmt) { sem->Behaviors().Add(sem::Behavior::kNext); } - return ValidateIfStatement(sem); + return validator_.IfStatement(sem); }); } @@ -989,7 +990,7 @@ sem::ElseStatement* Resolver::ElseStatement(const ast::ElseStatement* stmt) { } sem->Behaviors().Add(body->Behaviors()); - return ValidateElseStatement(sem); + return validator_.ElseStatement(sem); }); } @@ -1039,7 +1040,7 @@ sem::LoopStatement* Resolver::LoopStatement(const ast::LoopStatement* stmt) { } behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue); - return ValidateLoopStatement(sem); + return validator_.LoopStatement(sem); }); }); } @@ -1095,7 +1096,7 @@ sem::ForLoopStatement* Resolver::ForLoopStatement( } behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue); - return ValidateForLoopStatement(sem); + return validator_.ForLoopStatement(sem); }); } @@ -1226,7 +1227,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { sem->Behaviors() = inner->Behaviors(); - if (!ValidateBitcast(expr, ty)) { + if (!validator_.Bitcast(expr, ty)) { return nullptr; } @@ -1316,7 +1317,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { Mark(vec); auto* v = builder_->create( arg_el_ty, static_cast(vec->width)); - if (!ValidateVector(v, vec->source)) { + if (!validator_.Vector(v, vec->source)) { return nullptr; } builder_->Sem().Add(vec, v); @@ -1337,7 +1338,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { auto* column_type = builder_->create(arg_el_ty, mat->rows); auto* m = builder_->create(column_type, mat->columns); - if (!ValidateMatrix(m, mat->source)) { + if (!validator_.Matrix(m, mat->source)) { return nullptr; } builder_->Sem().Add(mat, m); @@ -1359,7 +1360,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { auto* ident = expr->target.name; Mark(ident); - auto* resolved = ResolvedSymbol(ident); + auto* resolved = sem_.ResolvedSymbol(ident); return Switch( resolved, // [&](sem::Type* type) { return type_ctor_or_conv(type); }, @@ -1414,7 +1415,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, current_function_->AddDirectlyCalledBuiltin(builtin); if (IsTextureBuiltin(builtin_type)) { - if (!ValidateTextureBuiltinFunction(call)) { + if (!validator_.TextureBuiltinFunction(call)) { return nullptr; } // Collect a texture/sampler pair for this builtin. @@ -1436,7 +1437,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, } } - if (!ValidateBuiltinCall(call)) { + if (!validator_.BuiltinCall(call)) { return nullptr; } @@ -1500,7 +1501,7 @@ sem::Call* Resolver::FunctionCall( call->Behaviors() = arg_behaviors + target->Behaviors(); - if (!ValidateFunctionCall(call)) { + if (!validator_.FunctionCall(call, current_statement_)) { return nullptr; } @@ -1527,23 +1528,23 @@ sem::Call* Resolver::TypeConversion(const ast::CallExpression* expr, bool ok = Switch( target, [&](const sem::Vector* vec_type) { - return ValidateVectorConstructorOrCast(expr, vec_type); + return validator_.VectorConstructorOrCast(expr, vec_type); }, [&](const sem::Matrix* mat_type) { // Note: Matrix types currently cannot be converted (the element // type must only be f32). We implement this for the day we // support other matrix element types. - return ValidateMatrixConstructorOrCast(expr, mat_type); + return validator_.MatrixConstructorOrCast(expr, mat_type); }, [&](const sem::Array* arr_type) { - return ValidateArrayConstructorOrCast(expr, arr_type); + return validator_.ArrayConstructorOrCast(expr, arr_type); }, [&](const sem::Struct* struct_type) { - return ValidateStructureConstructorOrCast(expr, struct_type); + return validator_.StructureConstructorOrCast(expr, struct_type); }, [&](Default) { if (target->is_scalar()) { - return ValidateScalarConstructorOrCast(expr, target); + return validator_.ScalarConstructorOrCast(expr, target); } AddError("type is not constructible", expr->source); return false; @@ -1593,20 +1594,20 @@ sem::Call* Resolver::TypeConstructor( bool ok = Switch( ty, [&](const sem::Vector* vec_type) { - return ValidateVectorConstructorOrCast(expr, vec_type); + return validator_.VectorConstructorOrCast(expr, vec_type); }, [&](const sem::Matrix* mat_type) { - return ValidateMatrixConstructorOrCast(expr, mat_type); + return validator_.MatrixConstructorOrCast(expr, mat_type); }, [&](const sem::Array* arr_type) { - return ValidateArrayConstructorOrCast(expr, arr_type); + return validator_.ArrayConstructorOrCast(expr, arr_type); }, [&](const sem::Struct* struct_type) { - return ValidateStructureConstructorOrCast(expr, struct_type); + return validator_.StructureConstructorOrCast(expr, struct_type); }, [&](Default) { if (ty->is_scalar()) { - return ValidateScalarConstructorOrCast(expr, ty); + return validator_.ScalarConstructorOrCast(expr, ty); } AddError("type is not constructible", expr->source); return false; @@ -1652,7 +1653,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { auto symbol = expr->symbol; - auto* resolved = ResolvedSymbol(expr); + auto* resolved = sem_.ResolvedSymbol(expr); if (auto* var = As(resolved)) { auto* user = builder_->create(expr, current_statement_, var); @@ -2156,7 +2157,8 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return nullptr; } - if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize() + if (!validator_.IsPlain( + elem_type)) { // Check must come before GetDefaultAlignAndSize() AddError(sem_.TypeNameOf(elem_type) + " cannot be used as an element type of an array", source); @@ -2166,7 +2168,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { uint32_t el_align = elem_type->Align(); uint32_t el_size = elem_type->Size(); - if (!ValidateNoDuplicateAttributes(arr->attributes)) { + if (!validator_.NoDuplicateAttributes(arr->attributes)) { return nullptr; } @@ -2176,7 +2178,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { Mark(attr); if (auto* sd = attr->As()) { explicit_stride = sd->stride; - if (!ValidateArrayStrideAttribute(sd, el_size, el_align, source)) { + if (!validator_.ArrayStrideAttribute(sd, el_size, el_align, source)) { return nullptr; } continue; @@ -2210,7 +2212,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { if (auto* ident = count_expr->As()) { // Make sure the identifier is a non-overridable module-scope constant. - auto* var = ResolvedSymbol(ident); + auto* var = sem_.ResolvedSymbol(ident); if (!var || !var->Declaration()->is_const) { AddError("array size identifier must be a module-scope constant", size_source); @@ -2266,7 +2268,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { elem_type, count, el_align, static_cast(size), static_cast(stride), static_cast(implicit_stride)); - if (!ValidateArray(out, source)) { + if (!validator_.Array(out, source)) { return nullptr; } @@ -2287,14 +2289,14 @@ sem::Type* Resolver::Alias(const ast::Alias* alias) { if (!ty) { return nullptr; } - if (!ValidateAlias(alias)) { + if (!validator_.Alias(alias)) { return nullptr; } return ty; } sem::Struct* Resolver::Structure(const ast::Struct* str) { - if (!ValidateNoDuplicateAttributes(str->attributes)) { + if (!validator_.NoDuplicateAttributes(str->attributes)) { return nullptr; } for (auto* attr : str->attributes) { @@ -2335,8 +2337,8 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { return nullptr; } - // Validate member type - if (!IsPlain(type)) { + // validator_.Validate member type + if (!validator_.IsPlain(type)) { AddError(sem_.TypeNameOf(type) + " cannot be used as the type of a structure member", member->source); @@ -2347,7 +2349,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { uint64_t align = type->Align(); uint64_t size = type->Size(); - if (!ValidateNoDuplicateAttributes(member->attributes)) { + if (!validator_.NoDuplicateAttributes(member->attributes)) { return nullptr; } @@ -2453,7 +2455,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { auto stage = current_function_ ? current_function_->Declaration()->PipelineStage() : ast::PipelineStage::kNone; - if (!ValidateStructure(out, stage)) { + if (!validator_.Structure(out, stage)) { return nullptr; } @@ -2479,7 +2481,8 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { // is available for validation. auto* ret_type = stmt->value ? sem_.TypeOf(stmt->value)->UnwrapRef() : builder_->create(); - return ValidateReturn(stmt, current_function_->ReturnType(), ret_type); + return validator_.Return(stmt, current_function_->ReturnType(), ret_type, + current_statement_); }); } @@ -2510,7 +2513,7 @@ sem::SwitchStatement* Resolver::SwitchStatement( } behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kFallthrough); - return ValidateSwitch(stmt); + return validator_.SwitchStatement(stmt); }); } @@ -2542,7 +2545,7 @@ sem::Statement* Resolver::VariableDeclStatement( sem->Behaviors() = ctor->Behaviors(); } - return ValidateVariable(var); + return validator_.Variable(var); }); } @@ -2567,7 +2570,7 @@ sem::Statement* Resolver::AssignmentStatement( behaviors.Add(lhs->Behaviors()); } - return ValidateAssignment(stmt, sem_.TypeOf(stmt->rhs)); + return validator_.Assignment(stmt, sem_.TypeOf(stmt->rhs)); }); } @@ -2577,7 +2580,7 @@ sem::Statement* Resolver::BreakStatement(const ast::BreakStatement* stmt) { return StatementScope(stmt, sem, [&] { sem->Behaviors() = sem::Behavior::kBreak; - return ValidateBreakStatement(sem); + return validator_.BreakStatement(sem, current_statement_); }); } @@ -2620,7 +2623,7 @@ sem::Statement* Resolver::CompoundAssignmentStatement( stmt->source); return false; } - return ValidateAssignment(stmt, ty); + return validator_.Assignment(stmt, ty); }); } @@ -2639,7 +2642,7 @@ sem::Statement* Resolver::ContinueStatement( } } - return ValidateContinueStatement(sem); + return validator_.ContinueStatement(sem, current_statement_); }); } @@ -2650,7 +2653,7 @@ sem::Statement* Resolver::DiscardStatement(const ast::DiscardStatement* stmt) { sem->Behaviors() = sem::Behavior::kDiscard; current_function_->SetHasDiscard(); - return ValidateDiscardStatement(sem); + return validator_.DiscardStatement(sem, current_statement_); }); } @@ -2661,7 +2664,7 @@ sem::Statement* Resolver::FallthroughStatement( return StatementScope(stmt, sem, [&] { sem->Behaviors() = sem::Behavior::kFallthrough; - return ValidateFallthroughStatement(sem); + return validator_.FallthroughStatement(sem); }); } @@ -2676,7 +2679,7 @@ sem::Statement* Resolver::IncrementDecrementStatement( } sem->Behaviors() = lhs->Behaviors(); - return ValidateIncrementDecrementStatement(stmt); + return validator_.IncrementDecrementStatement(stmt); }); } @@ -2718,7 +2721,7 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, sc, const_cast(arr->ElemType()), usage); } - if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) { + if (ast::IsHostShareable(sc) && !validator_.IsHostShareable(ty)) { std::stringstream err; err << "Type '" << sem_.TypeNameOf(ty) << "' cannot be used in storage class '" << sc @@ -2782,62 +2785,6 @@ void Resolver::AddNote(const std::string& msg, const Source& source) const { diagnostics_.add_note(diag::System::Resolver, msg, source); } -// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section -bool Resolver::IsPlain(const sem::Type* type) const { - return type->is_scalar() || - type->IsAnyOf(); -} - -// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types -bool Resolver::IsFixedFootprint(const sem::Type* type) const { - return Switch( - type, // - [&](const sem::Vector*) { return true; }, // - [&](const sem::Matrix*) { return true; }, // - [&](const sem::Atomic*) { return true; }, - [&](const sem::Array* arr) { - return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType()); - }, - [&](const sem::Struct* str) { - for (auto* member : str->Members()) { - if (!IsFixedFootprint(member->Type())) { - return false; - } - } - return true; - }, - [&](Default) { return type->is_scalar(); }); -} - -// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types -bool Resolver::IsStorable(const sem::Type* type) const { - return IsPlain(type) || type->IsAnyOf(); -} - -// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types -bool Resolver::IsHostShareable(const sem::Type* type) const { - if (type->IsAnyOf()) { - return true; - } - return Switch( - type, // - [&](const sem::Vector* vec) { return IsHostShareable(vec->type()); }, - [&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); }, - [&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); }, - [&](const sem::Struct* str) { - for (auto* member : str->Members()) { - if (!IsHostShareable(member->Type())) { - return false; - } - } - return true; - }, - [&](const sem::Atomic* atomic) { - return IsHostShareable(atomic->Type()); - }); -} - bool Resolver::IsBuiltin(Symbol symbol) const { std::string name = builder_->Symbols().NameFor(symbol); return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone; @@ -2849,26 +2796,6 @@ bool Resolver::IsCallStatement(const ast::Expression* expr) const { [&](auto* stmt) { return stmt->expr == expr; }); } -const ast::Statement* Resolver::ClosestContinuing(bool stop_at_loop) const { - for (const auto* s = current_statement_; s != nullptr; s = s->Parent()) { - if (stop_at_loop && s->Is()) { - break; - } - if (s->Is()) { - return s->Declaration(); - } - if (auto* f = As(s->Parent())) { - if (f->Declaration()->continuing == s->Declaration()) { - return s->Declaration(); - } - if (stop_at_loop) { - break; - } - } - } - return nullptr; -} - //////////////////////////////////////////////////////////////////////////////// // Resolver::TypeConversionSig //////////////////////////////////////////////////////////////////////////////// diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 95dd5ffc68..6487d35a84 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -16,7 +16,6 @@ #define SRC_TINT_RESOLVER_RESOLVER_H_ #include -#include #include #include #include @@ -27,13 +26,13 @@ #include "src/tint/program_builder.h" #include "src/tint/resolver/dependency_graph.h" #include "src/tint/resolver/sem_helper.h" +#include "src/tint/resolver/validator.h" #include "src/tint/scope_stack.h" #include "src/tint/sem/binding_point.h" #include "src/tint/sem/block_statement.h" #include "src/tint/sem/constant.h" #include "src/tint/sem/function.h" #include "src/tint/sem/struct.h" -#include "src/tint/utils/map.h" #include "src/tint/utils/unique_vector.h" // Forward declarations @@ -89,27 +88,31 @@ class Resolver { /// @param type the given type /// @returns true if the given type is a plain type - bool IsPlain(const sem::Type* type) const; + bool IsPlain(const sem::Type* type) const { return validator_.IsPlain(type); } /// @param type the given type /// @returns true if the given type is a fixed-footprint type - bool IsFixedFootprint(const sem::Type* type) const; + bool IsFixedFootprint(const sem::Type* type) const { + return validator_.IsFixedFootprint(type); + } /// @param type the given type /// @returns true if the given type is storable - bool IsStorable(const sem::Type* type) const; + bool IsStorable(const sem::Type* type) const { + return validator_.IsStorable(type); + } /// @param type the given type /// @returns true if the given type is host-shareable - bool IsHostShareable(const sem::Type* type) const; + bool IsHostShareable(const sem::Type* type) const { + return validator_.IsHostShareable(type); + } private: /// Describes the context in which a variable is declared enum class VariableKind { kParameter, kLocal, kGlobal }; - using ValidTypeStorageLayouts = - std::set>; - ValidTypeStorageLayouts valid_type_storage_layouts_; + Validator::ValidTypeStorageLayouts valid_type_storage_layouts_; /// Structure holding semantic information about a block (i.e. scope), such as /// parent block and variables declared in the block. @@ -237,106 +240,6 @@ class Resolver { const sem::Type* rhs_ty, ast::BinaryOp op); - // AST and Type validation methods - // Each return true on success, false on failure. - bool ValidatePipelineStages() const; - bool ValidateAlias(const ast::Alias*) const; - bool ValidateArray(const sem::Array* arr, const Source& source) const; - bool ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, - uint32_t el_size, - uint32_t el_align, - const Source& source) const; - bool ValidateAtomic(const ast::Atomic* a, const sem::Atomic* s) const; - bool ValidateAtomicVariable(const sem::Variable* var) const; - bool ValidateAssignment(const ast::Statement* a, - const sem::Type* rhs_ty) const; - bool ValidateBitcast(const ast::BitcastExpression* cast, - const sem::Type* to) const; - bool ValidateBreakStatement(const sem::Statement* stmt) const; - bool ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, - const sem::Type* storage_type, - ast::PipelineStage stage, - const bool is_input) const; - bool ValidateContinueStatement(const sem::Statement* stmt) const; - bool ValidateDiscardStatement(const sem::Statement* stmt) const; - bool ValidateElseStatement(const sem::ElseStatement* stmt) const; - bool ValidateEntryPoint(const sem::Function* func, - ast::PipelineStage stage) const; - bool ValidateForLoopStatement(const sem::ForLoopStatement* stmt) const; - bool ValidateFallthroughStatement(const sem::Statement* stmt) const; - bool ValidateFunction(const sem::Function* func, - ast::PipelineStage stage) const; - bool ValidateFunctionCall(const sem::Call* call) const; - bool ValidateGlobalVariable(const sem::Variable* var) const; - bool ValidateIfStatement(const sem::IfStatement* stmt) const; - bool ValidateIncrementDecrementStatement( - const ast::IncrementDecrementStatement* stmt) const; - bool ValidateInterpolateAttribute(const ast::InterpolateAttribute* attr, - const sem::Type* storage_type) const; - bool ValidateBuiltinCall(const sem::Call* call) const; - bool ValidateLocationAttribute(const ast::LocationAttribute* location, - const sem::Type* type, - std::unordered_set& locations, - ast::PipelineStage stage, - const Source& source, - const bool is_input = false) const; - bool ValidateLoopStatement(const sem::LoopStatement* stmt) const; - bool ValidateMatrix(const sem::Matrix* ty, const Source& source) const; - bool ValidateFunctionParameter(const ast::Function* func, - const sem::Variable* var) const; - bool ValidateReturn(const ast::ReturnStatement* ret, - const sem::Type* func_type, - const sem::Type* ret_type) const; - bool ValidateStatements(const ast::StatementList& stmts) const; - bool ValidateStorageTexture(const ast::StorageTexture* t) const; - bool ValidateStructure(const sem::Struct* str, - ast::PipelineStage stage) const; - bool ValidateStructureConstructorOrCast(const ast::CallExpression* ctor, - const sem::Struct* struct_type) const; - bool ValidateSwitch(const ast::SwitchStatement* s); - bool ValidateVariable(const sem::Variable* var) const; - bool ValidateVariableConstructorOrCast(const ast::Variable* var, - ast::StorageClass storage_class, - const sem::Type* storage_type, - const sem::Type* rhs_type) const; - bool ValidateVector(const sem::Vector* ty, const Source& source) const; - bool ValidateVectorConstructorOrCast(const ast::CallExpression* ctor, - const sem::Vector* vec_type) const; - bool ValidateMatrixConstructorOrCast(const ast::CallExpression* ctor, - const sem::Matrix* matrix_type) const; - bool ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, - const sem::Type* type) const; - bool ValidateArrayConstructorOrCast(const ast::CallExpression* ctor, - const sem::Array* arr_type) const; - bool ValidateTextureBuiltinFunction(const sem::Call* call) const; - bool ValidateNoDuplicateAttributes( - const ast::AttributeList& attributes) const; - bool ValidateStorageClassLayout(const sem::Type* type, - ast::StorageClass sc, - Source source, - ValidTypeStorageLayouts& layouts) const; - bool ValidateStorageClassLayout(const sem::Variable* var, - ValidTypeStorageLayouts& layouts) const; - - /// @returns true if the attribute list contains a - /// ast::DisableValidationAttribute with the validation mode equal to - /// `validation` - bool IsValidationDisabled(const ast::AttributeList& attributes, - ast::DisabledValidation validation) const; - - /// @returns true if the attribute list does not contains a - /// ast::DisableValidationAttribute with the validation mode equal to - /// `validation` - bool IsValidationEnabled(const ast::AttributeList& attributes, - ast::DisabledValidation validation) const; - - /// Returns a human-readable string representation of the vector type name - /// with the given parameters. - /// @param size the vector dimension - /// @param element_type scalar vector sub-element type - /// @return pretty string representation - std::string VectorPretty(uint32_t size, const sem::Type* element_type) const; - /// Resolves the WorkgroupSize for the given function, assigning it to /// current_function_ bool WorkgroupSize(const ast::Function*); @@ -457,23 +360,6 @@ class Resolver { /// @returns true if `expr` is the current CallStatement's CallExpression bool IsCallStatement(const ast::Expression* expr) const; - /// Searches the current statement and up through parents of the current - /// statement looking for a loop or for-loop continuing statement. - /// @returns the closest continuing statement to the current statement that - /// (transitively) owns the current statement. - /// @param stop_at_loop if true then the function will return nullptr if a - /// loop or for-loop was found before the continuing. - const ast::Statement* ClosestContinuing(bool stop_at_loop) const; - - /// @returns the resolved symbol (function, type or variable) for the given - /// ast::Identifier or ast::TypeName cast to the given semantic type. - template - SEM* ResolvedSymbol(const ast::Node* node) const { - auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node); - return resolved ? const_cast(builder_->Sem().Get(resolved)) - : nullptr; - } - struct TypeConversionSig { const sem::Type* target; const sem::Type* source; @@ -511,6 +397,7 @@ class Resolver { std::unique_ptr const builtin_table_; DependencyGraph dependencies_; SemHelper sem_; + Validator validator_; std::vector entry_points_; std::unordered_map atomic_composite_info_; std::unordered_set marked_; diff --git a/src/tint/resolver/resolver_is_storeable_test.cc b/src/tint/resolver/resolver_is_storeable_test.cc new file mode 100644 index 0000000000..de180a33e0 --- /dev/null +++ b/src/tint/resolver/resolver_is_storeable_test.cc @@ -0,0 +1,79 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/resolver.h" + +#include "gmock/gmock.h" +#include "src/tint/resolver/resolver_test_helper.h" +#include "src/tint/sem/atomic_type.h" + +namespace tint::resolver { +namespace { + +using ResolverIsStorableTest = ResolverTest; + +TEST_F(ResolverIsStorableTest, Struct_AllMembersStorable) { + Structure("S", { + Member("a", ty.i32()), + Member("b", ty.f32()), + }); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIsStorableTest, Struct_SomeMembersNonStorable) { + Structure("S", { + Member("a", ty.i32()), + Member("b", ty.pointer(ast::StorageClass::kPrivate)), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(error: ptr cannot be used as the type of a structure member)"); +} + +TEST_F(ResolverIsStorableTest, Struct_NestedStorable) { + auto* storable = Structure("Storable", { + Member("a", ty.i32()), + Member("b", ty.f32()), + }); + Structure("S", { + Member("a", ty.i32()), + Member("b", ty.Of(storable)), + }); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverIsStorableTest, Struct_NestedNonStorable) { + auto* non_storable = + Structure("nonstorable", + { + Member("a", ty.i32()), + Member("b", ty.pointer(ast::StorageClass::kPrivate)), + }); + Structure("S", { + Member("a", ty.i32()), + Member("b", ty.Of(non_storable)), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(error: ptr cannot be used as the type of a structure member)"); +} + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/sem_helper.cc b/src/tint/resolver/sem_helper.cc index 74b3c5b077..57fff2db17 100644 --- a/src/tint/resolver/sem_helper.cc +++ b/src/tint/resolver/sem_helper.cc @@ -18,7 +18,8 @@ namespace tint::resolver { -SemHelper::SemHelper(ProgramBuilder* builder) : builder_(builder) {} +SemHelper::SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies) + : builder_(builder), dependencies_(dependencies) {} SemHelper::~SemHelper() = default; diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h index 0e9539766e..58c2d5756b 100644 --- a/src/tint/resolver/sem_helper.h +++ b/src/tint/resolver/sem_helper.h @@ -19,6 +19,8 @@ #include "src/tint/diagnostic/diagnostic.h" #include "src/tint/program_builder.h" +#include "src/tint/resolver/dependency_graph.h" +#include "src/tint/utils/map.h" namespace tint::resolver { @@ -27,7 +29,8 @@ class SemHelper { public: /// Constructor /// @param builder the program builder - explicit SemHelper(ProgramBuilder* builder); + /// @param dependencies the program dependency graph + explicit SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies); ~SemHelper(); /// Get is a helper for obtaining the semantic node for the given AST node. @@ -47,6 +50,16 @@ class SemHelper { return const_cast(As(sem)); } + /// @returns the resolved symbol (function, type or variable) for the given + /// ast::Identifier or ast::TypeName cast to the given semantic type. + /// @param node the node to retrieve + template + SEM* ResolvedSymbol(const ast::Node* node) const { + auto* resolved = utils::Lookup(dependencies_.resolved_symbols, node); + return resolved ? const_cast(builder_->Sem().Get(resolved)) + : nullptr; + } + /// @returns the resolved type of the ast::Expression `expr` /// @param expr the expression sem::Type* TypeOf(const ast::Expression* expr) const; @@ -67,6 +80,7 @@ class SemHelper { private: ProgramBuilder* builder_; + DependencyGraph& dependencies_; }; } // namespace tint::resolver diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/validator.cc similarity index 85% rename from src/tint/resolver/resolver_validation.cc rename to src/tint/resolver/validator.cc index 44e0a2f078..ffe855e86f 100644 --- a/src/tint/resolver/resolver_validation.cc +++ b/src/tint/resolver/validator.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "src/tint/resolver/resolver.h" +#include "src/tint/resolver/validator.h" #include #include @@ -149,8 +149,104 @@ void TraverseCallChain(diag::List& diagnostics, } // namespace -bool Resolver::ValidateAtomic(const ast::Atomic* a, - const sem::Atomic* s) const { +Validator::Validator(ProgramBuilder* builder, SemHelper& sem) + : symbols_(builder->Symbols()), + diagnostics_(builder->Diagnostics()), + sem_(sem) {} + +Validator::~Validator() = default; + +void Validator::AddError(const std::string& msg, const Source& source) const { + diagnostics_.add_error(diag::System::Resolver, msg, source); +} + +void Validator::AddWarning(const std::string& msg, const Source& source) const { + diagnostics_.add_warning(diag::System::Resolver, msg, source); +} + +void Validator::AddNote(const std::string& msg, const Source& source) const { + diagnostics_.add_note(diag::System::Resolver, msg, source); +} + +// https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section +bool Validator::IsPlain(const sem::Type* type) const { + return type->is_scalar() || + type->IsAnyOf(); +} + +// https://gpuweb.github.io/gpuweb/wgsl/#fixed-footprint-types +bool Validator::IsFixedFootprint(const sem::Type* type) const { + return Switch( + type, // + [&](const sem::Vector*) { return true; }, // + [&](const sem::Matrix*) { return true; }, // + [&](const sem::Atomic*) { return true; }, + [&](const sem::Array* arr) { + return !arr->IsRuntimeSized() && IsFixedFootprint(arr->ElemType()); + }, + [&](const sem::Struct* str) { + for (auto* member : str->Members()) { + if (!IsFixedFootprint(member->Type())) { + return false; + } + } + return true; + }, + [&](Default) { return type->is_scalar(); }); +} + +// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types +bool Validator::IsHostShareable(const sem::Type* type) const { + if (type->IsAnyOf()) { + return true; + } + return Switch( + type, // + [&](const sem::Vector* vec) { return IsHostShareable(vec->type()); }, + [&](const sem::Matrix* mat) { return IsHostShareable(mat->type()); }, + [&](const sem::Array* arr) { return IsHostShareable(arr->ElemType()); }, + [&](const sem::Struct* str) { + for (auto* member : str->Members()) { + if (!IsHostShareable(member->Type())) { + return false; + } + } + return true; + }, + [&](const sem::Atomic* atomic) { + return IsHostShareable(atomic->Type()); + }); +} + +// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types +bool Validator::IsStorable(const sem::Type* type) const { + return IsPlain(type) || type->IsAnyOf(); +} + +const ast::Statement* Validator::ClosestContinuing( + bool stop_at_loop, + sem::Statement* current_statement) const { + for (const auto* s = current_statement; s != nullptr; s = s->Parent()) { + if (stop_at_loop && s->Is()) { + break; + } + if (s->Is()) { + return s->Declaration(); + } + if (auto* f = As(s->Parent())) { + if (f->Declaration()->continuing == s->Declaration()) { + return s->Declaration(); + } + if (stop_at_loop) { + break; + } + } + } + return nullptr; +} + +bool Validator::Atomic(const ast::Atomic* a, const sem::Atomic* s) const { // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // T must be either u32 or i32. if (!s->Type()->IsAnyOf()) { @@ -161,7 +257,7 @@ bool Resolver::ValidateAtomic(const ast::Atomic* a, return true; } -bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const { +bool Validator::StorageTexture(const ast::StorageTexture* t) const { switch (t->access) { case ast::Access::kWrite: break; @@ -190,11 +286,10 @@ bool Resolver::ValidateStorageTexture(const ast::StorageTexture* t) const { return true; } -bool Resolver::ValidateVariableConstructorOrCast( - const ast::Variable* var, - ast::StorageClass storage_class, - const sem::Type* storage_ty, - const sem::Type* rhs_ty) const { +bool Validator::VariableConstructorOrCast(const ast::Variable* var, + ast::StorageClass storage_class, + const sem::Type* storage_ty, + const sem::Type* rhs_ty) const { auto* value_type = rhs_ty->UnwrapRef(); // Implicit load of RHS // Value type has to match storage type @@ -229,11 +324,10 @@ bool Resolver::ValidateVariableConstructorOrCast( return true; } -bool Resolver::ValidateStorageClassLayout( - const sem::Type* store_ty, - ast::StorageClass sc, - Source source, - ValidTypeStorageLayouts& layouts) const { +bool Validator::StorageClassLayout(const sem::Type* store_ty, + ast::StorageClass sc, + Source source, + ValidTypeStorageLayouts& layouts) const { // https://gpuweb.github.io/gpuweb/wgsl/#storage-class-layout-constraints auto is_uniform_struct_or_array = [sc](const sem::Type* ty) { @@ -255,7 +349,7 @@ bool Resolver::ValidateStorageClassLayout( }; auto member_name_of = [this](const sem::StructMember* sm) { - return builder_->Symbols().NameFor(sm->Declaration()->symbol); + return symbols_.NameFor(sm->Declaration()->symbol); }; // Cache result of type + storage class pair. @@ -273,9 +367,9 @@ bool Resolver::ValidateStorageClassLayout( uint32_t required_align = required_alignment_of(m->Type()); // Recurse into the member type. - if (!ValidateStorageClassLayout( - m->Type(), sc, m->Declaration()->type->source, layouts)) { - AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()), + if (!StorageClassLayout(m->Type(), sc, m->Declaration()->type->source, + layouts)) { + AddNote("see layout of struct:\n" + str->Layout(symbols_), str->Declaration()->source); return false; } @@ -283,7 +377,7 @@ bool Resolver::ValidateStorageClassLayout( // Validate that member is at a valid byte offset if (m->Offset() % required_align != 0) { AddError("the offset of a struct member of type '" + - m->Type()->UnwrapRef()->FriendlyName(builder_->Symbols()) + + m->Type()->UnwrapRef()->FriendlyName(symbols_) + "' in storage class '" + ast::ToString(sc) + "' must be a multiple of " + std::to_string(required_align) + " bytes, but '" + @@ -293,13 +387,13 @@ bool Resolver::ValidateStorageClassLayout( std::to_string(required_align) + ") on this member", m->Declaration()->source); - AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()), + AddNote("see layout of struct:\n" + str->Layout(symbols_), str->Declaration()->source); if (auto* member_str = m->Type()->As()) { - AddNote("and layout of struct member:\n" + - member_str->Layout(builder_->Symbols()), - member_str->Declaration()->source); + AddNote( + "and layout of struct member:\n" + member_str->Layout(symbols_), + member_str->Declaration()->source); } return false; @@ -322,12 +416,12 @@ bool Resolver::ValidateStorageClassLayout( "'. Consider setting @align(16) on this member", m->Declaration()->source); - AddNote("see layout of struct:\n" + str->Layout(builder_->Symbols()), + AddNote("see layout of struct:\n" + str->Layout(symbols_), str->Declaration()->source); auto* prev_member_str = prev_member->Type()->As(); AddNote("and layout of previous member struct:\n" + - prev_member_str->Layout(builder_->Symbols()), + prev_member_str->Layout(symbols_), prev_member_str->Declaration()->source); return false; } @@ -342,7 +436,7 @@ bool Resolver::ValidateStorageClassLayout( // TODO(crbug.com/tint/1388): Ideally we'd pass the source for nested // element type here, but we can't easily get that from the semantic node. // We should consider recursing through the AST type nodes instead. - if (!ValidateStorageClassLayout(arr->ElemType(), sc, source, layouts)) { + if (!StorageClassLayout(arr->ElemType(), sc, source, layouts)) { return false; } @@ -384,12 +478,11 @@ bool Resolver::ValidateStorageClassLayout( return true; } -bool Resolver::ValidateStorageClassLayout( - const sem::Variable* var, - ValidTypeStorageLayouts& layouts) const { +bool Validator::StorageClassLayout(const sem::Variable* var, + ValidTypeStorageLayouts& layouts) const { if (auto* str = var->Type()->UnwrapRef()->As()) { - if (!ValidateStorageClassLayout(str, var->StorageClass(), - str->Declaration()->source, layouts)) { + if (!StorageClassLayout(str, var->StorageClass(), + str->Declaration()->source, layouts)) { AddNote("see declaration of variable", var->Declaration()->source); return false; } @@ -398,8 +491,8 @@ bool Resolver::ValidateStorageClassLayout( if (var->Declaration()->type) { source = var->Declaration()->type->source; } - if (!ValidateStorageClassLayout(var->Type()->UnwrapRef(), - var->StorageClass(), source, layouts)) { + if (!StorageClassLayout(var->Type()->UnwrapRef(), var->StorageClass(), + source, layouts)) { return false; } } @@ -407,9 +500,13 @@ bool Resolver::ValidateStorageClassLayout( return true; } -bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const { +bool Validator::GlobalVariable( + const sem::Variable* var, + std::unordered_map constant_ids, + std::unordered_map atomic_composite_info) + const { auto* decl = var->Declaration(); - if (!ValidateNoDuplicateAttributes(decl->attributes)) { + if (!NoDuplicateAttributes(decl->attributes)) { return false; } @@ -417,8 +514,8 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const { if (decl->is_const) { if (auto* id_attr = attr->As()) { uint32_t id = id_attr->value; - auto it = constant_ids_.find(id); - if (it != constant_ids_.end() && it->second != var) { + auto it = constant_ids.find(id); + if (it != constant_ids.end() && it->second != var) { AddError("pipeline constant IDs must be unique", attr->source); AddNote("a pipeline constant with an ID of " + std::to_string(id) + " was previously declared " @@ -502,18 +599,21 @@ bool Resolver::ValidateGlobalVariable(const sem::Variable* var) const { } if (!decl->is_const) { - if (!ValidateAtomicVariable(var)) { + if (!AtomicVariable(var, atomic_composite_info)) { return false; } } - return ValidateVariable(var); + return Variable(var); } // https://gpuweb.github.io/gpuweb/wgsl/#atomic-types // Atomic types may only be instantiated by variables in the workgroup storage // class or by storage buffer variables with a read_write access mode. -bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const { +bool Validator::AtomicVariable( + const sem::Variable* var, + std::unordered_map atomic_composite_info) + const { auto sc = var->StorageClass(); auto* decl = var->Declaration(); auto access = var->Access(); @@ -529,8 +629,8 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const { return false; } } else if (type->IsAnyOf()) { - auto found = atomic_composite_info_.find(type); - if (found != atomic_composite_info_.end()) { + auto found = atomic_composite_info.find(type); + if (found != atomic_composite_info.end()) { if (sc != ast::StorageClass::kStorage && sc != ast::StorageClass::kWorkgroup) { AddError( @@ -557,12 +657,12 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const { return true; } -bool Resolver::ValidateVariable(const sem::Variable* var) const { +bool Validator::Variable(const sem::Variable* var) const { auto* decl = var->Declaration(); auto* storage_ty = var->Type()->UnwrapRef(); if (var->Is()) { - auto name = builder_->Symbols().NameFor(decl->symbol); + auto name = symbols_.NameFor(decl->symbol); if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { auto* kind = var->Declaration()->is_const ? "let" : "var"; AddError( @@ -634,9 +734,9 @@ bool Resolver::ValidateVariable(const sem::Variable* var) const { return true; } -bool Resolver::ValidateFunctionParameter(const ast::Function* func, - const sem::Variable* var) const { - if (!ValidateVariable(var)) { +bool Validator::FunctionParameter(const ast::Function* func, + const sem::Variable* var) const { + if (!Variable(var)) { return false; } @@ -697,10 +797,10 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, return true; } -bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, - const sem::Type* storage_ty, - ast::PipelineStage stage, - const bool is_input) const { +bool Validator::BuiltinAttribute(const ast::BuiltinAttribute* attr, + const sem::Type* storage_ty, + ast::PipelineStage stage, + const bool is_input) const { auto* type = storage_ty->UnwrapRef(); std::stringstream stage_name; stage_name << stage; @@ -816,9 +916,8 @@ bool Resolver::ValidateBuiltinAttribute(const ast::BuiltinAttribute* attr, return true; } -bool Resolver::ValidateInterpolateAttribute( - const ast::InterpolateAttribute* attr, - const sem::Type* storage_ty) const { +bool Validator::InterpolateAttribute(const ast::InterpolateAttribute* attr, + const sem::Type* storage_ty) const { auto* type = storage_ty->UnwrapRef(); if (type->is_integer_scalar_or_vector() && @@ -839,11 +938,11 @@ bool Resolver::ValidateInterpolateAttribute( return true; } -bool Resolver::ValidateFunction(const sem::Function* func, - ast::PipelineStage stage) const { +bool Validator::Function(const sem::Function* func, + ast::PipelineStage stage) const { auto* decl = func->Declaration(); - auto name = builder_->Symbols().NameFor(decl->symbol); + auto name = symbols_.NameFor(decl->symbol); if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { AddError( "'" + name + "' is a builtin and cannot be redeclared as a function", @@ -873,7 +972,7 @@ bool Resolver::ValidateFunction(const sem::Function* func, } for (size_t i = 0; i < decl->params.size(); i++) { - if (!ValidateFunctionParameter(decl, func->Parameters()[i])) { + if (!FunctionParameter(decl, func->Parameters()[i])) { return false; } } @@ -898,8 +997,7 @@ bool Resolver::ValidateFunction(const sem::Function* func, decl->attributes, ast::DisabledValidation::kFunctionHasNoBody)) { TINT_ICE(Resolver, diagnostics_) - << "Function " << builder_->Symbols().NameFor(decl->symbol) - << " has no body"; + << "Function " << symbols_.NameFor(decl->symbol) << " has no body"; } for (auto* attr : decl->return_type_attributes) { @@ -925,7 +1023,7 @@ bool Resolver::ValidateFunction(const sem::Function* func, } if (decl->IsEntryPoint()) { - if (!ValidateEntryPoint(func, stage)) { + if (!EntryPoint(func, stage)) { return false; } } @@ -945,8 +1043,8 @@ bool Resolver::ValidateFunction(const sem::Function* func, return true; } -bool Resolver::ValidateEntryPoint(const sem::Function* func, - ast::PipelineStage stage) const { +bool Validator::EntryPoint(const sem::Function* func, + ast::PipelineStage stage) const { auto* decl = func->Declaration(); // Use a lambda to validate the entry point attributes for a type. @@ -994,7 +1092,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func, return false; } - if (!ValidateBuiltinAttribute( + if (!BuiltinAttribute( builtin, ty, stage, /* is_input */ param_or_ret == ParamOrRetType::kParameter)) { return false; @@ -1011,14 +1109,14 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func, bool is_input = param_or_ret == ParamOrRetType::kParameter; - if (!ValidateLocationAttribute(location, ty, locations, stage, source, - is_input)) { + if (!LocationAttribute(location, ty, locations, stage, source, + is_input)) { return false; } } else if (auto* interpolate = attr->As()) { if (decl->PipelineStage() == ast::PipelineStage::kCompute) { is_invalid_compute_shader_attribute = true; - } else if (!ValidateInterpolateAttribute(interpolate, ty)) { + } else if (!InterpolateAttribute(interpolate, ty)) { return false; } interpolate_attribute = interpolate; @@ -1122,7 +1220,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func, member->Declaration()->source, param_or_ret, /*is_struct_member*/ true)) { AddNote("while analysing entry point '" + - builder_->Symbols().NameFor(decl->symbol) + "'", + symbols_.NameFor(decl->symbol) + "'", decl->source); return false; } @@ -1206,7 +1304,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func, // variables in the resource interface of a given shader must not have // the same group and binding values, when considered as a pair of // values. - auto func_name = builder_->Symbols().NameFor(decl->symbol); + auto func_name = symbols_.NameFor(decl->symbol); AddError("entry point '" + func_name + "' references multiple variables that use the " "same resource binding @group(" + @@ -1222,7 +1320,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func, return true; } -bool Resolver::ValidateStatements(const ast::StatementList& stmts) const { +bool Validator::Statements(const ast::StatementList& stmts) const { for (auto* stmt : stmts) { if (!sem_.Get(stmt)->IsReachable()) { /// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to @@ -1234,8 +1332,8 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) const { return true; } -bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, - const sem::Type* to) const { +bool Validator::Bitcast(const ast::BitcastExpression* cast, + const sem::Type* to) const { auto* from = sem_.TypeOf(cast->expr)->UnwrapRef(); if (!from->is_numeric_scalar_or_vector()) { AddError("'" + sem_.TypeNameOf(from) + "' cannot be bitcast", @@ -1265,13 +1363,15 @@ bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, return true; } -bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const { +bool Validator::BreakStatement(const sem::Statement* stmt, + sem::Statement* current_statement) const { if (!stmt->FindFirstParent()) { AddError("break statement must be in a loop or switch case", stmt->Declaration()->source); return false; } - if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) { + if (auto* continuing = + ClosestContinuing(/*stop_at_loop*/ true, current_statement)) { auto fail = [&](const char* note_msg, const Source& note_src) { constexpr const char* kErrorMsg = "break statement in a continuing block must be the single statement " @@ -1332,8 +1432,10 @@ bool Resolver::ValidateBreakStatement(const sem::Statement* stmt) const { return true; } -bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const { - if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ true)) { +bool Validator::ContinueStatement(const sem::Statement* stmt, + sem::Statement* current_statement) const { + if (auto* continuing = + ClosestContinuing(/*stop_at_loop*/ true, current_statement)) { AddError("continuing blocks must not contain a continue statement", stmt->Declaration()->source); if (continuing != stmt->Declaration() && @@ -1352,8 +1454,10 @@ bool Resolver::ValidateContinueStatement(const sem::Statement* stmt) const { return true; } -bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const { - if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { +bool Validator::DiscardStatement(const sem::Statement* stmt, + sem::Statement* current_statement) const { + if (auto* continuing = + ClosestContinuing(/*stop_at_loop*/ false, current_statement)) { AddError("continuing blocks must not contain a discard statement", stmt->Declaration()->source); if (continuing != stmt->Declaration() && @@ -1365,7 +1469,7 @@ bool Resolver::ValidateDiscardStatement(const sem::Statement* stmt) const { return true; } -bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const { +bool Validator::FallthroughStatement(const sem::Statement* stmt) const { if (auto* block = As(stmt->Parent())) { if (auto* c = As(block->Parent())) { if (block->Declaration()->Last() == stmt->Declaration()) { @@ -1388,7 +1492,7 @@ bool Resolver::ValidateFallthroughStatement(const sem::Statement* stmt) const { return false; } -bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const { +bool Validator::ElseStatement(const sem::ElseStatement* stmt) const { if (auto* cond = stmt->Condition()) { auto* cond_ty = cond->Type()->UnwrapRef(); if (!cond_ty->Is()) { @@ -1401,7 +1505,7 @@ bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const { return true; } -bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const { +bool Validator::LoopStatement(const sem::LoopStatement* stmt) const { if (stmt->Behaviors().Empty()) { AddError("loop does not exit", stmt->Declaration()->source.Begin()); return false; @@ -1409,8 +1513,7 @@ bool Resolver::ValidateLoopStatement(const sem::LoopStatement* stmt) const { return true; } -bool Resolver::ValidateForLoopStatement( - const sem::ForLoopStatement* stmt) const { +bool Validator::ForLoopStatement(const sem::ForLoopStatement* stmt) const { if (stmt->Behaviors().Empty()) { AddError("for-loop does not exit", stmt->Declaration()->source.Begin()); return false; @@ -1427,7 +1530,7 @@ bool Resolver::ValidateForLoopStatement( return true; } -bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const { +bool Validator::IfStatement(const sem::IfStatement* stmt) const { auto* cond_ty = stmt->Condition()->Type()->UnwrapRef(); if (!cond_ty->Is()) { AddError( @@ -1438,7 +1541,7 @@ bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const { return true; } -bool Resolver::ValidateBuiltinCall(const sem::Call* call) const { +bool Validator::BuiltinCall(const sem::Call* call) const { if (call->Type()->Is()) { bool is_call_statement = false; if (auto* call_stmt = As(call->Stmt()->Declaration())) { @@ -1451,7 +1554,7 @@ bool Resolver::ValidateBuiltinCall(const sem::Call* call) const { // If the called function does not return a value, a function call // statement should be used instead. auto* ident = call->Declaration()->target.name; - auto name = builder_->Symbols().NameFor(ident->symbol); + auto name = symbols_.NameFor(ident->symbol); AddError("builtin '" + name + "' does not return a value", call->Declaration()->source); return false; @@ -1461,7 +1564,7 @@ bool Resolver::ValidateBuiltinCall(const sem::Call* call) const { return true; } -bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const { +bool Validator::TextureBuiltinFunction(const sem::Call* call) const { auto* builtin = call->Target()->As(); if (!builtin) { return false; @@ -1533,11 +1636,12 @@ bool Resolver::ValidateTextureBuiltinFunction(const sem::Call* call) const { check_arg_is_constexpr(sem::ParameterUsage::kComponent, 0, 3); } -bool Resolver::ValidateFunctionCall(const sem::Call* call) const { +bool Validator::FunctionCall(const sem::Call* call, + sem::Statement* current_statement) const { auto* decl = call->Declaration(); auto* target = call->Target()->As(); auto sym = decl->target.name->symbol; - auto name = builder_->Symbols().NameFor(sym); + auto name = symbols_.NameFor(sym); if (target->Declaration()->IsEntryPoint()) { // https://www.w3.org/TR/WGSL/#function-restriction @@ -1575,7 +1679,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const { if (param_type->Is()) { auto is_valid = false; if (auto* ident_expr = arg_expr->As()) { - auto* var = ResolvedSymbol(ident_expr); + auto* var = sem_.ResolvedSymbol(ident_expr); if (!var) { TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; return false; @@ -1587,7 +1691,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const { if (unary->op == ast::UnaryOp::kAddressOf) { if (auto* ident_unary = unary->expr->As()) { - auto* var = ResolvedSymbol(ident_unary); + auto* var = sem_.ResolvedSymbol(ident_unary); if (!var) { TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; @@ -1634,7 +1738,8 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const { } if (call->Behaviors().Contains(sem::Behavior::kDiscard)) { - if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { + if (auto* continuing = + ClosestContinuing(/*stop_at_loop*/ false, current_statement)) { AddError( "cannot call a function that may discard inside a continuing block", call->Declaration()->source); @@ -1649,7 +1754,7 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const { return true; } -bool Resolver::ValidateStructureConstructorOrCast( +bool Validator::StructureConstructorOrCast( const ast::CallExpression* ctor, const sem::Struct* struct_type) const { if (!struct_type->IsConstructible()) { @@ -1684,9 +1789,8 @@ bool Resolver::ValidateStructureConstructorOrCast( return true; } -bool Resolver::ValidateArrayConstructorOrCast( - const ast::CallExpression* ctor, - const sem::Array* array_type) const { +bool Validator::ArrayConstructorOrCast(const ast::CallExpression* ctor, + const sem::Array* array_type) const { auto& values = ctor->args; auto* elem_ty = array_type->ElemType(); for (auto* value : values) { @@ -1726,9 +1830,8 @@ bool Resolver::ValidateArrayConstructorOrCast( return true; } -bool Resolver::ValidateVectorConstructorOrCast( - const ast::CallExpression* ctor, - const sem::Vector* vec_type) const { +bool Validator::VectorConstructorOrCast(const ast::CallExpression* ctor, + const sem::Vector* vec_type) const { auto& values = ctor->args; auto* elem_ty = vec_type->type(); size_t value_cardinality_sum = 0; @@ -1790,8 +1893,7 @@ bool Resolver::ValidateVectorConstructorOrCast( return true; } -bool Resolver::ValidateVector(const sem::Vector* ty, - const Source& source) const { +bool Validator::Vector(const sem::Vector* ty, const Source& source) const { if (!ty->type()->is_scalar()) { AddError("vector element type must be 'bool', 'f32', 'i32' or 'u32'", source); @@ -1800,8 +1902,7 @@ bool Resolver::ValidateVector(const sem::Vector* ty, return true; } -bool Resolver::ValidateMatrix(const sem::Matrix* ty, - const Source& source) const { +bool Validator::Matrix(const sem::Matrix* ty, const Source& source) const { if (!ty->is_float_matrix()) { AddError("matrix element type must be 'f32'", source); return false; @@ -1809,16 +1910,15 @@ bool Resolver::ValidateMatrix(const sem::Matrix* ty, return true; } -bool Resolver::ValidateMatrixConstructorOrCast( - const ast::CallExpression* ctor, - const sem::Matrix* matrix_ty) const { +bool Validator::MatrixConstructorOrCast(const ast::CallExpression* ctor, + const sem::Matrix* matrix_ty) const { auto& values = ctor->args; // Zero Value expression if (values.empty()) { return true; } - if (!ValidateMatrix(matrix_ty, ctor->source)) { + if (!Matrix(matrix_ty, ctor->source)) { return false; } @@ -1844,7 +1944,7 @@ bool Resolver::ValidateMatrixConstructorOrCast( if (i > 0) { ss << ", "; } - ss << arg_tys[i]->FriendlyName(builder_->Symbols()); + ss << arg_tys[i]->FriendlyName(symbols_); } ss << ")" << std::endl << std::endl; ss << "3 candidates available:" << std::endl; @@ -1885,8 +1985,8 @@ bool Resolver::ValidateMatrixConstructorOrCast( return true; } -bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, - const sem::Type* ty) const { +bool Validator::ScalarConstructorOrCast(const ast::CallExpression* ctor, + const sem::Type* ty) const { if (ctor->args.size() == 0) { return true; } @@ -1921,7 +2021,8 @@ bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, return true; } -bool Resolver::ValidatePipelineStages() const { +bool Validator::PipelineStages( + const std::vector& entry_points) const { auto check_workgroup_storage = [&](const sem::Function* func, const sem::Function* entry_point) { auto stage = entry_point->Declaration()->PipelineStage(); @@ -1940,17 +2041,14 @@ bool Resolver::ValidatePipelineStages() const { } AddNote("variable is declared here", var->Declaration()->source); if (func != entry_point) { - TraverseCallChain(diagnostics_, entry_point, func, - [&](const sem::Function* f) { - AddNote("called by function '" + - builder_->Symbols().NameFor( - f->Declaration()->symbol) + - "'", - f->Declaration()->source); - }); + TraverseCallChain( + diagnostics_, entry_point, func, [&](const sem::Function* f) { + AddNote("called by function '" + + symbols_.NameFor(f->Declaration()->symbol) + "'", + f->Declaration()->source); + }); AddNote("called by entry point '" + - builder_->Symbols().NameFor( - entry_point->Declaration()->symbol) + + symbols_.NameFor(entry_point->Declaration()->symbol) + "'", entry_point->Declaration()->source); } @@ -1961,7 +2059,7 @@ bool Resolver::ValidatePipelineStages() const { return true; }; - for (auto* entry_point : entry_points_) { + for (auto* entry_point : entry_points) { if (!check_workgroup_storage(entry_point, entry_point)) { return false; } @@ -1985,15 +2083,12 @@ bool Resolver::ValidatePipelineStages() const { if (func != entry_point) { TraverseCallChain( diagnostics_, entry_point, func, [&](const sem::Function* f) { - AddNote( - "called by function '" + - builder_->Symbols().NameFor(f->Declaration()->symbol) + - "'", - f->Declaration()->source); + AddNote("called by function '" + + symbols_.NameFor(f->Declaration()->symbol) + "'", + f->Declaration()->source); }); AddNote("called by entry point '" + - builder_->Symbols().NameFor( - entry_point->Declaration()->symbol) + + symbols_.NameFor(entry_point->Declaration()->symbol) + "'", entry_point->Declaration()->source); } @@ -2003,7 +2098,7 @@ bool Resolver::ValidatePipelineStages() const { return true; }; - for (auto* entry_point : entry_points_) { + for (auto* entry_point : entry_points) { if (!check_builtin_calls(entry_point, entry_point)) { return false; } @@ -2016,8 +2111,7 @@ bool Resolver::ValidatePipelineStages() const { return true; } -bool Resolver::ValidateArray(const sem::Array* arr, - const Source& source) const { +bool Validator::Array(const sem::Array* arr, const Source& source) const { auto* el_ty = arr->ElemType(); if (!IsFixedFootprint(el_ty)) { @@ -2028,10 +2122,10 @@ bool Resolver::ValidateArray(const sem::Array* arr, return true; } -bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, - uint32_t el_size, - uint32_t el_align, - const Source& source) const { +bool Validator::ArrayStrideAttribute(const ast::StrideAttribute* attr, + uint32_t el_size, + uint32_t el_align, + const Source& source) const { auto stride = attr->stride; bool is_valid_stride = (stride >= el_size) && (stride >= el_align) && (stride % el_align == 0); @@ -2050,8 +2144,8 @@ bool Resolver::ValidateArrayStrideAttribute(const ast::StrideAttribute* attr, return true; } -bool Resolver::ValidateAlias(const ast::Alias* alias) const { - auto name = builder_->Symbols().NameFor(alias->name); +bool Validator::Alias(const ast::Alias* alias) const { + auto name = symbols_.NameFor(alias->name); if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { AddError("'" + name + "' is a builtin and cannot be redeclared as an alias", alias->source); @@ -2061,9 +2155,9 @@ bool Resolver::ValidateAlias(const ast::Alias* alias) const { return true; } -bool Resolver::ValidateStructure(const sem::Struct* str, - ast::PipelineStage stage) const { - auto name = builder_->Symbols().NameFor(str->Declaration()->name); +bool Validator::Structure(const sem::Struct* str, + ast::PipelineStage stage) const { + auto name = symbols_.NameFor(str->Declaration()->name); if (sem::ParseBuiltinType(name) != sem::BuiltinType::kNone) { AddError("'" + name + "' is a builtin and cannot be redeclared as a struct", str->Declaration()->source); @@ -2122,13 +2216,13 @@ bool Resolver::ValidateStructure(const sem::Struct* str, invariant_attribute = invariant; } else if (auto* location = attr->As()) { has_location = true; - if (!ValidateLocationAttribute(location, member->Type(), locations, - stage, member->Declaration()->source)) { + if (!LocationAttribute(location, member->Type(), locations, stage, + member->Declaration()->source)) { return false; } } else if (auto* builtin = attr->As()) { - if (!ValidateBuiltinAttribute(builtin, member->Type(), stage, - /* is_input */ false)) { + if (!BuiltinAttribute(builtin, member->Type(), stage, + /* is_input */ false)) { return false; } if (builtin->builtin == ast::Builtin::kPosition) { @@ -2136,7 +2230,7 @@ bool Resolver::ValidateStructure(const sem::Struct* str, } } else if (auto* interpolate = attr->As()) { interpolate_attribute = interpolate; - if (!ValidateInterpolateAttribute(interpolate, member->Type())) { + if (!InterpolateAttribute(interpolate, member->Type())) { return false; } } @@ -2165,13 +2259,12 @@ bool Resolver::ValidateStructure(const sem::Struct* str, return true; } -bool Resolver::ValidateLocationAttribute( - const ast::LocationAttribute* location, - const sem::Type* type, - std::unordered_set& locations, - ast::PipelineStage stage, - const Source& source, - const bool is_input) const { +bool Validator::LocationAttribute(const ast::LocationAttribute* location, + const sem::Type* type, + std::unordered_set& locations, + ast::PipelineStage stage, + const Source& source, + const bool is_input) const { std::string inputs_or_output = is_input ? "inputs" : "output"; if (stage == ast::PipelineStage::kCompute) { AddError("attribute is not valid for compute shader " + inputs_or_output, @@ -2201,9 +2294,10 @@ bool Resolver::ValidateLocationAttribute( return true; } -bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, - const sem::Type* func_type, - const sem::Type* ret_type) const { +bool Validator::Return(const ast::ReturnStatement* ret, + const sem::Type* func_type, + const sem::Type* ret_type, + sem::Statement* current_statement) const { if (func_type->UnwrapRef() != ret_type) { AddError( "return statement type must match its function " @@ -2215,7 +2309,8 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, } auto* sem = sem_.Get(ret); - if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { + if (auto* continuing = + ClosestContinuing(/*stop_at_loop*/ false, current_statement)) { AddError("continuing blocks must not contain a return statement", ret->source); if (continuing != sem->Declaration() && @@ -2228,7 +2323,7 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, return true; } -bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { +bool Validator::SwitchStatement(const ast::SwitchStatement* s) { auto* cond_ty = sem_.TypeOf(s->condition)->UnwrapRef(); if (!cond_ty->is_integer_scalar()) { AddError( @@ -2284,8 +2379,8 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { return true; } -bool Resolver::ValidateAssignment(const ast::Statement* a, - const sem::Type* rhs_ty) const { +bool Validator::Assignment(const ast::Statement* a, + const sem::Type* rhs_ty) const { const ast::Expression* lhs; const ast::Expression* rhs; if (auto* assign = a->As()) { @@ -2317,19 +2412,17 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement auto const* lhs_ty = sem_.TypeOf(lhs); - if (auto* var = ResolvedSymbol(lhs)) { + if (auto* var = sem_.ResolvedSymbol(lhs)) { auto* decl = var->Declaration(); if (var->Is()) { AddError("cannot assign to function parameter", lhs->source); - AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + - "' is declared here:", + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); return false; } if (decl->is_const) { AddError("cannot assign to const", lhs->source); - AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + - "' is declared here:", + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); return false; } @@ -2366,25 +2459,23 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, return true; } -bool Resolver::ValidateIncrementDecrementStatement( +bool Validator::IncrementDecrementStatement( const ast::IncrementDecrementStatement* inc) const { const ast::Expression* lhs = inc->lhs; // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement - if (auto* var = ResolvedSymbol(lhs)) { + if (auto* var = sem_.ResolvedSymbol(lhs)) { auto* decl = var->Declaration(); if (var->Is()) { AddError("cannot modify function parameter", lhs->source); - AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + - "' is declared here:", + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); return false; } if (decl->is_const) { AddError("cannot modify constant value", lhs->source); - AddNote("'" + builder_->Symbols().NameFor(decl->symbol) + - "' is declared here:", + AddNote("'" + symbols_.NameFor(decl->symbol) + "' is declared here:", decl->source); return false; } @@ -2415,7 +2506,7 @@ bool Resolver::ValidateIncrementDecrementStatement( return true; } -bool Resolver::ValidateNoDuplicateAttributes( +bool Validator::NoDuplicateAttributes( const ast::AttributeList& attributes) const { std::unordered_map seen; for (auto* d : attributes) { @@ -2429,8 +2520,8 @@ bool Resolver::ValidateNoDuplicateAttributes( return true; } -bool Resolver::IsValidationDisabled(const ast::AttributeList& attributes, - ast::DisabledValidation validation) const { +bool Validator::IsValidationDisabled(const ast::AttributeList& attributes, + ast::DisabledValidation validation) const { for (auto* attribute : attributes) { if (auto* dv = attribute->As()) { if (dv->validation == validation) { @@ -2441,15 +2532,15 @@ bool Resolver::IsValidationDisabled(const ast::AttributeList& attributes, return false; } -bool Resolver::IsValidationEnabled(const ast::AttributeList& attributes, - ast::DisabledValidation validation) const { +bool Validator::IsValidationEnabled(const ast::AttributeList& attributes, + ast::DisabledValidation validation) const { return !IsValidationDisabled(attributes, validation); } -std::string Resolver::VectorPretty(uint32_t size, - const sem::Type* element_type) const { +std::string Validator::VectorPretty(uint32_t size, + const sem::Type* element_type) const { sem::Vector vec_type(element_type, size); - return vec_type.FriendlyName(builder_->Symbols()); + return vec_type.FriendlyName(symbols_); } } // namespace tint::resolver diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h new file mode 100644 index 0000000000..146f7ba052 --- /dev/null +++ b/src/tint/resolver/validator.h @@ -0,0 +1,457 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_RESOLVER_VALIDATOR_H_ +#define SRC_TINT_RESOLVER_VALIDATOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "src/tint/ast/pipeline_stage.h" +#include "src/tint/program_builder.h" +#include "src/tint/resolver/sem_helper.h" +#include "src/tint/source.h" + +// Forward declarations +namespace tint::ast { +class IndexAccessorExpression; +class BinaryExpression; +class BitcastExpression; +class CallExpression; +class CallStatement; +class CaseStatement; +class ForLoopStatement; +class Function; +class IdentifierExpression; +class LoopStatement; +class MemberAccessorExpression; +class ReturnStatement; +class SwitchStatement; +class UnaryOpExpression; +class Variable; +} // namespace tint::ast +namespace tint::sem { +class Array; +class Atomic; +class BlockStatement; +class Builtin; +class CaseStatement; +class ElseStatement; +class ForLoopStatement; +class IfStatement; +class LoopStatement; +class Statement; +class SwitchStatement; +class TypeConstructor; +} // namespace tint::sem + +namespace tint::resolver { + +/// Validation logic for various ast nodes. The validations in general should +/// be shallow and depend on the resolver to call on children. The validations +/// also assume that sem changes have already been made. The validation checks +/// should not alter the AST or SEM trees. +class Validator { + public: + /// The valid type storage layouts typedef + using ValidTypeStorageLayouts = + std::set>; + + /// Constructor + /// @param builder the program builder + /// @param helper the SEM helper to validate with + Validator(ProgramBuilder* builder, SemHelper& helper); + ~Validator(); + + /// Adds the given error message to the diagnostics + /// @param msg the error message + /// @param source the error source + void AddError(const std::string& msg, const Source& source) const; + + /// Adds the given warning message to the diagnostics + /// @param msg the warning message + /// @param source the warning source + void AddWarning(const std::string& msg, const Source& source) const; + + /// Adds the given note message to the diagnostics + /// @param msg the note message + /// @param source the note source + void AddNote(const std::string& msg, const Source& source) const; + + /// @param type the given type + /// @returns true if the given type is a plain type + bool IsPlain(const sem::Type* type) const; + + /// @param type the given type + /// @returns true if the given type is a fixed-footprint type + bool IsFixedFootprint(const sem::Type* type) const; + + /// @param type the given type + /// @returns true if the given type is storable + bool IsStorable(const sem::Type* type) const; + + /// @param type the given type + /// @returns true if the given type is host-shareable + bool IsHostShareable(const sem::Type* type) const; + + /// Validates pipeline stages + /// @param entry_points the entry points to the module + /// @returns true on success, false otherwise. + bool PipelineStages(const std::vector& entry_points) const; + + /// Validates aliases + /// @param alias the alias to validate + /// @returns true on success, false otherwise. + bool Alias(const ast::Alias* alias) const; + + /// Validates the array + /// @param arr the array to validate + /// @param source the source of the array + /// @returns true on success, false otherwise. + bool Array(const sem::Array* arr, const Source& source) const; + + /// Validates an array stride attribute + /// @param attr the stride attribute to validate + /// @param el_size the element size + /// @param el_align the element alignment + /// @param source the source of the attribute + /// @returns true on success, false otherwise + bool ArrayStrideAttribute(const ast::StrideAttribute* attr, + uint32_t el_size, + uint32_t el_align, + const Source& source) const; + + /// Validates an atomic + /// @param a the atomic ast node to validate + /// @param s the atomic sem node + /// @returns true on success, false otherwise. + bool Atomic(const ast::Atomic* a, const sem::Atomic* s) const; + + /// Validates an atoic variable + /// @param var the variable to validate + /// @param atomic_composite_info store atomic information + /// @returns true on success, false otherwise. + bool AtomicVariable(const sem::Variable* var, + std::unordered_map + atomic_composite_info) const; + + /// Validates an assignment + /// @param a the assignment statement + /// @param rhs_ty the type of the right hand side + /// @returns true on success, false otherwise. + bool Assignment(const ast::Statement* a, const sem::Type* rhs_ty) const; + + /// Validates a bitcase + /// @param cast the bitcast expression + /// @param to the destination type + /// @returns true on success, false otherwise + bool Bitcast(const ast::BitcastExpression* cast, const sem::Type* to) const; + + /// Validates a break statement + /// @param stmt the break statement to validate + /// @param current_statement the current statement being resolved + /// @returns true on success, false otherwise. + bool BreakStatement(const sem::Statement* stmt, + sem::Statement* current_statement) const; + + /// Validates a builtin attribute + /// @param attr the attribute to validate + /// @param storage_type the attribute storage type + /// @param stage the current pipeline stage + /// @param is_input true if this is an input attribute + /// @returns true on success, false otherwise. + bool BuiltinAttribute(const ast::BuiltinAttribute* attr, + const sem::Type* storage_type, + ast::PipelineStage stage, + const bool is_input) const; + + /// Validates a continue statement + /// @param stmt the continue statement to validate + /// @param current_statement the current statement being resolved + /// @returns true on success, false otherwise + bool ContinueStatement(const sem::Statement* stmt, + sem::Statement* current_statement) const; + + /// Validates a discard statement + /// @param stmt the statement to validate + /// @param current_statement the current statement being resolved + /// @returns true on success, false otherwise + bool DiscardStatement(const sem::Statement* stmt, + sem::Statement* current_statement) const; + + /// Validates an else statement + /// @param stmt the else statement to validate + /// @returns true on success, false otherwise + bool ElseStatement(const sem::ElseStatement* stmt) const; + + /// Validates an entry point + /// @param func the entry point function to validate + /// @param stage the pipeline stage for the entry point + /// @returns true on success, false otherwise + bool EntryPoint(const sem::Function* func, ast::PipelineStage stage) const; + + /// Validates a for loop + /// @param stmt the for loop statement to validate + /// @returns true on success, false otherwise + bool ForLoopStatement(const sem::ForLoopStatement* stmt) const; + + /// Validates a fallthrough statement + /// @param stmt the fallthrough to validate + /// @returns true on success, false otherwise + bool FallthroughStatement(const sem::Statement* stmt) const; + + /// Validates a function + /// @param func the function to validate + /// @param stage the current pipeline stage + /// @returns true on success, false otherwise. + bool Function(const sem::Function* func, ast::PipelineStage stage) const; + + /// Validates a function call + /// @param call the function call to validate + /// @param current_statement the current statement being resolved + /// @returns true on success, false otherwise + bool FunctionCall(const sem::Call* call, + sem::Statement* current_statement) const; + + /// Validates a global variable + /// @param var the global variable to validate + /// @param constant_ids the set of constant ids in the module + /// @param atomic_composite_info atomic composite info in the module + /// @returns true on success, false otherwise + bool GlobalVariable( + const sem::Variable* var, + std::unordered_map constant_ids, + std::unordered_map atomic_composite_info) + const; + + /// Validates an if statement + /// @param stmt the statement to validate + /// @returns true on success, false otherwise + bool IfStatement(const sem::IfStatement* stmt) const; + + /// Validates an increment or decrement statement + /// @param stmt the statement to validate + /// @returns true on success, false otherwise + bool IncrementDecrementStatement( + const ast::IncrementDecrementStatement* stmt) const; + + /// Validates an interpolate attribute + /// @param attr the interpolation attribute to validate + /// @param storage_type the storage type of the attached variable + /// @returns true on succes, false otherwise + bool InterpolateAttribute(const ast::InterpolateAttribute* attr, + const sem::Type* storage_type) const; + + /// Validates a builtin call + /// @param call the builtin call to validate + /// @returns true on success, false otherwise. + bool BuiltinCall(const sem::Call* call) const; + + /// Validates a location attribute + /// @param location the location attribute to validate + /// @param type the variable type + /// @param locations the set of locations in the module + /// @param stage the current pipeline stage + /// @param source the source of the attribute + /// @param is_input true if this is an input variable + /// @returns true on success, false otherwise. + bool LocationAttribute(const ast::LocationAttribute* location, + const sem::Type* type, + std::unordered_set& locations, + ast::PipelineStage stage, + const Source& source, + const bool is_input = false) const; + + /// Validates a loop statement + /// @param stmt the loop statement + /// @returns true on success, false otherwise. + bool LoopStatement(const sem::LoopStatement* stmt) const; + + /// Validates a matrix + /// @param ty the matrix to validate + /// @param source the source of the matrix + /// @returns true on success, false otherwise + bool Matrix(const sem::Matrix* ty, const Source& source) const; + + /// Validates a function parameter + /// @param func the function the variable is for + /// @param var the variable to validate + /// @returns true on success, false otherwise + bool FunctionParameter(const ast::Function* func, + const sem::Variable* var) const; + + /// Validates a return + /// @param ret the return statement to validate + /// @param func_type the return type of the curreunt function + /// @param ret_type the return type + /// @param current_statement the current statement being resolved + /// @returns true on success, false otherwise + bool Return(const ast::ReturnStatement* ret, + const sem::Type* func_type, + const sem::Type* ret_type, + sem::Statement* current_statement) const; + + /// Validates a list of statements + /// @param stmts the statements to validate + /// @returns true on success, false otherwise + bool Statements(const ast::StatementList& stmts) const; + + /// Validates a storage texture + /// @param t the texture to validate + /// @returns true on success, false otherwise + bool StorageTexture(const ast::StorageTexture* t) const; + + /// Validates a structure + /// @param str the structure to validate + /// @param stage the current pipeline stage + /// @returns true on success, false otherwise. + bool Structure(const sem::Struct* str, ast::PipelineStage stage) const; + + /// Validates a structure constructor or cast + /// @param ctor the call expression to validate + /// @param struct_type the type of the structure + /// @returns true on success, false otherwise + bool StructureConstructorOrCast(const ast::CallExpression* ctor, + const sem::Struct* struct_type) const; + + /// Validates a switch statement + /// @param s the switch to validate + /// @returns true on success, false otherwise + bool SwitchStatement(const ast::SwitchStatement* s); + + /// Validates a variable + /// @param var the variable to validate + /// @returns true on success, false otherwise. + bool Variable(const sem::Variable* var) const; + + /// Validates a variable constructor or cast + /// @param var the variable to validate + /// @param storage_class the storage class of the variable + /// @param storage_type the type of the storage + /// @param rhs_type the right hand side of the expression + /// @returns true on succes, false otherwise + bool VariableConstructorOrCast(const ast::Variable* var, + ast::StorageClass storage_class, + const sem::Type* storage_type, + const sem::Type* rhs_type) const; + + /// Validates a vector + /// @param ty the vector to validate + /// @param source the source of the vector + /// @returns true on success, false otherwise + bool Vector(const sem::Vector* ty, const Source& source) const; + + /// Validates a vector constructor or cast + /// @param ctor the call expression to validate + /// @param vec_type the vector type + /// @returns true on success, false otherwise + bool VectorConstructorOrCast(const ast::CallExpression* ctor, + const sem::Vector* vec_type) const; + + /// Validates a matrix constructor or cast + /// @param ctor the call expression to validate + /// @param matrix_type the type of the matrix + /// @returns true on success, false otherwise + bool MatrixConstructorOrCast(const ast::CallExpression* ctor, + const sem::Matrix* matrix_type) const; + + /// Validates a scalar constructor or cast + /// @param ctor the call expression to validate + /// @param type the type of the scalar + /// @returns true on success, false otherwise. + bool ScalarConstructorOrCast(const ast::CallExpression* ctor, + const sem::Type* type) const; + + /// Validates an array constructor or cast + /// @param ctor the call expresion to validate + /// @param arr_type the type of the array + /// @returns true on success, false otherwise + bool ArrayConstructorOrCast(const ast::CallExpression* ctor, + const sem::Array* arr_type) const; + + /// Validates a texture builtin function + /// @param call the builtin call to validate + /// @returns true on success, false otherwise + bool TextureBuiltinFunction(const sem::Call* call) const; + + /// Validates there are no duplicate attributes + /// @param attributes the list of attributes to validate + /// @returns true on success, false otherwise. + bool NoDuplicateAttributes(const ast::AttributeList& attributes) const; + + /// Validates a storage class layout + /// @param type the type to validate + /// @param sc the storage class + /// @param source the source of the type + /// @param layouts previously validated storage layouts + /// @returns true on success, false otherwise + bool StorageClassLayout(const sem::Type* type, + ast::StorageClass sc, + Source source, + ValidTypeStorageLayouts& layouts) const; + + /// Validates a storage class layout + /// @param var the variable to validate + /// @param layouts previously validated storage layouts + /// @returns true on success, false otherwise. + bool StorageClassLayout(const sem::Variable* var, + ValidTypeStorageLayouts& layouts) const; + + /// @returns true if the attribute list contains a + /// ast::DisableValidationAttribute with the validation mode equal to + /// `validation` + /// @param attributes the attribute list to check + /// @param validation the validation mode to check + bool IsValidationDisabled(const ast::AttributeList& attributes, + ast::DisabledValidation validation) const; + + /// @returns true if the attribute list does not contains a + /// ast::DisableValidationAttribute with the validation mode equal to + /// `validation` + /// @param attributes the attribute list to check + /// @param validation the validation mode to check + bool IsValidationEnabled(const ast::AttributeList& attributes, + ast::DisabledValidation validation) const; + + private: + /// Searches the current statement and up through parents of the current + /// statement looking for a loop or for-loop continuing statement. + /// @returns the closest continuing statement to the current statement that + /// (transitively) owns the current statement. + /// @param stop_at_loop if true then the function will return nullptr if a + /// loop or for-loop was found before the continuing. + /// @param current_statement the current statement being resolved + const ast::Statement* ClosestContinuing( + bool stop_at_loop, + sem::Statement* current_statement) const; + + /// Returns a human-readable string representation of the vector type name + /// with the given parameters. + /// @param size the vector dimension + /// @param element_type scalar vector sub-element type + /// @return pretty string representation + std::string VectorPretty(uint32_t size, const sem::Type* element_type) const; + + SymbolTable& symbols_; + diag::List& diagnostics_; + SemHelper& sem_; +}; + +} // namespace tint::resolver + +#endif // SRC_TINT_RESOLVER_VALIDATOR_H_ diff --git a/src/tint/resolver/validator_is_storeable_test.cc b/src/tint/resolver/validator_is_storeable_test.cc new file mode 100644 index 0000000000..f180936e87 --- /dev/null +++ b/src/tint/resolver/validator_is_storeable_test.cc @@ -0,0 +1,86 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/validator.h" + +#include "gmock/gmock.h" +#include "src/tint/resolver/validator_test_helper.h" +#include "src/tint/sem/atomic_type.h" + +namespace tint::resolver { +namespace { + +using ValidatorIsStorableTest = ValidatorTest; + +TEST_F(ValidatorIsStorableTest, Void) { + EXPECT_FALSE(v()->IsStorable(create())); +} + +TEST_F(ValidatorIsStorableTest, Scalar) { + EXPECT_TRUE(v()->IsStorable(create())); + EXPECT_TRUE(v()->IsStorable(create())); + EXPECT_TRUE(v()->IsStorable(create())); + EXPECT_TRUE(v()->IsStorable(create())); +} + +TEST_F(ValidatorIsStorableTest, Vector) { + EXPECT_TRUE(v()->IsStorable(create(create(), 2u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 3u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 4u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 2u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 3u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 4u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 2u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 3u))); + EXPECT_TRUE(v()->IsStorable(create(create(), 4u))); +} + +TEST_F(ValidatorIsStorableTest, Matrix) { + auto* vec2 = create(create(), 2u); + auto* vec3 = create(create(), 3u); + auto* vec4 = create(create(), 4u); + EXPECT_TRUE(v()->IsStorable(create(vec2, 2u))); + EXPECT_TRUE(v()->IsStorable(create(vec2, 3u))); + EXPECT_TRUE(v()->IsStorable(create(vec2, 4u))); + EXPECT_TRUE(v()->IsStorable(create(vec3, 2u))); + EXPECT_TRUE(v()->IsStorable(create(vec3, 3u))); + EXPECT_TRUE(v()->IsStorable(create(vec3, 4u))); + EXPECT_TRUE(v()->IsStorable(create(vec4, 2u))); + EXPECT_TRUE(v()->IsStorable(create(vec4, 3u))); + EXPECT_TRUE(v()->IsStorable(create(vec4, 4u))); +} + +TEST_F(ValidatorIsStorableTest, Pointer) { + auto* ptr = create( + create(), ast::StorageClass::kPrivate, ast::Access::kReadWrite); + EXPECT_FALSE(v()->IsStorable(ptr)); +} + +TEST_F(ValidatorIsStorableTest, Atomic) { + EXPECT_TRUE(v()->IsStorable(create(create()))); + EXPECT_TRUE(v()->IsStorable(create(create()))); +} + +TEST_F(ValidatorIsStorableTest, ArraySizedOfStorable) { + auto* arr = create(create(), 5u, 4u, 20u, 4u, 4u); + EXPECT_TRUE(v()->IsStorable(arr)); +} + +TEST_F(ValidatorIsStorableTest, ArrayUnsizedOfStorable) { + auto* arr = create(create(), 0u, 4u, 4u, 4u, 4u); + EXPECT_TRUE(v()->IsStorable(arr)); +} + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/validator_test_helper.cc b/src/tint/resolver/validator_test_helper.cc new file mode 100644 index 0000000000..123d5c7c53 --- /dev/null +++ b/src/tint/resolver/validator_test_helper.cc @@ -0,0 +1,27 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/validator_test_helper.h" + +#include + +namespace tint::resolver { + +TestHelper::TestHelper() + : validator_( + std::make_unique(this->Symbols(), this->Diagnostics())) {} + +TestHelper::~TestHelper() = default; + +} // namespace tint::resolver diff --git a/src/tint/resolver/validator_test_helper.h b/src/tint/resolver/validator_test_helper.h new file mode 100644 index 0000000000..3dd3ea2a05 --- /dev/null +++ b/src/tint/resolver/validator_test_helper.h @@ -0,0 +1,46 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_ +#define SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_ + +#include + +#include "gtest/gtest.h" +#include "src/tint/program_builder.h" +#include "src/tint/resolver/validator.h" + +namespace tint::resolver { + +/// Helper class for testing +class TestHelper : public ProgramBuilder { + public: + /// Constructor + TestHelper(); + + /// Destructor + ~TestHelper() override; + + /// @return a pointer to the Validator + Validator* v() const { return validator_.get(); } + + private: + std::unique_ptr validator_; +}; + +class ValidatorTest : public TestHelper, public testing::Test {}; + +} // namespace tint::resolver + +#endif // SRC_TINT_RESOLVER_VALIDATOR_TEST_HELPER_H_