tint: Add binary-ops to the intrinsic table

• Declare all the binary ops in the intrinsics.def file.
• Reimplement Resolver::BinaryOpType() with the IntrinsicTable.

This will simplify maintenance of the operators, and will greatly
simplify the [AbstractInt -> i32|u32] [AbstractFloat -> f32|f16] logic.

Bug: tint:1504
Change-Id: Ie028602e05b59916c3f2168c92f200f10e402b96
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89027
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-09 20:00:13 +00:00 committed by Dawn LUCI CQ
parent bcdb6e9da8
commit 9fb29a364e
10 changed files with 5226 additions and 3679 deletions

View File

@ -558,3 +558,91 @@ fn textureLoad(texture: texture_external, coords: vec2<i32>) -> vec4<f32>
[[stage("fragment", "compute")]] fn atomicXor<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T [[stage("fragment", "compute")]] fn atomicXor<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T
[[stage("fragment", "compute")]] fn atomicExchange<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T [[stage("fragment", "compute")]] fn atomicExchange<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T
[[stage("fragment", "compute")]] fn atomicCompareExchangeWeak<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T, T) -> vec2<T> [[stage("fragment", "compute")]] fn atomicCompareExchangeWeak<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T, T) -> vec2<T>
////////////////////////////////////////////////////////////////////////////////
// Operators //
// //
// The operator declarations below declare all the unary and binary operators //
// supported by the WGSL language (with exception for address-of and //
// dereference unary operators). //
// //
// The syntax is almost identical to builtin functions, except we use 'op' //
// instead of 'fn'. The resolving rules are identical to builtins, which is //
// described in detail above. //
// //
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// Binary Operators //
////////////////////////////////////////////////////////////////////////////////
op + <T: fiu32>(T, T) -> T
op + <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op + <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op + <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op + <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
op - <T: fiu32>(T, T) -> T
op - <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op - <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op - <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op - <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
op * <T: fiu32>(T, T) -> T
op * <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op * <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op * <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op * <N: num, M: num> (f32, mat<N, M, f32>) -> mat<N, M, f32>
op * <N: num, M: num> (mat<N, M, f32>, f32) -> mat<N, M, f32>
op * <C: num, R: num> (mat<C, R, f32>, vec<C, f32>) -> vec<R, f32>
op * <C: num, R: num> (vec<R, f32>, mat<C, R, f32>) -> vec<C, f32>
op * <K: num, C: num, R: num> (mat<K, R, f32>, mat<C, K, f32>) -> mat<C, R, f32>
op / <T: fiu32>(T, T) -> T
op / <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op / <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op / <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op % <T: fiu32>(T, T) -> T
op % <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op % <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op % <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op ^ <T: iu32>(T, T) -> T
op ^ <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op & (bool, bool) -> bool
op & <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
op & <T: iu32>(T, T) -> T
op & <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op | (bool, bool) -> bool
op | <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
op | <T: iu32>(T, T) -> T
op | <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op && (bool, bool) -> bool
op || (bool, bool) -> bool
op == <T: scalar>(T, T) -> bool
op == <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op != <T: scalar>(T, T) -> bool
op != <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op < <T: fiu32>(T, T) -> bool
op < <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op > <T: fiu32>(T, T) -> bool
op > <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op <= <T: fiu32>(T, T) -> bool
op <= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op >= <T: fiu32>(T, T) -> bool
op >= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op << <T: iu32>(T, u32) -> T
op << <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
op >> <T: iu32>(T, u32) -> T
op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>

View File

@ -18,6 +18,8 @@
#include "src/tint/resolver/resolver_test_helper.h" #include "src/tint/resolver/resolver_test_helper.h"
#include "src/tint/sem/storage_texture.h" #include "src/tint/sem/storage_texture.h"
using ::testing::HasSubstr;
using namespace tint::number_suffixes; // NOLINT using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver { namespace tint::resolver {
@ -71,9 +73,8 @@ TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleTypes) {
ASSERT_FALSE(r()->Resolve()); ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_THAT(r()->error(),
"12:34 error: compound assignment operand types are invalid: i32 " HasSubstr("12:34 error: no matching overload for operator += (i32, f32)"));
"add f32");
} }
TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleOp) { TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleOp) {
@ -89,8 +90,8 @@ TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleOp) {
ASSERT_FALSE(r()->Resolve()); ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_THAT(r()->error(),
"12:34 error: compound assignment operand types are invalid: f32 or f32"); HasSubstr("12:34 error: no matching overload for operator |= (f32, f32)"));
} }
TEST_F(ResolverCompoundAssignmentValidationTest, VectorScalar_Pass) { TEST_F(ResolverCompoundAssignmentValidationTest, VectorScalar_Pass) {
@ -180,9 +181,9 @@ TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ColumnMismatch) {
ASSERT_FALSE(r()->Resolve()); ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_THAT(
"12:34 error: compound assignment operand types are invalid: " r()->error(),
"vec4<f32> multiply mat4x2<f32>"); HasSubstr("12:34 error: no matching overload for operator *= (vec4<f32>, mat4x2<f32>)"));
} }
TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ResultMismatch) { TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ResultMismatch) {
@ -223,9 +224,8 @@ TEST_F(ResolverCompoundAssignmentValidationTest, Phony) {
// } // }
WrapInFunction(CompoundAssign(Source{{56, 78}}, Phony(), 1_i, ast::BinaryOp::kAdd)); WrapInFunction(CompoundAssign(Source{{56, 78}}, Phony(), 1_i, ast::BinaryOp::kAdd));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_THAT(r()->error(),
"56:78 error: compound assignment operand types are invalid: void " HasSubstr("56:78 error: no matching overload for operator += (void, i32)"));
"add i32");
} }
TEST_F(ResolverCompoundAssignmentValidationTest, ReadOnlyBuffer) { TEST_F(ResolverCompoundAssignmentValidationTest, ReadOnlyBuffer) {
@ -239,8 +239,7 @@ TEST_F(ResolverCompoundAssignmentValidationTest, ReadOnlyBuffer) {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"56:78 error: cannot store into a read-only type 'ref<storage, " "56:78 error: cannot store into a read-only type 'ref<storage, i32, read>'");
"i32, read>'");
} }
TEST_F(ResolverCompoundAssignmentValidationTest, LhsConstant) { TEST_F(ResolverCompoundAssignmentValidationTest, LhsConstant) {
@ -269,9 +268,9 @@ TEST_F(ResolverCompoundAssignmentValidationTest, LhsAtomic) {
WrapInFunction(CompoundAssign(Source{{56, 78}}, "a", "a", ast::BinaryOp::kAdd)); WrapInFunction(CompoundAssign(Source{{56, 78}}, "a", "a", ast::BinaryOp::kAdd));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_THAT(
"56:78 error: compound assignment operand types are invalid: " r()->error(),
"atomic<i32> add atomic<i32>"); HasSubstr("error: no matching overload for operator += (atomic<i32>, atomic<i32>)"));
} }
} // namespace } // namespace

View File

@ -762,7 +762,7 @@ struct OverloadInfo {
bool is_deprecated; bool is_deprecated;
}; };
/// IntrinsicInfo describes a builtin function /// IntrinsicInfo describes a builtin function or operator overload
struct IntrinsicInfo { struct IntrinsicInfo {
/// Number of overloads of the intrinsic /// Number of overloads of the intrinsic
const uint8_t num_overloads; const uint8_t num_overloads;
@ -791,11 +791,11 @@ struct IntrinsicPrototype {
for (auto& p : i.parameters) { for (auto& p : i.parameters) {
utils::HashCombine(&hash, p.type, p.usage); utils::HashCombine(&hash, p.type, p.usage);
} }
return utils::Hash(hash, i.type, i.return_type, i.supported_stages, i.is_deprecated); return utils::Hash(hash, i.index, i.return_type, i.supported_stages, i.is_deprecated);
} }
}; };
sem::BuiltinType type = sem::BuiltinType::kNone; uint32_t index = 0; // Index of the intrinsic (builtin or operator)
std::vector<Parameter> parameters; std::vector<Parameter> parameters;
sem::Type const* return_type = nullptr; sem::Type const* return_type = nullptr;
PipelineStageSet supported_stages; PipelineStageSet supported_stages;
@ -804,7 +804,7 @@ struct IntrinsicPrototype {
/// Equality operator for IntrinsicPrototype /// Equality operator for IntrinsicPrototype
bool operator==(const IntrinsicPrototype& a, const IntrinsicPrototype& b) { bool operator==(const IntrinsicPrototype& a, const IntrinsicPrototype& b) {
if (a.type != b.type || a.supported_stages != b.supported_stages || if (a.index != b.index || a.supported_stages != b.supported_stages ||
a.return_type != b.return_type || a.is_deprecated != b.is_deprecated || a.return_type != b.return_type || a.is_deprecated != b.is_deprecated ||
a.parameters.size() != b.parameters.size()) { a.parameters.size() != b.parameters.size()) {
return false; return false;
@ -828,8 +828,22 @@ class Impl : public IntrinsicTable {
const std::vector<const sem::Type*>& args, const std::vector<const sem::Type*>& args,
const Source& source) override; const Source& source) override;
BinaryOperator Lookup(ast::BinaryOp op,
const sem::Type* lhs,
const sem::Type* rhs,
const Source& source,
bool is_compound) override;
private: private:
const sem::Builtin* Match(sem::BuiltinType builtin_type, // Candidate holds information about a mismatched overload that could be what the user intended
// to call.
struct Candidate {
const OverloadInfo* overload;
int score;
};
const IntrinsicPrototype Match(const char* intrinsic_name,
uint32_t intrinsic_index,
const OverloadInfo& overload, const OverloadInfo& overload,
const std::vector<const sem::Type*>& args, const std::vector<const sem::Type*>& args,
int& match_score); int& match_score);
@ -838,9 +852,7 @@ class Impl : public IntrinsicTable {
const OverloadInfo& overload, const OverloadInfo& overload,
MatcherIndex const* matcher_indices) const; MatcherIndex const* matcher_indices) const;
void PrintOverload(std::ostream& ss, void PrintOverload(std::ostream& ss, const OverloadInfo& overload, const char* name) const;
const OverloadInfo& overload,
sem::BuiltinType builtin_type) const;
ProgramBuilder& builder; ProgramBuilder& builder;
Matchers matchers; Matchers matchers;
@ -850,10 +862,10 @@ class Impl : public IntrinsicTable {
/// @return a string representing a call to a builtin with the given argument /// @return a string representing a call to a builtin with the given argument
/// types. /// types.
std::string CallSignature(ProgramBuilder& builder, std::string CallSignature(ProgramBuilder& builder,
sem::BuiltinType builtin_type, const char* intrinsic_name,
const std::vector<const sem::Type*>& args) { const std::vector<const sem::Type*>& args) {
std::stringstream ss; std::stringstream ss;
ss << sem::str(builtin_type) << "("; ss << intrinsic_name << "(";
{ {
bool first = true; bool first = true;
for (auto* arg : args) { for (auto* arg : args) {
@ -882,22 +894,30 @@ Impl::Impl(ProgramBuilder& b) : builder(b) {}
const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type, const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
const std::vector<const sem::Type*>& args, const std::vector<const sem::Type*>& args,
const Source& source) { const Source& source) {
// Candidate holds information about a mismatched overload that could be what
// the user intended to call.
struct Candidate {
const OverloadInfo* overload;
int score;
};
// The list of failed matches that had promise. // The list of failed matches that had promise.
std::vector<Candidate> candidates; std::vector<Candidate> candidates;
auto& builtin = kBuiltins[static_cast<uint32_t>(builtin_type)]; uint32_t intrinsic_index = static_cast<uint32_t>(builtin_type);
const char* intrinsic_name = sem::str(builtin_type);
auto& builtin = kBuiltins[intrinsic_index];
for (uint32_t o = 0; o < builtin.num_overloads; o++) { for (uint32_t o = 0; o < builtin.num_overloads; o++) {
int match_score = 1000; int match_score = 1000;
auto& overload = builtin.overloads[o]; auto& overload = builtin.overloads[o];
if (auto* match = Match(builtin_type, overload, args, match_score)) { auto match = Match(intrinsic_name, intrinsic_index, overload, args, match_score);
return match; if (match.return_type) {
// De-duplicate builtins that are identical.
return utils::GetOrCreate(builtins, match, [&] {
std::vector<sem::Parameter*> params;
params.reserve(match.parameters.size());
for (auto& p : match.parameters) {
params.emplace_back(builder.create<sem::Parameter>(
nullptr, static_cast<uint32_t>(params.size()), p.type,
ast::StorageClass::kNone, ast::Access::kUndefined, p.usage));
}
return builder.create<sem::Builtin>(builtin_type, match.return_type,
std::move(params), match.supported_stages,
match.is_deprecated);
});
} }
if (match_score > 0) { if (match_score > 0) {
candidates.emplace_back(Candidate{&overload, match_score}); candidates.emplace_back(Candidate{&overload, match_score});
@ -910,14 +930,14 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
// Generate an error message // Generate an error message
std::stringstream ss; std::stringstream ss;
ss << "no matching call to " << CallSignature(builder, builtin_type, args) << std::endl; ss << "no matching call to " << CallSignature(builder, intrinsic_name, args) << std::endl;
if (!candidates.empty()) { if (!candidates.empty()) {
ss << std::endl; ss << std::endl;
ss << candidates.size() << " candidate function" << (candidates.size() > 1 ? "s:" : ":") ss << candidates.size() << " candidate function" << (candidates.size() > 1 ? "s:" : ":")
<< std::endl; << std::endl;
for (auto& candidate : candidates) { for (auto& candidate : candidates) {
ss << " "; ss << " ";
PrintOverload(ss, *candidate.overload, builtin_type); PrintOverload(ss, *candidate.overload, intrinsic_name);
ss << std::endl; ss << std::endl;
} }
} }
@ -925,7 +945,95 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
return nullptr; return nullptr;
} }
const sem::Builtin* Impl::Match(sem::BuiltinType builtin_type, IntrinsicTable::BinaryOperator Impl::Lookup(ast::BinaryOp op,
const sem::Type* lhs,
const sem::Type* rhs,
const Source& source,
bool is_compound) {
// The list of failed matches that had promise.
std::vector<Candidate> candidates;
auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<uint32_t, const char*> {
switch (op) {
case ast::BinaryOp::kAnd:
return {kOperatorAnd, is_compound ? "operator &= " : "operator & "};
case ast::BinaryOp::kOr:
return {kOperatorOr, is_compound ? "operator |= " : "operator | "};
case ast::BinaryOp::kXor:
return {kOperatorXor, is_compound ? "operator ^= " : "operator ^ "};
case ast::BinaryOp::kLogicalAnd:
return {kOperatorLogicalAnd, "operator && "};
case ast::BinaryOp::kLogicalOr:
return {kOperatorLogicalOr, "operator || "};
case ast::BinaryOp::kEqual:
return {kOperatorEqual, "operator == "};
case ast::BinaryOp::kNotEqual:
return {kOperatorNotEqual, "operator != "};
case ast::BinaryOp::kLessThan:
return {kOperatorLessThan, "operator < "};
case ast::BinaryOp::kGreaterThan:
return {kOperatorGreaterThan, "operator > "};
case ast::BinaryOp::kLessThanEqual:
return {kOperatorLessThanEqual, "operator <= "};
case ast::BinaryOp::kGreaterThanEqual:
return {kOperatorGreaterThanEqual, "operator >= "};
case ast::BinaryOp::kShiftLeft:
return {kOperatorShiftLeft, is_compound ? "operator <<= " : "operator << "};
case ast::BinaryOp::kShiftRight:
return {kOperatorShiftRight, is_compound ? "operator >>= " : "operator >> "};
case ast::BinaryOp::kAdd:
return {kOperatorPlus, is_compound ? "operator += " : "operator + "};
case ast::BinaryOp::kSubtract:
return {kOperatorMinus, is_compound ? "operator -= " : "operator - "};
case ast::BinaryOp::kMultiply:
return {kOperatorStar, is_compound ? "operator *= " : "operator * "};
case ast::BinaryOp::kDivide:
return {kOperatorDivide, is_compound ? "operator /= " : "operator / "};
case ast::BinaryOp::kModulo:
return {kOperatorModulo, is_compound ? "operator %= " : "operator % "};
default:
return {0, "<unknown>"};
}
}();
auto& builtin = kOperators[intrinsic_index];
for (uint32_t o = 0; o < builtin.num_overloads; o++) {
int match_score = 1000;
auto& overload = builtin.overloads[o];
auto match = Match(intrinsic_name, intrinsic_index, overload, {lhs, rhs}, match_score);
if (match.return_type) {
return BinaryOperator{match.return_type, match.parameters[0].type,
match.parameters[1].type};
}
if (match_score > 0) {
candidates.emplace_back(Candidate{&overload, match_score});
}
}
// Sort the candidates with the most promising first
std::stable_sort(candidates.begin(), candidates.end(),
[](const Candidate& a, const Candidate& b) { return a.score > b.score; });
// Generate an error message
std::stringstream ss;
ss << "no matching overload for " << CallSignature(builder, intrinsic_name, {lhs, rhs})
<< std::endl;
if (!candidates.empty()) {
ss << std::endl;
ss << candidates.size() << " candidate operator" << (candidates.size() > 1 ? "s:" : ":")
<< std::endl;
for (auto& candidate : candidates) {
ss << " ";
PrintOverload(ss, *candidate.overload, intrinsic_name);
ss << std::endl;
}
}
builder.Diagnostics().add_error(diag::System::Resolver, ss.str(), source);
return {};
}
const IntrinsicPrototype Impl::Match(const char* intrinsic_name,
uint32_t intrinsic_index,
const OverloadInfo& overload, const OverloadInfo& overload,
const std::vector<const sem::Type*>& args, const std::vector<const sem::Type*>& args,
int& match_score) { int& match_score) {
@ -994,7 +1102,7 @@ const sem::Builtin* Impl::Match(sem::BuiltinType builtin_type,
} }
if (!overload_matched) { if (!overload_matched) {
return nullptr; return {};
} }
// Build the return type // Build the return type
@ -1004,34 +1112,22 @@ const sem::Builtin* Impl::Match(sem::BuiltinType builtin_type,
return_type = Match(closed, overload, indices).Type(&any); return_type = Match(closed, overload, indices).Type(&any);
if (!return_type) { if (!return_type) {
std::stringstream ss; std::stringstream ss;
PrintOverload(ss, overload, builtin_type); PrintOverload(ss, overload, intrinsic_name);
TINT_ICE(Resolver, builder.Diagnostics()) TINT_ICE(Resolver, builder.Diagnostics())
<< "MatchState.Match() returned null for " << ss.str(); << "MatchState.Match() returned null for " << ss.str();
return nullptr; return {};
} }
} else { } else {
return_type = builder.create<sem::Void>(); return_type = builder.create<sem::Void>();
} }
IntrinsicPrototype builtin; IntrinsicPrototype builtin;
builtin.type = builtin_type; builtin.index = intrinsic_index;
builtin.return_type = return_type; builtin.return_type = return_type;
builtin.parameters = std::move(parameters); builtin.parameters = std::move(parameters);
builtin.supported_stages = overload.supported_stages; builtin.supported_stages = overload.supported_stages;
builtin.is_deprecated = overload.is_deprecated; builtin.is_deprecated = overload.is_deprecated;
return builtin;
// De-duplicate builtins that are identical.
return utils::GetOrCreate(builtins, builtin, [&] {
std::vector<sem::Parameter*> params;
params.reserve(builtin.parameters.size());
for (auto& p : builtin.parameters) {
params.emplace_back(builder.create<sem::Parameter>(
nullptr, static_cast<uint32_t>(params.size()), p.type, ast::StorageClass::kNone,
ast::Access::kUndefined, p.usage));
}
return builder.create<sem::Builtin>(builtin.type, builtin.return_type, std::move(params),
builtin.supported_stages, builtin.is_deprecated);
});
} }
MatchState Impl::Match(ClosedState& closed, MatchState Impl::Match(ClosedState& closed,
@ -1040,12 +1136,10 @@ MatchState Impl::Match(ClosedState& closed,
return MatchState(builder, closed, matchers, overload, matcher_indices); return MatchState(builder, closed, matchers, overload, matcher_indices);
} }
void Impl::PrintOverload(std::ostream& ss, void Impl::PrintOverload(std::ostream& ss, const OverloadInfo& overload, const char* name) const {
const OverloadInfo& overload,
sem::BuiltinType builtin_type) const {
ClosedState closed(builder); ClosedState closed(builder);
ss << builtin_type << "("; ss << name << "(";
for (uint32_t p = 0; p < overload.num_parameters; p++) { for (uint32_t p = 0; p < overload.num_parameters; p++) {
auto& parameter = overload.parameters[p]; auto& parameter = overload.parameters[p];
if (p > 0) { if (p > 0) {

View File

@ -28,7 +28,7 @@ class ProgramBuilder;
namespace tint { namespace tint {
/// IntrinsicTable is a lookup table of all the WGSL builtin functions /// IntrinsicTable is a lookup table of all the WGSL builtin functions and intrinsic operators
class IntrinsicTable { class IntrinsicTable {
public: public:
/// @param builder the program builder /// @param builder the program builder
@ -38,8 +38,18 @@ class IntrinsicTable {
/// Destructor /// Destructor
virtual ~IntrinsicTable(); virtual ~IntrinsicTable();
/// Lookup looks for the builtin overload with the given signature, raising /// BinaryOperator describes a resolved binary operator
/// an error diagnostic if the builtin was not found. struct BinaryOperator {
/// The result type of the binary operator
const sem::Type* result;
/// The type of LHS of the binary operator
const sem::Type* lhs;
/// The type of RHS of the binary operator
const sem::Type* rhs;
};
/// Lookup looks for the builtin overload with the given signature, raising an error diagnostic
/// if the builtin was not found.
/// @param type the builtin type /// @param type the builtin type
/// @param args the argument types passed to the builtin function /// @param args the argument types passed to the builtin function
/// @param source the source of the builtin call /// @param source the source of the builtin call
@ -47,6 +57,21 @@ class IntrinsicTable {
virtual const sem::Builtin* Lookup(sem::BuiltinType type, virtual const sem::Builtin* Lookup(sem::BuiltinType type,
const std::vector<const sem::Type*>& args, const std::vector<const sem::Type*>& args,
const Source& source) = 0; const Source& source) = 0;
/// Lookup looks for the binary op overload with the given signature, raising an error
/// diagnostic if the operator was not found.
/// @param op the binary operator
/// @param lhs the LHS value type passed to the operator
/// @param rhs the RHS value type passed to the operator
/// @param source the source of the operator call
/// @param is_compound true if the binary operator is being used as a compound assignment
/// @return the operator call target signature. If the operator was not found
/// BinaryOperator::result will be nullptr.
virtual BinaryOperator Lookup(ast::BinaryOp op,
const sem::Type* lhs,
const sem::Type* rhs,
const Source& source,
bool is_compound) = 0;
}; };
} // namespace tint } // namespace tint

File diff suppressed because it is too large Load Diff

View File

@ -119,6 +119,23 @@ constexpr IntrinsicInfo kBuiltins[] = {
{{- end }} {{- end }}
}; };
constexpr IntrinsicInfo kOperators[] = {
{{- range $i, $o := .Operators }}
{
/* [{{$i}}] */
{{- range $o.OverloadDescriptions }}
/* {{.}} */
{{- end }}
/* num overloads */ {{$o.NumOverloads}},
/* overloads */ &kOverloads[{{$o.OverloadsOffset}}],
},
{{- end }}
};
{{- range $i, $o := .Operators }}
constexpr uint8_t kOperator{{template "OperatorName" $o.Name}} = {{$i}};
{{- end }}
// clang-format on // clang-format on
{{ end -}} {{ end -}}
@ -399,3 +416,29 @@ Matchers::~Matchers() = default;
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}
{{- /* ------------------------------------------------------------------ */ -}}
{{- define "OperatorName" -}}
{{- /* ------------------------------------------------------------------ */ -}}
{{- if eq . "<<" -}}ShiftLeft
{{- else if eq . "&" -}}And
{{- else if eq . "|" -}}Or
{{- else if eq . "^" -}}Xor
{{- else if eq . "&&" -}}LogicalAnd
{{- else if eq . "||" -}}LogicalOr
{{- else if eq . "==" -}}Equal
{{- else if eq . "!=" -}}NotEqual
{{- else if eq . "<" -}}LessThan
{{- else if eq . ">" -}}GreaterThan
{{- else if eq . "<=" -}}LessThanEqual
{{- else if eq . ">=" -}}GreaterThanEqual
{{- else if eq . "<<" -}}ShiftLeft
{{- else if eq . ">>" -}}ShiftRight
{{- else if eq . "+" -}}Plus
{{- else if eq . "-" -}}Minus
{{- else if eq . "*" -}}Star
{{- else if eq . "/" -}}Divide
{{- else if eq . "%" -}}Modulo
{{- else -}}<unknown-{{.}}>
{{- end -}}
{{- end -}}

View File

@ -576,5 +576,69 @@ TEST_F(IntrinsicTableTest, SameOverloadReturnsSameBuiltinPointer) {
EXPECT_NE(b, c); EXPECT_NE(b, c);
} }
TEST_F(IntrinsicTableTest, MatchBinaryOp) {
auto* i32 = create<sem::I32>();
auto* vec3_i32 = create<sem::Vector>(i32, 3u);
auto result = table->Lookup(ast::BinaryOp::kMultiply, i32, vec3_i32, Source{{12, 34}},
/* is_compound */ false);
EXPECT_EQ(result.result, vec3_i32);
EXPECT_EQ(result.lhs, i32);
EXPECT_EQ(result.rhs, vec3_i32);
EXPECT_EQ(Diagnostics().str(), "");
}
TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
auto* f32 = create<sem::F32>();
auto* bool_ = create<sem::Bool>();
auto result = table->Lookup(ast::BinaryOp::kMultiply, f32, bool_, Source{{12, 34}},
/* is_compound */ false);
ASSERT_EQ(result.result, nullptr);
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
9 candidate operators:
operator * (T, T) -> T where: T is f32, i32 or u32
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator * (f32, matNxM<f32>) -> matNxM<f32>
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator * (matNxM<f32>, f32) -> matNxM<f32>
operator * (matCxR<f32>, vecC<f32>) -> vecR<f32>
operator * (vecR<f32>, matCxR<f32>) -> vecC<f32>
operator * (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
)");
}
TEST_F(IntrinsicTableTest, MatchCompoundOp) {
auto* i32 = create<sem::I32>();
auto* vec3_i32 = create<sem::Vector>(i32, 3u);
auto result = table->Lookup(ast::BinaryOp::kMultiply, i32, vec3_i32, Source{{12, 34}},
/* is_compound */ true);
EXPECT_EQ(result.result, vec3_i32);
EXPECT_EQ(result.lhs, i32);
EXPECT_EQ(result.rhs, vec3_i32);
EXPECT_EQ(Diagnostics().str(), "");
}
TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
auto* f32 = create<sem::F32>();
auto* bool_ = create<sem::Bool>();
auto result = table->Lookup(ast::BinaryOp::kMultiply, f32, bool_, Source{{12, 34}},
/* is_compound */ true);
ASSERT_EQ(result.result, nullptr);
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
9 candidate operators:
operator *= (T, T) -> T where: T is f32, i32 or u32
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator *= (f32, matNxM<f32>) -> matNxM<f32>
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator *= (matNxM<f32>, f32) -> matNxM<f32>
operator *= (matCxR<f32>, vecC<f32>) -> vecR<f32>
operator *= (vecR<f32>, matCxR<f32>) -> vecC<f32>
operator *= (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
)");
}
} // namespace } // namespace
} // namespace tint } // namespace tint

View File

@ -1746,12 +1746,8 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
auto* lhs_ty = lhs->Type()->UnwrapRef(); auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef(); auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op); auto* ty = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false).result;
if (!ty) { if (!ty) {
AddError("Binary expression operand types are invalid for this operation: " +
sem_.TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
sem_.TypeNameOf(rhs_ty),
expr->source);
return nullptr; return nullptr;
} }
@ -1764,160 +1760,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
return sem; return sem;
} }
const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty,
const sem::Type* rhs_ty,
ast::BinaryOp op) {
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
using U32 = sem::U32;
using Matrix = sem::Matrix;
using Vector = sem::Vector;
auto* lhs_vec = lhs_ty->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_ty->As<Vector>();
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
(lhs_vec_elem_type == rhs_vec_elem_type) &&
(lhs_vec->Width() == rhs_vec->Width());
const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
// Binary logical expressions
if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) {
if (matching_types && lhs_ty->Is<Bool>()) {
return lhs_ty;
}
}
if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) {
if (matching_types && lhs_ty->Is<Bool>()) {
return lhs_ty;
}
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
return lhs_ty;
}
}
// Arithmetic expressions
if (ast::IsArithmetic(op)) {
// Binary arithmetic expressions over scalars
if (matching_types && lhs_ty->is_numeric_scalar()) {
return lhs_ty;
}
// Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->is_numeric_scalar()) {
return lhs_ty;
}
// Binary arithmetic expressions with mixed scalar and vector operands
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) && rhs_ty->is_numeric_scalar()) {
return lhs_ty;
}
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) && lhs_ty->is_numeric_scalar()) {
return rhs_ty;
}
}
// Matrix arithmetic
auto* lhs_mat = lhs_ty->As<Matrix>();
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_ty->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Addition and subtraction of float matrices
if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) && lhs_mat_elem_type &&
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->columns()) && (lhs_mat->rows() == rhs_mat->rows())) {
return rhs_ty;
}
if (op == ast::BinaryOp::kMultiply) {
// Multiplication of a matrix and a scalar
if (lhs_ty->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
return rhs_ty;
}
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_ty->Is<F32>()) {
return lhs_ty;
}
// Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>() && (lhs_vec->Width() == rhs_mat->rows())) {
return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns());
}
// Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_vec_elem_type &&
rhs_vec_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_vec->Width())) {
return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows());
}
// Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_mat->rows())) {
return builder_->create<sem::Matrix>(
builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
rhs_mat->columns());
}
}
// Comparison expressions
if (ast::IsComparison(op)) {
if (matching_types) {
// Special case for bools: only == and !=
if (lhs_ty->Is<Bool>() &&
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
return builder_->create<sem::Bool>();
}
// For the rest, we can compare i32, u32, and f32
if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
return builder_->create<sem::Bool>();
}
}
// Same for vectors
if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() &&
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
lhs_vec->Width());
}
if (lhs_vec_elem_type->is_numeric_scalar()) {
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
lhs_vec->Width());
}
}
}
// Binary bitwise operations
if (ast::IsBitwise(op)) {
if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
return lhs_ty;
}
}
// Bit shift expressions
if (ast::IsBitshift(op)) {
// Type validation rules are the same for left or right shift, despite
// differences in computation rules (i.e. right shift can be arithmetic or
// logical depending on lhs type).
if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
return lhs_ty;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() && rhs_vec_elem_type &&
rhs_vec_elem_type->Is<U32>()) {
return lhs_ty;
}
}
return nullptr;
}
sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
auto* expr = sem_.Get(unary->expr); auto* expr = sem_.Get(unary->expr);
auto* expr_ty = expr->Type(); auto* expr_ty = expr->Type();
@ -2472,11 +2314,8 @@ sem::Statement* Resolver::CompoundAssignmentStatement(
auto* lhs_ty = lhs->Type()->UnwrapRef(); auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef(); auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* ty = BinaryOpType(lhs_ty, rhs_ty, stmt->op); auto* ty = intrinsic_table_->Lookup(stmt->op, lhs_ty, rhs_ty, stmt->source, true).result;
if (!ty) { if (!ty) {
AddError("compound assignment operand types are invalid: " + sem_.TypeNameOf(lhs_ty) +
" " + FriendlyName(stmt->op) + " " + sem_.TypeNameOf(rhs_ty),
stmt->source);
return false; return false;
} }
return validator_.Assignment(stmt, ty); return validator_.Assignment(stmt, ty);

View File

@ -226,11 +226,6 @@ class Resolver {
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&); bool Statements(const ast::StatementList&);
// Resolve the result type of a binary operator.
// Returns nullptr if the types are not valid for this operator.
const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
const sem::Type* rhs_ty,
ast::BinaryOp op);
/// Resolves the WorkgroupSize for the given function, assigning it to /// Resolves the WorkgroupSize for the given function, assigning it to
/// current_function_ /// current_function_

View File

@ -1570,11 +1570,7 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
WrapInFunction(expr); WrapInFunction(expr);
ASSERT_FALSE(r()->Resolve()); ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(), EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching overload for operator "));
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) + " " +
FriendlyName(rhs_type));
} }
INSTANTIATE_TEST_SUITE_P(ResolverTest, INSTANTIATE_TEST_SUITE_P(ResolverTest,
Expr_Binary_Test_Invalid, Expr_Binary_Test_Invalid,
@ -1618,11 +1614,7 @@ TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
ASSERT_TRUE(TypeOf(expr) == result_type); ASSERT_TRUE(TypeOf(expr) == result_type);
} else { } else {
ASSERT_FALSE(r()->Resolve()); ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(), EXPECT_THAT(r()->error(), HasSubstr("no matching overload for operator *"));
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) + " " +
FriendlyName(rhs_type));
} }
} }
auto all_dimension_values = testing::Values(2u, 3u, 4u); auto all_dimension_values = testing::Values(2u, 3u, 4u);
@ -1660,11 +1652,7 @@ TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
ASSERT_TRUE(TypeOf(expr) == result_type); ASSERT_TRUE(TypeOf(expr) == result_type);
} else { } else {
ASSERT_FALSE(r()->Resolve()); ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(), EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching overload for operator * "));
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) + " " +
FriendlyName(rhs_type));
} }
} }
INSTANTIATE_TEST_SUITE_P(ResolverTest, INSTANTIATE_TEST_SUITE_P(ResolverTest,