diff --git a/src/tint/ir/builder.h b/src/tint/ir/builder.h index 976d86687e..78bcb9b4a9 100644 --- a/src/tint/ir/builder.h +++ b/src/tint/ir/builder.h @@ -46,6 +46,7 @@ #include "src/tint/type/f32.h" #include "src/tint/type/i32.h" #include "src/tint/type/u32.h" +#include "src/tint/type/vector.h" #include "src/tint/type/void.h" namespace tint::ir { @@ -119,6 +120,41 @@ class Builder { return ir.constants_arena.Create(std::forward(args)...); } + /// @param v the value + /// @returns the constant value + const constant::Value* Bool(bool v) { + // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. + return Constant(create>(ir.types.Get(), v))->Value(); + } + + /// @param v the value + /// @returns the constant value + const constant::Value* U32(uint32_t v) { + // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. + return Constant(create>(ir.types.Get(), u32(v)))->Value(); + } + + /// @param v the value + /// @returns the constant value + const constant::Value* I32(int32_t v) { + // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. + return Constant(create>(ir.types.Get(), i32(v)))->Value(); + } + + /// @param v the value + /// @returns the constant value + const constant::Value* F16(float v) { + // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. + return Constant(create>(ir.types.Get(), f16(v)))->Value(); + } + + /// @param v the value + /// @returns the constant value + const constant::Value* F32(float v) { + // TODO(dsinclair): Replace when constant::Value is uniqed by the arena. + return Constant(create>(ir.types.Get(), f32(v)))->Value(); + } + /// Creates a new ir::Constant /// @param val the constant value /// @returns the new constant diff --git a/src/tint/type/manager.h b/src/tint/type/manager.h index 0650f1bb4a..4eb48bb1dc 100644 --- a/src/tint/type/manager.h +++ b/src/tint/type/manager.h @@ -18,6 +18,7 @@ #include #include "src/tint/type/type.h" +#include "src/tint/type/vector.h" #include "src/tint/utils/hash.h" #include "src/tint/utils/unique_allocator.h" @@ -84,6 +85,23 @@ class Manager final { return types_.Find(std::forward(args)...); } + /// @param inner the inner type + /// @param size the vector size + /// @returns the vector type + type::Type* vec(type::Type* inner, uint32_t size) { return Get(inner, size); } + + /// @param inner the inner type + /// @returns the vector type + type::Type* vec2(type::Type* inner) { return vec(inner, 2); } + + /// @param inner the inner type + /// @returns the vector type + type::Type* vec3(type::Type* inner) { return vec(inner, 3); } + + /// @param inner the inner type + /// @returns the vector type + type::Type* vec4(type::Type* inner) { return vec(inner, 4); } + /// @returns an iterator to the beginning of the types TypeIterator begin() const { return types_.begin(); } /// @returns an iterator to the end of the types diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc index 6c131c5541..879223952d 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc @@ -147,12 +147,10 @@ OpFunctionEnd TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) { auto* func = b.CreateFunction("foo", mod.types.Get()); - auto* lhs = mod.constants_arena.Create( - mod.types.Get(mod.types.Get(), 2u), - utils::Vector{b.Constant(42_i)->Value(), b.Constant(-1_i)->Value()}, false, false); - auto* rhs = mod.constants_arena.Create( - mod.types.Get(mod.types.Get(), 2u), - utils::Vector{b.Constant(0_i)->Value(), b.Constant(-43_i)->Value()}, false, false); + auto* lhs = b.create(mod.types.vec2(mod.types.Get()), + utils::Vector{b.I32(42), b.I32(-1)}, false, false); + auto* rhs = b.create(mod.types.vec2(mod.types.Get()), + utils::Vector{b.I32(0), b.I32(-43)}, false, false); func->StartTarget()->SetInstructions( utils::Vector{b.Subtract(mod.types.Get(mod.types.Get(), 2u), b.Constant(lhs), b.Constant(rhs)), @@ -180,16 +178,12 @@ OpFunctionEnd TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) { auto* func = b.CreateFunction("foo", mod.types.Get()); - auto* lhs = mod.constants_arena.Create( - mod.types.Get(mod.types.Get(), 4u), - utils::Vector{b.Constant(42_f)->Value(), b.Constant(-1_f)->Value(), - b.Constant(0_f)->Value(), b.Constant(1.25_f)->Value()}, - false, false); - auto* rhs = mod.constants_arena.Create( - mod.types.Get(mod.types.Get(), 4u), - utils::Vector{b.Constant(0_f)->Value(), b.Constant(1.25_f)->Value(), - b.Constant(-42_f)->Value(), b.Constant(1_f)->Value()}, - false, false); + auto* lhs = b.create( + mod.types.vec4(mod.types.Get()), + utils::Vector{b.F32(42), b.F32(-1), b.F32(0), b.F32(1.25)}, false, false); + auto* rhs = b.create( + mod.types.vec4(mod.types.Get()), + utils::Vector{b.F32(0), b.F32(1.25), b.F32(-42), b.F32(1)}, false, false); func->StartTarget()->SetInstructions( utils::Vector{b.Subtract(mod.types.Get(mod.types.Get(), 4u), b.Constant(lhs), b.Constant(rhs)), diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc index 6ab48aa2a2..75775ff9a2 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_constant_test.cc @@ -63,11 +63,10 @@ TEST_F(SpvGeneratorImplTest, Constant_F16) { } TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) { - auto* t = b.Constant(true); - auto* f = b.Constant(false); - auto* v = mod.constants_arena.Create( - mod.types.Get(mod.types.Get(), 4u), - utils::Vector{t->Value(), f->Value(), f->Value(), t->Value()}, false, true); + auto* v = b.create( + mod.types.vec4(mod.types.Get()), + utils::Vector{b.Bool(true), b.Bool(false), b.Bool(false), b.Bool(true)}, false, true); + generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeBool %2 = OpTypeVector %3 4 @@ -78,12 +77,8 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) { } TEST_F(SpvGeneratorImplTest, Constant_Vec2i) { - auto* i = mod.types.Get(); - auto* i_42 = b.Constant(i32(42)); - auto* i_n1 = b.Constant(i32(-1)); - auto* v = mod.constants_arena.Create( - mod.types.Get(i, 2u), utils::Vector{i_42->Value(), i_n1->Value()}, false, - false); + auto* v = b.create(mod.types.vec2(mod.types.Get()), + utils::Vector{b.I32(42), b.I32(-1)}, false, false); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1 %2 = OpTypeVector %3 2 @@ -94,13 +89,9 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec2i) { } TEST_F(SpvGeneratorImplTest, Constant_Vec3u) { - auto* u = mod.types.Get(); - auto* u_42 = b.Constant(u32(42)); - auto* u_0 = b.Constant(u32(0)); - auto* u_4b = b.Constant(u32(4000000000)); - auto* v = mod.constants_arena.Create( - mod.types.Get(u, 3u), - utils::Vector{u_42->Value(), u_0->Value(), u_4b->Value()}, false, true); + auto* v = b.create(mod.types.vec3(mod.types.Get()), + utils::Vector{b.U32(42), b.U32(0), b.U32(4000000000)}, + false, true); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0 %2 = OpTypeVector %3 3 @@ -112,14 +103,9 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec3u) { } TEST_F(SpvGeneratorImplTest, Constant_Vec4f) { - auto* f = mod.types.Get(); - auto* f_42 = b.Constant(f32(42)); - auto* f_0 = b.Constant(f32(0)); - auto* f_q = b.Constant(f32(0.25)); - auto* f_n1 = b.Constant(f32(-1)); - auto* v = mod.constants_arena.Create( - mod.types.Get(f, 4u), - utils::Vector{f_42->Value(), f_0->Value(), f_q->Value(), f_n1->Value()}, false, true); + auto* v = b.create( + mod.types.vec4(mod.types.Get()), + utils::Vector{b.F32(42), b.F32(0), b.F32(0.25), b.F32(-1)}, false, true); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32 %2 = OpTypeVector %3 4 @@ -132,12 +118,8 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec4f) { } TEST_F(SpvGeneratorImplTest, Constant_Vec2h) { - auto* h = mod.types.Get(); - auto* h_42 = b.Constant(f16(42)); - auto* h_q = b.Constant(f16(0.25)); - auto* v = mod.constants_arena.Create( - mod.types.Get(h, 2u), utils::Vector{h_42->Value(), h_q->Value()}, false, - false); + auto* v = b.create(mod.types.vec2(mod.types.Get()), + utils::Vector{b.F16(42), b.F16(0.25)}, false, false); generator_.Constant(b.Constant(v)); EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16 %2 = OpTypeVector %3 2