tint: Fix constant::Splat conversion of struct types

Conversion can happen for structure materialization (modf, frexp).
If both structure members are the same type and value, then a constant::Splat will be constructed, which needs to handle conversion.

Bug: chromium:1417515
Change-Id: Iadd14ce00b8d5c22226c601ec5af9a84e6c0c5cf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/122900
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2023-03-09 23:22:27 +00:00 committed by Dawn LUCI CQ
parent a92a8078c5
commit 574b4b1996
14 changed files with 228 additions and 38 deletions

View File

@ -25,31 +25,40 @@
namespace tint::constant { namespace tint::constant {
/// Composite holds a number of mixed child values. /// Composite holds a number of mixed child values.
/// Composite may be of a vector, matrix or array type. /// Composite may be of a vector, matrix, array or structure type.
/// If each element is the same type and value, then a Splat would be a more efficient constant /// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate type. /// implementation. Use CreateComposite() to create the appropriate type.
class Composite : public Castable<Composite, constant::Value> { class Composite : public Castable<Composite, Value> {
public: public:
/// Constructor /// Constructor
/// @param t the compsite type /// @param t the compsite type
/// @param els the composite elements /// @param els the composite elements
/// @param all_0 true if all elements are 0 /// @param all_0 true if all elements are 0
/// @param any_0 true if any element is 0 /// @param any_0 true if any element is 0
Composite(const type::Type* t, Composite(const type::Type* t, utils::VectorRef<const Value*> els, bool all_0, bool any_0);
utils::VectorRef<const constant::Value*> els,
bool all_0,
bool any_0);
~Composite() override; ~Composite() override;
/// @copydoc Value::Type()
const type::Type* Type() const override { return type; } const type::Type* Type() const override { return type; }
const constant::Value* Index(size_t i) const override { /// @copydoc Value::Index()
const Value* Index(size_t i) const override {
return i < elements.Length() ? elements[i] : nullptr; return i < elements.Length() ? elements[i] : nullptr;
} }
/// @copydoc Value::NumElements()
size_t NumElements() const override { return elements.Length(); }
/// @copydoc Value::AllZero()
bool AllZero() const override { return all_zero; } bool AllZero() const override { return all_zero; }
/// @copydoc Value::AnyZero()
bool AnyZero() const override { return any_zero; } bool AnyZero() const override { return any_zero; }
/// @copydoc Value::AllEqual()
bool AllEqual() const override { return false; } bool AllEqual() const override { return false; }
/// @copydoc Value::Hash()
size_t Hash() const override { return hash; } size_t Hash() const override { return hash; }
/// Clones the constant into the provided context /// Clones the constant into the provided context
@ -60,7 +69,7 @@ class Composite : public Castable<Composite, constant::Value> {
/// The composite type /// The composite type
type::Type const* const type; type::Type const* const type;
/// The composite elements /// The composite elements
const utils::Vector<const constant::Value*, 4> elements; const utils::Vector<const Value*, 4> elements;
/// True if all elements are zero /// True if all elements are zero
const bool all_zero; const bool all_zero;
/// True if any element is zero /// True if any element is zero
@ -69,6 +78,7 @@ class Composite : public Castable<Composite, constant::Value> {
const size_t hash; const size_t hash;
protected: protected:
/// @copydoc Value::InternalValue()
std::variant<std::monostate, AInt, AFloat> InternalValue() const override { return {}; } std::variant<std::monostate, AInt, AFloat> InternalValue() const override { return {}; }
private: private:

View File

@ -25,7 +25,7 @@ namespace tint::constant {
/// Scalar holds a single scalar or abstract-numeric value. /// Scalar holds a single scalar or abstract-numeric value.
template <typename T> template <typename T>
class Scalar : public Castable<Scalar<T>, constant::Value> { class Scalar : public Castable<Scalar<T>, Value> {
public: public:
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>, static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
"T must be a Number or bool"); "T must be a Number or bool");
@ -40,13 +40,25 @@ class Scalar : public Castable<Scalar<T>, constant::Value> {
} }
~Scalar() override = default; ~Scalar() override = default;
/// @copydoc Value::Type()
const type::Type* Type() const override { return type; } const type::Type* Type() const override { return type; }
const constant::Value* Index(size_t) const override { return nullptr; } /// @return nullptr, as Scalar does not hold any elements.
const Value* Index(size_t) const override { return nullptr; }
/// @copydoc Value::NumElements()
size_t NumElements() const override { return 1; }
/// @copydoc Value::AllZero()
bool AllZero() const override { return IsPositiveZero(); } bool AllZero() const override { return IsPositiveZero(); }
/// @copydoc Value::AnyZero()
bool AnyZero() const override { return IsPositiveZero(); } bool AnyZero() const override { return IsPositiveZero(); }
/// @copydoc Value::AllEqual()
bool AllEqual() const override { return true; } bool AllEqual() const override { return true; }
/// @copydoc Value::Hash()
size_t Hash() const override { return utils::Hash(type, ValueOf()); } size_t Hash() const override { return utils::Hash(type, ValueOf()); }
/// Clones the constant into the provided context /// Clones the constant into the provided context
@ -79,6 +91,7 @@ class Scalar : public Castable<Scalar<T>, constant::Value> {
const T value; const T value;
protected: protected:
/// @copydoc Value::InternalValue()
std::variant<std::monostate, AInt, AFloat> InternalValue() const override { std::variant<std::monostate, AInt, AFloat> InternalValue() const override {
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) { if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return static_cast<AFloat>(value); return static_cast<AFloat>(value);

View File

@ -25,14 +25,14 @@ namespace tint::constant {
/// Splat holds a single value, duplicated as all children. /// Splat holds a single value, duplicated as all children.
/// ///
/// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is /// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is
/// identical. Splat may be of a vector, matrix or array type. /// identical. Splat may be of a vector, matrix, array or structure type.
class Splat : public Castable<Splat, constant::Value> { class Splat : public Castable<Splat, Value> {
public: public:
/// Constructor /// Constructor
/// @param t the splat type /// @param t the splat type
/// @param e the splat element /// @param e the splat element
/// @param n the number of items in the splat /// @param n the number of items in the splat
Splat(const type::Type* t, const constant::Value* e, size_t n); Splat(const type::Type* t, const Value* e, size_t n);
~Splat() override; ~Splat() override;
/// @returns the type of the splat /// @returns the type of the splat
@ -41,7 +41,10 @@ class Splat : public Castable<Splat, constant::Value> {
/// Retrieve item at index @p i /// Retrieve item at index @p i
/// @param i the index to retrieve /// @param i the index to retrieve
/// @returns the element, or nullptr if out of bounds /// @returns the element, or nullptr if out of bounds
const constant::Value* Index(size_t i) const override { return i < count ? el : nullptr; } const Value* Index(size_t i) const override { return i < count ? el : nullptr; }
/// @copydoc Value::NumElements()
size_t NumElements() const override { return count; }
/// @returns true if the element is zero /// @returns true if the element is zero
bool AllZero() const override { return el->AllZero(); } bool AllZero() const override { return el->AllZero(); }
@ -61,7 +64,7 @@ class Splat : public Castable<Splat, constant::Value> {
/// The type of the splat element /// The type of the splat element
type::Type const* const type; type::Type const* const type;
/// The element stored in the splat /// The element stored in the splat
const constant::Value* el; const Value* el;
/// The number of items in the splat /// The number of items in the splat
const size_t count; const size_t count;

View File

@ -37,6 +37,7 @@ class Value : public Castable<Value, Node> {
/// @returns the type of the value /// @returns the type of the value
virtual const type::Type* Type() const = 0; virtual const type::Type* Type() const = 0;
/// @param i the index of the element
/// @returns the child element with the given index, or nullptr if there are no children, or /// @returns the child element with the given index, or nullptr if there are no children, or
/// the index is out of bounds. /// the index is out of bounds.
/// ///
@ -44,7 +45,10 @@ class Value : public Castable<Value, Node> {
/// For vectors, this returns the i'th element of the vector. /// For vectors, this returns the i'th element of the vector.
/// For matrices, this returns the i'th column vector of the matrix. /// For matrices, this returns the i'th column vector of the matrix.
/// For structures, this returns the i'th member field of the structure. /// For structures, this returns the i'th member field of the structure.
virtual const Value* Index(size_t) const = 0; virtual const Value* Index(size_t i) const = 0;
/// @return the number of elements held by this Value
virtual size_t NumElements() const = 0;
/// @returns true if child elements are positive-zero valued. /// @returns true if child elements are positive-zero valued.
virtual bool AllZero() const = 0; virtual bool AllZero() const = 0;
@ -74,7 +78,7 @@ class Value : public Castable<Value, Node> {
/// @param b the value to compare too /// @param b the value to compare too
/// @returns true if this value is equal to @p b /// @returns true if this value is equal to @p b
bool Equal(const constant::Value* b) const; bool Equal(const Value* b) const;
/// Clones the constant into the provided context /// Clones the constant into the provided context
/// @param ctx the clone context /// @param ctx the clone context

View File

@ -326,35 +326,20 @@ ConstEval::Result ConvertInternal(const constant::Value* c,
const Source& source, const Source& source,
bool use_runtime_semantics); bool use_runtime_semantics);
ConstEval::Result SplatConvert(const constant::Splat* splat, ConstEval::Result CompositeConvert(const constant::Value* value,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source,
bool use_runtime_semantics) {
// Convert the single splatted element type.
auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source,
use_runtime_semantics);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
}
ConstEval::Result CompositeConvert(const constant::Composite* composite,
ProgramBuilder& builder, ProgramBuilder& builder,
const type::Type* target_ty, const type::Type* target_ty,
const Source& source, const Source& source,
bool use_runtime_semantics) { bool use_runtime_semantics) {
const size_t el_count = value->NumElements();
// Convert each of the composite element types. // Convert each of the composite element types.
utils::Vector<const constant::Value*, 4> conv_els; utils::Vector<const constant::Value*, 4> conv_els;
conv_els.Reserve(composite->elements.Length()); conv_els.Reserve(el_count);
std::function<const type::Type*(size_t idx)> target_el_ty; std::function<const type::Type*(size_t idx)> target_el_ty;
if (auto* str = target_ty->As<type::Struct>()) { if (auto* str = target_ty->As<type::Struct>()) {
if (TINT_UNLIKELY(str->Members().Length() != composite->elements.Length())) { if (TINT_UNLIKELY(str->Members().Length() != el_count)) {
TINT_ICE(Resolver, builder.Diagnostics()) TINT_ICE(Resolver, builder.Diagnostics())
<< "const-eval conversion of structure has mismatched element counts"; << "const-eval conversion of structure has mismatched element counts";
return utils::Failure; return utils::Failure;
@ -365,7 +350,8 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite,
target_el_ty = [el_ty](size_t) { return el_ty; }; target_el_ty = [el_ty](size_t) { return el_ty; };
} }
for (auto* el : composite->elements) { for (size_t i = 0; i < el_count; i++) {
auto* el = value->Index(i);
auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source, auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source,
use_runtime_semantics); use_runtime_semantics);
if (!conv_el) { if (!conv_el) {
@ -379,6 +365,40 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite,
return builder.create<constant::Composite>(target_ty, std::move(conv_els)); return builder.create<constant::Composite>(target_ty, std::move(conv_els));
} }
ConstEval::Result SplatConvert(const constant::Splat* splat,
ProgramBuilder& builder,
const type::Type* target_ty,
const Source& source,
bool use_runtime_semantics) {
const type::Type* target_el_ty = nullptr;
if (auto* str = target_ty->As<type::Struct>()) {
// Structure conversion.
auto members = str->Members();
target_el_ty = members[0]->Type();
// Structures can only be converted during materialization. The user cannot declare the
// target structure type, so each member type must be the same default materialization type.
for (size_t i = 1; i < members.Length(); i++) {
if (members[i]->Type() != target_el_ty) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "inconsistent target struct member types for SplatConvert";
return utils::Failure;
}
}
} else {
target_el_ty = type::Type::ElementOf(target_ty);
}
// Convert the single splatted element type.
auto conv_el = ConvertInternal(splat->el, builder, target_el_ty, source, use_runtime_semantics);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
}
ConstEval::Result ConvertInternal(const constant::Value* c, ConstEval::Result ConvertInternal(const constant::Value* c,
ProgramBuilder& builder, ProgramBuilder& builder,
const type::Type* target_ty, const type::Type* target_ty,

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "src/tint/resolver/const_eval_test.h" #include "src/tint/resolver/const_eval_test.h"
#include "src/tint/sem/materialize.h"
using namespace tint::number_suffixes; // NOLINT using namespace tint::number_suffixes; // NOLINT
@ -472,5 +473,56 @@ TEST_F(ResolverConstEvalTest, Vec3_Convert_Small_f32_to_f16) {
EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->ValueAs<AFloat>().value)); EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->ValueAs<AFloat>().value));
} }
TEST_F(ResolverConstEvalTest, StructAbstractSplat_to_StructDifferentTypes) {
// fn f() {
// const c = modf(4.0);
// var v = c;
// }
auto* expr_c = Call(builtin::Function::kModf, 0_a);
auto* materialized = Expr("c");
WrapInFunction(Decl(Const("c", expr_c)), Decl(Var("v", materialized)));
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* c = Sem().Get(expr_c);
ASSERT_NE(c, nullptr);
EXPECT_TRUE(c->ConstantValue()->Is<constant::Splat>());
EXPECT_TRUE(c->ConstantValue()->AllEqual());
EXPECT_TRUE(c->ConstantValue()->AnyZero());
EXPECT_TRUE(c->ConstantValue()->AllZero());
EXPECT_TRUE(c->ConstantValue()->Index(0)->AllEqual());
EXPECT_TRUE(c->ConstantValue()->Index(0)->AnyZero());
EXPECT_TRUE(c->ConstantValue()->Index(0)->AllZero());
EXPECT_TRUE(c->ConstantValue()->Index(0)->Type()->Is<type::AbstractFloat>());
EXPECT_EQ(c->ConstantValue()->Index(0)->ValueAs<AFloat>(), 0_f);
EXPECT_TRUE(c->ConstantValue()->Index(1)->AllEqual());
EXPECT_TRUE(c->ConstantValue()->Index(1)->AnyZero());
EXPECT_TRUE(c->ConstantValue()->Index(1)->AllZero());
EXPECT_TRUE(c->ConstantValue()->Index(1)->Type()->Is<type::AbstractFloat>());
EXPECT_EQ(c->ConstantValue()->Index(1)->ValueAs<AFloat>(), 0_a);
auto* v = Sem().GetVal(materialized);
ASSERT_NE(v, nullptr);
EXPECT_TRUE(v->Is<sem::Materialize>());
EXPECT_TRUE(v->ConstantValue()->Is<constant::Splat>());
EXPECT_TRUE(v->ConstantValue()->AllEqual());
EXPECT_TRUE(v->ConstantValue()->AnyZero());
EXPECT_TRUE(v->ConstantValue()->AllZero());
EXPECT_TRUE(v->ConstantValue()->Index(0)->AllEqual());
EXPECT_TRUE(v->ConstantValue()->Index(0)->AnyZero());
EXPECT_TRUE(v->ConstantValue()->Index(0)->AllZero());
EXPECT_TRUE(v->ConstantValue()->Index(0)->Type()->Is<type::F32>());
EXPECT_EQ(v->ConstantValue()->Index(0)->ValueAs<f32>(), 0_f);
EXPECT_TRUE(v->ConstantValue()->Index(1)->AllEqual());
EXPECT_TRUE(v->ConstantValue()->Index(1)->AnyZero());
EXPECT_TRUE(v->ConstantValue()->Index(1)->AllZero());
EXPECT_TRUE(v->ConstantValue()->Index(1)->Type()->Is<type::F32>());
EXPECT_EQ(v->ConstantValue()->Index(1)->ValueAs<f32>(), 0_f);
}
} // namespace } // namespace
} // namespace tint::resolver } // namespace tint::resolver

View File

@ -29,6 +29,7 @@ class MockConstant : public constant::Value {
~MockConstant() override {} ~MockConstant() override {}
const type::Type* Type() const override { return type; } const type::Type* Type() const override { return type; }
const constant::Value* Index(size_t) const override { return {}; } const constant::Value* Index(size_t) const override { return {}; }
size_t NumElements() const override { return 0; }
bool AllZero() const override { return {}; } bool AllZero() const override { return {}; }
bool AnyZero() const override { return {}; } bool AnyZero() const override { return {}; }
bool AllEqual() const override { return {}; } bool AllEqual() const override { return {}; }

View File

@ -0,0 +1,3 @@
fn foo(){
let s1 = modf(0.0);
}

View File

@ -0,0 +1,12 @@
struct modf_result_f32 {
float fract;
float whole;
};
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void foo() {
const modf_result_f32 s1 = (modf_result_f32)0;
}

View File

@ -0,0 +1,12 @@
struct modf_result_f32 {
float fract;
float whole;
};
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void foo() {
const modf_result_f32 s1 = (modf_result_f32)0;
}

View File

@ -0,0 +1,16 @@
#version 310 es
struct modf_result_f32 {
float fract;
float whole;
};
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
void foo() {
modf_result_f32 s1 = modf_result_f32(0.0f, 0.0f);
}

View File

@ -0,0 +1,12 @@
#include <metal_stdlib>
using namespace metal;
struct modf_result_f32 {
float fract;
float whole;
};
void foo() {
modf_result_f32 const s1 = modf_result_f32{};
}

View File

@ -0,0 +1,29 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 10
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %unused_entry_point "unused_entry_point"
OpName %foo "foo"
OpName %__modf_result_f32 "__modf_result_f32"
OpMemberName %__modf_result_f32 0 "fract"
OpMemberName %__modf_result_f32 1 "whole"
OpMemberDecorate %__modf_result_f32 0 Offset 0
OpMemberDecorate %__modf_result_f32 1 Offset 4
%void = OpTypeVoid
%1 = OpTypeFunction %void
%float = OpTypeFloat 32
%__modf_result_f32 = OpTypeStruct %float %float
%9 = OpConstantNull %__modf_result_f32
%unused_entry_point = OpFunction %void None %1
%4 = OpLabel
OpReturn
OpFunctionEnd
%foo = OpFunction %void None %1
%6 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,3 @@
fn foo() {
let s1 = modf(0.0);
}