diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def index a6792b13ae..28be2d9f85 100644 --- a/src/tint/intrinsics.def +++ b/src/tint/intrinsics.def @@ -208,13 +208,12 @@ match workgroup_or_storage: workgroup | storage // Matching algorithm for a single overload: // // ----------------------------------------- // // // -// The goal of matching is to compare a function call's arguments in the // -// program source against a possibly-templated overload declaration, and // -// determine if the call satisfies the form and type constraints of the // -// overload. Currently it is impossible for a call to match more than one // -// overload definition. In the event that more than one overload matches, an // -// ICE will be raised. Note that Tint may need to support multiple-overload // -// resolution in the future, depending on future overload definitions. // +// The goal of matching is to compare a function call's arguments and any // +// explicitly provided template types in the program source against an // +// overload declaration in this file, and determine if the call satisfies // +// the form and type constraints of the overload. If the call matches an // +// overload, then the overload is added to the list of 'overload candidates' // +// used for overload resolution (described below). // // // // Prior to matching an overload, all template types are undefined. // // // @@ -258,11 +257,11 @@ match workgroup_or_storage: workgroup | storage // need to be checked next. If the defined type does not match the // // 'match' constraint, then the overload is no longer considered. // // // -// This algorithm is less general than the overload resolution described in // -// the WGSL spec. But it makes the same decisions because the overloads // -// defined by WGSL are monotonic in the sense that once a template parameter // -// has been refined, there is never a need to backtrack and un-refine it to // -// match a later argument. // +// This algorithm for matching a single overload is less general than the // +// algorithm described in the WGSL spec. But it makes the same decisions // +// because the overloads defined by WGSL are monotonic in the sense that once // +// a template parameter has been refined, there is never a need to backtrack // +// and un-refine it to match a later argument. // // // // The algorithm for matching template numbers is similar to matching // // template types, except numbers need to exactly match across all uses - // @@ -270,7 +269,24 @@ match workgroup_or_storage: workgroup | storage // numbers or enumerators. // // // // // -// * More examples: // +// Overload resolution for candidate overloads // +// ------------------------------------------- // +// // +// If multiple candidate overloads match a given set of arguments, then a // +// final overload resolution pass needs to be performed. The arguments and // +// overload parameter types for each candidate overload are compared, // +// following the algorithm described at: // +// https://www.w3.org/TR/WGSL/#overload-resolution-section // +// // +// If the candidate list contains a single entry, then that single candidate // +// is picked, and no overload resolution needs to be performed. // +// // +// If the candidate list is empty, then the call fails to resolve and an // +// error diagnostic is raised. // +// // +// // +// More examples // +// ------------- // // // // fn F() // // - Function called F. // diff --git a/src/tint/resolver/intrinsic_table.cc b/src/tint/resolver/intrinsic_table.cc index 85711769f1..4d9ab7799b 100644 --- a/src/tint/resolver/intrinsic_table.cc +++ b/src/tint/resolver/intrinsic_table.cc @@ -944,6 +944,12 @@ class Impl : public IntrinsicTable { /// Callback function when no overloads match. using OnNoMatch = std::function; + /// Sorts the candidates based on their score, with the lowest (best-ranking) scores first. + static inline void SortCandidates(Candidates& candidates) { + std::stable_sort(candidates.begin(), candidates.end(), + [&](const Candidate& a, const Candidate& b) { return a.score < b.score; }); + } + /// Attempts to find a single intrinsic overload that matches the provided argument types. /// @param intrinsic the intrinsic being called /// @param intrinsic_name the name of the intrinsic @@ -962,7 +968,7 @@ class Impl : public IntrinsicTable { TemplateState templates, OnNoMatch on_no_match) const; - /// Evaluates the overload for the provided argument types. + /// Evaluates the single overload for the provided argument types. /// @param overload the overload being considered /// @param args the argument types /// @param templates initial template state. This may contain explicitly specified template @@ -973,6 +979,21 @@ class Impl : public IntrinsicTable { const std::vector& args, TemplateState templates) const; + /// Performs overload resolution given the list of candidates, by ranking the conversions of + /// arguments to the each of the candidate's parameter types. + /// @param candidates the list of candidate overloads + /// @param intrinsic_name the name of the intrinsic + /// @param args the argument types + /// @param templates initial template state. This may contain explicitly specified template + /// arguments. For example `vec3()` would have the first template-type + /// template as `f32`. + /// @see https://www.w3.org/TR/WGSL/#overload-resolution-section + /// @returns the resolved Candidate. + Candidate ResolveCandidate(Candidates&& candidates, + const char* intrinsic_name, + const std::vector& args, + TemplateState templates) const; + /// Match constructs a new MatchState /// @param templates the template state used for matcher evaluation /// @param overload the overload being evaluated @@ -991,12 +1012,11 @@ class Impl : public IntrinsicTable { const Candidates& candidates, const char* intrinsic_name) const; - /// Raises an ICE when multiple overload candidates match, as this should never happen. - void ErrMultipleOverloadsMatched(size_t num_matched, - const char* intrinsic_name, - const std::vector& args, - TemplateState templates, - Candidates candidates) const; + /// Raises an error when no overload is a clear winner of overload resolution + void ErrAmbiguousOverload(const char* intrinsic_name, + const std::vector& args, + TemplateState templates, + Candidates candidates) const; ProgramBuilder& builder; Matchers matchers; @@ -1280,38 +1300,38 @@ IntrinsicPrototype Impl::MatchIntrinsic(const IntrinsicInfo& intrinsic, TemplateState templates, OnNoMatch on_no_match) const { size_t num_matched = 0; + size_t match_idx = 0; Candidates candidates; candidates.reserve(intrinsic.num_overloads); for (size_t overload_idx = 0; overload_idx < static_cast(intrinsic.num_overloads); overload_idx++) { auto candidate = ScoreOverload(&intrinsic.overloads[overload_idx], args, templates); if (candidate.score == 0) { + match_idx = overload_idx; num_matched++; } candidates.emplace_back(std::move(candidate)); } - // Sort the candidates with the most promising first - std::stable_sort(candidates.begin(), candidates.end(), - [&](const Candidate& a, const Candidate& b) { return a.score < b.score; }); - // How many candidates matched? - switch (num_matched) { - case 0: - on_no_match(std::move(candidates)); - return {}; - case 1: - break; - default: - // Note: Currently the intrinsic table does not contain any overloads which may result - // in ambiguities, so here we call ErrMultipleOverloadsMatched() which will produce and - // ICE. If we end up in the situation where this is unavoidable, we'll need to perform - // further overload resolution as described in - // https://www.w3.org/TR/WGSL/#overload-resolution-section. - ErrMultipleOverloadsMatched(num_matched, intrinsic_name, args, templates, candidates); + if (num_matched == 0) { + // Sort the candidates with the most promising first + SortCandidates(candidates); + on_no_match(std::move(candidates)); + return {}; } - auto match = candidates[0]; + Candidate match; + + if (num_matched == 1) { + match = std::move(candidates[match_idx]); + } else { + match = ResolveCandidate(std::move(candidates), intrinsic_name, args, std::move(templates)); + if (!match.overload) { + // Ambiguous overload. ResolveCandidate() will have already raised an error diagnostic. + return {}; + } + } // Build the return type const sem::Type* return_type = nullptr; @@ -1423,6 +1443,70 @@ Impl::Candidate Impl::ScoreOverload(const OverloadInfo* overload, return Candidate{overload, templates, parameters, score}; } +Impl::Candidate Impl::ResolveCandidate(Impl::Candidates&& candidates, + const char* intrinsic_name, + const std::vector& args, + TemplateState templates) const { + std::vector best_ranks(args.size(), 0xffffffff); + size_t num_matched = 0; + Candidate* best = nullptr; + for (auto& candidate : candidates) { + if (candidate.score > 0) { + continue; // Candidate has already been ruled out. + } + bool some_won = false; // An argument ranked less than the 'best' overload's argument + bool some_lost = false; // An argument ranked more than the 'best' overload's argument + for (size_t i = 0; i < args.size(); i++) { + auto rank = sem::Type::ConversionRank(args[i], candidate.parameters[i].type); + if (best_ranks[i] > rank) { + best_ranks[i] = rank; + some_won = true; + } else if (best_ranks[i] < rank) { + some_lost = true; + } + } + // If no arguments of this candidate ranked worse than the previous best candidate, then + // this candidate becomes the new best candidate. + // If no arguments of this candidate ranked better than the previous best candidate, then + // this candidate is removed from the list of matches. + // If neither of the above apply, then we have two candidates with no clear winner, which + // results in an ambiguous overload error. In this situation the loop ends with + // `num_matched > 1`. + if (some_won) { + // One or more arguments of this candidate ranked better than the previous best + // candidate's argument(s). + num_matched++; + if (!some_lost) { + // All arguments were at as-good or better than the previous best. + if (best) { + // Mark the previous best candidate as no longer being in the running, by + // setting its score to a non-zero value. We pick 1 as this is the closest to 0 + // (match) as we can get. + best->score = 1; + num_matched--; + } + // This candidate is the new best. + best = &candidate; + } + } else { + // No arguments ranked better than the current best. + // Change the score of this candidate to a non-zero value, so that it's not considered a + // match. + candidate.score = 1; + } + } + + if (num_matched > 1) { + // Re-sort the candidates with the most promising first + SortCandidates(candidates); + // Raise an error + ErrAmbiguousOverload(intrinsic_name, args, templates, candidates); + return {}; + } + + return std::move(*best); +} + MatchState Impl::Match(TemplateState& templates, const OverloadInfo* overload, MatcherIndex const* matcher_indices) const { @@ -1512,13 +1596,12 @@ std::string MatchState::NumName() { return matcher->String(this); } -void Impl::ErrMultipleOverloadsMatched(size_t num_matched, - const char* intrinsic_name, - const std::vector& args, - TemplateState templates, - Candidates candidates) const { +void Impl::ErrAmbiguousOverload(const char* intrinsic_name, + const std::vector& args, + TemplateState templates, + Candidates candidates) const { std::stringstream ss; - ss << num_matched << " overloads matched " << intrinsic_name; + ss << "ambiguous overload while attempting to match " << intrinsic_name; for (size_t i = 0; i < std::numeric_limits::max(); i++) { if (auto* ty = templates.Type(i)) { ss << ((i == 0) ? "<" : ", ") << ty->FriendlyName(builder.Symbols()); diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc index 4a5c1714b2..9c5de1e5a8 100644 --- a/src/tint/resolver/intrinsic_table_test.cc +++ b/src/tint/resolver/intrinsic_table_test.cc @@ -793,6 +793,20 @@ TEST_F(IntrinsicTableTest, Err257Arguments) { // crbug.com/1323605 ASSERT_THAT(Diagnostics().str(), HasSubstr("no matching call")); } +TEST_F(IntrinsicTableTest, OverloadResolution) { + // i32(abstract-int) produces candidates for both: + // ctor i32(i32) -> i32 + // conv i32(T) -> i32 + // The first should win overload resolution. + auto* ai = create(); + auto* i32 = create(); + auto result = table->Lookup(CtorConvIntrinsic::kI32, nullptr, {ai}, Source{}); + ASSERT_NE(result, nullptr); + EXPECT_EQ(result->ReturnType(), i32); + EXPECT_EQ(result->Parameters().size(), 1u); + EXPECT_EQ(result->Parameters()[0]->Type(), i32); +} + //////////////////////////////////////////////////////////////////////////////// // AbstractBinaryTests ////////////////////////////////////////////////////////////////////////////////