From 75bc93c0df83671adc3a62c172bd33a8be939dad Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Tue, 11 Oct 2022 20:36:48 +0000 Subject: [PATCH] tint: Fix const eval of type conversions These were quite spectacularly broken. Also: * Fix the definition of 'scalar' in `intrinsics.def`. This was in part why conversions were broken, as abstracts were materialized before reaching the converter builtin when they shouldn't have been. * Implement `ScalarArgsFrom()` helper in `const_eval_test.cc`. This is used by the new conversion tests, and also implements part of the suggestion to improve tint:1709. Fixed: tint:1707 Bug: tint:1709 Change-Id: Iab962b671305e868f92710912d2ed07e3338c680 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105261 Commit-Queue: Ben Clayton Reviewed-by: Antonio Maiorano Kokoro: Kokoro Commit-Queue: Ben Clayton --- docs/tint/origin-trial-changes.md | 4 + src/tint/intrinsics.def | 68 +- src/tint/resolver/const_eval.cc | 24 +- src/tint/resolver/const_eval_test.cc | 814 ++++++++++++------ src/tint/resolver/intrinsic_table.inl | 116 ++- src/tint/resolver/intrinsic_table_test.cc | 32 +- src/tint/resolver/resolver_test_helper.h | 83 +- .../writer/msl/generator_impl_cast_test.cc | 2 +- 8 files changed, 737 insertions(+), 406 deletions(-) diff --git a/docs/tint/origin-trial-changes.md b/docs/tint/origin-trial-changes.md index bd27dd1582..ead20108e7 100644 --- a/docs/tint/origin-trial-changes.md +++ b/docs/tint/origin-trial-changes.md @@ -10,6 +10,10 @@ * The `external_texture` overload of `textureSampleLevel()` has been deprecated. Use `textureSampleBaseClampToEdge()` instead. [tint:1671](crbug.com/tint/1671) +### Fixes + +* Constant evaluation of type conversions where the value exceeds the limits of the target type have been fixed. [tint:1707](crbug.com/tint/1707) + ## Changes for M107 ### New features diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def index 28b78ffa66..493d22443d 100644 --- a/src/tint/intrinsics.def +++ b/src/tint/intrinsics.def @@ -175,13 +175,13 @@ type __atomic_compare_exchange_result // A type matcher that can match one or more types. // //////////////////////////////////////////////////////////////////////////////// -match abstract_or_scalar: ia | fa | f32 | f16 | i32 | u32 | bool -match scalar: f32 | f16 | i32 | u32 | bool -match scalar_no_f32: i32 | f16 | u32 | bool -match scalar_no_f16: f32 | i32 | u32 | bool -match scalar_no_i32: f32 | f16 | u32 | bool -match scalar_no_u32: f32 | f16 | i32 | bool -match scalar_no_bool: f32 | f16 | i32 | u32 +match scalar: ia | fa | f32 | f16 | i32 | u32 | bool +match concrete_scalar: f32 | f16 | i32 | u32 | bool +match scalar_no_f32: ia | fa | i32 | f16 | u32 | bool +match scalar_no_f16: ia | fa | f32 | i32 | u32 | bool +match scalar_no_i32: ia | fa | f32 | f16 | u32 | bool +match scalar_no_u32: ia | fa | f32 | f16 | i32 | bool +match scalar_no_bool: ia | fa | f32 | f16 | i32 | u32 match fia_fiu32_f16: fa | ia | f32 | i32 | u32 | f16 match fia_fi32_f16: fa | ia | f32 | i32 | f16 match fia_fiu32: fa | ia | f32 | i32 | u32 @@ -524,9 +524,9 @@ fn round(T) -> T fn round(vec) -> vec fn saturate(T) -> T fn saturate(vec) -> vec -@const("select_bool") fn select(T, T, bool) -> T -@const("select_bool") fn select(vec, vec, bool) -> vec -@const("select_boolvec") fn select(vec, vec, vec) -> vec +@const("select_bool") fn select(T, T, bool) -> T +@const("select_bool") fn select(vec, vec, bool) -> vec +@const("select_boolvec") fn select(vec, vec, vec) -> vec fn sign(T) -> T fn sign(vec) -> vec fn sin(T) -> T @@ -720,9 +720,9 @@ fn textureLoad(texture: texture_external, coords: vec2) -> vec4 @const("Zero") ctor f32() -> f32 @const("Zero") ctor f16() -> f16 @const("Zero") ctor bool() -> bool -@const("Zero") ctor vec2() -> vec2 -@const("Zero") ctor vec3() -> vec3 -@const("Zero") ctor vec4() -> vec4 +@const("Zero") ctor vec2() -> vec2 +@const("Zero") ctor vec3() -> vec3 +@const("Zero") ctor vec4() -> vec4 @const("Zero") ctor mat2x2() -> mat2x2 @const("Zero") ctor mat2x3() -> mat2x3 @const("Zero") ctor mat2x4() -> mat2x4 @@ -739,9 +739,9 @@ fn textureLoad(texture: texture_external, coords: vec2) -> vec4 @const("Identity") ctor f32(f32) -> f32 @const("Identity") ctor f16(f16) -> f16 @const("Identity") ctor bool(bool) -> bool -@const("Identity") ctor vec2(vec2) -> vec2 -@const("Identity") ctor vec3(vec3) -> vec3 -@const("Identity") ctor vec4(vec4) -> vec4 +@const("Identity") ctor vec2(vec2) -> vec2 +@const("Identity") ctor vec3(vec3) -> vec3 +@const("Identity") ctor vec4(vec4) -> vec4 @const("Identity") ctor mat2x2(mat2x2) -> mat2x2 @const("Identity") ctor mat2x3(mat2x3) -> mat2x3 @const("Identity") ctor mat2x4(mat2x4) -> mat2x4 @@ -753,24 +753,24 @@ fn textureLoad(texture: texture_external, coords: vec2) -> vec4 @const("Identity") ctor mat4x4(mat4x4) -> mat4x4 // Vector constructors (splat) -@const("VecSplat") ctor vec2(T) -> vec2 -@const("VecSplat") ctor vec3(T) -> vec3 -@const("VecSplat") ctor vec4(T) -> vec4 +@const("VecSplat") ctor vec2(T) -> vec2 +@const("VecSplat") ctor vec3(T) -> vec3 +@const("VecSplat") ctor vec4(T) -> vec4 // Vector constructors (scalar) -@const("VecCtorS") ctor vec2(x: T, y: T) -> vec2 -@const("VecCtorS") ctor vec3(x: T, y: T, z: T) -> vec3 -@const("VecCtorS") ctor vec4(x: T, y: T, z: T, w: T) -> vec4 +@const("VecCtorS") ctor vec2(x: T, y: T) -> vec2 +@const("VecCtorS") ctor vec3(x: T, y: T, z: T) -> vec3 +@const("VecCtorS") ctor vec4(x: T, y: T, z: T, w: T) -> vec4 // Vector constructors (mixed) -@const("VecCtorM") ctor vec3(xy: vec2, z: T) -> vec3 -@const("VecCtorM") ctor vec3(x: T, yz: vec2) -> vec3 -@const("VecCtorM") ctor vec4(xy: vec2, z: T, w: T) -> vec4 -@const("VecCtorM") ctor vec4(x: T, yz: vec2, w: T) -> vec4 -@const("VecCtorM") ctor vec4(x: T, y: T, zw: vec2) -> vec4 -@const("VecCtorM") ctor vec4(xy: vec2, zw: vec2) -> vec4 -@const("VecCtorM") ctor vec4(xyz: vec3, w: T) -> vec4 -@const("VecCtorM") ctor vec4(x: T, zyw: vec3) -> vec4 +@const("VecCtorM") ctor vec3(xy: vec2, z: T) -> vec3 +@const("VecCtorM") ctor vec3(x: T, yz: vec2) -> vec3 +@const("VecCtorM") ctor vec4(xy: vec2, z: T, w: T) -> vec4 +@const("VecCtorM") ctor vec4(x: T, yz: vec2, w: T) -> vec4 +@const("VecCtorM") ctor vec4(x: T, y: T, zw: vec2) -> vec4 +@const("VecCtorM") ctor vec4(xy: vec2, zw: vec2) -> vec4 +@const("VecCtorM") ctor vec4(xyz: vec3, w: T) -> vec4 +@const("VecCtorM") ctor vec4(x: T, zyw: vec3) -> vec4 // Matrix constructors (scalar) @const("MatCtorS") @@ -952,11 +952,11 @@ op % (T, vec) -> vec op && (bool, bool) -> bool op || (bool, bool) -> bool -@const op == (T, T) -> bool -@const op == (vec, vec) -> vec +@const op == (T, T) -> bool +@const op == (vec, vec) -> vec -@const op != (T, T) -> bool -@const op != (vec, vec) -> vec +@const op != (T, T) -> bool +@const op != (vec, vec) -> vec @const op < (T, T) -> bool @const op < (vec, vec) -> vec diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 9ddf24ff9a..d9fde25b23 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -231,27 +231,29 @@ struct Element : ImplConstant { return this; } return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult { - // `T` is the source type, `value` is the source value. + // `value` is the source value. + // `FROM` is the source type. // `TO` is the target type. using TO = std::decay_t; + using FROM = T; if constexpr (std::is_same_v) { // [x -> bool] return builder.create>(target_ty, !IsPositiveZero(value)); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { // [bool -> x] return builder.create>(target_ty, TO(value ? 1 : 0)); } else if (auto conv = CheckedConvert(value)) { // Conversion success return builder.create>(target_ty, conv.Get()); // --- Below this point are the failure cases --- - } else if constexpr (IsAbstract) { + } else if constexpr (IsAbstract) { // [abstract-numeric -> x] - materialization failure std::stringstream ss; ss << "value " << value << " cannot be represented as "; ss << "'" << builder.FriendlyName(target_ty) << "'"; builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source); return utils::Failure; - } else if constexpr (IsFloatingPoint>) { + } else if constexpr (IsFloatingPoint) { // [x -> floating-point] - number not exactly representable // https://www.w3.org/TR/WGSL/#floating-point-conversion switch (conv.Failure()) { @@ -260,8 +262,8 @@ struct Element : ImplConstant { case ConversionFailure::kExceedsPositiveLimit: return builder.create>(target_ty, TO::Inf()); } - } else { - // [x -> integer] - number not exactly representable + } else if constexpr (IsFloatingPoint) { + // [floating-point -> integer] - number not exactly representable // https://www.w3.org/TR/WGSL/#floating-point-conversion switch (conv.Failure()) { case ConversionFailure::kExceedsNegativeLimit: @@ -269,6 +271,10 @@ struct Element : ImplConstant { case ConversionFailure::kExceedsPositiveLimit: return builder.create>(target_ty, TO::Highest()); } + } else if constexpr (IsIntegral) { + // [integer -> integer] - number not exactly representable + // Static cast + return builder.create>(target_ty, static_cast(value)); } return nullptr; // Expression is not constant. }); @@ -842,11 +848,7 @@ ConstEval::Result ConstEval::Conv(const sem::Type* ty, return nullptr; // Single argument is not constant. } - if (auto conv = Convert(ty, args[0], source)) { - return conv.Get(); - } - - return nullptr; + return Convert(ty, args[0], source); } ConstEval::Result ConstEval::Zero(const sem::Type* ty, diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc index 8bded4caab..b966097425 100644 --- a/src/tint/resolver/const_eval_test.cc +++ b/src/tint/resolver/const_eval_test.cc @@ -44,6 +44,30 @@ const auto kPiOver4 = T(UnwrapNumber(0.785398163397448309616)); template const auto k3PiOver4 = T(UnwrapNumber(2.356194490192344928846)); +/// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args +void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) { + Switch( + c->Type(), // + [&](const sem::Bool*) { args.values.Push(c->As()); }, + [&](const sem::I32*) { args.values.Push(c->As()); }, + [&](const sem::U32*) { args.values.Push(c->As()); }, + [&](const sem::F32*) { args.values.Push(c->As()); }, + [&](const sem::F16*) { args.values.Push(c->As()); }, + [&](Default) { + size_t i = 0; + while (auto* child = c->Index(i++)) { + CollectScalarArgs(child, args); + } + }); +} + +/// Walks the sem::Constant @p c, returning all the inner-most scalar values. +builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) { + builder::ScalarArgs out; + CollectScalarArgs(c, out); + return out; +} + template constexpr auto Negate(const Number& v) { if constexpr (std::is_integral_v) { @@ -1211,283 +1235,6 @@ TEST_F(ResolverConstEvalTest, Vec3_MixConstruct_all_false) { EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), false); } -TEST_F(ResolverConstEvalTest, Vec3_Convert_f32_to_i32) { - auto* expr = vec3(vec3(1.1_f, 2.2_f, 3.3_f)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - auto* sem = Sem().Get(expr); - ASSERT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 1); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 2); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 3); -} - -TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f32) { - auto* expr = vec3(vec3(10_u, 20_u, 30_u)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - auto* sem = Sem().Get(expr); - ASSERT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 10.f); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 20.f); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 30.f); -} - -TEST_F(ResolverConstEvalTest, Vec3_Convert_f16_to_i32) { - Enable(ast::Extension::kF16); - - auto* expr = vec3(vec3(1.1_h, 2.2_h, 3.3_h)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - auto* sem = Sem().Get(expr); - EXPECT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 1_i); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 2_i); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 3_i); -} - -TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f16) { - Enable(ast::Extension::kF16); - - auto* expr = vec3(vec3(10_u, 20_u, 30_u)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - auto* sem = Sem().Get(expr); - EXPECT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 10.f); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 20.f); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 30.f); -} - -TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_i32) { - auto* expr = vec3(vec3(1e10_f, -1e20_f, 1e30_f)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - auto* sem = Sem().Get(expr); - ASSERT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), i32::Highest()); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), i32::Lowest()); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), i32::Highest()); -} - -TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_u32) { - auto* expr = vec3(vec3(1e10_f, -1e20_f, 1e30_f)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - auto* sem = Sem().Get(expr); - ASSERT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_TRUE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), u32::Highest()); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), u32::Lowest()); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), u32::Highest()); -} - -TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_f16) { - Enable(ast::Extension::kF16); - - auto* expr = vec3(vec3(1e10_f, -1e20_f, 1e30_f)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - constexpr auto kInfinity = std::numeric_limits::infinity(); - - auto* sem = Sem().Get(expr); - ASSERT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), kInfinity); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), -kInfinity); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), kInfinity); -} - -TEST_F(ResolverConstEvalTest, Vec3_Convert_Small_f32_to_f16) { - Enable(ast::Extension::kF16); - - auto* expr = vec3(vec3(1e-20_f, -2e-30_f, 3e-40_f)); - WrapInFunction(expr); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - auto* sem = Sem().Get(expr); - ASSERT_NE(sem, nullptr); - auto* vec = sem->Type()->As(); - ASSERT_NE(vec, nullptr); - EXPECT_TRUE(vec->type()->Is()); - EXPECT_EQ(vec->Width(), 3u); - EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); - EXPECT_FALSE(sem->ConstantValue()->AllEqual()); - EXPECT_TRUE(sem->ConstantValue()->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->AllZero()); - - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero()); - EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 0.0); - EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(0)->As().value)); - - EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); - EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), -0.0); - EXPECT_TRUE(std::signbit(sem->ConstantValue()->Index(1)->As().value)); - - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero()); - EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero()); - EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 0.0); - EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->As().value)); -} - TEST_F(ResolverConstEvalTest, Mat2x3_ZeroInit_f32) { auto* expr = mat2x3(); WrapInFunction(expr); @@ -2474,6 +2221,519 @@ TEST_F(ResolverConstEvalTest, Struct_Array_Construct) { EXPECT_EQ(sem->ConstantValue()->Index(1)->Index(2)->As(), 3_f); } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Conversion +//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace conv { + +using Scalar = std::variant< // + builder::Value, + builder::Value, + builder::Value, + builder::Value, + builder::Value, + builder::Value, + builder::Value>; + +static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) { + std::visit( + [&](auto&& v) { + using ValueType = std::decay_t; + o << ValueType::DataType::Name() << "("; + for (auto& a : v.args.values) { + o << std::get(a); + if (&a != &v.args.values.Back()) { + o << ", "; + } + } + o << ")"; + }, + scalar); + return o; +} + +enum class Kind { + kScalar, + kVector, +}; + +static std::ostream& operator<<(std::ostream& o, const Kind& k) { + switch (k) { + case Kind::kScalar: + return o << "scalar"; + case Kind::kVector: + return o << "vector"; + } + return o << ""; +} + +struct Case { + Scalar input; + Scalar expected; + builder::CreatePtrs type; + bool unrepresentable = false; +}; + +static std::ostream& operator<<(std::ostream& o, const Case& c) { + if (c.unrepresentable) { + o << "[unrepresentable] input: " << c.input; + } else { + o << "input: " << c.input << ", expected: " << c.expected; + } + return o << ", type: " << c.type; +} + +template +Case Success(FROM input, TO expected) { + return {builder::Val(input), builder::Val(expected), builder::CreatePtrsFor()}; +} + +template +Case Unrepresentable(FROM input) { + return {builder::Val(input), builder::Val(0_i), builder::CreatePtrsFor(), + /* unrepresentable */ true}; +} + +using ResolverConstEvalConvTest = ResolverTestWithParam>; + +TEST_P(ResolverConstEvalConvTest, Test) { + const auto& kind = std::get<0>(GetParam()); + const auto& input = std::get<1>(GetParam()).input; + const auto& expected = std::get<1>(GetParam()).expected; + const auto& type = std::get<1>(GetParam()).type; + const auto unrepresentable = std::get<1>(GetParam()).unrepresentable; + + auto* input_val = std::visit([&](auto val) { return val.Expr(*this); }, input); + auto* expr = Construct(type.ast(*this), input_val); + if (kind == Kind::kVector) { + expr = Construct(ty.vec(nullptr, 3), expr); + } + WrapInFunction(expr); + + auto* target_sem_ty = type.sem(*this); + if (kind == Kind::kVector) { + target_sem_ty = create(target_sem_ty, 3u); + } + + if (unrepresentable) { + ASSERT_FALSE(r()->Resolve()); + EXPECT_THAT(r()->error(), testing::HasSubstr("cannot be represented as")); + } else { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + EXPECT_TYPE(sem->Type(), target_sem_ty); + ASSERT_NE(sem->ConstantValue(), nullptr); + EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty); + + auto expected_values = std::visit([&](auto&& val) { return val.args; }, expected); + if (kind == Kind::kVector) { + expected_values.values.Push(expected_values.values[0]); + expected_values.values.Push(expected_values.values[0]); + } + auto got_values = ScalarArgsFrom(sem->ConstantValue()); + EXPECT_EQ(expected_values, got_values); + } +} +INSTANTIATE_TEST_SUITE_P(ScalarAndVector, + ResolverConstEvalConvTest, + testing::Combine(testing::Values(Kind::kScalar, Kind::kVector), + testing::ValuesIn({ + // TODO(crbug.com/tint/1502): Add f16 tests + // i32 -> u32 + Success(0_i, 0_u), + Success(1_i, 1_u), + Success(-1_i, 0xffffffff_u), + Success(2_i, 2_u), + Success(-2_i, 0xfffffffe_u), + // i32 -> f32 + Success(0_i, 0_f), + Success(1_i, 1_f), + Success(-1_i, -1_f), + Success(2_i, 2_f), + Success(-2_i, -2_f), + // i32 -> bool + Success(0_i, false), + Success(1_i, true), + Success(-1_i, true), + Success(2_i, true), + Success(-2_i, true), + // u32 -> i32 + Success(0_u, 0_i), + Success(1_u, 1_i), + Success(0xffffffff_u, -1_i), + Success(2_u, 2_i), + Success(0xfffffffe_u, -2_i), + // u32 -> f32 + Success(0_u, 0_f), + Success(1_u, 1_f), + Success(2_u, 2_f), + Success(0xffffffff_u, 0xffffffff_f), + // u32 -> bool + Success(0_u, false), + Success(1_u, true), + Success(2_u, true), + Success(0xffffffff_u, true), + // f32 -> i32 + Success(0_f, 0_i), + Success(1_f, 1_i), + Success(2_f, 2_i), + Success(1e20_f, i32::Highest()), + Success(-1e20_f, i32::Lowest()), + // f32 -> u32 + Success(0_f, 0_i), + Success(1_f, 1_i), + Success(-1_f, u32::Lowest()), + Success(2_f, 2_i), + Success(1e20_f, u32::Highest()), + Success(-1e20_f, u32::Lowest()), + // f32 -> bool + Success(0_f, false), + Success(1_f, true), + Success(-1_f, true), + Success(2_f, true), + Success(1e20_f, true), + Success(-1e20_f, true), + // abstract-int -> i32 + Success(0_a, 0_i), + Success(1_a, 1_i), + Success(-1_a, -1_i), + Success(0x7fffffff_a, i32::Highest()), + Success(-0x80000000_a, i32::Lowest()), + Unrepresentable(0x80000000_a), + // abstract-int -> u32 + Success(0_a, 0_u), + Success(1_a, 1_u), + Success(0xffffffff_a, 0xffffffff_u), + Unrepresentable(0x100000000_a), + Unrepresentable(-1_a), + // abstract-int -> f32 + Success(0_a, 0_f), + Success(1_a, 1_f), + Success(0xffffffff_a, 0xffffffff_f), + Success(0x100000000_a, 0x100000000_f), + Success(-0x100000000_a, -0x100000000_f), + Success(0x7fffffffffffffff_a, 0x7fffffffffffffff_f), + Success(-0x7fffffffffffffff_a, -0x7fffffffffffffff_f), + // abstract-int -> bool + Success(0_a, false), + Success(1_a, true), + Success(0xffffffff_a, true), + Success(0x100000000_a, true), + Success(-0x100000000_a, true), + Success(0x7fffffffffffffff_a, true), + Success(-0x7fffffffffffffff_a, true), + // abstract-float -> i32 + Success(0.0_a, 0_i), + Success(1.0_a, 1_i), + Success(-1.0_a, -1_i), + Success(AFloat(0x7fffffff), i32::Highest()), + Success(-AFloat(0x80000000), i32::Lowest()), + Unrepresentable(0x80000000_a), + // abstract-float -> u32 + Success(0.0_a, 0_u), + Success(1.0_a, 1_u), + Success(AFloat(0xffffffff), 0xffffffff_u), + Unrepresentable(AFloat(0x100000000)), + Unrepresentable(AFloat(-1)), + // abstract-float -> f32 + Success(0.0_a, 0_f), + Success(1.0_a, 1_f), + Success(AFloat(0xffffffff), 0xffffffff_f), + Success(AFloat(0x100000000), 0x100000000_f), + Success(-AFloat(0x100000000), -0x100000000_f), + Unrepresentable(1e40_a), + Unrepresentable(-1e40_a), + // abstract-float -> bool + Success(0.0_a, false), + Success(1.0_a, true), + Success(AFloat(0xffffffff), true), + Success(AFloat(0x100000000), true), + Success(-AFloat(0x100000000), true), + Success(1e40_a, true), + Success(-1e40_a, true), + }))); + +TEST_F(ResolverConstEvalTest, Vec3_Convert_f32_to_i32) { + auto* expr = vec3(vec3(1.1_f, 2.2_f, 3.3_f)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 1); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 2); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 3); +} + +TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f32) { + auto* expr = vec3(vec3(10_u, 20_u, 30_u)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 10.f); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 20.f); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 30.f); +} + +TEST_F(ResolverConstEvalTest, Vec3_Convert_f16_to_i32) { + Enable(ast::Extension::kF16); + + auto* expr = vec3(vec3(1.1_h, 2.2_h, 3.3_h)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 1_i); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 2_i); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 3_i); +} + +TEST_F(ResolverConstEvalTest, Vec3_Convert_u32_to_f16) { + Enable(ast::Extension::kF16); + + auto* expr = vec3(vec3(10_u, 20_u, 30_u)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 10.f); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), 20.f); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 30.f); +} + +TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_i32) { + auto* expr = vec3(vec3(1e10_f, -1e20_f, 1e30_f)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), i32::Highest()); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), i32::Lowest()); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), i32::Highest()); +} + +TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_u32) { + auto* expr = vec3(vec3(1e10_f, -1e20_f, 1e30_f)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_TRUE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), u32::Highest()); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), u32::Lowest()); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), u32::Highest()); +} + +TEST_F(ResolverConstEvalTest, Vec3_Convert_Large_f32_to_f16) { + Enable(ast::Extension::kF16); + + auto* expr = vec3(vec3(1e10_f, -1e20_f, 1e30_f)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + constexpr auto kInfinity = std::numeric_limits::infinity(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), kInfinity); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), -kInfinity); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), kInfinity); +} + +TEST_F(ResolverConstEvalTest, Vec3_Convert_Small_f32_to_f16) { + Enable(ast::Extension::kF16); + + auto* expr = vec3(vec3(1e-20_f, -2e-30_f, 3e-40_f)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + EXPECT_FALSE(sem->ConstantValue()->AllEqual()); + EXPECT_TRUE(sem->ConstantValue()->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->AllZero()); + + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllEqual()); + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AnyZero()); + EXPECT_TRUE(sem->ConstantValue()->Index(0)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(0)->As(), 0.0); + EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(0)->As().value)); + + EXPECT_TRUE(sem->ConstantValue()->Index(1)->AllEqual()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AnyZero()); + EXPECT_FALSE(sem->ConstantValue()->Index(1)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(1)->As(), -0.0); + EXPECT_TRUE(std::signbit(sem->ConstantValue()->Index(1)->As().value)); + + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllEqual()); + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AnyZero()); + EXPECT_TRUE(sem->ConstantValue()->Index(2)->AllZero()); + EXPECT_EQ(sem->ConstantValue()->Index(2)->As(), 0.0); + EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->As().value)); +} + +} // namespace conv + //////////////////////////////////////////////////////////////////////////////////////////////////// // Indexing //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl index 67cb23fead..426b78bc0c 100644 --- a/src/tint/resolver/intrinsic_table.inl +++ b/src/tint/resolver/intrinsic_table.inl @@ -1550,8 +1550,8 @@ std::string AtomicCompareExchangeResult::String(MatchState* state) const { return "__atomic_compare_exchange_result<" + T + ">"; } -/// TypeMatcher for 'match abstract_or_scalar' -class AbstractOrScalar : public TypeMatcher { +/// TypeMatcher for 'match scalar' +class Scalar : public TypeMatcher { public: /// Checks whether the given type matches the matcher rules, and returns the /// expected, canonicalized type on success. @@ -1566,7 +1566,7 @@ class AbstractOrScalar : public TypeMatcher { std::string String(MatchState* state) const override; }; -const sem::Type* AbstractOrScalar::Match(MatchState& state, const sem::Type* ty) const { +const sem::Type* Scalar::Match(MatchState& state, const sem::Type* ty) const { if (match_ia(state, ty)) { return build_ia(state); } @@ -1591,7 +1591,7 @@ const sem::Type* AbstractOrScalar::Match(MatchState& state, const sem::Type* ty) return nullptr; } -std::string AbstractOrScalar::String(MatchState*) const { +std::string Scalar::String(MatchState*) const { std::stringstream ss; // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // template arguments, nor can they match sub-types. As such, they have no use for the MatchState. @@ -1599,8 +1599,8 @@ std::string AbstractOrScalar::String(MatchState*) const { return ss.str(); } -/// TypeMatcher for 'match scalar' -class Scalar : public TypeMatcher { +/// TypeMatcher for 'match concrete_scalar' +class ConcreteScalar : public TypeMatcher { public: /// Checks whether the given type matches the matcher rules, and returns the /// expected, canonicalized type on success. @@ -1615,7 +1615,7 @@ class Scalar : public TypeMatcher { std::string String(MatchState* state) const override; }; -const sem::Type* Scalar::Match(MatchState& state, const sem::Type* ty) const { +const sem::Type* ConcreteScalar::Match(MatchState& state, const sem::Type* ty) const { if (match_i32(state, ty)) { return build_i32(state); } @@ -1634,7 +1634,7 @@ const sem::Type* Scalar::Match(MatchState& state, const sem::Type* ty) const { return nullptr; } -std::string Scalar::String(MatchState*) const { +std::string ConcreteScalar::String(MatchState*) const { std::stringstream ss; // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // template arguments, nor can they match sub-types. As such, they have no use for the MatchState. @@ -1659,6 +1659,12 @@ class ScalarNoF32 : public TypeMatcher { }; const sem::Type* ScalarNoF32::Match(MatchState& state, const sem::Type* ty) const { + if (match_ia(state, ty)) { + return build_ia(state); + } + if (match_fa(state, ty)) { + return build_fa(state); + } if (match_i32(state, ty)) { return build_i32(state); } @@ -1678,7 +1684,7 @@ std::string ScalarNoF32::String(MatchState*) const { std::stringstream ss; // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // template arguments, nor can they match sub-types. As such, they have no use for the MatchState. - ss << I32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); + ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << I32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); return ss.str(); } @@ -1699,6 +1705,12 @@ class ScalarNoF16 : public TypeMatcher { }; const sem::Type* ScalarNoF16::Match(MatchState& state, const sem::Type* ty) const { + if (match_ia(state, ty)) { + return build_ia(state); + } + if (match_fa(state, ty)) { + return build_fa(state); + } if (match_i32(state, ty)) { return build_i32(state); } @@ -1718,7 +1730,7 @@ std::string ScalarNoF16::String(MatchState*) const { std::stringstream ss; // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // template arguments, nor can they match sub-types. As such, they have no use for the MatchState. - ss << F32().String(nullptr) << ", " << I32().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); + ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << I32().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); return ss.str(); } @@ -1739,6 +1751,12 @@ class ScalarNoI32 : public TypeMatcher { }; const sem::Type* ScalarNoI32::Match(MatchState& state, const sem::Type* ty) const { + if (match_ia(state, ty)) { + return build_ia(state); + } + if (match_fa(state, ty)) { + return build_fa(state); + } if (match_u32(state, ty)) { return build_u32(state); } @@ -1758,7 +1776,7 @@ std::string ScalarNoI32::String(MatchState*) const { std::stringstream ss; // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // template arguments, nor can they match sub-types. As such, they have no use for the MatchState. - ss << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); + ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << U32().String(nullptr) << " or " << Bool().String(nullptr); return ss.str(); } @@ -1779,6 +1797,12 @@ class ScalarNoU32 : public TypeMatcher { }; const sem::Type* ScalarNoU32::Match(MatchState& state, const sem::Type* ty) const { + if (match_ia(state, ty)) { + return build_ia(state); + } + if (match_fa(state, ty)) { + return build_fa(state); + } if (match_i32(state, ty)) { return build_i32(state); } @@ -1798,7 +1822,7 @@ std::string ScalarNoU32::String(MatchState*) const { std::stringstream ss; // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // template arguments, nor can they match sub-types. As such, they have no use for the MatchState. - ss << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << Bool().String(nullptr); + ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << Bool().String(nullptr); return ss.str(); } @@ -1819,6 +1843,12 @@ class ScalarNoBool : public TypeMatcher { }; const sem::Type* ScalarNoBool::Match(MatchState& state, const sem::Type* ty) const { + if (match_ia(state, ty)) { + return build_ia(state); + } + if (match_fa(state, ty)) { + return build_fa(state); + } if (match_i32(state, ty)) { return build_i32(state); } @@ -1838,7 +1868,7 @@ std::string ScalarNoBool::String(MatchState*) const { std::stringstream ss; // Note: We pass nullptr to the TypeMatcher::String() functions, as 'matcher's do not support // template arguments, nor can they match sub-types. As such, they have no use for the MatchState. - ss << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << U32().String(nullptr); + ss << Ia().String(nullptr) << ", " << Fa().String(nullptr) << ", " << F32().String(nullptr) << ", " << F16().String(nullptr) << ", " << I32().String(nullptr) << " or " << U32().String(nullptr); return ss.str(); } @@ -2580,8 +2610,8 @@ class Matchers { FrexpResult FrexpResult_; FrexpResultVec FrexpResultVec_; AtomicCompareExchangeResult AtomicCompareExchangeResult_; - AbstractOrScalar AbstractOrScalar_; Scalar Scalar_; + ConcreteScalar ConcreteScalar_; ScalarNoF32 ScalarNoF32_; ScalarNoF16 ScalarNoF16_; ScalarNoI32 ScalarNoI32_; @@ -2666,8 +2696,8 @@ class Matchers { /* [47] */ &FrexpResult_, /* [48] */ &FrexpResultVec_, /* [49] */ &AtomicCompareExchangeResult_, - /* [50] */ &AbstractOrScalar_, - /* [51] */ &Scalar_, + /* [50] */ &Scalar_, + /* [51] */ &ConcreteScalar_, /* [52] */ &ScalarNoF32_, /* [53] */ &ScalarNoF16_, /* [54] */ &ScalarNoI32_, @@ -14348,9 +14378,9 @@ constexpr IntrinsicInfo kBuiltins[] = { }, { /* [67] */ - /* fn select(T, T, bool) -> T */ - /* fn select(vec, vec, bool) -> vec */ - /* fn select(vec, vec, vec) -> vec */ + /* fn select(T, T, bool) -> T */ + /* fn select(vec, vec, bool) -> vec */ + /* fn select(vec, vec, vec) -> vec */ /* num overloads */ 3, /* overloads */ &kOverloads[276], }, @@ -14876,15 +14906,15 @@ constexpr IntrinsicInfo kBinaryOperators[] = { }, { /* [10] */ - /* op ==(T, T) -> bool */ - /* op ==(vec, vec) -> vec */ + /* op ==(T, T) -> bool */ + /* op ==(vec, vec) -> vec */ /* num overloads */ 2, /* overloads */ &kOverloads[408], }, { /* [11] */ - /* op !=(T, T) -> bool */ - /* op !=(vec, vec) -> vec */ + /* op !=(T, T) -> bool */ + /* op !=(vec, vec) -> vec */ /* num overloads */ 2, /* overloads */ &kOverloads[394], }, @@ -14995,10 +15025,10 @@ constexpr IntrinsicInfo kConstructorsAndConverters[] = { }, { /* [5] */ - /* ctor vec2() -> vec2 */ - /* ctor vec2(vec2) -> vec2 */ - /* ctor vec2(T) -> vec2 */ - /* ctor vec2(x: T, y: T) -> vec2 */ + /* ctor vec2() -> vec2 */ + /* ctor vec2(vec2) -> vec2 */ + /* ctor vec2(T) -> vec2 */ + /* ctor vec2(x: T, y: T) -> vec2 */ /* conv vec2(vec2) -> vec2 */ /* conv vec2(vec2) -> vec2 */ /* conv vec2(vec2) -> vec2 */ @@ -15009,12 +15039,12 @@ constexpr IntrinsicInfo kConstructorsAndConverters[] = { }, { /* [6] */ - /* ctor vec3() -> vec3 */ - /* ctor vec3(vec3) -> vec3 */ - /* ctor vec3(T) -> vec3 */ - /* ctor vec3(x: T, y: T, z: T) -> vec3 */ - /* ctor vec3(xy: vec2, z: T) -> vec3 */ - /* ctor vec3(x: T, yz: vec2) -> vec3 */ + /* ctor vec3() -> vec3 */ + /* ctor vec3(vec3) -> vec3 */ + /* ctor vec3(T) -> vec3 */ + /* ctor vec3(x: T, y: T, z: T) -> vec3 */ + /* ctor vec3(xy: vec2, z: T) -> vec3 */ + /* ctor vec3(x: T, yz: vec2) -> vec3 */ /* conv vec3(vec3) -> vec3 */ /* conv vec3(vec3) -> vec3 */ /* conv vec3(vec3) -> vec3 */ @@ -15025,16 +15055,16 @@ constexpr IntrinsicInfo kConstructorsAndConverters[] = { }, { /* [7] */ - /* ctor vec4() -> vec4 */ - /* ctor vec4(vec4) -> vec4 */ - /* ctor vec4(T) -> vec4 */ - /* ctor vec4(x: T, y: T, z: T, w: T) -> vec4 */ - /* ctor vec4(xy: vec2, z: T, w: T) -> vec4 */ - /* ctor vec4(x: T, yz: vec2, w: T) -> vec4 */ - /* ctor vec4(x: T, y: T, zw: vec2) -> vec4 */ - /* ctor vec4(xy: vec2, zw: vec2) -> vec4 */ - /* ctor vec4(xyz: vec3, w: T) -> vec4 */ - /* ctor vec4(x: T, zyw: vec3) -> vec4 */ + /* ctor vec4() -> vec4 */ + /* ctor vec4(vec4) -> vec4 */ + /* ctor vec4(T) -> vec4 */ + /* ctor vec4(x: T, y: T, z: T, w: T) -> vec4 */ + /* ctor vec4(xy: vec2, z: T, w: T) -> vec4 */ + /* ctor vec4(x: T, yz: vec2, w: T) -> vec4 */ + /* ctor vec4(x: T, y: T, zw: vec2) -> vec4 */ + /* ctor vec4(xy: vec2, zw: vec2) -> vec4 */ + /* ctor vec4(xyz: vec3, w: T) -> vec4 */ + /* ctor vec4(x: T, zyw: vec3) -> vec4 */ /* conv vec4(vec4) -> vec4 */ /* conv vec4(vec4) -> vec4 */ /* conv vec4(vec4) -> vec4 */ diff --git a/src/tint/resolver/intrinsic_table_test.cc b/src/tint/resolver/intrinsic_table_test.cc index 34cec07f42..b99fc34c99 100644 --- a/src/tint/resolver/intrinsic_table_test.cc +++ b/src/tint/resolver/intrinsic_table_test.cc @@ -793,11 +793,11 @@ TEST_F(IntrinsicTableTest, MismatchTypeConstructorImplicit) { vec3() -> vec3 where: T is f32, f16, i32, u32 or bool 5 candidate conversions: - vec3(vec3) -> vec3 where: T is f32, U is i32, f16, u32 or bool - vec3(vec3) -> vec3 where: T is f16, U is f32, i32, u32 or bool - vec3(vec3) -> vec3 where: T is i32, U is f32, f16, u32 or bool - vec3(vec3) -> vec3 where: T is u32, U is f32, f16, i32 or bool - vec3(vec3) -> vec3 where: T is bool, U is f32, f16, i32 or u32 + vec3(vec3) -> vec3 where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool + vec3(vec3) -> vec3 where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool + vec3(vec3) -> vec3 where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool + vec3(vec3) -> vec3 where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool + vec3(vec3) -> vec3 where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32 )"); } @@ -819,11 +819,11 @@ TEST_F(IntrinsicTableTest, MismatchTypeConstructorExplicit) { vec3() -> vec3 where: T is f32, f16, i32, u32 or bool 5 candidate conversions: - vec3(vec3) -> vec3 where: T is f32, U is i32, f16, u32 or bool - vec3(vec3) -> vec3 where: T is f16, U is f32, i32, u32 or bool - vec3(vec3) -> vec3 where: T is i32, U is f32, f16, u32 or bool - vec3(vec3) -> vec3 where: T is u32, U is f32, f16, i32 or bool - vec3(vec3) -> vec3 where: T is bool, U is f32, f16, i32 or u32 + vec3(vec3) -> vec3 where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool + vec3(vec3) -> vec3 where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool + vec3(vec3) -> vec3 where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool + vec3(vec3) -> vec3 where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool + vec3(vec3) -> vec3 where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32 )"); } @@ -875,11 +875,11 @@ TEST_F(IntrinsicTableTest, MismatchTypeConversion) { vec3(x: T, y: T, z: T) -> vec3 where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool 5 candidate conversions: - vec3(vec3) -> vec3 where: T is f32, U is i32, f16, u32 or bool - vec3(vec3) -> vec3 where: T is f16, U is f32, i32, u32 or bool - vec3(vec3) -> vec3 where: T is i32, U is f32, f16, u32 or bool - vec3(vec3) -> vec3 where: T is u32, U is f32, f16, i32 or bool - vec3(vec3) -> vec3 where: T is bool, U is f32, f16, i32 or u32 + vec3(vec3) -> vec3 where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool + vec3(vec3) -> vec3 where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool + vec3(vec3) -> vec3 where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool + vec3(vec3) -> vec3 where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool + vec3(vec3) -> vec3 where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32 )"); } @@ -904,7 +904,7 @@ TEST_F(IntrinsicTableTest, OverloadResolution) { ASSERT_NE(result.target, nullptr); EXPECT_EQ(result.target->ReturnType(), i32); EXPECT_EQ(result.target->Parameters().Length(), 1u); - EXPECT_EQ(result.target->Parameters()[0]->Type(), i32); + EXPECT_EQ(result.target->Parameters()[0]->Type(), ai); } //////////////////////////////////////////////////////////////////////////////// diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h index 501cd0f583..923f3ff461 100644 --- a/src/tint/resolver/resolver_test_helper.h +++ b/src/tint/resolver/resolver_test_helper.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -174,13 +175,15 @@ template struct ptr {}; /// Type used to accept scalars as arguments. Can be either a single value that gets splatted for -/// composite types, or all values requried by the composite type. +/// composite types, or all values required by the composite type. struct ScalarArgs { + /// Constructor + ScalarArgs() = default; + /// Constructor /// @param single_value single value to initialize with template - ScalarArgs(T single_value) // NOLINT: implicit on purpose - : values(utils::Vector{single_value}) {} + explicit ScalarArgs(T single_value) : values(utils::Vector{single_value}) {} /// Constructor /// @param all_values all values to initialize the composite type with @@ -192,6 +195,10 @@ struct ScalarArgs { } } + /// @param other the other ScalarArgs to compare against + /// @returns true if all values are equal to the values in @p other + bool operator==(const ScalarArgs& other) const { return values == other.values; } + /// Valid scalar types for args using Storage = std::variant; @@ -199,10 +206,28 @@ struct ScalarArgs { utils::Vector values; }; +/// @param o the std::ostream to write to +/// @param args the ScalarArgs +/// @return the std::ostream so calls can be chained +inline std::ostream& operator<<(std::ostream& o, const ScalarArgs& args) { + o << "["; + bool first = true; + for (auto& val : args.values) { + if (!first) { + o << ", "; + } + first = false; + std::visit([&](auto&& v) { o << v; }, val); + } + o << "]"; + return o; +} + using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b); using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args); using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v); using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b); +using type_name_func_ptr = std::string (*)(); template struct DataType {}; @@ -241,7 +266,7 @@ struct DataType { /// @param v arg of type double that will be cast to bool. /// @return a new AST expression of the bool type static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "bool"; } @@ -272,7 +297,7 @@ struct DataType { /// @param v arg of type double that will be cast to i32. /// @return a new AST i32 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "i32"; } @@ -303,7 +328,7 @@ struct DataType { /// @param v arg of type double that will be cast to u32. /// @return a new AST u32 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "u32"; } @@ -334,7 +359,7 @@ struct DataType { /// @param v arg of type double that will be cast to f32. /// @return a new AST f32 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "f32"; } @@ -365,7 +390,7 @@ struct DataType { /// @param v arg of type double that will be cast to f16. /// @return a new AST f16 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "f16"; } @@ -395,7 +420,7 @@ struct DataType { /// @param v arg of type double that will be cast to AFloat. /// @return a new AST abstract-float literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "abstract-float"; } @@ -425,7 +450,7 @@ struct DataType { /// @param v arg of type double that will be cast to AInt. /// @return a new AST abstract-int literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "abstract-int"; } @@ -463,7 +488,7 @@ struct DataType> { const bool one_value = args.values.Length() == 1; utils::Vector r; for (size_t i = 0; i < N; ++i) { - r.Push(DataType::Expr(b, one_value ? args.values[0] : args.values[i])); + r.Push(DataType::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]})); } return r; } @@ -471,7 +496,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST vector value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -514,7 +539,7 @@ struct DataType> { utils::Vector r; for (uint32_t i = 0; i < N; ++i) { if (one_value) { - r.Push(DataType>::Expr(b, args.values[0])); + r.Push(DataType>::Expr(b, ScalarArgs{args.values[0]})); } else { utils::Vector v; for (size_t j = 0; j < M; ++j) { @@ -529,7 +554,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST matrix value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -585,7 +610,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST expression of the alias type static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type @@ -626,7 +651,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST expression of the pointer type static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type @@ -680,7 +705,7 @@ struct DataType> { const bool one_value = args.values.Length() == 1; utils::Vector r; for (uint32_t i = 0; i < N; i++) { - r.Push(DataType::Expr(b, one_value ? args.values[0] : args.values[i])); + r.Push(DataType::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]})); } return r; } @@ -688,7 +713,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST array value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, static_cast(v)); + return Expr(b, ScalarArgs{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -706,13 +731,23 @@ struct CreatePtrs { ast_expr_from_double_func_ptr expr_from_double; /// sem type create function sem_type_func_ptr sem; + /// type name function + type_name_func_ptr name; }; +/// @param o the std::ostream to write to +/// @param ptrs the CreatePtrs +/// @return the std::ostream so calls can be chained +inline std::ostream& operator<<(std::ostream& o, const CreatePtrs& ptrs) { + return o << (ptrs.name ? ptrs.name() : ""); +} + /// Returns a CreatePtrs struct instance with all creation pointer types for /// type `T` template constexpr CreatePtrs CreatePtrsFor() { - return {DataType::AST, DataType::Expr, DataType::ExprFromDouble, DataType::Sem}; + return {DataType::AST, DataType::Expr, DataType::ExprFromDouble, DataType::Sem, + DataType::Name}; } /// Value is an instance of a value of type DataType. Useful for storing values to create @@ -729,15 +764,15 @@ struct Value { /// Creates a Value with `args` /// @param args the args that will be passed to the expression /// @returns a Value - static Value Create(ScalarArgs args) { return Value{DataType::Expr, std::move(args)}; } + static Value Create(ScalarArgs args) { return Value{CreatePtrsFor(), std::move(args)}; } /// Creates an `ast::Expression` for the type T passing in previously stored args /// @param b the ProgramBuilder /// @returns an expression node - const ast::Expression* Expr(ProgramBuilder& b) const { return (*expr)(b, args); } + const ast::Expression* Expr(ProgramBuilder& b) const { return (*create.expr)(b, args); } - /// ast expression type create function - ast_expr_func_ptr expr; + /// functions to create values / types of the value + CreatePtrs create; /// args to create expression with ScalarArgs args; }; @@ -764,7 +799,7 @@ const char* FriendlyName() { /// Creates a `Value` from a scalar `v` template auto Val(T v) { - return Value::Create(v); + return Value::Create(ScalarArgs{v}); } /// Creates a `Value>` from N scalar `args` diff --git a/src/tint/writer/msl/generator_impl_cast_test.cc b/src/tint/writer/msl/generator_impl_cast_test.cc index 4b9e3f28d9..74b3be6a8c 100644 --- a/src/tint/writer/msl/generator_impl_cast_test.cc +++ b/src/tint/writer/msl/generator_impl_cast_test.cc @@ -51,7 +51,7 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Cast_IntMin) { std::stringstream out; ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error(); - EXPECT_EQ(out.str(), "0u"); + EXPECT_EQ(out.str(), "2147483648u"); } } // namespace