diff --git a/src/tint/constant/clone_context.h b/src/tint/constant/clone_context.h index 3f597ed6b9..5709baf3c4 100644 --- a/src/tint/constant/clone_context.h +++ b/src/tint/constant/clone_context.h @@ -16,11 +16,10 @@ #define SRC_TINT_CONSTANT_CLONE_CONTEXT_H_ #include "src/tint/type/clone_context.h" -#include "src/tint/utils/block_allocator.h" -// Forward Declarations +// Forward declarations namespace tint::constant { -class Value; +class Manager; } // namespace tint::constant namespace tint::constant { @@ -31,10 +30,7 @@ struct CloneContext { type::CloneContext type_ctx; /// Destination information - struct { - /// The constant allocator - utils::BlockAllocator* constants; - } dst; + constant::Manager& dst; }; } // namespace tint::constant diff --git a/src/tint/constant/composite.cc b/src/tint/constant/composite.cc index fee4baf186..488f6f0341 100644 --- a/src/tint/constant/composite.cc +++ b/src/tint/constant/composite.cc @@ -16,6 +16,8 @@ #include +#include "src/tint/constant/manager.h" + TINT_INSTANTIATE_TYPEINFO(tint::constant::Composite); namespace tint::constant { @@ -24,7 +26,9 @@ Composite::Composite(const type::Type* t, utils::VectorRef els, bool all_0, bool any_0) - : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {} + : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) { + TINT_ASSERT(Constant, !elements.IsEmpty()); +} Composite::~Composite() = default; @@ -34,7 +38,7 @@ const Composite* Composite::Clone(CloneContext& ctx) const { for (const auto* el : elements) { els.Push(el->Clone(ctx)); } - return ctx.dst.constants->Create(ty, els, all_zero, any_zero); + return ctx.dst.Get(ty, std::move(els), all_zero, any_zero); } } // namespace tint::constant diff --git a/src/tint/constant/composite_test.cc b/src/tint/constant/composite_test.cc index fd083db7a0..6a4c12d422 100644 --- a/src/tint/constant/composite_test.cc +++ b/src/tint/constant/composite_test.cc @@ -25,15 +25,15 @@ using namespace tint::number_suffixes; // NOLINT using ConstantTest_Composite = TestHelper; TEST_F(ConstantTest_Composite, AllZero) { - auto* f32 = create(); + auto* vec3f = create(create(), 3u); - auto* fPos0 = create>(f32, 0_f); - auto* fNeg0 = create>(f32, -0_f); - auto* fPos1 = create>(f32, 1_f); + auto* fPos0 = constants.Get(0_f); + auto* fNeg0 = constants.Get(-0_f); + auto* fPos1 = constants.Get(1_f); - auto* compositeAll = create(f32, utils::Vector{fPos0, fPos0}); - auto* compositeAny = create(f32, utils::Vector{fNeg0, fPos1, fPos0}); - auto* compositeNone = create(f32, utils::Vector{fNeg0, fNeg0}); + auto* compositeAll = constants.Composite(vec3f, utils::Vector{fPos0, fPos0}); + auto* compositeAny = constants.Composite(vec3f, utils::Vector{fNeg0, fPos1, fPos0}); + auto* compositeNone = constants.Composite(vec3f, utils::Vector{fNeg0, fNeg0}); EXPECT_TRUE(compositeAll->AllZero()); EXPECT_FALSE(compositeAny->AllZero()); @@ -41,15 +41,15 @@ TEST_F(ConstantTest_Composite, AllZero) { } TEST_F(ConstantTest_Composite, AnyZero) { - auto* f32 = create(); + auto* vec3f = create(create(), 3u); - auto* fPos0 = create>(f32, 0_f); - auto* fNeg0 = create>(f32, -0_f); - auto* fPos1 = create>(f32, 1_f); + auto* fPos0 = constants.Get(0_f); + auto* fNeg0 = constants.Get(-0_f); + auto* fPos1 = constants.Get(1_f); - auto* compositeAll = create(f32, utils::Vector{fPos0, fPos0}); - auto* compositeAny = create(f32, utils::Vector{fNeg0, fPos1, fPos0}); - auto* compositeNone = create(f32, utils::Vector{fNeg0, fNeg0}); + auto* compositeAll = constants.Composite(vec3f, utils::Vector{fPos0, fPos0}); + auto* compositeAny = constants.Composite(vec3f, utils::Vector{fNeg0, fPos1, fPos0}); + auto* compositeNone = constants.Composite(vec3f, utils::Vector{fNeg0, fNeg0}); EXPECT_TRUE(compositeAll->AnyZero()); EXPECT_TRUE(compositeAny->AnyZero()); @@ -57,12 +57,12 @@ TEST_F(ConstantTest_Composite, AnyZero) { } TEST_F(ConstantTest_Composite, Index) { - auto* f32 = create(); + auto* vec3f = create(create(), 3u); - auto* fPos0 = create>(f32, 0_f); - auto* fPos1 = create>(f32, 1_f); + auto* fPos0 = constants.Get(0_f); + auto* fPos1 = constants.Get(1_f); - auto* composite = create(f32, utils::Vector{fPos1, fPos0}); + auto* composite = constants.Composite(vec3f, utils::Vector{fPos1, fPos0}); ASSERT_NE(composite->Index(0), nullptr); ASSERT_NE(composite->Index(1), nullptr); @@ -75,20 +75,19 @@ TEST_F(ConstantTest_Composite, Index) { } TEST_F(ConstantTest_Composite, Clone) { - auto* f32 = create(); + auto* vec3f = create(create(), 3u); - auto* fPos0 = create>(f32, 0_f); - auto* fPos1 = create>(f32, 1_f); + auto* fPos0 = constants.Get(0_f); + auto* fPos1 = constants.Get(1_f); - auto* composite = create(f32, utils::Vector{fPos1, fPos0}); + auto* composite = constants.Composite(vec3f, utils::Vector{fPos1, fPos0}); - type::Manager mgr; - utils::BlockAllocator consts; - constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr}}, {&consts}}; + constant::Manager mgr; + constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr}; auto* r = composite->As()->Clone(ctx); ASSERT_NE(r, nullptr); - EXPECT_TRUE(r->type->Is()); + EXPECT_TRUE(r->type->Is()); EXPECT_FALSE(r->all_zero); EXPECT_TRUE(r->any_zero); ASSERT_EQ(r->elements.Length(), 2u); diff --git a/src/tint/constant/scalar.h b/src/tint/constant/scalar.h index 5a40cdf457..474c9ad4f8 100644 --- a/src/tint/constant/scalar.h +++ b/src/tint/constant/scalar.h @@ -15,6 +15,7 @@ #ifndef SRC_TINT_CONSTANT_SCALAR_H_ #define SRC_TINT_CONSTANT_SCALAR_H_ +#include "src/tint/constant/manager.h" #include "src/tint/constant/value.h" #include "src/tint/number.h" #include "src/tint/type/type.h" @@ -63,7 +64,7 @@ class Scalar : public utils::Castable, Value> { /// @returns the cloned node const Scalar* Clone(CloneContext& ctx) const override { auto* ty = type->Clone(ctx.type_ctx); - return ctx.dst.constants->Create>(ty, value); + return ctx.dst.Get>(ty, value); } /// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the diff --git a/src/tint/constant/scalar_test.cc b/src/tint/constant/scalar_test.cc index 6dedf75b5f..49881e5c84 100644 --- a/src/tint/constant/scalar_test.cc +++ b/src/tint/constant/scalar_test.cc @@ -24,40 +24,34 @@ using namespace tint::number_suffixes; // NOLINT using ConstantTest_Scalar = TestHelper; TEST_F(ConstantTest_Scalar, AllZero) { - auto* i32 = create(); - auto* u32 = create(); - auto* f16 = create(); - auto* f32 = create(); - auto* bool_ = create(); + auto* i0 = constants.Get(0_i); + auto* iPos1 = constants.Get(1_i); + auto* iNeg1 = constants.Get(-1_i); - auto* i0 = create>(i32, 0_i); - auto* iPos1 = create>(i32, 1_i); - auto* iNeg1 = create>(i32, -1_i); + auto* u0 = constants.Get(0_u); + auto* u1 = constants.Get(1_u); - auto* u0 = create>(u32, 0_u); - auto* u1 = create>(u32, 1_u); + auto* fPos0 = constants.Get(0_f); + auto* fNeg0 = constants.Get(-0_f); + auto* fPos1 = constants.Get(1_f); + auto* fNeg1 = constants.Get(-1_f); - auto* fPos0 = create>(f32, 0_f); - auto* fNeg0 = create>(f32, -0_f); - auto* fPos1 = create>(f32, 1_f); - auto* fNeg1 = create>(f32, -1_f); + auto* f16Pos0 = constants.Get(0_h); + auto* f16Neg0 = constants.Get(-0_h); + auto* f16Pos1 = constants.Get(1_h); + auto* f16Neg1 = constants.Get(-1_h); - auto* f16Pos0 = create>(f16, 0_h); - auto* f16Neg0 = create>(f16, -0_h); - auto* f16Pos1 = create>(f16, 1_h); - auto* f16Neg1 = create>(f16, -1_h); + auto* bf = constants.Get(false); + auto* bt = constants.Get(true); - auto* bf = create>(bool_, false); - auto* bt = create>(bool_, true); + auto* afPos0 = constants.Get(0.0_a); + auto* afNeg0 = constants.Get(-0.0_a); + auto* afPos1 = constants.Get(1.0_a); + auto* afNeg1 = constants.Get(-1.0_a); - auto* afPos0 = create>(f32, 0.0_a); - auto* afNeg0 = create>(f32, -0.0_a); - auto* afPos1 = create>(f32, 1.0_a); - auto* afNeg1 = create>(f32, -1.0_a); - - auto* ai0 = create>(i32, 0_a); - auto* aiPos1 = create>(i32, 1_a); - auto* aiNeg1 = create>(i32, -1_a); + auto* ai0 = constants.Get(0_a); + auto* aiPos1 = constants.Get(1_a); + auto* aiNeg1 = constants.Get(-1_a); EXPECT_TRUE(i0->AllZero()); EXPECT_FALSE(iPos1->AllZero()); @@ -90,40 +84,34 @@ TEST_F(ConstantTest_Scalar, AllZero) { } TEST_F(ConstantTest_Scalar, AnyZero) { - auto* i32 = create(); - auto* u32 = create(); - auto* f16 = create(); - auto* f32 = create(); - auto* bool_ = create(); + auto* i0 = constants.Get(0_i); + auto* iPos1 = constants.Get(1_i); + auto* iNeg1 = constants.Get(-1_i); - auto* i0 = create>(i32, 0_i); - auto* iPos1 = create>(i32, 1_i); - auto* iNeg1 = create>(i32, -1_i); + auto* u0 = constants.Get(0_u); + auto* u1 = constants.Get(1_u); - auto* u0 = create>(u32, 0_u); - auto* u1 = create>(u32, 1_u); + auto* fPos0 = constants.Get(0_f); + auto* fNeg0 = constants.Get(-0_f); + auto* fPos1 = constants.Get(1_f); + auto* fNeg1 = constants.Get(-1_f); - auto* fPos0 = create>(f32, 0_f); - auto* fNeg0 = create>(f32, -0_f); - auto* fPos1 = create>(f32, 1_f); - auto* fNeg1 = create>(f32, -1_f); + auto* f16Pos0 = constants.Get(0_h); + auto* f16Neg0 = constants.Get(-0_h); + auto* f16Pos1 = constants.Get(1_h); + auto* f16Neg1 = constants.Get(-1_h); - auto* f16Pos0 = create>(f16, 0_h); - auto* f16Neg0 = create>(f16, -0_h); - auto* f16Pos1 = create>(f16, 1_h); - auto* f16Neg1 = create>(f16, -1_h); + auto* bf = constants.Get(false); + auto* bt = constants.Get(true); - auto* bf = create>(bool_, false); - auto* bt = create>(bool_, true); + auto* afPos0 = constants.Get(0.0_a); + auto* afNeg0 = constants.Get(-0.0_a); + auto* afPos1 = constants.Get(1.0_a); + auto* afNeg1 = constants.Get(-1.0_a); - auto* afPos0 = create>(f32, 0.0_a); - auto* afNeg0 = create>(f32, -0.0_a); - auto* afPos1 = create>(f32, 1.0_a); - auto* afNeg1 = create>(f32, -1.0_a); - - auto* ai0 = create>(i32, 0_a); - auto* aiPos1 = create>(i32, 1_a); - auto* aiNeg1 = create>(i32, -1_a); + auto* ai0 = constants.Get(0_a); + auto* aiPos1 = constants.Get(1_a); + auto* aiNeg1 = constants.Get(-1_a); EXPECT_TRUE(i0->AnyZero()); EXPECT_FALSE(iPos1->AnyZero()); @@ -156,20 +144,14 @@ TEST_F(ConstantTest_Scalar, AnyZero) { } TEST_F(ConstantTest_Scalar, ValueOf) { - auto* i32 = create(); - auto* u32 = create(); - auto* f16 = create(); - auto* f32 = create(); - auto* bool_ = create(); - - auto* i1 = create>(i32, 1_i); - auto* u1 = create>(u32, 1_u); - auto* f1 = create>(f32, 1_f); - auto* f16Pos1 = create>(f16, 1_h); - auto* bf = create>(bool_, false); - auto* bt = create>(bool_, true); - auto* af1 = create>(f32, 1.0_a); - auto* ai1 = create>(i32, 1_a); + auto* i1 = constants.Get(1_i); + auto* u1 = constants.Get(1_u); + auto* f1 = constants.Get(1_f); + auto* f16Pos1 = constants.Get(1_h); + auto* bf = constants.Get(false); + auto* bt = constants.Get(true); + auto* af1 = constants.Get(1.0_a); + auto* ai1 = constants.Get(1_a); EXPECT_EQ(i1->ValueOf(), 1); EXPECT_EQ(u1->ValueOf(), 1u); @@ -182,12 +164,10 @@ TEST_F(ConstantTest_Scalar, ValueOf) { } TEST_F(ConstantTest_Scalar, Clone) { - auto* i32 = create(); - auto* val = create>(i32, 12_i); + auto* val = constants.Get(12_i); - type::Manager mgr; - utils::BlockAllocator consts; - constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr}}, {&consts}}; + constant::Manager mgr; + constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr}; auto* r = val->Clone(ctx); ASSERT_NE(r, nullptr); diff --git a/src/tint/constant/splat.cc b/src/tint/constant/splat.cc index e1c9a90fea..8adbed7652 100644 --- a/src/tint/constant/splat.cc +++ b/src/tint/constant/splat.cc @@ -14,6 +14,8 @@ #include "src/tint/constant/splat.h" +#include "src/tint/constant/manager.h" + TINT_INSTANTIATE_TYPEINFO(tint::constant::Splat); namespace tint::constant { @@ -25,7 +27,7 @@ Splat::~Splat() = default; const Splat* Splat::Clone(CloneContext& ctx) const { auto* ty = type->Clone(ctx.type_ctx); auto* element = el->Clone(ctx); - return ctx.dst.constants->Create(ty, element, count); + return ctx.dst.Splat(ty, element, count); } } // namespace tint::constant diff --git a/src/tint/constant/splat_test.cc b/src/tint/constant/splat_test.cc index fe8aeda090..0e2f21ed5d 100644 --- a/src/tint/constant/splat_test.cc +++ b/src/tint/constant/splat_test.cc @@ -25,15 +25,15 @@ using namespace tint::number_suffixes; // NOLINT using ConstantTest_Splat = TestHelper; TEST_F(ConstantTest_Splat, AllZero) { - auto* f32 = create(); + auto* vec3f = create(create(), 3u); - auto* fPos0 = create>(f32, 0_f); - auto* fNeg0 = create>(f32, -0_f); - auto* fPos1 = create>(f32, 1_f); + auto* fPos0 = constants.Get(0_f); + auto* fNeg0 = constants.Get(-0_f); + auto* fPos1 = constants.Get(1_f); - auto* SpfPos0 = create(f32, fPos0, 2); - auto* SpfNeg0 = create(f32, fNeg0, 2); - auto* SpfPos1 = create(f32, fPos1, 2); + auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2); + auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2); + auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2); EXPECT_TRUE(SpfPos0->AllZero()); EXPECT_FALSE(SpfNeg0->AllZero()); @@ -41,15 +41,15 @@ TEST_F(ConstantTest_Splat, AllZero) { } TEST_F(ConstantTest_Splat, AnyZero) { - auto* f32 = create(); + auto* vec3f = create(create(), 3u); - auto* fPos0 = create>(f32, 0_f); - auto* fNeg0 = create>(f32, -0_f); - auto* fPos1 = create>(f32, 1_f); + auto* fPos0 = constants.Get(0_f); + auto* fNeg0 = constants.Get(-0_f); + auto* fPos1 = constants.Get(1_f); - auto* SpfPos0 = create(f32, fPos0, 2); - auto* SpfNeg0 = create(f32, fNeg0, 2); - auto* SpfPos1 = create(f32, fPos1, 2); + auto* SpfPos0 = constants.Splat(vec3f, fPos0, 2); + auto* SpfNeg0 = constants.Splat(vec3f, fNeg0, 2); + auto* SpfPos1 = constants.Splat(vec3f, fPos1, 2); EXPECT_TRUE(SpfPos0->AnyZero()); EXPECT_FALSE(SpfNeg0->AnyZero()); @@ -57,10 +57,10 @@ TEST_F(ConstantTest_Splat, AnyZero) { } TEST_F(ConstantTest_Splat, Index) { - auto* f32 = create(); + auto* vec3f = create(create(), 3u); - auto* f1 = create>(f32, 1_f); - auto* sp = create(f32, f1, 2); + auto* f1 = constants.Get(1_f); + auto* sp = constants.Splat(vec3f, f1, 2); ASSERT_NE(sp->Index(0), nullptr); ASSERT_NE(sp->Index(1), nullptr); @@ -71,17 +71,16 @@ TEST_F(ConstantTest_Splat, Index) { } TEST_F(ConstantTest_Splat, Clone) { - auto* i32 = create(); - auto* val = create>(i32, 12_i); - auto* sp = create(i32, val, 2); + auto* vec3i = create(create(), 3u); + auto* val = constants.Get(12_i); + auto* sp = constants.Splat(vec3i, val, 2); - type::Manager mgr; - utils::BlockAllocator consts; - constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr}}, {&consts}}; + constant::Manager mgr; + constant::CloneContext ctx{type::CloneContext{{nullptr}, {nullptr, &mgr.types}}, mgr}; auto* r = sp->Clone(ctx); ASSERT_NE(r, nullptr); - EXPECT_TRUE(r->type->Is()); + EXPECT_TRUE(r->type->Is()); EXPECT_TRUE(r->el->Is>()); EXPECT_EQ(r->count, 2u); } diff --git a/src/tint/ir/binary_test.cc b/src/tint/ir/binary_test.cc index 2435936f15..281d04b1d3 100644 --- a/src/tint/ir/binary_test.cc +++ b/src/tint/ir/binary_test.cc @@ -27,7 +27,7 @@ TEST_F(IR_InstructionTest, CreateAnd) { Module mod; Builder b{mod}; - const auto* inst = b.And(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd); @@ -48,7 +48,7 @@ TEST_F(IR_InstructionTest, CreateOr) { Module mod; Builder b{mod}; - const auto* inst = b.Or(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Or(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kOr); @@ -68,7 +68,7 @@ TEST_F(IR_InstructionTest, CreateXor) { Module mod; Builder b{mod}; - const auto* inst = b.Xor(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Xor(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kXor); @@ -88,7 +88,7 @@ TEST_F(IR_InstructionTest, CreateEqual) { Module mod; Builder b{mod}; - const auto* inst = b.Equal(b.ir.types.bool_(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Equal(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kEqual); @@ -108,7 +108,7 @@ TEST_F(IR_InstructionTest, CreateNotEqual) { Module mod; Builder b{mod}; - const auto* inst = b.NotEqual(b.ir.types.bool_(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.NotEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kNotEqual); @@ -128,7 +128,7 @@ TEST_F(IR_InstructionTest, CreateLessThan) { Module mod; Builder b{mod}; - const auto* inst = b.LessThan(b.ir.types.bool_(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.LessThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kLessThan); @@ -148,7 +148,7 @@ TEST_F(IR_InstructionTest, CreateGreaterThan) { Module mod; Builder b{mod}; - const auto* inst = b.GreaterThan(b.ir.types.bool_(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.GreaterThan(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kGreaterThan); @@ -168,7 +168,7 @@ TEST_F(IR_InstructionTest, CreateLessThanEqual) { Module mod; Builder b{mod}; - const auto* inst = b.LessThanEqual(b.ir.types.bool_(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.LessThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kLessThanEqual); @@ -188,7 +188,7 @@ TEST_F(IR_InstructionTest, CreateGreaterThanEqual) { Module mod; Builder b{mod}; - const auto* inst = b.GreaterThanEqual(b.ir.types.bool_(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.GreaterThanEqual(mod.Types().bool_(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kGreaterThanEqual); @@ -207,7 +207,7 @@ TEST_F(IR_InstructionTest, CreateGreaterThanEqual) { TEST_F(IR_InstructionTest, CreateNot) { Module mod; Builder b{mod}; - const auto* inst = b.Not(b.ir.types.bool_(), b.Constant(true)); + const auto* inst = b.Not(mod.Types().bool_(), b.Constant(true)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kEqual); @@ -227,7 +227,7 @@ TEST_F(IR_InstructionTest, CreateShiftLeft) { Module mod; Builder b{mod}; - const auto* inst = b.ShiftLeft(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.ShiftLeft(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kShiftLeft); @@ -247,7 +247,7 @@ TEST_F(IR_InstructionTest, CreateShiftRight) { Module mod; Builder b{mod}; - const auto* inst = b.ShiftRight(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.ShiftRight(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kShiftRight); @@ -267,7 +267,7 @@ TEST_F(IR_InstructionTest, CreateAdd) { Module mod; Builder b{mod}; - const auto* inst = b.Add(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Add(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kAdd); @@ -287,7 +287,7 @@ TEST_F(IR_InstructionTest, CreateSubtract) { Module mod; Builder b{mod}; - const auto* inst = b.Subtract(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Subtract(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kSubtract); @@ -307,7 +307,7 @@ TEST_F(IR_InstructionTest, CreateMultiply) { Module mod; Builder b{mod}; - const auto* inst = b.Multiply(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Multiply(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kMultiply); @@ -327,7 +327,7 @@ TEST_F(IR_InstructionTest, CreateDivide) { Module mod; Builder b{mod}; - const auto* inst = b.Divide(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Divide(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kDivide); @@ -347,7 +347,7 @@ TEST_F(IR_InstructionTest, CreateModulo) { Module mod; Builder b{mod}; - const auto* inst = b.Modulo(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.Modulo(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Binary::Kind::kModulo); @@ -366,7 +366,7 @@ TEST_F(IR_InstructionTest, CreateModulo) { TEST_F(IR_InstructionTest, Binary_Usage) { Module mod; Builder b{mod}; - const auto* inst = b.And(b.ir.types.i32(), b.Constant(4_i), b.Constant(2_i)); + const auto* inst = b.And(mod.Types().i32(), b.Constant(4_i), b.Constant(2_i)); EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd); @@ -383,7 +383,7 @@ TEST_F(IR_InstructionTest, Binary_Usage_DuplicateValue) { Module mod; Builder b{mod}; auto val = b.Constant(4_i); - const auto* inst = b.And(b.ir.types.i32(), val, val); + const auto* inst = b.And(mod.Types().i32(), val, val); EXPECT_EQ(inst->Kind(), Binary::Kind::kAnd); ASSERT_EQ(inst->LHS(), inst->RHS()); diff --git a/src/tint/ir/bitcast_test.cc b/src/tint/ir/bitcast_test.cc index 157769c039..bf66e39bb2 100644 --- a/src/tint/ir/bitcast_test.cc +++ b/src/tint/ir/bitcast_test.cc @@ -27,7 +27,7 @@ using IR_InstructionTest = TestHelper; TEST_F(IR_InstructionTest, Bitcast) { Module mod; Builder b{mod}; - const auto* inst = b.Bitcast(b.ir.types.i32(), b.Constant(4_i)); + const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i)); ASSERT_TRUE(inst->Is()); ASSERT_NE(inst->Type(), nullptr); @@ -43,7 +43,7 @@ TEST_F(IR_InstructionTest, Bitcast) { TEST_F(IR_InstructionTest, Bitcast_Usage) { Module mod; Builder b{mod}; - const auto* inst = b.Bitcast(b.ir.types.i32(), b.Constant(4_i)); + const auto* inst = b.Bitcast(mod.Types().i32(), b.Constant(4_i)); const auto args = inst->Args(); ASSERT_EQ(args.Length(), 1u); diff --git a/src/tint/ir/builder.cc b/src/tint/ir/builder.cc index d395da6144..603f305717 100644 --- a/src/tint/ir/builder.cc +++ b/src/tint/ir/builder.cc @@ -170,7 +170,7 @@ Unary* Builder::Negation(const type::Type* type, Value* val) { } Binary* Builder::Not(const type::Type* type, Value* val) { - return Equal(type, val, Constant(create>(type, false))); + return Equal(type, val, Constant(false)); } ir::Bitcast* Builder::Bitcast(const type::Type* type, Value* val) { diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h index fdddc6ed07..02d2b369d3 100644 --- a/src/tint/ir/builder.h +++ b/src/tint/ir/builder.h @@ -110,50 +110,6 @@ class Builder { /// @returns the start block for the case flow node Block* CreateCase(Switch* s, utils::VectorRef selectors); - /// Creates a constant::Value - /// @param args the arguments - /// @returns the new constant value - template - utils::traits::EnableIf, const T>* create( - ARGS&&... args) { - return ir.constants_arena.Create(std::forward(args)...); - } - - /// @param v the value - /// @returns the constant value - const constant::Value* Bool(bool v) { - // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. - return Constant(create>(ir.types.bool_(), v))->Value(); - } - - /// @param v the value - /// @returns the constant value - const constant::Value* U32(uint32_t v) { - // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. - return Constant(create>(ir.types.u32(), u32(v)))->Value(); - } - - /// @param v the value - /// @returns the constant value - const constant::Value* I32(int32_t v) { - // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. - return Constant(create>(ir.types.i32(), i32(v)))->Value(); - } - - /// @param v the value - /// @returns the constant value - const constant::Value* F16(float v) { - // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. - return Constant(create>(ir.types.f16(), f16(v)))->Value(); - } - - /// @param v the value - /// @returns the constant value - const constant::Value* F32(float v) { - // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. - return Constant(create>(ir.types.f32(), f32(v)))->Value(); - } - /// Creates a new ir::Constant /// @param val the constant value /// @returns the new constant @@ -164,37 +120,27 @@ class Builder { /// Creates a ir::Constant for an i32 Scalar /// @param v the value /// @returns the new constant - ir::Constant* Constant(i32 v) { - return Constant(create>(ir.types.i32(), v)); - } + ir::Constant* Constant(i32 v) { return Constant(ir.constant_values.Get(v)); } /// Creates a ir::Constant for a u32 Scalar /// @param v the value /// @returns the new constant - ir::Constant* Constant(u32 v) { - return Constant(create>(ir.types.u32(), v)); - } + ir::Constant* Constant(u32 v) { return Constant(ir.constant_values.Get(v)); } /// Creates a ir::Constant for a f32 Scalar /// @param v the value /// @returns the new constant - ir::Constant* Constant(f32 v) { - return Constant(create>(ir.types.f32(), v)); - } + ir::Constant* Constant(f32 v) { return Constant(ir.constant_values.Get(v)); } /// Creates a ir::Constant for a f16 Scalar /// @param v the value /// @returns the new constant - ir::Constant* Constant(f16 v) { - return Constant(create>(ir.types.f16(), v)); - } + ir::Constant* Constant(f16 v) { return Constant(ir.constant_values.Get(v)); } /// Creates a ir::Constant for a bool Scalar /// @param v the value /// @returns the new constant - ir::Constant* Constant(bool v) { - return Constant(create>(ir.types.bool_(), v)); - } + ir::Constant* Constant(bool v) { return Constant(ir.constant_values.Get(v)); } /// Creates an op for `lhs kind rhs` /// @param kind the kind of operation diff --git a/src/tint/ir/from_program.cc b/src/tint/ir/from_program.cc index ad9423581a..f64037a268 100644 --- a/src/tint/ir/from_program.cc +++ b/src/tint/ir/from_program.cc @@ -132,9 +132,9 @@ class Impl { constant::CloneContext clone_ctx_{ /* type_ctx */ type::CloneContext{ /* src */ {&program_->Symbols()}, - /* dst */ {&builder_.ir.symbols, &builder_.ir.types}, + /* dst */ {&builder_.ir.symbols, &builder_.ir.Types()}, }, - /* dst */ {&builder_.ir.constants_arena}, + /* dst */ {builder_.ir.constant_values}, }; /// The stack of control blocks. @@ -841,7 +841,7 @@ class Impl { var, [&](const ast::Var* v) { auto* ref = sem->Type()->As(); - auto* ty = builder_.ir.types.Get( + auto* ty = builder_.ir.Types().Get( ref->StoreType()->Clone(clone_ctx_.type_ctx), ref->AddressSpace(), ref->Access()); @@ -946,7 +946,7 @@ class Impl { auto* if_inst = builder_.CreateIf(lhs.Get()); current_flow_block_->Instructions().Push(if_inst); - auto* result = builder_.BlockParam(builder_.ir.types.bool_()); + auto* result = builder_.BlockParam(builder_.ir.Types().bool_()); if_inst->Merge()->SetParams(utils::Vector{result}); utils::Result rhs; diff --git a/src/tint/ir/load_test.cc b/src/tint/ir/load_test.cc index e920a6ce1d..72359e7bab 100644 --- a/src/tint/ir/load_test.cc +++ b/src/tint/ir/load_test.cc @@ -27,8 +27,8 @@ TEST_F(IR_InstructionTest, CreateLoad) { Module mod; Builder b{mod}; - auto* store_type = b.ir.types.i32(); - auto* var = b.Declare(b.ir.types.Get( + auto* store_type = mod.Types().i32(); + auto* var = b.Declare(mod.Types().Get( store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite)); const auto* inst = b.Load(var); @@ -45,8 +45,8 @@ TEST_F(IR_InstructionTest, Load_Usage) { Module mod; Builder b{mod}; - auto* store_type = b.ir.types.i32(); - auto* var = b.Declare(b.ir.types.Get( + auto* store_type = mod.Types().i32(); + auto* var = b.Declare(mod.Types().Get( store_type, builtin::AddressSpace::kFunction, builtin::Access::kReadWrite)); const auto* inst = b.Load(var); diff --git a/src/tint/ir/module.h b/src/tint/ir/module.h index 9a9034b78f..cabbf11b45 100644 --- a/src/tint/ir/module.h +++ b/src/tint/ir/module.h @@ -17,7 +17,7 @@ #include -#include "src/tint/constant/value.h" +#include "src/tint/constant/manager.h" #include "src/tint/ir/constant.h" #include "src/tint/ir/flow_node.h" #include "src/tint/ir/function.h" @@ -66,10 +66,15 @@ class Module { /// @return the unique symbol of the given value. Symbol SetName(const Value* value, std::string_view name); + /// @return the type manager for the module + type::Manager& Types() { return constant_values.types; } + /// The flow node allocator utils::BlockAllocator flow_nodes; - /// The constant allocator - utils::BlockAllocator constants_arena; + + /// The constant value manager + constant::Manager constant_values; + /// The value allocator utils::BlockAllocator values; @@ -79,34 +84,11 @@ class Module { /// The block containing module level declarations, if any exist. Block* root_block = nullptr; - /// The type manager for the module - type::Manager types; - /// The symbol table for the module SymbolTable symbols{prog_id_}; - /// ConstantHasher provides a hash function for a constant::Value pointer, hashing the value - /// instead of the pointer itself. - struct ConstantHasher { - /// @param c the constant pointer to create a hash for - /// @return the hash value - inline std::size_t operator()(const constant::Value* c) const { return c->Hash(); } - }; - - /// ConstantEquals provides an equality function for two constant::Value pointers, comparing - /// their values instead of the pointers. - struct ConstantEquals { - /// @param a the first constant pointer to compare - /// @param b the second constant pointer to compare - /// @return the hash value - inline bool operator()(const constant::Value* a, const constant::Value* b) const { - return a->Equal(b); - } - }; - /// The map of constant::Value to their ir::Constant. - utils::Hashmap - constants; + utils::Hashmap constants; }; } // namespace tint::ir diff --git a/src/tint/ir/module_test.cc b/src/tint/ir/module_test.cc index c9343d9e9a..f15aa35a00 100644 --- a/src/tint/ir/module_test.cc +++ b/src/tint/ir/module_test.cc @@ -25,20 +25,20 @@ using IR_ModuleTest = TestHelper; TEST_F(IR_ModuleTest, NameOfUnnamed) { Module mod; - auto* v = mod.values.Create(mod.types.i32()); + auto* v = mod.values.Create(mod.Types().i32()); EXPECT_FALSE(mod.NameOf(v).IsValid()); } TEST_F(IR_ModuleTest, SetName) { Module mod; - auto* v = mod.values.Create(mod.types.i32()); + auto* v = mod.values.Create(mod.Types().i32()); EXPECT_EQ(mod.SetName(v, "a").Name(), "a"); EXPECT_EQ(mod.NameOf(v).Name(), "a"); } TEST_F(IR_ModuleTest, SetNameRename) { Module mod; - auto* v = mod.values.Create(mod.types.i32()); + auto* v = mod.values.Create(mod.Types().i32()); EXPECT_EQ(mod.SetName(v, "a").Name(), "a"); EXPECT_EQ(mod.SetName(v, "b").Name(), "b"); EXPECT_EQ(mod.NameOf(v).Name(), "b"); @@ -46,9 +46,9 @@ TEST_F(IR_ModuleTest, SetNameRename) { TEST_F(IR_ModuleTest, SetNameCollision) { Module mod; - auto* a = mod.values.Create(mod.types.i32()); - auto* b = mod.values.Create(mod.types.i32()); - auto* c = mod.values.Create(mod.types.i32()); + auto* a = mod.values.Create(mod.Types().i32()); + auto* b = mod.values.Create(mod.Types().i32()); + auto* c = mod.values.Create(mod.Types().i32()); EXPECT_EQ(mod.SetName(a, "x").Name(), "x"); EXPECT_EQ(mod.SetName(b, "x_1").Name(), "x_1"); EXPECT_EQ(mod.SetName(c, "x").Name(), "x_2"); diff --git a/src/tint/ir/transform/add_empty_entry_point.cc b/src/tint/ir/transform/add_empty_entry_point.cc index c69532698c..2372ef540b 100644 --- a/src/tint/ir/transform/add_empty_entry_point.cc +++ b/src/tint/ir/transform/add_empty_entry_point.cc @@ -35,7 +35,7 @@ void AddEmptyEntryPoint::Run(ir::Module* ir, const DataMap&, DataMap&) const { } ir::Builder builder(*ir); - auto* ep = builder.CreateFunction(ir->symbols.New("unused_entry_point"), ir->types.void_(), + auto* ep = builder.CreateFunction(ir->symbols.New("unused_entry_point"), ir->Types().void_(), Function::PipelineStage::kCompute, std::array{1u, 1u, 1u}); ep->StartTarget()->SetInstructions(utils::Vector{builder.Branch(ep->EndTarget())}); ir->functions.Push(ep); diff --git a/src/tint/ir/transform/add_empty_entry_point_test.cc b/src/tint/ir/transform/add_empty_entry_point_test.cc index a97ad7a2b8..fec41928dd 100644 --- a/src/tint/ir/transform/add_empty_entry_point_test.cc +++ b/src/tint/ir/transform/add_empty_entry_point_test.cc @@ -39,7 +39,7 @@ TEST_F(IR_AddEmptyEntryPointTest, EmptyModule) { } TEST_F(IR_AddEmptyEntryPointTest, ExistingEntryPoint) { - auto* ep = b.CreateFunction("main", mod.types.void_(), Function::PipelineStage::kFragment); + auto* ep = b.CreateFunction("main", mod.Types().void_(), Function::PipelineStage::kFragment); ep->StartTarget()->SetInstructions(utils::Vector{b.Branch(ep->EndTarget())}); mod.functions.Push(ep); diff --git a/src/tint/ir/unary_test.cc b/src/tint/ir/unary_test.cc index 14de464851..92ad24c1ed 100644 --- a/src/tint/ir/unary_test.cc +++ b/src/tint/ir/unary_test.cc @@ -26,7 +26,7 @@ using IR_InstructionTest = TestHelper; TEST_F(IR_InstructionTest, CreateComplement) { Module mod; Builder b{mod}; - auto* inst = b.Complement(b.ir.types.i32(), b.Constant(4_i)); + auto* inst = b.Complement(mod.Types().i32(), b.Constant(4_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Unary::Kind::kComplement); @@ -40,7 +40,7 @@ TEST_F(IR_InstructionTest, CreateComplement) { TEST_F(IR_InstructionTest, CreateNegation) { Module mod; Builder b{mod}; - auto* inst = b.Negation(b.ir.types.i32(), b.Constant(4_i)); + auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i)); ASSERT_TRUE(inst->Is()); EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation); @@ -54,7 +54,7 @@ TEST_F(IR_InstructionTest, CreateNegation) { TEST_F(IR_InstructionTest, Unary_Usage) { Module mod; Builder b{mod}; - auto* inst = b.Negation(b.ir.types.i32(), b.Constant(4_i)); + auto* inst = b.Negation(mod.Types().i32(), b.Constant(4_i)); EXPECT_EQ(inst->Kind(), Unary::Kind::kNegation); diff --git a/src/tint/program.cc b/src/tint/program.cc index a1cf5c8b47..92a9f02ed5 100644 --- a/src/tint/program.cc +++ b/src/tint/program.cc @@ -37,10 +37,9 @@ Program::Program() = default; Program::Program(Program&& program) : id_(std::move(program.id_)), highest_node_id_(std::move(program.highest_node_id_)), - types_(std::move(program.types_)), + constants_(std::move(program.constants_)), ast_nodes_(std::move(program.ast_nodes_)), sem_nodes_(std::move(program.sem_nodes_)), - constant_nodes_(std::move(program.constant_nodes_)), ast_(std::move(program.ast_)), sem_(std::move(program.sem_)), symbols_(std::move(program.symbols_)), @@ -63,10 +62,9 @@ Program::Program(ProgramBuilder&& builder) { } // The above must be called *before* the calls to std::move() below - types_ = std::move(builder.Types()); + constants_ = std::move(builder.constants); ast_nodes_ = std::move(builder.ASTNodes()); sem_nodes_ = std::move(builder.SemNodes()); - constant_nodes_ = std::move(builder.ConstantNodes()); ast_ = &builder.AST(); // ast::Module is actually a heap allocation. sem_ = std::move(builder.Sem()); symbols_ = std::move(builder.Symbols()); @@ -89,10 +87,9 @@ Program& Program::operator=(Program&& program) { moved_ = false; id_ = std::move(program.id_); highest_node_id_ = std::move(program.highest_node_id_); - types_ = std::move(program.types_); + constants_ = std::move(program.constants_); ast_nodes_ = std::move(program.ast_nodes_); sem_nodes_ = std::move(program.sem_nodes_); - constant_nodes_ = std::move(program.constant_nodes_); ast_ = std::move(program.ast_); sem_ = std::move(program.sem_); symbols_ = std::move(program.symbols_); diff --git a/src/tint/program.h b/src/tint/program.h index ae46acdfbf..940a6ecfa0 100644 --- a/src/tint/program.h +++ b/src/tint/program.h @@ -19,7 +19,7 @@ #include #include "src/tint/ast/function.h" -#include "src/tint/constant/value.h" +#include "src/tint/constant/manager.h" #include "src/tint/program_id.h" #include "src/tint/sem/info.h" #include "src/tint/symbol_table.h" @@ -44,9 +44,6 @@ class Program { /// SemNodeAllocator is an alias to BlockAllocator using SemNodeAllocator = utils::BlockAllocator; - /// ConstantAllocator is an alias to BlockAllocator - using ConstantAllocator = utils::BlockAllocator; - /// Constructor Program(); @@ -72,10 +69,16 @@ class Program { /// @returns the last allocated (numerically highest) AST node identifier. ast::NodeID HighestASTNodeID() const { return highest_node_id_; } + /// @returns a reference to the program's constants + const constant::Manager& Constants() const { + AssertNotMoved(); + return constants_; + } + /// @returns a reference to the program's types const type::Manager& Types() const { AssertNotMoved(); - return types_; + return constants_.types; } /// @returns a reference to the program's AST nodes storage @@ -165,10 +168,9 @@ class Program { ProgramID id_; ast::NodeID highest_node_id_; - type::Manager types_; + constant::Manager constants_; ASTNodeAllocator ast_nodes_; SemNodeAllocator sem_nodes_; - ConstantAllocator constant_nodes_; ast::Module* ast_ = nullptr; sem::Info sem_; SymbolTable symbols_{id_}; diff --git a/src/tint/program_builder.cc b/src/tint/program_builder.cc index fc2e7956b9..20a3fc4b43 100644 --- a/src/tint/program_builder.cc +++ b/src/tint/program_builder.cc @@ -38,9 +38,9 @@ ProgramBuilder::ProgramBuilder() ast_(ast_nodes_.Create(id_, AllocateNodeID(), Source{})) {} ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs) - : id_(std::move(rhs.id_)), + : constants(std::move(rhs.constants)), + id_(std::move(rhs.id_)), last_ast_node_id_(std::move(rhs.last_ast_node_id_)), - types_(std::move(rhs.types_)), ast_nodes_(std::move(rhs.ast_nodes_)), sem_nodes_(std::move(rhs.sem_nodes_)), ast_(std::move(rhs.ast_)), @@ -57,7 +57,7 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) { AssertNotMoved(); id_ = std::move(rhs.id_); last_ast_node_id_ = std::move(rhs.last_ast_node_id_); - types_ = std::move(rhs.types_); + constants = std::move(rhs.constants); ast_nodes_ = std::move(rhs.ast_nodes_); sem_nodes_ = std::move(rhs.sem_nodes_); ast_ = std::move(rhs.ast_); @@ -72,7 +72,7 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) { ProgramBuilder builder; builder.id_ = program->ID(); builder.last_ast_node_id_ = program->HighestASTNodeID(); - builder.types_ = type::Manager::Wrap(program->Types()); + builder.constants = constant::Manager::Wrap(program->Constants()); builder.ast_ = builder.create(program->AST().source, program->AST().GlobalDeclarations()); builder.sem_ = sem::Info::Wrap(program->Sem()); @@ -136,39 +136,4 @@ const ast::Function* ProgramBuilder::WrapInFunction(utils::VectorRef elements) { - if (elements.IsEmpty()) { - return nullptr; - } - - bool any_zero = false; - bool all_zero = true; - bool all_equal = true; - auto* first = elements.Front(); - for (auto* el : elements) { - if (!el) { - return nullptr; - } - if (!any_zero && el->AnyZero()) { - any_zero = true; - } - if (all_zero && !el->AllZero()) { - all_zero = false; - } - if (all_equal && el != first) { - if (!el->Equal(first)) { - all_equal = false; - } - } - } - if (all_equal) { - return create(type, elements[0], elements.Length()); - } - - return constant_nodes_.Create(type, std::move(elements), all_zero, - any_zero); -} - } // namespace tint diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index b944e21177..459edaadd0 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -78,9 +78,7 @@ #include "src/tint/builtin/extension.h" #include "src/tint/builtin/interpolation_sampling.h" #include "src/tint/builtin/interpolation_type.h" -#include "src/tint/constant/composite.h" -#include "src/tint/constant/splat.h" -#include "src/tint/constant/value.h" +#include "src/tint/constant/manager.h" #include "src/tint/number.h" #include "src/tint/program.h" #include "src/tint/program_id.h" @@ -328,9 +326,6 @@ class ProgramBuilder { /// SemNodeAllocator is an alias to BlockAllocator using SemNodeAllocator = utils::BlockAllocator; - /// ConstantAllocator is an alias to BlockAllocator - using ConstantAllocator = utils::BlockAllocator; - /// Constructor ProgramBuilder(); @@ -364,13 +359,13 @@ class ProgramBuilder { /// @returns a reference to the program's types type::Manager& Types() { AssertNotMoved(); - return types_; + return constants.types; } /// @returns a reference to the program's types const type::Manager& Types() const { AssertNotMoved(); - return types_; + return constants.types; } /// @returns a reference to the program's AST nodes storage @@ -397,12 +392,6 @@ class ProgramBuilder { return sem_nodes_; } - /// @returns a reference to the program's semantic constant storage - ConstantAllocator& ConstantNodes() { - AssertNotMoved(); - return constant_nodes_; - } - /// @returns a reference to the program's AST root Module ast::Module& AST() { AssertNotMoved(); @@ -528,53 +517,6 @@ class ProgramBuilder { return sem_nodes_.Create(std::forward(args)...); } - /// Creates a new constant::Value owned by the ProgramBuilder. - /// When the ProgramBuilder is destructed, the sem::Node will also be destructed. - /// @param args the arguments to pass to the constructor - /// @returns the node pointer - template - utils::traits::EnableIf && - !utils::traits::IsTypeOrDerived && - !utils::traits::IsTypeOrDerived, - T>* - create(ARGS&&... args) { - AssertNotMoved(); - return constant_nodes_.Create(std::forward(args)...); - } - - /// Constructs a constant of a vector, matrix or array type. - /// - /// Examines the element values and will return either a constant::Composite or a - /// constant::Splat, depending on the element types and values. - /// - /// @param type the composite type - /// @param elements the composite elements - /// @returns the node pointer - template < - typename T, - typename = utils::traits::EnableIf || - utils::traits::IsTypeOrDerived>> - const constant::Value* create(const type::Type* type, - utils::VectorRef elements) { - AssertNotMoved(); - return createSplatOrComposite(type, elements); - } - - /// Constructs a splat constant. - /// @param type the splat type - /// @param element the splat element - /// @param n the number of elements - /// @returns the node pointer - template < - typename T, - typename = utils::traits::EnableIf>> - const constant::Splat* create(const type::Type* type, - const constant::Value* element, - size_t n) { - AssertNotMoved(); - return constant_nodes_.Create(type, element, n); - } - /// Creates a new type::Node owned by the ProgramBuilder. /// When the ProgramBuilder is destructed, owned ProgramBuilder and the returned node will also /// be destructed. If T derives from type::UniqueNode, then the calling create() for the same @@ -584,7 +526,7 @@ class ProgramBuilder { template utils::traits::EnableIfIsType* create(ARGS&&... args) { AssertNotMoved(); - return types_.Get(std::forward(args)...); + return constants.types.Get(std::forward(args)...); } /// Marks this builder as moved, preventing any further use of the builder. @@ -3953,6 +3895,9 @@ class ProgramBuilder { /// @returns the function const ast::Function* WrapInFunction(utils::VectorRef stmts); + /// The constants manager + constant::Manager constants; + /// The builder types TypesBuilder const ty{this}; @@ -3961,16 +3906,10 @@ class ProgramBuilder { void AssertNotMoved() const; private: - const constant::Value* createSplatOrComposite( - const type::Type* type, - utils::VectorRef elements); - ProgramID id_; ast::NodeID last_ast_node_id_ = ast::NodeID{static_cast(0) - 1}; - type::Manager types_; ASTNodeAllocator ast_nodes_; SemNodeAllocator sem_nodes_; - ConstantAllocator constant_nodes_; ast::Module* ast_; sem::Info sem_; SymbolTable symbols_{id_}; diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 35ea8bba31..c181f02387 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -261,13 +261,15 @@ ConstEval::Result ScalarConvert(const constant::Scalar* scalar, using FROM = T; if constexpr (std::is_same_v) { // [x -> bool] - return builder.create>(target_ty, !scalar->IsPositiveZero()); + return builder.constants.Get>(target_ty, + !scalar->IsPositiveZero()); } else if constexpr (std::is_same_v) { // [bool -> x] - return builder.create>(target_ty, TO(scalar->value ? 1 : 0)); + return builder.constants.Get>(target_ty, + TO(scalar->value ? 1 : 0)); } else if (auto conv = CheckedConvert(scalar->value)) { // Conversion success - return builder.create>(target_ty, conv.Get()); + return builder.constants.Get>(target_ty, conv.Get()); // --- Below this point are the failure cases --- } else if constexpr (IsAbstract) { // [abstract-numeric -> x] - materialization failure @@ -276,9 +278,10 @@ ConstEval::Result ScalarConvert(const constant::Scalar* scalar, builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source); switch (conv.Failure()) { case ConversionFailure::kExceedsNegativeLimit: - return builder.create>(target_ty, TO::Lowest()); + return builder.constants.Get>(target_ty, TO::Lowest()); case ConversionFailure::kExceedsPositiveLimit: - return builder.create>(target_ty, TO::Highest()); + return builder.constants.Get>(target_ty, + TO::Highest()); } } else { builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source); @@ -292,9 +295,10 @@ ConstEval::Result ScalarConvert(const constant::Scalar* scalar, builder.Diagnostics().add_warning(tint::diag::System::Resolver, msg, source); switch (conv.Failure()) { case ConversionFailure::kExceedsNegativeLimit: - return builder.create>(target_ty, TO::Lowest()); + return builder.constants.Get>(target_ty, TO::Lowest()); case ConversionFailure::kExceedsPositiveLimit: - return builder.create>(target_ty, TO::Highest()); + return builder.constants.Get>(target_ty, + TO::Highest()); } } else { builder.Diagnostics().add_error(tint::diag::System::Resolver, msg, source); @@ -305,14 +309,15 @@ ConstEval::Result ScalarConvert(const constant::Scalar* scalar, // https://www.w3.org/TR/WGSL/#floating-point-conversion switch (conv.Failure()) { case ConversionFailure::kExceedsNegativeLimit: - return builder.create>(target_ty, TO::Lowest()); + return builder.constants.Get>(target_ty, TO::Lowest()); case ConversionFailure::kExceedsPositiveLimit: - return builder.create>(target_ty, TO::Highest()); + return builder.constants.Get>(target_ty, TO::Highest()); } } else if constexpr (IsIntegral) { // [integer -> integer] - number not exactly representable // Static cast - return builder.create>(target_ty, static_cast(scalar->value)); + return builder.constants.Get>(target_ty, + static_cast(scalar->value)); } return nullptr; // Expression is not constant. }); @@ -362,7 +367,7 @@ ConstEval::Result CompositeConvert(const constant::Value* value, } conv_els.Push(conv_el.Get()); } - return builder.create(target_ty, std::move(conv_els)); + return builder.constants.Composite(target_ty, std::move(conv_els)); } ConstEval::Result SplatConvert(const constant::Splat* splat, @@ -396,7 +401,7 @@ ConstEval::Result SplatConvert(const constant::Splat* splat, if (!conv_el.Get()) { return nullptr; } - return builder.create(target_ty, conv_el.Get(), splat->count); + return builder.constants.Splat(target_ty, conv_el.Get(), splat->count); } ConstEval::Result ConvertInternal(const constant::Value* c, @@ -466,7 +471,7 @@ ConstEval::Result TransformElements(ProgramBuilder& builder, return el.Failure(); } } - return builder.create(composite_ty, std::move(els)); + return builder.constants.Composite(composite_ty, std::move(els)); } } // namespace detail @@ -520,7 +525,7 @@ ConstEval::Result TransformBinaryElements(ProgramBuilder& builder, return el.Failure(); } } - return builder.create(composite_ty, std::move(els)); + return builder.constants.Composite(composite_ty, std::move(els)); } } // namespace @@ -542,7 +547,7 @@ ConstEval::Result ConstEval::CreateScalar(const Source& source, const type::Type } } } - return builder.create>(t, v); + return builder.constants.Get>(t, v); } const constant::Value* ConstEval::ZeroValue(const type::Type* type) { @@ -550,16 +555,16 @@ const constant::Value* ConstEval::ZeroValue(const type::Type* type) { type, // [&](const type::Vector* v) -> const constant::Value* { auto* zero_el = ZeroValue(v->type()); - return builder.create(type, zero_el, v->Width()); + return builder.constants.Splat(type, zero_el, v->Width()); }, [&](const type::Matrix* m) -> const constant::Value* { auto* zero_el = ZeroValue(m->ColumnType()); - return builder.create(type, zero_el, m->columns()); + return builder.constants.Splat(type, zero_el, m->columns()); }, [&](const type::Array* a) -> const constant::Value* { if (auto n = a->ConstantCount()) { if (auto* zero_el = ZeroValue(a->ElemType())) { - return builder.create(type, zero_el, n.value()); + return builder.constants.Splat(type, zero_el, n.value()); } } return nullptr; @@ -578,9 +583,9 @@ const constant::Value* ConstEval::ZeroValue(const type::Type* type) { } if (zero_by_type.Count() == 1) { // All members were of the same type, so the zero value is the same for all members. - return builder.create(type, zeros[0], s->Members().Length()); + return builder.constants.Splat(type, zeros[0], s->Members().Length()); } - return builder.create(s, std::move(zeros)); + return builder.constants.Composite(s, std::move(zeros)); }, [&](Default) -> const constant::Value* { return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* { @@ -1260,7 +1265,7 @@ ConstEval::Result ConstEval::ArrayOrStructCtor(const type::Type* ty, } // Multiple arguments. Must be a value constructor. - return builder.create(ty, std::move(args)); + return builder.constants.Composite(ty, std::move(args)); } ConstEval::Result ConstEval::Conv(const type::Type* ty, @@ -1295,8 +1300,7 @@ ConstEval::Result ConstEval::VecSplat(const type::Type* ty, utils::VectorRef args, const Source&) { if (auto* arg = args[0]) { - return builder.create(ty, arg, - static_cast(ty)->Width()); + return builder.constants.Splat(ty, arg, static_cast(ty)->Width()); } return nullptr; } @@ -1304,7 +1308,7 @@ ConstEval::Result ConstEval::VecSplat(const type::Type* ty, ConstEval::Result ConstEval::VecInitS(const type::Type* ty, utils::VectorRef args, const Source&) { - return builder.create(ty, args); + return builder.constants.Composite(ty, args); } ConstEval::Result ConstEval::VecInitM(const type::Type* ty, @@ -1330,7 +1334,7 @@ ConstEval::Result ConstEval::VecInitM(const type::Type* ty, els.Push(val); } } - return builder.create(ty, std::move(els)); + return builder.constants.Composite(ty, std::move(els)); } ConstEval::Result ConstEval::MatInitS(const type::Type* ty, @@ -1345,15 +1349,15 @@ ConstEval::Result ConstEval::MatInitS(const type::Type* ty, auto i = r + c * m->rows(); column.Push(args[i]); } - els.Push(builder.create(m->ColumnType(), std::move(column))); + els.Push(builder.constants.Composite(m->ColumnType(), std::move(column))); } - return builder.create(ty, std::move(els)); + return builder.constants.Composite(ty, std::move(els)); } ConstEval::Result ConstEval::MatInitV(const type::Type* ty, utils::VectorRef args, const Source&) { - return builder.create(ty, args); + return builder.constants.Composite(ty, args); } ConstEval::Result ConstEval::Index(const type::Type* ty, @@ -1411,7 +1415,7 @@ ConstEval::Result ConstEval::Swizzle(const type::Type* ty, } auto values = utils::Transform<4>( indices, [&](uint32_t i) { return vec_val->Index(static_cast(i)); }); - return builder.create(ty, std::move(values)); + return builder.constants.Composite(ty, std::move(values)); } ConstEval::Result ConstEval::Bitcast(const type::Type* ty, @@ -1557,7 +1561,7 @@ ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty, } result.Push(r.Get()); } - return builder.create(ty, result); + return builder.constants.Composite(ty, result); } ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, utils::VectorRef args, @@ -1607,7 +1611,7 @@ ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, } result.Push(r.Get()); } - return builder.create(ty, result); + return builder.constants.Composite(ty, result); } ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, @@ -1669,9 +1673,9 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, // Add column vector to matrix auto* col_vec_ty = ty->As()->ColumnType(); - result_mat.Push(builder.create(col_vec_ty, col_vec)); + result_mat.Push(builder.constants.Composite(col_vec_ty, col_vec)); } - return builder.create(ty, result_mat); + return builder.constants.Composite(ty, result_mat); } ConstEval::Result ConstEval::OpDivide(const type::Type* ty, @@ -2311,7 +2315,7 @@ ConstEval::Result ConstEval::cross(const type::Type* ty, return utils::Failure; } - return builder.create( + return builder.constants.Composite( ty, utils::Vector{x.Get(), y.Get(), z.Get()}); } @@ -2707,20 +2711,20 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, } auto fract_ty = builder.create(fract_els[0]->Type(), vec->Width()); auto exp_ty = builder.create(exp_els[0]->Type(), vec->Width()); - return builder.create( + return builder.constants.Composite( ty, utils::Vector{ - builder.create(fract_ty, std::move(fract_els)), - builder.create(exp_ty, std::move(exp_els)), + builder.constants.Composite(fract_ty, std::move(fract_els)), + builder.constants.Composite(exp_ty, std::move(exp_els)), }); } else { auto fe = scalar(arg); if (!fe.fract || !fe.exp) { return utils::Failure; } - return builder.create(ty, utils::Vector{ - fe.fract.Get(), - fe.exp.Get(), - }); + return builder.constants.Composite(ty, utils::Vector{ + fe.fract.Get(), + fe.exp.Get(), + }); } } @@ -3014,7 +3018,7 @@ ConstEval::Result ConstEval::modf(const type::Type* ty, return utils::Failure; } - return builder.create(ty, std::move(fields)); + return builder.constants.Composite(ty, std::move(fields)); } ConstEval::Result ConstEval::normalize(const type::Type* ty, @@ -3600,10 +3604,9 @@ ConstEval::Result ConstEval::transpose(const type::Type* ty, for (size_t c = 0; c < mat_ty->columns(); ++c) { new_col_vec.Push(me(r, c)); } - result_mat.Push( - builder.create(result_mat_ty->ColumnType(), new_col_vec)); + result_mat.Push(builder.constants.Composite(result_mat_ty->ColumnType(), new_col_vec)); } - return builder.create(ty, result_mat); + return builder.constants.Composite(ty, result_mat); } ConstEval::Result ConstEval::trunc(const type::Type* ty, @@ -3643,7 +3646,7 @@ ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty, } els.Push(el.Get()); } - return builder.create(ty, std::move(els)); + return builder.constants.Composite(ty, std::move(els)); } ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty, @@ -3663,7 +3666,7 @@ ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty, } els.Push(el.Get()); } - return builder.create(ty, std::move(els)); + return builder.constants.Composite(ty, std::move(els)); } ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty, @@ -3682,7 +3685,7 @@ ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty, } els.Push(el.Get()); } - return builder.create(ty, std::move(els)); + return builder.constants.Composite(ty, std::move(els)); } ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty, @@ -3702,7 +3705,7 @@ ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty, } els.Push(el.Get()); } - return builder.create(ty, std::move(els)); + return builder.constants.Composite(ty, std::move(els)); } ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty, @@ -3721,7 +3724,7 @@ ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty, } els.Push(el.Get()); } - return builder.create(ty, std::move(els)); + return builder.constants.Composite(ty, std::move(els)); } ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty, diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc index 350afb2226..468d2b45c2 100644 --- a/src/tint/resolver/const_eval_conversion_test.cc +++ b/src/tint/resolver/const_eval_conversion_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "src/tint/constant/splat.h" #include "src/tint/resolver/const_eval_test.h" #include "src/tint/sem/materialize.h" diff --git a/src/tint/resolver/const_eval_runtime_semantics_test.cc b/src/tint/resolver/const_eval_runtime_semantics_test.cc index eb0cf52e9f..2e41482559 100644 --- a/src/tint/resolver/const_eval_runtime_semantics_test.cc +++ b/src/tint/resolver/const_eval_runtime_semantics_test.cc @@ -37,29 +37,11 @@ class ResolverConstEvalRuntimeSemanticsTest : public ResolverConstEvalTest { diag::Formatter formatter{style}; return formatter.format(Diagnostics()); } - - /// Helper to make a scalar constant::Value from a value. - template - const constant::Value* Scalar(T value) { - if constexpr (IsAbstract) { - if constexpr (IsFloatingPoint) { - return create>(create(), value); - } else if constexpr (IsIntegral) { - return create>(create(), value); - } - } else if constexpr (IsFloatingPoint) { - return create>(create(), value); - } else if constexpr (IsSignedIntegral) { - return create>(create(), value); - } else if constexpr (IsUnsignedIntegral) { - return create>(create(), value); - } - } }; TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AInt_Overflow) { - auto* a = Scalar(AInt::Highest()); - auto* b = Scalar(AInt(1)); + auto* a = constants.Get(AInt::Highest()); + auto* b = constants.Get(AInt(1)); auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -68,8 +50,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AInt_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AFloat_Overflow) { - auto* a = Scalar(AFloat::Highest()); - auto* b = Scalar(AFloat::Highest()); + auto* a = constants.Get(AFloat::Highest()); + auto* b = constants.Get(AFloat::Highest()); auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -79,8 +61,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_AFloat_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_F32_Overflow) { - auto* a = Scalar(f32::Highest()); - auto* b = Scalar(f32::Highest()); + auto* a = constants.Get(f32::Highest()); + auto* b = constants.Get(f32::Highest()); auto result = const_eval.OpPlus(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -90,8 +72,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Add_F32_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AInt_Overflow) { - auto* a = Scalar(AInt::Lowest()); - auto* b = Scalar(AInt(1)); + auto* a = constants.Get(AInt::Lowest()); + auto* b = constants.Get(AInt(1)); auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -100,8 +82,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AInt_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AFloat_Overflow) { - auto* a = Scalar(AFloat::Lowest()); - auto* b = Scalar(AFloat::Highest()); + auto* a = constants.Get(AFloat::Lowest()); + auto* b = constants.Get(AFloat::Highest()); auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -111,8 +93,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_AFloat_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_F32_Overflow) { - auto* a = Scalar(f32::Lowest()); - auto* b = Scalar(f32::Highest()); + auto* a = constants.Get(f32::Lowest()); + auto* b = constants.Get(f32::Highest()); auto result = const_eval.OpMinus(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -122,8 +104,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sub_F32_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AInt_Overflow) { - auto* a = Scalar(AInt::Highest()); - auto* b = Scalar(AInt(2)); + auto* a = constants.Get(AInt::Highest()); + auto* b = constants.Get(AInt(2)); auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -132,8 +114,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AInt_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AFloat_Overflow) { - auto* a = Scalar(AFloat::Highest()); - auto* b = Scalar(AFloat::Highest()); + auto* a = constants.Get(AFloat::Highest()); + auto* b = constants.Get(AFloat::Highest()); auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -143,8 +125,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_AFloat_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_F32_Overflow) { - auto* a = Scalar(f32::Highest()); - auto* b = Scalar(f32::Highest()); + auto* a = constants.Get(f32::Highest()); + auto* b = constants.Get(f32::Highest()); auto result = const_eval.OpMultiply(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -154,8 +136,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mul_F32_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AInt_ZeroDenominator) { - auto* a = Scalar(AInt(42)); - auto* b = Scalar(AInt(0)); + auto* a = constants.Get(AInt(42)); + auto* b = constants.Get(AInt(0)); auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 42); @@ -163,8 +145,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AInt_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_ZeroDenominator) { - auto* a = Scalar(i32(42)); - auto* b = Scalar(i32(0)); + auto* a = constants.Get(i32(42)); + auto* b = constants.Get(i32(0)); auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 42); @@ -172,8 +154,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_U32_ZeroDenominator) { - auto* a = Scalar(u32(42)); - auto* b = Scalar(u32(0)); + auto* a = constants.Get(u32(42)); + auto* b = constants.Get(u32(0)); auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 42); @@ -181,8 +163,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_U32_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AFloat_ZeroDenominator) { - auto* a = Scalar(AFloat(42)); - auto* b = Scalar(AFloat(0)); + auto* a = constants.Get(AFloat(42)); + auto* b = constants.Get(AFloat(0)); auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 42.f); @@ -190,8 +172,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_AFloat_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_F32_ZeroDenominator) { - auto* a = Scalar(f32(42)); - auto* b = Scalar(f32(0)); + auto* a = constants.Get(f32(42)); + auto* b = constants.Get(f32(0)); auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 42.f); @@ -199,8 +181,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_F32_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_MostNegativeByMinInt) { - auto* a = Scalar(i32::Lowest()); - auto* b = Scalar(i32(-1)); + auto* a = constants.Get(i32::Lowest()); + auto* b = constants.Get(i32(-1)); auto result = const_eval.OpDivide(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), i32::Lowest()); @@ -208,8 +190,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Div_I32_MostNegativeByMinInt) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AInt_ZeroDenominator) { - auto* a = Scalar(AInt(42)); - auto* b = Scalar(AInt(0)); + auto* a = constants.Get(AInt(42)); + auto* b = constants.Get(AInt(0)); auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -217,8 +199,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AInt_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_ZeroDenominator) { - auto* a = Scalar(i32(42)); - auto* b = Scalar(i32(0)); + auto* a = constants.Get(i32(42)); + auto* b = constants.Get(i32(0)); auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -226,8 +208,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_U32_ZeroDenominator) { - auto* a = Scalar(u32(42)); - auto* b = Scalar(u32(0)); + auto* a = constants.Get(u32(42)); + auto* b = constants.Get(u32(0)); auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -235,8 +217,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_U32_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AFloat_ZeroDenominator) { - auto* a = Scalar(AFloat(42)); - auto* b = Scalar(AFloat(0)); + auto* a = constants.Get(AFloat(42)); + auto* b = constants.Get(AFloat(0)); auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -244,8 +226,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_AFloat_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_F32_ZeroDenominator) { - auto* a = Scalar(f32(42)); - auto* b = Scalar(f32(0)); + auto* a = constants.Get(f32(42)); + auto* b = constants.Get(f32(0)); auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -253,8 +235,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_F32_ZeroDenominator) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_MostNegativeByMinInt) { - auto* a = Scalar(i32::Lowest()); - auto* b = Scalar(i32(-1)); + auto* a = constants.Get(i32::Lowest()); + auto* b = constants.Get(i32(-1)); auto result = const_eval.OpModulo(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -262,8 +244,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Mod_I32_MostNegativeByMinInt) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_AInt_SignChange) { - auto* a = Scalar(AInt(0x0FFFFFFFFFFFFFFFll)); - auto* b = Scalar(u32(9)); + auto* a = constants.Get(AInt(0x0FFFFFFFFFFFFFFFll)); + auto* b = constants.Get(u32(9)); auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), static_cast(0x0FFFFFFFFFFFFFFFull << 9)); @@ -271,8 +253,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_AInt_SignChange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_SignChange) { - auto* a = Scalar(i32(0x0FFFFFFF)); - auto* b = Scalar(u32(9)); + auto* a = constants.Get(i32(0x0FFFFFFF)); + auto* b = constants.Get(u32(9)); auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), static_cast(0x0FFFFFFFu << 9)); @@ -280,8 +262,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_SignChange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_MoreThanBitWidth) { - auto* a = Scalar(i32(0x1)); - auto* b = Scalar(u32(33)); + auto* a = constants.Get(i32(0x1)); + auto* b = constants.Get(u32(33)); auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 2); @@ -291,8 +273,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_I32_MoreThanBitWidth) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_U32_MoreThanBitWidth) { - auto* a = Scalar(u32(0x1)); - auto* b = Scalar(u32(33)); + auto* a = constants.Get(u32(0x1)); + auto* b = constants.Get(u32(33)); auto result = const_eval.OpShiftLeft(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 2); @@ -302,8 +284,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftLeft_U32_MoreThanBitWidth) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_I32_MoreThanBitWidth) { - auto* a = Scalar(i32(0x2)); - auto* b = Scalar(u32(33)); + auto* a = constants.Get(i32(0x2)); + auto* b = constants.Get(u32(33)); auto result = const_eval.OpShiftRight(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 1); @@ -313,8 +295,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_I32_MoreThanBitWidth) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_U32_MoreThanBitWidth) { - auto* a = Scalar(u32(0x2)); - auto* b = Scalar(u32(33)); + auto* a = constants.Get(u32(0x2)); + auto* b = constants.Get(u32(33)); auto result = const_eval.OpShiftRight(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 1); @@ -324,7 +306,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ShiftRight_U32_MoreThanBitWidth) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acos_F32_OutOfRange) { - auto* a = Scalar(f32(2)); + auto* a = constants.Get(f32(2)); auto result = const_eval.acos(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -333,7 +315,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acos_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acosh_F32_OutOfRange) { - auto* a = Scalar(f32(-1)); + auto* a = constants.Get(f32(-1)); auto result = const_eval.acosh(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -341,7 +323,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Acosh_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Asin_F32_OutOfRange) { - auto* a = Scalar(f32(2)); + auto* a = constants.Get(f32(2)); auto result = const_eval.asin(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -350,7 +332,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Asin_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Atanh_F32_OutOfRange) { - auto* a = Scalar(f32(2)); + auto* a = constants.Get(f32(2)); auto result = const_eval.atanh(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -359,7 +341,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Atanh_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp_F32_Overflow) { - auto* a = Scalar(f32(1000)); + auto* a = constants.Get(f32(1000)); auto result = const_eval.exp(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -367,7 +349,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp_F32_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp2_F32_Overflow) { - auto* a = Scalar(f32(1000)); + auto* a = constants.Get(f32(1000)); auto result = const_eval.exp2(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -375,9 +357,9 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Exp2_F32_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_I32_TooManyBits) { - auto* a = Scalar(i32(0x12345678)); - auto* offset = Scalar(u32(24)); - auto* count = Scalar(u32(16)); + auto* a = constants.Get(i32(0x12345678)); + auto* offset = constants.Get(u32(24)); + auto* count = constants.Get(u32(16)); auto result = const_eval.extractBits(a->Type(), utils::Vector{a, offset, count}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0x12); @@ -386,9 +368,9 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_I32_TooManyBits) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_U32_TooManyBits) { - auto* a = Scalar(u32(0x12345678)); - auto* offset = Scalar(u32(24)); - auto* count = Scalar(u32(16)); + auto* a = constants.Get(u32(0x12345678)); + auto* offset = constants.Get(u32(24)); + auto* count = constants.Get(u32(16)); auto result = const_eval.extractBits(a->Type(), utils::Vector{a, offset, count}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0x12); @@ -397,10 +379,10 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, ExtractBits_U32_TooManyBits) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_I32_TooManyBits) { - auto* a = Scalar(i32(0x99345678)); - auto* b = Scalar(i32(0x12)); - auto* offset = Scalar(u32(24)); - auto* count = Scalar(u32(16)); + auto* a = constants.Get(i32(0x99345678)); + auto* b = constants.Get(i32(0x12)); + auto* offset = constants.Get(u32(24)); + auto* count = constants.Get(u32(16)); auto result = const_eval.insertBits(a->Type(), utils::Vector{a, b, offset, count}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0x12345678); @@ -409,10 +391,10 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_I32_TooManyBits) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_U32_TooManyBits) { - auto* a = Scalar(u32(0x99345678)); - auto* b = Scalar(u32(0x12)); - auto* offset = Scalar(u32(24)); - auto* count = Scalar(u32(16)); + auto* a = constants.Get(u32(0x99345678)); + auto* b = constants.Get(u32(0x12)); + auto* offset = constants.Get(u32(24)); + auto* count = constants.Get(u32(16)); auto result = const_eval.insertBits(a->Type(), utils::Vector{a, b, offset, count}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0x12345678); @@ -421,7 +403,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, InsertBits_U32_TooManyBits) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, InverseSqrt_F32_OutOfRange) { - auto* a = Scalar(f32(-1)); + auto* a = constants.Get(f32(-1)); auto result = const_eval.inverseSqrt(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -429,8 +411,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, InverseSqrt_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, LDExpr_F32_OutOfRange) { - auto* a = Scalar(f32(42.f)); - auto* b = Scalar(f32(200)); + auto* a = constants.Get(f32(42.f)); + auto* b = constants.Get(f32(200)); auto result = const_eval.ldexp(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -438,7 +420,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, LDExpr_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log_F32_OutOfRange) { - auto* a = Scalar(f32(-1)); + auto* a = constants.Get(f32(-1)); auto result = const_eval.log(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -446,7 +428,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log2_F32_OutOfRange) { - auto* a = Scalar(f32(-1)); + auto* a = constants.Get(f32(-1)); auto result = const_eval.log2(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -454,7 +436,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Log2_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Normalize_ZeroLength) { - auto* zero = Scalar(f32(0)); + auto* zero = constants.Get(f32(0)); auto* vec = const_eval.VecSplat(create(create(), 4u), utils::Vector{zero}, {}) .Get(); @@ -468,8 +450,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Normalize_ZeroLength) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pack2x16Float_OutOfRange) { - auto* a = Scalar(f32(75250.f)); - auto* b = Scalar(f32(42.1f)); + auto* a = constants.Get(f32(75250.f)); + auto* b = constants.Get(f32(42.1f)); auto* vec = const_eval.VecInitS(create(create(), 2u), utils::Vector{a, b}, {}) .Get(); @@ -480,8 +462,8 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pack2x16Float_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pow_F32_Overflow) { - auto* a = Scalar(f32(2)); - auto* b = Scalar(f32(1000)); + auto* a = constants.Get(f32(2)); + auto* b = constants.Get(f32(1000)); auto result = const_eval.pow(a->Type(), utils::Vector{a, b}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -489,7 +471,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Pow_F32_Overflow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Unpack2x16Float_OutOfRange) { - auto* a = Scalar(u32(0x51437C00)); + auto* a = constants.Get(u32(0x51437C00)); auto result = const_eval.unpack2x16float(create(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_FLOAT_EQ(result.Get()->Index(0)->ValueAs(), 0.f); @@ -498,7 +480,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Unpack2x16Float_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, QuantizeToF16_OutOfRange) { - auto* a = Scalar(f32(75250.f)); + auto* a = constants.Get(f32(75250.f)); auto result = const_eval.quantizeToF16(create(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0); @@ -506,7 +488,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, QuantizeToF16_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sqrt_F32_OutOfRange) { - auto* a = Scalar(f32(-1)); + auto* a = constants.Get(f32(-1)); auto result = const_eval.sqrt(a->Type(), utils::Vector{a}, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -514,7 +496,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Sqrt_F32_OutOfRange) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_Infinity) { - auto* a = Scalar(u32(0x7F800000)); + auto* a = constants.Get(u32(0x7F800000)); auto result = const_eval.Bitcast(create(), a, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -522,7 +504,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_Infinity) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_NaN) { - auto* a = Scalar(u32(0x7FC00000)); + auto* a = constants.Get(u32(0x7FC00000)); auto result = const_eval.Bitcast(create(), a, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), 0.f); @@ -530,7 +512,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Bitcast_NaN) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooHigh) { - auto* a = Scalar(AFloat::Highest()); + auto* a = constants.Get(AFloat::Highest()); auto result = const_eval.Convert(create(), a, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), f32::kHighestValue); @@ -540,7 +522,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooHigh) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooLow) { - auto* a = Scalar(AFloat::Lowest()); + auto* a = constants.Get(AFloat::Lowest()); auto result = const_eval.Convert(create(), a, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), f32::kLowestValue); @@ -550,7 +532,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F32_TooLow) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F16_TooHigh) { - auto* a = Scalar(f32(1000000.0)); + auto* a = constants.Get(f32(1000000.0)); auto result = const_eval.Convert(create(), a, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), f16::kHighestValue); @@ -558,7 +540,7 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F16_TooHigh) { } TEST_F(ResolverConstEvalRuntimeSemanticsTest, Convert_F16_TooLow) { - auto* a = Scalar(f32(-1000000.0)); + auto* a = constants.Get(f32(-1000000.0)); auto result = const_eval.Convert(create(), a, {}); ASSERT_TRUE(result); EXPECT_EQ(result.Get()->ValueAs(), f16::kLowestValue); @@ -571,10 +553,10 @@ TEST_F(ResolverConstEvalRuntimeSemanticsTest, Vec_Overflow_SingleComponent) { auto* a = const_eval .VecInitS(vec4f, utils::Vector{ - Scalar(f32(1)), - Scalar(f32(4)), - Scalar(f32(-1)), - Scalar(f32(65536)), + constants.Get(f32(1)), + constants.Get(f32(4)), + constants.Get(f32(-1)), + constants.Get(f32(65536)), }, {}) .Get(); diff --git a/src/tint/transform/manager_test.cc b/src/tint/transform/manager_test.cc index d1f63337aa..75afa938f8 100644 --- a/src/tint/transform/manager_test.cc +++ b/src/tint/transform/manager_test.cc @@ -51,7 +51,7 @@ class IR_AddFunction final : public ir::transform::Transform { void Run(ir::Module* mod, const DataMap&, DataMap&) const override { ir::Builder builder(*mod); auto* func = - builder.CreateFunction(mod->symbols.New("ir_func"), mod->types.Get()); + builder.CreateFunction(mod->symbols.New("ir_func"), mod->Types().Get()); func->StartTarget()->SetInstructions(utils::Vector{builder.Branch(func->EndTarget())}); mod->functions.Push(func); } @@ -69,7 +69,7 @@ ir::Module MakeIR() { ir::Module mod; ir::Builder builder(mod); auto* func = - builder.CreateFunction(builder.ir.symbols.New("main"), builder.ir.types.Get()); + builder.CreateFunction(builder.ir.symbols.New("main"), mod.Types().Get()); func->StartTarget()->SetInstructions(utils::Vector{builder.Branch(func->EndTarget())}); builder.ir.functions.Push(func); return mod; diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index a15f6a96b7..549f366541 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -52,6 +52,7 @@ #include "src/tint/ast/transform/unshadow.h" #include "src/tint/ast/transform/zero_init_workgroup_memory.h" #include "src/tint/ast/variable_decl_statement.h" +#include "src/tint/constant/splat.h" #include "src/tint/constant/value.h" #include "src/tint/debug.h" #include "src/tint/sem/block_statement.h" diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 95d2022589..eb07e8e42d 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -51,6 +51,7 @@ #include "src/tint/ast/transform/vectorize_scalar_matrix_initializers.h" #include "src/tint/ast/transform/zero_init_workgroup_memory.h" #include "src/tint/ast/variable_decl_statement.h" +#include "src/tint/constant/splat.h" #include "src/tint/constant/value.h" #include "src/tint/debug.h" #include "src/tint/sem/block_statement.h" diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 55b9f8d88c..4c51b17137 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -49,6 +49,7 @@ #include "src/tint/ast/transform/vectorize_scalar_matrix_initializers.h" #include "src/tint/ast/transform/zero_init_workgroup_memory.h" #include "src/tint/ast/variable_decl_statement.h" +#include "src/tint/constant/splat.h" #include "src/tint/constant/value.h" #include "src/tint/sem/call.h" #include "src/tint/sem/function.h" diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc index fd1d41a135..264f619ad9 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc @@ -20,9 +20,9 @@ namespace tint::writer::spirv { namespace { TEST_F(SpvGeneratorImplTest, Binary_Add_I32) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions(utils::Vector{ - b.Add(mod.types.i32(), b.Constant(1_i), b.Constant(2_i)), b.Branch(func->EndTarget())}); + b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" @@ -40,9 +40,9 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Add_U32) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions(utils::Vector{ - b.Add(mod.types.u32(), b.Constant(1_u), b.Constant(2_u)), b.Branch(func->EndTarget())}); + b.Add(mod.Types().u32(), b.Constant(1_u), b.Constant(2_u)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" @@ -60,9 +60,9 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Add_F32) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions(utils::Vector{ - b.Add(mod.types.f32(), b.Constant(1_f), b.Constant(2_f)), b.Branch(func->EndTarget())}); + b.Add(mod.Types().f32(), b.Constant(1_f), b.Constant(2_f)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" @@ -80,9 +80,9 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Sub_I32) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions( - utils::Vector{b.Subtract(mod.types.i32(), b.Constant(1_i), b.Constant(2_i)), + utils::Vector{b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -101,9 +101,9 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Sub_U32) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions( - utils::Vector{b.Subtract(mod.types.u32(), b.Constant(1_u), b.Constant(2_u)), + utils::Vector{b.Subtract(mod.Types().u32(), b.Constant(1_u), b.Constant(2_u)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -122,9 +122,9 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Sub_F32) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions( - utils::Vector{b.Subtract(mod.types.f32(), b.Constant(1_f), b.Constant(2_f)), + utils::Vector{b.Subtract(mod.Types().f32(), b.Constant(1_f), b.Constant(2_f)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -143,14 +143,15 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) { - auto* func = b.CreateFunction("foo", mod.types.void_()); - auto* lhs = b.create(mod.types.vec2(mod.types.i32()), - utils::Vector{b.I32(42), b.I32(-1)}, false, false); - auto* rhs = b.create(mod.types.vec2(mod.types.i32()), - utils::Vector{b.I32(0), b.I32(-43)}, false, false); + auto const_i32 = [&](int val) { return b.ir.constant_values.Get(i32(val)); }; + auto* func = b.CreateFunction("foo", mod.Types().void_()); + auto* lhs = b.ir.constant_values.Composite(mod.Types().vec2(mod.Types().i32()), + utils::Vector{const_i32(42), const_i32(-1)}); + auto* rhs = b.ir.constant_values.Composite(mod.Types().vec2(mod.Types().i32()), + utils::Vector{const_i32(0), const_i32(-43)}); func->StartTarget()->SetInstructions( - utils::Vector{b.Subtract(mod.types.Get(mod.types.i32(), 2u), b.Constant(lhs), - b.Constant(rhs)), + utils::Vector{b.Subtract(mod.Types().Get(mod.Types().i32(), 2u), + b.Constant(lhs), b.Constant(rhs)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -174,16 +175,17 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) { - auto* func = b.CreateFunction("foo", mod.types.void_()); - auto* lhs = b.create( - mod.types.vec4(mod.types.f32()), utils::Vector{b.F32(42), b.F32(-1), b.F32(0), b.F32(1.25)}, - false, false); - auto* rhs = b.create( - mod.types.vec4(mod.types.f32()), utils::Vector{b.F32(0), b.F32(1.25), b.F32(-42), b.F32(1)}, - false, false); + auto const_f32 = [&](float val) { return b.ir.constant_values.Get(f32(val)); }; + auto* func = b.CreateFunction("foo", mod.Types().void_()); + auto* lhs = b.ir.constant_values.Composite( + mod.Types().vec4(mod.Types().f32()), + utils::Vector{const_f32(42), const_f32(-1), const_f32(0), const_f32(1.25)}); + auto* rhs = b.ir.constant_values.Composite( + mod.Types().vec4(mod.Types().f32()), + utils::Vector{const_f32(0), const_f32(1.25), const_f32(-42), const_f32(1)}); func->StartTarget()->SetInstructions( - utils::Vector{b.Subtract(mod.types.Get(mod.types.f32(), 4u), b.Constant(lhs), - b.Constant(rhs)), + utils::Vector{b.Subtract(mod.Types().Get(mod.Types().f32(), 4u), + b.Constant(lhs), b.Constant(rhs)), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -209,10 +211,10 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Binary_Chain) { - auto* func = b.CreateFunction("foo", mod.types.void_()); - auto* a = b.Subtract(mod.types.i32(), b.Constant(1_i), b.Constant(2_i)); + auto* func = b.CreateFunction("foo", mod.Types().void_()); + auto* a = b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i)); func->StartTarget()->SetInstructions( - utils::Vector{a, b.Add(mod.types.i32(), a, a), b.Branch(func->EndTarget())}); + utils::Vector{a, b.Add(mod.Types().i32(), a, a), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo" diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc index 7881c0a16e..7cb240ac62 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc @@ -63,9 +63,10 @@ TEST_F(SpvGeneratorImplTest, Constant_F16) { } TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) { - auto* v = b.create( - mod.types.vec4(mod.types.bool_()), - utils::Vector{b.Bool(true), b.Bool(false), b.Bool(false), b.Bool(true)}, false, true); + auto const_bool = [&](bool val) { return mod.constant_values.Get(val); }; + auto* v = mod.constant_values.Composite( + mod.Types().vec4(mod.Types().bool_()), + utils::Vector{const_bool(true), const_bool(false), const_bool(false), const_bool(true)}); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeBool @@ -77,8 +78,9 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) { } TEST_F(SpvGeneratorImplTest, Constant_Vec2i) { - auto* v = b.create(mod.types.vec2(mod.types.i32()), - utils::Vector{b.I32(42), b.I32(-1)}, false, false); + auto const_i32 = [&](float val) { return mod.constant_values.Get(i32(val)); }; + auto* v = mod.constant_values.Composite(mod.Types().vec2(mod.Types().i32()), + utils::Vector{const_i32(42), const_i32(-1)}); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1 %2 = OpTypeVector %3 2 @@ -89,9 +91,10 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec2i) { } TEST_F(SpvGeneratorImplTest, Constant_Vec3u) { - auto* v = b.create(mod.types.vec3(mod.types.u32()), - utils::Vector{b.U32(42), b.U32(0), b.U32(4000000000)}, - false, true); + auto const_u32 = [&](float val) { return mod.constant_values.Get(u32(val)); }; + auto* v = mod.constant_values.Composite( + mod.Types().vec3(mod.Types().u32()), + utils::Vector{const_u32(42), const_u32(0), const_u32(4000000000)}); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0 %2 = OpTypeVector %3 3 @@ -103,9 +106,10 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec3u) { } TEST_F(SpvGeneratorImplTest, Constant_Vec4f) { - auto* v = b.create( - mod.types.vec4(mod.types.f32()), utils::Vector{b.F32(42), b.F32(0), b.F32(0.25), b.F32(-1)}, - false, true); + auto const_f32 = [&](float val) { return mod.constant_values.Get(f32(val)); }; + auto* v = mod.constant_values.Composite( + mod.Types().vec4(mod.Types().f32()), + utils::Vector{const_f32(42), const_f32(0), const_f32(0.25), const_f32(-1)}); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32 %2 = OpTypeVector %3 4 @@ -118,8 +122,9 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec4f) { } TEST_F(SpvGeneratorImplTest, Constant_Vec2h) { - auto* v = b.create(mod.types.vec2(mod.types.f16()), - utils::Vector{b.F16(42), b.F16(0.25)}, false, false); + auto const_f16 = [&](float val) { return mod.constant_values.Get(f16(val)); }; + auto* v = mod.constant_values.Composite(mod.Types().vec2(mod.Types().f16()), + utils::Vector{const_f16(42), const_f16(0.25)}); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16 %2 = OpTypeVector %3 2 @@ -130,18 +135,18 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec2h) { } TEST_F(SpvGeneratorImplTest, Constant_Mat2x3f) { - auto* f32 = mod.types.f32(); - auto* v = b.create( - mod.types.mat2x3(f32), + auto const_f32 = [&](float val) { return mod.constant_values.Get(f32(val)); }; + auto* f32 = mod.Types().f32(); + auto* v = mod.constant_values.Composite( + mod.Types().mat2x3(f32), utils::Vector{ - b.create(mod.types.vec3(f32), - utils::Vector{b.F32(42), b.F32(-1), b.F32(0.25)}, false, - false), - b.create(mod.types.vec3(f32), - utils::Vector{b.F32(-42), b.F32(0), b.F32(-0.25)}, false, - true), - }, - false, false); + mod.constant_values.Composite( + mod.Types().vec3(f32), + utils::Vector{const_f32(42), const_f32(-1), const_f32(0.25)}), + mod.constant_values.Composite( + mod.Types().vec3(f32), + utils::Vector{const_f32(-42), const_f32(0), const_f32(-0.25)}), + }); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 32 %3 = OpTypeVector %4 3 @@ -159,20 +164,20 @@ TEST_F(SpvGeneratorImplTest, Constant_Mat2x3f) { } TEST_F(SpvGeneratorImplTest, Constant_Mat4x2h) { - auto* f16 = mod.types.f16(); - auto* v = b.create( - mod.types.mat4x2(f16), + auto const_f16 = [&](float val) { return mod.constant_values.Get(f16(val)); }; + auto* f16 = mod.Types().f16(); + auto* v = mod.constant_values.Composite( + mod.Types().mat4x2(f16), utils::Vector{ - b.create(mod.types.vec2(f16), utils::Vector{b.F16(42), b.F16(-1)}, - false, false), - b.create(mod.types.vec2(f16), utils::Vector{b.F16(0), b.F16(0.25)}, - false, true), - b.create(mod.types.vec2(f16), utils::Vector{b.F16(-42), b.F16(1)}, - false, false), - b.create(mod.types.vec2(f16), utils::Vector{b.F16(0.5), b.F16(-0)}, - false, true), - }, - false, false); + mod.constant_values.Composite(mod.Types().vec2(f16), + utils::Vector{const_f16(42), const_f16(-1)}), + mod.constant_values.Composite(mod.Types().vec2(f16), + utils::Vector{const_f16(0), const_f16(0.25)}), + mod.constant_values.Composite(mod.Types().vec2(f16), + utils::Vector{const_f16(-42), const_f16(1)}), + mod.constant_values.Composite(mod.Types().vec2(f16), + utils::Vector{const_f16(0.5), const_f16(-0)}), + }); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%4 = OpTypeFloat 16 %3 = OpTypeVector %4 2 diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc index 07f2086f5f..22b059425a 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_function_test.cc @@ -18,7 +18,7 @@ namespace tint::writer::spirv { namespace { TEST_F(SpvGeneratorImplTest, Function_Empty) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -34,7 +34,7 @@ OpFunctionEnd // Test that we do not emit the same function type more than once. TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -46,8 +46,8 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) { } TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) { - auto* func = b.CreateFunction("main", mod.types.void_(), ir::Function::PipelineStage::kCompute, - {{32, 4, 1}}); + auto* func = b.CreateFunction("main", mod.Types().void_(), + ir::Function::PipelineStage::kCompute, {{32, 4, 1}}); func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -65,7 +65,7 @@ OpFunctionEnd TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) { auto* func = - b.CreateFunction("main", mod.types.void_(), ir::Function::PipelineStage::kFragment); + b.CreateFunction("main", mod.Types().void_(), ir::Function::PipelineStage::kFragment); func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -82,7 +82,8 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) { - auto* func = b.CreateFunction("main", mod.types.void_(), ir::Function::PipelineStage::kVertex); + auto* func = + b.CreateFunction("main", mod.Types().void_(), ir::Function::PipelineStage::kVertex); func->StartTarget()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -98,15 +99,16 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) { - auto* f1 = b.CreateFunction("main1", mod.types.void_(), ir::Function::PipelineStage::kCompute, + auto* f1 = b.CreateFunction("main1", mod.Types().void_(), ir::Function::PipelineStage::kCompute, {{32, 4, 1}}); f1->StartTarget()->SetInstructions(utils::Vector{b.Branch(f1->EndTarget())}); - auto* f2 = b.CreateFunction("main2", mod.types.void_(), ir::Function::PipelineStage::kCompute, + auto* f2 = b.CreateFunction("main2", mod.Types().void_(), ir::Function::PipelineStage::kCompute, {{8, 2, 16}}); f2->StartTarget()->SetInstructions(utils::Vector{b.Branch(f2->EndTarget())}); - auto* f3 = b.CreateFunction("main3", mod.types.void_(), ir::Function::PipelineStage::kFragment); + auto* f3 = + b.CreateFunction("main3", mod.Types().void_(), ir::Function::PipelineStage::kFragment); f3->StartTarget()->SetInstructions(utils::Vector{b.Branch(f3->EndTarget())}); generator_.EmitFunction(f1); @@ -139,7 +141,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_ReturnValue) { - auto* func = b.CreateFunction("foo", mod.types.i32()); + auto* func = b.CreateFunction("foo", mod.Types().i32()); func->StartTarget()->SetInstructions( utils::Vector{b.Branch(func->EndTarget(), utils::Vector{b.Constant(i32(42))})}); @@ -156,7 +158,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_Parameters) { - auto* i32 = mod.types.i32(); + auto* i32 = mod.Types().i32(); auto* x = b.FunctionParam(i32); auto* y = b.FunctionParam(i32); auto* result = b.Add(i32, x, y); @@ -184,7 +186,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_Call) { - auto* i32_ty = mod.types.i32(); + auto* i32_ty = mod.Types().i32(); auto* x = b.FunctionParam(i32_ty); auto* y = b.FunctionParam(i32_ty); auto* result = b.Add(i32_ty, x, y); @@ -193,7 +195,7 @@ TEST_F(SpvGeneratorImplTest, Function_Call) { foo->StartTarget()->SetInstructions( utils::Vector{result, b.Branch(foo->EndTarget(), utils::Vector{result})}); - auto* bar = b.CreateFunction("bar", mod.types.void_()); + auto* bar = b.CreateFunction("bar", mod.Types().void_()); bar->StartTarget()->SetInstructions( utils::Vector{b.UserCall(i32_ty, mod.symbols.Get("foo"), utils::Vector{b.Constant(i32(2)), b.Constant(i32(3))}), @@ -225,12 +227,12 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, Function_Call_Void) { - auto* foo = b.CreateFunction("foo", mod.types.void_()); + auto* foo = b.CreateFunction("foo", mod.Types().void_()); foo->StartTarget()->SetInstructions(utils::Vector{b.Branch(foo->EndTarget())}); - auto* bar = b.CreateFunction("bar", mod.types.void_()); + auto* bar = b.CreateFunction("bar", mod.Types().void_()); bar->StartTarget()->SetInstructions( - utils::Vector{b.UserCall(mod.types.void_(), mod.symbols.Get("foo"), utils::Empty), + utils::Vector{b.UserCall(mod.Types().void_(), mod.symbols.Get("foo"), utils::Empty), b.Branch(bar->EndTarget())}); generator_.EmitFunction(foo); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc index 4c18ec1013..ff7defc320 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_if_test.cc @@ -20,7 +20,7 @@ namespace tint::writer::spirv { namespace { TEST_F(SpvGeneratorImplTest, If_TrueEmpty_FalseEmpty) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); auto* i = b.CreateIf(b.Constant(true)); i->True()->SetInstructions(utils::Vector{b.Branch(i->Merge())}); @@ -46,7 +46,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, If_FalseEmpty) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); auto* i = b.CreateIf(b.Constant(true)); i->False()->SetInstructions(utils::Vector{b.Branch(i->Merge())}); @@ -54,7 +54,7 @@ TEST_F(SpvGeneratorImplTest, If_FalseEmpty) { auto* true_block = i->True(); true_block->SetInstructions(utils::Vector{ - b.Add(mod.types.i32(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())}); + b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())}); func->StartTarget()->SetInstructions(utils::Vector{i}); @@ -80,7 +80,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, If_TrueEmpty) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); auto* i = b.CreateIf(b.Constant(true)); i->True()->SetInstructions(utils::Vector{b.Branch(i->Merge())}); @@ -88,7 +88,7 @@ TEST_F(SpvGeneratorImplTest, If_TrueEmpty) { auto* false_block = i->False(); false_block->SetInstructions(utils::Vector{ - b.Add(mod.types.i32(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())}); + b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(1_i)), b.Branch(i->Merge())}); func->StartTarget()->SetInstructions(utils::Vector{i}); @@ -114,7 +114,7 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, If_BothBranchesReturn) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); auto* i = b.CreateIf(b.Constant(true)); i->True()->SetInstructions(utils::Vector{b.Branch(func->EndTarget())}); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc index 0a8dc904ca..80e69c3b89 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_type_test.cc @@ -25,43 +25,43 @@ namespace tint::writer::spirv { namespace { TEST_F(SpvGeneratorImplTest, Type_Void) { - auto id = generator_.Type(mod.types.void_()); + auto id = generator_.Type(mod.Types().void_()); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), "%1 = OpTypeVoid\n"); } TEST_F(SpvGeneratorImplTest, Type_Bool) { - auto id = generator_.Type(mod.types.bool_()); + auto id = generator_.Type(mod.Types().bool_()); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), "%1 = OpTypeBool\n"); } TEST_F(SpvGeneratorImplTest, Type_I32) { - auto id = generator_.Type(mod.types.i32()); + auto id = generator_.Type(mod.Types().i32()); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 1\n"); } TEST_F(SpvGeneratorImplTest, Type_U32) { - auto id = generator_.Type(mod.types.u32()); + auto id = generator_.Type(mod.Types().u32()); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), "%1 = OpTypeInt 32 0\n"); } TEST_F(SpvGeneratorImplTest, Type_F32) { - auto id = generator_.Type(mod.types.f32()); + auto id = generator_.Type(mod.Types().f32()); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 32\n"); } TEST_F(SpvGeneratorImplTest, Type_F16) { - auto id = generator_.Type(mod.types.f16()); + auto id = generator_.Type(mod.Types().f16()); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), "%1 = OpTypeFloat 16\n"); } TEST_F(SpvGeneratorImplTest, Type_Vec2i) { - auto* vec = b.ir.types.Get(b.ir.types.i32(), 2u); + auto* vec = mod.Types().Get(mod.Types().i32(), 2u); auto id = generator_.Type(vec); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), @@ -70,7 +70,7 @@ TEST_F(SpvGeneratorImplTest, Type_Vec2i) { } TEST_F(SpvGeneratorImplTest, Type_Vec3u) { - auto* vec = b.ir.types.Get(b.ir.types.u32(), 3u); + auto* vec = mod.Types().Get(mod.Types().u32(), 3u); auto id = generator_.Type(vec); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), @@ -79,7 +79,7 @@ TEST_F(SpvGeneratorImplTest, Type_Vec3u) { } TEST_F(SpvGeneratorImplTest, Type_Vec4f) { - auto* vec = b.ir.types.Get(b.ir.types.f32(), 4u); + auto* vec = mod.Types().Get(mod.Types().f32(), 4u); auto id = generator_.Type(vec); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), @@ -88,7 +88,7 @@ TEST_F(SpvGeneratorImplTest, Type_Vec4f) { } TEST_F(SpvGeneratorImplTest, Type_Vec4h) { - auto* vec = b.ir.types.Get(b.ir.types.f16(), 2u); + auto* vec = mod.Types().Get(mod.Types().f16(), 2u); auto id = generator_.Type(vec); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), @@ -97,7 +97,7 @@ TEST_F(SpvGeneratorImplTest, Type_Vec4h) { } TEST_F(SpvGeneratorImplTest, Type_Vec4Bool) { - auto* vec = b.ir.types.Get(b.ir.types.bool_(), 4u); + auto* vec = mod.Types().Get(mod.Types().bool_(), 4u); auto id = generator_.Type(vec); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), @@ -106,7 +106,7 @@ TEST_F(SpvGeneratorImplTest, Type_Vec4Bool) { } TEST_F(SpvGeneratorImplTest, Type_Mat2x3f) { - auto* vec = b.ir.types.mat2x3(b.ir.types.f32()); + auto* vec = mod.Types().mat2x3(mod.Types().f32()); auto id = generator_.Type(vec); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), @@ -116,7 +116,7 @@ TEST_F(SpvGeneratorImplTest, Type_Mat2x3f) { } TEST_F(SpvGeneratorImplTest, Type_Mat4x2h) { - auto* vec = b.ir.types.mat4x2(b.ir.types.f16()); + auto* vec = mod.Types().mat4x2(mod.Types().f16()); auto id = generator_.Type(vec); EXPECT_EQ(id, 1u); EXPECT_EQ(DumpTypes(), @@ -128,10 +128,10 @@ TEST_F(SpvGeneratorImplTest, Type_Mat4x2h) { // Test that we can emit multiple types. // Includes types with the same opcode but different parameters. TEST_F(SpvGeneratorImplTest, Type_Multiple) { - EXPECT_EQ(generator_.Type(mod.types.i32()), 1u); - EXPECT_EQ(generator_.Type(mod.types.u32()), 2u); - EXPECT_EQ(generator_.Type(mod.types.f32()), 3u); - EXPECT_EQ(generator_.Type(mod.types.f16()), 4u); + EXPECT_EQ(generator_.Type(mod.Types().i32()), 1u); + EXPECT_EQ(generator_.Type(mod.Types().u32()), 2u); + EXPECT_EQ(generator_.Type(mod.Types().f32()), 3u); + EXPECT_EQ(generator_.Type(mod.Types().f16()), 4u); EXPECT_EQ(DumpTypes(), R"(%1 = OpTypeInt 32 1 %2 = OpTypeInt 32 0 %3 = OpTypeFloat 32 @@ -141,7 +141,7 @@ TEST_F(SpvGeneratorImplTest, Type_Multiple) { // Test that we do not emit the same type more than once. TEST_F(SpvGeneratorImplTest, Type_Deduplicate) { - auto* i32 = mod.types.i32(); + auto* i32 = mod.Types().i32(); EXPECT_EQ(generator_.Type(i32), 1u); EXPECT_EQ(generator_.Type(i32), 1u); EXPECT_EQ(generator_.Type(i32), 1u); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc index 814ef05295..4a42bdc664 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_var_test.cc @@ -21,10 +21,10 @@ namespace tint::writer::spirv { namespace { TEST_F(SpvGeneratorImplTest, FunctionVar_NoInit) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); - auto* ty = mod.types.Get(mod.types.i32(), builtin::AddressSpace::kFunction, - builtin::Access::kReadWrite); + auto* ty = mod.Types().Get(mod.Types().i32(), builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite); func->StartTarget()->SetInstructions(utils::Vector{b.Declare(ty), b.Branch(func->EndTarget())}); generator_.EmitFunction(func); @@ -42,10 +42,10 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, FunctionVar_WithInit) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); - auto* ty = mod.types.Get(mod.types.i32(), builtin::AddressSpace::kFunction, - builtin::Access::kReadWrite); + auto* ty = mod.Types().Get(mod.Types().i32(), builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite); auto* v = b.Declare(ty); v->SetInitializer(b.Constant(42_i)); @@ -68,10 +68,10 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, FunctionVar_Name) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); - auto* ty = mod.types.Get(mod.types.i32(), builtin::AddressSpace::kFunction, - builtin::Access::kReadWrite); + auto* ty = mod.Types().Get(mod.Types().i32(), builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite); auto* v = b.Declare(ty); func->StartTarget()->SetInstructions(utils::Vector{v, b.Branch(func->EndTarget())}); mod.SetName(v, "myvar"); @@ -92,10 +92,10 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, FunctionVar_DeclInsideBlock) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); - auto* ty = mod.types.Get(mod.types.i32(), builtin::AddressSpace::kFunction, - builtin::Access::kReadWrite); + auto* ty = mod.Types().Get(mod.Types().i32(), builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite); auto* v = b.Declare(ty); v->SetInitializer(b.Constant(42_i)); @@ -132,11 +132,11 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, FunctionVar_Load) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); - auto* store_ty = mod.types.i32(); - auto* ty = mod.types.Get(store_ty, builtin::AddressSpace::kFunction, - builtin::Access::kReadWrite); + auto* store_ty = mod.Types().i32(); + auto* ty = mod.Types().Get(store_ty, builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite); auto* v = b.Declare(ty); func->StartTarget()->SetInstructions(utils::Vector{v, b.Load(v), b.Branch(func->EndTarget())}); @@ -156,10 +156,10 @@ OpFunctionEnd } TEST_F(SpvGeneratorImplTest, FunctionVar_Store) { - auto* func = b.CreateFunction("foo", mod.types.void_()); + auto* func = b.CreateFunction("foo", mod.Types().void_()); - auto* ty = mod.types.Get(mod.types.i32(), builtin::AddressSpace::kFunction, - builtin::Access::kReadWrite); + auto* ty = mod.Types().Get(mod.Types().i32(), builtin::AddressSpace::kFunction, + builtin::Access::kReadWrite); auto* v = b.Declare(ty); func->StartTarget()->SetInstructions( utils::Vector{v, b.Store(v, b.Constant(42_i)), b.Branch(func->EndTarget())});