Minor cleanups from #114202.

This Cl adds a couple cleanups requested in 114202 as a followup.
Templates updated to have the EnableIf in the `template` block. The code
for `create` of a Splat or Composite is moved to a helper method.

Bug: tint:1718
Change-Id: Ib302d78633c6102cfbe17d63f0a4841ecf147472
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/116100
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
dan sinclair 2023-01-03 20:25:37 +00:00 committed by Dawn LUCI CQ
parent 3e449f2194
commit cf58122c58
2 changed files with 49 additions and 38 deletions

View File

@ -153,4 +153,39 @@ const ast::Function* ProgramBuilder::WrapInFunction(utils::VectorRef<const ast::
}); });
} }
const constant::Value* ProgramBuilder::createSplatOrComposite(
const type::Type* type,
utils::VectorRef<const constant::Value*> elements) {
if (elements.IsEmpty()) {
return nullptr;
}
bool any_zero = false;
bool all_zero = true;
bool all_equal = true;
auto* first = elements.Front();
for (auto* el : elements) {
if (!el) {
return nullptr;
}
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
if (all_zero && !el->AllZero()) {
all_zero = false;
}
if (all_equal && el != first) {
if (!el->Equal(first)) {
all_equal = false;
}
}
}
if (all_equal) {
return create<constant::Splat>(type, elements[0], elements.Length());
}
return constant_nodes_.Create<constant::Composite>(type, std::move(elements), all_zero,
any_zero);
}
} // namespace tint } // namespace tint

View File

@ -493,42 +493,13 @@ class ProgramBuilder {
/// @param type the composite type /// @param type the composite type
/// @param elements the composite elements /// @param elements the composite elements
/// @returns the node pointer /// @returns the node pointer
template <typename T> template <typename T,
traits::EnableIf<traits::IsTypeOrDerived<T, constant::Composite> || typename = traits::EnableIf<traits::IsTypeOrDerived<T, constant::Composite> ||
traits::IsTypeOrDerived<T, constant::Splat>, traits::IsTypeOrDerived<T, constant::Splat>>>
const constant::Value>* const constant::Value* create(const type::Type* type,
create(const type::Type* type, utils::VectorRef<const constant::Value*> elements) { utils::VectorRef<const constant::Value*> elements) {
AssertNotMoved(); AssertNotMoved();
if (elements.IsEmpty()) { return createSplatOrComposite(type, elements);
return nullptr;
}
bool any_zero = false;
bool all_zero = true;
bool all_equal = true;
auto* first = elements.Front();
for (auto* el : elements) {
if (!el) {
return nullptr;
}
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
if (all_zero && !el->AllZero()) {
all_zero = false;
}
if (all_equal && el != first) {
if (!el->Equal(first)) {
all_equal = false;
}
}
}
if (all_equal) {
return create<constant::Splat>(type, elements[0], elements.Length());
}
return constant_nodes_.Create<constant::Composite>(type, std::move(elements), all_zero,
any_zero);
} }
/// Constructs a splat constant. /// Constructs a splat constant.
@ -536,9 +507,10 @@ class ProgramBuilder {
/// @param element the splat element /// @param element the splat element
/// @param n the number of elements /// @param n the number of elements
/// @returns the node pointer /// @returns the node pointer
template <typename T> template <typename T, typename = traits::EnableIf<traits::IsTypeOrDerived<T, constant::Splat>>>
traits::EnableIf<traits::IsTypeOrDerived<T, constant::Splat>, const constant::Splat>* const constant::Splat* create(const type::Type* type,
create(const type::Type* type, const constant::Value* element, size_t n) { const constant::Value* element,
size_t n) {
AssertNotMoved(); AssertNotMoved();
return constant_nodes_.Create<constant::Splat>(type, element, n); return constant_nodes_.Create<constant::Splat>(type, element, n);
} }
@ -3351,6 +3323,10 @@ class ProgramBuilder {
void AssertNotMoved() const; void AssertNotMoved() const;
private: private:
const constant::Value* createSplatOrComposite(
const type::Type* type,
utils::VectorRef<const constant::Value*> elements);
ProgramID id_; ProgramID id_;
ast::NodeID last_ast_node_id_ = ast::NodeID{static_cast<decltype(ast::NodeID::value)>(0) - 1}; ast::NodeID last_ast_node_id_ = ast::NodeID{static_cast<decltype(ast::NodeID::value)>(0) - 1};
type::Manager types_; type::Manager types_;