From 02b466feb1897b6a6a1f6c44bf9f156411331b0e Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Thu, 15 Dec 2022 21:56:32 +0000 Subject: [PATCH] Move CreateComposite into ProgramBuilder. This CL moves the CreateComposite helper into the ProgramBuilder. Bug: tint:1718 Change-Id: I4aca7dc3d7192a7aa8b300f00529670aa9c09a27 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/114202 Kokoro: Kokoro Auto-Submit: Dan Sinclair Commit-Queue: Dan Sinclair Reviewed-by: Ben Clayton --- src/tint/program_builder.h | 66 ++++++++++++++++++- src/tint/resolver/const_eval.cc | 111 ++++++++++---------------------- 2 files changed, 100 insertions(+), 77 deletions(-) diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 83ad287f8d..597463c582 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -87,6 +87,8 @@ #include "src/tint/ast/void.h" #include "src/tint/ast/while_statement.h" #include "src/tint/ast/workgroup_attribute.h" +#include "src/tint/constant/composite.h" +#include "src/tint/constant/splat.h" #include "src/tint/constant/value.h" #include "src/tint/number.h" #include "src/tint/program.h" @@ -470,11 +472,73 @@ class ProgramBuilder { /// @param args the arguments to pass to the constructor /// @returns the node pointer template - traits::EnableIf, T>* create(ARGS&&... args) { + traits::EnableIf && + !traits::IsTypeOrDerived && + !traits::IsTypeOrDerived, + T>* + create(ARGS&&... args) { AssertNotMoved(); return constant_nodes_.Create(std::forward(args)...); } + /// Constructs a constant of a vector, matrix or array type. + /// + /// Examines the element values and will return either a constant::Composite or a + /// constant::Splat, depending on the element types and values. + /// + /// @param type the composite type + /// @param elements the composite elements + /// @returns the node pointer + template + traits::EnableIf || + traits::IsTypeOrDerived, + const constant::Value>* + create(const type::Type* type, utils::VectorRef elements) { + AssertNotMoved(); + if (elements.IsEmpty()) { + return nullptr; + } + + bool any_zero = false; + bool all_zero = true; + bool all_equal = true; + auto* first = elements.Front(); + for (auto* el : elements) { + if (!el) { + return nullptr; + } + if (!any_zero && el->AnyZero()) { + any_zero = true; + } + if (all_zero && !el->AllZero()) { + all_zero = false; + } + if (all_equal && el != first) { + if (!el->Equal(first)) { + all_equal = false; + } + } + } + if (all_equal) { + return create(type, elements[0], elements.Length()); + } + + return constant_nodes_.Create(type, std::move(elements), all_zero, + any_zero); + } + + /// Constructs a splat constant. + /// @param type the splat type + /// @param element the splat element + /// @param n the number of elements + /// @returns the node pointer + template + traits::EnableIf, const constant::Splat>* + create(const type::Type* type, const constant::Value* element, size_t n) { + AssertNotMoved(); + return constant_nodes_.Create(type, element, n); + } + /// Creates a new type::Type owned by the ProgramBuilder. /// When the ProgramBuilder is destructed, owned ProgramBuilder and the /// returned `Type` will also be destructed. diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index a92b02a74d..628975b348 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -232,11 +232,6 @@ std::make_unsigned_t CountTrailingBits(T e, T bit_value_to_count) { return count; } -// Forward declaration -const constant::Value* CreateComposite(ProgramBuilder& builder, - const type::Type* type, - utils::VectorRef elements); - template ConstEval::Result ScalarConvert(const constant::Scalar* scalar, ProgramBuilder& builder, @@ -347,7 +342,7 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite, } conv_els.Push(conv_el.Get()); } - return CreateComposite(builder, target_ty, std::move(conv_els)); + return builder.create(target_ty, std::move(conv_els)); } ConstEval::Result ConvertInternal(const constant::Value* c, @@ -438,7 +433,7 @@ const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type // All members were of the same type, so the zero value is the same for all members. return builder.create(type, zeros[0], s->Members().Length()); } - return CreateComposite(builder, s, std::move(zeros)); + return builder.create(s, std::move(zeros)); }, [&](Default) -> const constant::Value* { return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* { @@ -449,42 +444,6 @@ const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type }); } -/// CreateComposite is used to construct a constant of a vector, matrix or array type. -/// CreateComposite examines the element values and will return either a Composite or a Splat, -/// depending on the element types and values. -const constant::Value* CreateComposite(ProgramBuilder& builder, - const type::Type* type, - utils::VectorRef elements) { - if (elements.IsEmpty()) { - return nullptr; - } - bool any_zero = false; - bool all_zero = true; - bool all_equal = true; - auto* first = elements.Front(); - for (auto* el : elements) { - if (!el) { - return nullptr; - } - if (!any_zero && el->AnyZero()) { - any_zero = true; - } - if (all_zero && !el->AllZero()) { - all_zero = false; - } - if (all_equal && el != first) { - if (!el->Equal(first)) { - all_equal = false; - } - } - } - if (all_equal) { - return builder.create(type, elements[0], elements.Length()); - } else { - return builder.create(type, std::move(elements), all_zero, any_zero); - } -} - namespace detail { /// Implementation of TransformElements template @@ -515,7 +474,7 @@ ConstEval::Result TransformElements(ProgramBuilder& builder, return el.Failure(); } } - return CreateComposite(builder, composite_ty, std::move(els)); + return builder.create(composite_ty, std::move(els)); } } // namespace detail @@ -569,7 +528,7 @@ ConstEval::Result TransformBinaryElements(ProgramBuilder& builder, return el.Failure(); } } - return CreateComposite(builder, composite_ty, std::move(els)); + return builder.create(composite_ty, std::move(els)); } } // namespace @@ -1211,7 +1170,7 @@ ConstEval::Result ConstEval::ArrayOrStructInit(const type::Type* ty, for (auto* arg : args) { els.Push(arg->ConstantValue()); } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::Conv(const type::Type* ty, @@ -1255,7 +1214,7 @@ ConstEval::Result ConstEval::VecSplat(const type::Type* ty, ConstEval::Result ConstEval::VecInitS(const type::Type* ty, utils::VectorRef args, const Source&) { - return CreateComposite(builder, ty, args); + return builder.create(ty, args); } ConstEval::Result ConstEval::VecInitM(const type::Type* ty, @@ -1281,7 +1240,7 @@ ConstEval::Result ConstEval::VecInitM(const type::Type* ty, els.Push(val); } } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::MatInitS(const type::Type* ty, @@ -1296,15 +1255,15 @@ ConstEval::Result ConstEval::MatInitS(const type::Type* ty, auto i = r + c * m->rows(); column.Push(args[i]); } - els.Push(CreateComposite(builder, m->ColumnType(), std::move(column))); + els.Push(builder.create(m->ColumnType(), std::move(column))); } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::MatInitV(const type::Type* ty, utils::VectorRef args, const Source&) { - return CreateComposite(builder, ty, args); + return builder.create(ty, args); } ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr, @@ -1357,7 +1316,7 @@ ConstEval::Result ConstEval::Swizzle(const type::Type* ty, } auto values = utils::Transform<4>( indices, [&](uint32_t i) { return vec_val->Index(static_cast(i)); }); - return CreateComposite(builder, ty, std::move(values)); + return builder.create(ty, std::move(values)); } ConstEval::Result ConstEval::Bitcast(const type::Type*, const sem::Expression*) { @@ -1484,7 +1443,7 @@ ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty, } result.Push(r.Get()); } - return CreateComposite(builder, ty, result); + return builder.create(ty, result); } ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, utils::VectorRef args, @@ -1534,7 +1493,7 @@ ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, } result.Push(r.Get()); } - return CreateComposite(builder, ty, result); + return builder.create(ty, result); } ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, @@ -1596,9 +1555,9 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, // Add column vector to matrix auto* col_vec_ty = ty->As()->ColumnType(); - result_mat.Push(CreateComposite(builder, col_vec_ty, col_vec)); + result_mat.Push(builder.create(col_vec_ty, col_vec)); } - return CreateComposite(builder, ty, result_mat); + return builder.create(ty, result_mat); } ConstEval::Result ConstEval::OpDivide(const type::Type* ty, @@ -2208,8 +2167,8 @@ ConstEval::Result ConstEval::cross(const type::Type* ty, return utils::Failure; } - return CreateComposite(builder, ty, - utils::Vector{x.Get(), y.Get(), z.Get()}); + return builder.create( + ty, utils::Vector{x.Get(), y.Get(), z.Get()}); } ConstEval::Result ConstEval::degrees(const type::Type* ty, @@ -2592,21 +2551,20 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, } auto fract_ty = builder.create(fract_els[0]->Type(), vec->Width()); auto exp_ty = builder.create(exp_els[0]->Type(), vec->Width()); - return CreateComposite(builder, ty, - utils::Vector{ - CreateComposite(builder, fract_ty, std::move(fract_els)), - CreateComposite(builder, exp_ty, std::move(exp_els)), - }); + return builder.create( + ty, utils::Vector{ + builder.create(fract_ty, std::move(fract_els)), + builder.create(exp_ty, std::move(exp_els)), + }); } else { auto fe = scalar(arg); if (!fe.fract || !fe.exp) { return utils::Failure; } - return CreateComposite(builder, ty, - utils::Vector{ - fe.fract.Get(), - fe.exp.Get(), - }); + return builder.create(ty, utils::Vector{ + fe.fract.Get(), + fe.exp.Get(), + }); } } @@ -2838,7 +2796,7 @@ ConstEval::Result ConstEval::modf(const type::Type* ty, return utils::Failure; } - return CreateComposite(builder, ty, std::move(fields)); + return builder.create(ty, std::move(fields)); } ConstEval::Result ConstEval::normalize(const type::Type* ty, @@ -3412,9 +3370,10 @@ ConstEval::Result ConstEval::transpose(const type::Type* ty, for (size_t c = 0; c < mat_ty->columns(); ++c) { new_col_vec.Push(me(r, c)); } - result_mat.Push(CreateComposite(builder, result_mat_ty->ColumnType(), new_col_vec)); + result_mat.Push( + builder.create(result_mat_ty->ColumnType(), new_col_vec)); } - return CreateComposite(builder, ty, result_mat); + return builder.create(ty, result_mat); } ConstEval::Result ConstEval::trunc(const type::Type* ty, @@ -3450,7 +3409,7 @@ ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty, } els.Push(el.Get()); } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty, @@ -3470,7 +3429,7 @@ ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty, } els.Push(el.Get()); } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty, @@ -3489,7 +3448,7 @@ ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty, } els.Push(el.Get()); } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty, @@ -3509,7 +3468,7 @@ ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty, } els.Push(el.Get()); } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty, @@ -3528,7 +3487,7 @@ ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty, } els.Push(el.Get()); } - return CreateComposite(builder, ty, std::move(els)); + return builder.create(ty, std::move(els)); } ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty,