diff --git a/src/tint/constant/composite.h b/src/tint/constant/composite.h index 3bd49735f5..5bfa2cb658 100644 --- a/src/tint/constant/composite.h +++ b/src/tint/constant/composite.h @@ -25,31 +25,40 @@ namespace tint::constant { /// Composite holds a number of mixed child values. -/// Composite may be of a vector, matrix or array type. +/// Composite may be of a vector, matrix, array or structure type. /// If each element is the same type and value, then a Splat would be a more efficient constant /// implementation. Use CreateComposite() to create the appropriate type. -class Composite : public Castable { +class Composite : public Castable { public: /// Constructor /// @param t the compsite type /// @param els the composite elements /// @param all_0 true if all elements are 0 /// @param any_0 true if any element is 0 - Composite(const type::Type* t, - utils::VectorRef els, - bool all_0, - bool any_0); + Composite(const type::Type* t, utils::VectorRef els, bool all_0, bool any_0); ~Composite() override; + /// @copydoc Value::Type() const type::Type* Type() const override { return type; } - const constant::Value* Index(size_t i) const override { + /// @copydoc Value::Index() + const Value* Index(size_t i) const override { return i < elements.Length() ? elements[i] : nullptr; } + /// @copydoc Value::NumElements() + size_t NumElements() const override { return elements.Length(); } + + /// @copydoc Value::AllZero() bool AllZero() const override { return all_zero; } + + /// @copydoc Value::AnyZero() bool AnyZero() const override { return any_zero; } + + /// @copydoc Value::AllEqual() bool AllEqual() const override { return false; } + + /// @copydoc Value::Hash() size_t Hash() const override { return hash; } /// Clones the constant into the provided context @@ -60,7 +69,7 @@ class Composite : public Castable { /// The composite type type::Type const* const type; /// The composite elements - const utils::Vector elements; + const utils::Vector elements; /// True if all elements are zero const bool all_zero; /// True if any element is zero @@ -69,6 +78,7 @@ class Composite : public Castable { const size_t hash; protected: + /// @copydoc Value::InternalValue() std::variant InternalValue() const override { return {}; } private: diff --git a/src/tint/constant/scalar.h b/src/tint/constant/scalar.h index a412d910f4..2a6f6a79fb 100644 --- a/src/tint/constant/scalar.h +++ b/src/tint/constant/scalar.h @@ -25,7 +25,7 @@ namespace tint::constant { /// Scalar holds a single scalar or abstract-numeric value. template -class Scalar : public Castable, constant::Value> { +class Scalar : public Castable, Value> { public: static_assert(!std::is_same_v, T> || std::is_same_v, "T must be a Number or bool"); @@ -40,13 +40,25 @@ class Scalar : public Castable, constant::Value> { } ~Scalar() override = default; + /// @copydoc Value::Type() const type::Type* Type() const override { return type; } - const constant::Value* Index(size_t) const override { return nullptr; } + /// @return nullptr, as Scalar does not hold any elements. + const Value* Index(size_t) const override { return nullptr; } + /// @copydoc Value::NumElements() + size_t NumElements() const override { return 1; } + + /// @copydoc Value::AllZero() bool AllZero() const override { return IsPositiveZero(); } + + /// @copydoc Value::AnyZero() bool AnyZero() const override { return IsPositiveZero(); } + + /// @copydoc Value::AllEqual() bool AllEqual() const override { return true; } + + /// @copydoc Value::Hash() size_t Hash() const override { return utils::Hash(type, ValueOf()); } /// Clones the constant into the provided context @@ -79,6 +91,7 @@ class Scalar : public Castable, constant::Value> { const T value; protected: + /// @copydoc Value::InternalValue() std::variant InternalValue() const override { if constexpr (IsFloatingPoint>) { return static_cast(value); diff --git a/src/tint/constant/splat.h b/src/tint/constant/splat.h index 5494c8f098..480512424a 100644 --- a/src/tint/constant/splat.h +++ b/src/tint/constant/splat.h @@ -25,14 +25,14 @@ namespace tint::constant { /// Splat holds a single value, duplicated as all children. /// /// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is -/// identical. Splat may be of a vector, matrix or array type. -class Splat : public Castable { +/// identical. Splat may be of a vector, matrix, array or structure type. +class Splat : public Castable { public: /// Constructor /// @param t the splat type /// @param e the splat element /// @param n the number of items in the splat - Splat(const type::Type* t, const constant::Value* e, size_t n); + Splat(const type::Type* t, const Value* e, size_t n); ~Splat() override; /// @returns the type of the splat @@ -41,7 +41,10 @@ class Splat : public Castable { /// Retrieve item at index @p i /// @param i the index to retrieve /// @returns the element, or nullptr if out of bounds - const constant::Value* Index(size_t i) const override { return i < count ? el : nullptr; } + const Value* Index(size_t i) const override { return i < count ? el : nullptr; } + + /// @copydoc Value::NumElements() + size_t NumElements() const override { return count; } /// @returns true if the element is zero bool AllZero() const override { return el->AllZero(); } @@ -61,7 +64,7 @@ class Splat : public Castable { /// The type of the splat element type::Type const* const type; /// The element stored in the splat - const constant::Value* el; + const Value* el; /// The number of items in the splat const size_t count; diff --git a/src/tint/constant/value.h b/src/tint/constant/value.h index d5b98769ee..594091dfec 100644 --- a/src/tint/constant/value.h +++ b/src/tint/constant/value.h @@ -37,6 +37,7 @@ class Value : public Castable { /// @returns the type of the value virtual const type::Type* Type() const = 0; + /// @param i the index of the element /// @returns the child element with the given index, or nullptr if there are no children, or /// the index is out of bounds. /// @@ -44,7 +45,10 @@ class Value : public Castable { /// For vectors, this returns the i'th element of the vector. /// For matrices, this returns the i'th column vector of the matrix. /// For structures, this returns the i'th member field of the structure. - virtual const Value* Index(size_t) const = 0; + virtual const Value* Index(size_t i) const = 0; + + /// @return the number of elements held by this Value + virtual size_t NumElements() const = 0; /// @returns true if child elements are positive-zero valued. virtual bool AllZero() const = 0; @@ -74,7 +78,7 @@ class Value : public Castable { /// @param b the value to compare too /// @returns true if this value is equal to @p b - bool Equal(const constant::Value* b) const; + bool Equal(const Value* b) const; /// Clones the constant into the provided context /// @param ctx the clone context diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 0ce4f3b81e..b90ea00a92 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -326,35 +326,20 @@ ConstEval::Result ConvertInternal(const constant::Value* c, const Source& source, bool use_runtime_semantics); -ConstEval::Result SplatConvert(const constant::Splat* splat, - ProgramBuilder& builder, - const type::Type* target_ty, - const Source& source, - bool use_runtime_semantics) { - // Convert the single splatted element type. - auto conv_el = ConvertInternal(splat->el, builder, type::Type::ElementOf(target_ty), source, - use_runtime_semantics); - if (!conv_el) { - return utils::Failure; - } - if (!conv_el.Get()) { - return nullptr; - } - return builder.create(target_ty, conv_el.Get(), splat->count); -} - -ConstEval::Result CompositeConvert(const constant::Composite* composite, +ConstEval::Result CompositeConvert(const constant::Value* value, ProgramBuilder& builder, const type::Type* target_ty, const Source& source, bool use_runtime_semantics) { + const size_t el_count = value->NumElements(); + // Convert each of the composite element types. utils::Vector conv_els; - conv_els.Reserve(composite->elements.Length()); + conv_els.Reserve(el_count); std::function target_el_ty; if (auto* str = target_ty->As()) { - if (TINT_UNLIKELY(str->Members().Length() != composite->elements.Length())) { + if (TINT_UNLIKELY(str->Members().Length() != el_count)) { TINT_ICE(Resolver, builder.Diagnostics()) << "const-eval conversion of structure has mismatched element counts"; return utils::Failure; @@ -365,7 +350,8 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite, target_el_ty = [el_ty](size_t) { return el_ty; }; } - for (auto* el : composite->elements) { + for (size_t i = 0; i < el_count; i++) { + auto* el = value->Index(i); auto conv_el = ConvertInternal(el, builder, target_el_ty(conv_els.Length()), source, use_runtime_semantics); if (!conv_el) { @@ -379,6 +365,40 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite, return builder.create(target_ty, std::move(conv_els)); } +ConstEval::Result SplatConvert(const constant::Splat* splat, + ProgramBuilder& builder, + const type::Type* target_ty, + const Source& source, + bool use_runtime_semantics) { + const type::Type* target_el_ty = nullptr; + if (auto* str = target_ty->As()) { + // Structure conversion. + auto members = str->Members(); + target_el_ty = members[0]->Type(); + + // Structures can only be converted during materialization. The user cannot declare the + // target structure type, so each member type must be the same default materialization type. + for (size_t i = 1; i < members.Length(); i++) { + if (members[i]->Type() != target_el_ty) { + TINT_ICE(Resolver, builder.Diagnostics()) + << "inconsistent target struct member types for SplatConvert"; + return utils::Failure; + } + } + } else { + target_el_ty = type::Type::ElementOf(target_ty); + } + // Convert the single splatted element type. + auto conv_el = ConvertInternal(splat->el, builder, target_el_ty, source, use_runtime_semantics); + if (!conv_el) { + return utils::Failure; + } + if (!conv_el.Get()) { + return nullptr; + } + return builder.create(target_ty, conv_el.Get(), splat->count); +} + ConstEval::Result ConvertInternal(const constant::Value* c, ProgramBuilder& builder, const type::Type* target_ty, diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc index 92af2ded45..331eb9207a 100644 --- a/src/tint/resolver/const_eval_conversion_test.cc +++ b/src/tint/resolver/const_eval_conversion_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "src/tint/resolver/const_eval_test.h" +#include "src/tint/sem/materialize.h" using namespace tint::number_suffixes; // NOLINT @@ -472,5 +473,56 @@ TEST_F(ResolverConstEvalTest, Vec3_Convert_Small_f32_to_f16) { EXPECT_FALSE(std::signbit(sem->ConstantValue()->Index(2)->ValueAs().value)); } +TEST_F(ResolverConstEvalTest, StructAbstractSplat_to_StructDifferentTypes) { + // fn f() { + // const c = modf(4.0); + // var v = c; + // } + auto* expr_c = Call(builtin::Function::kModf, 0_a); + auto* materialized = Expr("c"); + WrapInFunction(Decl(Const("c", expr_c)), Decl(Var("v", materialized))); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* c = Sem().Get(expr_c); + ASSERT_NE(c, nullptr); + EXPECT_TRUE(c->ConstantValue()->Is()); + EXPECT_TRUE(c->ConstantValue()->AllEqual()); + EXPECT_TRUE(c->ConstantValue()->AnyZero()); + EXPECT_TRUE(c->ConstantValue()->AllZero()); + + EXPECT_TRUE(c->ConstantValue()->Index(0)->AllEqual()); + EXPECT_TRUE(c->ConstantValue()->Index(0)->AnyZero()); + EXPECT_TRUE(c->ConstantValue()->Index(0)->AllZero()); + EXPECT_TRUE(c->ConstantValue()->Index(0)->Type()->Is()); + EXPECT_EQ(c->ConstantValue()->Index(0)->ValueAs(), 0_f); + + EXPECT_TRUE(c->ConstantValue()->Index(1)->AllEqual()); + EXPECT_TRUE(c->ConstantValue()->Index(1)->AnyZero()); + EXPECT_TRUE(c->ConstantValue()->Index(1)->AllZero()); + EXPECT_TRUE(c->ConstantValue()->Index(1)->Type()->Is()); + EXPECT_EQ(c->ConstantValue()->Index(1)->ValueAs(), 0_a); + + auto* v = Sem().GetVal(materialized); + ASSERT_NE(v, nullptr); + EXPECT_TRUE(v->Is()); + EXPECT_TRUE(v->ConstantValue()->Is()); + EXPECT_TRUE(v->ConstantValue()->AllEqual()); + EXPECT_TRUE(v->ConstantValue()->AnyZero()); + EXPECT_TRUE(v->ConstantValue()->AllZero()); + + EXPECT_TRUE(v->ConstantValue()->Index(0)->AllEqual()); + EXPECT_TRUE(v->ConstantValue()->Index(0)->AnyZero()); + EXPECT_TRUE(v->ConstantValue()->Index(0)->AllZero()); + EXPECT_TRUE(v->ConstantValue()->Index(0)->Type()->Is()); + EXPECT_EQ(v->ConstantValue()->Index(0)->ValueAs(), 0_f); + + EXPECT_TRUE(v->ConstantValue()->Index(1)->AllEqual()); + EXPECT_TRUE(v->ConstantValue()->Index(1)->AnyZero()); + EXPECT_TRUE(v->ConstantValue()->Index(1)->AllZero()); + EXPECT_TRUE(v->ConstantValue()->Index(1)->Type()->Is()); + EXPECT_EQ(v->ConstantValue()->Index(1)->ValueAs(), 0_f); +} + } // namespace } // namespace tint::resolver diff --git a/src/tint/sem/value_expression_test.cc b/src/tint/sem/value_expression_test.cc index 17588941ee..5386ffca3a 100644 --- a/src/tint/sem/value_expression_test.cc +++ b/src/tint/sem/value_expression_test.cc @@ -29,6 +29,7 @@ class MockConstant : public constant::Value { ~MockConstant() override {} const type::Type* Type() const override { return type; } const constant::Value* Index(size_t) const override { return {}; } + size_t NumElements() const override { return 0; } bool AllZero() const override { return {}; } bool AnyZero() const override { return {}; } bool AllEqual() const override { return {}; } diff --git a/test/tint/bug/chromium/1417515.wgsl b/test/tint/bug/chromium/1417515.wgsl new file mode 100644 index 0000000000..ee34253ec4 --- /dev/null +++ b/test/tint/bug/chromium/1417515.wgsl @@ -0,0 +1,3 @@ +fn foo(){ + let s1 = modf(0.0); +} diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl b/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl new file mode 100644 index 0000000000..8916eb5c8b --- /dev/null +++ b/test/tint/bug/chromium/1417515.wgsl.expected.dxc.hlsl @@ -0,0 +1,12 @@ +struct modf_result_f32 { + float fract; + float whole; +}; +[numthreads(1, 1, 1)] +void unused_entry_point() { + return; +} + +void foo() { + const modf_result_f32 s1 = (modf_result_f32)0; +} diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl b/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl new file mode 100644 index 0000000000..8916eb5c8b --- /dev/null +++ b/test/tint/bug/chromium/1417515.wgsl.expected.fxc.hlsl @@ -0,0 +1,12 @@ +struct modf_result_f32 { + float fract; + float whole; +}; +[numthreads(1, 1, 1)] +void unused_entry_point() { + return; +} + +void foo() { + const modf_result_f32 s1 = (modf_result_f32)0; +} diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.glsl b/test/tint/bug/chromium/1417515.wgsl.expected.glsl new file mode 100644 index 0000000000..75b85f1322 --- /dev/null +++ b/test/tint/bug/chromium/1417515.wgsl.expected.glsl @@ -0,0 +1,16 @@ +#version 310 es + +struct modf_result_f32 { + float fract; + float whole; +}; + + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +void unused_entry_point() { + return; +} +void foo() { + modf_result_f32 s1 = modf_result_f32(0.0f, 0.0f); +} + diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.msl b/test/tint/bug/chromium/1417515.wgsl.expected.msl new file mode 100644 index 0000000000..3433613b89 --- /dev/null +++ b/test/tint/bug/chromium/1417515.wgsl.expected.msl @@ -0,0 +1,12 @@ +#include + +using namespace metal; + +struct modf_result_f32 { + float fract; + float whole; +}; +void foo() { + modf_result_f32 const s1 = modf_result_f32{}; +} + diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.spvasm b/test/tint/bug/chromium/1417515.wgsl.expected.spvasm new file mode 100644 index 0000000000..09a7770023 --- /dev/null +++ b/test/tint/bug/chromium/1417515.wgsl.expected.spvasm @@ -0,0 +1,29 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 10 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %unused_entry_point "unused_entry_point" + OpExecutionMode %unused_entry_point LocalSize 1 1 1 + OpName %unused_entry_point "unused_entry_point" + OpName %foo "foo" + OpName %__modf_result_f32 "__modf_result_f32" + OpMemberName %__modf_result_f32 0 "fract" + OpMemberName %__modf_result_f32 1 "whole" + OpMemberDecorate %__modf_result_f32 0 Offset 0 + OpMemberDecorate %__modf_result_f32 1 Offset 4 + %void = OpTypeVoid + %1 = OpTypeFunction %void + %float = OpTypeFloat 32 +%__modf_result_f32 = OpTypeStruct %float %float + %9 = OpConstantNull %__modf_result_f32 +%unused_entry_point = OpFunction %void None %1 + %4 = OpLabel + OpReturn + OpFunctionEnd + %foo = OpFunction %void None %1 + %6 = OpLabel + OpReturn + OpFunctionEnd diff --git a/test/tint/bug/chromium/1417515.wgsl.expected.wgsl b/test/tint/bug/chromium/1417515.wgsl.expected.wgsl new file mode 100644 index 0000000000..0750da8fab --- /dev/null +++ b/test/tint/bug/chromium/1417515.wgsl.expected.wgsl @@ -0,0 +1,3 @@ +fn foo() { + let s1 = modf(0.0); +}