tint: Fix const eval of type conversions

These were quite spectacularly broken.

Also:
* Fix the definition of 'scalar' in `intrinsics.def`. This was in part why conversions were broken, as abstracts were materialized before reaching the converter builtin when they shouldn't have been.
* Implement `ScalarArgsFrom()` helper in `const_eval_test.cc`. This is used by the new conversion tests, and also implements part of the suggestion to improve tint:1709.

Fixed: tint:1707
Bug: tint:1709
Change-Id: Iab962b671305e868f92710912d2ed07e3338c680
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105261
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
Ben Clayton 2022-10-11 20:36:48 +00:00 committed by Dawn LUCI CQ
parent feb447d9dc
commit 75bc93c0df
8 changed files with 737 additions and 406 deletions

View File

@ -10,6 +10,10 @@
* The `external_texture` overload of `textureSampleLevel()` has been deprecated. Use `textureSampleBaseClampToEdge()` instead. [tint:1671](crbug.com/tint/1671) * The `external_texture` overload of `textureSampleLevel()` has been deprecated. Use `textureSampleBaseClampToEdge()` instead. [tint:1671](crbug.com/tint/1671)
### Fixes
* Constant evaluation of type conversions where the value exceeds the limits of the target type have been fixed. [tint:1707](crbug.com/tint/1707)
## Changes for M107 ## Changes for M107
### New features ### New features

View File

@ -175,13 +175,13 @@ type __atomic_compare_exchange_result<T>
// A type matcher that can match one or more types. // // A type matcher that can match one or more types. //
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
match abstract_or_scalar: ia | fa | f32 | f16 | i32 | u32 | bool match scalar: ia | fa | f32 | f16 | i32 | u32 | bool
match scalar: f32 | f16 | i32 | u32 | bool match concrete_scalar: f32 | f16 | i32 | u32 | bool
match scalar_no_f32: i32 | f16 | u32 | bool match scalar_no_f32: ia | fa | i32 | f16 | u32 | bool
match scalar_no_f16: f32 | i32 | u32 | bool match scalar_no_f16: ia | fa | f32 | i32 | u32 | bool
match scalar_no_i32: f32 | f16 | u32 | bool match scalar_no_i32: ia | fa | f32 | f16 | u32 | bool
match scalar_no_u32: f32 | f16 | i32 | bool match scalar_no_u32: ia | fa | f32 | f16 | i32 | bool
match scalar_no_bool: f32 | f16 | i32 | u32 match scalar_no_bool: ia | fa | f32 | f16 | i32 | u32
match fia_fiu32_f16: fa | ia | f32 | i32 | u32 | f16 match fia_fiu32_f16: fa | ia | f32 | i32 | u32 | f16
match fia_fi32_f16: fa | ia | f32 | i32 | f16 match fia_fi32_f16: fa | ia | f32 | i32 | f16
match fia_fiu32: fa | ia | f32 | i32 | u32 match fia_fiu32: fa | ia | f32 | i32 | u32
@ -524,9 +524,9 @@ fn round<T: f32_f16>(T) -> T
fn round<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T> fn round<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn saturate<T: f32_f16>(T) -> T fn saturate<T: f32_f16>(T) -> T
fn saturate<T: f32_f16, N: num>(vec<N, T>) -> vec<N, T> fn saturate<T: f32_f16, N: num>(vec<N, T>) -> vec<N, T>
@const("select_bool") fn select<T: abstract_or_scalar>(T, T, bool) -> T @const("select_bool") fn select<T: scalar>(T, T, bool) -> T
@const("select_bool") fn select<T: abstract_or_scalar, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T> @const("select_bool") fn select<T: scalar, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T>
@const("select_boolvec") fn select<N: num, T: abstract_or_scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T> @const("select_boolvec") fn select<N: num, T: scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T>
fn sign<T: f32_f16>(T) -> T fn sign<T: f32_f16>(T) -> T
fn sign<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T> fn sign<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn sin<T: f32_f16>(T) -> T fn sin<T: f32_f16>(T) -> T
@ -720,9 +720,9 @@ fn textureLoad(texture: texture_external, coords: vec2<i32>) -> vec4<f32>
@const("Zero") ctor f32() -> f32 @const("Zero") ctor f32() -> f32
@const("Zero") ctor f16() -> f16 @const("Zero") ctor f16() -> f16
@const("Zero") ctor bool() -> bool @const("Zero") ctor bool() -> bool
@const("Zero") ctor vec2<T: scalar>() -> vec2<T> @const("Zero") ctor vec2<T: concrete_scalar>() -> vec2<T>
@const("Zero") ctor vec3<T: scalar>() -> vec3<T> @const("Zero") ctor vec3<T: concrete_scalar>() -> vec3<T>
@const("Zero") ctor vec4<T: scalar>() -> vec4<T> @const("Zero") ctor vec4<T: concrete_scalar>() -> vec4<T>
@const("Zero") ctor mat2x2<T: f32_f16>() -> mat2x2<T> @const("Zero") ctor mat2x2<T: f32_f16>() -> mat2x2<T>
@const("Zero") ctor mat2x3<T: f32_f16>() -> mat2x3<T> @const("Zero") ctor mat2x3<T: f32_f16>() -> mat2x3<T>
@const("Zero") ctor mat2x4<T: f32_f16>() -> mat2x4<T> @const("Zero") ctor mat2x4<T: f32_f16>() -> mat2x4<T>
@ -739,9 +739,9 @@ fn textureLoad(texture: texture_external, coords: vec2<i32>) -> vec4<f32>
@const("Identity") ctor f32(f32) -> f32 @const("Identity") ctor f32(f32) -> f32
@const("Identity") ctor f16(f16) -> f16 @const("Identity") ctor f16(f16) -> f16
@const("Identity") ctor bool(bool) -> bool @const("Identity") ctor bool(bool) -> bool
@const("Identity") ctor vec2<T: scalar>(vec2<T>) -> vec2<T> @const("Identity") ctor vec2<T: concrete_scalar>(vec2<T>) -> vec2<T>
@const("Identity") ctor vec3<T: scalar>(vec3<T>) -> vec3<T> @const("Identity") ctor vec3<T: concrete_scalar>(vec3<T>) -> vec3<T>
@const("Identity") ctor vec4<T: scalar>(vec4<T>) -> vec4<T> @const("Identity") ctor vec4<T: concrete_scalar>(vec4<T>) -> vec4<T>
@const("Identity") ctor mat2x2<T: f32_f16>(mat2x2<T>) -> mat2x2<T> @const("Identity") ctor mat2x2<T: f32_f16>(mat2x2<T>) -> mat2x2<T>
@const("Identity") ctor mat2x3<T: f32_f16>(mat2x3<T>) -> mat2x3<T> @const("Identity") ctor mat2x3<T: f32_f16>(mat2x3<T>) -> mat2x3<T>
@const("Identity") ctor mat2x4<T: f32_f16>(mat2x4<T>) -> mat2x4<T> @const("Identity") ctor mat2x4<T: f32_f16>(mat2x4<T>) -> mat2x4<T>
@ -753,24 +753,24 @@ fn textureLoad(texture: texture_external, coords: vec2<i32>) -> vec4<f32>
@const("Identity") ctor mat4x4<T: f32_f16>(mat4x4<T>) -> mat4x4<T> @const("Identity") ctor mat4x4<T: f32_f16>(mat4x4<T>) -> mat4x4<T>
// Vector constructors (splat) // Vector constructors (splat)
@const("VecSplat") ctor vec2<T: abstract_or_scalar>(T) -> vec2<T> @const("VecSplat") ctor vec2<T: scalar>(T) -> vec2<T>
@const("VecSplat") ctor vec3<T: abstract_or_scalar>(T) -> vec3<T> @const("VecSplat") ctor vec3<T: scalar>(T) -> vec3<T>
@const("VecSplat") ctor vec4<T: abstract_or_scalar>(T) -> vec4<T> @const("VecSplat") ctor vec4<T: scalar>(T) -> vec4<T>
// Vector constructors (scalar) // Vector constructors (scalar)
@const("VecCtorS") ctor vec2<T: abstract_or_scalar>(x: T, y: T) -> vec2<T> @const("VecCtorS") ctor vec2<T: scalar>(x: T, y: T) -> vec2<T>
@const("VecCtorS") ctor vec3<T: abstract_or_scalar>(x: T, y: T, z: T) -> vec3<T> @const("VecCtorS") ctor vec3<T: scalar>(x: T, y: T, z: T) -> vec3<T>
@const("VecCtorS") ctor vec4<T: abstract_or_scalar>(x: T, y: T, z: T, w: T) -> vec4<T> @const("VecCtorS") ctor vec4<T: scalar>(x: T, y: T, z: T, w: T) -> vec4<T>
// Vector constructors (mixed) // Vector constructors (mixed)
@const("VecCtorM") ctor vec3<T: abstract_or_scalar>(xy: vec2<T>, z: T) -> vec3<T> @const("VecCtorM") ctor vec3<T: scalar>(xy: vec2<T>, z: T) -> vec3<T>
@const("VecCtorM") ctor vec3<T: abstract_or_scalar>(x: T, yz: vec2<T>) -> vec3<T> @const("VecCtorM") ctor vec3<T: scalar>(x: T, yz: vec2<T>) -> vec3<T>
@const("VecCtorM") ctor vec4<T: abstract_or_scalar>(xy: vec2<T>, z: T, w: T) -> vec4<T> @const("VecCtorM") ctor vec4<T: scalar>(xy: vec2<T>, z: T, w: T) -> vec4<T>
@const("VecCtorM") ctor vec4<T: abstract_or_scalar>(x: T, yz: vec2<T>, w: T) -> vec4<T> @const("VecCtorM") ctor vec4<T: scalar>(x: T, yz: vec2<T>, w: T) -> vec4<T>
@const("VecCtorM") ctor vec4<T: abstract_or_scalar>(x: T, y: T, zw: vec2<T>) -> vec4<T> @const("VecCtorM") ctor vec4<T: scalar>(x: T, y: T, zw: vec2<T>) -> vec4<T>
@const("VecCtorM") ctor vec4<T: abstract_or_scalar>(xy: vec2<T>, zw: vec2<T>) -> vec4<T> @const("VecCtorM") ctor vec4<T: scalar>(xy: vec2<T>, zw: vec2<T>) -> vec4<T>
@const("VecCtorM") ctor vec4<T: abstract_or_scalar>(xyz: vec3<T>, w: T) -> vec4<T> @const("VecCtorM") ctor vec4<T: scalar>(xyz: vec3<T>, w: T) -> vec4<T>
@const("VecCtorM") ctor vec4<T: abstract_or_scalar>(x: T, zyw: vec3<T>) -> vec4<T> @const("VecCtorM") ctor vec4<T: scalar>(x: T, zyw: vec3<T>) -> vec4<T>
// Matrix constructors (scalar) // Matrix constructors (scalar)
@const("MatCtorS") @const("MatCtorS")
@ -952,11 +952,11 @@ op % <T: fiu32_f16, N: num> (T, vec<N, T>) -> vec<N, T>
op && (bool, bool) -> bool op && (bool, bool) -> bool
op || (bool, bool) -> bool op || (bool, bool) -> bool
@const op == <T: abstract_or_scalar>(T, T) -> bool @const op == <T: scalar>(T, T) -> bool
@const op == <T: abstract_or_scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool> @const op == <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
@const op != <T: abstract_or_scalar>(T, T) -> bool @const op != <T: scalar>(T, T) -> bool
@const op != <T: abstract_or_scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool> @const op != <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
@const op < <T: fia_fiu32_f16>(T, T) -> bool @const op < <T: fia_fiu32_f16>(T, T) -> bool
@const op < <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool> @const op < <T: fia_fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>

View File

@ -231,27 +231,29 @@ struct Element : ImplConstant {
return this; return this;
} }
return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult { return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
// `T` is the source type, `value` is the source value. // `value` is the source value.
// `FROM` is the source type.
// `TO` is the target type. // `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>; using TO = std::decay_t<decltype(zero_to)>;
using FROM = T;
if constexpr (std::is_same_v<TO, bool>) { if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool] // [x -> bool]
return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value)); return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value));
} else if constexpr (std::is_same_v<T, bool>) { } else if constexpr (std::is_same_v<FROM, bool>) {
// [bool -> x] // [bool -> x]
return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0)); return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(value)) { } else if (auto conv = CheckedConvert<TO>(value)) {
// Conversion success // Conversion success
return builder.create<Element<TO>>(target_ty, conv.Get()); return builder.create<Element<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases --- // --- Below this point are the failure cases ---
} else if constexpr (IsAbstract<T>) { } else if constexpr (IsAbstract<FROM>) {
// [abstract-numeric -> x] - materialization failure // [abstract-numeric -> x] - materialization failure
std::stringstream ss; std::stringstream ss;
ss << "value " << value << " cannot be represented as "; ss << "value " << value << " cannot be represented as ";
ss << "'" << builder.FriendlyName(target_ty) << "'"; ss << "'" << builder.FriendlyName(target_ty) << "'";
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source); builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
return utils::Failure; return utils::Failure;
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) { } else if constexpr (IsFloatingPoint<TO>) {
// [x -> floating-point] - number not exactly representable // [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion // https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) { switch (conv.Failure()) {
@ -260,8 +262,8 @@ struct Element : ImplConstant {
case ConversionFailure::kExceedsPositiveLimit: case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO::Inf()); return builder.create<Element<TO>>(target_ty, TO::Inf());
} }
} else { } else if constexpr (IsFloatingPoint<FROM>) {
// [x -> integer] - number not exactly representable // [floating-point -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion // https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) { switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit: case ConversionFailure::kExceedsNegativeLimit:
@ -269,6 +271,10 @@ struct Element : ImplConstant {
case ConversionFailure::kExceedsPositiveLimit: case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO::Highest()); return builder.create<Element<TO>>(target_ty, TO::Highest());
} }
} else if constexpr (IsIntegral<FROM>) {
// [integer -> integer] - number not exactly representable
// Static cast
return builder.create<Element<TO>>(target_ty, static_cast<TO>(value));
} }
return nullptr; // Expression is not constant. return nullptr; // Expression is not constant.
}); });
@ -842,11 +848,7 @@ ConstEval::Result ConstEval::Conv(const sem::Type* ty,
return nullptr; // Single argument is not constant. return nullptr; // Single argument is not constant.
} }
if (auto conv = Convert(ty, args[0], source)) { return Convert(ty, args[0], source);
return conv.Get();
}
return nullptr;
} }
ConstEval::Result ConstEval::Zero(const sem::Type* ty, ConstEval::Result ConstEval::Zero(const sem::Type* ty,

View File

@ -44,6 +44,30 @@ const auto kPiOver4 = T(UnwrapNumber<T>(0.785398163397448309616));
template <typename T> template <typename T>
const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846)); const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846));
/// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args
void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
Switch(
c->Type(), //
[&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
[&](const sem::I32*) { args.values.Push(c->As<i32>()); },
[&](const sem::U32*) { args.values.Push(c->As<u32>()); },
[&](const sem::F32*) { args.values.Push(c->As<f32>()); },
[&](const sem::F16*) { args.values.Push(c->As<f16>()); },
[&](Default) {
size_t i = 0;
while (auto* child = c->Index(i++)) {
CollectScalarArgs(child, args);
}
});
}
/// Walks the sem::Constant @p c, returning all the inner-most scalar values.
builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
builder::ScalarArgs out;
CollectScalarArgs(c, out);
return out;
}
template <typename T> template <typename T>
constexpr auto Negate(const Number<T>& v) { constexpr auto Negate(const Number<T>& v) {
if constexpr (std::is_integral_v<T>) { if constexpr (std::is_integral_v<T>) {
@ -1211,283 +1235,6 @@ TEST_F(ResolverConstEvalTest, Vec3_MixConstruct_all_false) {
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<bool>(), false); EXPECT_EQ(sem->ConstantValue()->Index(2)->As<bool>(), false);
} }
TEST_F(ResolverConstEvalTest, Vec3_Convert_f32_to_i32) {
auto* expr = vec3<i32>(vec3<f32>(1.1_f, 2.2_f, 3.3_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), 1);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), 2);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), 3);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f32) {
auto* expr = vec3<f32>(vec3<u32>(10_u, 20_u, 30_u));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 10.f);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), 20.f);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 30.f);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_f16_to_i32) {
Enable(ast::Extension::kF16);
auto* expr = vec3<i32>(vec3<f16>(1.1_h, 2.2_h, 3.3_h));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), 1_i);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), 2_i);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), 3_i);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f16) {
Enable(ast::Extension::kF16);
auto* expr = vec3<f16>(vec3<u32>(10_u, 20_u, 30_u));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 10.f);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), 20.f);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 30.f);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_i32) {
auto* expr = vec3<i32>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), i32::Highest());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), i32::Lowest());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), i32::Highest());
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_u32) {
auto* expr = vec3<u32>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::U32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), u32::Highest());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), u32::Lowest());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), u32::Highest());
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_f16) {
Enable(ast::Extension::kF16);
auto* expr = vec3<f16>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
constexpr auto kInfinity = std::numeric_limits<double>::infinity();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), kInfinity);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), -kInfinity);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), kInfinity);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Small_f32_to_f16) {
Enable(ast::Extension::kF16);
auto* expr = vec3<f16>(vec3<f32>(1e-20_f, -2e-30_f, 3e-40_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 0.0);
EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(0)->As<AFloat>().value));
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), -0.0);
EXPECT_TRUE(std::signbit(sem->ConstantValue()->Index(1)->As<AFloat>().value));
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 0.0);
EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->As<AFloat>().value));
}
TEST_F(ResolverConstEvalTest, Mat2x3_ZeroInit_f32) { TEST_F(ResolverConstEvalTest, Mat2x3_ZeroInit_f32) {
auto* expr = mat2x3<f32>(); auto* expr = mat2x3<f32>();
WrapInFunction(expr); WrapInFunction(expr);
@ -2474,6 +2221,519 @@ TEST_F(ResolverConstEvalTest, Struct_Array_Construct) {
EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 3_f); EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As<f32>(), 3_f);
} }
////////////////////////////////////////////////////////////////////////////////////////////////////
// Conversion
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace conv {
using Scalar = std::variant< //
builder::Value<AInt>,
builder::Value<AFloat>,
builder::Value<u32>,
builder::Value<i32>,
builder::Value<f32>,
builder::Value<f16>,
builder::Value<bool>>;
static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) {
std::visit(
[&](auto&& v) {
using ValueType = std::decay_t<decltype(v)>;
o << ValueType::DataType::Name() << "(";
for (auto& a : v.args.values) {
o << std::get<typename ValueType::ElementType>(a);
if (&a != &v.args.values.Back()) {
o << ", ";
}
}
o << ")";
},
scalar);
return o;
}
enum class Kind {
kScalar,
kVector,
};
static std::ostream& operator<<(std::ostream& o, const Kind& k) {
switch (k) {
case Kind::kScalar:
return o << "scalar";
case Kind::kVector:
return o << "vector";
}
return o << "<unknown>";
}
struct Case {
Scalar input;
Scalar expected;
builder::CreatePtrs type;
bool unrepresentable = false;
};
static std::ostream& operator<<(std::ostream& o, const Case& c) {
if (c.unrepresentable) {
o << "[unrepresentable] input: " << c.input;
} else {
o << "input: " << c.input << ", expected: " << c.expected;
}
return o << ", type: " << c.type;
}
template <typename TO, typename FROM>
Case Success(FROM input, TO expected) {
return {builder::Val(input), builder::Val(expected), builder::CreatePtrsFor<TO>()};
}
template <typename TO, typename FROM>
Case Unrepresentable(FROM input) {
return {builder::Val(input), builder::Val(0_i), builder::CreatePtrsFor<TO>(),
/* unrepresentable */ true};
}
using ResolverConstEvalConvTest = ResolverTestWithParam<std::tuple<Kind, Case>>;
TEST_P(ResolverConstEvalConvTest, Test) {
const auto& kind = std::get<0>(GetParam());
const auto& input = std::get<1>(GetParam()).input;
const auto& expected = std::get<1>(GetParam()).expected;
const auto& type = std::get<1>(GetParam()).type;
const auto unrepresentable = std::get<1>(GetParam()).unrepresentable;
auto* input_val = std::visit([&](auto val) { return val.Expr(*this); }, input);
auto* expr = Construct(type.ast(*this), input_val);
if (kind == Kind::kVector) {
expr = Construct(ty.vec(nullptr, 3), expr);
}
WrapInFunction(expr);
auto* target_sem_ty = type.sem(*this);
if (kind == Kind::kVector) {
target_sem_ty = create<sem::Vector>(target_sem_ty, 3u);
}
if (unrepresentable) {
ASSERT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), testing::HasSubstr("cannot be represented as"));
} else {
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
EXPECT_TYPE(sem->Type(), target_sem_ty);
ASSERT_NE(sem->ConstantValue(), nullptr);
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
auto expected_values = std::visit([&](auto&& val) { return val.args; }, expected);
if (kind == Kind::kVector) {
expected_values.values.Push(expected_values.values[0]);
expected_values.values.Push(expected_values.values[0]);
}
auto got_values = ScalarArgsFrom(sem->ConstantValue());
EXPECT_EQ(expected_values, got_values);
}
}
INSTANTIATE_TEST_SUITE_P(ScalarAndVector,
ResolverConstEvalConvTest,
testing::Combine(testing::Values(Kind::kScalar, Kind::kVector),
testing::ValuesIn({
// TODO(crbug.com/tint/1502): Add f16 tests
// i32 -> u32
Success(0_i, 0_u),
Success(1_i, 1_u),
Success(-1_i, 0xffffffff_u),
Success(2_i, 2_u),
Success(-2_i, 0xfffffffe_u),
// i32 -> f32
Success(0_i, 0_f),
Success(1_i, 1_f),
Success(-1_i, -1_f),
Success(2_i, 2_f),
Success(-2_i, -2_f),
// i32 -> bool
Success(0_i, false),
Success(1_i, true),
Success(-1_i, true),
Success(2_i, true),
Success(-2_i, true),
// u32 -> i32
Success(0_u, 0_i),
Success(1_u, 1_i),
Success(0xffffffff_u, -1_i),
Success(2_u, 2_i),
Success(0xfffffffe_u, -2_i),
// u32 -> f32
Success(0_u, 0_f),
Success(1_u, 1_f),
Success(2_u, 2_f),
Success(0xffffffff_u, 0xffffffff_f),
// u32 -> bool
Success(0_u, false),
Success(1_u, true),
Success(2_u, true),
Success(0xffffffff_u, true),
// f32 -> i32
Success(0_f, 0_i),
Success(1_f, 1_i),
Success(2_f, 2_i),
Success(1e20_f, i32::Highest()),
Success(-1e20_f, i32::Lowest()),
// f32 -> u32
Success(0_f, 0_i),
Success(1_f, 1_i),
Success(-1_f, u32::Lowest()),
Success(2_f, 2_i),
Success(1e20_f, u32::Highest()),
Success(-1e20_f, u32::Lowest()),
// f32 -> bool
Success(0_f, false),
Success(1_f, true),
Success(-1_f, true),
Success(2_f, true),
Success(1e20_f, true),
Success(-1e20_f, true),
// abstract-int -> i32
Success(0_a, 0_i),
Success(1_a, 1_i),
Success(-1_a, -1_i),
Success(0x7fffffff_a, i32::Highest()),
Success(-0x80000000_a, i32::Lowest()),
Unrepresentable<i32>(0x80000000_a),
// abstract-int -> u32
Success(0_a, 0_u),
Success(1_a, 1_u),
Success(0xffffffff_a, 0xffffffff_u),
Unrepresentable<u32>(0x100000000_a),
Unrepresentable<u32>(-1_a),
// abstract-int -> f32
Success(0_a, 0_f),
Success(1_a, 1_f),
Success(0xffffffff_a, 0xffffffff_f),
Success(0x100000000_a, 0x100000000_f),
Success(-0x100000000_a, -0x100000000_f),
Success(0x7fffffffffffffff_a, 0x7fffffffffffffff_f),
Success(-0x7fffffffffffffff_a, -0x7fffffffffffffff_f),
// abstract-int -> bool
Success(0_a, false),
Success(1_a, true),
Success(0xffffffff_a, true),
Success(0x100000000_a, true),
Success(-0x100000000_a, true),
Success(0x7fffffffffffffff_a, true),
Success(-0x7fffffffffffffff_a, true),
// abstract-float -> i32
Success(0.0_a, 0_i),
Success(1.0_a, 1_i),
Success(-1.0_a, -1_i),
Success(AFloat(0x7fffffff), i32::Highest()),
Success(-AFloat(0x80000000), i32::Lowest()),
Unrepresentable<i32>(0x80000000_a),
// abstract-float -> u32
Success(0.0_a, 0_u),
Success(1.0_a, 1_u),
Success(AFloat(0xffffffff), 0xffffffff_u),
Unrepresentable<u32>(AFloat(0x100000000)),
Unrepresentable<u32>(AFloat(-1)),
// abstract-float -> f32
Success(0.0_a, 0_f),
Success(1.0_a, 1_f),
Success(AFloat(0xffffffff), 0xffffffff_f),
Success(AFloat(0x100000000), 0x100000000_f),
Success(-AFloat(0x100000000), -0x100000000_f),
Unrepresentable<f32>(1e40_a),
Unrepresentable<f32>(-1e40_a),
// abstract-float -> bool
Success(0.0_a, false),
Success(1.0_a, true),
Success(AFloat(0xffffffff), true),
Success(AFloat(0x100000000), true),
Success(-AFloat(0x100000000), true),
Success(1e40_a, true),
Success(-1e40_a, true),
})));
TEST_F(ResolverConstEvalTest, Vec3_Convert_f32_to_i32) {
auto* expr = vec3<i32>(vec3<f32>(1.1_f, 2.2_f, 3.3_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), 1);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), 2);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), 3);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f32) {
auto* expr = vec3<f32>(vec3<u32>(10_u, 20_u, 30_u));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 10.f);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), 20.f);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 30.f);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_f16_to_i32) {
Enable(ast::Extension::kF16);
auto* expr = vec3<i32>(vec3<f16>(1.1_h, 2.2_h, 3.3_h));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), 1_i);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), 2_i);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), 3_i);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f16) {
Enable(ast::Extension::kF16);
auto* expr = vec3<f16>(vec3<u32>(10_u, 20_u, 30_u));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 10.f);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), 20.f);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 30.f);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_i32) {
auto* expr = vec3<i32>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), i32::Highest());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), i32::Lowest());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), i32::Highest());
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_u32) {
auto* expr = vec3<u32>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::U32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AInt>(), u32::Highest());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AInt>(), u32::Lowest());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AInt>(), u32::Highest());
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_f16) {
Enable(ast::Extension::kF16);
auto* expr = vec3<f16>(vec3<f32>(1e10_f, -1e20_f, 1e30_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
constexpr auto kInfinity = std::numeric_limits<double>::infinity();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), kInfinity);
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), -kInfinity);
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), kInfinity);
}
TEST_F(ResolverConstEvalTest, Vec3_Convert_Small_f32_to_f16) {
Enable(ast::Extension::kF16);
auto* expr = vec3<f16>(vec3<f32>(1e-20_f, -2e-30_f, 3e-40_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
EXPECT_FALSE(sem->ConstantValue()->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->AllZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero());
EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(0)->As<AFloat>(), 0.0);
EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(0)->As<AFloat>().value));
EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero());
EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(1)->As<AFloat>(), -0.0);
EXPECT_TRUE(std::signbit(sem->ConstantValue()->Index(1)->As<AFloat>().value));
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero());
EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero());
EXPECT_EQ(sem->ConstantValue()->Index(2)->As<AFloat>(), 0.0);
EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->As<AFloat>().value));
}
} // namespace conv
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Indexing // Indexing
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1550,8 +1550,8 @@ std::string AtomicCompareExchangeResult::String(MatchState* state) const {
return "__atomic_compare_exchange_result<" + T + ">"; return "__atomic_compare_exchange_result<" + T + ">";
} }
/// TypeMatcher for 'match abstract_or_scalar' /// TypeMatcher for 'match scalar'
class AbstractOrScalar : public TypeMatcher { class Scalar : public TypeMatcher {
public: public:
/// Checks whether the given type matches the matcher rules, and returns the /// Checks whether the given type matches the matcher rules, and returns the
/// expected, canonicalized type on success. /// expected, canonicalized type on success.
@ -1566,7 +1566,7 @@ class AbstractOrScalar : public TypeMatcher {
std::string String(MatchState* state) const override; std::string String(MatchState* state) const override;
}; };
const sem::Type* AbstractOrScalar::Match(MatchState& state, const sem::Type* ty) const { const sem::Type* Scalar::Match(MatchState& state, const sem::Type* ty) const {
if (match_ia(state, ty)) { if (match_ia(state, ty)) {
return build_ia(state); return build_ia(state);
} }
@ -1591,7 +1591,7 @@ const sem::Type* AbstractOrScalar::Match(MatchState& state, const sem::Type* ty)
return nullptr; return nullptr;
} }
std::string AbstractOrScalar::String(MatchState*) const { std::string Scalar::String(MatchState*) const {
std::stringstream ss; std::stringstream ss;
// Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support
// template arguments, nor can they match sub-types. As such, they have no use for the MatchState. // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
@ -1599,8 +1599,8 @@ std::string AbstractOrScalar::String(MatchState*) const {
return ss.str(); return ss.str();
} }
/// TypeMatcher for 'match scalar' /// TypeMatcher for 'match concrete_scalar'
class Scalar : public TypeMatcher { class ConcreteScalar : public TypeMatcher {
public: public:
/// Checks whether the given type matches the matcher rules, and returns the /// Checks whether the given type matches the matcher rules, and returns the
/// expected, canonicalized type on success. /// expected, canonicalized type on success.
@ -1615,7 +1615,7 @@ class Scalar : public TypeMatcher {
std::string String(MatchState* state) const override; std::string String(MatchState* state) const override;
}; };
const sem::Type* Scalar::Match(MatchState& state, const sem::Type* ty) const { const sem::Type* ConcreteScalar::Match(MatchState& state, const sem::Type* ty) const {
if (match_i32(state, ty)) { if (match_i32(state, ty)) {
return build_i32(state); return build_i32(state);
} }
@ -1634,7 +1634,7 @@ const sem::Type* Scalar::Match(MatchState& state, const sem::Type* ty) const {
return nullptr; return nullptr;
} }
std::string Scalar::String(MatchState*) const { std::string ConcreteScalar::String(MatchState*) const {
std::stringstream ss; std::stringstream ss;
// Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support
// template arguments, nor can they match sub-types. As such, they have no use for the MatchState. // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
@ -1659,6 +1659,12 @@ class ScalarNoF32 : public TypeMatcher {
}; };
const sem::Type* ScalarNoF32::Match(MatchState& state, const sem::Type* ty) const { const sem::Type* ScalarNoF32::Match(MatchState& state, const sem::Type* ty) const {
if (match_ia(state, ty)) {
return build_ia(state);
}
if (match_fa(state, ty)) {
return build_fa(state);
}
if (match_i32(state, ty)) { if (match_i32(state, ty)) {
return build_i32(state); return build_i32(state);
} }
@ -1678,7 +1684,7 @@ std::string ScalarNoF32::String(MatchState*) const {
std::stringstream ss; std::stringstream ss;
// Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support
// template arguments, nor can they match sub-types. As such, they have no use for the MatchState. // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
ss << I32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << I32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr);
return ss.str(); return ss.str();
} }
@ -1699,6 +1705,12 @@ class ScalarNoF16 : public TypeMatcher {
}; };
const sem::Type* ScalarNoF16::Match(MatchState& state, const sem::Type* ty) const { const sem::Type* ScalarNoF16::Match(MatchState& state, const sem::Type* ty) const {
if (match_ia(state, ty)) {
return build_ia(state);
}
if (match_fa(state, ty)) {
return build_fa(state);
}
if (match_i32(state, ty)) { if (match_i32(state, ty)) {
return build_i32(state); return build_i32(state);
} }
@ -1718,7 +1730,7 @@ std::string ScalarNoF16::String(MatchState*) const {
std::stringstream ss; std::stringstream ss;
// Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support
// template arguments, nor can they match sub-types. As such, they have no use for the MatchState. // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
ss << F32().String(nullptr) << ", " << I32().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << I32().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr);
return ss.str(); return ss.str();
} }
@ -1739,6 +1751,12 @@ class ScalarNoI32 : public TypeMatcher {
}; };
const sem::Type* ScalarNoI32::Match(MatchState& state, const sem::Type* ty) const { const sem::Type* ScalarNoI32::Match(MatchState& state, const sem::Type* ty) const {
if (match_ia(state, ty)) {
return build_ia(state);
}
if (match_fa(state, ty)) {
return build_fa(state);
}
if (match_u32(state, ty)) { if (match_u32(state, ty)) {
return build_u32(state); return build_u32(state);
} }
@ -1758,7 +1776,7 @@ std::string ScalarNoI32::String(MatchState*) const {
std::stringstream ss; std::stringstream ss;
// Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support
// template arguments, nor can they match sub-types. As such, they have no use for the MatchState. // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
ss << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr);
return ss.str(); return ss.str();
} }
@ -1779,6 +1797,12 @@ class ScalarNoU32 : public TypeMatcher {
}; };
const sem::Type* ScalarNoU32::Match(MatchState& state, const sem::Type* ty) const { const sem::Type* ScalarNoU32::Match(MatchState& state, const sem::Type* ty) const {
if (match_ia(state, ty)) {
return build_ia(state);
}
if (match_fa(state, ty)) {
return build_fa(state);
}
if (match_i32(state, ty)) { if (match_i32(state, ty)) {
return build_i32(state); return build_i32(state);
} }
@ -1798,7 +1822,7 @@ std::string ScalarNoU32::String(MatchState*) const {
std::stringstream ss; std::stringstream ss;
// Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support
// template arguments, nor can they match sub-types. As such, they have no use for the MatchState. // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
ss << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << Bool().String(nullptr); ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << Bool().String(nullptr);
return ss.str(); return ss.str();
} }
@ -1819,6 +1843,12 @@ class ScalarNoBool : public TypeMatcher {
}; };
const sem::Type* ScalarNoBool::Match(MatchState& state, const sem::Type* ty) const { const sem::Type* ScalarNoBool::Match(MatchState& state, const sem::Type* ty) const {
if (match_ia(state, ty)) {
return build_ia(state);
}
if (match_fa(state, ty)) {
return build_fa(state);
}
if (match_i32(state, ty)) { if (match_i32(state, ty)) {
return build_i32(state); return build_i32(state);
} }
@ -1838,7 +1868,7 @@ std::string ScalarNoBool::String(MatchState*) const {
std::stringstream ss; std::stringstream ss;
// Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support
// template arguments, nor can they match sub-types. As such, they have no use for the MatchState. // template arguments, nor can they match sub-types. As such, they have no use for the MatchState.
ss << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << U32().String(nullptr); ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << U32().String(nullptr);
return ss.str(); return ss.str();
} }
@ -2580,8 +2610,8 @@ class Matchers {
FrexpResult FrexpResult_; FrexpResult FrexpResult_;
FrexpResultVec FrexpResultVec_; FrexpResultVec FrexpResultVec_;
AtomicCompareExchangeResult AtomicCompareExchangeResult_; AtomicCompareExchangeResult AtomicCompareExchangeResult_;
AbstractOrScalar AbstractOrScalar_;
Scalar Scalar_; Scalar Scalar_;
ConcreteScalar ConcreteScalar_;
ScalarNoF32 ScalarNoF32_; ScalarNoF32 ScalarNoF32_;
ScalarNoF16 ScalarNoF16_; ScalarNoF16 ScalarNoF16_;
ScalarNoI32 ScalarNoI32_; ScalarNoI32 ScalarNoI32_;
@ -2666,8 +2696,8 @@ class Matchers {
/* [47] */ &FrexpResult_, /* [47] */ &FrexpResult_,
/* [48] */ &FrexpResultVec_, /* [48] */ &FrexpResultVec_,
/* [49] */ &AtomicCompareExchangeResult_, /* [49] */ &AtomicCompareExchangeResult_,
/* [50] */ &AbstractOrScalar_, /* [50] */ &Scalar_,
/* [51] */ &Scalar_, /* [51] */ &ConcreteScalar_,
/* [52] */ &ScalarNoF32_, /* [52] */ &ScalarNoF32_,
/* [53] */ &ScalarNoF16_, /* [53] */ &ScalarNoF16_,
/* [54] */ &ScalarNoI32_, /* [54] */ &ScalarNoI32_,
@ -14348,9 +14378,9 @@ constexpr IntrinsicInfo kBuiltins[] = {
}, },
{ {
/* [67] */ /* [67] */
/* fn select<T : abstract_or_scalar>(T, T, bool) -> T */ /* fn select<T : scalar>(T, T, bool) -> T */
/* fn select<T : abstract_or_scalar, N : num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T> */ /* fn select<T : scalar, N : num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T> */
/* fn select<N : num, T : abstract_or_scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T> */ /* fn select<N : num, T : scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T> */
/* num overloads */ 3, /* num overloads */ 3,
/* overloads */ &kOverloads[276], /* overloads */ &kOverloads[276],
}, },
@ -14876,15 +14906,15 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
}, },
{ {
/* [10] */ /* [10] */
/* op ==<T : abstract_or_scalar>(T, T) -> bool */ /* op ==<T : scalar>(T, T) -> bool */
/* op ==<T : abstract_or_scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */ /* op ==<T : scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2, /* num overloads */ 2,
/* overloads */ &kOverloads[408], /* overloads */ &kOverloads[408],
}, },
{ {
/* [11] */ /* [11] */
/* op !=<T : abstract_or_scalar>(T, T) -> bool */ /* op !=<T : scalar>(T, T) -> bool */
/* op !=<T : abstract_or_scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */ /* op !=<T : scalar, N : num>(vec<N, T>, vec<N, T>) -> vec<N, bool> */
/* num overloads */ 2, /* num overloads */ 2,
/* overloads */ &kOverloads[394], /* overloads */ &kOverloads[394],
}, },
@ -14995,10 +15025,10 @@ constexpr IntrinsicInfo kConstructorsAndConverters[] = {
}, },
{ {
/* [5] */ /* [5] */
/* ctor vec2<T : scalar>() -> vec2<T> */ /* ctor vec2<T : concrete_scalar>() -> vec2<T> */
/* ctor vec2<T : scalar>(vec2<T>) -> vec2<T> */ /* ctor vec2<T : concrete_scalar>(vec2<T>) -> vec2<T> */
/* ctor vec2<T : abstract_or_scalar>(T) -> vec2<T> */ /* ctor vec2<T : scalar>(T) -> vec2<T> */
/* ctor vec2<T : abstract_or_scalar>(x: T, y: T) -> vec2<T> */ /* ctor vec2<T : scalar>(x: T, y: T) -> vec2<T> */
/* conv vec2<T : f32, U : scalar_no_f32>(vec2<U>) -> vec2<f32> */ /* conv vec2<T : f32, U : scalar_no_f32>(vec2<U>) -> vec2<f32> */
/* conv vec2<T : f16, U : scalar_no_f16>(vec2<U>) -> vec2<f16> */ /* conv vec2<T : f16, U : scalar_no_f16>(vec2<U>) -> vec2<f16> */
/* conv vec2<T : i32, U : scalar_no_i32>(vec2<U>) -> vec2<i32> */ /* conv vec2<T : i32, U : scalar_no_i32>(vec2<U>) -> vec2<i32> */
@ -15009,12 +15039,12 @@ constexpr IntrinsicInfo kConstructorsAndConverters[] = {
}, },
{ {
/* [6] */ /* [6] */
/* ctor vec3<T : scalar>() -> vec3<T> */ /* ctor vec3<T : concrete_scalar>() -> vec3<T> */
/* ctor vec3<T : scalar>(vec3<T>) -> vec3<T> */ /* ctor vec3<T : concrete_scalar>(vec3<T>) -> vec3<T> */
/* ctor vec3<T : abstract_or_scalar>(T) -> vec3<T> */ /* ctor vec3<T : scalar>(T) -> vec3<T> */
/* ctor vec3<T : abstract_or_scalar>(x: T, y: T, z: T) -> vec3<T> */ /* ctor vec3<T : scalar>(x: T, y: T, z: T) -> vec3<T> */
/* ctor vec3<T : abstract_or_scalar>(xy: vec2<T>, z: T) -> vec3<T> */ /* ctor vec3<T : scalar>(xy: vec2<T>, z: T) -> vec3<T> */
/* ctor vec3<T : abstract_or_scalar>(x: T, yz: vec2<T>) -> vec3<T> */ /* ctor vec3<T : scalar>(x: T, yz: vec2<T>) -> vec3<T> */
/* conv vec3<T : f32, U : scalar_no_f32>(vec3<U>) -> vec3<f32> */ /* conv vec3<T : f32, U : scalar_no_f32>(vec3<U>) -> vec3<f32> */
/* conv vec3<T : f16, U : scalar_no_f16>(vec3<U>) -> vec3<f16> */ /* conv vec3<T : f16, U : scalar_no_f16>(vec3<U>) -> vec3<f16> */
/* conv vec3<T : i32, U : scalar_no_i32>(vec3<U>) -> vec3<i32> */ /* conv vec3<T : i32, U : scalar_no_i32>(vec3<U>) -> vec3<i32> */
@ -15025,16 +15055,16 @@ constexpr IntrinsicInfo kConstructorsAndConverters[] = {
}, },
{ {
/* [7] */ /* [7] */
/* ctor vec4<T : scalar>() -> vec4<T> */ /* ctor vec4<T : concrete_scalar>() -> vec4<T> */
/* ctor vec4<T : scalar>(vec4<T>) -> vec4<T> */ /* ctor vec4<T : concrete_scalar>(vec4<T>) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(T) -> vec4<T> */ /* ctor vec4<T : scalar>(T) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(x: T, y: T, z: T, w: T) -> vec4<T> */ /* ctor vec4<T : scalar>(x: T, y: T, z: T, w: T) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(xy: vec2<T>, z: T, w: T) -> vec4<T> */ /* ctor vec4<T : scalar>(xy: vec2<T>, z: T, w: T) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(x: T, yz: vec2<T>, w: T) -> vec4<T> */ /* ctor vec4<T : scalar>(x: T, yz: vec2<T>, w: T) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(x: T, y: T, zw: vec2<T>) -> vec4<T> */ /* ctor vec4<T : scalar>(x: T, y: T, zw: vec2<T>) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(xy: vec2<T>, zw: vec2<T>) -> vec4<T> */ /* ctor vec4<T : scalar>(xy: vec2<T>, zw: vec2<T>) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(xyz: vec3<T>, w: T) -> vec4<T> */ /* ctor vec4<T : scalar>(xyz: vec3<T>, w: T) -> vec4<T> */
/* ctor vec4<T : abstract_or_scalar>(x: T, zyw: vec3<T>) -> vec4<T> */ /* ctor vec4<T : scalar>(x: T, zyw: vec3<T>) -> vec4<T> */
/* conv vec4<T : f32, U : scalar_no_f32>(vec4<U>) -> vec4<f32> */ /* conv vec4<T : f32, U : scalar_no_f32>(vec4<U>) -> vec4<f32> */
/* conv vec4<T : f16, U : scalar_no_f16>(vec4<U>) -> vec4<f16> */ /* conv vec4<T : f16, U : scalar_no_f16>(vec4<U>) -> vec4<f16> */
/* conv vec4<T : i32, U : scalar_no_i32>(vec4<U>) -> vec4<i32> */ /* conv vec4<T : i32, U : scalar_no_i32>(vec4<U>) -> vec4<i32> */

View File

@ -793,11 +793,11 @@ TEST_F(IntrinsicTableTest, MismatchTypeConstructorImplicit) {
vec3() -> vec3<T> where: T is f32, f16, i32, u32 or bool vec3() -> vec3<T> where: T is f32, f16, i32, u32 or bool
5 candidate conversions: 5 candidate conversions:
vec3(vec3<U>) -> vec3<f32> where: T is f32, U is i32, f16, u32 or bool vec3(vec3<U>) -> vec3<f32> where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool
vec3(vec3<U>) -> vec3<f16> where: T is f16, U is f32, i32, u32 or bool vec3(vec3<U>) -> vec3<f16> where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool
vec3(vec3<U>) -> vec3<i32> where: T is i32, U is f32, f16, u32 or bool vec3(vec3<U>) -> vec3<i32> where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool
vec3(vec3<U>) -> vec3<u32> where: T is u32, U is f32, f16, i32 or bool vec3(vec3<U>) -> vec3<u32> where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool
vec3(vec3<U>) -> vec3<bool> where: T is bool, U is f32, f16, i32 or u32 vec3(vec3<U>) -> vec3<bool> where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32
)"); )");
} }
@ -819,11 +819,11 @@ TEST_F(IntrinsicTableTest, MismatchTypeConstructorExplicit) {
vec3() -> vec3<T> where: T is f32, f16, i32, u32 or bool vec3() -> vec3<T> where: T is f32, f16, i32, u32 or bool
5 candidate conversions: 5 candidate conversions:
vec3(vec3<U>) -> vec3<f32> where: T is f32, U is i32, f16, u32 or bool vec3(vec3<U>) -> vec3<f32> where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool
vec3(vec3<U>) -> vec3<f16> where: T is f16, U is f32, i32, u32 or bool vec3(vec3<U>) -> vec3<f16> where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool
vec3(vec3<U>) -> vec3<i32> where: T is i32, U is f32, f16, u32 or bool vec3(vec3<U>) -> vec3<i32> where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool
vec3(vec3<U>) -> vec3<u32> where: T is u32, U is f32, f16, i32 or bool vec3(vec3<U>) -> vec3<u32> where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool
vec3(vec3<U>) -> vec3<bool> where: T is bool, U is f32, f16, i32 or u32 vec3(vec3<U>) -> vec3<bool> where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32
)"); )");
} }
@ -875,11 +875,11 @@ TEST_F(IntrinsicTableTest, MismatchTypeConversion) {
vec3(x: T, y: T, z: T) -> vec3<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool vec3(x: T, y: T, z: T) -> vec3<T> where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool
5 candidate conversions: 5 candidate conversions:
vec3(vec3<U>) -> vec3<f32> where: T is f32, U is i32, f16, u32 or bool vec3(vec3<U>) -> vec3<f32> where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool
vec3(vec3<U>) -> vec3<f16> where: T is f16, U is f32, i32, u32 or bool vec3(vec3<U>) -> vec3<f16> where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool
vec3(vec3<U>) -> vec3<i32> where: T is i32, U is f32, f16, u32 or bool vec3(vec3<U>) -> vec3<i32> where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool
vec3(vec3<U>) -> vec3<u32> where: T is u32, U is f32, f16, i32 or bool vec3(vec3<U>) -> vec3<u32> where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool
vec3(vec3<U>) -> vec3<bool> where: T is bool, U is f32, f16, i32 or u32 vec3(vec3<U>) -> vec3<bool> where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32
)"); )");
} }
@ -904,7 +904,7 @@ TEST_F(IntrinsicTableTest, OverloadResolution) {
ASSERT_NE(result.target, nullptr); ASSERT_NE(result.target, nullptr);
EXPECT_EQ(result.target->ReturnType(), i32); EXPECT_EQ(result.target->ReturnType(), i32);
EXPECT_EQ(result.target->Parameters().Length(), 1u); EXPECT_EQ(result.target->Parameters().Length(), 1u);
EXPECT_EQ(result.target->Parameters()[0]->Type(), i32); EXPECT_EQ(result.target->Parameters()[0]->Type(), ai);
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -17,6 +17,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <ostream>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
@ -174,13 +175,15 @@ template <typename TO>
struct ptr {}; struct ptr {};
/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for /// Type used to accept scalars as arguments. Can be either a single value that gets splatted for
/// composite types, or all values requried by the composite type. /// composite types, or all values required by the composite type.
struct ScalarArgs { struct ScalarArgs {
/// Constructor
ScalarArgs() = default;
/// Constructor /// Constructor
/// @param single_value single value to initialize with /// @param single_value single value to initialize with
template <typename T> template <typename T>
ScalarArgs(T single_value) // NOLINT: implicit on purpose explicit ScalarArgs(T single_value) : values(utils::Vector<Storage, 1>{single_value}) {}
: values(utils::Vector<Storage, 1>{single_value}) {}
/// Constructor /// Constructor
/// @param all_values all values to initialize the composite type with /// @param all_values all values to initialize the composite type with
@ -192,6 +195,10 @@ struct ScalarArgs {
} }
} }
/// @param other the other ScalarArgs to compare against
/// @returns true if all values are equal to the values in @p other
bool operator==(const ScalarArgs& other) const { return values == other.values; }
/// Valid scalar types for args /// Valid scalar types for args
using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>; using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
@ -199,10 +206,28 @@ struct ScalarArgs {
utils::Vector<Storage, 16> values; utils::Vector<Storage, 16> values;
}; };
/// @param o the std::ostream to write to
/// @param args the ScalarArgs
/// @return the std::ostream so calls can be chained
inline std::ostream& operator<<(std::ostream& o, const ScalarArgs& args) {
o << "[";
bool first = true;
for (auto& val : args.values) {
if (!first) {
o << ", ";
}
first = false;
std::visit([&](auto&& v) { o << v; }, val);
}
o << "]";
return o;
}
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b); using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args); using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args);
using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v); using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b); using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
using type_name_func_ptr = std::string (*)();
template <typename T> template <typename T>
struct DataType {}; struct DataType {};
@ -241,7 +266,7 @@ struct DataType<bool> {
/// @param v arg of type double that will be cast to bool. /// @param v arg of type double that will be cast to bool.
/// @return a new AST expression of the bool type /// @return a new AST expression of the bool type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { return "bool"; } static inline std::string Name() { return "bool"; }
@ -272,7 +297,7 @@ struct DataType<i32> {
/// @param v arg of type double that will be cast to i32. /// @param v arg of type double that will be cast to i32.
/// @return a new AST i32 literal value expression /// @return a new AST i32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { return "i32"; } static inline std::string Name() { return "i32"; }
@ -303,7 +328,7 @@ struct DataType<u32> {
/// @param v arg of type double that will be cast to u32. /// @param v arg of type double that will be cast to u32.
/// @return a new AST u32 literal value expression /// @return a new AST u32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { return "u32"; } static inline std::string Name() { return "u32"; }
@ -334,7 +359,7 @@ struct DataType<f32> {
/// @param v arg of type double that will be cast to f32. /// @param v arg of type double that will be cast to f32.
/// @return a new AST f32 literal value expression /// @return a new AST f32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<f32>(v)); return Expr(b, ScalarArgs{static_cast<f32>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { return "f32"; } static inline std::string Name() { return "f32"; }
@ -365,7 +390,7 @@ struct DataType<f16> {
/// @param v arg of type double that will be cast to f16. /// @param v arg of type double that will be cast to f16.
/// @return a new AST f16 literal value expression /// @return a new AST f16 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { return "f16"; } static inline std::string Name() { return "f16"; }
@ -395,7 +420,7 @@ struct DataType<AFloat> {
/// @param v arg of type double that will be cast to AFloat. /// @param v arg of type double that will be cast to AFloat.
/// @return a new AST abstract-float literal value expression /// @return a new AST abstract-float literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-float"; } static inline std::string Name() { return "abstract-float"; }
@ -425,7 +450,7 @@ struct DataType<AInt> {
/// @param v arg of type double that will be cast to AInt. /// @param v arg of type double that will be cast to AInt.
/// @return a new AST abstract-int literal value expression /// @return a new AST abstract-int literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-int"; } static inline std::string Name() { return "abstract-int"; }
@ -463,7 +488,7 @@ struct DataType<vec<N, T>> {
const bool one_value = args.values.Length() == 1; const bool one_value = args.values.Length() == 1;
utils::Vector<const ast::Expression*, N> r; utils::Vector<const ast::Expression*, N> r;
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i])); r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
} }
return r; return r;
} }
@ -471,7 +496,7 @@ struct DataType<vec<N, T>> {
/// @param v arg of type double that will be cast to ElementType /// @param v arg of type double that will be cast to ElementType
/// @return a new AST vector value expression /// @return a new AST vector value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { static inline std::string Name() {
@ -514,7 +539,7 @@ struct DataType<mat<N, M, T>> {
utils::Vector<const ast::Expression*, N> r; utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; ++i) { for (uint32_t i = 0; i < N; ++i) {
if (one_value) { if (one_value) {
r.Push(DataType<vec<M, T>>::Expr(b, args.values[0])); r.Push(DataType<vec<M, T>>::Expr(b, ScalarArgs{args.values[0]}));
} else { } else {
utils::Vector<T, M> v; utils::Vector<T, M> v;
for (size_t j = 0; j < M; ++j) { for (size_t j = 0; j < M; ++j) {
@ -529,7 +554,7 @@ struct DataType<mat<N, M, T>> {
/// @param v arg of type double that will be cast to ElementType /// @param v arg of type double that will be cast to ElementType
/// @return a new AST matrix value expression /// @return a new AST matrix value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { static inline std::string Name() {
@ -585,7 +610,7 @@ struct DataType<alias<T, ID>> {
/// @param v arg of type double that will be cast to ElementType /// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the alias type /// @return a new AST expression of the alias type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
@ -626,7 +651,7 @@ struct DataType<ptr<T>> {
/// @param v arg of type double that will be cast to ElementType /// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the pointer type /// @return a new AST expression of the pointer type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
@ -680,7 +705,7 @@ struct DataType<array<N, T>> {
const bool one_value = args.values.Length() == 1; const bool one_value = args.values.Length() == 1;
utils::Vector<const ast::Expression*, N> r; utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; i++) { for (uint32_t i = 0; i < N; i++) {
r.Push(DataType<T>::Expr(b, one_value ? args.values[0] : args.values[i])); r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
} }
return r; return r;
} }
@ -688,7 +713,7 @@ struct DataType<array<N, T>> {
/// @param v arg of type double that will be cast to ElementType /// @param v arg of type double that will be cast to ElementType
/// @return a new AST array value expression /// @return a new AST array value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, static_cast<ElementType>(v)); return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
} }
/// @returns the WGSL name for the type /// @returns the WGSL name for the type
static inline std::string Name() { static inline std::string Name() {
@ -706,13 +731,23 @@ struct CreatePtrs {
ast_expr_from_double_func_ptr expr_from_double; ast_expr_from_double_func_ptr expr_from_double;
/// sem type create function /// sem type create function
sem_type_func_ptr sem; sem_type_func_ptr sem;
/// type name function
type_name_func_ptr name;
}; };
/// @param o the std::ostream to write to
/// @param ptrs the CreatePtrs
/// @return the std::ostream so calls can be chained
inline std::ostream& operator<<(std::ostream& o, const CreatePtrs& ptrs) {
return o << (ptrs.name ? ptrs.name() : "<unknown>");
}
/// Returns a CreatePtrs struct instance with all creation pointer types for /// Returns a CreatePtrs struct instance with all creation pointer types for
/// type `T` /// type `T`
template <typename T> template <typename T>
constexpr CreatePtrs CreatePtrsFor() { constexpr CreatePtrs CreatePtrsFor() {
return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::ExprFromDouble, DataType<T>::Sem}; return {DataType<T>::AST, DataType<T>::Expr, DataType<T>::ExprFromDouble, DataType<T>::Sem,
DataType<T>::Name};
} }
/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create /// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
@ -729,15 +764,15 @@ struct Value {
/// Creates a Value<T> with `args` /// Creates a Value<T> with `args`
/// @param args the args that will be passed to the expression /// @param args the args that will be passed to the expression
/// @returns a Value<T> /// @returns a Value<T>
static Value Create(ScalarArgs args) { return Value{DataType::Expr, std::move(args)}; } static Value Create(ScalarArgs args) { return Value{CreatePtrsFor<T>(), std::move(args)}; }
/// Creates an `ast::Expression` for the type T passing in previously stored args /// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder /// @param b the ProgramBuilder
/// @returns an expression node /// @returns an expression node
const ast::Expression* Expr(ProgramBuilder& b) const { return (*expr)(b, args); } const ast::Expression* Expr(ProgramBuilder& b) const { return (*create.expr)(b, args); }
/// ast expression type create function /// functions to create values / types of the value
ast_expr_func_ptr expr; CreatePtrs create;
/// args to create expression with /// args to create expression with
ScalarArgs args; ScalarArgs args;
}; };
@ -764,7 +799,7 @@ const char* FriendlyName() {
/// Creates a `Value<T>` from a scalar `v` /// Creates a `Value<T>` from a scalar `v`
template <typename T> template <typename T>
auto Val(T v) { auto Val(T v) {
return Value<T>::Create(v); return Value<T>::Create(ScalarArgs{v});
} }
/// Creates a `Value<vec<N, T>>` from N scalar `args` /// Creates a `Value<vec<N, T>>` from N scalar `args`

View File

@ -51,7 +51,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_IntMin) {
std::stringstream out; std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error(); ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error();
EXPECT_EQ(out.str(), "0u"); EXPECT_EQ(out.str(), "2147483648u");
} }
} // namespace } // namespace