diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index b9c0833c8e..cc454908b6 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1163,7 +1163,6 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { // Resolve all of the arguments, their types and the set of behaviors. std::vector args(expr->args.size()); - std::vector arg_tys(expr->args.size()); sem::Behaviors arg_behaviors; for (size_t i = 0; i < expr->args.size(); i++) { auto* arg = sem_.Get(expr->args[i]); @@ -1171,7 +1170,6 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { return nullptr; } args[i] = arg; - arg_tys[i] = arg->Type(); arg_behaviors.Add(arg->Behaviors()); } arg_behaviors.Remove(sem::Behavior::kNext); @@ -1180,31 +1178,10 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { bool has_side_effects = std::any_of(args.begin(), args.end(), [](auto* e) { return e->HasSideEffects(); }); - // array_or_struct_ctor is a helper for building a sem::TypeConstructor call for an array or - // structure type. These types have constructors that are always explicitly typed (no - // inference), and do not support type conversion. As such, they do not use the IntrinsicTable. - auto array_or_struct_ctor = [&](const sem::Type* ty) -> sem::Call* { - auto* call_target = utils::GetOrCreate( - type_ctors_, TypeConstructorSig{ty, arg_tys}, [&]() -> sem::TypeConstructor* { - return builder_->create( - ty, utils::Transform( - arg_tys, [&](const sem::Type* t, size_t i) -> const sem::Parameter* { - return builder_->create( - nullptr, // declaration - static_cast(i), // index - t->UnwrapRef(), // type - ast::StorageClass::kNone, // storage_class - ast::Access::kUndefined); // access - })); - }); - auto value = EvaluateConstantValue(expr, ty); - return builder_->create(expr, call_target, std::move(args), current_statement_, - value, has_side_effects); - }; - // ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion // call for a CtorConvIntrinsic with an optional template argument type. auto ct_ctor_or_conv = [&](CtorConvIntrinsic ty, const sem::Type* template_arg) -> sem::Call* { + auto arg_tys = utils::Transform(args, [](auto* arg) { return arg->Type(); }); auto* call_target = intrinsic_table_->Lookup(ty, template_arg, arg_tys, expr->source); if (!call_target) { return nullptr; @@ -1229,8 +1206,44 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { [&](const sem::U32*) { return ct_ctor_or_conv(CtorConvIntrinsic::kU32, nullptr); }, [&](const sem::F32*) { return ct_ctor_or_conv(CtorConvIntrinsic::kF32, nullptr); }, [&](const sem::Bool*) { return ct_ctor_or_conv(CtorConvIntrinsic::kBool, nullptr); }, - [&](const sem::Array*) { return array_or_struct_ctor(ty); }, - [&](const sem::Struct*) { return array_or_struct_ctor(ty); }, + [&](const sem::Array* arr) -> sem::Call* { + auto* call_target = utils::GetOrCreate( + array_ctors_, ArrayConstructorSig{{arr, args.size()}}, + [&]() -> sem::TypeConstructor* { + sem::ParameterList params(args.size()); + for (size_t i = 0; i < args.size(); i++) { + params[i] = builder_->create( + nullptr, // declaration + static_cast(i), // index + arr->ElemType(), // type + ast::StorageClass::kNone, // storage_class + ast::Access::kUndefined); // access + } + return builder_->create(arr, std::move(params)); + }); + auto value = EvaluateConstantValue(expr, call_target->ReturnType()); + return builder_->create(expr, call_target, std::move(args), + current_statement_, value, has_side_effects); + }, + [&](const sem::Struct* str) -> sem::Call* { + auto* call_target = utils::GetOrCreate( + struct_ctors_, StructConstructorSig{{str, args.size()}}, + [&]() -> sem::TypeConstructor* { + sem::ParameterList params(std::min(args.size(), str->Members().size())); + for (size_t i = 0, n = params.size(); i < n; i++) { + params[i] = builder_->create( + nullptr, // declaration + static_cast(i), // index + str->Members()[i]->Type(), // type + ast::StorageClass::kNone, // storage_class + ast::Access::kUndefined); // access + } + return builder_->create(str, std::move(params)); + }); + auto value = EvaluateConstantValue(expr, call_target->ReturnType()); + return builder_->create(expr, call_target, std::move(args), + current_statement_, value, has_side_effects); + }, [&](Default) { AddError("type is not constructible", expr->source); return nullptr; @@ -1301,7 +1314,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { resolved, // [&](sem::Type* ty) { // A type constructor or conversions. - // Note: Unlike the codepath where we're resolving the call target from an + // Note: Unlike the code path where we're resolving the call target from an // ast::Type, all types must already have the element type explicitly specified, so // there's no need to infer element types. return ty_ctor_or_conv(ty); @@ -1319,7 +1332,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { auto name = builder_->Symbols().NameFor(ident->symbol); auto builtin_type = sem::ParseBuiltinType(name); if (builtin_type != sem::BuiltinType::kNone) { - return BuiltinCall(expr, builtin_type, std::move(args), std::move(arg_tys)); + return BuiltinCall(expr, builtin_type, std::move(args)); } TINT_ICE(Resolver, diagnostics_) @@ -1339,11 +1352,14 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, sem::BuiltinType builtin_type, - const std::vector args, - const std::vector arg_tys) { - auto* builtin = intrinsic_table_->Lookup(builtin_type, std::move(arg_tys), expr->source); - if (!builtin) { - return nullptr; + std::vector args) { + const sem::Builtin* builtin = nullptr; + { + auto arg_tys = utils::Transform(args, [](auto* arg) { return arg->Type(); }); + builtin = intrinsic_table_->Lookup(builtin_type, arg_tys, expr->source); + if (!builtin) { + return nullptr; + } } if (builtin->IsDeprecated()) { @@ -1366,21 +1382,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, if (!validator_.TextureBuiltinFunction(call)) { return nullptr; } - // Collect a texture/sampler pair for this builtin. - const auto& signature = builtin->Signature(); - int texture_index = signature.IndexOf(sem::ParameterUsage::kTexture); - if (texture_index == -1) { - TINT_ICE(Resolver, diagnostics_) << "texture builtin without texture parameter"; - } - - auto* texture = args[texture_index]->As()->Variable(); - if (!texture->Type()->UnwrapRef()->Is()) { - int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler); - const sem::Variable* sampler = - sampler_index != -1 ? args[sampler_index]->As()->Variable() - : nullptr; - current_function_->AddTextureSamplerPair(texture, sampler); - } + CollectTextureSamplerPairs(builtin, call->Arguments()); } if (!validator_.BuiltinCall(call)) { @@ -1392,9 +1394,27 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, return call; } +void Resolver::CollectTextureSamplerPairs(const sem::Builtin* builtin, + const std::vector& args) const { + // Collect a texture/sampler pair for this builtin. + const auto& signature = builtin->Signature(); + int texture_index = signature.IndexOf(sem::ParameterUsage::kTexture); + if (texture_index == -1) { + TINT_ICE(Resolver, diagnostics_) << "texture builtin without texture parameter"; + } + auto* texture = args[texture_index]->As()->Variable(); + if (!texture->Type()->UnwrapRef()->Is()) { + int sampler_index = signature.IndexOf(sem::ParameterUsage::kSampler); + const sem::Variable* sampler = + sampler_index != -1 ? args[sampler_index]->As()->Variable() + : nullptr; + current_function_->AddTextureSamplerPair(texture, sampler); + } +} + sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr, sem::Function* target, - const std::vector args, + std::vector args, sem::Behaviors arg_behaviors) { auto sym = expr->target.name->symbol; auto name = builder_->Symbols().NameFor(sym); @@ -1420,25 +1440,7 @@ sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr, current_function_->AddTransitivelyReferencedGlobal(var); } - // Map all texture/sampler pairs from the target function to the - // current function. These can only be global or parameter - // variables. Resolve any parameter variables to the corresponding - // argument passed to the current function. Leave global variables - // as-is. Then add the mapped pair to the current function's list of - // texture/sampler pairs. - for (sem::VariablePair pair : target->TextureSamplerPairs()) { - const sem::Variable* texture = pair.first; - const sem::Variable* sampler = pair.second; - if (auto* param = texture->As()) { - texture = args[param->Index()]->As()->Variable(); - } - if (sampler) { - if (auto* param = sampler->As()) { - sampler = args[param->Index()]->As()->Variable(); - } - } - current_function_->AddTextureSamplerPair(texture, sampler); - } + CollectTextureSamplerPairs(target, call->Arguments()); } target->AddCallSite(call); @@ -1452,6 +1454,29 @@ sem::Call* Resolver::FunctionCall(const ast::CallExpression* expr, return call; } +void Resolver::CollectTextureSamplerPairs(sem::Function* func, + const std::vector& args) const { + // Map all texture/sampler pairs from the target function to the + // current function. These can only be global or parameter + // variables. Resolve any parameter variables to the corresponding + // argument passed to the current function. Leave global variables + // as-is. Then add the mapped pair to the current function's list of + // texture/sampler pairs. + for (sem::VariablePair pair : func->TextureSamplerPairs()) { + const sem::Variable* texture = pair.first; + const sem::Variable* sampler = pair.second; + if (auto* param = texture->As()) { + texture = args[param->Index()]->As()->Variable(); + } + if (sampler) { + if (auto* param = sampler->As()) { + sampler = args[param->Index()]->As()->Variable(); + } + } + current_function_->AddTextureSamplerPair(texture, sampler); + } +} + sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { auto* ty = Switch( literal, @@ -1667,27 +1692,27 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e } sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { - auto* lhs = sem_.Get(expr->lhs); - auto* rhs = sem_.Get(expr->rhs); + const auto* lhs = sem_.Get(expr->lhs); + const auto* rhs = sem_.Get(expr->rhs); auto* lhs_ty = lhs->Type()->UnwrapRef(); auto* rhs_ty = rhs->Type()->UnwrapRef(); - auto* ty = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false).result; - if (!ty) { + auto op = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false); + if (!op.result) { return nullptr; } - auto val = EvaluateConstantValue(expr, ty); + auto val = EvaluateConstantValue(expr, op.result); bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); - auto* sem = - builder_->create(expr, ty, current_statement_, val, has_side_effects); + auto* sem = builder_->create(expr, op.result, current_statement_, val, + has_side_effects); sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); return sem; } sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { - auto* expr = sem_.Get(unary->expr); + const auto* expr = sem_.Get(unary->expr); auto* expr_ty = expr->Type(); if (!expr_ty) { return nullptr; @@ -1740,6 +1765,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { if (!ty) { return nullptr; } + break; } } @@ -2076,19 +2102,21 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { auto& behaviors = current_statement_->Behaviors(); behaviors = sem::Behavior::kReturn; + const sem::Type* value_ty = nullptr; if (auto* value = stmt->value) { auto* expr = Expression(value); if (!expr) { return false; } behaviors.Add(expr->Behaviors() - sem::Behavior::kNext); + value_ty = expr->Type()->UnwrapRef(); + } else { + value_ty = builder_->create(); } // Validate after processing the return value expression so that its type // is available for validation. - auto* ret_type = - stmt->value ? sem_.TypeOf(stmt->value)->UnwrapRef() : builder_->create(); - return validator_.Return(stmt, current_function_->ReturnType(), ret_type, + return validator_.Return(stmt, current_function_->ReturnType(), value_ty, current_statement_); }); } @@ -2379,20 +2407,4 @@ bool Resolver::IsBuiltin(Symbol symbol) const { return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone; } -//////////////////////////////////////////////////////////////////////////////// -// Resolver::TypeConstructorSig -//////////////////////////////////////////////////////////////////////////////// -Resolver::TypeConstructorSig::TypeConstructorSig(const sem::Type* ty, - const std::vector params) - : type(ty), parameters(params) {} -Resolver::TypeConstructorSig::TypeConstructorSig(const TypeConstructorSig&) = default; -Resolver::TypeConstructorSig::~TypeConstructorSig() = default; - -bool Resolver::TypeConstructorSig::operator==(const TypeConstructorSig& rhs) const { - return type == rhs.type && parameters == rhs.parameters; -} -std::size_t Resolver::TypeConstructorSig::Hasher::operator()(const TypeConstructorSig& sig) const { - return utils::Hash(sig.type, sig.parameters); -} - } // namespace tint::resolver diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 865c243c00..5f7e4ca9f4 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -184,13 +185,12 @@ class Resolver { sem::Function* Function(const ast::Function*); sem::Call* FunctionCall(const ast::CallExpression*, sem::Function* target, - const std::vector args, + std::vector args, sem::Behaviors arg_behaviors); sem::Expression* Identifier(const ast::IdentifierExpression*); sem::Call* BuiltinCall(const ast::CallExpression*, sem::BuiltinType, - const std::vector args, - const std::vector arg_tys); + std::vector args); sem::Expression* Literal(const ast::LiteralExpression*); sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*); sem::Expression* UnaryOp(const ast::UnaryOpExpression*); @@ -218,6 +218,13 @@ class Resolver { sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); bool Statements(const ast::StatementList&); + // CollectTextureSamplerPairs() collects all the texture/sampler pairs from the target function + // / builtin, and records these on the current function by calling AddTextureSamplerPair(). + void CollectTextureSamplerPairs(sem::Function* func, + const std::vector& args) const; + void CollectTextureSamplerPairs(const sem::Builtin* builtin, + const std::vector& args) const; + /// Resolves the WorkgroupSize for the given function, assigning it to /// current_function_ bool WorkgroupSize(const ast::Function*); @@ -332,22 +339,13 @@ class Resolver { /// @returns true if the symbol is the name of a builtin function. bool IsBuiltin(Symbol) const; - struct TypeConstructorSig { - const sem::Type* type; - const std::vector parameters; + // ArrayConstructorSig represents a unique array constructor signature. + // It is a tuple of the array type and number of arguments provided. + using ArrayConstructorSig = utils::UnorderedKeyWrapper>; - TypeConstructorSig(const sem::Type* ty, const std::vector params); - TypeConstructorSig(const TypeConstructorSig&); - ~TypeConstructorSig(); - bool operator==(const TypeConstructorSig&) const; - - /// Hasher provides a hash function for the TypeConstructorSig - struct Hasher { - /// @param sig the TypeConstructorSig to create a hash for - /// @return the hash value - std::size_t operator()(const TypeConstructorSig& sig) const; - }; - }; + // StructConstructorSig represents a unique structure constructor signature. + // It is a tuple of the structure type and number of arguments provided. + using StructConstructorSig = utils::UnorderedKeyWrapper>; ProgramBuilder* const builder_; diag::List& diagnostics_; @@ -360,8 +358,8 @@ class Resolver { std::unordered_map atomic_composite_info_; std::unordered_set marked_; std::unordered_map constant_ids_; - std::unordered_map - type_ctors_; + std::unordered_map array_ctors_; + std::unordered_map struct_ctors_; sem::Function* current_function_ = nullptr; sem::Statement* current_statement_ = nullptr;