tint: Fix constant::Splat conversion of struct types
Conversion can happen for structure materialization (modf, frexp). If both structure members are the same type and value, then a constant::Splat will be constructed, which needs to handle conversion. Bug: chromium:1417515 Change-Id: Iadd14ce00b8d5c22226c601ec5af9a84e6c0c5cf Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/122900 Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
a92a8078c5
commit
574b4b1996
|
@ -25,31 +25,40 @@
|
|||
namespace tint::constant {
|
||||
|
||||
/// Composite holds a number of mixed child values.
|
||||
/// Composite may be of a vector, matrix or array type.
|
||||
/// Composite may be of a vector, matrix, array or structure type.
|
||||
/// If each element is the same type and value, then a Splat would be a more efficient constant
|
||||
/// implementation. Use CreateComposite() to create the appropriate type.
|
||||
class Composite : public Castable<Composite, constant::Value> {
|
||||
class Composite : public Castable<Composite, Value> {
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param t the compsite type
|
||||
/// @param els the composite elements
|
||||
/// @param all_0 true if all elements are 0
|
||||
/// @param any_0 true if any element is 0
|
||||
Composite(const type::Type* t,
|
||||
utils::VectorRef<const constant::Value*> els,
|
||||
bool all_0,
|
||||
bool any_0);
|
||||
Composite(const type::Type* t, utils::VectorRef<const Value*> els, bool all_0, bool any_0);
|
||||
~Composite() override;
|
||||
|
||||
/// @copydoc Value::Type()
|
||||
const type::Type* Type() const override { return type; }
|
||||
|
||||
const constant::Value* Index(size_t i) const override {
|
||||
/// @copydoc Value::Index()
|
||||
const Value* Index(size_t i) const override {
|
||||
return i < elements.Length() ? elements[i] : nullptr;
|
||||
}
|
||||
|
||||
/// @copydoc Value::NumElements()
|
||||
size_t NumElements() const override { return elements.Length(); }
|
||||
|
||||
/// @copydoc Value::AllZero()
|
||||
bool AllZero() const override { return all_zero; }
|
||||
|
||||
/// @copydoc Value::AnyZero()
|
||||
bool AnyZero() const override { return any_zero; }
|
||||
|
||||
/// @copydoc Value::AllEqual()
|
||||
bool AllEqual() const override { return false; }
|
||||
|
||||
/// @copydoc Value::Hash()
|
||||
size_t Hash() const override { return hash; }
|
||||
|
||||
/// Clones the constant into the provided context
|
||||
|
@ -60,7 +69,7 @@ class Composite : public Castable<Composite, constant::Value> {
|
|||
/// The composite type
|
||||
type::Type const* const type;
|
||||
/// The composite elements
|
||||
const utils::Vector<const constant::Value*, 4> elements;
|
||||
const utils::Vector<const Value*, 4> elements;
|
||||
/// True if all elements are zero
|
||||
const bool all_zero;
|
||||
/// True if any element is zero
|
||||
|
@ -69,6 +78,7 @@ class Composite : public Castable<Composite, constant::Value> {
|
|||
const size_t hash;
|
||||
|
||||
protected:
|
||||
/// @copydoc Value::InternalValue()
|
||||
std::variant<std::monostate, AInt, AFloat> InternalValue() const override { return {}; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace tint::constant {
|
|||
|
||||
/// Scalar holds a single scalar or abstract-numeric value.
|
||||
template <typename T>
|
||||
class Scalar : public Castable<Scalar<T>, constant::Value> {
|
||||
class Scalar : public Castable<Scalar<T>, Value> {
|
||||
public:
|
||||
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
|
||||
"T must be a Number or bool");
|
||||
|
@ -40,13 +40,25 @@ class Scalar : public Castable<Scalar<T>, constant::Value> {
|
|||
}
|
||||
~Scalar() override = default;
|
||||
|
||||
/// @copydoc Value::Type()
|
||||
const type::Type* Type() const override { return type; }
|
||||
|
||||
const constant::Value* Index(size_t) const override { return nullptr; }
|
||||
/// @return nullptr, as Scalar does not hold any elements.
|
||||
const Value* Index(size_t) const override { return nullptr; }
|
||||
|
||||
/// @copydoc Value::NumElements()
|
||||
size_t NumElements() const override { return 1; }
|
||||
|
||||
/// @copydoc Value::AllZero()
|
||||
bool AllZero() const override { return IsPositiveZero(); }
|
||||
|
||||
/// @copydoc Value::AnyZero()
|
||||
bool AnyZero() const override { return IsPositiveZero(); }
|
||||
|
||||
/// @copydoc Value::AllEqual()
|
||||
bool AllEqual() const override { return true; }
|
||||
|
||||
/// @copydoc Value::Hash()
|
||||
size_t Hash() const override { return utils::Hash(type, ValueOf()); }
|
||||
|
||||
/// Clones the constant into the provided context
|
||||
|
@ -79,6 +91,7 @@ class Scalar : public Castable<Scalar<T>, constant::Value> {
|
|||
const T value;
|
||||
|
||||
protected:
|
||||
/// @copydoc Value::InternalValue()
|
||||
std::variant<std::monostate, AInt, AFloat> InternalValue() const override {
|
||||
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
|
||||
return static_cast<AFloat>(value);
|
||||
|
|
|
@ -25,14 +25,14 @@ namespace tint::constant {
|
|||
/// Splat holds a single value, duplicated as all children.
|
||||
///
|
||||
/// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is
|
||||
/// identical. Splat may be of a vector, matrix or array type.
|
||||
class Splat : public Castable<Splat, constant::Value> {
|
||||
/// identical. Splat may be of a vector, matrix, array or structure type.
|
||||
class Splat : public Castable<Splat, Value> {
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param t the splat type
|
||||
/// @param e the splat element
|
||||
/// @param n the number of items in the splat
|
||||
Splat(const type::Type* t, const constant::Value* e, size_t n);
|
||||
Splat(const type::Type* t, const Value* e, size_t n);
|
||||
~Splat() override;
|
||||
|
||||
/// @returns the type of the splat
|
||||
|
@ -41,7 +41,10 @@ class Splat : public Castable<Splat, constant::Value> {
|
|||
/// Retrieve item at index @p i
|
||||
/// @param i the index to retrieve
|
||||
/// @returns the element, or nullptr if out of bounds
|
||||
const constant::Value* Index(size_t i) const override { return i < count ? el : nullptr; }
|
||||
const Value* Index(size_t i) const override { return i < count ? el : nullptr; }
|
||||
|
||||
/// @copydoc Value::NumElements()
|
||||
size_t NumElements() const override { return count; }
|
||||
|
||||
/// @returns true if the element is zero
|
||||
bool AllZero() const override { return el->AllZero(); }
|
||||
|
@ -61,7 +64,7 @@ class Splat : public Castable<Splat, constant::Value> {
|
|||
/// The type of the splat element
|
||||
type::Type const* const type;
|
||||
/// The element stored in the splat
|
||||
const constant::Value* el;
|
||||
const Value* el;
|
||||
/// The number of items in the splat
|
||||
const size_t count;
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ class Value : public Castable<Value, Node> {
|
|||
/// @returns the type of the value
|
||||
virtual const type::Type* Type() const = 0;
|
||||
|
||||
/// @param i the index of the element
|
||||
/// @returns the child element with the given index, or nullptr if there are no children, or
|
||||
/// the index is out of bounds.
|
||||
///
|
||||
|
@ -44,7 +45,10 @@ class Value : public Castable<Value, Node> {
|
|||
/// For vectors, this returns the i'th element of the vector.
|
||||
/// For matrices, this returns the i'th column vector of the matrix.
|
||||
/// For structures, this returns the i'th member field of the structure.
|
||||
virtual const Value* Index(size_t) const = 0;
|
||||
virtual const Value* Index(size_t i) const = 0;
|
||||
|
||||
/// @return the number of elements held by this Value
|
||||
virtual size_t NumElements() const = 0;
|
||||
|
||||
/// @returns true if child elements are positive-zero valued.
|
||||
virtual bool AllZero() const = 0;
|
||||
|
@ -74,7 +78,7 @@ class Value : public Castable<Value, Node> {
|
|||
|
||||
/// @param b the value to compare too
|
||||
/// @returns true if this value is equal to @p b
|
||||
bool Equal(const constant::Value* b) const;
|
||||
bool Equal(const Value* b) const;
|
||||
|
||||
/// Clones the constant into the provided context
|
||||
/// @param ctx the clone context
|
||||
|
|
|
@ -326,35 +326,20 @@ ConstEval::Result ConvertInternal(const constant::Value* c,
|
|||
const Source& source,
|
||||
bool use_runtime_semantics);
|
||||
|
||||
ConstEval::Result SplatConvert(const constant::Splat* splat,
|
||||
ConstEval::Result CompositeConvert(const constant::Value* value,
|
||||
ProgramBuilder& builder,
|
||||
const type::Type* target_ty,
|
||||
const Source& source,
|
||||
bool use_runtime_semantics) {
|
||||
// Convert the single splatted element type.
|
||||
auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source,
|
||||
use_runtime_semantics);
|
||||
if (!conv_el) {
|
||||
return utils::Failure;
|
||||
}
|
||||
if (!conv_el.Get()) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
|
||||
}
|
||||
const size_t el_count = value->NumElements();
|
||||
|
||||
ConstEval::Result CompositeConvert(const constant::Composite* composite,
|
||||
ProgramBuilder& builder,
|
||||
const type::Type* target_ty,
|
||||
const Source& source,
|
||||
bool use_runtime_semantics) {
|
||||
// Convert each of the composite element types.
|
||||
utils::Vector<const constant::Value*, 4> conv_els;
|
||||
conv_els.Reserve(composite->elements.Length());
|
||||
conv_els.Reserve(el_count);
|
||||
|
||||
std::function<const type::Type*(size_t idx)> target_el_ty;
|
||||
if (auto* str = target_ty->As<type::Struct>()) {
|
||||
if (TINT_UNLIKELY(str->Members().Length() != composite->elements.Length())) {
|
||||
if (TINT_UNLIKELY(str->Members().Length() != el_count)) {
|
||||
TINT_ICE(Resolver, builder.Diagnostics())
|
||||
<< "const-eval conversion of structure has mismatched element counts";
|
||||
return utils::Failure;
|
||||
|
@ -365,7 +350,8 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite,
|
|||
target_el_ty = [el_ty](size_t) { return el_ty; };
|
||||
}
|
||||
|
||||
for (auto* el : composite->elements) {
|
||||
for (size_t i = 0; i < el_count; i++) {
|
||||
auto* el = value->Index(i);
|
||||
auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source,
|
||||
use_runtime_semantics);
|
||||
if (!conv_el) {
|
||||
|
@ -379,6 +365,40 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite,
|
|||
return builder.create<constant::Composite>(target_ty, std::move(conv_els));
|
||||
}
|
||||
|
||||
ConstEval::Result SplatConvert(const constant::Splat* splat,
|
||||
ProgramBuilder& builder,
|
||||
const type::Type* target_ty,
|
||||
const Source& source,
|
||||
bool use_runtime_semantics) {
|
||||
const type::Type* target_el_ty = nullptr;
|
||||
if (auto* str = target_ty->As<type::Struct>()) {
|
||||
// Structure conversion.
|
||||
auto members = str->Members();
|
||||
target_el_ty = members[0]->Type();
|
||||
|
||||
// Structures can only be converted during materialization. The user cannot declare the
|
||||
// target structure type, so each member type must be the same default materialization type.
|
||||
for (size_t i = 1; i < members.Length(); i++) {
|
||||
if (members[i]->Type() != target_el_ty) {
|
||||
TINT_ICE(Resolver, builder.Diagnostics())
|
||||
<< "inconsistent target struct member types for SplatConvert";
|
||||
return utils::Failure;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
target_el_ty = type::Type::ElementOf(target_ty);
|
||||
}
|
||||
// Convert the single splatted element type.
|
||||
auto conv_el = ConvertInternal(splat->el, builder, target_el_ty, source, use_runtime_semantics);
|
||||
if (!conv_el) {
|
||||
return utils::Failure;
|
||||
}
|
||||
if (!conv_el.Get()) {
|
||||
return nullptr;
|
||||
}
|
||||
return builder.create<constant::Splat>(target_ty, conv_el.Get(), splat->count);
|
||||
}
|
||||
|
||||
ConstEval::Result ConvertInternal(const constant::Value* c,
|
||||
ProgramBuilder& builder,
|
||||
const type::Type* target_ty,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "src/tint/resolver/const_eval_test.h"
|
||||
#include "src/tint/sem/materialize.h"
|
||||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
|
@ -472,5 +473,56 @@ TEST_F(ResolverConstEvalTest, Vec3_Convert_Small_f32_to_f16) {
|
|||
EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->ValueAs<AFloat>().value));
|
||||
}
|
||||
|
||||
TEST_F(ResolverConstEvalTest, StructAbstractSplat_to_StructDifferentTypes) {
|
||||
// fn f() {
|
||||
// const c = modf(4.0);
|
||||
// var v = c;
|
||||
// }
|
||||
auto* expr_c = Call(builtin::Function::kModf, 0_a);
|
||||
auto* materialized = Expr("c");
|
||||
WrapInFunction(Decl(Const("c", expr_c)), Decl(Var("v", materialized)));
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
auto* c = Sem().Get(expr_c);
|
||||
ASSERT_NE(c, nullptr);
|
||||
EXPECT_TRUE(c->ConstantValue()->Is<constant::Splat>());
|
||||
EXPECT_TRUE(c->ConstantValue()->AllEqual());
|
||||
EXPECT_TRUE(c->ConstantValue()->AnyZero());
|
||||
EXPECT_TRUE(c->ConstantValue()->AllZero());
|
||||
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(0)->AllEqual());
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(0)->AnyZero());
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(0)->AllZero());
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(0)->Type()->Is<type::AbstractFloat>());
|
||||
EXPECT_EQ(c->ConstantValue()->Index(0)->ValueAs<AFloat>(), 0_f);
|
||||
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(1)->AllEqual());
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(1)->AnyZero());
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(1)->AllZero());
|
||||
EXPECT_TRUE(c->ConstantValue()->Index(1)->Type()->Is<type::AbstractFloat>());
|
||||
EXPECT_EQ(c->ConstantValue()->Index(1)->ValueAs<AFloat>(), 0_a);
|
||||
|
||||
auto* v = Sem().GetVal(materialized);
|
||||
ASSERT_NE(v, nullptr);
|
||||
EXPECT_TRUE(v->Is<sem::Materialize>());
|
||||
EXPECT_TRUE(v->ConstantValue()->Is<constant::Splat>());
|
||||
EXPECT_TRUE(v->ConstantValue()->AllEqual());
|
||||
EXPECT_TRUE(v->ConstantValue()->AnyZero());
|
||||
EXPECT_TRUE(v->ConstantValue()->AllZero());
|
||||
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(0)->AllEqual());
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(0)->AnyZero());
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(0)->AllZero());
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(0)->Type()->Is<type::F32>());
|
||||
EXPECT_EQ(v->ConstantValue()->Index(0)->ValueAs<f32>(), 0_f);
|
||||
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(1)->AllEqual());
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(1)->AnyZero());
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(1)->AllZero());
|
||||
EXPECT_TRUE(v->ConstantValue()->Index(1)->Type()->Is<type::F32>());
|
||||
EXPECT_EQ(v->ConstantValue()->Index(1)->ValueAs<f32>(), 0_f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::resolver
|
||||
|
|
|
@ -29,6 +29,7 @@ class MockConstant : public constant::Value {
|
|||
~MockConstant() override {}
|
||||
const type::Type* Type() const override { return type; }
|
||||
const constant::Value* Index(size_t) const override { return {}; }
|
||||
size_t NumElements() const override { return 0; }
|
||||
bool AllZero() const override { return {}; }
|
||||
bool AnyZero() const override { return {}; }
|
||||
bool AllEqual() const override { return {}; }
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
fn foo(){
|
||||
let s1 = modf(0.0);
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
struct modf_result_f32 {
|
||||
float fract;
|
||||
float whole;
|
||||
};
|
||||
[numthreads(1, 1, 1)]
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
|
||||
void foo() {
|
||||
const modf_result_f32 s1 = (modf_result_f32)0;
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
struct modf_result_f32 {
|
||||
float fract;
|
||||
float whole;
|
||||
};
|
||||
[numthreads(1, 1, 1)]
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
|
||||
void foo() {
|
||||
const modf_result_f32 s1 = (modf_result_f32)0;
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#version 310 es
|
||||
|
||||
struct modf_result_f32 {
|
||||
float fract;
|
||||
float whole;
|
||||
};
|
||||
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void unused_entry_point() {
|
||||
return;
|
||||
}
|
||||
void foo() {
|
||||
modf_result_f32 s1 = modf_result_f32(0.0f, 0.0f);
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct modf_result_f32 {
|
||||
float fract;
|
||||
float whole;
|
||||
};
|
||||
void foo() {
|
||||
modf_result_f32 const s1 = modf_result_f32{};
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 10
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
|
||||
OpExecutionMode %unused_entry_point LocalSize 1 1 1
|
||||
OpName %unused_entry_point "unused_entry_point"
|
||||
OpName %foo "foo"
|
||||
OpName %__modf_result_f32 "__modf_result_f32"
|
||||
OpMemberName %__modf_result_f32 0 "fract"
|
||||
OpMemberName %__modf_result_f32 1 "whole"
|
||||
OpMemberDecorate %__modf_result_f32 0 Offset 0
|
||||
OpMemberDecorate %__modf_result_f32 1 Offset 4
|
||||
%void = OpTypeVoid
|
||||
%1 = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%__modf_result_f32 = OpTypeStruct %float %float
|
||||
%9 = OpConstantNull %__modf_result_f32
|
||||
%unused_entry_point = OpFunction %void None %1
|
||||
%4 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%foo = OpFunction %void None %1
|
||||
%6 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
|
@ -0,0 +1,3 @@
|
|||
fn foo() {
|
||||
let s1 = modf(0.0);
|
||||
}
|
Loading…
Reference in New Issue