[ir] Deduplicate constants

This CL updates the IR builder to deduplicate constants such that for a
given constant value only a single `ir::Constant` will be created.

Bug: tint:1935
Change-Id: Ia743cdb7782cf7ea9918b913dac70b0a3dde4499
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133241
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2023-05-18 14:51:54 +00:00 committed by Dawn LUCI CQ
parent fe58d80871
commit 97744832bc
7 changed files with 138 additions and 33 deletions

View File

@ -110,14 +110,14 @@ class Builder {
template <typename T, typename... ARGS>
utils::traits::EnableIf<utils::traits::IsTypeOrDerived<T, constant::Value>, const T>* create(
ARGS&&... args) {
return ir.constants.Create<T>(std::forward<ARGS>(args)...);
return ir.constants_arena.Create<T>(std::forward<ARGS>(args)...);
}
/// Creates a new ir::Constant
/// @param val the constant value
/// @returns the new constant
ir::Constant* Constant(const constant::Value* val) {
return ir.values.Create<ir::Constant>(val);
return ir.constants.GetOrCreate(val, [&]() { return ir.values.Create<ir::Constant>(val); });
}
/// Creates a ir::Constant for an i32 Scalar

View File

@ -146,7 +146,7 @@ class Impl {
/* src */ {&program_->Symbols()},
/* dst */ {&builder_.ir.symbols, &builder_.ir.types},
},
/* dst */ {&builder_.ir.constants},
/* dst */ {&builder_.ir.constants_arena},
};
/// The stack of flow control blocks.

View File

@ -70,6 +70,29 @@ TEST_F(IR_BuilderImplTest, EmitLiteral_Bool_False) {
EXPECT_FALSE(val->As<constant::Scalar<bool>>()->ValueAs<bool>());
}
TEST_F(IR_BuilderImplTest, EmitLiteral_Bool_Deduped) {
GlobalVar("a", ty.bool_(), builtin::AddressSpace::kPrivate, Expr(true));
GlobalVar("b", ty.bool_(), builtin::AddressSpace::kPrivate, Expr(false));
GlobalVar("c", ty.bool_(), builtin::AddressSpace::kPrivate, Expr(true));
GlobalVar("d", ty.bool_(), builtin::AddressSpace::kPrivate, Expr(false));
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
auto* var_a = m.Get().root_block->instructions[0]->As<ir::Var>();
ASSERT_NE(var_a, nullptr);
auto* var_b = m.Get().root_block->instructions[1]->As<ir::Var>();
ASSERT_NE(var_b, nullptr);
auto* var_c = m.Get().root_block->instructions[2]->As<ir::Var>();
ASSERT_NE(var_c, nullptr);
auto* var_d = m.Get().root_block->instructions[3]->As<ir::Var>();
ASSERT_NE(var_d, nullptr);
ASSERT_EQ(var_a->initializer, var_c->initializer);
ASSERT_EQ(var_b->initializer, var_d->initializer);
ASSERT_NE(var_a->initializer, var_b->initializer);
}
TEST_F(IR_BuilderImplTest, EmitLiteral_F32) {
auto* expr = Expr(1.2_f);
GlobalVar("a", ty.f32(), builtin::AddressSpace::kPrivate, expr);
@ -84,6 +107,25 @@ TEST_F(IR_BuilderImplTest, EmitLiteral_F32) {
EXPECT_EQ(1.2_f, val->As<constant::Scalar<f32>>()->ValueAs<f32>());
}
TEST_F(IR_BuilderImplTest, EmitLiteral_F32_Deduped) {
GlobalVar("a", ty.f32(), builtin::AddressSpace::kPrivate, Expr(1.2_f));
GlobalVar("b", ty.f32(), builtin::AddressSpace::kPrivate, Expr(1.25_f));
GlobalVar("c", ty.f32(), builtin::AddressSpace::kPrivate, Expr(1.2_f));
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
auto* var_a = m.Get().root_block->instructions[0]->As<ir::Var>();
ASSERT_NE(var_a, nullptr);
auto* var_b = m.Get().root_block->instructions[1]->As<ir::Var>();
ASSERT_NE(var_b, nullptr);
auto* var_c = m.Get().root_block->instructions[2]->As<ir::Var>();
ASSERT_NE(var_c, nullptr);
ASSERT_EQ(var_a->initializer, var_c->initializer);
ASSERT_NE(var_a->initializer, var_b->initializer);
}
TEST_F(IR_BuilderImplTest, EmitLiteral_F16) {
Enable(builtin::Extension::kF16);
auto* expr = Expr(1.2_h);
@ -99,6 +141,26 @@ TEST_F(IR_BuilderImplTest, EmitLiteral_F16) {
EXPECT_EQ(1.2_h, val->As<constant::Scalar<f16>>()->ValueAs<f32>());
}
TEST_F(IR_BuilderImplTest, EmitLiteral_F16_Deduped) {
Enable(builtin::Extension::kF16);
GlobalVar("a", ty.f16(), builtin::AddressSpace::kPrivate, Expr(1.2_h));
GlobalVar("b", ty.f16(), builtin::AddressSpace::kPrivate, Expr(1.25_h));
GlobalVar("c", ty.f16(), builtin::AddressSpace::kPrivate, Expr(1.2_h));
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
auto* var_a = m.Get().root_block->instructions[0]->As<ir::Var>();
ASSERT_NE(var_a, nullptr);
auto* var_b = m.Get().root_block->instructions[1]->As<ir::Var>();
ASSERT_NE(var_b, nullptr);
auto* var_c = m.Get().root_block->instructions[2]->As<ir::Var>();
ASSERT_NE(var_c, nullptr);
ASSERT_EQ(var_a->initializer, var_c->initializer);
ASSERT_NE(var_a->initializer, var_b->initializer);
}
TEST_F(IR_BuilderImplTest, EmitLiteral_I32) {
auto* expr = Expr(-2_i);
GlobalVar("a", ty.i32(), builtin::AddressSpace::kPrivate, expr);
@ -113,6 +175,25 @@ TEST_F(IR_BuilderImplTest, EmitLiteral_I32) {
EXPECT_EQ(-2_i, val->As<constant::Scalar<i32>>()->ValueAs<f32>());
}
TEST_F(IR_BuilderImplTest, EmitLiteral_I32_Deduped) {
GlobalVar("a", ty.i32(), builtin::AddressSpace::kPrivate, Expr(-2_i));
GlobalVar("b", ty.i32(), builtin::AddressSpace::kPrivate, Expr(2_i));
GlobalVar("c", ty.i32(), builtin::AddressSpace::kPrivate, Expr(-2_i));
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
auto* var_a = m.Get().root_block->instructions[0]->As<ir::Var>();
ASSERT_NE(var_a, nullptr);
auto* var_b = m.Get().root_block->instructions[1]->As<ir::Var>();
ASSERT_NE(var_b, nullptr);
auto* var_c = m.Get().root_block->instructions[2]->As<ir::Var>();
ASSERT_NE(var_c, nullptr);
ASSERT_EQ(var_a->initializer, var_c->initializer);
ASSERT_NE(var_a->initializer, var_b->initializer);
}
TEST_F(IR_BuilderImplTest, EmitLiteral_U32) {
auto* expr = Expr(2_u);
GlobalVar("a", ty.u32(), builtin::AddressSpace::kPrivate, expr);
@ -127,5 +208,24 @@ TEST_F(IR_BuilderImplTest, EmitLiteral_U32) {
EXPECT_EQ(2_u, val->As<constant::Scalar<u32>>()->ValueAs<f32>());
}
TEST_F(IR_BuilderImplTest, EmitLiteral_U32_Deduped) {
GlobalVar("a", ty.u32(), builtin::AddressSpace::kPrivate, Expr(2_u));
GlobalVar("b", ty.u32(), builtin::AddressSpace::kPrivate, Expr(3_u));
GlobalVar("c", ty.u32(), builtin::AddressSpace::kPrivate, Expr(2_u));
auto m = Build();
ASSERT_TRUE(m) << (!m ? m.Failure() : "");
auto* var_a = m.Get().root_block->instructions[0]->As<ir::Var>();
ASSERT_NE(var_a, nullptr);
auto* var_b = m.Get().root_block->instructions[1]->As<ir::Var>();
ASSERT_NE(var_b, nullptr);
auto* var_c = m.Get().root_block->instructions[2]->As<ir::Var>();
ASSERT_NE(var_c, nullptr);
ASSERT_EQ(var_a->initializer, var_c->initializer);
ASSERT_NE(var_a->initializer, var_b->initializer);
}
} // namespace
} // namespace tint::ir

View File

@ -18,6 +18,7 @@
#include <string>
#include "src/tint/constant/value.h"
#include "src/tint/ir/constant.h"
#include "src/tint/ir/function.h"
#include "src/tint/ir/instruction.h"
#include "src/tint/ir/value.h"
@ -67,7 +68,7 @@ class Module {
/// The flow node allocator
utils::BlockAllocator<FlowNode> flow_nodes;
/// The constant allocator
utils::BlockAllocator<constant::Value> constants;
utils::BlockAllocator<constant::Value> constants_arena;
/// The value allocator
utils::BlockAllocator<Value> values;
@ -82,6 +83,29 @@ class Module {
/// The symbol table for the module
SymbolTable symbols{prog_id_};
/// ConstantHasher provides a hash function for a constant::Value pointer, hashing the value
/// instead of the pointer itself.
struct ConstantHasher {
/// @param c the constant pointer to create a hash for
/// @return the hash value
inline std::size_t operator()(const constant::Value* c) const { return c->Hash(); }
};
/// ConstantEquals provides an equality function for two constant::Value pointers, comparing
/// their values instead of the pointers.
struct ConstantEquals {
/// @param a the first constant pointer to compare
/// @param b the second constant pointer to compare
/// @return the hash value
inline bool operator()(const constant::Value* a, const constant::Value* b) const {
return a->Equal(b);
}
};
/// The map of constant::Value to their ir::Constant.
utils::Hashmap<const constant::Value*, ir::Constant*, 16, ConstantHasher, ConstantEquals>
constants;
};
} // namespace tint::ir

View File

@ -157,25 +157,6 @@ class GeneratorImplIr {
}
};
/// ConstantHasher provides a hash function for a constant::Value pointer, hashing the value
/// instead of the pointer itself.
struct ConstantHasher {
/// @param c the constant::Value pointer to create a hash for
/// @return the hash value
inline std::size_t operator()(const constant::Value* c) const { return c->Hash(); }
};
/// ConstantEquals provides an equality function for two constant::Value pointers, comparing
/// their values instead of the pointers.
struct ConstantEquals {
/// @param a the first constant::Value pointer to compare
/// @param b the second constant::Value pointer to compare
/// @return the hash value
inline bool operator()(const constant::Value* a, const constant::Value* b) const {
return a->Equal(b);
}
};
/// The map of types to their result IDs.
utils::Hashmap<const type::Type*, uint32_t, 8> types_;
@ -183,7 +164,7 @@ class GeneratorImplIr {
utils::Hashmap<FunctionType, uint32_t, 8, FunctionType::Hasher> function_types_;
/// The map of constants to their result IDs.
utils::Hashmap<const constant::Value*, uint32_t, 16, ConstantHasher, ConstantEquals> constants_;
utils::Hashmap<const constant::Value*, uint32_t, 16> constants_;
/// The map of instructions to their result IDs.
utils::Hashmap<const ir::Instruction*, uint32_t, 8> instructions_;

View File

@ -155,10 +155,10 @@ TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) {
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
b.Branch(func->start_target, func->end_target);
auto* lhs = mod.constants.Create<constant::Composite>(
auto* lhs = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
utils::Vector{b.Constant(42_i)->value, b.Constant(-1_i)->value}, false, false);
auto* rhs = mod.constants.Create<constant::Composite>(
auto* rhs = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(mod.types.Get<type::I32>(), 2u),
utils::Vector{b.Constant(0_i)->value, b.Constant(-43_i)->value}, false, false);
func->start_target->instructions.Push(
@ -189,12 +189,12 @@ TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) {
auto* func = b.CreateFunction(mod.symbols.Register("foo"), mod.types.Get<type::Void>());
b.Branch(func->start_target, func->end_target);
auto* lhs = mod.constants.Create<constant::Composite>(
auto* lhs = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
utils::Vector{b.Constant(42_f)->value, b.Constant(-1_f)->value, b.Constant(0_f)->value,
b.Constant(1.25_f)->value},
false, false);
auto* rhs = mod.constants.Create<constant::Composite>(
auto* rhs = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(mod.types.Get<type::F32>(), 4u),
utils::Vector{b.Constant(0_f)->value, b.Constant(1.25_f)->value, b.Constant(-42_f)->value,
b.Constant(1_f)->value},

View File

@ -65,7 +65,7 @@ TEST_F(SpvGeneratorImplTest, Constant_F16) {
TEST_F(SpvGeneratorImplTest, Constant_Vec4Bool) {
auto* t = b.Constant(true);
auto* f = b.Constant(false);
auto* v = mod.constants.Create<constant::Composite>(
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(mod.types.Get<type::Bool>(), 4u),
utils::Vector{t->value, f->value, f->value, t->value}, false, true);
generator_.Constant(b.Constant(v));
@ -81,7 +81,7 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec2i) {
auto* i = mod.types.Get<type::I32>();
auto* i_42 = b.Constant(i32(42));
auto* i_n1 = b.Constant(i32(-1));
auto* v = mod.constants.Create<constant::Composite>(
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(i, 2u), utils::Vector{i_42->value, i_n1->value}, false, false);
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeInt 32 1
@ -97,7 +97,7 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec3u) {
auto* u_42 = b.Constant(u32(42));
auto* u_0 = b.Constant(u32(0));
auto* u_4b = b.Constant(u32(4000000000));
auto* v = mod.constants.Create<constant::Composite>(
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(u, 3u), utils::Vector{u_42->value, u_0->value, u_4b->value},
false, true);
generator_.Constant(b.Constant(v));
@ -116,7 +116,7 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec4f) {
auto* f_0 = b.Constant(f32(0));
auto* f_q = b.Constant(f32(0.25));
auto* f_n1 = b.Constant(f32(-1));
auto* v = mod.constants.Create<constant::Composite>(
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(f, 4u),
utils::Vector{f_42->value, f_0->value, f_q->value, f_n1->value}, false, true);
generator_.Constant(b.Constant(v));
@ -134,7 +134,7 @@ TEST_F(SpvGeneratorImplTest, Constant_Vec2h) {
auto* h = mod.types.Get<type::F16>();
auto* h_42 = b.Constant(f16(42));
auto* h_q = b.Constant(f16(0.25));
auto* v = mod.constants.Create<constant::Composite>(
auto* v = mod.constants_arena.Create<constant::Composite>(
mod.types.Get<type::Vector>(h, 2u), utils::Vector{h_42->value, h_q->value}, false, false);
generator_.Constant(b.Constant(v));
EXPECT_EQ(DumpTypes(), R"(%3 = OpTypeFloat 16