tint: Add binary-ops to the intrinsic table
• Declare all the binary ops in the intrinsics.def file. • Reimplement Resolver::BinaryOpType() with the IntrinsicTable. This will simplify maintenance of the operators, and will greatly simplify the [AbstractInt -> i32|u32] [AbstractFloat -> f32|f16] logic. Bug: tint:1504 Change-Id: Ie028602e05b59916c3f2168c92f200f10e402b96 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89027 Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
bcdb6e9da8
commit
9fb29a364e
|
@ -558,3 +558,91 @@ fn textureLoad(texture: texture_external, coords: vec2<i32>) -> vec4<f32>
|
||||||
[[stage("fragment", "compute")]] fn atomicXor<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T
|
[[stage("fragment", "compute")]] fn atomicXor<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T
|
||||||
[[stage("fragment", "compute")]] fn atomicExchange<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T
|
[[stage("fragment", "compute")]] fn atomicExchange<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T) -> T
|
||||||
[[stage("fragment", "compute")]] fn atomicCompareExchangeWeak<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T, T) -> vec2<T>
|
[[stage("fragment", "compute")]] fn atomicCompareExchangeWeak<T: iu32, S: workgroup_or_storage>(ptr<S, atomic<T>, read_write>, T, T) -> vec2<T>
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Operators //
|
||||||
|
// //
|
||||||
|
// The operator declarations below declare all the unary and binary operators //
|
||||||
|
// supported by the WGSL language (with exception for address-of and //
|
||||||
|
// dereference unary operators). //
|
||||||
|
// //
|
||||||
|
// The syntax is almost identical to builtin functions, except we use 'op' //
|
||||||
|
// instead of 'fn'. The resolving rules are identical to builtins, which is //
|
||||||
|
// described in detail above. //
|
||||||
|
// //
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Binary Operators //
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
op + <T: fiu32>(T, T) -> T
|
||||||
|
op + <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
op + <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||||
|
op + <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||||
|
op + <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
|
||||||
|
|
||||||
|
op - <T: fiu32>(T, T) -> T
|
||||||
|
op - <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
op - <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||||
|
op - <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||||
|
op - <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
|
||||||
|
|
||||||
|
op * <T: fiu32>(T, T) -> T
|
||||||
|
op * <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
op * <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||||
|
op * <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||||
|
op * <N: num, M: num> (f32, mat<N, M, f32>) -> mat<N, M, f32>
|
||||||
|
op * <N: num, M: num> (mat<N, M, f32>, f32) -> mat<N, M, f32>
|
||||||
|
op * <C: num, R: num> (mat<C, R, f32>, vec<C, f32>) -> vec<R, f32>
|
||||||
|
op * <C: num, R: num> (vec<R, f32>, mat<C, R, f32>) -> vec<C, f32>
|
||||||
|
op * <K: num, C: num, R: num> (mat<K, R, f32>, mat<C, K, f32>) -> mat<C, R, f32>
|
||||||
|
|
||||||
|
op / <T: fiu32>(T, T) -> T
|
||||||
|
op / <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
op / <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||||
|
op / <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||||
|
|
||||||
|
op % <T: fiu32>(T, T) -> T
|
||||||
|
op % <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
op % <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||||
|
op % <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||||
|
|
||||||
|
op ^ <T: iu32>(T, T) -> T
|
||||||
|
op ^ <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
|
||||||
|
op & (bool, bool) -> bool
|
||||||
|
op & <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
|
||||||
|
op & <T: iu32>(T, T) -> T
|
||||||
|
op & <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
|
||||||
|
op | (bool, bool) -> bool
|
||||||
|
op | <N: num> (vec<N, bool>, vec<N, bool>) -> vec<N, bool>
|
||||||
|
op | <T: iu32>(T, T) -> T
|
||||||
|
op | <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||||
|
|
||||||
|
op && (bool, bool) -> bool
|
||||||
|
op || (bool, bool) -> bool
|
||||||
|
|
||||||
|
op == <T: scalar>(T, T) -> bool
|
||||||
|
op == <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||||
|
|
||||||
|
op != <T: scalar>(T, T) -> bool
|
||||||
|
op != <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||||
|
|
||||||
|
op < <T: fiu32>(T, T) -> bool
|
||||||
|
op < <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||||
|
|
||||||
|
op > <T: fiu32>(T, T) -> bool
|
||||||
|
op > <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||||
|
|
||||||
|
op <= <T: fiu32>(T, T) -> bool
|
||||||
|
op <= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||||
|
|
||||||
|
op >= <T: fiu32>(T, T) -> bool
|
||||||
|
op >= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||||
|
|
||||||
|
op << <T: iu32>(T, u32) -> T
|
||||||
|
op << <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
|
||||||
|
|
||||||
|
op >> <T: iu32>(T, u32) -> T
|
||||||
|
op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
#include "src/tint/resolver/resolver_test_helper.h"
|
#include "src/tint/resolver/resolver_test_helper.h"
|
||||||
#include "src/tint/sem/storage_texture.h"
|
#include "src/tint/sem/storage_texture.h"
|
||||||
|
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
using namespace tint::number_suffixes; // NOLINT
|
using namespace tint::number_suffixes; // NOLINT
|
||||||
|
|
||||||
namespace tint::resolver {
|
namespace tint::resolver {
|
||||||
|
@ -71,9 +73,8 @@ TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleTypes) {
|
||||||
|
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
|
|
||||||
EXPECT_EQ(r()->error(),
|
EXPECT_THAT(r()->error(),
|
||||||
"12:34 error: compound assignment operand types are invalid: i32 "
|
HasSubstr("12:34 error: no matching overload for operator += (i32, f32)"));
|
||||||
"add f32");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleOp) {
|
TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleOp) {
|
||||||
|
@ -89,8 +90,8 @@ TEST_F(ResolverCompoundAssignmentValidationTest, IncompatibleOp) {
|
||||||
|
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
|
|
||||||
EXPECT_EQ(r()->error(),
|
EXPECT_THAT(r()->error(),
|
||||||
"12:34 error: compound assignment operand types are invalid: f32 or f32");
|
HasSubstr("12:34 error: no matching overload for operator |= (f32, f32)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverCompoundAssignmentValidationTest, VectorScalar_Pass) {
|
TEST_F(ResolverCompoundAssignmentValidationTest, VectorScalar_Pass) {
|
||||||
|
@ -180,9 +181,9 @@ TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ColumnMismatch) {
|
||||||
|
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
|
|
||||||
EXPECT_EQ(r()->error(),
|
EXPECT_THAT(
|
||||||
"12:34 error: compound assignment operand types are invalid: "
|
r()->error(),
|
||||||
"vec4<f32> multiply mat4x2<f32>");
|
HasSubstr("12:34 error: no matching overload for operator *= (vec4<f32>, mat4x2<f32>)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ResultMismatch) {
|
TEST_F(ResolverCompoundAssignmentValidationTest, VectorMatrix_ResultMismatch) {
|
||||||
|
@ -223,9 +224,8 @@ TEST_F(ResolverCompoundAssignmentValidationTest, Phony) {
|
||||||
// }
|
// }
|
||||||
WrapInFunction(CompoundAssign(Source{{56, 78}}, Phony(), 1_i, ast::BinaryOp::kAdd));
|
WrapInFunction(CompoundAssign(Source{{56, 78}}, Phony(), 1_i, ast::BinaryOp::kAdd));
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(r()->error(),
|
EXPECT_THAT(r()->error(),
|
||||||
"56:78 error: compound assignment operand types are invalid: void "
|
HasSubstr("56:78 error: no matching overload for operator += (void, i32)"));
|
||||||
"add i32");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverCompoundAssignmentValidationTest, ReadOnlyBuffer) {
|
TEST_F(ResolverCompoundAssignmentValidationTest, ReadOnlyBuffer) {
|
||||||
|
@ -239,8 +239,7 @@ TEST_F(ResolverCompoundAssignmentValidationTest, ReadOnlyBuffer) {
|
||||||
|
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(r()->error(),
|
EXPECT_EQ(r()->error(),
|
||||||
"56:78 error: cannot store into a read-only type 'ref<storage, "
|
"56:78 error: cannot store into a read-only type 'ref<storage, i32, read>'");
|
||||||
"i32, read>'");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverCompoundAssignmentValidationTest, LhsConstant) {
|
TEST_F(ResolverCompoundAssignmentValidationTest, LhsConstant) {
|
||||||
|
@ -269,9 +268,9 @@ TEST_F(ResolverCompoundAssignmentValidationTest, LhsAtomic) {
|
||||||
WrapInFunction(CompoundAssign(Source{{56, 78}}, "a", "a", ast::BinaryOp::kAdd));
|
WrapInFunction(CompoundAssign(Source{{56, 78}}, "a", "a", ast::BinaryOp::kAdd));
|
||||||
|
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(r()->error(),
|
EXPECT_THAT(
|
||||||
"56:78 error: compound assignment operand types are invalid: "
|
r()->error(),
|
||||||
"atomic<i32> add atomic<i32>");
|
HasSubstr("error: no matching overload for operator += (atomic<i32>, atomic<i32>)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -762,7 +762,7 @@ struct OverloadInfo {
|
||||||
bool is_deprecated;
|
bool is_deprecated;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// IntrinsicInfo describes a builtin function
|
/// IntrinsicInfo describes a builtin function or operator overload
|
||||||
struct IntrinsicInfo {
|
struct IntrinsicInfo {
|
||||||
/// Number of overloads of the intrinsic
|
/// Number of overloads of the intrinsic
|
||||||
const uint8_t num_overloads;
|
const uint8_t num_overloads;
|
||||||
|
@ -791,11 +791,11 @@ struct IntrinsicPrototype {
|
||||||
for (auto& p : i.parameters) {
|
for (auto& p : i.parameters) {
|
||||||
utils::HashCombine(&hash, p.type, p.usage);
|
utils::HashCombine(&hash, p.type, p.usage);
|
||||||
}
|
}
|
||||||
return utils::Hash(hash, i.type, i.return_type, i.supported_stages, i.is_deprecated);
|
return utils::Hash(hash, i.index, i.return_type, i.supported_stages, i.is_deprecated);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
sem::BuiltinType type = sem::BuiltinType::kNone;
|
uint32_t index = 0; // Index of the intrinsic (builtin or operator)
|
||||||
std::vector<Parameter> parameters;
|
std::vector<Parameter> parameters;
|
||||||
sem::Type const* return_type = nullptr;
|
sem::Type const* return_type = nullptr;
|
||||||
PipelineStageSet supported_stages;
|
PipelineStageSet supported_stages;
|
||||||
|
@ -804,7 +804,7 @@ struct IntrinsicPrototype {
|
||||||
|
|
||||||
/// Equality operator for IntrinsicPrototype
|
/// Equality operator for IntrinsicPrototype
|
||||||
bool operator==(const IntrinsicPrototype& a, const IntrinsicPrototype& b) {
|
bool operator==(const IntrinsicPrototype& a, const IntrinsicPrototype& b) {
|
||||||
if (a.type != b.type || a.supported_stages != b.supported_stages ||
|
if (a.index != b.index || a.supported_stages != b.supported_stages ||
|
||||||
a.return_type != b.return_type || a.is_deprecated != b.is_deprecated ||
|
a.return_type != b.return_type || a.is_deprecated != b.is_deprecated ||
|
||||||
a.parameters.size() != b.parameters.size()) {
|
a.parameters.size() != b.parameters.size()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -828,8 +828,22 @@ class Impl : public IntrinsicTable {
|
||||||
const std::vector<const sem::Type*>& args,
|
const std::vector<const sem::Type*>& args,
|
||||||
const Source& source) override;
|
const Source& source) override;
|
||||||
|
|
||||||
|
BinaryOperator Lookup(ast::BinaryOp op,
|
||||||
|
const sem::Type* lhs,
|
||||||
|
const sem::Type* rhs,
|
||||||
|
const Source& source,
|
||||||
|
bool is_compound) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const sem::Builtin* Match(sem::BuiltinType builtin_type,
|
// Candidate holds information about a mismatched overload that could be what the user intended
|
||||||
|
// to call.
|
||||||
|
struct Candidate {
|
||||||
|
const OverloadInfo* overload;
|
||||||
|
int score;
|
||||||
|
};
|
||||||
|
|
||||||
|
const IntrinsicPrototype Match(const char* intrinsic_name,
|
||||||
|
uint32_t intrinsic_index,
|
||||||
const OverloadInfo& overload,
|
const OverloadInfo& overload,
|
||||||
const std::vector<const sem::Type*>& args,
|
const std::vector<const sem::Type*>& args,
|
||||||
int& match_score);
|
int& match_score);
|
||||||
|
@ -838,9 +852,7 @@ class Impl : public IntrinsicTable {
|
||||||
const OverloadInfo& overload,
|
const OverloadInfo& overload,
|
||||||
MatcherIndex const* matcher_indices) const;
|
MatcherIndex const* matcher_indices) const;
|
||||||
|
|
||||||
void PrintOverload(std::ostream& ss,
|
void PrintOverload(std::ostream& ss, const OverloadInfo& overload, const char* name) const;
|
||||||
const OverloadInfo& overload,
|
|
||||||
sem::BuiltinType builtin_type) const;
|
|
||||||
|
|
||||||
ProgramBuilder& builder;
|
ProgramBuilder& builder;
|
||||||
Matchers matchers;
|
Matchers matchers;
|
||||||
|
@ -850,10 +862,10 @@ class Impl : public IntrinsicTable {
|
||||||
/// @return a string representing a call to a builtin with the given argument
|
/// @return a string representing a call to a builtin with the given argument
|
||||||
/// types.
|
/// types.
|
||||||
std::string CallSignature(ProgramBuilder& builder,
|
std::string CallSignature(ProgramBuilder& builder,
|
||||||
sem::BuiltinType builtin_type,
|
const char* intrinsic_name,
|
||||||
const std::vector<const sem::Type*>& args) {
|
const std::vector<const sem::Type*>& args) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << sem::str(builtin_type) << "(";
|
ss << intrinsic_name << "(";
|
||||||
{
|
{
|
||||||
bool first = true;
|
bool first = true;
|
||||||
for (auto* arg : args) {
|
for (auto* arg : args) {
|
||||||
|
@ -882,22 +894,30 @@ Impl::Impl(ProgramBuilder& b) : builder(b) {}
|
||||||
const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
|
const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
|
||||||
const std::vector<const sem::Type*>& args,
|
const std::vector<const sem::Type*>& args,
|
||||||
const Source& source) {
|
const Source& source) {
|
||||||
// Candidate holds information about a mismatched overload that could be what
|
|
||||||
// the user intended to call.
|
|
||||||
struct Candidate {
|
|
||||||
const OverloadInfo* overload;
|
|
||||||
int score;
|
|
||||||
};
|
|
||||||
|
|
||||||
// The list of failed matches that had promise.
|
// The list of failed matches that had promise.
|
||||||
std::vector<Candidate> candidates;
|
std::vector<Candidate> candidates;
|
||||||
|
|
||||||
auto& builtin = kBuiltins[static_cast<uint32_t>(builtin_type)];
|
uint32_t intrinsic_index = static_cast<uint32_t>(builtin_type);
|
||||||
|
const char* intrinsic_name = sem::str(builtin_type);
|
||||||
|
auto& builtin = kBuiltins[intrinsic_index];
|
||||||
for (uint32_t o = 0; o < builtin.num_overloads; o++) {
|
for (uint32_t o = 0; o < builtin.num_overloads; o++) {
|
||||||
int match_score = 1000;
|
int match_score = 1000;
|
||||||
auto& overload = builtin.overloads[o];
|
auto& overload = builtin.overloads[o];
|
||||||
if (auto* match = Match(builtin_type, overload, args, match_score)) {
|
auto match = Match(intrinsic_name, intrinsic_index, overload, args, match_score);
|
||||||
return match;
|
if (match.return_type) {
|
||||||
|
// De-duplicate builtins that are identical.
|
||||||
|
return utils::GetOrCreate(builtins, match, [&] {
|
||||||
|
std::vector<sem::Parameter*> params;
|
||||||
|
params.reserve(match.parameters.size());
|
||||||
|
for (auto& p : match.parameters) {
|
||||||
|
params.emplace_back(builder.create<sem::Parameter>(
|
||||||
|
nullptr, static_cast<uint32_t>(params.size()), p.type,
|
||||||
|
ast::StorageClass::kNone, ast::Access::kUndefined, p.usage));
|
||||||
|
}
|
||||||
|
return builder.create<sem::Builtin>(builtin_type, match.return_type,
|
||||||
|
std::move(params), match.supported_stages,
|
||||||
|
match.is_deprecated);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
if (match_score > 0) {
|
if (match_score > 0) {
|
||||||
candidates.emplace_back(Candidate{&overload, match_score});
|
candidates.emplace_back(Candidate{&overload, match_score});
|
||||||
|
@ -910,14 +930,14 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
|
||||||
|
|
||||||
// Generate an error message
|
// Generate an error message
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "no matching call to " << CallSignature(builder, builtin_type, args) << std::endl;
|
ss << "no matching call to " << CallSignature(builder, intrinsic_name, args) << std::endl;
|
||||||
if (!candidates.empty()) {
|
if (!candidates.empty()) {
|
||||||
ss << std::endl;
|
ss << std::endl;
|
||||||
ss << candidates.size() << " candidate function" << (candidates.size() > 1 ? "s:" : ":")
|
ss << candidates.size() << " candidate function" << (candidates.size() > 1 ? "s:" : ":")
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
for (auto& candidate : candidates) {
|
for (auto& candidate : candidates) {
|
||||||
ss << " ";
|
ss << " ";
|
||||||
PrintOverload(ss, *candidate.overload, builtin_type);
|
PrintOverload(ss, *candidate.overload, intrinsic_name);
|
||||||
ss << std::endl;
|
ss << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -925,7 +945,95 @@ const sem::Builtin* Impl::Lookup(sem::BuiltinType builtin_type,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const sem::Builtin* Impl::Match(sem::BuiltinType builtin_type,
|
IntrinsicTable::BinaryOperator Impl::Lookup(ast::BinaryOp op,
|
||||||
|
const sem::Type* lhs,
|
||||||
|
const sem::Type* rhs,
|
||||||
|
const Source& source,
|
||||||
|
bool is_compound) {
|
||||||
|
// The list of failed matches that had promise.
|
||||||
|
std::vector<Candidate> candidates;
|
||||||
|
|
||||||
|
auto [intrinsic_index, intrinsic_name] = [&]() -> std::pair<uint32_t, const char*> {
|
||||||
|
switch (op) {
|
||||||
|
case ast::BinaryOp::kAnd:
|
||||||
|
return {kOperatorAnd, is_compound ? "operator &= " : "operator & "};
|
||||||
|
case ast::BinaryOp::kOr:
|
||||||
|
return {kOperatorOr, is_compound ? "operator |= " : "operator | "};
|
||||||
|
case ast::BinaryOp::kXor:
|
||||||
|
return {kOperatorXor, is_compound ? "operator ^= " : "operator ^ "};
|
||||||
|
case ast::BinaryOp::kLogicalAnd:
|
||||||
|
return {kOperatorLogicalAnd, "operator && "};
|
||||||
|
case ast::BinaryOp::kLogicalOr:
|
||||||
|
return {kOperatorLogicalOr, "operator || "};
|
||||||
|
case ast::BinaryOp::kEqual:
|
||||||
|
return {kOperatorEqual, "operator == "};
|
||||||
|
case ast::BinaryOp::kNotEqual:
|
||||||
|
return {kOperatorNotEqual, "operator != "};
|
||||||
|
case ast::BinaryOp::kLessThan:
|
||||||
|
return {kOperatorLessThan, "operator < "};
|
||||||
|
case ast::BinaryOp::kGreaterThan:
|
||||||
|
return {kOperatorGreaterThan, "operator > "};
|
||||||
|
case ast::BinaryOp::kLessThanEqual:
|
||||||
|
return {kOperatorLessThanEqual, "operator <= "};
|
||||||
|
case ast::BinaryOp::kGreaterThanEqual:
|
||||||
|
return {kOperatorGreaterThanEqual, "operator >= "};
|
||||||
|
case ast::BinaryOp::kShiftLeft:
|
||||||
|
return {kOperatorShiftLeft, is_compound ? "operator <<= " : "operator << "};
|
||||||
|
case ast::BinaryOp::kShiftRight:
|
||||||
|
return {kOperatorShiftRight, is_compound ? "operator >>= " : "operator >> "};
|
||||||
|
case ast::BinaryOp::kAdd:
|
||||||
|
return {kOperatorPlus, is_compound ? "operator += " : "operator + "};
|
||||||
|
case ast::BinaryOp::kSubtract:
|
||||||
|
return {kOperatorMinus, is_compound ? "operator -= " : "operator - "};
|
||||||
|
case ast::BinaryOp::kMultiply:
|
||||||
|
return {kOperatorStar, is_compound ? "operator *= " : "operator * "};
|
||||||
|
case ast::BinaryOp::kDivide:
|
||||||
|
return {kOperatorDivide, is_compound ? "operator /= " : "operator / "};
|
||||||
|
case ast::BinaryOp::kModulo:
|
||||||
|
return {kOperatorModulo, is_compound ? "operator %= " : "operator % "};
|
||||||
|
default:
|
||||||
|
return {0, "<unknown>"};
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
|
||||||
|
auto& builtin = kOperators[intrinsic_index];
|
||||||
|
for (uint32_t o = 0; o < builtin.num_overloads; o++) {
|
||||||
|
int match_score = 1000;
|
||||||
|
auto& overload = builtin.overloads[o];
|
||||||
|
auto match = Match(intrinsic_name, intrinsic_index, overload, {lhs, rhs}, match_score);
|
||||||
|
if (match.return_type) {
|
||||||
|
return BinaryOperator{match.return_type, match.parameters[0].type,
|
||||||
|
match.parameters[1].type};
|
||||||
|
}
|
||||||
|
if (match_score > 0) {
|
||||||
|
candidates.emplace_back(Candidate{&overload, match_score});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the candidates with the most promising first
|
||||||
|
std::stable_sort(candidates.begin(), candidates.end(),
|
||||||
|
[](const Candidate& a, const Candidate& b) { return a.score > b.score; });
|
||||||
|
|
||||||
|
// Generate an error message
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "no matching overload for " << CallSignature(builder, intrinsic_name, {lhs, rhs})
|
||||||
|
<< std::endl;
|
||||||
|
if (!candidates.empty()) {
|
||||||
|
ss << std::endl;
|
||||||
|
ss << candidates.size() << " candidate operator" << (candidates.size() > 1 ? "s:" : ":")
|
||||||
|
<< std::endl;
|
||||||
|
for (auto& candidate : candidates) {
|
||||||
|
ss << " ";
|
||||||
|
PrintOverload(ss, *candidate.overload, intrinsic_name);
|
||||||
|
ss << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.Diagnostics().add_error(diag::System::Resolver, ss.str(), source);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const IntrinsicPrototype Impl::Match(const char* intrinsic_name,
|
||||||
|
uint32_t intrinsic_index,
|
||||||
const OverloadInfo& overload,
|
const OverloadInfo& overload,
|
||||||
const std::vector<const sem::Type*>& args,
|
const std::vector<const sem::Type*>& args,
|
||||||
int& match_score) {
|
int& match_score) {
|
||||||
|
@ -994,7 +1102,7 @@ const sem::Builtin* Impl::Match(sem::BuiltinType builtin_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!overload_matched) {
|
if (!overload_matched) {
|
||||||
return nullptr;
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the return type
|
// Build the return type
|
||||||
|
@ -1004,34 +1112,22 @@ const sem::Builtin* Impl::Match(sem::BuiltinType builtin_type,
|
||||||
return_type = Match(closed, overload, indices).Type(&any);
|
return_type = Match(closed, overload, indices).Type(&any);
|
||||||
if (!return_type) {
|
if (!return_type) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
PrintOverload(ss, overload, builtin_type);
|
PrintOverload(ss, overload, intrinsic_name);
|
||||||
TINT_ICE(Resolver, builder.Diagnostics())
|
TINT_ICE(Resolver, builder.Diagnostics())
|
||||||
<< "MatchState.Match() returned null for " << ss.str();
|
<< "MatchState.Match() returned null for " << ss.str();
|
||||||
return nullptr;
|
return {};
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return_type = builder.create<sem::Void>();
|
return_type = builder.create<sem::Void>();
|
||||||
}
|
}
|
||||||
|
|
||||||
IntrinsicPrototype builtin;
|
IntrinsicPrototype builtin;
|
||||||
builtin.type = builtin_type;
|
builtin.index = intrinsic_index;
|
||||||
builtin.return_type = return_type;
|
builtin.return_type = return_type;
|
||||||
builtin.parameters = std::move(parameters);
|
builtin.parameters = std::move(parameters);
|
||||||
builtin.supported_stages = overload.supported_stages;
|
builtin.supported_stages = overload.supported_stages;
|
||||||
builtin.is_deprecated = overload.is_deprecated;
|
builtin.is_deprecated = overload.is_deprecated;
|
||||||
|
return builtin;
|
||||||
// De-duplicate builtins that are identical.
|
|
||||||
return utils::GetOrCreate(builtins, builtin, [&] {
|
|
||||||
std::vector<sem::Parameter*> params;
|
|
||||||
params.reserve(builtin.parameters.size());
|
|
||||||
for (auto& p : builtin.parameters) {
|
|
||||||
params.emplace_back(builder.create<sem::Parameter>(
|
|
||||||
nullptr, static_cast<uint32_t>(params.size()), p.type, ast::StorageClass::kNone,
|
|
||||||
ast::Access::kUndefined, p.usage));
|
|
||||||
}
|
|
||||||
return builder.create<sem::Builtin>(builtin.type, builtin.return_type, std::move(params),
|
|
||||||
builtin.supported_stages, builtin.is_deprecated);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MatchState Impl::Match(ClosedState& closed,
|
MatchState Impl::Match(ClosedState& closed,
|
||||||
|
@ -1040,12 +1136,10 @@ MatchState Impl::Match(ClosedState& closed,
|
||||||
return MatchState(builder, closed, matchers, overload, matcher_indices);
|
return MatchState(builder, closed, matchers, overload, matcher_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Impl::PrintOverload(std::ostream& ss,
|
void Impl::PrintOverload(std::ostream& ss, const OverloadInfo& overload, const char* name) const {
|
||||||
const OverloadInfo& overload,
|
|
||||||
sem::BuiltinType builtin_type) const {
|
|
||||||
ClosedState closed(builder);
|
ClosedState closed(builder);
|
||||||
|
|
||||||
ss << builtin_type << "(";
|
ss << name << "(";
|
||||||
for (uint32_t p = 0; p < overload.num_parameters; p++) {
|
for (uint32_t p = 0; p < overload.num_parameters; p++) {
|
||||||
auto& parameter = overload.parameters[p];
|
auto& parameter = overload.parameters[p];
|
||||||
if (p > 0) {
|
if (p > 0) {
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ProgramBuilder;
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
|
||||||
/// IntrinsicTable is a lookup table of all the WGSL builtin functions
|
/// IntrinsicTable is a lookup table of all the WGSL builtin functions and intrinsic operators
|
||||||
class IntrinsicTable {
|
class IntrinsicTable {
|
||||||
public:
|
public:
|
||||||
/// @param builder the program builder
|
/// @param builder the program builder
|
||||||
|
@ -38,8 +38,18 @@ class IntrinsicTable {
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~IntrinsicTable();
|
virtual ~IntrinsicTable();
|
||||||
|
|
||||||
/// Lookup looks for the builtin overload with the given signature, raising
|
/// BinaryOperator describes a resolved binary operator
|
||||||
/// an error diagnostic if the builtin was not found.
|
struct BinaryOperator {
|
||||||
|
/// The result type of the binary operator
|
||||||
|
const sem::Type* result;
|
||||||
|
/// The type of LHS of the binary operator
|
||||||
|
const sem::Type* lhs;
|
||||||
|
/// The type of RHS of the binary operator
|
||||||
|
const sem::Type* rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Lookup looks for the builtin overload with the given signature, raising an error diagnostic
|
||||||
|
/// if the builtin was not found.
|
||||||
/// @param type the builtin type
|
/// @param type the builtin type
|
||||||
/// @param args the argument types passed to the builtin function
|
/// @param args the argument types passed to the builtin function
|
||||||
/// @param source the source of the builtin call
|
/// @param source the source of the builtin call
|
||||||
|
@ -47,6 +57,21 @@ class IntrinsicTable {
|
||||||
virtual const sem::Builtin* Lookup(sem::BuiltinType type,
|
virtual const sem::Builtin* Lookup(sem::BuiltinType type,
|
||||||
const std::vector<const sem::Type*>& args,
|
const std::vector<const sem::Type*>& args,
|
||||||
const Source& source) = 0;
|
const Source& source) = 0;
|
||||||
|
|
||||||
|
/// Lookup looks for the binary op overload with the given signature, raising an error
|
||||||
|
/// diagnostic if the operator was not found.
|
||||||
|
/// @param op the binary operator
|
||||||
|
/// @param lhs the LHS value type passed to the operator
|
||||||
|
/// @param rhs the RHS value type passed to the operator
|
||||||
|
/// @param source the source of the operator call
|
||||||
|
/// @param is_compound true if the binary operator is being used as a compound assignment
|
||||||
|
/// @return the operator call target signature. If the operator was not found
|
||||||
|
/// BinaryOperator::result will be nullptr.
|
||||||
|
virtual BinaryOperator Lookup(ast::BinaryOp op,
|
||||||
|
const sem::Type* lhs,
|
||||||
|
const sem::Type* rhs,
|
||||||
|
const Source& source,
|
||||||
|
bool is_compound) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -119,6 +119,23 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||||
{{- end }}
|
{{- end }}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
constexpr IntrinsicInfo kOperators[] = {
|
||||||
|
{{- range $i, $o := .Operators }}
|
||||||
|
{
|
||||||
|
/* [{{$i}}] */
|
||||||
|
{{- range $o.OverloadDescriptions }}
|
||||||
|
/* {{.}} */
|
||||||
|
{{- end }}
|
||||||
|
/* num overloads */ {{$o.NumOverloads}},
|
||||||
|
/* overloads */ &kOverloads[{{$o.OverloadsOffset}}],
|
||||||
|
},
|
||||||
|
{{- end }}
|
||||||
|
};
|
||||||
|
|
||||||
|
{{- range $i, $o := .Operators }}
|
||||||
|
constexpr uint8_t kOperator{{template "OperatorName" $o.Name}} = {{$i}};
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
{{ end -}}
|
{{ end -}}
|
||||||
|
|
||||||
|
@ -399,3 +416,29 @@ Matchers::~Matchers() = default;
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
{{- end -}}
|
{{- end -}}
|
||||||
|
|
||||||
|
{{- /* ------------------------------------------------------------------ */ -}}
|
||||||
|
{{- define "OperatorName" -}}
|
||||||
|
{{- /* ------------------------------------------------------------------ */ -}}
|
||||||
|
{{- if eq . "<<" -}}ShiftLeft
|
||||||
|
{{- else if eq . "&" -}}And
|
||||||
|
{{- else if eq . "|" -}}Or
|
||||||
|
{{- else if eq . "^" -}}Xor
|
||||||
|
{{- else if eq . "&&" -}}LogicalAnd
|
||||||
|
{{- else if eq . "||" -}}LogicalOr
|
||||||
|
{{- else if eq . "==" -}}Equal
|
||||||
|
{{- else if eq . "!=" -}}NotEqual
|
||||||
|
{{- else if eq . "<" -}}LessThan
|
||||||
|
{{- else if eq . ">" -}}GreaterThan
|
||||||
|
{{- else if eq . "<=" -}}LessThanEqual
|
||||||
|
{{- else if eq . ">=" -}}GreaterThanEqual
|
||||||
|
{{- else if eq . "<<" -}}ShiftLeft
|
||||||
|
{{- else if eq . ">>" -}}ShiftRight
|
||||||
|
{{- else if eq . "+" -}}Plus
|
||||||
|
{{- else if eq . "-" -}}Minus
|
||||||
|
{{- else if eq . "*" -}}Star
|
||||||
|
{{- else if eq . "/" -}}Divide
|
||||||
|
{{- else if eq . "%" -}}Modulo
|
||||||
|
{{- else -}}<unknown-{{.}}>
|
||||||
|
{{- end -}}
|
||||||
|
{{- end -}}
|
||||||
|
|
|
@ -576,5 +576,69 @@ TEST_F(IntrinsicTableTest, SameOverloadReturnsSameBuiltinPointer) {
|
||||||
EXPECT_NE(b, c);
|
EXPECT_NE(b, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(IntrinsicTableTest, MatchBinaryOp) {
|
||||||
|
auto* i32 = create<sem::I32>();
|
||||||
|
auto* vec3_i32 = create<sem::Vector>(i32, 3u);
|
||||||
|
auto result = table->Lookup(ast::BinaryOp::kMultiply, i32, vec3_i32, Source{{12, 34}},
|
||||||
|
/* is_compound */ false);
|
||||||
|
EXPECT_EQ(result.result, vec3_i32);
|
||||||
|
EXPECT_EQ(result.lhs, i32);
|
||||||
|
EXPECT_EQ(result.rhs, vec3_i32);
|
||||||
|
EXPECT_EQ(Diagnostics().str(), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
|
||||||
|
auto* f32 = create<sem::F32>();
|
||||||
|
auto* bool_ = create<sem::Bool>();
|
||||||
|
auto result = table->Lookup(ast::BinaryOp::kMultiply, f32, bool_, Source{{12, 34}},
|
||||||
|
/* is_compound */ false);
|
||||||
|
ASSERT_EQ(result.result, nullptr);
|
||||||
|
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
|
||||||
|
|
||||||
|
9 candidate operators:
|
||||||
|
operator * (T, T) -> T where: T is f32, i32 or u32
|
||||||
|
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
|
||||||
|
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||||
|
operator * (f32, matNxM<f32>) -> matNxM<f32>
|
||||||
|
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||||
|
operator * (matNxM<f32>, f32) -> matNxM<f32>
|
||||||
|
operator * (matCxR<f32>, vecC<f32>) -> vecR<f32>
|
||||||
|
operator * (vecR<f32>, matCxR<f32>) -> vecC<f32>
|
||||||
|
operator * (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IntrinsicTableTest, MatchCompoundOp) {
|
||||||
|
auto* i32 = create<sem::I32>();
|
||||||
|
auto* vec3_i32 = create<sem::Vector>(i32, 3u);
|
||||||
|
auto result = table->Lookup(ast::BinaryOp::kMultiply, i32, vec3_i32, Source{{12, 34}},
|
||||||
|
/* is_compound */ true);
|
||||||
|
EXPECT_EQ(result.result, vec3_i32);
|
||||||
|
EXPECT_EQ(result.lhs, i32);
|
||||||
|
EXPECT_EQ(result.rhs, vec3_i32);
|
||||||
|
EXPECT_EQ(Diagnostics().str(), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
|
||||||
|
auto* f32 = create<sem::F32>();
|
||||||
|
auto* bool_ = create<sem::Bool>();
|
||||||
|
auto result = table->Lookup(ast::BinaryOp::kMultiply, f32, bool_, Source{{12, 34}},
|
||||||
|
/* is_compound */ true);
|
||||||
|
ASSERT_EQ(result.result, nullptr);
|
||||||
|
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
|
||||||
|
|
||||||
|
9 candidate operators:
|
||||||
|
operator *= (T, T) -> T where: T is f32, i32 or u32
|
||||||
|
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
|
||||||
|
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||||
|
operator *= (f32, matNxM<f32>) -> matNxM<f32>
|
||||||
|
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||||
|
operator *= (matNxM<f32>, f32) -> matNxM<f32>
|
||||||
|
operator *= (matCxR<f32>, vecC<f32>) -> vecR<f32>
|
||||||
|
operator *= (vecR<f32>, matCxR<f32>) -> vecC<f32>
|
||||||
|
operator *= (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
|
||||||
|
)");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
|
@ -1746,12 +1746,8 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
||||||
auto* lhs_ty = lhs->Type()->UnwrapRef();
|
auto* lhs_ty = lhs->Type()->UnwrapRef();
|
||||||
auto* rhs_ty = rhs->Type()->UnwrapRef();
|
auto* rhs_ty = rhs->Type()->UnwrapRef();
|
||||||
|
|
||||||
auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op);
|
auto* ty = intrinsic_table_->Lookup(expr->op, lhs_ty, rhs_ty, expr->source, false).result;
|
||||||
if (!ty) {
|
if (!ty) {
|
||||||
AddError("Binary expression operand types are invalid for this operation: " +
|
|
||||||
sem_.TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
|
|
||||||
sem_.TypeNameOf(rhs_ty),
|
|
||||||
expr->source);
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1764,160 +1760,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
||||||
return sem;
|
return sem;
|
||||||
}
|
}
|
||||||
|
|
||||||
const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty,
|
|
||||||
const sem::Type* rhs_ty,
|
|
||||||
ast::BinaryOp op) {
|
|
||||||
using Bool = sem::Bool;
|
|
||||||
using F32 = sem::F32;
|
|
||||||
using I32 = sem::I32;
|
|
||||||
using U32 = sem::U32;
|
|
||||||
using Matrix = sem::Matrix;
|
|
||||||
using Vector = sem::Vector;
|
|
||||||
|
|
||||||
auto* lhs_vec = lhs_ty->As<Vector>();
|
|
||||||
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
|
|
||||||
auto* rhs_vec = rhs_ty->As<Vector>();
|
|
||||||
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
|
|
||||||
|
|
||||||
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
|
|
||||||
(lhs_vec_elem_type == rhs_vec_elem_type) &&
|
|
||||||
(lhs_vec->Width() == rhs_vec->Width());
|
|
||||||
|
|
||||||
const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
|
|
||||||
|
|
||||||
// Binary logical expressions
|
|
||||||
if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) {
|
|
||||||
if (matching_types && lhs_ty->Is<Bool>()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) {
|
|
||||||
if (matching_types && lhs_ty->Is<Bool>()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Arithmetic expressions
|
|
||||||
if (ast::IsArithmetic(op)) {
|
|
||||||
// Binary arithmetic expressions over scalars
|
|
||||||
if (matching_types && lhs_ty->is_numeric_scalar()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Binary arithmetic expressions over vectors
|
|
||||||
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->is_numeric_scalar()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Binary arithmetic expressions with mixed scalar and vector operands
|
|
||||||
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) && rhs_ty->is_numeric_scalar()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) && lhs_ty->is_numeric_scalar()) {
|
|
||||||
return rhs_ty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Matrix arithmetic
|
|
||||||
auto* lhs_mat = lhs_ty->As<Matrix>();
|
|
||||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
|
||||||
auto* rhs_mat = rhs_ty->As<Matrix>();
|
|
||||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
|
||||||
// Addition and subtraction of float matrices
|
|
||||||
if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) && lhs_mat_elem_type &&
|
|
||||||
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
|
|
||||||
(lhs_mat->columns() == rhs_mat->columns()) && (lhs_mat->rows() == rhs_mat->rows())) {
|
|
||||||
return rhs_ty;
|
|
||||||
}
|
|
||||||
if (op == ast::BinaryOp::kMultiply) {
|
|
||||||
// Multiplication of a matrix and a scalar
|
|
||||||
if (lhs_ty->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
|
|
||||||
return rhs_ty;
|
|
||||||
}
|
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_ty->Is<F32>()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Vector times matrix
|
|
||||||
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && rhs_mat_elem_type &&
|
|
||||||
rhs_mat_elem_type->Is<F32>() && (lhs_vec->Width() == rhs_mat->rows())) {
|
|
||||||
return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Matrix times vector
|
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_vec_elem_type &&
|
|
||||||
rhs_vec_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_vec->Width())) {
|
|
||||||
return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Matrix times matrix
|
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
|
|
||||||
rhs_mat_elem_type->Is<F32>() && (lhs_mat->columns() == rhs_mat->rows())) {
|
|
||||||
return builder_->create<sem::Matrix>(
|
|
||||||
builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
|
|
||||||
rhs_mat->columns());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Comparison expressions
|
|
||||||
if (ast::IsComparison(op)) {
|
|
||||||
if (matching_types) {
|
|
||||||
// Special case for bools: only == and !=
|
|
||||||
if (lhs_ty->Is<Bool>() &&
|
|
||||||
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
|
|
||||||
return builder_->create<sem::Bool>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// For the rest, we can compare i32, u32, and f32
|
|
||||||
if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
|
|
||||||
return builder_->create<sem::Bool>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Same for vectors
|
|
||||||
if (matching_vec_elem_types) {
|
|
||||||
if (lhs_vec_elem_type->Is<Bool>() &&
|
|
||||||
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
|
|
||||||
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
|
|
||||||
lhs_vec->Width());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lhs_vec_elem_type->is_numeric_scalar()) {
|
|
||||||
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
|
|
||||||
lhs_vec->Width());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Binary bitwise operations
|
|
||||||
if (ast::IsBitwise(op)) {
|
|
||||||
if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bit shift expressions
|
|
||||||
if (ast::IsBitshift(op)) {
|
|
||||||
// Type validation rules are the same for left or right shift, despite
|
|
||||||
// differences in computation rules (i.e. right shift can be arithmetic or
|
|
||||||
// logical depending on lhs type).
|
|
||||||
|
|
||||||
if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() && rhs_vec_elem_type &&
|
|
||||||
rhs_vec_elem_type->Is<U32>()) {
|
|
||||||
return lhs_ty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
|
sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
|
||||||
auto* expr = sem_.Get(unary->expr);
|
auto* expr = sem_.Get(unary->expr);
|
||||||
auto* expr_ty = expr->Type();
|
auto* expr_ty = expr->Type();
|
||||||
|
@ -2472,11 +2314,8 @@ sem::Statement* Resolver::CompoundAssignmentStatement(
|
||||||
|
|
||||||
auto* lhs_ty = lhs->Type()->UnwrapRef();
|
auto* lhs_ty = lhs->Type()->UnwrapRef();
|
||||||
auto* rhs_ty = rhs->Type()->UnwrapRef();
|
auto* rhs_ty = rhs->Type()->UnwrapRef();
|
||||||
auto* ty = BinaryOpType(lhs_ty, rhs_ty, stmt->op);
|
auto* ty = intrinsic_table_->Lookup(stmt->op, lhs_ty, rhs_ty, stmt->source, true).result;
|
||||||
if (!ty) {
|
if (!ty) {
|
||||||
AddError("compound assignment operand types are invalid: " + sem_.TypeNameOf(lhs_ty) +
|
|
||||||
" " + FriendlyName(stmt->op) + " " + sem_.TypeNameOf(rhs_ty),
|
|
||||||
stmt->source);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return validator_.Assignment(stmt, ty);
|
return validator_.Assignment(stmt, ty);
|
||||||
|
|
|
@ -226,11 +226,6 @@ class Resolver {
|
||||||
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
|
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
|
||||||
bool Statements(const ast::StatementList&);
|
bool Statements(const ast::StatementList&);
|
||||||
|
|
||||||
// Resolve the result type of a binary operator.
|
|
||||||
// Returns nullptr if the types are not valid for this operator.
|
|
||||||
const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
|
|
||||||
const sem::Type* rhs_ty,
|
|
||||||
ast::BinaryOp op);
|
|
||||||
|
|
||||||
/// Resolves the WorkgroupSize for the given function, assigning it to
|
/// Resolves the WorkgroupSize for the given function, assigning it to
|
||||||
/// current_function_
|
/// current_function_
|
||||||
|
|
|
@ -1570,11 +1570,7 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
|
||||||
WrapInFunction(expr);
|
WrapInFunction(expr);
|
||||||
|
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
ASSERT_EQ(r()->error(),
|
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching overload for operator "));
|
||||||
"12:34 error: Binary expression operand types are invalid for "
|
|
||||||
"this operation: " +
|
|
||||||
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) + " " +
|
|
||||||
FriendlyName(rhs_type));
|
|
||||||
}
|
}
|
||||||
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||||
Expr_Binary_Test_Invalid,
|
Expr_Binary_Test_Invalid,
|
||||||
|
@ -1618,11 +1614,7 @@ TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
|
||||||
ASSERT_TRUE(TypeOf(expr) == result_type);
|
ASSERT_TRUE(TypeOf(expr) == result_type);
|
||||||
} else {
|
} else {
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
ASSERT_EQ(r()->error(),
|
EXPECT_THAT(r()->error(), HasSubstr("no matching overload for operator *"));
|
||||||
"12:34 error: Binary expression operand types are invalid for "
|
|
||||||
"this operation: " +
|
|
||||||
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) + " " +
|
|
||||||
FriendlyName(rhs_type));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto all_dimension_values = testing::Values(2u, 3u, 4u);
|
auto all_dimension_values = testing::Values(2u, 3u, 4u);
|
||||||
|
@ -1660,11 +1652,7 @@ TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
|
||||||
ASSERT_TRUE(TypeOf(expr) == result_type);
|
ASSERT_TRUE(TypeOf(expr) == result_type);
|
||||||
} else {
|
} else {
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
ASSERT_EQ(r()->error(),
|
EXPECT_THAT(r()->error(), HasSubstr("12:34 error: no matching overload for operator * "));
|
||||||
"12:34 error: Binary expression operand types are invalid for "
|
|
||||||
"this operation: " +
|
|
||||||
FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) + " " +
|
|
||||||
FriendlyName(rhs_type));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||||
|
|
Loading…
Reference in New Issue