tint: Cleanup of IntrinsicTable

Remove the ProgramBuilder from ClosedState and use a pointer for the
'overload' field instead of a reference. Let's the Candidate be
copy-assignable, which in turn, allows the Candidates vector to be
sorted directly, instead of jumping through hoops to use moves.

Replace random mix of 'int', 'uint8_t' with 'size_t' (externally to the
constant table data). Reduces fragile weak binding between distant code.

Swap the overload scoring order (high-best -> low-best). Remove the
'matched' field - we can now just check whether the 'score' is 0.
Further simplifies sorting.

Change-Id: I4a4b7934be337306202647d096c546eab5c8498f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90641
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-17 17:54:32 +00:00 committed by Dawn LUCI CQ
parent 5211b0b0fd
commit 661e33ca18
1 changed files with 85 additions and 122 deletions

View File

@ -104,12 +104,10 @@ const Number Number::invalid{Number::kInvalid};
/// Used by the MatchState.
class ClosedState {
public:
explicit ClosedState(ProgramBuilder& b) : builder(b) {}
/// If the type with index `idx` is open, then it is closed with type `ty` and
/// Type() returns true. If the type is closed, then `Type()` returns true iff
/// it is equal to `ty`.
bool Type(uint32_t idx, const sem::Type* ty) {
bool Type(size_t idx, const sem::Type* ty) {
auto res = types_.emplace(idx, ty);
return res.second || res.first->second == ty;
}
@ -117,33 +115,27 @@ class ClosedState {
/// If the number with index `idx` is open, then it is closed with number
/// `number` and Num() returns true. If the number is closed, then `Num()`
/// returns true iff it is equal to `ty`.
bool Num(uint32_t idx, Number number) {
bool Num(size_t idx, Number number) {
auto res = numbers_.emplace(idx, number.Value());
return res.second || res.first->second == number.Value();
}
/// Type returns the closed type with index `idx`, or nullptr if the type was not closed.
const sem::Type* Type(uint32_t idx) const {
const sem::Type* Type(size_t idx) const {
auto it = types_.find(idx);
return (it != types_.end()) ? it->second : nullptr;
}
/// Type returns the number type with index `idx`.
/// An ICE is raised if the number is not closed.
Number Num(uint32_t idx) const {
Number Num(size_t idx) const {
auto it = numbers_.find(idx);
if (it == numbers_.end()) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "number with index " << idx << " is not closed";
return Number::invalid;
}
return Number(it->second);
return (it != numbers_.end()) ? Number(it->second) : Number::invalid;
}
private:
ProgramBuilder& builder;
std::unordered_map<uint32_t, const sem::Type*> types_;
std::unordered_map<uint32_t, uint32_t> numbers_;
std::unordered_map<size_t, const sem::Type*> types_;
std::unordered_map<size_t, uint32_t> numbers_;
};
/// Index type used for matcher indices
@ -158,7 +150,7 @@ class MatchState {
MatchState(ProgramBuilder& b,
ClosedState& c,
const Matchers& m,
const OverloadInfo& o,
const OverloadInfo* o,
MatcherIndex const* matcher_indices)
: builder(b), closed(c), matchers(m), overload(o), matcher_indices_(matcher_indices) {}
@ -169,7 +161,7 @@ class MatchState {
/// The type and number matchers
Matchers const& matchers;
/// The current overload being evaluated
OverloadInfo const& overload;
OverloadInfo const* overload;
/// Type uses the next TypeMatcher from the matcher indices to match the type
/// `ty`. If the type matches, the canonical expected type is returned. If the
@ -240,7 +232,7 @@ class NumberMatcher {
class OpenTypeMatcher : public TypeMatcher {
public:
/// Constructor
explicit OpenTypeMatcher(uint32_t index) : index_(index) {}
explicit OpenTypeMatcher(size_t index) : index_(index) {}
const sem::Type* Match(MatchState& state, const sem::Type* type) const override {
if (type->Is<Any>()) {
@ -252,7 +244,7 @@ class OpenTypeMatcher : public TypeMatcher {
std::string String(MatchState& state) const override;
private:
uint32_t index_;
size_t index_;
};
/// OpenNumberMatcher is a Matcher for an open number.
@ -260,7 +252,7 @@ class OpenTypeMatcher : public TypeMatcher {
/// consistent for the overload)
class OpenNumberMatcher : public NumberMatcher {
public:
explicit OpenNumberMatcher(uint32_t index) : index_(index) {}
explicit OpenNumberMatcher(size_t index) : index_(index) {}
Number Match(MatchState& state, Number number) const override {
if (number.IsAny()) {
@ -272,7 +264,7 @@ class OpenNumberMatcher : public NumberMatcher {
std::string String(MatchState& state) const override;
private:
uint32_t index_;
size_t index_;
};
////////////////////////////////////////////////////////////////////////////////
@ -879,16 +871,16 @@ class Impl : public IntrinsicTable {
/// Candidate holds information about an overload evaluated for resolution.
struct Candidate {
/// The candidate overload
const OverloadInfo& overload;
const OverloadInfo* overload;
/// The closed types and numbers
ClosedState closed;
/// The parameter types for the candidate overload
std::vector<IntrinsicPrototype::Parameter> parameters;
/// True if the candidate is a viable match for the call
bool matched;
/// The match-score of the candidate overload. Used for diagnostics when no overload
/// matches. Higher scores are displayed first (top-most).
int score;
/// The match-score of the candidate overload.
/// A score of zero indicates an exact match.
/// Non-zero scores are used for diagnostics when no overload matches.
/// Lower scores are displayed first (top-most).
size_t score;
};
/// A list of candidates
@ -922,7 +914,7 @@ class Impl : public IntrinsicTable {
/// arguments. For example `vec3<f32>()` would have the first template-type closed
/// as `f32`.
/// @returns the evaluated Candidate information.
Candidate ScoreOverload(const OverloadInfo& overload,
Candidate ScoreOverload(const OverloadInfo* overload,
const std::vector<const sem::Type*>& args,
ClosedState closed) const;
@ -931,12 +923,12 @@ class Impl : public IntrinsicTable {
/// @param overload the overload being evaluated
/// @param matcher_indices pointer to a list of matcher indices
MatchState Match(ClosedState& closed,
const OverloadInfo& overload,
const OverloadInfo* overload,
MatcherIndex const* matcher_indices) const;
// Prints the overload for emitting diagnostics
void PrintOverload(std::ostream& ss,
const OverloadInfo& overload,
const OverloadInfo* overload,
const char* intrinsic_name) const;
// Prints the list of candidates for emitting diagnostics
@ -945,7 +937,7 @@ class Impl : public IntrinsicTable {
const char* intrinsic_name) const;
/// Raises an ICE when multiple overload candidates match, as this should never happen.
void ErrMultipleOverloadsMatched(uint32_t num_matched,
void ErrMultipleOverloadsMatched(size_t num_matched,
const char* intrinsic_name,
const std::vector<const sem::Type*>& args,
ClosedState closed,
@ -988,11 +980,11 @@ std::string CallSignature(ProgramBuilder& builder,
}
std::string OpenTypeMatcher::String(MatchState& state) const {
return state.overload.open_types[index_].name;
return state.overload->open_types[index_].name;
}
std::string OpenNumberMatcher::String(MatchState& state) const {
return state.overload.open_numbers[index_].name;
return state.overload->open_numbers[index_].name;
}
Impl::Impl(ProgramBuilder& b) : builder(b) {}
@ -1016,8 +1008,8 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
};
// Resolve the intrinsic overload
auto match = MatchIntrinsic(kBuiltins[static_cast<uint32_t>(builtin_type)], intrinsic_name,
args, ClosedState(builder), on_no_match);
auto match = MatchIntrinsic(kBuiltins[static_cast<size_t>(builtin_type)], intrinsic_name, args,
ClosedState{}, on_no_match);
if (!match.overload) {
return {};
}
@ -1050,7 +1042,7 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
IntrinsicTable::UnaryOperator Impl::Lookup(ast::UnaryOp op,
const sem::Type* arg,
const Source& source) {
auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<uint32_t, const char*> {
auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<size_t, const char*> {
switch (op) {
case ast::UnaryOp::kComplement:
return {kUnaryOperatorComplement, "operator ~ "};
@ -1078,7 +1070,7 @@ IntrinsicTable::UnaryOperator Impl::Lookup(ast::UnaryOp op,
// Resolve the intrinsic overload
auto match = MatchIntrinsic(kUnaryOperators[intrinsic_index], intrinsic_name, {arg},
ClosedState(builder), on_no_match);
ClosedState{}, on_no_match);
if (!match.overload) {
return {};
}
@ -1091,7 +1083,7 @@ IntrinsicTable::BinaryOperator Impl::Lookup(ast::BinaryOp op,
const sem::Type* rhs,
const Source& source,
bool is_compound) {
auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<uint32_t, const char*> {
auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<size_t, const char*> {
switch (op) {
case ast::BinaryOp::kAnd:
return {kBinaryOperatorAnd, is_compound ? "operator &= " : "operator & "};
@ -1149,7 +1141,7 @@ IntrinsicTable::BinaryOperator Impl::Lookup(ast::BinaryOp op,
// Resolve the intrinsic overload
auto match = MatchIntrinsic(kBinaryOperators[intrinsic_index], intrinsic_name, {lhs, rhs},
ClosedState(builder), on_no_match);
ClosedState{}, on_no_match);
if (!match.overload) {
return {};
}
@ -1170,7 +1162,7 @@ const sem::CallTarget* Impl::Lookup(CtorConvIntrinsic type,
<< std::endl;
Candidates ctor, conv;
for (auto candidate : candidates) {
if (candidate.overload.flags.Contains(OverloadFlag::kIsConstructor)) {
if (candidate.overload->flags.Contains(OverloadFlag::kIsConstructor)) {
ctor.emplace_back(candidate);
} else {
conv.emplace_back(candidate);
@ -1192,13 +1184,13 @@ const sem::CallTarget* Impl::Lookup(CtorConvIntrinsic type,
};
// If a template type was provided, then close the 0'th type with this.
ClosedState closed(builder);
ClosedState closed;
if (template_arg) {
closed.Type(0, template_arg);
}
// Resolve the intrinsic overload
auto match = MatchIntrinsic(kConstructorsAndConverters[static_cast<uint32_t>(type)], name, args,
auto match = MatchIntrinsic(kConstructorsAndConverters[static_cast<size_t>(type)], name, args,
closed, on_no_match);
if (!match.overload) {
return {};
@ -1232,40 +1224,21 @@ IntrinsicPrototype Impl::MatchIntrinsic(const IntrinsicInfo& intrinsic,
const std::vector<const sem::Type*>& args,
ClosedState closed,
OnNoMatch on_no_match) const {
uint32_t num_matched = 0;
size_t num_matched = 0;
Candidates candidates;
candidates.reserve(intrinsic.num_overloads);
for (uint8_t overload_idx = 0; overload_idx < intrinsic.num_overloads; overload_idx++) {
auto candidate = ScoreOverload(intrinsic.overloads[overload_idx], args, closed);
if (candidate.matched) {
for (size_t overload_idx = 0; overload_idx < static_cast<size_t>(intrinsic.num_overloads);
overload_idx++) {
auto candidate = ScoreOverload(&intrinsic.overloads[overload_idx], args, closed);
if (candidate.score == 0) {
num_matched++;
}
candidates.emplace_back(std::move(candidate));
}
// Sort the candidates with the most promising first
{
std::vector<size_t> candidate_indices(candidates.size());
for (size_t i = 0; i < candidate_indices.size(); i++) {
candidate_indices[i] = i;
}
std::stable_sort(candidate_indices.begin(), candidate_indices.end(),
[&](size_t a, size_t b) {
if (candidates[a].matched && !candidates[b].matched) {
return true;
}
if (candidates[b].matched && !candidates[a].matched) {
return false;
}
return candidates[a].score > candidates[b].score;
});
Candidates candidates_sorted;
candidates_sorted.reserve(candidate_indices.size());
for (size_t idx : candidate_indices) {
candidates_sorted.emplace_back(std::move(candidates[idx]));
}
std::swap(candidates, candidates_sorted);
}
std::stable_sort(candidates.begin(), candidates.end(),
[&](const Candidate& a, const Candidate& b) { return a.score < b.score; });
// How many candidates matched?
switch (num_matched) {
@ -1282,7 +1255,7 @@ IntrinsicPrototype Impl::MatchIntrinsic(const IntrinsicInfo& intrinsic,
// Build the return type
const sem::Type* return_type = nullptr;
if (auto* indices = match.overload.return_matcher_indices) {
if (auto* indices = match.overload->return_matcher_indices) {
Any any;
return_type = Match(match.closed, match.overload, indices).Type(&any);
if (!return_type) {
@ -1293,101 +1266,91 @@ IntrinsicPrototype Impl::MatchIntrinsic(const IntrinsicInfo& intrinsic,
return_type = builder.create<sem::Void>();
}
return IntrinsicPrototype{&match.overload, return_type, std::move(match.parameters)};
return IntrinsicPrototype{match.overload, return_type, std::move(match.parameters)};
}
Impl::Candidate Impl::ScoreOverload(const OverloadInfo& overload,
Impl::Candidate Impl::ScoreOverload(const OverloadInfo* overload,
const std::vector<const sem::Type*>& args,
ClosedState closed) const {
// Score weight for argument <-> parameter count matches / mismatches
// Penalty weights for overload mismatching.
// This scoring is used to order the suggested overloads in diagnostic on overload mismatch, and
// has no impact for a correct program.
// The overloads with the highest score will be displayed first (top-most).
constexpr int kScorePerParamArgMismatch = -1;
constexpr int kScorePerMatchedParam = 2;
constexpr int kScorePerMatchedOpenType = 1;
constexpr int kScorePerMatchedOpenNumber = 1;
// The overloads with the lowest score will be displayed first (top-most).
constexpr int kMismatchedParamCountPenalty = 3;
constexpr int kMismatchedParamTypePenalty = 2;
constexpr int kMismatchedOpenTypePenalty = 1;
constexpr int kMismatchedOpenNumberPenalty = 1;
uint32_t num_parameters = static_cast<uint32_t>(overload.num_parameters);
uint32_t num_arguments = static_cast<uint32_t>(args.size());
size_t num_parameters = static_cast<size_t>(overload->num_parameters);
size_t num_arguments = static_cast<size_t>(args.size());
bool overload_matched = true;
int overload_score = 0;
if (static_cast<uint64_t>(args.size()) >
static_cast<uint64_t>(std::numeric_limits<uint32_t>::max())) {
overload_matched = false; // No overload has this number of arguments.
}
size_t score = 0;
if (num_parameters != num_arguments) {
overload_score += kScorePerParamArgMismatch * (std::max(num_parameters, num_arguments) -
std::min(num_parameters, num_arguments));
overload_matched = false;
score += kMismatchedParamCountPenalty * (std::max(num_parameters, num_arguments) -
std::min(num_parameters, num_arguments));
}
std::vector<IntrinsicPrototype::Parameter> parameters;
auto num_params = std::min(num_parameters, num_arguments);
for (uint32_t p = 0; p < num_params; p++) {
auto& parameter = overload.parameters[p];
for (size_t p = 0; p < num_params; p++) {
auto& parameter = overload->parameters[p];
auto* indices = parameter.matcher_indices;
auto* type = Match(closed, overload, indices).Type(args[p]->UnwrapRef());
if (type) {
parameters.emplace_back(IntrinsicPrototype::Parameter{type, parameter.usage});
overload_score += kScorePerMatchedParam;
} else {
overload_matched = false;
score += kMismatchedParamTypePenalty;
}
}
if (overload_matched) {
if (score == 0) {
// Check all constrained open types matched
for (uint32_t ot = 0; ot < overload.num_open_types; ot++) {
auto& open_type = overload.open_types[ot];
for (size_t ot = 0; ot < overload->num_open_types; ot++) {
auto& open_type = overload->open_types[ot];
if (open_type.matcher_index != kNoMatcher) {
auto* closed_type = closed.Type(ot);
auto* matcher_index = &open_type.matcher_index;
if (closed_type && Match(closed, overload, matcher_index).Type(closed_type)) {
overload_score += kScorePerMatchedOpenType;
} else {
overload_matched = false;
if (!closed_type || !Match(closed, overload, matcher_index).Type(closed_type)) {
score += kMismatchedOpenTypePenalty;
}
}
}
}
if (overload_matched) {
if (score == 0) {
// Check all constrained open numbers matched
for (uint32_t on = 0; on < overload.num_open_numbers; on++) {
auto& open_number = overload.open_numbers[on];
for (size_t on = 0; on < overload->num_open_numbers; on++) {
auto& open_number = overload->open_numbers[on];
if (open_number.matcher_index != kNoMatcher) {
auto closed_num = closed.Num(on);
auto* index = &open_number.matcher_index;
if (Match(closed, overload, index).Num(closed.Num(on)).IsValid()) {
overload_score += kScorePerMatchedOpenNumber;
} else {
overload_matched = false;
if (!closed_num.IsValid() ||
!Match(closed, overload, index).Num(closed_num).IsValid()) {
score += kMismatchedOpenNumberPenalty;
}
}
}
}
return Candidate{overload, closed, parameters, overload_matched, overload_score};
return Candidate{overload, closed, parameters, score};
}
MatchState Impl::Match(ClosedState& closed,
const OverloadInfo& overload,
const OverloadInfo* overload,
MatcherIndex const* matcher_indices) const {
return MatchState(builder, closed, matchers, overload, matcher_indices);
}
void Impl::PrintOverload(std::ostream& ss,
const OverloadInfo& overload,
const OverloadInfo* overload,
const char* intrinsic_name) const {
ClosedState closed(builder);
ClosedState closed;
ss << intrinsic_name << "(";
for (uint32_t p = 0; p < overload.num_parameters; p++) {
auto& parameter = overload.parameters[p];
for (size_t p = 0; p < overload->num_parameters; p++) {
auto& parameter = overload->parameters[p];
if (p > 0) {
ss << ", ";
}
@ -1398,9 +1361,9 @@ void Impl::PrintOverload(std::ostream& ss,
ss << Match(closed, overload, indices).TypeName();
}
ss << ")";
if (overload.return_matcher_indices) {
if (overload->return_matcher_indices) {
ss << " -> ";
auto* indices = overload.return_matcher_indices;
auto* indices = overload->return_matcher_indices;
ss << Match(closed, overload, indices).TypeName();
}
@ -1409,8 +1372,8 @@ void Impl::PrintOverload(std::ostream& ss,
ss << (first ? " where: " : ", ");
first = false;
};
for (uint32_t i = 0; i < overload.num_open_types; i++) {
auto& open_type = overload.open_types[i];
for (size_t i = 0; i < overload->num_open_types; i++) {
auto& open_type = overload->open_types[i];
if (open_type.matcher_index != kNoMatcher) {
separator();
ss << open_type.name;
@ -1418,8 +1381,8 @@ void Impl::PrintOverload(std::ostream& ss,
ss << " is " << Match(closed, overload, index).TypeName();
}
}
for (uint32_t i = 0; i < overload.num_open_numbers; i++) {
auto& open_number = overload.open_numbers[i];
for (size_t i = 0; i < overload->num_open_numbers; i++) {
auto& open_number = overload->open_numbers[i];
if (open_number.matcher_index != kNoMatcher) {
separator();
ss << open_number.name;
@ -1463,14 +1426,14 @@ std::string MatchState::NumName() {
return matcher->String(*this);
}
void Impl::ErrMultipleOverloadsMatched(uint32_t num_matched,
void Impl::ErrMultipleOverloadsMatched(size_t num_matched,
const char* intrinsic_name,
const std::vector<const sem::Type*>& args,
ClosedState closed,
Candidates candidates) const {
std::stringstream ss;
ss << num_matched << " overloads matched " << intrinsic_name;
for (uint32_t i = 0; i < 0xffffffffu; i++) {
for (size_t i = 0; i < std::numeric_limits<size_t>::max(); i++) {
if (auto* ty = closed.Type(i)) {
ss << ((i == 0) ? "<" : ", ") << ty->FriendlyName(builder.Symbols());
} else if (i > 0) {
@ -1489,7 +1452,7 @@ void Impl::ErrMultipleOverloadsMatched(uint32_t num_matched,
}
ss << "):\n";
for (auto& candidate : candidates) {
if (candidate.matched) {
if (candidate.score == 0) {
ss << " ";
PrintOverload(ss, candidate.overload, intrinsic_name);
ss << std::endl;