diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index af37a52a95..87eb8521c8 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -367,10 +367,10 @@ libtint_source_set("libtint_core_all_src") { "clone_context.cc", "clone_context.h", "constant/composite.h", - "constant/constant.h", "constant/node.h", "constant/scalar.h", "constant/splat.h", + "constant/value.h", "debug.cc", "debug.h", "demangler.cc", @@ -767,14 +767,14 @@ libtint_source_set("libtint_constant_src") { sources = [ "constant/composite.cc", "constant/composite.h", - "constant/constant.cc", - "constant/constant.h", "constant/node.cc", "constant/node.h", "constant/scalar.cc", "constant/scalar.h", "constant/splat.cc", "constant/splat.h", + "constant/value.cc", + "constant/value.h", ] public_deps = [ ":libtint_core_all_src" ] } diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index e0248b75e2..ee350f167f 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -256,14 +256,14 @@ list(APPEND TINT_LIB_SRCS clone_context.h constant/composite.cc constant/composite.h - constant/constant.cc - constant/constant.h constant/scalar.cc constant/scalar.h constant/splat.cc constant/splat.h constant/node.cc constant/node.h + constant/value.cc + constant/value.h demangler.cc demangler.h inspector/entry_point.cc diff --git a/src/tint/constant/composite.cc b/src/tint/constant/composite.cc index a0f0c9f601..76d2adb4ce 100644 --- a/src/tint/constant/composite.cc +++ b/src/tint/constant/composite.cc @@ -21,7 +21,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::constant::Composite); namespace tint::constant { Composite::Composite(const type::Type* t, - utils::VectorRef els, + utils::VectorRef els, bool all_0, bool any_0) : type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {} diff --git a/src/tint/constant/composite.h b/src/tint/constant/composite.h index 598a4af7ca..aacc2ef924 100644 --- a/src/tint/constant/composite.h +++ b/src/tint/constant/composite.h @@ -16,7 +16,7 @@ #define SRC_TINT_CONSTANT_COMPOSITE_H_ #include "src/tint/castable.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/number.h" #include "src/tint/type/type.h" #include "src/tint/utils/hash.h" @@ -24,12 +24,11 @@ namespace tint::constant { -/// Composite holds a number of mixed child Constant values. +/// Composite holds a number of mixed child values. /// Composite may be of a vector, matrix or array type. /// If each element is the same type and value, then a Splat would be a more efficient constant -/// implementation. Use CreateComposite() to create the appropriate Constant type. -/// Composite implements the Constant interface. -class Composite : public Castable { +/// implementation. Use CreateComposite() to create the appropriate type. +class Composite : public Castable { public: /// Constructor /// @param t the compsite type @@ -37,15 +36,14 @@ class Composite : public Castable { /// @param all_0 true if all elements are 0 /// @param any_0 true if any element is 0 Composite(const type::Type* t, - utils::VectorRef els, + utils::VectorRef els, bool all_0, bool any_0); ~Composite() override; const type::Type* Type() const override { return type; } - std::variant Value() const override { return {}; } - const constant::Constant* Index(size_t i) const override { + const constant::Value* Index(size_t i) const override { return i < elements.Length() ? elements[i] : nullptr; } @@ -57,7 +55,7 @@ class Composite : public Castable { /// The composite type type::Type const* const type; /// The composite elements - const utils::Vector elements; + const utils::Vector elements; /// True if all elements are zero const bool all_zero; /// True if any element is zero @@ -65,6 +63,9 @@ class Composite : public Castable { /// The hash of the composite const size_t hash; + protected: + std::variant InternalValue() const override { return {}; } + private: size_t CalcHash() { auto h = utils::Hash(type, all_zero, any_zero); diff --git a/src/tint/constant/constant.cc b/src/tint/constant/constant.cc deleted file mode 100644 index a5b0caf16a..0000000000 --- a/src/tint/constant/constant.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2021 The Tint Authors. -// -// Licensed under the Apache License, Version 2.0(the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "src/tint/constant/constant.h" - -TINT_INSTANTIATE_TYPEINFO(tint::constant::Constant); - -namespace tint::constant { - -Constant::Constant() = default; - -Constant::~Constant() = default; - -} // namespace tint::constant diff --git a/src/tint/constant/scalar.h b/src/tint/constant/scalar.h index bba4fbba86..3cf705ecde 100644 --- a/src/tint/constant/scalar.h +++ b/src/tint/constant/scalar.h @@ -16,7 +16,7 @@ #define SRC_TINT_CONSTANT_SCALAR_H_ #include "src/tint/castable.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/number.h" #include "src/tint/type/type.h" #include "src/tint/utils/hash.h" @@ -24,9 +24,8 @@ namespace tint::constant { /// Scalar holds a single scalar or abstract-numeric value. -/// Scalar implements the Constant interface. template -class Scalar : public Castable, constant::Constant> { +class Scalar : public Castable, constant::Value> { public: static_assert(!std::is_same_v, T> || std::is_same_v, "T must be a Number or bool"); @@ -43,14 +42,7 @@ class Scalar : public Castable, constant::Constant> { const type::Type* Type() const override { return type; } - std::variant Value() const override { - if constexpr (IsFloatingPoint>) { - return static_cast(value); - } else { - return static_cast(value); - } - } - const constant::Constant* Index(size_t) const override { return nullptr; } + const constant::Value* Index(size_t) const override { return nullptr; } bool AllZero() const override { return IsPositiveZero(); } bool AnyZero() const override { return IsPositiveZero(); } @@ -77,6 +69,15 @@ class Scalar : public Castable, constant::Constant> { type::Type const* const type; /// The scalar value const T value; + + protected: + std::variant InternalValue() const override { + if constexpr (IsFloatingPoint>) { + return static_cast(value); + } else { + return static_cast(value); + } + } }; } // namespace tint::constant diff --git a/src/tint/constant/splat.cc b/src/tint/constant/splat.cc index 2b0766e273..9ef68d46f7 100644 --- a/src/tint/constant/splat.cc +++ b/src/tint/constant/splat.cc @@ -18,8 +18,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::constant::Splat); namespace tint::constant { -Splat::Splat(const type::Type* t, const constant::Constant* e, size_t n) - : type(t), el(e), count(n) {} +Splat::Splat(const type::Type* t, const constant::Value* e, size_t n) : type(t), el(e), count(n) {} Splat::~Splat() = default; diff --git a/src/tint/constant/splat.h b/src/tint/constant/splat.h index 5393927f4d..9fa7b13b10 100644 --- a/src/tint/constant/splat.h +++ b/src/tint/constant/splat.h @@ -22,29 +22,26 @@ namespace tint::constant { -/// Splat holds a single Constant value, duplicated as all children. +/// Splat holds a single value, duplicated as all children. +/// /// Splat is used for zero-initializers, 'splat' initializers, or initializers where each element is /// identical. Splat may be of a vector, matrix or array type. -/// Splat implements the Constant interface. -class Splat : public Castable { +class Splat : public Castable { public: /// Constructor /// @param t the splat type /// @param e the splat element /// @param n the number of items in the splat - Splat(const type::Type* t, const constant::Constant* e, size_t n); + Splat(const type::Type* t, const constant::Value* e, size_t n); ~Splat() override; /// @returns the type of the splat const type::Type* Type() const override { return type; } - /// @returns a monostate variant. - std::variant Value() const override { return {}; } - /// Retrieve item at index @p i /// @param i the index to retrieve /// @returns the element, or nullptr if out of bounds - const constant::Constant* Index(size_t i) const override { return i < count ? el : nullptr; } + const constant::Value* Index(size_t i) const override { return i < count ? el : nullptr; } /// @returns true if the element is zero bool AllZero() const override { return el->AllZero(); } @@ -59,9 +56,13 @@ class Splat : public Castable { /// The type of the splat element type::Type const* const type; /// The element stored in the splat - const constant::Constant* el; + const constant::Value* el; /// The number of items in the splat const size_t count; + + protected: + /// @returns a monostate variant. + std::variant InternalValue() const override { return {}; } }; } // namespace tint::constant diff --git a/src/tint/constant/value.cc b/src/tint/constant/value.cc new file mode 100644 index 0000000000..c41ca3493a --- /dev/null +++ b/src/tint/constant/value.cc @@ -0,0 +1,86 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0(the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/constant/value.h" + +#include "src/tint/type/array.h" +#include "src/tint/type/matrix.h" +#include "src/tint/type/struct.h" +#include "src/tint/type/vector.h" + +TINT_INSTANTIATE_TYPEINFO(tint::constant::Value); + +namespace tint::constant { + +Value::Value() = default; + +Value::~Value() = default; + +/// Equal returns true if the constants `a` and `b` are of the same type and value. +bool Value::Equal(const constant::Value* b) const { + if (Hash() != b->Hash()) { + return false; + } + if (Type() != b->Type()) { + return false; + } + return Switch( + Type(), // + [&](const type::Vector* vec) { + for (size_t i = 0; i < vec->Width(); i++) { + if (!Index(i)->Equal(b->Index(i))) { + return false; + } + } + return true; + }, + [&](const type::Matrix* mat) { + for (size_t i = 0; i < mat->columns(); i++) { + if (!Index(i)->Equal(b->Index(i))) { + return false; + } + } + return true; + }, + [&](const type::Array* arr) { + if (auto count = arr->ConstantCount()) { + for (size_t i = 0; i < count; i++) { + if (!Index(i)->Equal(b->Index(i))) { + return false; + } + } + return true; + } + + return false; + }, + [&](const type::Struct* str) { + auto count = str->Members().Length(); + for (size_t i = 0; i < count; i++) { + if (!Index(i)->Equal(b->Index(i))) { + return false; + } + } + return true; + }, + [&](Default) { + auto va = InternalValue(); + auto vb = b->InternalValue(); + TINT_ASSERT(Resolver, !std::holds_alternative(va)); + TINT_ASSERT(Resolver, !std::holds_alternative(vb)); + return va == vb; + }); +} + +} // namespace tint::constant diff --git a/src/tint/constant/constant.h b/src/tint/constant/value.h similarity index 58% rename from src/tint/constant/constant.h rename to src/tint/constant/value.h index f9124efc51..c16fe7623a 100644 --- a/src/tint/constant/constant.h +++ b/src/tint/constant/value.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SRC_TINT_CONSTANT_CONSTANT_H_ -#define SRC_TINT_CONSTANT_CONSTANT_H_ +#ifndef SRC_TINT_CONSTANT_VALUE_H_ +#define SRC_TINT_CONSTANT_VALUE_H_ #include @@ -24,43 +24,40 @@ namespace tint::constant { -/// Constant is the interface to a compile-time evaluated expression value. -class Constant : public Castable { +/// Value is the interface to a compile-time evaluated expression value. +class Value : public Castable { public: /// Constructor - Constant(); + Value(); /// Destructor - ~Constant() override; + ~Value() override; - /// @returns the type of the constant + /// @returns the type of the value virtual const type::Type* Type() const = 0; - /// @returns the value of this Constant, if this constant is of a scalar value or abstract - /// numeric, otherwise std::monostate. - virtual std::variant Value() const = 0; - - /// @returns the child constant element with the given index, or nullptr if the constant has no - /// children, or the index is out of bounds. + /// @returns the child element with the given index, or nullptr if there are no children, or + /// the index is out of bounds. + /// /// For arrays, this returns the i'th element of the array. /// For vectors, this returns the i'th element of the vector. /// For matrices, this returns the i'th column vector of the matrix. /// For structures, this returns the i'th member field of the structure. - virtual const Constant* Index(size_t) const = 0; + virtual const Value* Index(size_t) const = 0; - /// @returns true if child elements of this constant are positive-zero valued. + /// @returns true if child elements are positive-zero valued. virtual bool AllZero() const = 0; - /// @returns true if any child elements of this constant are positive-zero valued. + /// @returns true if any child elements are positive-zero valued. virtual bool AnyZero() const = 0; - /// @returns true if all child elements of this constant have the same value and type. + /// @returns true if all child elements have the same value and type. virtual bool AllEqual() const = 0; - /// @returns a hash of the constant. + /// @returns a hash of the value. virtual size_t Hash() const = 0; - /// @returns the value of the constant as the given scalar or abstract value. + /// @returns the value as the given scalar or abstract value. template T ValueAs() const { return std::visit( @@ -71,10 +68,19 @@ class Constant : public Castable { return static_cast(v); } }, - Value()); + InternalValue()); } + + /// @param b the value to compare too + /// @returns true if this value is equal to @p b + bool Equal(const constant::Value* b) const; + + protected: + /// @returns the value, if this is of a scalar value or abstract numeric, otherwise + /// std::monostate. + virtual std::variant InternalValue() const = 0; }; } // namespace tint::constant -#endif // SRC_TINT_CONSTANT_CONSTANT_H_ +#endif // SRC_TINT_CONSTANT_VALUE_H_ diff --git a/src/tint/program.h b/src/tint/program.h index a906041b81..873a0c4a63 100644 --- a/src/tint/program.h +++ b/src/tint/program.h @@ -19,7 +19,7 @@ #include #include "src/tint/ast/function.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/program_id.h" #include "src/tint/sem/info.h" #include "src/tint/symbol_table.h" @@ -44,8 +44,8 @@ class Program { /// SemNodeAllocator is an alias to BlockAllocator using SemNodeAllocator = utils::BlockAllocator; - /// ConstantAllocator is an alias to BlockAllocator - using ConstantAllocator = utils::BlockAllocator; + /// ConstantAllocator is an alias to BlockAllocator + using ConstantAllocator = utils::BlockAllocator; /// Constructor Program(); diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 92fdf31a5d..83ad287f8d 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -87,7 +87,7 @@ #include "src/tint/ast/void.h" #include "src/tint/ast/while_statement.h" #include "src/tint/ast/workgroup_attribute.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/number.h" #include "src/tint/program.h" #include "src/tint/program_id.h" @@ -265,8 +265,8 @@ class ProgramBuilder { /// SemNodeAllocator is an alias to BlockAllocator using SemNodeAllocator = utils::BlockAllocator; - /// ConstantAllocator is an alias to BlockAllocator - using ConstantAllocator = utils::BlockAllocator; + /// ConstantAllocator is an alias to BlockAllocator + using ConstantAllocator = utils::BlockAllocator; /// Constructor ProgramBuilder(); @@ -465,12 +465,12 @@ class ProgramBuilder { return sem_nodes_.Create(std::forward(args)...); } - /// Creates a new constant::Constant owned by the ProgramBuilder. + /// Creates a new constant::Value owned by the ProgramBuilder. /// When the ProgramBuilder is destructed, the sem::Node will also be destructed. /// @param args the arguments to pass to the constructor /// @returns the node pointer template - traits::EnableIf, T>* create(ARGS&&... args) { + traits::EnableIf, T>* create(ARGS&&... args) { AssertNotMoved(); return constant_nodes_.Create(std::forward(args)...); } diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 434a7659ad..a92b02a74d 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -23,9 +23,9 @@ #include #include "src/tint/constant/composite.h" -#include "src/tint/constant/constant.h" #include "src/tint/constant/scalar.h" #include "src/tint/constant/splat.h" +#include "src/tint/constant/value.h" #include "src/tint/number.h" #include "src/tint/program_builder.h" #include "src/tint/sem/member_accessor_expression.h" @@ -233,9 +233,9 @@ std::make_unsigned_t CountTrailingBits(T e, T bit_value_to_count) { } // Forward declaration -const constant::Constant* CreateComposite(ProgramBuilder& builder, - const type::Type* type, - utils::VectorRef elements); +const constant::Value* CreateComposite(ProgramBuilder& builder, + const type::Type* type, + utils::VectorRef elements); template ConstEval::Result ScalarConvert(const constant::Scalar* scalar, @@ -296,7 +296,7 @@ ConstEval::Result ScalarConvert(const constant::Scalar* scalar, } // Forward declare -ConstEval::Result ConvertInternal(const constant::Constant* c, +ConstEval::Result ConvertInternal(const constant::Value* c, ProgramBuilder& builder, const type::Type* target_ty, const Source& source); @@ -321,7 +321,7 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite, const type::Type* target_ty, const Source& source) { // Convert each of the composite element types. - utils::Vector conv_els; + utils::Vector conv_els; conv_els.Reserve(composite->elements.Length()); std::function target_el_ty; @@ -350,7 +350,7 @@ ConstEval::Result CompositeConvert(const constant::Composite* composite, return CreateComposite(builder, target_ty, std::move(conv_els)); } -ConstEval::Result ConvertInternal(const constant::Constant* c, +ConstEval::Result ConvertInternal(const constant::Value* c, ProgramBuilder& builder, const type::Type* target_ty, const Source& source) { @@ -403,18 +403,18 @@ ConstEval::Result CreateScalar(ProgramBuilder& builder, } /// ZeroValue returns a Constant for the zero-value of the type `type`. -const constant::Constant* ZeroValue(ProgramBuilder& builder, const type::Type* type) { +const constant::Value* ZeroValue(ProgramBuilder& builder, const type::Type* type) { return Switch( type, // - [&](const type::Vector* v) -> const constant::Constant* { + [&](const type::Vector* v) -> const constant::Value* { auto* zero_el = ZeroValue(builder, v->type()); return builder.create(type, zero_el, v->Width()); }, - [&](const type::Matrix* m) -> const constant::Constant* { + [&](const type::Matrix* m) -> const constant::Value* { auto* zero_el = ZeroValue(builder, m->ColumnType()); return builder.create(type, zero_el, m->columns()); }, - [&](const type::Array* a) -> const constant::Constant* { + [&](const type::Array* a) -> const constant::Value* { if (auto n = a->ConstantCount()) { if (auto* zero_el = ZeroValue(builder, a->ElemType())) { return builder.create(type, zero_el, n.value()); @@ -422,9 +422,9 @@ const constant::Constant* ZeroValue(ProgramBuilder& builder, const type::Type* t } return nullptr; }, - [&](const type::Struct* s) -> const constant::Constant* { - utils::Hashmap zero_by_type; - utils::Vector zeros; + [&](const type::Struct* s) -> const constant::Value* { + utils::Hashmap zero_by_type; + utils::Vector zeros; zeros.Reserve(s->Members().Length()); for (auto* member : s->Members()) { auto* zero = zero_by_type.GetOrCreate( @@ -440,8 +440,8 @@ const constant::Constant* ZeroValue(ProgramBuilder& builder, const type::Type* t } return CreateComposite(builder, s, std::move(zeros)); }, - [&](Default) -> const constant::Constant* { - return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Constant* { + [&](Default) -> const constant::Value* { + return ZeroTypeDispatch(type, [&](auto zero) -> const constant::Value* { auto el = CreateScalar(builder, Source{}, type, zero); TINT_ASSERT(Resolver, el); return el.Get(); @@ -449,68 +449,12 @@ const constant::Constant* ZeroValue(ProgramBuilder& builder, const type::Type* t }); } -/// Equal returns true if the constants `a` and `b` are of the same type and value. -bool Equal(const constant::Constant* a, const constant::Constant* b) { - if (a->Hash() != b->Hash()) { - return false; - } - if (a->Type() != b->Type()) { - return false; - } - return Switch( - a->Type(), // - [&](const type::Vector* vec) { - for (size_t i = 0; i < vec->Width(); i++) { - if (!Equal(a->Index(i), b->Index(i))) { - return false; - } - } - return true; - }, - [&](const type::Matrix* mat) { - for (size_t i = 0; i < mat->columns(); i++) { - if (!Equal(a->Index(i), b->Index(i))) { - return false; - } - } - return true; - }, - [&](const type::Array* arr) { - if (auto count = arr->ConstantCount()) { - for (size_t i = 0; i < count; i++) { - if (!Equal(a->Index(i), b->Index(i))) { - return false; - } - } - return true; - } - - return false; - }, - [&](const type::Struct* str) { - auto count = str->Members().Length(); - for (size_t i = 0; i < count; i++) { - if (!Equal(a->Index(i), b->Index(i))) { - return false; - } - } - return true; - }, - [&](Default) { - auto va = a->Value(); - auto vb = b->Value(); - TINT_ASSERT(Resolver, !std::holds_alternative(va)); - TINT_ASSERT(Resolver, !std::holds_alternative(vb)); - return va == vb; - }); -} - /// CreateComposite is used to construct a constant of a vector, matrix or array type. /// CreateComposite examines the element values and will return either a Composite or a Splat, /// depending on the element types and values. -const constant::Constant* CreateComposite(ProgramBuilder& builder, - const type::Type* type, - utils::VectorRef elements) { +const constant::Value* CreateComposite(ProgramBuilder& builder, + const type::Type* type, + utils::VectorRef elements) { if (elements.IsEmpty()) { return nullptr; } @@ -529,7 +473,7 @@ const constant::Constant* CreateComposite(ProgramBuilder& builder, all_zero = false; } if (all_equal && el != first) { - if (!Equal(el, first)) { + if (!el->Equal(first)) { all_equal = false; } } @@ -560,7 +504,7 @@ ConstEval::Result TransformElements(ProgramBuilder& builder, return f(cs...); } } - utils::Vector els; + utils::Vector els; els.Reserve(n); for (uint32_t i = 0; i < n; i++) { if (auto el = detail::TransformElements(builder, type::Type::ElementOf(composite_ty), @@ -596,8 +540,8 @@ template ConstEval::Result TransformBinaryElements(ProgramBuilder& builder, const type::Type* composite_ty, F&& f, - const constant::Constant* c0, - const constant::Constant* c1) { + const constant::Value* c0, + const constant::Value* c1) { uint32_t n0 = 0; type::Type::ElementOf(c0->Type(), &n0); uint32_t n1 = 0; @@ -608,7 +552,7 @@ ConstEval::Result TransformBinaryElements(ProgramBuilder& builder, return f(c0, c1); } - utils::Vector els; + utils::Vector els; els.Reserve(max_n); for (uint32_t i = 0; i < max_n; i++) { auto nested_or_self = [&](auto* c, uint32_t num_elems) { @@ -1123,8 +1067,8 @@ auto ConstEval::Dot4Func(const Source& source, const type::Type* elem_ty) { } ConstEval::Result ConstEval::Dot(const Source& source, - const constant::Constant* v1, - const constant::Constant* v2) { + const constant::Value* v1, + const constant::Value* v2) { auto* vec_ty = v1->Type()->As(); TINT_ASSERT(Resolver, vec_ty); auto* elem_ty = vec_ty->type(); @@ -1151,7 +1095,7 @@ ConstEval::Result ConstEval::Dot(const Source& source, ConstEval::Result ConstEval::Length(const Source& source, const type::Type* ty, - const constant::Constant* c0) { + const constant::Value* c0) { auto* vec_ty = c0->Type()->As(); // Evaluates to the absolute value of e if T is scalar. if (vec_ty == nullptr) { @@ -1172,9 +1116,9 @@ ConstEval::Result ConstEval::Length(const Source& source, ConstEval::Result ConstEval::Mul(const Source& source, const type::Type* ty, - const constant::Constant* v1, - const constant::Constant* v2) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + const constant::Value* v1, + const constant::Value* v2) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { return Dispatch_fia_fiu32_f16(MulFunc(source, c0->Type()), c0, c1); }; return TransformBinaryElements(builder, ty, transform, v1, v2); @@ -1182,9 +1126,9 @@ ConstEval::Result ConstEval::Mul(const Source& source, ConstEval::Result ConstEval::Sub(const Source& source, const type::Type* ty, - const constant::Constant* v1, - const constant::Constant* v2) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + const constant::Value* v1, + const constant::Value* v2) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { return Dispatch_fia_fiu32_f16(SubFunc(source, c0->Type()), c0, c1); }; return TransformBinaryElements(builder, ty, transform, v1, v2); @@ -1262,7 +1206,7 @@ ConstEval::Result ConstEval::ArrayOrStructInit(const type::Type* ty, } // Multiple arguments. Must be a type initializer. - utils::Vector els; + utils::Vector els; els.Reserve(args.Length()); for (auto* arg : args) { els.Push(arg->ConstantValue()); @@ -1271,7 +1215,7 @@ ConstEval::Result ConstEval::ArrayOrStructInit(const type::Type* ty, } ConstEval::Result ConstEval::Conv(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { uint32_t el_count = 0; auto* el_ty = type::Type::ElementOf(ty, &el_count); @@ -1287,19 +1231,19 @@ ConstEval::Result ConstEval::Conv(const type::Type* ty, } ConstEval::Result ConstEval::Zero(const type::Type* ty, - utils::VectorRef, + utils::VectorRef, const Source&) { return ZeroValue(builder, ty); } ConstEval::Result ConstEval::Identity(const type::Type*, - utils::VectorRef args, + utils::VectorRef args, const Source&) { return args[0]; } ConstEval::Result ConstEval::VecSplat(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source&) { if (auto* arg = args[0]) { return builder.create(ty, arg, @@ -1309,15 +1253,15 @@ ConstEval::Result ConstEval::VecSplat(const type::Type* ty, } ConstEval::Result ConstEval::VecInitS(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source&) { return CreateComposite(builder, ty, args); } ConstEval::Result ConstEval::VecInitM(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source&) { - utils::Vector els; + utils::Vector els; for (auto* arg : args) { auto* val = arg; if (!val) { @@ -1341,13 +1285,13 @@ ConstEval::Result ConstEval::VecInitM(const type::Type* ty, } ConstEval::Result ConstEval::MatInitS(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source&) { auto* m = static_cast(ty); - utils::Vector els; + utils::Vector els; for (uint32_t c = 0; c < m->columns(); c++) { - utils::Vector column; + utils::Vector column; for (uint32_t r = 0; r < m->rows(); r++) { auto i = r + c * m->rows(); column.Push(args[i]); @@ -1358,7 +1302,7 @@ ConstEval::Result ConstEval::MatInitS(const type::Type* ty, } ConstEval::Result ConstEval::MatInitV(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source&) { return CreateComposite(builder, ty, args); } @@ -1422,9 +1366,9 @@ ConstEval::Result ConstEval::Bitcast(const type::Type*, const sem::Expression*) } ConstEval::Result ConstEval::OpComplement(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c) { + auto transform = [&](const constant::Value* c) { auto create = [&](auto i) { return CreateScalar(builder, source, c->Type(), decltype(i)(~i.value)); }; @@ -1434,9 +1378,9 @@ ConstEval::Result ConstEval::OpComplement(const type::Type* ty, } ConstEval::Result ConstEval::OpUnaryMinus(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c) { + auto transform = [&](const constant::Value* c) { auto create = [&](auto i) { // For signed integrals, avoid C++ UB by not negating the // smallest negative number. In WGSL, this operation is well @@ -1459,9 +1403,9 @@ ConstEval::Result ConstEval::OpUnaryMinus(const type::Type* ty, } ConstEval::Result ConstEval::OpNot(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c) { + auto transform = [&](const constant::Value* c) { auto create = [&](auto i) { return CreateScalar(builder, source, c->Type(), decltype(i)(!i)); }; @@ -1471,9 +1415,9 @@ ConstEval::Result ConstEval::OpNot(const type::Type* ty, } ConstEval::Result ConstEval::OpPlus(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { return Dispatch_fia_fiu32_f16(AddFunc(source, c0->Type()), c0, c1); }; @@ -1481,25 +1425,25 @@ ConstEval::Result ConstEval::OpPlus(const type::Type* ty, } ConstEval::Result ConstEval::OpMinus(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { return Sub(source, ty, args[0], args[1]); } ConstEval::Result ConstEval::OpMultiply(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { return Mul(source, ty, args[0], args[1]); } ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* mat_ty = args[0]->Type()->As(); auto* vec_ty = args[1]->Type()->As(); auto* elem_ty = vec_ty->type(); - auto dot = [&](const constant::Constant* m, size_t row, const constant::Constant* v) { + auto dot = [&](const constant::Value* m, size_t row, const constant::Value* v) { ConstEval::Result result; switch (mat_ty->columns()) { case 2: @@ -1532,7 +1476,7 @@ ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty, return result; }; - utils::Vector result; + utils::Vector result; for (size_t i = 0; i < mat_ty->rows(); ++i) { auto r = dot(args[0], i, args[1]); // matrix row i * vector if (!r) { @@ -1543,13 +1487,13 @@ ConstEval::Result ConstEval::OpMultiplyMatVec(const type::Type* ty, return CreateComposite(builder, ty, result); } ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* vec_ty = args[0]->Type()->As(); auto* mat_ty = args[1]->Type()->As(); auto* elem_ty = vec_ty->type(); - auto dot = [&](const constant::Constant* v, const constant::Constant* m, size_t col) { + auto dot = [&](const constant::Value* v, const constant::Value* m, size_t col) { ConstEval::Result result; switch (mat_ty->rows()) { case 2: @@ -1582,7 +1526,7 @@ ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, return result; }; - utils::Vector result; + utils::Vector result; for (size_t i = 0; i < mat_ty->columns(); ++i) { auto r = dot(args[0], args[1], i); // vector * matrix col i if (!r) { @@ -1594,7 +1538,7 @@ ConstEval::Result ConstEval::OpMultiplyVecMat(const type::Type* ty, } ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* mat1 = args[0]; auto* mat2 = args[1]; @@ -1602,8 +1546,7 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, auto* mat2_ty = mat2->Type()->As(); auto* elem_ty = mat1_ty->type(); - auto dot = [&](const constant::Constant* m1, size_t row, const constant::Constant* m2, - size_t col) { + auto dot = [&](const constant::Value* m1, size_t row, const constant::Value* m2, size_t col) { auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); }; auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); }; @@ -1640,9 +1583,9 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, return result; }; - utils::Vector result_mat; + utils::Vector result_mat; for (size_t c = 0; c < mat2_ty->columns(); ++c) { - utils::Vector col_vec; + utils::Vector col_vec; for (size_t r = 0; r < mat1_ty->rows(); ++r) { auto v = dot(mat1, r, mat2, c); // mat1 row r * mat2 col c if (!v) { @@ -1659,9 +1602,9 @@ ConstEval::Result ConstEval::OpMultiplyMatMat(const type::Type* ty, } ConstEval::Result ConstEval::OpDivide(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { return Dispatch_fia_fiu32_f16(DivFunc(source, c0->Type()), c0, c1); }; @@ -1669,9 +1612,9 @@ ConstEval::Result ConstEval::OpDivide(const type::Type* ty, } ConstEval::Result ConstEval::OpModulo(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { return Dispatch_fia_fiu32_f16(ModFunc(source, c0->Type()), c0, c1); }; @@ -1679,9 +1622,9 @@ ConstEval::Result ConstEval::OpModulo(const type::Type* ty, } ConstEval::Result ConstEval::OpEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i == j); }; @@ -1692,9 +1635,9 @@ ConstEval::Result ConstEval::OpEqual(const type::Type* ty, } ConstEval::Result ConstEval::OpNotEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i != j); }; @@ -1705,9 +1648,9 @@ ConstEval::Result ConstEval::OpNotEqual(const type::Type* ty, } ConstEval::Result ConstEval::OpLessThan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i < j); }; @@ -1718,9 +1661,9 @@ ConstEval::Result ConstEval::OpLessThan(const type::Type* ty, } ConstEval::Result ConstEval::OpGreaterThan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i > j); }; @@ -1731,9 +1674,9 @@ ConstEval::Result ConstEval::OpGreaterThan(const type::Type* ty, } ConstEval::Result ConstEval::OpLessThanEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i <= j); }; @@ -1744,9 +1687,9 @@ ConstEval::Result ConstEval::OpLessThanEqual(const type::Type* ty, } ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), i >= j); }; @@ -1757,7 +1700,7 @@ ConstEval::Result ConstEval::OpGreaterThanEqual(const type::Type* ty, } ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { // Note: Due to short-circuiting, this function is only called if lhs is true, so we could // technically only return the value of the rhs. @@ -1765,7 +1708,7 @@ ConstEval::Result ConstEval::OpLogicalAnd(const type::Type* ty, } ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { // Note: Due to short-circuiting, this function is only called if lhs is false, so we could // technically only return the value of the rhs. @@ -1773,9 +1716,9 @@ ConstEval::Result ConstEval::OpLogicalOr(const type::Type* ty, } ConstEval::Result ConstEval::OpAnd(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { using T = decltype(i); T result; @@ -1793,9 +1736,9 @@ ConstEval::Result ConstEval::OpAnd(const type::Type* ty, } ConstEval::Result ConstEval::OpOr(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { using T = decltype(i); T result; @@ -1813,9 +1756,9 @@ ConstEval::Result ConstEval::OpOr(const type::Type* ty, } ConstEval::Result ConstEval::OpXor(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), decltype(i){i ^ j}); @@ -1827,9 +1770,9 @@ ConstEval::Result ConstEval::OpXor(const type::Type* ty, } ConstEval::Result ConstEval::OpShiftLeft(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto e1, auto e2) -> ConstEval::Result { using NumberT = decltype(e1); using T = UnwrapNumber; @@ -1913,9 +1856,9 @@ ConstEval::Result ConstEval::OpShiftLeft(const type::Type* ty, } ConstEval::Result ConstEval::OpShiftRight(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto e1, auto e2) -> ConstEval::Result { using NumberT = decltype(e1); using T = UnwrapNumber; @@ -1977,9 +1920,9 @@ ConstEval::Result ConstEval::OpShiftRight(const type::Type* ty, } ConstEval::Result ConstEval::abs(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); NumberT result; @@ -2002,9 +1945,9 @@ ConstEval::Result ConstEval::abs(const type::Type* ty, } ConstEval::Result ConstEval::acos(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); if (i < NumberT(-1.0) || i > NumberT(1.0)) { @@ -2020,9 +1963,9 @@ ConstEval::Result ConstEval::acos(const type::Type* ty, } ConstEval::Result ConstEval::acosh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); if (i < NumberT(1.0)) { @@ -2038,21 +1981,21 @@ ConstEval::Result ConstEval::acosh(const type::Type* ty, } ConstEval::Result ConstEval::all(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { return CreateScalar(builder, source, ty, !args[0]->AnyZero()); } ConstEval::Result ConstEval::any(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { return CreateScalar(builder, source, ty, !args[0]->AllZero()); } ConstEval::Result ConstEval::asin(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); if (i < NumberT(-1.0) || i > NumberT(1.0)) { @@ -2068,9 +2011,9 @@ ConstEval::Result ConstEval::asin(const type::Type* ty, } ConstEval::Result ConstEval::asinh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) { return CreateScalar(builder, source, c0->Type(), decltype(i)(std::asinh(i.value))); }; @@ -2081,9 +2024,9 @@ ConstEval::Result ConstEval::asinh(const type::Type* ty, } ConstEval::Result ConstEval::atan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) { return CreateScalar(builder, source, c0->Type(), decltype(i)(std::atan(i.value))); }; @@ -2093,9 +2036,9 @@ ConstEval::Result ConstEval::atan(const type::Type* ty, } ConstEval::Result ConstEval::atanh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); if (i <= NumberT(-1.0) || i >= NumberT(1.0)) { @@ -2112,9 +2055,9 @@ ConstEval::Result ConstEval::atanh(const type::Type* ty, } ConstEval::Result ConstEval::atan2(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto i, auto j) { return CreateScalar(builder, source, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); @@ -2125,9 +2068,9 @@ ConstEval::Result ConstEval::atan2(const type::Type* ty, } ConstEval::Result ConstEval::ceil(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { return CreateScalar(builder, source, c0->Type(), decltype(e)(std::ceil(e))); }; @@ -2137,19 +2080,19 @@ ConstEval::Result ConstEval::ceil(const type::Type* ty, } ConstEval::Result ConstEval::clamp(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, - const constant::Constant* c2) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1, + const constant::Value* c2) { return Dispatch_fia_fiu32_f16(ClampFunc(source, c0->Type()), c0, c1, c2); }; return TransformElements(builder, ty, transform, args[0], args[1], args[2]); } ConstEval::Result ConstEval::cos(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); return CreateScalar(builder, source, c0->Type(), NumberT(std::cos(i.value))); @@ -2160,9 +2103,9 @@ ConstEval::Result ConstEval::cos(const type::Type* ty, } ConstEval::Result ConstEval::cosh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); return CreateScalar(builder, source, c0->Type(), NumberT(std::cosh(i.value))); @@ -2173,9 +2116,9 @@ ConstEval::Result ConstEval::cosh(const type::Type* ty, } ConstEval::Result ConstEval::countLeadingZeros(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; @@ -2188,9 +2131,9 @@ ConstEval::Result ConstEval::countLeadingZeros(const type::Type* ty, } ConstEval::Result ConstEval::countOneBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; @@ -2212,9 +2155,9 @@ ConstEval::Result ConstEval::countOneBits(const type::Type* ty, } ConstEval::Result ConstEval::countTrailingZeros(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; @@ -2227,7 +2170,7 @@ ConstEval::Result ConstEval::countTrailingZeros(const type::Type* ty, } ConstEval::Result ConstEval::cross(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* u = args[0]; auto* v = args[1]; @@ -2266,13 +2209,13 @@ ConstEval::Result ConstEval::cross(const type::Type* ty, } return CreateComposite(builder, ty, - utils::Vector{x.Get(), y.Get(), z.Get()}); + utils::Vector{x.Get(), y.Get(), z.Get()}); } ConstEval::Result ConstEval::degrees(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) -> ConstEval::Result { using NumberT = decltype(e); using T = UnwrapNumber; @@ -2296,7 +2239,7 @@ ConstEval::Result ConstEval::degrees(const type::Type* ty, } ConstEval::Result ConstEval::determinant(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto calculate = [&]() -> ConstEval::Result { auto* m = args[0]; @@ -2332,7 +2275,7 @@ ConstEval::Result ConstEval::determinant(const type::Type* ty, } ConstEval::Result ConstEval::distance(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto err = [&]() -> ConstEval::Result { AddNote("when calculating distance", source); @@ -2352,7 +2295,7 @@ ConstEval::Result ConstEval::distance(const type::Type* ty, } ConstEval::Result ConstEval::dot(const type::Type*, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto r = Dot(source, args[0], args[1]); if (!r) { @@ -2362,9 +2305,9 @@ ConstEval::Result ConstEval::dot(const type::Type*, } ConstEval::Result ConstEval::exp(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e0) -> ConstEval::Result { using NumberT = decltype(e0); auto val = NumberT(std::exp(e0)); @@ -2380,9 +2323,9 @@ ConstEval::Result ConstEval::exp(const type::Type* ty, } ConstEval::Result ConstEval::exp2(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e0) -> ConstEval::Result { using NumberT = decltype(e0); auto val = NumberT(std::exp2(e0)); @@ -2398,9 +2341,9 @@ ConstEval::Result ConstEval::exp2(const type::Type* ty, } ConstEval::Result ConstEval::extractBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto in_e) -> ConstEval::Result { using NumberT = decltype(in_e); using T = UnwrapNumber; @@ -2453,7 +2396,7 @@ ConstEval::Result ConstEval::extractBits(const type::Type* ty, } ConstEval::Result ConstEval::faceForward(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { // Returns e1 if dot(e2, e3) is negative, and -e1 otherwise. auto* e1 = args[0]; @@ -2472,9 +2415,9 @@ ConstEval::Result ConstEval::faceForward(const type::Type* ty, } ConstEval::Result ConstEval::firstLeadingBit(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; @@ -2516,9 +2459,9 @@ ConstEval::Result ConstEval::firstLeadingBit(const type::Type* ty, } ConstEval::Result ConstEval::firstTrailingBit(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; @@ -2542,9 +2485,9 @@ ConstEval::Result ConstEval::firstTrailingBit(const type::Type* ty, } ConstEval::Result ConstEval::floor(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { return CreateScalar(builder, source, c0->Type(), decltype(e)(std::floor(e))); }; @@ -2554,10 +2497,10 @@ ConstEval::Result ConstEval::floor(const type::Type* ty, } ConstEval::Result ConstEval::fma(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c1, const constant::Constant* c2, - const constant::Constant* c3) { + auto transform = [&](const constant::Value* c1, const constant::Value* c2, + const constant::Value* c3) { auto create = [&](auto e1, auto e2, auto e3) -> ConstEval::Result { auto err_msg = [&] { AddNote("when calculating fma", source); @@ -2581,9 +2524,9 @@ ConstEval::Result ConstEval::fma(const type::Type* ty, } ConstEval::Result ConstEval::fract(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c1) { + auto transform = [&](const constant::Value* c1) { auto create = [&](auto e) -> ConstEval::Result { using NumberT = decltype(e); auto r = e - std::floor(e); @@ -2595,7 +2538,7 @@ ConstEval::Result ConstEval::fract(const type::Type* ty, } ConstEval::Result ConstEval::frexp(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* arg = args[0]; @@ -2604,7 +2547,7 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, ConstEval::Result exp; }; - auto scalar = [&](const constant::Constant* s) { + auto scalar = [&](const constant::Value* s) { int exp = 0; double fract = std::frexp(s->ValueAs(), &exp); return Switch( @@ -2637,8 +2580,8 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, }; if (auto* vec = arg->Type()->As()) { - utils::Vector fract_els; - utils::Vector exp_els; + utils::Vector fract_els; + utils::Vector exp_els; for (uint32_t i = 0; i < vec->Width(); i++) { auto fe = scalar(arg->Index(i)); if (!fe.fract || !fe.exp) { @@ -2650,7 +2593,7 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, auto fract_ty = builder.create(fract_els[0]->Type(), vec->Width()); auto exp_ty = builder.create(exp_els[0]->Type(), vec->Width()); return CreateComposite(builder, ty, - utils::Vector{ + utils::Vector{ CreateComposite(builder, fract_ty, std::move(fract_els)), CreateComposite(builder, exp_ty, std::move(exp_els)), }); @@ -2660,7 +2603,7 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, return utils::Failure; } return CreateComposite(builder, ty, - utils::Vector{ + utils::Vector{ fe.fract.Get(), fe.exp.Get(), }); @@ -2668,9 +2611,9 @@ ConstEval::Result ConstEval::frexp(const type::Type* ty, } ConstEval::Result ConstEval::insertBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto in_e, auto in_newbits) -> ConstEval::Result { using NumberT = decltype(in_e); using T = UnwrapNumber; @@ -2720,9 +2663,9 @@ ConstEval::Result ConstEval::insertBits(const type::Type* ty, } ConstEval::Result ConstEval::inverseSqrt(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) -> ConstEval::Result { using NumberT = decltype(e); @@ -2754,7 +2697,7 @@ ConstEval::Result ConstEval::inverseSqrt(const type::Type* ty, } ConstEval::Result ConstEval::length(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto r = Length(source, ty, args[0]); if (!r) { @@ -2764,9 +2707,9 @@ ConstEval::Result ConstEval::length(const type::Type* ty, } ConstEval::Result ConstEval::log(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto v) -> ConstEval::Result { using NumberT = decltype(v); if (v <= NumberT(0)) { @@ -2781,9 +2724,9 @@ ConstEval::Result ConstEval::log(const type::Type* ty, } ConstEval::Result ConstEval::log2(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto v) -> ConstEval::Result { using NumberT = decltype(v); if (v <= NumberT(0)) { @@ -2798,9 +2741,9 @@ ConstEval::Result ConstEval::log2(const type::Type* ty, } ConstEval::Result ConstEval::max(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto e0, auto e1) { return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::max(e0, e1))); }; @@ -2810,9 +2753,9 @@ ConstEval::Result ConstEval::max(const type::Type* ty, } ConstEval::Result ConstEval::min(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto e0, auto e1) { return CreateScalar(builder, source, c0->Type(), decltype(e0)(std::min(e0, e1))); }; @@ -2822,9 +2765,9 @@ ConstEval::Result ConstEval::min(const type::Type* ty, } ConstEval::Result ConstEval::mix(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1, size_t index) { auto create = [&](auto e1, auto e2) -> ConstEval::Result { using NumberT = decltype(e1); // e3 is either a vector or a scalar @@ -2865,23 +2808,23 @@ ConstEval::Result ConstEval::mix(const type::Type* ty, } ConstEval::Result ConstEval::modf(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform_fract = [&](const constant::Constant* c) { + auto transform_fract = [&](const constant::Value* c) { auto create = [&](auto e) { return CreateScalar(builder, source, c->Type(), decltype(e)(e.value - std::trunc(e.value))); }; return Dispatch_fa_f32_f16(create, c); }; - auto transform_whole = [&](const constant::Constant* c) { + auto transform_whole = [&](const constant::Value* c) { auto create = [&](auto e) { return CreateScalar(builder, source, c->Type(), decltype(e)(std::trunc(e.value))); }; return Dispatch_fa_f32_f16(create, c); }; - utils::Vector fields; + utils::Vector fields; if (auto fract = TransformElements(builder, args[0]->Type(), transform_fract, args[0])) { fields.Push(fract.Get()); @@ -2899,7 +2842,7 @@ ConstEval::Result ConstEval::modf(const type::Type* ty, } ConstEval::Result ConstEval::normalize(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* len_ty = type::Type::DeepestElementOf(ty); auto len = Length(source, len_ty, args[0]); @@ -2916,7 +2859,7 @@ ConstEval::Result ConstEval::normalize(const type::Type* ty, } ConstEval::Result ConstEval::pack2x16float(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto convert = [&](f32 val) -> utils::Result { auto conv = CheckedConvert(val); @@ -2944,7 +2887,7 @@ ConstEval::Result ConstEval::pack2x16float(const type::Type* ty, } ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto calc = [&](f32 val) -> u32 { auto clamped = Clamp(source, val, f32(-1.0f), f32(1.0f)).Get(); @@ -2961,7 +2904,7 @@ ConstEval::Result ConstEval::pack2x16snorm(const type::Type* ty, } ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto calc = [&](f32 val) -> u32 { auto clamped = Clamp(source, val, f32(0.0f), f32(1.0f)).Get(); @@ -2977,7 +2920,7 @@ ConstEval::Result ConstEval::pack2x16unorm(const type::Type* ty, } ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto calc = [&](f32 val) -> u32 { auto clamped = Clamp(source, val, f32(-1.0f), f32(1.0f)).Get(); @@ -2997,7 +2940,7 @@ ConstEval::Result ConstEval::pack4x8snorm(const type::Type* ty, } ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto calc = [&](f32 val) -> u32 { auto clamped = Clamp(source, val, f32(0.0f), f32(1.0f)).Get(); @@ -3016,9 +2959,9 @@ ConstEval::Result ConstEval::pack4x8unorm(const type::Type* ty, } ConstEval::Result ConstEval::pow(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto e1, auto e2) -> ConstEval::Result { auto r = CheckedPow(e1, e2); if (!r) { @@ -3033,9 +2976,9 @@ ConstEval::Result ConstEval::pow(const type::Type* ty, } ConstEval::Result ConstEval::radians(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) -> ConstEval::Result { using NumberT = decltype(e); using T = UnwrapNumber; @@ -3059,7 +3002,7 @@ ConstEval::Result ConstEval::radians(const type::Type* ty, } ConstEval::Result ConstEval::reflect(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto calculate = [&]() -> ConstEval::Result { // For the incident vector e1 and surface orientation e2, returns the reflection direction @@ -3102,7 +3045,7 @@ ConstEval::Result ConstEval::reflect(const type::Type* ty, } ConstEval::Result ConstEval::refract(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* vec_ty = ty->As(); auto* el_ty = vec_ty->type(); @@ -3200,9 +3143,9 @@ ConstEval::Result ConstEval::refract(const type::Type* ty, } ConstEval::Result ConstEval::reverseBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto in_e) -> ConstEval::Result { using NumberT = decltype(in_e); using T = UnwrapNumber; @@ -3227,9 +3170,9 @@ ConstEval::Result ConstEval::reverseBits(const type::Type* ty, } ConstEval::Result ConstEval::round(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); using T = UnwrapNumber; @@ -3263,9 +3206,9 @@ ConstEval::Result ConstEval::round(const type::Type* ty, } ConstEval::Result ConstEval::saturate(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) { using NumberT = decltype(e); return CreateScalar(builder, source, c0->Type(), @@ -3277,10 +3220,10 @@ ConstEval::Result ConstEval::saturate(const type::Type* ty, } ConstEval::Result ConstEval::select_bool(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto cond = args[2]->ValueAs(); - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto f, auto t) -> ConstEval::Result { return CreateScalar(builder, source, type::Type::DeepestElementOf(ty), cond ? t : f); }; @@ -3291,9 +3234,9 @@ ConstEval::Result ConstEval::select_bool(const type::Type* ty, } ConstEval::Result ConstEval::select_boolvec(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, size_t index) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1, size_t index) { auto create = [&](auto f, auto t) -> ConstEval::Result { // Get corresponding bool value at the current vector value index auto cond = args[2]->Index(index)->ValueAs(); @@ -3306,9 +3249,9 @@ ConstEval::Result ConstEval::select_boolvec(const type::Type* ty, } ConstEval::Result ConstEval::sign(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto e) -> ConstEval::Result { using NumberT = decltype(e); NumberT result; @@ -3328,9 +3271,9 @@ ConstEval::Result ConstEval::sign(const type::Type* ty, } ConstEval::Result ConstEval::sin(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); return CreateScalar(builder, source, c0->Type(), NumberT(std::sin(i.value))); @@ -3341,9 +3284,9 @@ ConstEval::Result ConstEval::sin(const type::Type* ty, } ConstEval::Result ConstEval::sinh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); return CreateScalar(builder, source, c0->Type(), NumberT(std::sinh(i.value))); @@ -3354,10 +3297,10 @@ ConstEval::Result ConstEval::sinh(const type::Type* ty, } ConstEval::Result ConstEval::smoothstep(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1, - const constant::Constant* c2) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1, + const constant::Value* c2) { auto create = [&](auto low, auto high, auto x) -> ConstEval::Result { using NumberT = decltype(low); @@ -3405,9 +3348,9 @@ ConstEval::Result ConstEval::smoothstep(const type::Type* ty, } ConstEval::Result ConstEval::step(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0, const constant::Constant* c1) { + auto transform = [&](const constant::Value* c0, const constant::Value* c1) { auto create = [&](auto edge, auto x) -> ConstEval::Result { using NumberT = decltype(edge); NumberT result = x.value < edge.value ? NumberT(0.0) : NumberT(1.0); @@ -3419,9 +3362,9 @@ ConstEval::Result ConstEval::step(const type::Type* ty, } ConstEval::Result ConstEval::sqrt(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { return Dispatch_fa_f32_f16(SqrtFunc(source, c0->Type()), c0); }; @@ -3429,9 +3372,9 @@ ConstEval::Result ConstEval::sqrt(const type::Type* ty, } ConstEval::Result ConstEval::tan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); return CreateScalar(builder, source, c0->Type(), NumberT(std::tan(i.value))); @@ -3442,9 +3385,9 @@ ConstEval::Result ConstEval::tan(const type::Type* ty, } ConstEval::Result ConstEval::tanh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) -> ConstEval::Result { using NumberT = decltype(i); return CreateScalar(builder, source, c0->Type(), NumberT(std::tanh(i.value))); @@ -3455,7 +3398,7 @@ ConstEval::Result ConstEval::tanh(const type::Type* ty, } ConstEval::Result ConstEval::transpose(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source&) { auto* m = args[0]; auto* mat_ty = m->Type()->As(); @@ -3463,9 +3406,9 @@ ConstEval::Result ConstEval::transpose(const type::Type* ty, auto* result_mat_ty = ty->As(); // Produce column vectors from each row - utils::Vector result_mat; + utils::Vector result_mat; for (size_t r = 0; r < mat_ty->rows(); ++r) { - utils::Vector new_col_vec; + utils::Vector new_col_vec; for (size_t c = 0; c < mat_ty->columns(); ++c) { new_col_vec.Push(me(r, c)); } @@ -3475,9 +3418,9 @@ ConstEval::Result ConstEval::transpose(const type::Type* ty, } ConstEval::Result ConstEval::trunc(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c0) { + auto transform = [&](const constant::Value* c0) { auto create = [&](auto i) { return CreateScalar(builder, source, c0->Type(), decltype(i)(std::trunc(i.value))); }; @@ -3487,12 +3430,12 @@ ConstEval::Result ConstEval::trunc(const type::Type* ty, } ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* inner_ty = type::Type::DeepestElementOf(ty); auto e = args[0]->ValueAs().value; - utils::Vector els; + utils::Vector els; els.Reserve(2); for (size_t i = 0; i < 2; ++i) { auto in = f16::FromBits(uint16_t((e >> (16 * i)) & 0x0000'ffff)); @@ -3511,12 +3454,12 @@ ConstEval::Result ConstEval::unpack2x16float(const type::Type* ty, } ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* inner_ty = type::Type::DeepestElementOf(ty); auto e = args[0]->ValueAs().value; - utils::Vector els; + utils::Vector els; els.Reserve(2); for (size_t i = 0; i < 2; ++i) { auto val = f32( @@ -3531,12 +3474,12 @@ ConstEval::Result ConstEval::unpack2x16snorm(const type::Type* ty, } ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* inner_ty = type::Type::DeepestElementOf(ty); auto e = args[0]->ValueAs().value; - utils::Vector els; + utils::Vector els; els.Reserve(2); for (size_t i = 0; i < 2; ++i) { auto val = f32(static_cast(uint16_t((e >> (16 * i)) & 0x0000'ffff)) / 65535.f); @@ -3550,12 +3493,12 @@ ConstEval::Result ConstEval::unpack2x16unorm(const type::Type* ty, } ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* inner_ty = type::Type::DeepestElementOf(ty); auto e = args[0]->ValueAs().value; - utils::Vector els; + utils::Vector els; els.Reserve(4); for (size_t i = 0; i < 4; ++i) { auto val = @@ -3570,12 +3513,12 @@ ConstEval::Result ConstEval::unpack4x8snorm(const type::Type* ty, } ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { auto* inner_ty = type::Type::DeepestElementOf(ty); auto e = args[0]->ValueAs().value; - utils::Vector els; + utils::Vector els; els.Reserve(4); for (size_t i = 0; i < 4; ++i) { auto val = f32(static_cast(uint8_t((e >> (8 * i)) & 0x0000'00ff)) / 255.f); @@ -3589,9 +3532,9 @@ ConstEval::Result ConstEval::unpack4x8unorm(const type::Type* ty, } ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source) { - auto transform = [&](const constant::Constant* c) -> ConstEval::Result { + auto transform = [&](const constant::Value* c) -> ConstEval::Result { auto value = c->ValueAs(); auto conv = CheckedConvert(f16(value)); if (!conv) { @@ -3604,7 +3547,7 @@ ConstEval::Result ConstEval::quantizeToF16(const type::Type* ty, } ConstEval::Result ConstEval::Convert(const type::Type* target_ty, - const constant::Constant* value, + const constant::Value* value, const Source& source) { if (value->Type() == target_ty) { return value; diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 6caac076f0..baeb7e468c 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -31,7 +31,7 @@ namespace tint::ast { class LiteralExpression; } // namespace tint::ast namespace tint::constant { -class Constant; +class Value; } // namespace tint::constant namespace tint::sem { class Expression; @@ -50,20 +50,20 @@ class ConstEval { public: /// The result type of a method that may raise a diagnostic error and the caller should abort /// resolving. Can be one of three distinct values: - /// * A non-null constant::Constant pointer. Returned when a expression resolves to a creation + /// * A non-null constant::Value pointer. Returned when a expression resolves to a creation /// time /// value. - /// * A null constant::Constant pointer. Returned when a expression cannot resolve to a creation + /// * A null constant::Value pointer. Returned when a expression cannot resolve to a creation /// time /// value, but is otherwise legal. /// * `utils::Failure`. Returned when there was a resolver error. In this situation the method /// will have already reported a diagnostic error message, and the caller should abort /// resolving. - using Result = utils::Result; + using Result = utils::Result; /// Typedef for a constant evaluation function using Function = Result (ConstEval::*)(const type::Type* result_ty, - utils::VectorRef, + utils::VectorRef, const Source&); /// Constructor @@ -113,7 +113,7 @@ class ConstEval { /// @param value the value being converted /// @param source the source location /// @return the converted value, or null if the value cannot be calculated - Result Convert(const type::Type* ty, const constant::Constant* value, const Source& source); + Result Convert(const type::Type* ty, const constant::Value* value, const Source& source); //////////////////////////////////////////////////////////////////////////////////////////////// // Constant value evaluation methods, to be indirectly called via the intrinsic table @@ -125,7 +125,7 @@ class ConstEval { /// @param source the source location /// @return the converted value, or null if the value cannot be calculated Result Conv(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Zero value type initializer @@ -134,7 +134,7 @@ class ConstEval { /// @param source the source location /// @return the constructed value, or null if the value cannot be calculated Result Zero(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Identity value type initializer @@ -143,7 +143,7 @@ class ConstEval { /// @param source the source location /// @return the constructed value, or null if the value cannot be calculated Result Identity(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Vector splat initializer @@ -152,7 +152,7 @@ class ConstEval { /// @param source the source location /// @return the constructed value, or null if the value cannot be calculated Result VecSplat(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Vector initializer using scalars @@ -161,7 +161,7 @@ class ConstEval { /// @param source the source location /// @return the constructed value, or null if the value cannot be calculated Result VecInitS(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Vector initializer using a mix of scalars and smaller vectors @@ -170,7 +170,7 @@ class ConstEval { /// @param source the source location /// @return the constructed value, or null if the value cannot be calculated Result VecInitM(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Matrix initializer using scalar values @@ -179,7 +179,7 @@ class ConstEval { /// @param source the source location /// @return the constructed value, or null if the value cannot be calculated Result MatInitS(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Matrix initializer using column vectors @@ -188,7 +188,7 @@ class ConstEval { /// @param source the source location /// @return the constructed value, or null if the value cannot be calculated Result MatInitV(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); //////////////////////////////////////////////////////////////////////////// @@ -201,7 +201,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpComplement(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Unary minus operator '-' @@ -210,7 +210,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpUnaryMinus(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Unary not operator '!' @@ -219,7 +219,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpNot(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); //////////////////////////////////////////////////////////////////////////// @@ -232,7 +232,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpPlus(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Minus operator '-' @@ -241,7 +241,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpMinus(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Multiply operator '*' for the same type on the LHS and RHS @@ -250,7 +250,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpMultiply(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Multiply operator '*' for matCxR * vecC @@ -259,7 +259,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpMultiplyMatVec(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Multiply operator '*' for vecR * matCxR @@ -268,7 +268,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpMultiplyVecMat(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Multiply operator '*' for matKxR * matCxK @@ -277,7 +277,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpMultiplyMatMat(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Divide operator '/' @@ -286,7 +286,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpDivide(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Modulo operator '%' @@ -295,7 +295,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpModulo(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Equality operator '==' @@ -304,7 +304,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Inequality operator '!=' @@ -313,7 +313,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpNotEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Less than operator '<' @@ -322,7 +322,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpLessThan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Greater than operator '>' @@ -331,7 +331,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpGreaterThan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Less than or equal operator '<=' @@ -340,7 +340,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpLessThanEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Greater than or equal operator '>=' @@ -349,7 +349,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpGreaterThanEqual(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Logical and operator '&&' @@ -358,7 +358,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpLogicalAnd(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Logical or operator '||' @@ -367,7 +367,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpLogicalOr(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Bitwise and operator '&' @@ -376,7 +376,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpAnd(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Bitwise or operator '|' @@ -385,7 +385,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpOr(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Bitwise xor operator '^' @@ -394,7 +394,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpXor(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Bitwise shift left operator '<<' @@ -403,7 +403,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpShiftLeft(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// Bitwise shift right operator '<<' @@ -412,7 +412,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result OpShiftRight(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); //////////////////////////////////////////////////////////////////////////// @@ -425,7 +425,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result abs(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// acos builtin @@ -434,7 +434,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result acos(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// acosh builtin @@ -443,7 +443,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result acosh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// all builtin @@ -452,7 +452,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result all(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// any builtin @@ -461,7 +461,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result any(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// asin builtin @@ -470,7 +470,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result asin(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// asinh builtin @@ -479,7 +479,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result asinh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// atan builtin @@ -488,7 +488,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result atan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// atanh builtin @@ -497,7 +497,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result atanh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// atan2 builtin @@ -506,7 +506,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result atan2(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// ceil builtin @@ -515,7 +515,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result ceil(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// clamp builtin @@ -524,7 +524,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result clamp(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// cos builtin @@ -533,7 +533,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result cos(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// cosh builtin @@ -542,7 +542,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result cosh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// countLeadingZeros builtin @@ -551,7 +551,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result countLeadingZeros(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// countOneBits builtin @@ -560,7 +560,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result countOneBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// countTrailingZeros builtin @@ -569,7 +569,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result countTrailingZeros(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// cross builtin @@ -578,7 +578,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result cross(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// degrees builtin @@ -587,7 +587,7 @@ class ConstEval { /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated Result degrees(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// determinant builtin @@ -596,7 +596,7 @@ class ConstEval { /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated Result determinant(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// distance builtin @@ -605,7 +605,7 @@ class ConstEval { /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated Result distance(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// dot builtin @@ -614,7 +614,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result dot(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// exp builtin @@ -623,7 +623,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result exp(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// exp2 builtin @@ -632,7 +632,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result exp2(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// extractBits builtin @@ -641,7 +641,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result extractBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// faceForward builtin @@ -650,7 +650,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result faceForward(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// firstLeadingBit builtin @@ -659,7 +659,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result firstLeadingBit(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// firstTrailingBit builtin @@ -668,7 +668,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result firstTrailingBit(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// floor builtin @@ -677,7 +677,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result floor(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// fma builtin @@ -686,7 +686,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result fma(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// fract builtin @@ -695,7 +695,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result fract(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// frexp builtin @@ -704,7 +704,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result frexp(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// insertBits builtin @@ -713,7 +713,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result insertBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// inverseSqrt builtin @@ -722,7 +722,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result inverseSqrt(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// length builtin @@ -731,7 +731,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result length(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// log builtin @@ -740,7 +740,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result log(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// log2 builtin @@ -749,7 +749,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result log2(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// max builtin @@ -758,7 +758,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result max(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// min builtin @@ -767,7 +767,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result min(const type::Type* ty, // NOLINT(build/include_what_you_use) -- confused by min - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// mix builtin @@ -776,7 +776,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result mix(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// modf builtin @@ -785,7 +785,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result modf(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// normalize builtin @@ -794,7 +794,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result normalize(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// pack2x16float builtin @@ -803,7 +803,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result pack2x16float(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// pack2x16snorm builtin @@ -812,7 +812,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result pack2x16snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// pack2x16unorm builtin @@ -821,7 +821,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result pack2x16unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// pack4x8snorm builtin @@ -830,7 +830,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result pack4x8snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// pack4x8unorm builtin @@ -839,7 +839,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result pack4x8unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// pow builtin @@ -848,7 +848,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result pow(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// radians builtin @@ -857,7 +857,7 @@ class ConstEval { /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated Result radians(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// reflect builtin @@ -866,7 +866,7 @@ class ConstEval { /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated Result reflect(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// refract builtin @@ -875,7 +875,7 @@ class ConstEval { /// @param source the source location of the conversion /// @return the result value, or null if the value cannot be calculated Result refract(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// reverseBits builtin @@ -884,7 +884,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result reverseBits(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// round builtin @@ -893,7 +893,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result round(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// saturate builtin @@ -902,7 +902,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result saturate(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// select builtin with single bool third arg @@ -911,7 +911,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result select_bool(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// select builtin with vector of bool third arg @@ -920,7 +920,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result select_boolvec(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// sign builtin @@ -929,7 +929,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result sign(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// sin builtin @@ -938,7 +938,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result sin(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// sinh builtin @@ -947,7 +947,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result sinh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// smoothstep builtin @@ -956,7 +956,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result smoothstep(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// step builtin @@ -965,7 +965,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result step(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// sqrt builtin @@ -974,7 +974,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result sqrt(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// tan builtin @@ -983,7 +983,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result tan(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// tanh builtin @@ -992,7 +992,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result tanh(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// transpose builtin @@ -1001,7 +1001,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result transpose(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// trunc builtin @@ -1010,7 +1010,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result trunc(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// unpack2x16float builtin @@ -1019,7 +1019,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result unpack2x16float(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// unpack2x16snorm builtin @@ -1028,7 +1028,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result unpack2x16snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// unpack2x16unorm builtin @@ -1037,7 +1037,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result unpack2x16unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// unpack4x8snorm builtin @@ -1046,7 +1046,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result unpack4x8snorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// unpack4x8unorm builtin @@ -1055,7 +1055,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result unpack4x8unorm(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); /// quantizeToF16 builtin @@ -1064,7 +1064,7 @@ class ConstEval { /// @param source the source location /// @return the result value, or null if the value cannot be calculated Result quantizeToF16(const type::Type* ty, - utils::VectorRef args, + utils::VectorRef args, const Source& source); private: @@ -1361,14 +1361,14 @@ class ConstEval { /// @param v1 the first vector /// @param v2 the second vector /// @returns the dot product - Result Dot(const Source& source, const constant::Constant* v1, const constant::Constant* v2); + Result Dot(const Source& source, const constant::Value* v1, const constant::Value* v2); /// Returns the length of c0 /// @param source the source location /// @param ty the return type /// @param c0 the constant to calculate the length of /// @returns the length of c0 - Result Length(const Source& source, const type::Type* ty, const constant::Constant* c0); + Result Length(const Source& source, const type::Type* ty, const constant::Value* c0); /// Returns the product of v1 and v2 /// @param source the source location @@ -1378,8 +1378,8 @@ class ConstEval { /// @returns the product of v1 and v2 Result Mul(const Source& source, const type::Type* ty, - const constant::Constant* v1, - const constant::Constant* v2); + const constant::Value* v1, + const constant::Value* v2); /// Returns the difference between v2 and v1 /// @param source the source location @@ -1389,8 +1389,8 @@ class ConstEval { /// @returns the difference between v2 and v1 Result Sub(const Source& source, const type::Type* ty, - const constant::Constant* v1, - const constant::Constant* v2); + const constant::Value* v1, + const constant::Value* v2); ProgramBuilder& builder; }; diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc index 0f9e8ef69d..e7e306af20 100644 --- a/src/tint/resolver/const_eval_binary_op_test.cc +++ b/src/tint/resolver/const_eval_binary_op_test.cc @@ -99,7 +99,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) { auto& expected = expected_case.value; auto* sem = Sem().Get(expr); - const constant::Constant* value = sem->ConstantValue(); + const constant::Value* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); @@ -892,20 +892,19 @@ TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) { EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); - const constant::Constant* value = sem->ConstantValue(); + const constant::Value* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); auto* expected_sem = Sem().Get(expected_expr); - const constant::Constant* expected_value = expected_sem->ConstantValue(); + const constant::Value* expected_value = expected_sem->ConstantValue(); ASSERT_NE(expected_value, nullptr); EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); - ForEachElemPair(value, expected_value, - [&](const constant::Constant* a, const constant::Constant* b) { - EXPECT_EQ(a->ValueAs(), b->ValueAs()); - return HasFailure() ? Action::kStop : Action::kContinue; - }); + ForEachElemPair(value, expected_value, [&](const constant::Value* a, const constant::Value* b) { + EXPECT_EQ(a->ValueAs(), b->ValueAs()); + return HasFailure() ? Action::kStop : Action::kContinue; + }); } template diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc index 25adcef3c1..10070dbc62 100644 --- a/src/tint/resolver/const_eval_builtin_test.cc +++ b/src/tint/resolver/const_eval_builtin_test.cc @@ -162,7 +162,7 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { auto* sem = Sem().Get(expr); ASSERT_NE(sem, nullptr); - const constant::Constant* value = sem->ConstantValue(); + const constant::Value* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); diff --git a/src/tint/resolver/const_eval_member_access_test.cc b/src/tint/resolver/const_eval_member_access_test.cc index a1e7e59860..705ed270e6 100644 --- a/src/tint/resolver/const_eval_member_access_test.cc +++ b/src/tint/resolver/const_eval_member_access_test.cc @@ -89,10 +89,10 @@ TEST_F(ResolverConstEvalTest, Matrix_AFloat_Construct_From_AInt_Vectors) { EXPECT_FALSE(cv->AllZero()); auto* c0 = cv->Index(0); auto* c1 = cv->Index(1); - EXPECT_EQ(std::get(c0->Index(0)->Value()), 1.0); - EXPECT_EQ(std::get(c0->Index(1)->Value()), 2.0); - EXPECT_EQ(std::get(c1->Index(0)->Value()), 3.0); - EXPECT_EQ(std::get(c1->Index(1)->Value()), 4.0); + EXPECT_EQ(c0->Index(0)->ValueAs(), 1.0); + EXPECT_EQ(c0->Index(1)->ValueAs(), 2.0); + EXPECT_EQ(c1->Index(0)->ValueAs(), 3.0); + EXPECT_EQ(c1->Index(1)->ValueAs(), 4.0); } } // namespace } // namespace tint::resolver diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h index 1420be8f10..c31d89f7a8 100644 --- a/src/tint/resolver/const_eval_test.h +++ b/src/tint/resolver/const_eval_test.h @@ -36,10 +36,9 @@ inline const auto kPiOver4 = T(UnwrapNumber(0.785398163397448309616)); template inline const auto k3PiOver4 = T(UnwrapNumber(2.356194490192344928846)); -/// Walks the constant::Constant @p c, accumulating all the inner-most scalar values into @p args +/// Walks the constant::Value @p c, accumulating all the inner-most scalar values into @p args template -inline void CollectScalars(const constant::Constant* c, - utils::Vector& scalars) { +inline void CollectScalars(const constant::Value* c, utils::Vector& scalars) { Switch( c->Type(), // [&](const type::AbstractInt*) { scalars.Push(c->ValueAs()); }, @@ -57,8 +56,8 @@ inline void CollectScalars(const constant::Constant* c, }); } -/// Walks the constant::Constant @p c, returning all the inner-most scalar values. -inline utils::Vector ScalarsFrom(const constant::Constant* c) { +/// Walks the constant::Value @p c, returning all the inner-most scalar values. +inline utils::Vector ScalarsFrom(const constant::Value* c) { utils::Vector out; CollectScalars(c, out); return out; @@ -89,7 +88,7 @@ struct CheckConstantFlags { /// @param got_constant the constant value evaluated by the resolver /// @param expected_value the expected value for the test /// @param flags optional flags for controlling the comparisons -inline void CheckConstant(const constant::Constant* got_constant, +inline void CheckConstant(const constant::Value* got_constant, const builder::Value& expected_value, CheckConstantFlags flags = {}) { auto values_flat = ScalarsFrom(got_constant); @@ -258,7 +257,7 @@ using builder::Vec; // TODO(amaiorano): Move to Constant.h? enum class Action { kStop, kContinue }; template -inline Action ForEachElemPair(const constant::Constant* a, const constant::Constant* b, Func&& f) { +inline Action ForEachElemPair(const constant::Value* a, const constant::Value* b, Func&& f) { EXPECT_EQ(a->Type(), b->Type()); size_t i = 0; while (true) { diff --git a/src/tint/resolver/const_eval_unary_op_test.cc b/src/tint/resolver/const_eval_unary_op_test.cc index df67e54723..f9a1e578e9 100644 --- a/src/tint/resolver/const_eval_unary_op_test.cc +++ b/src/tint/resolver/const_eval_unary_op_test.cc @@ -57,7 +57,7 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) { ASSERT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); - const constant::Constant* value = sem->ConstantValue(); + const constant::Value* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc index 82f6bdde28..8de7af9fe8 100644 --- a/src/tint/resolver/materialize_test.cc +++ b/src/tint/resolver/materialize_test.cc @@ -101,7 +101,7 @@ class MaterializeTest : public resolver::ResolverTestWithParam { auto* el = value->Index(i); ASSERT_NE(el, nullptr); EXPECT_TYPE(el->Type(), v->type()); - EXPECT_EQ(std::get(el->Value()), expected_value); + EXPECT_EQ(el->ValueAs(), expected_value); } }, [&](const type::Matrix* m) { @@ -113,7 +113,7 @@ class MaterializeTest : public resolver::ResolverTestWithParam { auto* el = column->Index(r); ASSERT_NE(el, nullptr); EXPECT_TYPE(el->Type(), m->type()); - EXPECT_EQ(std::get(el->Value()), expected_value); + EXPECT_EQ(el->ValueAs(), expected_value); } } }, @@ -124,10 +124,10 @@ class MaterializeTest : public resolver::ResolverTestWithParam { auto* el = value->Index(i); ASSERT_NE(el, nullptr); EXPECT_TYPE(el->Type(), a->ElemType()); - EXPECT_EQ(std::get(el->Value()), expected_value); + EXPECT_EQ(el->ValueAs(), expected_value); } }, - [&](Default) { EXPECT_EQ(std::get(value->Value()), expected_value); }); + [&](Default) { EXPECT_EQ(value->ValueAs(), expected_value); }); } }; diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 6bc8383bdb..f0cbada3a1 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1278,7 +1278,7 @@ sem::CaseStatement* Resolver::CaseStatement(const ast::CaseStatement* stmt, cons ExprEvalStageConstraint constraint{sem::EvaluationStage::kConstant, "case selector"}; TINT_SCOPED_ASSIGNMENT(expr_eval_stage_constraint_, constraint); - const constant::Constant* const_value = nullptr; + const constant::Value* const_value = nullptr; if (!sel->IsDefault()) { // The sem statement was created in the switch when attempting to determine the // common type. @@ -1797,7 +1797,7 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr, return nullptr; } - const constant::Constant* materialized_val = nullptr; + const constant::Value* materialized_val = nullptr; if (!skip_const_eval_.Contains(decl)) { auto expr_val = expr->ConstantValue(); if (!expr_val) { @@ -1849,7 +1849,7 @@ bool Resolver::ShouldMaterializeArgument(const type::Type* parameter_ty) const { return param_el_ty && !param_el_ty->Is(); } -bool Resolver::Convert(const constant::Constant*& c, +bool Resolver::Convert(const constant::Value*& c, const type::Type* target_ty, const Source& source) { auto r = const_eval_.Convert(target_ty, c, source); @@ -1861,7 +1861,7 @@ bool Resolver::Convert(const constant::Constant*& c, } template -utils::Result> Resolver::ConvertArguments( +utils::Result> Resolver::ConvertArguments( const utils::Vector& args, const sem::CallTarget* target) { auto const_args = utils::Transform(args, [](auto* arg) { return arg->ConstantValue(); }); @@ -1919,7 +1919,7 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp ty = builder_->create(ty, ref->AddressSpace(), ref->Access()); } - const constant::Constant* val = nullptr; + const constant::Value* val = nullptr; auto stage = sem::EarliestStage(obj->Stage(), idx->Stage()); if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) { stage = sem::EvaluationStage::kNotEvaluated; @@ -1950,7 +1950,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) { RegisterLoadIfNeeded(inner); - const constant::Constant* val = nullptr; + const constant::Value* val = nullptr; // TODO(crbug.com/tint/1582): short circuit 'expr' once const eval of Bitcast is implemented. if (auto r = const_eval_.Bitcast(ty, inner)) { val = r.Get(); @@ -2012,7 +2012,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { return nullptr; } - const constant::Constant* value = nullptr; + const constant::Value* value = nullptr; auto stage = sem::EarliestStage(ctor_or_conv.target->Stage(), args_stage); if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) { stage = sem::EvaluationStage::kNotEvaluated; @@ -2042,7 +2042,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { } auto stage = args_stage; // The evaluation stage of the call - const constant::Constant* value = nullptr; // The constant value for the call + const constant::Value* value = nullptr; // The constant value for the call if (stage == sem::EvaluationStage::kConstant) { if (auto r = const_eval_.ArrayOrStructInit(ty, args)) { value = r.Get(); @@ -2336,7 +2336,7 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, // If the builtin is @const, and all arguments have constant values, evaluate the builtin // now. - const constant::Constant* value = nullptr; + const constant::Value* value = nullptr; auto stage = sem::EarliestStage(arg_stage, builtin.sem->Stage()); if (stage == sem::EvaluationStage::kConstant && skip_const_eval_.Contains(expr)) { stage = sem::EvaluationStage::kNotEvaluated; @@ -2607,7 +2607,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) { return nullptr; } - const constant::Constant* val = nullptr; + const constant::Value* val = nullptr; if (auto r = const_eval_.Literal(ty, literal)) { val = r.Get(); } else { @@ -2875,7 +2875,7 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { RegisterLoadIfNeeded(lhs); RegisterLoadIfNeeded(rhs); - const constant::Constant* value = nullptr; + const constant::Value* value = nullptr; if (stage == sem::EvaluationStage::kConstant) { if (op.const_eval_fn) { if (skip_const_eval_.Contains(expr)) { @@ -2920,7 +2920,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) { const type::Type* ty = nullptr; const sem::Variable* root_ident = nullptr; - const constant::Constant* value = nullptr; + const constant::Value* value = nullptr; auto stage = sem::EvaluationStage::kRuntime; switch (unary->op) { diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 186c653443..8b16abf862 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -23,7 +23,7 @@ #include #include -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/program_builder.h" #include "src/tint/resolver/const_eval.h" #include "src/tint/resolver/dependency_graph.h" @@ -197,13 +197,13 @@ class Resolver { /// Converts `c` to `target_ty` /// @returns true on success, false on failure. - bool Convert(const constant::Constant*& c, const type::Type* target_ty, const Source& source); + bool Convert(const constant::Value*& c, const type::Type* target_ty, const Source& source); /// Transforms `args` to a vector of constants, and converts each constant to the call target's /// parameter type. /// @returns the vector of constants, `utils::Failure` on failure. template - utils::Result> ConvertArguments( + utils::Result> ConvertArguments( const utils::Vector& args, const sem::CallTarget* target); diff --git a/src/tint/sem/call.cc b/src/tint/sem/call.cc index 0ed2a4adba..7a28e9f08c 100644 --- a/src/tint/sem/call.cc +++ b/src/tint/sem/call.cc @@ -26,7 +26,7 @@ Call::Call(const ast::CallExpression* declaration, EvaluationStage stage, utils::VectorRef arguments, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, bool has_side_effects) : Base(declaration, target->ReturnType(), stage, statement, constant, has_side_effects), target_(target), diff --git a/src/tint/sem/call.h b/src/tint/sem/call.h index 152ebb94c8..08737590f4 100644 --- a/src/tint/sem/call.h +++ b/src/tint/sem/call.h @@ -41,7 +41,7 @@ class Call final : public Castable { EvaluationStage stage, utils::VectorRef arguments, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, bool has_side_effects); /// Destructor diff --git a/src/tint/sem/expression.cc b/src/tint/sem/expression.cc index 7c5911105e..9d231be728 100644 --- a/src/tint/sem/expression.cc +++ b/src/tint/sem/expression.cc @@ -26,7 +26,7 @@ Expression::Expression(const ast::Expression* declaration, const type::Type* type, EvaluationStage stage, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, bool has_side_effects, const Variable* root_ident /* = nullptr */) : declaration_(declaration), diff --git a/src/tint/sem/expression.h b/src/tint/sem/expression.h index 7127a1e032..e39e84483c 100644 --- a/src/tint/sem/expression.h +++ b/src/tint/sem/expression.h @@ -16,7 +16,7 @@ #define SRC_TINT_SEM_EXPRESSION_H_ #include "src/tint/ast/expression.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/sem/behavior.h" #include "src/tint/sem/evaluation_stage.h" #include "src/tint/sem/node.h" @@ -44,7 +44,7 @@ class Expression : public Castable { const type::Type* type, EvaluationStage stage, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, bool has_side_effects, const Variable* root_ident = nullptr); @@ -64,7 +64,7 @@ class Expression : public Castable { const Statement* Stmt() const { return statement_; } /// @return the constant value of this expression - const constant::Constant* ConstantValue() const { return constant_; } + const constant::Value* ConstantValue() const { return constant_; } /// Returns the variable or parameter that this expression derives from. /// For reference and pointer expressions, this will either be the originating @@ -95,7 +95,7 @@ class Expression : public Castable { const type::Type* const type_; const EvaluationStage stage_; const Statement* const statement_; - const constant::Constant* const constant_; + const constant::Value* const constant_; sem::Behaviors behaviors_{sem::Behavior::kNext}; const bool has_side_effects_; }; diff --git a/src/tint/sem/expression_test.cc b/src/tint/sem/expression_test.cc index 1913f56823..891400c353 100644 --- a/src/tint/sem/expression_test.cc +++ b/src/tint/sem/expression_test.cc @@ -23,18 +23,20 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::sem { namespace { -class MockConstant : public constant::Constant { +class MockConstant : public constant::Value { public: explicit MockConstant(const type::Type* ty) : type(ty) {} ~MockConstant() override {} const type::Type* Type() const override { return type; } - std::variant Value() const override { return {}; } - const constant::Constant* Index(size_t) const override { return {}; } + const constant::Value* Index(size_t) const override { return {}; } bool AllZero() const override { return {}; } bool AnyZero() const override { return {}; } bool AllEqual() const override { return {}; } size_t Hash() const override { return 0; } + protected: + std::variant InternalValue() const override { return {}; } + private: const type::Type* type; }; diff --git a/src/tint/sem/index_accessor_expression.cc b/src/tint/sem/index_accessor_expression.cc index f8a3fdb600..ed5a4684d8 100644 --- a/src/tint/sem/index_accessor_expression.cc +++ b/src/tint/sem/index_accessor_expression.cc @@ -28,7 +28,7 @@ IndexAccessorExpression::IndexAccessorExpression(const ast::IndexAccessorExpress const Expression* object, const Expression* index, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, bool has_side_effects, const Variable* root_ident /* = nullptr */) : Base(declaration, type, stage, statement, constant, has_side_effects, root_ident), diff --git a/src/tint/sem/index_accessor_expression.h b/src/tint/sem/index_accessor_expression.h index 0e7586d899..8327b79cde 100644 --- a/src/tint/sem/index_accessor_expression.h +++ b/src/tint/sem/index_accessor_expression.h @@ -45,7 +45,7 @@ class IndexAccessorExpression final : public CastableDeclaration(), /* type */ type, /* stage */ constant ? EvaluationStage::kConstant : EvaluationStage::kNotEvaluated, diff --git a/src/tint/sem/materialize.h b/src/tint/sem/materialize.h index 8a65a03711..532b3b2133 100644 --- a/src/tint/sem/materialize.h +++ b/src/tint/sem/materialize.h @@ -35,7 +35,7 @@ class Materialize final : public Castable { Materialize(const Expression* expr, const Statement* statement, const type::Type* type, - const constant::Constant* constant); + const constant::Value* constant); /// Destructor ~Materialize() override; diff --git a/src/tint/sem/member_accessor_expression.cc b/src/tint/sem/member_accessor_expression.cc index be77b3f009..06cfb76c10 100644 --- a/src/tint/sem/member_accessor_expression.cc +++ b/src/tint/sem/member_accessor_expression.cc @@ -27,7 +27,7 @@ MemberAccessorExpression::MemberAccessorExpression(const ast::MemberAccessorExpr const type::Type* type, EvaluationStage stage, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, const Expression* object, bool has_side_effects, const Variable* root_ident /* = nullptr */) @@ -39,7 +39,7 @@ MemberAccessorExpression::~MemberAccessorExpression() = default; StructMemberAccess::StructMemberAccess(const ast::MemberAccessorExpression* declaration, const type::Type* type, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, const Expression* object, const StructMember* member, bool has_side_effects, @@ -59,7 +59,7 @@ StructMemberAccess::~StructMemberAccess() = default; Swizzle::Swizzle(const ast::MemberAccessorExpression* declaration, const type::Type* type, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, const Expression* object, utils::VectorRef indices, bool has_side_effects, diff --git a/src/tint/sem/member_accessor_expression.h b/src/tint/sem/member_accessor_expression.h index aaad5caa1c..1951afcb25 100644 --- a/src/tint/sem/member_accessor_expression.h +++ b/src/tint/sem/member_accessor_expression.h @@ -52,7 +52,7 @@ class MemberAccessorExpression : public Castable { Swizzle(const ast::MemberAccessorExpression* declaration, const type::Type* type, const Statement* statement, - const constant::Constant* constant, + const constant::Value* constant, const Expression* object, utils::VectorRef indices, bool has_side_effects, diff --git a/src/tint/sem/switch_statement.cc b/src/tint/sem/switch_statement.cc index 4cc857db58..874c3904ad 100644 --- a/src/tint/sem/switch_statement.cc +++ b/src/tint/sem/switch_statement.cc @@ -49,7 +49,7 @@ const ast::CaseStatement* CaseStatement::Declaration() const { return static_cast(Base::Declaration()); } -CaseSelector::CaseSelector(const ast::CaseSelector* decl, const constant::Constant* val) +CaseSelector::CaseSelector(const ast::CaseSelector* decl, const constant::Value* val) : Base(), decl_(decl), val_(val) {} CaseSelector::~CaseSelector() = default; diff --git a/src/tint/sem/switch_statement.h b/src/tint/sem/switch_statement.h index 4d182107ca..2476906311 100644 --- a/src/tint/sem/switch_statement.h +++ b/src/tint/sem/switch_statement.h @@ -26,7 +26,7 @@ class CaseSelector; class SwitchStatement; } // namespace tint::ast namespace tint::constant { -class Constant; +class Value; } // namespace tint::constant namespace tint::sem { class CaseStatement; @@ -103,7 +103,7 @@ class CaseSelector final : public Castable { /// Constructor /// @param decl the selector declaration /// @param val the case selector value, nullptr for a default selector - explicit CaseSelector(const ast::CaseSelector* decl, const constant::Constant* val = nullptr); + explicit CaseSelector(const ast::CaseSelector* decl, const constant::Value* val = nullptr); /// Destructor ~CaseSelector() override; @@ -115,11 +115,11 @@ class CaseSelector final : public Castable { const ast::CaseSelector* Declaration() const; /// @returns the selector constant value, or nullptr if this is the default selector - const constant::Constant* Value() const { return val_; } + const constant::Value* Value() const { return val_; } private: const ast::CaseSelector* const decl_; - const constant::Constant* const val_; + const constant::Value* const val_; }; } // namespace tint::sem diff --git a/src/tint/sem/variable.cc b/src/tint/sem/variable.cc index 41bff70d59..db5edfe52e 100644 --- a/src/tint/sem/variable.cc +++ b/src/tint/sem/variable.cc @@ -33,7 +33,7 @@ Variable::Variable(const ast::Variable* declaration, EvaluationStage stage, ast::AddressSpace address_space, ast::Access access, - const constant::Constant* constant_value) + const constant::Value* constant_value) : declaration_(declaration), type_(type), stage_(stage), @@ -49,7 +49,7 @@ LocalVariable::LocalVariable(const ast::Variable* declaration, ast::AddressSpace address_space, ast::Access access, const sem::Statement* statement, - const constant::Constant* constant_value) + const constant::Value* constant_value) : Base(declaration, type, stage, address_space, access, constant_value), statement_(statement) {} @@ -60,7 +60,7 @@ GlobalVariable::GlobalVariable(const ast::Variable* declaration, EvaluationStage stage, ast::AddressSpace address_space, ast::Access access, - const constant::Constant* constant_value, + const constant::Value* constant_value, sem::BindingPoint binding_point, std::optional location) : Base(declaration, type, stage, address_space, access, constant_value), diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h index 0ff1c2c1a2..bf2a2d5acc 100644 --- a/src/tint/sem/variable.h +++ b/src/tint/sem/variable.h @@ -59,7 +59,7 @@ class Variable : public Castable { EvaluationStage stage, ast::AddressSpace address_space, ast::Access access, - const constant::Constant* constant_value); + const constant::Value* constant_value); /// Destructor ~Variable() override; @@ -80,7 +80,7 @@ class Variable : public Castable { ast::Access Access() const { return access_; } /// @return the constant value of this expression - const constant::Constant* ConstantValue() const { return constant_value_; } + const constant::Value* ConstantValue() const { return constant_value_; } /// @returns the variable initializer expression, or nullptr if the variable /// does not have one. @@ -102,7 +102,7 @@ class Variable : public Castable { const EvaluationStage stage_; const ast::AddressSpace address_space_; const ast::Access access_; - const constant::Constant* constant_value_; + const constant::Value* constant_value_; const Expression* initializer_ = nullptr; std::vector users_; }; @@ -124,7 +124,7 @@ class LocalVariable final : public Castable { ast::AddressSpace address_space, ast::Access access, const sem::Statement* statement, - const constant::Constant* constant_value); + const constant::Value* constant_value); /// Destructor ~LocalVariable() override; @@ -164,7 +164,7 @@ class GlobalVariable final : public Castable { EvaluationStage stage, ast::AddressSpace address_space, ast::Access access, - const constant::Constant* constant_value, + const constant::Value* constant_value, sem::BindingPoint binding_point = {}, std::optional location = std::nullopt); diff --git a/src/tint/tint.natvis b/src/tint/tint.natvis index fccb23af30..71fbd1d57b 100644 --- a/src/tint/tint.natvis +++ b/src/tint/tint.natvis @@ -256,7 +256,7 @@ vec{width_}<{*subtype_}> - + Type={*Type()} Value={Value()} diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index 221d9bc364..2da5c8c918 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -26,7 +26,7 @@ #include "src/tint/ast/internal_attribute.h" #include "src/tint/ast/interpolate_attribute.h" #include "src/tint/ast/variable_decl_statement.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/debug.h" #include "src/tint/sem/block_statement.h" #include "src/tint/sem/call.h" @@ -2294,7 +2294,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { return true; } -bool GeneratorImpl::EmitConstant(std::ostream& out, const constant::Constant* constant) { +bool GeneratorImpl::EmitConstant(std::ostream& out, const constant::Value* constant) { return Switch( constant->Type(), // [&](const type::Bool*) { diff --git a/src/tint/writer/glsl/generator_impl.h b/src/tint/writer/glsl/generator_impl.h index 2d5d015e0f..35043e9aba 100644 --- a/src/tint/writer/glsl/generator_impl.h +++ b/src/tint/writer/glsl/generator_impl.h @@ -43,7 +43,6 @@ // Forward declarations namespace tint::sem { class Call; -class Constant; class Builtin; class TypeInitializer; class TypeConversion; @@ -363,7 +362,7 @@ class GeneratorImpl : public TextGenerator { /// @param out the output stream /// @param constant the constant value to emit /// @returns true if the constant value was successfully emitted - bool EmitConstant(std::ostream& out, const constant::Constant* constant); + bool EmitConstant(std::ostream& out, const constant::Value* constant); /// Handles a literal /// @param out the output stream /// @param lit the literal to emit diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index d21865b168..ece37231b4 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -27,7 +27,7 @@ #include "src/tint/ast/internal_attribute.h" #include "src/tint/ast/interpolate_attribute.h" #include "src/tint/ast/variable_decl_statement.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/debug.h" #include "src/tint/sem/block_statement.h" #include "src/tint/sem/call.h" @@ -1117,7 +1117,7 @@ bool GeneratorImpl::EmitUniformBufferAccess( if (auto* val = offset_arg->ConstantValue()) { TINT_ASSERT(Writer, val->Type()->Is()); - scalar_offset_bytes = static_cast(std::get(val->Value())); + scalar_offset_bytes = static_cast(val->ValueAs()); scalar_offset_index = scalar_offset_bytes / 4; // bytes -> scalar index scalar_offset_constant = true; } @@ -3263,7 +3263,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { } bool GeneratorImpl::EmitConstant(std::ostream& out, - const constant::Constant* constant, + const constant::Value* constant, bool is_variable_initializer) { return Switch( constant->Type(), // diff --git a/src/tint/writer/hlsl/generator_impl.h b/src/tint/writer/hlsl/generator_impl.h index caf3816ef2..1558fa5971 100644 --- a/src/tint/writer/hlsl/generator_impl.h +++ b/src/tint/writer/hlsl/generator_impl.h @@ -43,7 +43,6 @@ // Forward declarations namespace tint::sem { class Call; -class Constant; class Builtin; class TypeInitializer; class TypeConversion; @@ -352,7 +351,7 @@ class GeneratorImpl : public TextGenerator { /// initializer /// @returns true if the constant value was successfully emitted bool EmitConstant(std::ostream& out, - const constant::Constant* constant, + const constant::Value* constant, bool is_variable_initializer); /// Handles a literal /// @param out the output stream diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 15e190b937..c3fbe4efda 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -31,7 +31,7 @@ #include "src/tint/ast/module.h" #include "src/tint/ast/variable_decl_statement.h" #include "src/tint/ast/void.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/sem/call.h" #include "src/tint/sem/function.h" #include "src/tint/sem/member_accessor_expression.h" @@ -1658,7 +1658,7 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const type::Type* type) { }); } -bool GeneratorImpl::EmitConstant(std::ostream& out, const constant::Constant* constant) { +bool GeneratorImpl::EmitConstant(std::ostream& out, const constant::Value* constant) { return Switch( constant->Type(), // [&](const type::Bool*) { diff --git a/src/tint/writer/msl/generator_impl.h b/src/tint/writer/msl/generator_impl.h index 53fff205ca..64c83afab2 100644 --- a/src/tint/writer/msl/generator_impl.h +++ b/src/tint/writer/msl/generator_impl.h @@ -46,7 +46,6 @@ // Forward declarations namespace tint::sem { class Call; -class Constant; class Builtin; class TypeInitializer; class TypeConversion; @@ -260,7 +259,7 @@ class GeneratorImpl : public TextGenerator { /// @param out the output stream /// @param constant the constant value to emit /// @returns true if the constant value was successfully emitted - bool EmitConstant(std::ostream& out, const constant::Constant* constant); + bool EmitConstant(std::ostream& out, const constant::Value* constant); /// Handles a literal /// @param out the output of the expression stream /// @param lit the literal to emit diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index 79b0c92f31..dab6869f45 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -22,7 +22,7 @@ #include "src/tint/ast/id_attribute.h" #include "src/tint/ast/internal_attribute.h" #include "src/tint/ast/traverse_expressions.h" -#include "src/tint/constant/constant.h" +#include "src/tint/constant/value.h" #include "src/tint/sem/builtin.h" #include "src/tint/sem/call.h" #include "src/tint/sem/function.h" @@ -1641,7 +1641,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::LiteralExpression* lit) { return GenerateConstantIfNeeded(constant); } -uint32_t Builder::GenerateConstantIfNeeded(const constant::Constant* constant) { +uint32_t Builder::GenerateConstantIfNeeded(const constant::Value* constant) { if (constant->AllZero()) { return GenerateConstantNullIfNeeded(constant->Type()); } diff --git a/src/tint/writer/spirv/builder.h b/src/tint/writer/spirv/builder.h index f3af99a485..68c8f25deb 100644 --- a/src/tint/writer/spirv/builder.h +++ b/src/tint/writer/spirv/builder.h @@ -43,7 +43,6 @@ // Forward declarations namespace tint::sem { class Call; -class Constant; class TypeInitializer; class TypeConversion; } // namespace tint::sem @@ -559,7 +558,7 @@ class Builder { /// Generates a constant value if needed /// @param constant the constant to generate. /// @returns the ID on success or 0 on failure - uint32_t GenerateConstantIfNeeded(const constant::Constant* constant); + uint32_t GenerateConstantIfNeeded(const constant::Value* constant); /// Generates a scalar constant if needed /// @param constant the constant to generate.