diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 5594bf87da..980b16d0ea 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -377,6 +377,8 @@ libtint_source_set("libtint_core_all_src") { "resolver/resolver.h", "resolver/resolver_constants.cc", "resolver/resolver_validation.cc", + "resolver/sem_helper.cc", + "resolver/sem_helper.h", "scope_stack.h", "sem/array.h", "sem/atomic_type.h", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 527eeafa42..ae3f2240ff 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -258,6 +258,8 @@ set(TINT_LIB_SRCS resolver/resolver_constants.cc resolver/resolver_validation.cc resolver/resolver.h + resolver/sem_helper.cc + resolver/sem_helper.h scope_stack.h sem/array.cc sem/array.h diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 56e59009a6..4f7c08d2df 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -84,7 +84,8 @@ namespace tint::resolver { Resolver::Resolver(ProgramBuilder* builder) : builder_(builder), diagnostics_(builder->Diagnostics()), - builtin_table_(BuiltinTable::Create(*builder)) {} + builtin_table_(BuiltinTable::Create(*builder)), + sem_(builder) {} Resolver::~Resolver() = default; @@ -505,7 +506,7 @@ void Resolver::AllocateOverridableConstantIds() { next_constant_id = constant_id + 1; } - auto* sem = Sem(var); + auto* sem = sem_.Get(var); const_cast(sem)->SetConstantId(constant_id); } } @@ -513,9 +514,11 @@ void Resolver::AllocateOverridableConstantIds() { void Resolver::SetShadows() { for (auto it : dependencies_.shadows) { Switch( - Sem(it.first), // - [&](sem::LocalVariable* local) { local->SetShadows(Sem(it.second)); }, - [&](sem::Parameter* param) { param->SetShadows(Sem(it.second)); }); + sem_.Get(it.first), // + [&](sem::LocalVariable* local) { + local->SetShadows(sem_.Get(it.second)); + }, + [&](sem::Parameter* param) { param->SetShadows(sem_.Get(it.second)); }); } } @@ -755,7 +758,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { "workgroup_size arguments must be of the same type, either i32 " "or u32"; - auto* ty = TypeOf(expr); + auto* ty = sem_.TypeOf(expr); bool is_i32 = ty->UnwrapRef()->Is(); bool is_u32 = ty->UnwrapRef()->Is(); if (!is_i32 && !is_u32) { @@ -772,7 +775,7 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { sem::Constant value; - if (auto* user = Sem(expr)->As()) { + if (auto* user = sem_.Get(expr)->As()) { // We have an variable of a module-scope constant. auto* decl = user->Variable()->Declaration(); if (!decl->is_const) { @@ -785,14 +788,14 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { } if (decl->constructor) { - value = Sem(decl->constructor)->ConstantValue(); + value = sem_.Get(decl->constructor)->ConstantValue(); } else { // No constructor means this value must be overriden by the user. ws[i].value = 0; continue; } } else if (expr->Is()) { - value = Sem(expr)->ConstantValue(); + value = sem_.Get(expr)->ConstantValue(); } else { AddError( "workgroup_size argument must be either a literal or a " @@ -1168,8 +1171,8 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) { sem::Expression* Resolver::IndexAccessor( const ast::IndexAccessorExpression* expr) { - auto* idx = Sem(expr->index); - auto* obj = Sem(expr->object); + auto* idx = sem_.Get(expr->index); + auto* obj = sem_.Get(expr->object); auto* obj_raw_ty = obj->Type(); auto* obj_ty = obj_raw_ty->UnwrapRef(); auto* ty = Switch( @@ -1180,7 +1183,7 @@ sem::Expression* Resolver::IndexAccessor( return builder_->create(mat->type(), mat->rows()); }, [&](Default) { - AddError("cannot index type '" + TypeNameOf(obj_ty) + "'", + AddError("cannot index type '" + sem_.TypeNameOf(obj_ty) + "'", expr->source); return nullptr; }); @@ -1191,7 +1194,7 @@ sem::Expression* Resolver::IndexAccessor( auto* idx_ty = idx->Type()->UnwrapRef(); if (!idx_ty->IsAnyOf()) { AddError("index must be of type 'i32' or 'u32', found: '" + - TypeNameOf(idx_ty) + "'", + sem_.TypeNameOf(idx_ty) + "'", idx->Declaration()->source); return nullptr; } @@ -1211,7 +1214,7 @@ sem::Expression* Resolver::IndexAccessor( } sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { - auto* inner = Sem(expr->expr); + auto* inner = sem_.Get(expr->expr); auto* ty = Type(expr->type); if (!ty) { return nullptr; @@ -1240,7 +1243,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { const sem::Type* arg_el_ty = nullptr; for (size_t i = 0; i < expr->args.size(); i++) { - auto* arg = Sem(expr->args[i]); + auto* arg = sem_.Get(expr->args[i]); if (!arg) { return nullptr; } @@ -1637,7 +1640,7 @@ sem::Call* Resolver::TypeConstructor( } sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { - auto* ty = TypeOf(literal); + auto* ty = sem_.TypeOf(literal); if (!ty) { return nullptr; } @@ -1725,14 +1728,14 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { sem::Expression* Resolver::MemberAccessor( const ast::MemberAccessorExpression* expr) { - auto* structure = TypeOf(expr->structure); + auto* structure = sem_.TypeOf(expr->structure); auto* storage_ty = structure->UnwrapRef(); const sem::Type* ret = nullptr; std::vector swizzle; // Structure may be a side-effecting expression (e.g. function call). - auto* sem_structure = Sem(expr->structure); + auto* sem_structure = sem_.Get(expr->structure); bool has_side_effects = sem_structure && sem_structure->HasSideEffects(); if (auto* str = storage_ty->As()) { @@ -1840,14 +1843,14 @@ sem::Expression* Resolver::MemberAccessor( AddError( "invalid member accessor expression. Expected vector or struct, got '" + - TypeNameOf(storage_ty) + "'", + sem_.TypeNameOf(storage_ty) + "'", expr->structure->source); return nullptr; } sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { - auto* lhs = Sem(expr->lhs); - auto* rhs = Sem(expr->rhs); + auto* lhs = sem_.Get(expr->lhs); + auto* rhs = sem_.Get(expr->rhs); auto* lhs_ty = lhs->Type()->UnwrapRef(); auto* rhs_ty = rhs->Type()->UnwrapRef(); @@ -1855,8 +1858,8 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { if (!ty) { AddError( "Binary expression operand types are invalid for this operation: " + - TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " + - TypeNameOf(rhs_ty), + sem_.TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " + + sem_.TypeNameOf(rhs_ty), expr->source); return nullptr; } @@ -2036,7 +2039,7 @@ const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty, } sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { - auto* expr = Sem(unary->expr); + auto* expr = sem_.Get(unary->expr); auto* expr_ty = expr->Type(); if (!expr_ty) { return nullptr; @@ -2049,9 +2052,9 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { // Result type matches the deref'd inner type. ty = expr_ty->UnwrapRef(); if (!ty->Is() && !ty->is_bool_vector()) { - AddError( - "cannot logical negate expression of type '" + TypeNameOf(expr_ty), - unary->expr->source); + AddError("cannot logical negate expression of type '" + + sem_.TypeNameOf(expr_ty), + unary->expr->source); return nullptr; } break; @@ -2061,7 +2064,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { ty = expr_ty->UnwrapRef(); if (!ty->is_integer_scalar_or_vector()) { AddError("cannot bitwise complement expression of type '" + - TypeNameOf(expr_ty), + sem_.TypeNameOf(expr_ty), unary->expr->source); return nullptr; } @@ -2072,8 +2075,9 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { ty = expr_ty->UnwrapRef(); if (!(ty->IsAnyOf() || ty->is_signed_integer_vector() || ty->is_float_vector())) { - AddError("cannot negate expression of type '" + TypeNameOf(expr_ty), - unary->expr->source); + AddError( + "cannot negate expression of type '" + sem_.TypeNameOf(expr_ty), + unary->expr->source); return nullptr; } break; @@ -2089,9 +2093,10 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { auto* array = unary->expr->As(); auto* member = unary->expr->As(); - if ((array && TypeOf(array->object)->UnwrapRef()->Is()) || + if ((array && + sem_.TypeOf(array->object)->UnwrapRef()->Is()) || (member && - TypeOf(member->structure)->UnwrapRef()->Is())) { + sem_.TypeOf(member->structure)->UnwrapRef()->Is())) { AddError("cannot take the address of a vector component", unary->expr->source); return nullptr; @@ -2111,7 +2116,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { ptr->StoreType(), ptr->StorageClass(), ptr->Access()); } else { AddError("cannot dereference expression of type '" + - TypeNameOf(expr_ty) + "'", + sem_.TypeNameOf(expr_ty) + "'", unary->expr->source); return nullptr; } @@ -2143,41 +2148,6 @@ sem::Type* Resolver::TypeDecl(const ast::TypeDecl* named_type) { return result; } -std::string Resolver::TypeNameOf(const sem::Type* ty) const { - return RawTypeNameOf(ty->UnwrapRef()); -} - -std::string Resolver::RawTypeNameOf(const sem::Type* ty) const { - return ty->FriendlyName(builder_->Symbols()); -} - -sem::Type* Resolver::TypeOf(const ast::Expression* expr) const { - auto* sem = Sem(expr); - return sem ? const_cast(sem->Type()) : nullptr; -} - -sem::Type* Resolver::TypeOf(const ast::LiteralExpression* lit) { - return Switch( - lit, - [&](const ast::SintLiteralExpression*) { - return builder_->create(); - }, - [&](const ast::UintLiteralExpression*) { - return builder_->create(); - }, - [&](const ast::FloatLiteralExpression*) { - return builder_->create(); - }, - [&](const ast::BoolLiteralExpression*) { - return builder_->create(); - }, - [&](Default) { - TINT_UNREACHABLE(Resolver, diagnostics_) - << "Unhandled literal type: " << lit->TypeInfo().name; - return nullptr; - }); -} - sem::Array* Resolver::Array(const ast::Array* arr) { auto source = arr->source; @@ -2187,7 +2157,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { } if (!IsPlain(elem_type)) { // Check must come before GetDefaultAlignAndSize() - AddError(TypeNameOf(elem_type) + + AddError(sem_.TypeNameOf(elem_type) + " cannot be used as an element type of an array", source); return nullptr; @@ -2367,7 +2337,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) { // Validate member type if (!IsPlain(type)) { - AddError(TypeNameOf(type) + + AddError(sem_.TypeNameOf(type) + " cannot be used as the type of a structure member", member->source); return nullptr; @@ -2507,7 +2477,7 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { // Validate after processing the return value expression so that its type // is available for validation. - auto* ret_type = stmt->value ? TypeOf(stmt->value)->UnwrapRef() + auto* ret_type = stmt->value ? sem_.TypeOf(stmt->value)->UnwrapRef() : builder_->create(); return ValidateReturn(stmt, current_function_->ReturnType(), ret_type); }); @@ -2597,7 +2567,7 @@ sem::Statement* Resolver::AssignmentStatement( behaviors.Add(lhs->Behaviors()); } - return ValidateAssignment(stmt, TypeOf(stmt->rhs)); + return ValidateAssignment(stmt, sem_.TypeOf(stmt->rhs)); }); } @@ -2645,8 +2615,8 @@ sem::Statement* Resolver::CompoundAssignmentStatement( auto* ty = BinaryOpType(lhs_ty, rhs_ty, stmt->op); if (!ty) { AddError("compound assignment operand types are invalid: " + - TypeNameOf(lhs_ty) + " " + FriendlyName(stmt->op) + " " + - TypeNameOf(rhs_ty), + sem_.TypeNameOf(lhs_ty) + " " + FriendlyName(stmt->op) + + " " + sem_.TypeNameOf(rhs_ty), stmt->source); return false; } @@ -2725,7 +2695,8 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, for (auto* member : str->Members()) { if (!ApplyStorageClassUsageToType(sc, member->Type(), usage)) { std::stringstream err; - err << "while analysing structure member " << TypeNameOf(str) << "." + err << "while analysing structure member " << sem_.TypeNameOf(str) + << "." << builder_->Symbols().NameFor(member->Declaration()->symbol); AddNote(err.str(), member->Declaration()->source); return false; @@ -2749,8 +2720,9 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) { std::stringstream err; - err << "Type '" << TypeNameOf(ty) << "' cannot be used in storage class '" - << sc << "' as it is non-host-shareable"; + err << "Type '" << sem_.TypeNameOf(ty) + << "' cannot be used in storage class '" << sc + << "' as it is non-host-shareable"; AddError(err.str(), usage); return false; } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 7531585880..95dd5ffc68 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -26,6 +26,7 @@ #include "src/tint/builtin_table.h" #include "src/tint/program_builder.h" #include "src/tint/resolver/dependency_graph.h" +#include "src/tint/resolver/sem_helper.h" #include "src/tint/scope_stack.h" #include "src/tint/sem/binding_point.h" #include "src/tint/sem/block_statement.h" @@ -405,23 +406,6 @@ class Resolver { /// Set the shadowing information on variable declarations. /// @note this method must only be called after all semantic nodes are built. void SetShadows(); - - /// @returns the resolved type of the ast::Expression `expr` - /// @param expr the expression - sem::Type* TypeOf(const ast::Expression* expr) const; - - /// @returns the type name of the given semantic type, unwrapping - /// references. - std::string TypeNameOf(const sem::Type* ty) const; - - /// @returns the type name of the given semantic type, without unwrapping - /// references. - std::string RawTypeNameOf(const sem::Type* ty) const; - - /// @returns the semantic type of the AST literal `lit` - /// @param lit the literal - sem::Type* TypeOf(const ast::LiteralExpression* lit); - /// StatementScope() does the following: /// * Creates the AST -> SEM mapping. /// * Assigns `sem` to #current_statement_ @@ -467,21 +451,6 @@ class Resolver { sem::Constant EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type); - /// Sem is a helper for obtaining the semantic node for the given AST node. - template - auto* Sem(const AST_OR_TYPE* ast) const { - using T = sem::Info::GetResultType; - auto* sem = builder_->Sem().Get(ast); - if (!sem) { - TINT_ICE(Resolver, diagnostics_) - << "AST node '" << ast->TypeInfo().name << "' had no semantic info\n" - << "At: " << ast->source << "\n" - << "Pointer: " << ast; - } - return const_cast(As(sem)); - } - /// @returns true if the symbol is the name of a builtin function. bool IsBuiltin(Symbol) const; @@ -541,6 +510,7 @@ class Resolver { diag::List& diagnostics_; std::unique_ptr const builtin_table_; DependencyGraph dependencies_; + SemHelper sem_; std::vector entry_points_; std::unordered_map atomic_composite_info_; std::unordered_set marked_; diff --git a/src/tint/resolver/resolver_validation.cc b/src/tint/resolver/resolver_validation.cc index 04b9925f0a..44e0a2f078 100644 --- a/src/tint/resolver/resolver_validation.cc +++ b/src/tint/resolver/resolver_validation.cc @@ -201,8 +201,8 @@ bool Resolver::ValidateVariableConstructorOrCast( if (storage_ty != value_type) { std::string decl = var->is_const ? "let" : "var"; AddError("cannot initialize " + decl + " of type '" + - TypeNameOf(storage_ty) + "' with value of type '" + - TypeNameOf(rhs_ty) + "'", + sem_.TypeNameOf(storage_ty) + "' with value of type '" + + sem_.TypeNameOf(rhs_ty) + "'", var->source); return false; } @@ -536,9 +536,9 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const { AddError( "atomic variables must have or storage class", source); - AddNote( - "atomic sub-type of '" + TypeNameOf(type) + "' is declared here", - found->second); + AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + + "' is declared here", + found->second); return false; } else if (sc == ast::StorageClass::kStorage && access != ast::Access::kReadWrite) { @@ -546,9 +546,9 @@ bool Resolver::ValidateAtomicVariable(const sem::Variable* var) const { "atomic variables in storage class must have read_write " "access mode", source); - AddNote( - "atomic sub-type of '" + TypeNameOf(type) + "' is declared here", - found->second); + AddNote("atomic sub-type of '" + sem_.TypeNameOf(type) + + "' is declared here", + found->second); return false; } } @@ -575,15 +575,17 @@ bool Resolver::ValidateVariable(const sem::Variable* var) const { } if (!decl->is_const && !IsStorable(storage_ty)) { - AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a var", - decl->source); + AddError( + sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a var", + decl->source); return false; } if (decl->is_const && !var->Is() && !(storage_ty->IsConstructible() || storage_ty->Is())) { - AddError(TypeNameOf(storage_ty) + " cannot be used as the type of a let", - decl->source); + AddError( + sem_.TypeNameOf(storage_ty) + " cannot be used as the type of a let", + decl->source); return false; } @@ -616,7 +618,7 @@ bool Resolver::ValidateVariable(const sem::Variable* var) const { // If the store type is a texture type or a sampler type, then the // variable declaration must not have a storage class attribute. The // storage class will always be handle. - AddError("variables of type '" + TypeNameOf(storage_ty) + + AddError("variables of type '" + sem_.TypeNameOf(storage_ty) + "' must not have a storage class", decl->source); return false; @@ -686,9 +688,9 @@ bool Resolver::ValidateFunctionParameter(const ast::Function* func, } } else if (!var->Type() ->IsAnyOf()) { - AddError( - "store type of function parameter cannot be " + TypeNameOf(var->Type()), - decl->source); + AddError("store type of function parameter cannot be " + + sem_.TypeNameOf(var->Type()), + decl->source); return false; } @@ -886,7 +888,7 @@ bool Resolver::ValidateFunction(const sem::Function* func, if (decl->body) { sem::Behaviors behaviors{sem::Behavior::kNext}; if (auto* last = decl->body->Last()) { - behaviors = Sem(last)->Behaviors(); + behaviors = sem_.Get(last)->Behaviors(); } if (behaviors.Contains(sem::Behavior::kNext)) { AddError("missing return at end of function", decl->source); @@ -1222,7 +1224,7 @@ bool Resolver::ValidateEntryPoint(const sem::Function* func, bool Resolver::ValidateStatements(const ast::StatementList& stmts) const { for (auto* stmt : stmts) { - if (!Sem(stmt)->IsReachable()) { + if (!sem_.Get(stmt)->IsReachable()) { /// TODO(https://github.com/gpuweb/gpuweb/issues/2378): This may need to /// become an error. AddWarning("code is unreachable", stmt->source); @@ -1234,14 +1236,15 @@ bool Resolver::ValidateStatements(const ast::StatementList& stmts) const { bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to) const { - auto* from = TypeOf(cast->expr)->UnwrapRef(); + auto* from = sem_.TypeOf(cast->expr)->UnwrapRef(); if (!from->is_numeric_scalar_or_vector()) { - AddError("'" + TypeNameOf(from) + "' cannot be bitcast", + AddError("'" + sem_.TypeNameOf(from) + "' cannot be bitcast", cast->expr->source); return false; } if (!to->is_numeric_scalar_or_vector()) { - AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source); + AddError("cannot bitcast to '" + sem_.TypeNameOf(to) + "'", + cast->type->source); return false; } @@ -1253,8 +1256,8 @@ bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, }; if (width(from) != width(to)) { - AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" + - TypeNameOf(to) + "'", + AddError("cannot bitcast from '" + sem_.TypeNameOf(from) + "' to '" + + sem_.TypeNameOf(to) + "'", cast->source); return false; } @@ -1389,9 +1392,9 @@ bool Resolver::ValidateElseStatement(const sem::ElseStatement* stmt) const { if (auto* cond = stmt->Condition()) { auto* cond_ty = cond->Type()->UnwrapRef(); if (!cond_ty->Is()) { - AddError( - "else statement condition must be bool, got " + TypeNameOf(cond_ty), - stmt->Condition()->Declaration()->source); + AddError("else statement condition must be bool, got " + + sem_.TypeNameOf(cond_ty), + stmt->Condition()->Declaration()->source); return false; } } @@ -1415,8 +1418,9 @@ bool Resolver::ValidateForLoopStatement( if (auto* cond = stmt->Condition()) { auto* cond_ty = cond->Type()->UnwrapRef(); if (!cond_ty->Is()) { - AddError("for-loop condition must be bool, got " + TypeNameOf(cond_ty), - stmt->Condition()->Declaration()->source); + AddError( + "for-loop condition must be bool, got " + sem_.TypeNameOf(cond_ty), + stmt->Condition()->Declaration()->source); return false; } } @@ -1426,8 +1430,9 @@ bool Resolver::ValidateForLoopStatement( bool Resolver::ValidateIfStatement(const sem::IfStatement* stmt) const { auto* cond_ty = stmt->Condition()->Type()->UnwrapRef(); if (!cond_ty->Is()) { - AddError("if statement condition must be bool, got " + TypeNameOf(cond_ty), - stmt->Condition()->Declaration()->source); + AddError( + "if statement condition must be bool, got " + sem_.TypeNameOf(cond_ty), + stmt->Condition()->Declaration()->source); return false; } return true; @@ -1556,13 +1561,13 @@ bool Resolver::ValidateFunctionCall(const sem::Call* call) const { const sem::Variable* param = target->Parameters()[i]; const ast::Expression* arg_expr = decl->args[i]; auto* param_type = param->Type(); - auto* arg_type = TypeOf(arg_expr)->UnwrapRef(); + auto* arg_type = sem_.TypeOf(arg_expr)->UnwrapRef(); if (param_type != arg_type) { AddError("type mismatch for argument " + std::to_string(i + 1) + " in call to '" + name + "', expected '" + - TypeNameOf(param_type) + "', got '" + TypeNameOf(arg_type) + - "'", + sem_.TypeNameOf(param_type) + "', got '" + + sem_.TypeNameOf(arg_type) + "'", arg_expr->source); return false; } @@ -1664,13 +1669,13 @@ bool Resolver::ValidateStructureConstructorOrCast( } for (auto* member : struct_type->Members()) { auto* value = ctor->args[member->Index()]; - auto* value_ty = TypeOf(value); + auto* value_ty = sem_.TypeOf(value); if (member->Type() != value_ty->UnwrapRef()) { AddError( "type in struct constructor does not match struct member type: " "expected '" + - TypeNameOf(member->Type()) + "', found '" + - TypeNameOf(value_ty) + "'", + sem_.TypeNameOf(member->Type()) + "', found '" + + sem_.TypeNameOf(value_ty) + "'", value->source); return false; } @@ -1685,12 +1690,13 @@ bool Resolver::ValidateArrayConstructorOrCast( auto& values = ctor->args; auto* elem_ty = array_type->ElemType(); for (auto* value : values) { - auto* value_ty = TypeOf(value)->UnwrapRef(); + auto* value_ty = sem_.TypeOf(value)->UnwrapRef(); if (value_ty != elem_ty) { AddError( "type in array constructor does not match array type: " "expected '" + - TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'", + sem_.TypeNameOf(elem_ty) + "', found '" + + sem_.TypeNameOf(value_ty) + "'", value->source); return false; } @@ -1727,13 +1733,14 @@ bool Resolver::ValidateVectorConstructorOrCast( auto* elem_ty = vec_type->type(); size_t value_cardinality_sum = 0; for (auto* value : values) { - auto* value_ty = TypeOf(value)->UnwrapRef(); + auto* value_ty = sem_.TypeOf(value)->UnwrapRef(); if (value_ty->is_scalar()) { if (elem_ty != value_ty) { AddError( "type in vector constructor does not match vector type: " "expected '" + - TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_ty) + "'", + sem_.TypeNameOf(elem_ty) + "', found '" + + sem_.TypeNameOf(value_ty) + "'", value->source); return false; } @@ -1748,8 +1755,8 @@ bool Resolver::ValidateVectorConstructorOrCast( AddError( "type in vector constructor does not match vector type: " "expected '" + - TypeNameOf(elem_ty) + "', found '" + TypeNameOf(value_elem_ty) + - "'", + sem_.TypeNameOf(elem_ty) + "', found '" + + sem_.TypeNameOf(value_elem_ty) + "'", value->source); return false; } @@ -1758,7 +1765,7 @@ bool Resolver::ValidateVectorConstructorOrCast( } else { // A vector constructor can only accept vectors and scalars. AddError("expected vector or scalar type in vector constructor; found: " + - TypeNameOf(value_ty), + sem_.TypeNameOf(value_ty), value->source); return false; } @@ -1774,8 +1781,9 @@ bool Resolver::ValidateVectorConstructorOrCast( } const Source& values_start = values[0]->source; const Source& values_end = values[values.size() - 1]->source; - AddError("attempted to construct '" + TypeNameOf(vec_type) + "' with " + - std::to_string(value_cardinality_sum) + " component(s)", + AddError("attempted to construct '" + sem_.TypeNameOf(vec_type) + + "' with " + std::to_string(value_cardinality_sum) + + " component(s)", Source::Combine(values_start, values_end)); return false; } @@ -1817,7 +1825,7 @@ bool Resolver::ValidateMatrixConstructorOrCast( std::vector arg_tys; arg_tys.reserve(values.size()); for (auto* value : values) { - arg_tys.emplace_back(TypeOf(value)->UnwrapRef()); + arg_tys.emplace_back(sem_.TypeOf(value)->UnwrapRef()); } auto* elem_type = matrix_ty->type(); @@ -1828,8 +1836,8 @@ bool Resolver::ValidateMatrixConstructorOrCast( auto print_error = [&]() { const Source& values_start = values[0]->source; const Source& values_end = values[values.size() - 1]->source; - auto type_name = TypeNameOf(matrix_ty); - auto elem_type_name = TypeNameOf(elem_type); + auto type_name = sem_.TypeNameOf(matrix_ty); + auto elem_type_name = sem_.TypeNameOf(elem_type); std::stringstream ss; ss << "no matching constructor " + type_name << "("; for (size_t i = 0; i < values.size(); i++) { @@ -1891,7 +1899,7 @@ bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, // Validate constructor auto* value = ctor->args[0]; - auto* value_ty = TypeOf(value)->UnwrapRef(); + auto* value_ty = sem_.TypeOf(value)->UnwrapRef(); using Bool = sem::Bool; using I32 = sem::I32; @@ -1903,8 +1911,8 @@ bool Resolver::ValidateScalarConstructorOrCast(const ast::CallExpression* ctor, (ty->Is() && value_ty->is_scalar()) || (ty->Is() && value_ty->is_scalar()); if (!is_valid) { - AddError("cannot construct '" + TypeNameOf(ty) + - "' with a value of type '" + TypeNameOf(value_ty) + "'", + AddError("cannot construct '" + sem_.TypeNameOf(ty) + + "' with a value of type '" + sem_.TypeNameOf(value_ty) + "'", ctor->source); return false; @@ -2172,7 +2180,7 @@ bool Resolver::ValidateLocationAttribute( } if (!type->is_numeric_scalar_or_vector()) { - std::string invalid_type = TypeNameOf(type); + std::string invalid_type = sem_.TypeNameOf(type); AddError("cannot apply 'location' attribute to declaration of type '" + invalid_type + "'", source); @@ -2200,13 +2208,13 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, AddError( "return statement type must match its function " "return type, returned '" + - TypeNameOf(ret_type) + "', expected '" + TypeNameOf(func_type) + - "'", + sem_.TypeNameOf(ret_type) + "', expected '" + + sem_.TypeNameOf(func_type) + "'", ret->source); return false; } - auto* sem = Sem(ret); + auto* sem = sem_.Get(ret); if (auto* continuing = ClosestContinuing(/*stop_at_loop*/ false)) { AddError("continuing blocks must not contain a return statement", ret->source); @@ -2221,7 +2229,7 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret, } bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { - auto* cond_ty = TypeOf(s->condition)->UnwrapRef(); + auto* cond_ty = sem_.TypeOf(s->condition)->UnwrapRef(); if (!cond_ty->is_integer_scalar()) { AddError( "switch statement selector expression must be of a " @@ -2245,7 +2253,7 @@ bool Resolver::ValidateSwitch(const ast::SwitchStatement* s) { } for (auto* selector : case_stmt->selectors) { - if (cond_ty != TypeOf(selector)) { + if (cond_ty != sem_.TypeOf(selector)) { AddError( "the case selector values must have the same " "type as the selector expression.", @@ -2297,7 +2305,7 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, if (!ty->IsConstructible() && !ty->IsAnyOf()) { AddError( - "cannot assign '" + TypeNameOf(rhs_ty) + + "cannot assign '" + sem_.TypeNameOf(rhs_ty) + "' to '_'. '_' can only be assigned a constructible, pointer, " "texture or sampler type", rhs->source); @@ -2307,7 +2315,7 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, } // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement - auto const* lhs_ty = TypeOf(lhs); + auto const* lhs_ty = sem_.TypeOf(lhs); if (auto* var = ResolvedSymbol(lhs)) { auto* decl = var->Declaration(); @@ -2330,7 +2338,7 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, auto* lhs_ref = lhs_ty->As(); if (!lhs_ref) { // LHS is not a reference, so it has no storage. - AddError("cannot assign to value of type '" + TypeNameOf(lhs_ty) + "'", + AddError("cannot assign to value of type '" + sem_.TypeNameOf(lhs_ty) + "'", lhs->source); return false; } @@ -2340,8 +2348,8 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, // Value type has to match storage type if (storage_ty != value_type) { - AddError("cannot assign '" + TypeNameOf(rhs_ty) + "' to '" + - TypeNameOf(lhs_ty) + "'", + AddError("cannot assign '" + sem_.TypeNameOf(rhs_ty) + "' to '" + + sem_.TypeNameOf(lhs_ty) + "'", a->source); return false; } @@ -2350,9 +2358,9 @@ bool Resolver::ValidateAssignment(const ast::Statement* a, return false; } if (lhs_ref->Access() == ast::Access::kRead) { - AddError( - "cannot store into a read-only type '" + RawTypeNameOf(lhs_ty) + "'", - a->source); + AddError("cannot store into a read-only type '" + + sem_.RawTypeNameOf(lhs_ty) + "'", + a->source); return false; } return true; @@ -2382,11 +2390,11 @@ bool Resolver::ValidateIncrementDecrementStatement( } } - auto const* lhs_ty = TypeOf(lhs); + auto const* lhs_ty = sem_.TypeOf(lhs); auto* lhs_ref = lhs_ty->As(); if (!lhs_ref) { // LHS is not a reference, so it has no storage. - AddError("cannot modify value of type '" + TypeNameOf(lhs_ty) + "'", + AddError("cannot modify value of type '" + sem_.TypeNameOf(lhs_ty) + "'", lhs->source); return false; } @@ -2399,8 +2407,9 @@ bool Resolver::ValidateIncrementDecrementStatement( } if (lhs_ref->Access() == ast::Access::kRead) { - AddError("cannot modify read-only type '" + RawTypeNameOf(lhs_ty) + "'", - inc->source); + AddError( + "cannot modify read-only type '" + sem_.RawTypeNameOf(lhs_ty) + "'", + inc->source); return false; } return true; diff --git a/src/tint/resolver/sem_helper.cc b/src/tint/resolver/sem_helper.cc new file mode 100644 index 0000000000..74b3c5b077 --- /dev/null +++ b/src/tint/resolver/sem_helper.cc @@ -0,0 +1,60 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/resolver/sem_helper.h" + +#include "src/tint/sem/expression.h" + +namespace tint::resolver { + +SemHelper::SemHelper(ProgramBuilder* builder) : builder_(builder) {} + +SemHelper::~SemHelper() = default; + +std::string SemHelper::TypeNameOf(const sem::Type* ty) const { + return RawTypeNameOf(ty->UnwrapRef()); +} + +std::string SemHelper::RawTypeNameOf(const sem::Type* ty) const { + return ty->FriendlyName(builder_->Symbols()); +} + +sem::Type* SemHelper::TypeOf(const ast::Expression* expr) const { + auto* sem = Get(expr); + return sem ? const_cast(sem->Type()) : nullptr; +} + +sem::Type* SemHelper::TypeOf(const ast::LiteralExpression* lit) { + return Switch( + lit, + [&](const ast::SintLiteralExpression*) { + return builder_->create(); + }, + [&](const ast::UintLiteralExpression*) { + return builder_->create(); + }, + [&](const ast::FloatLiteralExpression*) { + return builder_->create(); + }, + [&](const ast::BoolLiteralExpression*) { + return builder_->create(); + }, + [&](Default) { + TINT_UNREACHABLE(Resolver, builder_->Diagnostics()) + << "Unhandled literal type: " << lit->TypeInfo().name; + return nullptr; + }); +} + +} // namespace tint::resolver diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h new file mode 100644 index 0000000000..a9400fd33b --- /dev/null +++ b/src/tint/resolver/sem_helper.h @@ -0,0 +1,66 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_RESOLVER_SEM_HELPER_H_ +#define SRC_TINT_RESOLVER_SEM_HELPER_H_ + +#include + +#include "src/tint/diagnostic/diagnostic.h" +#include "src/tint/program_builder.h" + +namespace tint::resolver { +class SemHelper { + public: + explicit SemHelper(ProgramBuilder* builder); + ~SemHelper(); + + /// Get is a helper for obtaining the semantic node for the given AST node. + template + auto* Get(const AST_OR_TYPE* ast) const { + using T = sem::Info::GetResultType; + auto* sem = builder_->Sem().Get(ast); + if (!sem) { + TINT_ICE(Resolver, builder_->Diagnostics()) + << "AST node '" << ast->TypeInfo().name << "' had no semantic info\n" + << "At: " << ast->source << "\n" + << "Pointer: " << ast; + } + return const_cast(As(sem)); + } + + /// @returns the resolved type of the ast::Expression `expr` + /// @param expr the expression + sem::Type* TypeOf(const ast::Expression* expr) const; + + /// @returns the semantic type of the AST literal `lit` + /// @param lit the literal + sem::Type* TypeOf(const ast::LiteralExpression* lit); + + /// @returns the type name of the given semantic type, unwrapping + /// references. + std::string TypeNameOf(const sem::Type* ty) const; + + /// @returns the type name of the given semantic type, without unwrapping + /// references. + std::string RawTypeNameOf(const sem::Type* ty) const; + + private: + ProgramBuilder* builder_; +}; + +} // namespace tint::resolver + +#endif // SRC_TINT_RESOLVER_SEM_HELPER_H_