[ir] Cleanup composite creation in tests

This CL adds some helpers to make composites easier to use in tests.

Bug: tint:1718
Change-Id: I16a0e94978c43efa619b31b6815089c8fff6983f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133920
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
dan sinclair 2023-05-24 00:27:17 +00:00 committed by Dawn LUCI CQ
parent 4e2be2d083
commit 07a1d65fbc
4 changed files with 78 additions and 48 deletions

View File

@ -46,6 +46,7 @@
#include "src/tint/type/f32.h" #include "src/tint/type/f32.h"
#include "src/tint/type/i32.h" #include "src/tint/type/i32.h"
#include "src/tint/type/u32.h" #include "src/tint/type/u32.h"
#include "src/tint/type/vector.h"
#include "src/tint/type/void.h" #include "src/tint/type/void.h"
namespace tint::ir { namespace tint::ir {
@ -119,6 +120,41 @@ class Builder {
return ir.constants_arena.Create<T>(std::forward<ARGS>(args)...); return ir.constants_arena.Create<T>(std::forward<ARGS>(args)...);
} }
/// @param v the value
/// @returns the constant value
const constant::Value* Bool(bool v) {
// TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
return Constant(create<constant::Scalar<bool>>(ir.types.Get<type::Bool>(), v))->Value();
}
/// @param v the value
/// @returns the constant value
const constant::Value* U32(uint32_t v) {
// TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
return Constant(create<constant::Scalar<u32>>(ir.types.Get<type::U32>(), u32(v)))->Value();
}
/// @param v the value
/// @returns the constant value
const constant::Value* I32(int32_t v) {
// TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
return Constant(create<constant::Scalar<i32>>(ir.types.Get<type::I32>(), i32(v)))->Value();
}
/// @param v the value
/// @returns the constant value
const constant::Value* F16(float v) {
// TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
return Constant(create<constant::Scalar<f16>>(ir.types.Get<type::F16>(), f16(v)))->Value();
}
/// @param v the value
/// @returns the constant value
const constant::Value* F32(float v) {
// TODO(dsinclair): Replace when constant::Value is uniqed by the arena.
return Constant(create<constant::Scalar<f32>>(ir.types.Get<type::F32>(), f32(v)))->Value();
}
/// Creates a new ir::Constant /// Creates a new ir::Constant
/// @param val the constant value /// @param val the constant value
/// @returns the new constant /// @returns the new constant

View File

@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include "src/tint/type/type.h" #include "src/tint/type/type.h"
#include "src/tint/type/vector.h"
#include "src/tint/utils/hash.h" #include "src/tint/utils/hash.h"
#include "src/tint/utils/unique_allocator.h" #include "src/tint/utils/unique_allocator.h"
@ -84,6 +85,23 @@ class Manager final {
return types_.Find<TYPE>(std::forward<ARGS>(args)...); return types_.Find<TYPE>(std::forward<ARGS>(args)...);
} }
/// @param inner the inner type
/// @param size the vector size
/// @returns the vector type
type::Type* vec(type::Type* inner, uint32_t size) { return Get<type::Vector>(inner, size); }
/// @param inner the inner type
/// @returns the vector type
type::Type* vec2(type::Type* inner) { return vec(inner, 2); }
/// @param inner the inner type
/// @returns the vector type
type::Type* vec3(type::Type* inner) { return vec(inner, 3); }
/// @param inner the inner type
/// @returns the vector type
type::Type* vec4(type::Type* inner) { return vec(inner, 4); }
/// @returns an iterator to the beginning of the types /// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); } TypeIterator begin() const { return types_.begin(); }
/// @returns an iterator to the end of the types /// @returns an iterator to the end of the types

View File

@ -147,12 +147,10 @@ OpFunctionEnd
TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) { TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) {
auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>()); auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
auto* lhs = mod.constants_arena.Create<constant::Composite>( auto* lhs = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u), utils::Vector{b.I32(42), b.I32(-1)}, false, false);
utils::Vector{b.Constant(42_i)->Value(), b.Constant(-1_i)->Value()}, false, false); auto* rhs = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
auto* rhs = mod.constants_arena.Create<constant::Composite>( utils::Vector{b.I32(0), b.I32(-43)}, false, false);
mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
utils::Vector{b.Constant(0_i)->Value(), b.Constant(-43_i)->Value()}, false, false);
func->StartTarget()->SetInstructions( func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u), utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
b.Constant(lhs), b.Constant(rhs)), b.Constant(lhs), b.Constant(rhs)),
@ -180,16 +178,12 @@ OpFunctionEnd
TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) { TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) {
auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>()); auto* func = b.CreateFunction("foo", mod.types.Get<type::Void>());
auto* lhs = mod.constants_arena.Create<constant::Composite>( auto* lhs = b.create<constant::Composite>(
mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u), mod.types.vec4(mod.types.Get<type::F32>()),
utils::Vector{b.Constant(42_f)->Value(), b.Constant(-1_f)->Value(), utils::Vector{b.F32(42), b.F32(-1), b.F32(0), b.F32(1.25)}, false, false);
b.Constant(0_f)->Value(), b.Constant(1.25_f)->Value()}, auto* rhs = b.create<constant::Composite>(
false, false); mod.types.vec4(mod.types.Get<type::F32>()),
auto* rhs = mod.constants_arena.Create<constant::Composite>( utils::Vector{b.F32(0), b.F32(1.25), b.F32(-42), b.F32(1)}, false, false);
mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
utils::Vector{b.Constant(0_f)->Value(), b.Constant(1.25_f)->Value(),
b.Constant(-42_f)->Value(), b.Constant(1_f)->Value()},
false, false);
func->StartTarget()->SetInstructions( func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u), utils::Vector{b.Subtract(mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
b.Constant(lhs), b.Constant(rhs)), b.Constant(lhs), b.Constant(rhs)),

View File

@ -63,11 +63,10 @@ TEST_F(SpvGeneratorImplTest, Constant_F16) {
} }
TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) { TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) {
auto* t = b.Constant(true); auto* v = b.create<constant::Composite>(
auto* f = b.Constant(false); mod.types.vec4(mod.types.Get<type::Bool>()),
auto* v = mod.constants_arena.Create<constant::Composite>( utils::Vector{b.Bool(true), b.Bool(false), b.Bool(false), b.Bool(true)}, false, true);
mod.types.Get<type::Vector>(mod.types.Get<type::Bool>(), 4u),
utils::Vector{t->Value(), f->Value(), f->Value(), t->Value()}, false, true);
generator_.Constant(b.Constant(v)); generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeBool EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeBool
%2 = OpTypeVector %3 4 %2 = OpTypeVector %3 4
@ -78,12 +77,8 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) {
} }
TEST_F(SpvGeneratorImplTest, Constant_Vec2i) { TEST_F(SpvGeneratorImplTest, Constant_Vec2i) {
auto* i = mod.types.Get<type::I32>(); auto* v = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::I32>()),
auto* i_42 = b.Constant(i32(42)); utils::Vector{b.I32(42), b.I32(-1)}, false, false);
auto* i_n1 = b.Constant(i32(-1));
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(i, 2u), utils::Vector{i_42->Value(), i_n1->Value()}, false,
false);
generator_.Constant(b.Constant(v)); generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1 EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
%2 = OpTypeVector %3 2 %2 = OpTypeVector %3 2
@ -94,13 +89,9 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec2i) {
} }
TEST_F(SpvGeneratorImplTest, Constant_Vec3u) { TEST_F(SpvGeneratorImplTest, Constant_Vec3u) {
auto* u = mod.types.Get<type::U32>(); auto* v = b.create<constant::Composite>(mod.types.vec3(mod.types.Get<type::U32>()),
auto* u_42 = b.Constant(u32(42)); utils::Vector{b.U32(42), b.U32(0), b.U32(4000000000)},
auto* u_0 = b.Constant(u32(0)); false, true);
auto* u_4b = b.Constant(u32(4000000000));
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(u, 3u),
utils::Vector{u_42->Value(), u_0->Value(), u_4b->Value()}, false, true);
generator_.Constant(b.Constant(v)); generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0 EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 0
%2 = OpTypeVector %3 3 %2 = OpTypeVector %3 3
@ -112,14 +103,9 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec3u) {
} }
TEST_F(SpvGeneratorImplTest, Constant_Vec4f) { TEST_F(SpvGeneratorImplTest, Constant_Vec4f) {
auto* f = mod.types.Get<type::F32>(); auto* v = b.create<constant::Composite>(
auto* f_42 = b.Constant(f32(42)); mod.types.vec4(mod.types.Get<type::F32>()),
auto* f_0 = b.Constant(f32(0)); utils::Vector{b.F32(42), b.F32(0), b.F32(0.25), b.F32(-1)}, false, true);
auto* f_q = b.Constant(f32(0.25));
auto* f_n1 = b.Constant(f32(-1));
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(f, 4u),
utils::Vector{f_42->Value(), f_0->Value(), f_q->Value(), f_n1->Value()}, false, true);
generator_.Constant(b.Constant(v)); generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32 EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 32
%2 = OpTypeVector %3 4 %2 = OpTypeVector %3 4
@ -132,12 +118,8 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec4f) {
} }
TEST_F(SpvGeneratorImplTest, Constant_Vec2h) { TEST_F(SpvGeneratorImplTest, Constant_Vec2h) {
auto* h = mod.types.Get<type::F16>(); auto* v = b.create<constant::Composite>(mod.types.vec2(mod.types.Get<type::F16>()),
auto* h_42 = b.Constant(f16(42)); utils::Vector{b.F16(42), b.F16(0.25)}, false, false);
auto* h_q = b.Constant(f16(0.25));
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(h, 2u), utils::Vector{h_42->Value(), h_q->Value()}, false,
false);
generator_.Constant(b.Constant(v)); generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16 EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16
%2 = OpTypeVector %3 2 %2 = OpTypeVector %3 2