Move CreateComposite into ProgramBuilder.

This CL moves the CreateComposite helper into the ProgramBuilder.

Bug: tint:1718
Change-Id: I4aca7dc3d7192a7aa8b300f00529670aa9c09a27
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114202
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
dan sinclair 2022-12-15 21:56:32 +00:00 committed by Dawn LUCI CQ
parent 19ebcb2230
commit 02b466feb1
2 changed files with 100 additions and 77 deletions

View File

@ -87,6 +87,8 @@
#include "src/tint/ast/void.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/constant/composite.h"
#include "src/tint/constant/splat.h"
#include "src/tint/constant/value.h"
#include "src/tint/number.h"
#include "src/tint/program.h"
@ -470,11 +472,73 @@ class ProgramBuilder {
/// @param args the arguments to pass to the constructor
/// @returns the node pointer
template <typename T, typename... ARGS>
traits::EnableIf<traits::IsTypeOrDerived<T, constant::Value>, T>* create(ARGS&&... args) {
traits::EnableIf<traits::IsTypeOrDerived<T, constant::Value> &&
!traits::IsTypeOrDerived<T, constant::Composite> &&
!traits::IsTypeOrDerived<T, constant::Splat>,
T>*
create(ARGS&&... args) {
AssertNotMoved();
return constant_nodes_.Create<T>(std::forward<ARGS>(args)...);
}
/// Constructs a constant of a vector, matrix or array type.
///
/// Examines the element values and will return either a constant::Composite or a
/// constant::Splat, depending on the element types and values.
///
/// @param type the composite type
/// @param elements the composite elements
/// @returns the node pointer
template <typename T>
traits::EnableIf<traits::IsTypeOrDerived<T, constant::Composite> ||
traits::IsTypeOrDerived<T, constant::Splat>,
const constant::Value>*
create(const type::Type* type, utils::VectorRef<const constant::Value*> elements) {
AssertNotMoved();
if (elements.IsEmpty()) {
return nullptr;
}
bool any_zero = false;
bool all_zero = true;
bool all_equal = true;
auto* first = elements.Front();
for (auto* el : elements) {
if (!el) {
return nullptr;
}
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
if (all_zero && !el->AllZero()) {
all_zero = false;
}
if (all_equal && el != first) {
if (!el->Equal(first)) {
all_equal = false;
}
}
}
if (all_equal) {
return create<constant::Splat>(type, elements[0], elements.Length());
}
return constant_nodes_.Create<constant::Composite>(type, std::move(elements), all_zero,
any_zero);
}
/// Constructs a splat constant.
/// @param type the splat type
/// @param element the splat element
/// @param n the number of elements
/// @returns the node pointer
template <typename T>
traits::EnableIf<traits::IsTypeOrDerived<T, constant::Splat>, const constant::Splat>*
create(const type::Type* type, const constant::Value* element, size_t n) {
AssertNotMoved();
return constant_nodes_.Create<constant::Splat>(type, element, n);
}
/// Creates a new type::Type owned by the ProgramBuilder.
/// When the ProgramBuilder is destructed, owned ProgramBuilder and the
/// returned `Type` will also be destructed.

View File

@ -232,11 +232,6 @@ std::make_unsigned_t<T> CountTrailingBits(T e, T bit_value_to_count) {
return count;
}
// Forward declaration
const constant::Value* CreateComposite(ProgramBuilder& builder,
const type::Type* type,
utils::VectorRef<const constant::Value*> elements);
template <typename T>
ConstEval::Result ScalarConvert(const constant::Scalar<T>* scalar,
ProgramBuilder& builder,
@ -347,7 +342,7 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite,
}
conv_els.Push(conv_el.Get());
}
return CreateComposite(builder, target_ty, std::move(conv_els));
return builder.create<constant::Composite>(target_ty, std::move(conv_els));
}
ConstEval::Result ConvertInternal(const constant::Value* c,
@ -438,7 +433,7 @@ const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type
// All members were of the same type, so the zero value is the same for all members.
return builder.create<constant::Splat>(type, zeros[0], s->Members().Length());
}
return CreateComposite(builder, s, std::move(zeros));
return builder.create<constant::Composite>(s, std::move(zeros));
},
[&](Default) -> const constant::Value* {
return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* {
@ -449,42 +444,6 @@ const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type
});
}
/// CreateComposite is used to construct a constant of a vector, matrix or array type.
/// CreateComposite examines the element values and will return either a Composite or a Splat,
/// depending on the element types and values.
const constant::Value* CreateComposite(ProgramBuilder& builder,
const type::Type* type,
utils::VectorRef<const constant::Value*> elements) {
if (elements.IsEmpty()) {
return nullptr;
}
bool any_zero = false;
bool all_zero = true;
bool all_equal = true;
auto* first = elements.Front();
for (auto* el : elements) {
if (!el) {
return nullptr;
}
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
if (all_zero && !el->AllZero()) {
all_zero = false;
}
if (all_equal && el != first) {
if (!el->Equal(first)) {
all_equal = false;
}
}
}
if (all_equal) {
return builder.create<constant::Splat>(type, elements[0], elements.Length());
} else {
return builder.create<constant::Composite>(type, std::move(elements), all_zero, any_zero);
}
}
namespace detail {
/// Implementation of TransformElements
template <typename F, typename... CONSTANTS>
@ -515,7 +474,7 @@ ConstEval::Result TransformElements(ProgramBuilder& builder,
return el.Failure();
}
}
return CreateComposite(builder, composite_ty, std::move(els));
return builder.create<constant::Composite>(composite_ty, std::move(els));
}
} // namespace detail
@ -569,7 +528,7 @@ ConstEval::Result TransformBinaryElements(ProgramBuilder& builder,
return el.Failure();
}
}
return CreateComposite(builder, composite_ty, std::move(els));
return builder.create<constant::Composite>(composite_ty, std::move(els));
}
} // namespace
@ -1211,7 +1170,7 @@ ConstEval::Result ConstEval::ArrayOrStructInit(const type::Type* ty,
for (auto* arg : args) {
els.Push(arg->ConstantValue());
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::Conv(const type::Type* ty,
@ -1255,7 +1214,7 @@ ConstEval::Result ConstEval::VecSplat(const type::Type* ty,
ConstEval::Result ConstEval::VecInitS(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source&) {
return CreateComposite(builder, ty, args);
return builder.create<constant::Composite>(ty, args);
}
ConstEval::Result ConstEval::VecInitM(const type::Type* ty,
@ -1281,7 +1240,7 @@ ConstEval::Result ConstEval::VecInitM(const type::Type* ty,
els.Push(val);
}
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::MatInitS(const type::Type* ty,
@ -1296,15 +1255,15 @@ ConstEval::Result ConstEval::MatInitS(const type::Type* ty,
auto i = r + c * m->rows();
column.Push(args[i]);
}
els.Push(CreateComposite(builder, m->ColumnType(), std::move(column)));
els.Push(builder.create<constant::Composite>(m->ColumnType(), std::move(column)));
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::MatInitV(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
const Source&) {
return CreateComposite(builder, ty, args);
return builder.create<constant::Composite>(ty, args);
}
ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr,
@ -1357,7 +1316,7 @@ ConstEval::Result ConstEval::Swizzle(const type::Type* ty,
}
auto values = utils::Transform<4>(
indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
return CreateComposite(builder, ty, std::move(values));
return builder.create<constant::Composite>(ty, std::move(values));
}
ConstEval::Result ConstEval::Bitcast(const type::Type*, const sem::Expression*) {
@ -1484,7 +1443,7 @@ ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty,
}
result.Push(r.Get());
}
return CreateComposite(builder, ty, result);
return builder.create<constant::Composite>(ty, result);
}
ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty,
utils::VectorRef<const constant::Value*> args,
@ -1534,7 +1493,7 @@ ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty,
}
result.Push(r.Get());
}
return CreateComposite(builder, ty, result);
return builder.create<constant::Composite>(ty, result);
}
ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty,
@ -1596,9 +1555,9 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty,
// Add column vector to matrix
auto* col_vec_ty = ty->As<type::Matrix>()->ColumnType();
result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec));
result_mat.Push(builder.create<constant::Composite>(col_vec_ty, col_vec));
}
return CreateComposite(builder, ty, result_mat);
return builder.create<constant::Composite>(ty, result_mat);
}
ConstEval::Result ConstEval::OpDivide(const type::Type* ty,
@ -2208,8 +2167,8 @@ ConstEval::Result ConstEval::cross(const type::Type* ty,
return utils::Failure;
}
return CreateComposite(builder, ty,
utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()});
return builder.create<constant::Composite>(
ty, utils::Vector<const constant::Value*, 3>{x.Get(), y.Get(), z.Get()});
}
ConstEval::Result ConstEval::degrees(const type::Type* ty,
@ -2592,21 +2551,20 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty,
}
auto fract_ty = builder.create<type::Vector>(fract_els[0]->Type(), vec->Width());
auto exp_ty = builder.create<type::Vector>(exp_els[0]->Type(), vec->Width());
return CreateComposite(builder, ty,
utils::Vector<const constant::Value*, 2>{
CreateComposite(builder, fract_ty, std::move(fract_els)),
CreateComposite(builder, exp_ty, std::move(exp_els)),
});
return builder.create<constant::Composite>(
ty, utils::Vector<const constant::Value*, 2>{
builder.create<constant::Composite>(fract_ty, std::move(fract_els)),
builder.create<constant::Composite>(exp_ty, std::move(exp_els)),
});
} else {
auto fe = scalar(arg);
if (!fe.fract || !fe.exp) {
return utils::Failure;
}
return CreateComposite(builder, ty,
utils::Vector<const constant::Value*, 2>{
fe.fract.Get(),
fe.exp.Get(),
});
return builder.create<constant::Composite>(ty, utils::Vector<const constant::Value*, 2>{
fe.fract.Get(),
fe.exp.Get(),
});
}
}
@ -2838,7 +2796,7 @@ ConstEval::Result ConstEval::modf(const type::Type* ty,
return utils::Failure;
}
return CreateComposite(builder, ty, std::move(fields));
return builder.create<constant::Composite>(ty, std::move(fields));
}
ConstEval::Result ConstEval::normalize(const type::Type* ty,
@ -3412,9 +3370,10 @@ ConstEval::Result ConstEval::transpose(const type::Type* ty,
for (size_t c = 0; c < mat_ty->columns(); ++c) {
new_col_vec.Push(me(r, c));
}
result_mat.Push(CreateComposite(builder, result_mat_ty->ColumnType(), new_col_vec));
result_mat.Push(
builder.create<constant::Composite>(result_mat_ty->ColumnType(), new_col_vec));
}
return CreateComposite(builder, ty, result_mat);
return builder.create<constant::Composite>(ty, result_mat);
}
ConstEval::Result ConstEval::trunc(const type::Type* ty,
@ -3450,7 +3409,7 @@ ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty,
}
els.Push(el.Get());
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
@ -3470,7 +3429,7 @@ ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty,
}
els.Push(el.Get());
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
@ -3489,7 +3448,7 @@ ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty,
}
els.Push(el.Get());
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
@ -3509,7 +3468,7 @@ ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty,
}
els.Push(el.Get());
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
@ -3528,7 +3487,7 @@ ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty,
}
els.Push(el.Get());
}
return CreateComposite(builder, ty, std::move(els));
return builder.create<constant::Composite>(ty, std::move(els));
}
ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,