diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc index 5ca8ee2109..58ba84f189 100644 --- a/src/tint/resolver/intrinsic_table.cc +++ b/src/tint/resolver/intrinsic_table.cc @@ -104,12 +104,10 @@ const Number Number::invalid{Number::kInvalid}; /// Used by the MatchState. class ClosedState { public: - explicit ClosedState(ProgramBuilder& b) : builder(b) {} - /// If the type with index `idx` is open, then it is closed with type `ty` and /// Type() returns true. If the type is closed, then `Type()` returns true iff /// it is equal to `ty`. - bool Type(uint32_t idx, const sem::Type* ty) { + bool Type(size_t idx, const sem::Type* ty) { auto res = types_.emplace(idx, ty); return res.second || res.first->second == ty; } @@ -117,33 +115,27 @@ class ClosedState { /// If the number with index `idx` is open, then it is closed with number /// `number` and Num() returns true. If the number is closed, then `Num()` /// returns true iff it is equal to `ty`. - bool Num(uint32_t idx, Number number) { + bool Num(size_t idx, Number number) { auto res = numbers_.emplace(idx, number.Value()); return res.second || res.first->second == number.Value(); } /// Type returns the closed type with index `idx`, or nullptr if the type was not closed. - const sem::Type* Type(uint32_t idx) const { + const sem::Type* Type(size_t idx) const { auto it = types_.find(idx); return (it != types_.end()) ? it->second : nullptr; } /// Type returns the number type with index `idx`. /// An ICE is raised if the number is not closed. - Number Num(uint32_t idx) const { + Number Num(size_t idx) const { auto it = numbers_.find(idx); - if (it == numbers_.end()) { - TINT_ICE(Resolver, builder.Diagnostics()) - << "number with index " << idx << " is not closed"; - return Number::invalid; - } - return Number(it->second); + return (it != numbers_.end()) ? Number(it->second) : Number::invalid; } private: - ProgramBuilder& builder; - std::unordered_map types_; - std::unordered_map numbers_; + std::unordered_map types_; + std::unordered_map numbers_; }; /// Index type used for matcher indices @@ -158,7 +150,7 @@ class MatchState { MatchState(ProgramBuilder& b, ClosedState& c, const Matchers& m, - const OverloadInfo& o, + const OverloadInfo* o, MatcherIndex const* matcher_indices) : builder(b), closed(c), matchers(m), overload(o), matcher_indices_(matcher_indices) {} @@ -169,7 +161,7 @@ class MatchState { /// The type and number matchers Matchers const& matchers; /// The current overload being evaluated - OverloadInfo const& overload; + OverloadInfo const* overload; /// Type uses the next TypeMatcher from the matcher indices to match the type /// `ty`. If the type matches, the canonical expected type is returned. If the @@ -240,7 +232,7 @@ class NumberMatcher { class OpenTypeMatcher : public TypeMatcher { public: /// Constructor - explicit OpenTypeMatcher(uint32_t index) : index_(index) {} + explicit OpenTypeMatcher(size_t index) : index_(index) {} const sem::Type* Match(MatchState& state, const sem::Type* type) const override { if (type->Is()) { @@ -252,7 +244,7 @@ class OpenTypeMatcher : public TypeMatcher { std::string String(MatchState& state) const override; private: - uint32_t index_; + size_t index_; }; /// OpenNumberMatcher is a Matcher for an open number. @@ -260,7 +252,7 @@ class OpenTypeMatcher : public TypeMatcher { /// consistent for the overload) class OpenNumberMatcher : public NumberMatcher { public: - explicit OpenNumberMatcher(uint32_t index) : index_(index) {} + explicit OpenNumberMatcher(size_t index) : index_(index) {} Number Match(MatchState& state, Number number) const override { if (number.IsAny()) { @@ -272,7 +264,7 @@ class OpenNumberMatcher : public NumberMatcher { std::string String(MatchState& state) const override; private: - uint32_t index_; + size_t index_; }; //////////////////////////////////////////////////////////////////////////////// @@ -879,16 +871,16 @@ class Impl : public IntrinsicTable { /// Candidate holds information about an overload evaluated for resolution. struct Candidate { /// The candidate overload - const OverloadInfo& overload; + const OverloadInfo* overload; /// The closed types and numbers ClosedState closed; /// The parameter types for the candidate overload std::vector parameters; - /// True if the candidate is a viable match for the call - bool matched; - /// The match-score of the candidate overload. Used for diagnostics when no overload - /// matches. Higher scores are displayed first (top-most). - int score; + /// The match-score of the candidate overload. + /// A score of zero indicates an exact match. + /// Non-zero scores are used for diagnostics when no overload matches. + /// Lower scores are displayed first (top-most). + size_t score; }; /// A list of candidates @@ -922,7 +914,7 @@ class Impl : public IntrinsicTable { /// arguments. For example `vec3()` would have the first template-type closed /// as `f32`. /// @returns the evaluated Candidate information. - Candidate ScoreOverload(const OverloadInfo& overload, + Candidate ScoreOverload(const OverloadInfo* overload, const std::vector& args, ClosedState closed) const; @@ -931,12 +923,12 @@ class Impl : public IntrinsicTable { /// @param overload the overload being evaluated /// @param matcher_indices pointer to a list of matcher indices MatchState Match(ClosedState& closed, - const OverloadInfo& overload, + const OverloadInfo* overload, MatcherIndex const* matcher_indices) const; // Prints the overload for emitting diagnostics void PrintOverload(std::ostream& ss, - const OverloadInfo& overload, + const OverloadInfo* overload, const char* intrinsic_name) const; // Prints the list of candidates for emitting diagnostics @@ -945,7 +937,7 @@ class Impl : public IntrinsicTable { const char* intrinsic_name) const; /// Raises an ICE when multiple overload candidates match, as this should never happen. - void ErrMultipleOverloadsMatched(uint32_t num_matched, + void ErrMultipleOverloadsMatched(size_t num_matched, const char* intrinsic_name, const std::vector& args, ClosedState closed, @@ -988,11 +980,11 @@ std::string CallSignature(ProgramBuilder& builder, } std::string OpenTypeMatcher::String(MatchState& state) const { - return state.overload.open_types[index_].name; + return state.overload->open_types[index_].name; } std::string OpenNumberMatcher::String(MatchState& state) const { - return state.overload.open_numbers[index_].name; + return state.overload->open_numbers[index_].name; } Impl::Impl(ProgramBuilder& b) : builder(b) {} @@ -1016,8 +1008,8 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type, }; // Resolve the intrinsic overload - auto match = MatchIntrinsic(kBuiltins[static_cast(builtin_type)], intrinsic_name, - args, ClosedState(builder), on_no_match); + auto match = MatchIntrinsic(kBuiltins[static_cast(builtin_type)], intrinsic_name, args, + ClosedState{}, on_no_match); if (!match.overload) { return {}; } @@ -1050,7 +1042,7 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type, IntrinsicTable::UnaryOperator Impl::Lookup(ast::UnaryOp op, const sem::Type* arg, const Source& source) { - auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair { + auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair { switch (op) { case ast::UnaryOp::kComplement: return {kUnaryOperatorComplement, "operator ~ "}; @@ -1078,7 +1070,7 @@ IntrinsicTable::UnaryOperator Impl::Lookup(ast::UnaryOp op, // Resolve the intrinsic overload auto match = MatchIntrinsic(kUnaryOperators[intrinsic_index], intrinsic_name, {arg}, - ClosedState(builder), on_no_match); + ClosedState{}, on_no_match); if (!match.overload) { return {}; } @@ -1091,7 +1083,7 @@ IntrinsicTable::BinaryOperator Impl::Lookup(ast::BinaryOp op, const sem::Type* rhs, const Source& source, bool is_compound) { - auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair { + auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair { switch (op) { case ast::BinaryOp::kAnd: return {kBinaryOperatorAnd, is_compound ? "operator &= " : "operator & "}; @@ -1149,7 +1141,7 @@ IntrinsicTable::BinaryOperator Impl::Lookup(ast::BinaryOp op, // Resolve the intrinsic overload auto match = MatchIntrinsic(kBinaryOperators[intrinsic_index], intrinsic_name, {lhs, rhs}, - ClosedState(builder), on_no_match); + ClosedState{}, on_no_match); if (!match.overload) { return {}; } @@ -1170,7 +1162,7 @@ const sem::CallTarget* Impl::Lookup(CtorConvIntrinsic type, << std::endl; Candidates ctor, conv; for (auto candidate : candidates) { - if (candidate.overload.flags.Contains(OverloadFlag::kIsConstructor)) { + if (candidate.overload->flags.Contains(OverloadFlag::kIsConstructor)) { ctor.emplace_back(candidate); } else { conv.emplace_back(candidate); @@ -1192,13 +1184,13 @@ const sem::CallTarget* Impl::Lookup(CtorConvIntrinsic type, }; // If a template type was provided, then close the 0'th type with this. - ClosedState closed(builder); + ClosedState closed; if (template_arg) { closed.Type(0, template_arg); } // Resolve the intrinsic overload - auto match = MatchIntrinsic(kConstructorsAndConverters[static_cast(type)], name, args, + auto match = MatchIntrinsic(kConstructorsAndConverters[static_cast(type)], name, args, closed, on_no_match); if (!match.overload) { return {}; @@ -1232,40 +1224,21 @@ IntrinsicPrototype Impl::MatchIntrinsic(const IntrinsicInfo& intrinsic, const std::vector& args, ClosedState closed, OnNoMatch on_no_match) const { - uint32_t num_matched = 0; + size_t num_matched = 0; Candidates candidates; candidates.reserve(intrinsic.num_overloads); - for (uint8_t overload_idx = 0; overload_idx < intrinsic.num_overloads; overload_idx++) { - auto candidate = ScoreOverload(intrinsic.overloads[overload_idx], args, closed); - if (candidate.matched) { + for (size_t overload_idx = 0; overload_idx < static_cast(intrinsic.num_overloads); + overload_idx++) { + auto candidate = ScoreOverload(&intrinsic.overloads[overload_idx], args, closed); + if (candidate.score == 0) { num_matched++; } candidates.emplace_back(std::move(candidate)); } // Sort the candidates with the most promising first - { - std::vector candidate_indices(candidates.size()); - for (size_t i = 0; i < candidate_indices.size(); i++) { - candidate_indices[i] = i; - } - std::stable_sort(candidate_indices.begin(), candidate_indices.end(), - [&](size_t a, size_t b) { - if (candidates[a].matched && !candidates[b].matched) { - return true; - } - if (candidates[b].matched && !candidates[a].matched) { - return false; - } - return candidates[a].score > candidates[b].score; - }); - Candidates candidates_sorted; - candidates_sorted.reserve(candidate_indices.size()); - for (size_t idx : candidate_indices) { - candidates_sorted.emplace_back(std::move(candidates[idx])); - } - std::swap(candidates, candidates_sorted); - } + std::stable_sort(candidates.begin(), candidates.end(), + [&](const Candidate& a, const Candidate& b) { return a.score < b.score; }); // How many candidates matched? switch (num_matched) { @@ -1282,7 +1255,7 @@ IntrinsicPrototype Impl::MatchIntrinsic(const IntrinsicInfo& intrinsic, // Build the return type const sem::Type* return_type = nullptr; - if (auto* indices = match.overload.return_matcher_indices) { + if (auto* indices = match.overload->return_matcher_indices) { Any any; return_type = Match(match.closed, match.overload, indices).Type(&any); if (!return_type) { @@ -1293,101 +1266,91 @@ IntrinsicPrototype Impl::MatchIntrinsic(const IntrinsicInfo& intrinsic, return_type = builder.create(); } - return IntrinsicPrototype{&match.overload, return_type, std::move(match.parameters)}; + return IntrinsicPrototype{match.overload, return_type, std::move(match.parameters)}; } -Impl::Candidate Impl::ScoreOverload(const OverloadInfo& overload, +Impl::Candidate Impl::ScoreOverload(const OverloadInfo* overload, const std::vector& args, ClosedState closed) const { - // Score weight for argument <-> parameter count matches / mismatches + // Penalty weights for overload mismatching. // This scoring is used to order the suggested overloads in diagnostic on overload mismatch, and // has no impact for a correct program. - // The overloads with the highest score will be displayed first (top-most). - constexpr int kScorePerParamArgMismatch = -1; - constexpr int kScorePerMatchedParam = 2; - constexpr int kScorePerMatchedOpenType = 1; - constexpr int kScorePerMatchedOpenNumber = 1; + // The overloads with the lowest score will be displayed first (top-most). + constexpr int kMismatchedParamCountPenalty = 3; + constexpr int kMismatchedParamTypePenalty = 2; + constexpr int kMismatchedOpenTypePenalty = 1; + constexpr int kMismatchedOpenNumberPenalty = 1; - uint32_t num_parameters = static_cast(overload.num_parameters); - uint32_t num_arguments = static_cast(args.size()); + size_t num_parameters = static_cast(overload->num_parameters); + size_t num_arguments = static_cast(args.size()); - bool overload_matched = true; - int overload_score = 0; - - if (static_cast(args.size()) > - static_cast(std::numeric_limits::max())) { - overload_matched = false; // No overload has this number of arguments. - } + size_t score = 0; if (num_parameters != num_arguments) { - overload_score += kScorePerParamArgMismatch * (std::max(num_parameters, num_arguments) - - std::min(num_parameters, num_arguments)); - overload_matched = false; + score += kMismatchedParamCountPenalty * (std::max(num_parameters, num_arguments) - + std::min(num_parameters, num_arguments)); } std::vector parameters; auto num_params = std::min(num_parameters, num_arguments); - for (uint32_t p = 0; p < num_params; p++) { - auto& parameter = overload.parameters[p]; + for (size_t p = 0; p < num_params; p++) { + auto& parameter = overload->parameters[p]; auto* indices = parameter.matcher_indices; auto* type = Match(closed, overload, indices).Type(args[p]->UnwrapRef()); if (type) { parameters.emplace_back(IntrinsicPrototype::Parameter{type, parameter.usage}); - overload_score += kScorePerMatchedParam; } else { - overload_matched = false; + score += kMismatchedParamTypePenalty; } } - if (overload_matched) { + if (score == 0) { // Check all constrained open types matched - for (uint32_t ot = 0; ot < overload.num_open_types; ot++) { - auto& open_type = overload.open_types[ot]; + for (size_t ot = 0; ot < overload->num_open_types; ot++) { + auto& open_type = overload->open_types[ot]; if (open_type.matcher_index != kNoMatcher) { auto* closed_type = closed.Type(ot); auto* matcher_index = &open_type.matcher_index; - if (closed_type && Match(closed, overload, matcher_index).Type(closed_type)) { - overload_score += kScorePerMatchedOpenType; - } else { - overload_matched = false; + if (!closed_type || !Match(closed, overload, matcher_index).Type(closed_type)) { + score += kMismatchedOpenTypePenalty; } } } } - if (overload_matched) { + if (score == 0) { // Check all constrained open numbers matched - for (uint32_t on = 0; on < overload.num_open_numbers; on++) { - auto& open_number = overload.open_numbers[on]; + for (size_t on = 0; on < overload->num_open_numbers; on++) { + auto& open_number = overload->open_numbers[on]; if (open_number.matcher_index != kNoMatcher) { + auto closed_num = closed.Num(on); auto* index = &open_number.matcher_index; - if (Match(closed, overload, index).Num(closed.Num(on)).IsValid()) { - overload_score += kScorePerMatchedOpenNumber; - } else { - overload_matched = false; + if (!closed_num.IsValid() || + !Match(closed, overload, index).Num(closed_num).IsValid()) { + score += kMismatchedOpenNumberPenalty; } } } } - return Candidate{overload, closed, parameters, overload_matched, overload_score}; + return Candidate{overload, closed, parameters, score}; } MatchState Impl::Match(ClosedState& closed, - const OverloadInfo& overload, + const OverloadInfo* overload, MatcherIndex const* matcher_indices) const { return MatchState(builder, closed, matchers, overload, matcher_indices); } void Impl::PrintOverload(std::ostream& ss, - const OverloadInfo& overload, + const OverloadInfo* overload, const char* intrinsic_name) const { - ClosedState closed(builder); + ClosedState closed; ss << intrinsic_name << "("; - for (uint32_t p = 0; p < overload.num_parameters; p++) { - auto& parameter = overload.parameters[p]; + for (size_t p = 0; p < overload->num_parameters; p++) { + auto& parameter = overload->parameters[p]; if (p > 0) { ss << ", "; } @@ -1398,9 +1361,9 @@ void Impl::PrintOverload(std::ostream& ss, ss << Match(closed, overload, indices).TypeName(); } ss << ")"; - if (overload.return_matcher_indices) { + if (overload->return_matcher_indices) { ss << " -> "; - auto* indices = overload.return_matcher_indices; + auto* indices = overload->return_matcher_indices; ss << Match(closed, overload, indices).TypeName(); } @@ -1409,8 +1372,8 @@ void Impl::PrintOverload(std::ostream& ss, ss << (first ? " where: " : ", "); first = false; }; - for (uint32_t i = 0; i < overload.num_open_types; i++) { - auto& open_type = overload.open_types[i]; + for (size_t i = 0; i < overload->num_open_types; i++) { + auto& open_type = overload->open_types[i]; if (open_type.matcher_index != kNoMatcher) { separator(); ss << open_type.name; @@ -1418,8 +1381,8 @@ void Impl::PrintOverload(std::ostream& ss, ss << " is " << Match(closed, overload, index).TypeName(); } } - for (uint32_t i = 0; i < overload.num_open_numbers; i++) { - auto& open_number = overload.open_numbers[i]; + for (size_t i = 0; i < overload->num_open_numbers; i++) { + auto& open_number = overload->open_numbers[i]; if (open_number.matcher_index != kNoMatcher) { separator(); ss << open_number.name; @@ -1463,14 +1426,14 @@ std::string MatchState::NumName() { return matcher->String(*this); } -void Impl::ErrMultipleOverloadsMatched(uint32_t num_matched, +void Impl::ErrMultipleOverloadsMatched(size_t num_matched, const char* intrinsic_name, const std::vector& args, ClosedState closed, Candidates candidates) const { std::stringstream ss; ss << num_matched << " overloads matched " << intrinsic_name; - for (uint32_t i = 0; i < 0xffffffffu; i++) { + for (size_t i = 0; i < std::numeric_limits::max(); i++) { if (auto* ty = closed.Type(i)) { ss << ((i == 0) ? "<" : ", ") << ty->FriendlyName(builder.Symbols()); } else if (i > 0) { @@ -1489,7 +1452,7 @@ void Impl::ErrMultipleOverloadsMatched(uint32_t num_matched, } ss << "):\n"; for (auto& candidate : candidates) { - if (candidate.matched) { + if (candidate.score == 0) { ss << " "; PrintOverload(ss, candidate.overload, intrinsic_name); ss << std::endl;