tint/resolver: Add ConstEval class

Extract out the methods of Resolver::EvaluateXXXValue() to a new
tint::resolver::ConstEval class.

Removes more bloat from Resolver, and creates a centralized class for
constant evaluation, which can be referred to by the IntrinsicTable.

Change-Id: I3b58882ef293fe07f019ad2138a7e9dbbac8de53
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95951
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2022-07-15 14:14:09 +00:00 committed by Dawn LUCI CQ
parent cfe07a1b33
commit 65c5c9d92b
11 changed files with 731 additions and 694 deletions

View File

@ -398,7 +398,6 @@ libtint_source_set("libtint_core_all_src") {
"resolver/intrinsic_table.inl",
"resolver/resolver.cc",
"resolver/resolver.h",
"resolver/resolver_constants.cc",
"resolver/sem_helper.cc",
"resolver/sem_helper.h",
"resolver/uniformity.cc",
@ -1090,6 +1089,7 @@ if (tint_build_unittests) {
"resolver/call_validation_test.cc",
"resolver/compound_assignment_validation_test.cc",
"resolver/compound_statement_test.cc",
"resolver/const_eval_test.cc",
"resolver/control_block_validation_test.cc",
"resolver/dependency_graph_test.cc",
"resolver/entry_point_validation_test.cc",
@ -1104,7 +1104,6 @@ if (tint_build_unittests) {
"resolver/ptr_ref_test.cc",
"resolver/ptr_ref_validation_test.cc",
"resolver/resolver_behavior_test.cc",
"resolver/resolver_constants_test.cc",
"resolver/resolver_test.cc",
"resolver/resolver_test_helper.cc",
"resolver/resolver_test_helper.h",

View File

@ -258,7 +258,6 @@ set(TINT_LIB_SRCS
resolver/intrinsic_table.cc
resolver/intrinsic_table.h
resolver/intrinsic_table.inl
resolver/resolver_constants.cc
resolver/resolver.cc
resolver/resolver.h
resolver/sem_helper.cc
@ -774,6 +773,7 @@ if(TINT_BUILD_TESTS)
resolver/call_validation_test.cc
resolver/compound_assignment_validation_test.cc
resolver/compound_statement_test.cc
resolver/const_eval_test.cc
resolver/control_block_validation_test.cc
resolver/dependency_graph_test.cc
resolver/entry_point_validation_test.cc
@ -789,7 +789,6 @@ if(TINT_BUILD_TESTS)
resolver/ptr_ref_test.cc
resolver/ptr_ref_validation_test.cc
resolver/resolver_behavior_test.cc
resolver/resolver_constants_test.cc
resolver/resolver_test_helper.cc
resolver/resolver_test_helper.h
resolver/resolver_test.cc

View File

@ -14,6 +14,603 @@
#include "src/tint/resolver/const_eval.h"
#include "src/tint/sem/constant.h"
#include <algorithm>
#include <limits>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
namespace tint::resolver::const_eval {} // namespace tint::resolver::const_eval
#include "src/tint/program_builder.h"
#include "src/tint/sem/abstract_float.h"
#include "src/tint/sem/abstract_int.h"
#include "src/tint/sem/array.h"
#include "src/tint/sem/bool.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/f16.h"
#include "src/tint/sem/f32.h"
#include "src/tint/sem/i32.h"
#include "src/tint/sem/matrix.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/u32.h"
#include "src/tint/sem/vector.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
namespace {
/// TypeDispatch is a helper for calling the function `f`, passing a single zero-value argument of
/// the C++ type that corresponds to the sem::Type `type`. For example, calling `TypeDispatch()`
/// with a type of `sem::I32*` will call the function f with a single argument of `i32(0)`.
/// @returns the value returned by calling `f`.
/// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will
/// return the zero-initialized value of the return type for `f`.
template <typename F>
auto TypeDispatch(const sem::Type* type, F&& f) {
return Switch(
type, //
[&](const sem::AbstractInt*) { return f(AInt(0)); }, //
[&](const sem::AbstractFloat*) { return f(AFloat(0)); }, //
[&](const sem::I32*) { return f(i32(0)); }, //
[&](const sem::U32*) { return f(u32(0)); }, //
[&](const sem::F32*) { return f(f32(0)); }, //
[&](const sem::F16*) { return f(f16(0)); }, //
[&](const sem::Bool*) { return f(static_cast<bool>(0)); });
}
/// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the
/// Number.
template <typename T>
inline auto ValueOf(T value) {
if constexpr (std::is_same_v<UnwrapNumber<T>, T>) {
return value;
} else {
return value.value;
}
}
/// @returns true if `value` is a positive zero.
template <typename T>
inline bool IsPositiveZero(T value) {
using N = UnwrapNumber<T>;
return Number<N>(value) == Number<N>(0); // Considers sign bit
}
/// Constant inherits from sem::Constant to add an private implementation method for conversion.
struct Constant : public sem::Constant {
/// Convert attempts to convert the constant value to the given type. On error, Convert()
/// creates a new diagnostic message and returns a Failure.
virtual utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const = 0;
};
// Forward declaration
const Constant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type,
std::vector<const Constant*> elements);
/// Element holds a single scalar or abstract-numeric value.
/// Element implements the Constant interface.
template <typename T>
struct Element : Constant {
Element(const sem::Type* t, T v) : type(t), value(v) {}
~Element() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override {
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return static_cast<AFloat>(value);
} else {
return static_cast<AInt>(value);
}
}
const Constant* Index(size_t) const override { return nullptr; }
bool AllZero() const override { return IsPositiveZero(value); }
bool AnyZero() const override { return IsPositiveZero(value); }
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == type) {
// If the types are identical, then no conversion is needed.
return this;
}
bool failed = false;
auto* res = TypeDispatch(target_ty, [&](auto zero_to) -> const Constant* {
// `T` is the source type, `value` is the source value.
// `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>;
if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool]
return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value));
} else if constexpr (std::is_same_v<T, bool>) {
// [bool -> x]
return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(value)) {
// Conversion success
return builder.create<Element<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases ---
} else if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
// [abstract-numeric -> x] - materialization failure
std::stringstream ss;
ss << "value " << value << " cannot be represented as ";
ss << "'" << builder.FriendlyName(target_ty) << "'";
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
failed = true;
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
constexpr auto kInf = std::numeric_limits<double>::infinity();
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Element<TO>>(target_ty, TO(-kInf));
case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO(kInf));
}
} else {
// [x -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Element<TO>>(target_ty, TO(TO::kLowest));
case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO(TO::kHighest));
}
}
return nullptr; // Expression is not constant.
});
if (failed) {
// A diagnostic error has been raised, and resolving should abort.
return utils::Failure;
}
return res;
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
sem::Type const* const type;
const T value;
};
/// Splat holds a single Constant value, duplicated as all children.
/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is
/// identical. Splat may be of a vector, matrix or array type.
/// Splat implements the Constant interface.
struct Splat : Constant {
Splat(const sem::Type* t, const Constant* e, size_t n) : type(t), el(e), count(n) {}
~Splat() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
const Constant* Index(size_t i) const override { return i < count ? el : nullptr; }
bool AllZero() const override { return el->AllZero(); }
bool AnyZero() const override { return el->AnyZero(); }
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
// Convert the single splatted element type.
auto conv_el = el->Convert(builder, sem::Type::ElementOf(target_ty), source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
return builder.create<Splat>(target_ty, conv_el.Get(), count);
}
sem::Type const* const type;
const Constant* el;
const size_t count;
};
/// Composite holds a number of mixed child Constant values.
/// Composite may be of a vector, matrix or array type.
/// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate Constant type.
/// Composite implements the Constant interface.
struct Composite : Constant {
Composite(const sem::Type* t, std::vector<const Constant*> els, bool all_0, bool any_0)
: type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {}
~Composite() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
const Constant* Index(size_t i) const override {
return i < elements.size() ? elements[i] : nullptr;
}
bool AllZero() const override { return all_zero; }
bool AnyZero() const override { return any_zero; }
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
size_t Hash() const override { return hash; }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
// Convert each of the composite element types.
auto* el_ty = sem::Type::ElementOf(target_ty);
std::vector<const Constant*> conv_els;
conv_els.reserve(elements.size());
for (auto* el : elements) {
auto conv_el = el->Convert(builder, el_ty, source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
conv_els.emplace_back(conv_el.Get());
}
return CreateComposite(builder, target_ty, std::move(conv_els));
}
size_t CalcHash() {
auto h = utils::Hash(type, all_zero, any_zero);
for (auto* el : elements) {
utils::HashCombine(&h, el->Hash());
}
return h;
}
sem::Type const* const type;
const std::vector<const Constant*> elements;
const bool all_zero;
const bool any_zero;
const size_t hash;
};
/// CreateElement constructs and returns an Element<T>.
template <typename T>
const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
return builder.create<Element<T>>(t, v);
}
/// ZeroValue returns a Constant for the zero-value of the type `type`.
const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return Switch(
type, //
[&](const sem::Vector* v) -> const Constant* {
auto* zero_el = ZeroValue(builder, v->type());
return builder.create<Splat>(type, zero_el, v->Width());
},
[&](const sem::Matrix* m) -> const Constant* {
auto* zero_el = ZeroValue(builder, m->ColumnType());
return builder.create<Splat>(type, zero_el, m->columns());
},
[&](const sem::Array* a) -> const Constant* {
if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
return builder.create<Splat>(type, zero_el, a->Count());
}
return nullptr;
},
[&](const sem::Struct* s) -> const Constant* {
std::unordered_map<sem::Type*, const Constant*> zero_by_type;
std::vector<const Constant*> zeros;
zeros.reserve(s->Members().size());
for (auto* member : s->Members()) {
auto* zero = utils::GetOrCreate(zero_by_type, member->Type(),
[&] { return ZeroValue(builder, member->Type()); });
if (!zero) {
return nullptr;
}
zeros.emplace_back(zero);
}
if (zero_by_type.size() == 1) {
// All members were of the same type, so the zero value is the same for all members.
return builder.create<Splat>(type, zeros[0], s->Members().size());
}
return CreateComposite(builder, s, std::move(zeros));
},
[&](Default) -> const Constant* {
return TypeDispatch(type, [&](auto zero) -> const Constant* {
return CreateElement(builder, type, zero);
});
});
}
/// Equal returns true if the constants `a` and `b` are of the same type and value.
bool Equal(const sem::Constant* a, const sem::Constant* b) {
if (a->Hash() != b->Hash()) {
return false;
}
if (a->Type() != b->Type()) {
return false;
}
return Switch(
a->Type(), //
[&](const sem::Vector* vec) {
for (size_t i = 0; i < vec->Width(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](const sem::Matrix* mat) {
for (size_t i = 0; i < mat->columns(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](const sem::Array* arr) {
for (size_t i = 0; i < arr->Count(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](Default) { return a->Value() == b->Value(); });
}
/// CreateComposite is used to construct a constant of a vector, matrix or array type.
/// CreateComposite examines the element values and will return either a Composite or a Splat,
/// depending on the element types and values.
const Constant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type,
std::vector<const Constant*> elements) {
if (elements.size() == 0) {
return nullptr;
}
bool any_zero = false;
bool all_zero = true;
bool all_equal = true;
auto* first = elements.front();
for (auto* el : elements) {
if (!el) {
return nullptr;
}
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
if (all_zero && !el->AllZero()) {
all_zero = false;
}
if (all_equal && el != first) {
if (!Equal(el, first)) {
all_equal = false;
}
}
}
if (all_equal) {
return builder.create<Splat>(type, elements[0], elements.size());
} else {
return builder.create<Composite>(type, std::move(elements), all_zero, any_zero);
}
}
} // namespace
ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
const sem::Constant* ConstEval::Literal(const sem::Type* ty,
const ast::LiteralExpression* literal) {
return Switch(
literal,
[&](const ast::BoolLiteralExpression* lit) {
return CreateElement(builder, ty, lit->value);
},
[&](const ast::IntLiteralExpression* lit) -> const Constant* {
switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone:
return CreateElement(builder, ty, AInt(lit->value));
case ast::IntLiteralExpression::Suffix::kI:
return CreateElement(builder, ty, i32(lit->value));
case ast::IntLiteralExpression::Suffix::kU:
return CreateElement(builder, ty, u32(lit->value));
}
return nullptr;
},
[&](const ast::FloatLiteralExpression* lit) -> const Constant* {
switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone:
return CreateElement(builder, ty, AFloat(lit->value));
case ast::FloatLiteralExpression::Suffix::kF:
return CreateElement(builder, ty, f32(lit->value));
case ast::FloatLiteralExpression::Suffix::kH:
return CreateElement(builder, ty, f16(lit->value));
}
return nullptr;
});
}
const sem::Constant* ConstEval::CtorOrConv(const sem::Type* ty,
const std::vector<const sem::Expression*>& args) {
// For zero value init, return 0s
if (args.empty()) {
return ZeroValue(builder, ty);
}
if (auto* el_ty = sem::Type::ElementOf(ty); el_ty && args.size() == 1) {
// Type constructor or conversion that takes a single argument.
auto& src = args[0]->Declaration()->source;
auto* arg = static_cast<const Constant*>(args[0]->ConstantValue());
if (!arg) {
return nullptr; // Single argument is not constant.
}
if (ty->is_scalar()) { // Scalar type conversion: i32(x), u32(x), bool(x), etc
return Convert(el_ty, arg, src).Get();
}
if (arg->Type() == el_ty) {
// Argument type matches function type. This is a splat.
auto splat = [&](size_t n) { return builder.create<Splat>(ty, arg, n); };
return Switch(
ty, //
[&](const sem::Vector* v) { return splat(v->Width()); },
[&](const sem::Matrix* m) { return splat(m->columns()); },
[&](const sem::Array* a) { return splat(a->Count()); });
}
// Argument type and function type mismatch. This is a type conversion.
if (auto conv = Convert(ty, arg, src)) {
return conv.Get();
}
return nullptr;
}
// Helper for pushing all the argument constants to `els`.
auto args_as_constants = [&] {
return utils::Transform(
args, [&](auto* expr) { return static_cast<const Constant*>(expr->ConstantValue()); });
};
// Multiple arguments. Must be a type constructor.
return Switch(
ty, // What's the target type being constructed?
[&](const sem::Vector*) -> const Constant* {
// Vector can be constructed with a mix of scalars / abstract numerics and smaller
// vectors.
std::vector<const Constant*> els;
els.reserve(args.size());
for (auto* expr : args) {
auto* arg = static_cast<const Constant*>(expr->ConstantValue());
if (!arg) {
return nullptr;
}
auto* arg_ty = arg->Type();
if (auto* arg_vec = arg_ty->As<sem::Vector>()) {
// Extract out vector elements.
for (uint32_t i = 0; i < arg_vec->Width(); i++) {
auto* el = static_cast<const Constant*>(arg->Index(i));
if (!el) {
return nullptr;
}
els.emplace_back(el);
}
} else {
els.emplace_back(arg);
}
}
return CreateComposite(builder, ty, std::move(els));
},
[&](const sem::Matrix* m) -> const Constant* {
// Matrix can be constructed with a set of scalars / abstract numerics, or column
// vectors.
if (args.size() == m->columns() * m->rows()) {
// Matrix built from scalars / abstract numerics
std::vector<const Constant*> els;
els.reserve(args.size());
for (uint32_t c = 0; c < m->columns(); c++) {
std::vector<const Constant*> column;
column.reserve(m->rows());
for (uint32_t r = 0; r < m->rows(); r++) {
auto* arg =
static_cast<const Constant*>(args[r + c * m->rows()]->ConstantValue());
if (!arg) {
return nullptr;
}
column.emplace_back(arg);
}
els.push_back(CreateComposite(builder, m->ColumnType(), std::move(column)));
}
return CreateComposite(builder, ty, std::move(els));
}
// Matrix built from column vectors
return CreateComposite(builder, ty, args_as_constants());
},
[&](const sem::Array*) {
// Arrays must be constructed using a list of elements
return CreateComposite(builder, ty, args_as_constants());
},
[&](const sem::Struct*) {
// Structures must be constructed using a list of elements
return CreateComposite(builder, ty, args_as_constants());
});
}
const sem::Constant* ConstEval::Index(const sem::Expression* obj_expr,
const sem::Expression* idx_expr) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return {};
}
auto idx_val = idx_expr->ConstantValue();
if (!idx_val) {
return {};
}
uint32_t el_count = 0;
sem::Type::ElementOf(obj_val->Type(), &el_count);
AInt idx = idx_val->As<AInt>();
if (idx < 0 || idx >= el_count) {
auto clamped = std::min<AInt::type>(std::max<AInt::type>(idx, 0), el_count - 1);
AddWarning("index " + std::to_string(idx) + " out of bounds [0.." +
std::to_string(el_count - 1) + "]. Clamping index to " +
std::to_string(clamped),
idx_expr->Declaration()->source);
idx = clamped;
}
return obj_val->Index(static_cast<size_t>(idx));
}
const sem::Constant* ConstEval::MemberAccess(const sem::Expression* obj_expr,
const sem::StructMember* member) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return {};
}
return obj_val->Index(static_cast<size_t>(member->Index()));
}
const sem::Constant* ConstEval::Swizzle(const sem::Type* ty,
const sem::Expression* vec_expr,
const std::vector<uint32_t>& indices) {
auto* vec_val = vec_expr->ConstantValue();
if (!vec_val) {
return nullptr;
}
if (indices.size() == 1) {
return static_cast<const Constant*>(vec_val->Index(static_cast<size_t>(indices[0])));
} else {
auto values = utils::Transform(indices, [&](uint32_t i) {
return static_cast<const Constant*>(vec_val->Index(static_cast<size_t>(i)));
});
return CreateComposite(builder, ty, std::move(values));
}
}
const sem::Constant* ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
utils::Result<const sem::Constant*> ConstEval::Convert(const sem::Type* target_ty,
const sem::Constant* value,
const Source& source) {
if (value->Type() == target_ty) {
return value;
}
auto conv = static_cast<const Constant*>(value)->Convert(builder, target_ty, source);
if (!conv) {
return utils::Failure;
}
return conv.Get();
}
void ConstEval::AddError(const std::string& msg, const Source& source) const {
builder.Diagnostics().add_error(diag::System::Resolver, msg, source);
}
void ConstEval::AddWarning(const std::string& msg, const Source& source) const {
builder.Diagnostics().add_warning(diag::System::Resolver, msg, source);
}
} // namespace tint::resolver

View File

@ -16,24 +16,111 @@
#define SRC_TINT_RESOLVER_CONST_EVAL_H_
#include <stddef.h>
#include <string>
#include <vector>
#include "src/tint/utils/result.h"
// Forward declarations
namespace tint {
class ProgramBuilder;
class Source;
} // namespace tint
// Forward declarations
namespace tint::ast {
class LiteralExpression;
} // namespace tint::ast
namespace tint::sem {
class Constant;
class Expression;
class StructMember;
class Type;
} // namespace tint::sem
namespace tint::resolver::const_eval {
namespace tint::resolver {
/// ConstEval performs shader creation-time (constant expression) expression evaluation.
/// Methods are called from the resolver, either directly or via member-function pointers indexed by
/// the IntrinsicTable. All child-expression nodes are guaranteed to have been already resolved
/// before calling a method to evaluate an expression's value.
class ConstEval {
public:
/// Typedef for a constant evaluation function
using Function = const sem::Constant*(ProgramBuilder& builder,
sem::Constant const* const* args,
using Function = const sem::Constant* (ConstEval::*)(const sem::Type* result_ty,
sem::Expression const* const* args,
size_t num_args);
} // namespace tint::resolver::const_eval
/// The result type of a method that may raise a diagnostic error and the caller should abort
/// resolving. Can be one of three distinct values:
/// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
/// value.
/// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time
/// value, but is otherwise legal.
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
/// will have already reported a diagnostic error message, and the caller should abort
/// resolving.
using ConstantResult = utils::Result<const sem::Constant*>;
/// Constructor
/// @param b the program builder
explicit ConstEval(ProgramBuilder& b);
////////////////////////////////////////////////////////////////////////////////////////////////
// Constant value evaluation methods, to be called directly from Resolver
////////////////////////////////////////////////////////////////////////////////////////////////
/// @param ty the target type
/// @param expr the input expression
/// @return the bit-cast of the given expression to the given type, or null if the value cannot
/// be calculated
const sem::Constant* Bitcast(const sem::Type* ty, const sem::Expression* expr);
/// @param ty the target type
/// @param args the input arguments
/// @return the resulting type constructor or conversion, or null if the value cannot be
/// calculated
const sem::Constant* CtorOrConv(const sem::Type* ty,
const std::vector<const sem::Expression*>& args);
/// @param obj the object being indexed
/// @param idx the index expression
/// @return the result of the index, or null if the value cannot be calculated
const sem::Constant* Index(const sem::Expression* obj, const sem::Expression* idx);
/// @param ty the result type
/// @param lit the literal AST node
/// @return the constant value of the literal
const sem::Constant* Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
/// @param obj the object being accessed
/// @param member the member
/// @return the result of the member access, or null if the value cannot be calculated
const sem::Constant* MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
/// @param ty the result type
/// @param vector the vector being swizzled
/// @param indices the swizzle indices
/// @return the result of the swizzle, or null if the value cannot be calculated
const sem::Constant* Swizzle(const sem::Type* ty,
const sem::Expression* vector,
const std::vector<uint32_t>& indices);
/// Convert the `value` to `target_type`
/// @param ty the result type
/// @param value the value being converted
/// @param source the source location of the conversion
/// @return the converted value, or null if the value cannot be calculated
ConstantResult Convert(const sem::Type* ty, const sem::Constant* value, const Source& source);
private:
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;
/// Adds the given warning message to the diagnostics
void AddWarning(const std::string& msg, const Source& source) const;
ProgramBuilder& builder;
};
} // namespace tint::resolver
#endif // SRC_TINT_RESOLVER_CONST_EVAL_H_

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/resolver.h"
#include <cmath>
#include "gtest/gtest.h"

View File

@ -854,7 +854,7 @@ struct OverloadInfo {
/// The flags for the overload
OverloadFlags flags;
/// The function used to evaluate the overload at shader-creation time.
const_eval::Function* const const_eval_fn;
ConstEval::Function const const_eval_fn;
};
/// IntrinsicInfo describes a builtin function or operator overload

View File

@ -47,7 +47,7 @@ class IntrinsicTable {
/// The semantic info for the builtin
const sem::Builtin* sem = nullptr;
/// The constant evaluation function
const_eval::Function* const_eval_fn = nullptr;
ConstEval::Function const_eval_fn = nullptr;
};
/// UnaryOperator describes a resolved unary operator
@ -56,6 +56,8 @@ class IntrinsicTable {
const sem::Type* result = nullptr;
/// The type of the parameter of the unary operator
const sem::Type* parameter = nullptr;
/// The constant evaluation function
ConstEval::Function const_eval_fn = nullptr;
};
/// BinaryOperator describes a resolved binary operator
@ -66,6 +68,8 @@ class IntrinsicTable {
const sem::Type* lhs = nullptr;
/// The type of RHS parameter of the binary operator
const sem::Type* rhs = nullptr;
/// The constant evaluation function
ConstEval::Function const_eval_fn = nullptr;
};
/// Lookup looks for the builtin overload with the given signature, raising an error diagnostic

View File

@ -103,7 +103,7 @@ constexpr OverloadInfo kOverloads[] = {
{{- end }}
{{- if $o.IsDeprecated}}, OverloadFlag::kIsDeprecated{{end }}),
/* const eval */
{{- if $o.ConstEvalFunction }} const_eval::{{$o.ConstEvalFunction}},
{{- if $o.ConstEvalFunction }} &ConstEval::{{$o.ConstEvalFunction}},
{{- else }} nullptr,
{{- end }}
},

View File

@ -91,6 +91,7 @@ namespace tint::resolver {
Resolver::Resolver(ProgramBuilder* builder)
: builder_(builder),
diagnostics_(builder->Diagnostics()),
const_eval_(*builder),
intrinsic_table_(IntrinsicTable::Create(*builder)),
sem_(builder, dependencies_),
validator_(builder, sem_) {}
@ -1313,7 +1314,7 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
<< ") called on expression with no constant value";
return nullptr;
}
auto materialized_val = ConvertValue(expr_val, target_ty, decl->source);
auto materialized_val = const_eval_.Convert(target_ty, expr_val, decl->source);
if (!materialized_val) {
// ConvertValue() has already failed and raised an diagnostic error.
return nullptr;
@ -1422,7 +1423,7 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
ty = builder_->create<sem::Reference>(ty, ref->StorageClass(), ref->Access());
}
auto val = EvaluateIndexValue(obj, idx);
auto val = const_eval_.Index(obj, idx);
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::IndexAccessorExpression>(
expr, ty, obj, idx, current_statement_, std::move(val), has_side_effects,
@ -1441,7 +1442,7 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
return nullptr;
}
auto val = EvaluateBitcastValue(inner, ty);
auto val = const_eval_.Bitcast(ty, inner);
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, std::move(val),
inner->HasSideEffects());
@ -1489,7 +1490,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto val = EvaluateCtorOrConvValue(args, call_target->ReturnType());
auto val = const_eval_.CtorOrConv(call_target->ReturnType(), args);
return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_,
val, has_side_effects);
};
@ -1528,7 +1529,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto val = EvaluateCtorOrConvValue(args, arr);
auto val = const_eval_.CtorOrConv(arr, args);
return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, val, has_side_effects);
},
@ -1550,7 +1551,7 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto val = EvaluateCtorOrConvValue(args, str);
auto val = const_eval_.CtorOrConv(str, args);
return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, std::move(val),
has_side_effects);
@ -1682,28 +1683,17 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
}
// If the builtin is @const, and all arguments have constant values, evaluate the builtin now.
const sem::Constant* constant = nullptr;
const sem::Constant* value = nullptr;
if (builtin.const_eval_fn) {
std::vector<const sem::Constant*> values(args.size());
bool is_const = true; // all arguments have constant values
for (size_t i = 0; i < values.size(); i++) {
if (auto v = args[i]->ConstantValue()) {
values[i] = std::move(v);
} else {
is_const = false;
break;
}
}
if (is_const) {
constant = builtin.const_eval_fn(*builder_, values.data(), args.size());
}
value = (const_eval_.*builtin.const_eval_fn)(builtin.sem->ReturnType(), args.data(),
args.size());
}
bool has_side_effects =
builtin.sem->HasSideEffects() ||
std::any_of(args.begin(), args.end(), [](auto* e) { return e->HasSideEffects(); });
auto* call = builder_->create<sem::Call>(expr, builtin.sem, std::move(args), current_statement_,
constant, has_side_effects);
value, has_side_effects);
if (current_function_) {
current_function_->AddDirectlyCalledBuiltin(builtin.sem);
@ -1856,7 +1846,7 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
return nullptr;
}
auto val = EvaluateLiteralValue(literal, ty);
auto val = const_eval_.Literal(ty, literal);
return builder_->create<sem::Expression>(literal, ty, current_statement_, std::move(val),
/* has_side_effects */ false);
}
@ -1976,7 +1966,7 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
ret = builder_->create<sem::Reference>(ret, ref->StorageClass(), ref->Access());
}
auto* val = EvaluateMemberAccessValue(object, member);
auto* val = const_eval_.MemberAccess(object, member);
return builder_->create<sem::StructMemberAccess>(expr, ret, current_statement_, val, object,
member, has_side_effects, source_var);
}
@ -2044,7 +2034,7 @@ sem::Expression* Resolver::MemberAccessor(const ast::MemberAccessorExpression* e
// the swizzle.
ret = builder_->create<sem::Vector>(vec->type(), static_cast<uint32_t>(size));
}
auto* val = EvaluateSwizzleValue(object, ret, swizzle);
auto* val = const_eval_.Swizzle(ret, object, swizzle);
return builder_->create<sem::Swizzle>(expr, ret, current_statement_, val, object,
std::move(swizzle), has_side_effects, source_var);
}
@ -2078,9 +2068,14 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
}
}
auto* val = EvaluateBinaryValue(lhs, rhs, op);
const sem::Constant* value = nullptr;
if (op.const_eval_fn) {
const sem::Expression* args[] = {lhs, rhs};
value = (const_eval_.*op.const_eval_fn)(op.result, args, 2u);
}
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val,
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, value,
has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
@ -2096,7 +2091,7 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
const sem::Type* ty = nullptr;
const sem::Variable* source_var = nullptr;
const sem::Constant* val = nullptr;
const sem::Constant* value = nullptr;
switch (unary->op) {
case ast::UnaryOp::kAddressOf:
@ -2149,12 +2144,14 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
}
}
ty = op.result;
val = EvaluateUnaryValue(expr, op);
if (op.const_eval_fn) {
value = (const_eval_.*op.const_eval_fn)(ty, &expr, 1u);
}
break;
}
}
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val,
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, value,
expr->HasSideEffects(), source_var);
sem->Behaviors() = expr->Behaviors();
return sem;

View File

@ -24,6 +24,7 @@
#include <vector>
#include "src/tint/program_builder.h"
#include "src/tint/resolver/const_eval.h"
#include "src/tint/resolver/dependency_graph.h"
#include "src/tint/resolver/intrinsic_table.h"
#include "src/tint/resolver/sem_helper.h"
@ -34,7 +35,6 @@
#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/result.h"
#include "src/tint/utils/unique_vector.h"
// Forward declarations
@ -204,46 +204,6 @@ class Resolver {
sem::Expression* MemberAccessor(const ast::MemberAccessorExpression*);
sem::Expression* UnaryOp(const ast::UnaryOpExpression*);
////////////////////////////////////////////////////////////////////////////////////////////////
/// Constant value evaluation methods
///
/// These methods are called from the expression resolving methods, and so child-expression
/// nodes are guaranteed to have been already resolved and any constant values calculated.
////////////////////////////////////////////////////////////////////////////////////////////////
const sem::Constant* EvaluateBinaryValue(const sem::Expression* lhs,
const sem::Expression* rhs,
const IntrinsicTable::BinaryOperator&);
const sem::Constant* EvaluateBitcastValue(const sem::Expression*, const sem::Type*);
const sem::Constant* EvaluateCtorOrConvValue(
const std::vector<const sem::Expression*>& args,
const sem::Type* ty); // Note: ty is not an array or structure
const sem::Constant* EvaluateIndexValue(const sem::Expression* obj, const sem::Expression* idx);
const sem::Constant* EvaluateLiteralValue(const ast::LiteralExpression*, const sem::Type*);
const sem::Constant* EvaluateMemberAccessValue(const sem::Expression* obj,
const sem::StructMember* member);
const sem::Constant* EvaluateSwizzleValue(const sem::Expression* vector,
const sem::Type* type,
const std::vector<uint32_t>& indices);
const sem::Constant* EvaluateUnaryValue(const sem::Expression*,
const IntrinsicTable::UnaryOperator&);
/// The result type of a ConstantEvaluation method.
/// Can be one of three distinct values:
/// * A non-null sem::Constant pointer. Returned when a expression resolves to a creation time
/// value.
/// * A null sem::Constant pointer. Returned when a expression cannot resolve to a creation time
/// value, but is otherwise legal.
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
/// will have already reported a diagnostic error message, and the caller should abort
/// resolving.
using ConstantResult = utils::Result<const sem::Constant*>;
/// Convert the `value` to `target_type`
/// @return the converted value
ConstantResult ConvertValue(const sem::Constant* value,
const sem::Type* target_type,
const Source& source);
/// If `expr` is not of an abstract-numeric type, then Materialize() will just return `expr`.
/// If `expr` is of an abstract-numeric type:
/// * Materialize will create and return a sem::Materialize node wrapping `expr`.
@ -449,6 +409,7 @@ class Resolver {
ProgramBuilder* const builder_;
diag::List& diagnostics_;
ConstEval const_eval_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
DependencyGraph dependencies_;
SemHelper sem_;

View File

@ -1,605 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/resolver.h"
#include <optional>
#include "src/tint/sem/abstract_float.h"
#include "src/tint/sem/abstract_int.h"
#include "src/tint/sem/constant.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/type_constructor.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/transform.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
namespace {
/// TypeDispatch is a helper for calling the function `f`, passing a single zero-value argument of
/// the C++ type that corresponds to the sem::Type `type`. For example, calling `TypeDispatch()`
/// with a type of `sem::I32*` will call the function f with a single argument of `i32(0)`.
/// @returns the value returned by calling `f`.
/// @note `type` must be a scalar or abstract numeric type. Other types will not call `f`, and will
/// return the zero-initialized value of the return type for `f`.
template <typename F>
auto TypeDispatch(const sem::Type* type, F&& f) {
return Switch(
type, //
[&](const sem::AbstractInt*) { return f(AInt(0)); }, //
[&](const sem::AbstractFloat*) { return f(AFloat(0)); }, //
[&](const sem::I32*) { return f(i32(0)); }, //
[&](const sem::U32*) { return f(u32(0)); }, //
[&](const sem::F32*) { return f(f32(0)); }, //
[&](const sem::F16*) { return f(f16(0)); }, //
[&](const sem::Bool*) { return f(static_cast<bool>(0)); });
}
/// @returns `value` if `T` is not a Number, otherwise ValueOf returns the inner value of the
/// Number.
template <typename T>
inline auto ValueOf(T value) {
if constexpr (std::is_same_v<UnwrapNumber<T>, T>) {
return value;
} else {
return value.value;
}
}
/// @returns true if `value` is a positive zero.
template <typename T>
inline bool IsPositiveZero(T value) {
using N = UnwrapNumber<T>;
return Number<N>(value) == Number<N>(0); // Considers sign bit
}
/// Constant inherits from sem::Constant to add an private implementation method for conversion.
struct Constant : public sem::Constant {
/// Convert attempts to convert the constant value to the given type. On error, Convert()
/// creates a new diagnostic message and returns a Failure.
virtual utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const = 0;
};
// Forward declaration
const Constant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type,
std::vector<const Constant*> elements);
/// Element holds a single scalar or abstract-numeric value.
/// Element implements the Constant interface.
template <typename T>
struct Element : Constant {
Element(const sem::Type* t, T v) : type(t), value(v) {}
~Element() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override {
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return static_cast<AFloat>(value);
} else {
return static_cast<AInt>(value);
}
}
const Constant* Index(size_t) const override { return nullptr; }
bool AllZero() const override { return IsPositiveZero(value); }
bool AnyZero() const override { return IsPositiveZero(value); }
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == type) {
// If the types are identical, then no conversion is needed.
return this;
}
bool failed = false;
auto* res = TypeDispatch(target_ty, [&](auto zero_to) -> const Constant* {
// `T` is the source type, `value` is the source value.
// `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>;
if constexpr (std::is_same_v<TO, bool>) {
// [x -> bool]
return builder.create<Element<TO>>(target_ty, !IsPositiveZero(value));
} else if constexpr (std::is_same_v<T, bool>) {
// [bool -> x]
return builder.create<Element<TO>>(target_ty, TO(value ? 1 : 0));
} else if (auto conv = CheckedConvert<TO>(value)) {
// Conversion success
return builder.create<Element<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases ---
} else if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
// [abstract-numeric -> x] - materialization failure
std::stringstream ss;
ss << "value " << value << " cannot be represented as ";
ss << "'" << builder.FriendlyName(target_ty) << "'";
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
failed = true;
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
// [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
constexpr auto kInf = std::numeric_limits<double>::infinity();
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Element<TO>>(target_ty, TO(-kInf));
case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO(kInf));
}
} else {
// [x -> integer] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion
switch (conv.Failure()) {
case ConversionFailure::kExceedsNegativeLimit:
return builder.create<Element<TO>>(target_ty, TO(TO::kLowest));
case ConversionFailure::kExceedsPositiveLimit:
return builder.create<Element<TO>>(target_ty, TO(TO::kHighest));
}
}
return nullptr; // Expression is not constant.
});
if (failed) {
// A diagnostic error has been raised, and resolving should abort.
return utils::Failure;
}
return res;
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
sem::Type const* const type;
const T value;
};
/// Splat holds a single Constant value, duplicated as all children.
/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is
/// identical. Splat may be of a vector, matrix or array type.
/// Splat implements the Constant interface.
struct Splat : Constant {
Splat(const sem::Type* t, const Constant* e, size_t n) : type(t), el(e), count(n) {}
~Splat() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
const Constant* Index(size_t i) const override { return i < count ? el : nullptr; }
bool AllZero() const override { return el->AllZero(); }
bool AnyZero() const override { return el->AnyZero(); }
bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
// Convert the single splatted element type.
auto conv_el = el->Convert(builder, sem::Type::ElementOf(target_ty), source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
return builder.create<Splat>(target_ty, conv_el.Get(), count);
}
sem::Type const* const type;
const Constant* el;
const size_t count;
};
/// Composite holds a number of mixed child Constant values.
/// Composite may be of a vector, matrix or array type.
/// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate Constant type.
/// Composite implements the Constant interface.
struct Composite : Constant {
Composite(const sem::Type* t, std::vector<const Constant*> els, bool all_0, bool any_0)
: type(t), elements(std::move(els)), all_zero(all_0), any_zero(any_0), hash(CalcHash()) {}
~Composite() override = default;
const sem::Type* Type() const override { return type; }
std::variant<std::monostate, AInt, AFloat> Value() const override { return {}; }
const Constant* Index(size_t i) const override {
return i < elements.size() ? elements[i] : nullptr;
}
bool AllZero() const override { return all_zero; }
bool AnyZero() const override { return any_zero; }
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
size_t Hash() const override { return hash; }
utils::Result<const Constant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty,
const Source& source) const override {
// Convert each of the composite element types.
auto* el_ty = sem::Type::ElementOf(target_ty);
std::vector<const Constant*> conv_els;
conv_els.reserve(elements.size());
for (auto* el : elements) {
auto conv_el = el->Convert(builder, el_ty, source);
if (!conv_el) {
return utils::Failure;
}
if (!conv_el.Get()) {
return nullptr;
}
conv_els.emplace_back(conv_el.Get());
}
return CreateComposite(builder, target_ty, std::move(conv_els));
}
size_t CalcHash() {
auto h = utils::Hash(type, all_zero, any_zero);
for (auto* el : elements) {
utils::HashCombine(&h, el->Hash());
}
return h;
}
sem::Type const* const type;
const std::vector<const Constant*> elements;
const bool all_zero;
const bool any_zero;
const size_t hash;
};
/// CreateElement constructs and returns an Element<T>.
template <typename T>
const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
return builder.create<Element<T>>(t, v);
}
/// ZeroValue returns a Constant for the zero-value of the type `type`.
const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return Switch(
type, //
[&](const sem::Vector* v) -> const Constant* {
auto* zero_el = ZeroValue(builder, v->type());
return builder.create<Splat>(type, zero_el, v->Width());
},
[&](const sem::Matrix* m) -> const Constant* {
auto* zero_el = ZeroValue(builder, m->ColumnType());
return builder.create<Splat>(type, zero_el, m->columns());
},
[&](const sem::Array* a) -> const Constant* {
if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
return builder.create<Splat>(type, zero_el, a->Count());
}
return nullptr;
},
[&](const sem::Struct* s) -> const Constant* {
std::unordered_map<sem::Type*, const Constant*> zero_by_type;
std::vector<const Constant*> zeros;
zeros.reserve(s->Members().size());
for (auto* member : s->Members()) {
auto* zero = utils::GetOrCreate(zero_by_type, member->Type(),
[&] { return ZeroValue(builder, member->Type()); });
if (!zero) {
return nullptr;
}
zeros.emplace_back(zero);
}
if (zero_by_type.size() == 1) {
// All members were of the same type, so the zero value is the same for all members.
return builder.create<Splat>(type, zeros[0], s->Members().size());
}
return CreateComposite(builder, s, std::move(zeros));
},
[&](Default) -> const Constant* {
return TypeDispatch(type, [&](auto zero) -> const Constant* {
return CreateElement(builder, type, zero);
});
});
}
/// Equal returns true if the constants `a` and `b` are of the same type and value.
bool Equal(const sem::Constant* a, const sem::Constant* b) {
if (a->Hash() != b->Hash()) {
return false;
}
if (a->Type() != b->Type()) {
return false;
}
return Switch(
a->Type(), //
[&](const sem::Vector* vec) {
for (size_t i = 0; i < vec->Width(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](const sem::Matrix* mat) {
for (size_t i = 0; i < mat->columns(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](const sem::Array* arr) {
for (size_t i = 0; i < arr->Count(); i++) {
if (!Equal(a->Index(i), b->Index(i))) {
return false;
}
}
return true;
},
[&](Default) { return a->Value() == b->Value(); });
}
/// CreateComposite is used to construct a constant of a vector, matrix or array type.
/// CreateComposite examines the element values and will return either a Composite or a Splat,
/// depending on the element types and values.
const Constant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type,
std::vector<const Constant*> elements) {
if (elements.size() == 0) {
return nullptr;
}
bool any_zero = false;
bool all_zero = true;
bool all_equal = true;
auto* first = elements.front();
for (auto* el : elements) {
if (!el) {
return nullptr;
}
if (!any_zero && el->AnyZero()) {
any_zero = true;
}
if (all_zero && !el->AllZero()) {
all_zero = false;
}
if (all_equal && el != first) {
if (!Equal(el, first)) {
all_equal = false;
}
}
}
if (all_equal) {
return builder.create<Splat>(type, elements[0], elements.size());
} else {
return builder.create<Composite>(type, std::move(elements), all_zero, any_zero);
}
}
} // namespace
const sem::Constant* Resolver::EvaluateLiteralValue(const ast::LiteralExpression* literal,
const sem::Type* type) {
return Switch(
literal,
[&](const ast::BoolLiteralExpression* lit) {
return CreateElement(*builder_, type, lit->value);
},
[&](const ast::IntLiteralExpression* lit) -> const Constant* {
switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone:
return CreateElement(*builder_, type, AInt(lit->value));
case ast::IntLiteralExpression::Suffix::kI:
return CreateElement(*builder_, type, i32(lit->value));
case ast::IntLiteralExpression::Suffix::kU:
return CreateElement(*builder_, type, u32(lit->value));
}
return nullptr;
},
[&](const ast::FloatLiteralExpression* lit) -> const Constant* {
switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone:
return CreateElement(*builder_, type, AFloat(lit->value));
case ast::FloatLiteralExpression::Suffix::kF:
return CreateElement(*builder_, type, f32(lit->value));
case ast::FloatLiteralExpression::Suffix::kH:
return CreateElement(*builder_, type, f16(lit->value));
}
return nullptr;
});
}
const sem::Constant* Resolver::EvaluateCtorOrConvValue(
const std::vector<const sem::Expression*>& args,
const sem::Type* ty) {
// For zero value init, return 0s
if (args.empty()) {
return ZeroValue(*builder_, ty);
}
if (auto* el_ty = sem::Type::ElementOf(ty); el_ty && args.size() == 1) {
// Type constructor or conversion that takes a single argument.
auto& src = args[0]->Declaration()->source;
auto* arg = static_cast<const Constant*>(args[0]->ConstantValue());
if (!arg) {
return nullptr; // Single argument is not constant.
}
if (ty->is_scalar()) { // Scalar type conversion: i32(x), u32(x), bool(x), etc
return ConvertValue(arg, el_ty, src).Get();
}
if (arg->Type() == el_ty) {
// Argument type matches function type. This is a splat.
auto splat = [&](size_t n) { return builder_->create<Splat>(ty, arg, n); };
return Switch(
ty, //
[&](const sem::Vector* v) { return splat(v->Width()); },
[&](const sem::Matrix* m) { return splat(m->columns()); },
[&](const sem::Array* a) { return splat(a->Count()); });
}
// Argument type and function type mismatch. This is a type conversion.
if (auto conv = ConvertValue(arg, ty, src)) {
return conv.Get();
}
return nullptr;
}
// Helper for pushing all the argument constants to `els`.
auto args_as_constants = [&] {
return utils::Transform(
args, [&](auto* expr) { return static_cast<const Constant*>(expr->ConstantValue()); });
};
// Multiple arguments. Must be a type constructor.
return Switch(
ty, // What's the target type being constructed?
[&](const sem::Vector*) -> const Constant* {
// Vector can be constructed with a mix of scalars / abstract numerics and smaller
// vectors.
std::vector<const Constant*> els;
els.reserve(args.size());
for (auto* expr : args) {
auto* arg = static_cast<const Constant*>(expr->ConstantValue());
if (!arg) {
return nullptr;
}
auto* arg_ty = arg->Type();
if (auto* arg_vec = arg_ty->As<sem::Vector>()) {
// Extract out vector elements.
for (uint32_t i = 0; i < arg_vec->Width(); i++) {
auto* el = static_cast<const Constant*>(arg->Index(i));
if (!el) {
return nullptr;
}
els.emplace_back(el);
}
} else {
els.emplace_back(arg);
}
}
return CreateComposite(*builder_, ty, std::move(els));
},
[&](const sem::Matrix* m) -> const Constant* {
// Matrix can be constructed with a set of scalars / abstract numerics, or column
// vectors.
if (args.size() == m->columns() * m->rows()) {
// Matrix built from scalars / abstract numerics
std::vector<const Constant*> els;
els.reserve(args.size());
for (uint32_t c = 0; c < m->columns(); c++) {
std::vector<const Constant*> column;
column.reserve(m->rows());
for (uint32_t r = 0; r < m->rows(); r++) {
auto* arg =
static_cast<const Constant*>(args[r + c * m->rows()]->ConstantValue());
if (!arg) {
return nullptr;
}
column.emplace_back(arg);
}
els.push_back(CreateComposite(*builder_, m->ColumnType(), std::move(column)));
}
return CreateComposite(*builder_, ty, std::move(els));
}
// Matrix built from column vectors
return CreateComposite(*builder_, ty, args_as_constants());
},
[&](const sem::Array*) {
// Arrays must be constructed using a list of elements
return CreateComposite(*builder_, ty, args_as_constants());
},
[&](const sem::Struct*) {
// Structures must be constructed using a list of elements
return CreateComposite(*builder_, ty, args_as_constants());
});
}
const sem::Constant* Resolver::EvaluateIndexValue(const sem::Expression* obj_expr,
const sem::Expression* idx_expr) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return {};
}
auto idx_val = idx_expr->ConstantValue();
if (!idx_val) {
return {};
}
uint32_t el_count = 0;
sem::Type::ElementOf(obj_val->Type(), &el_count);
AInt idx = idx_val->As<AInt>();
if (idx < 0 || idx >= el_count) {
auto clamped = std::min<AInt::type>(std::max<AInt::type>(idx, 0), el_count - 1);
AddWarning("index " + std::to_string(idx) + " out of bounds [0.." +
std::to_string(el_count - 1) + "]. Clamping index to " +
std::to_string(clamped),
idx_expr->Declaration()->source);
idx = clamped;
}
return obj_val->Index(static_cast<size_t>(idx));
}
const sem::Constant* Resolver::EvaluateMemberAccessValue(const sem::Expression* obj_expr,
const sem::StructMember* member) {
auto obj_val = obj_expr->ConstantValue();
if (!obj_val) {
return {};
}
return obj_val->Index(static_cast<size_t>(member->Index()));
}
const sem::Constant* Resolver::EvaluateSwizzleValue(const sem::Expression* vec_expr,
const sem::Type* type,
const std::vector<uint32_t>& indices) {
auto* vec_val = vec_expr->ConstantValue();
if (!vec_val) {
return nullptr;
}
if (indices.size() == 1) {
return static_cast<const Constant*>(vec_val->Index(static_cast<size_t>(indices[0])));
} else {
auto values = utils::Transform(
indices, [&](uint32_t i) { return static_cast<const Constant*>(vec_val->Index(i)); });
return CreateComposite(*builder_, type, std::move(values));
}
}
const sem::Constant* Resolver::EvaluateBitcastValue(const sem::Expression*, const sem::Type*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
const sem::Constant* Resolver::EvaluateBinaryValue(const sem::Expression*,
const sem::Expression*,
const IntrinsicTable::BinaryOperator&) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
const sem::Constant* Resolver::EvaluateUnaryValue(const sem::Expression*,
const IntrinsicTable::UnaryOperator&) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
}
utils::Result<const sem::Constant*> Resolver::ConvertValue(const sem::Constant* value,
const sem::Type* target_ty,
const Source& source) {
if (value->Type() == target_ty) {
return value;
}
auto conv = static_cast<const Constant*>(value)->Convert(*builder_, target_ty, source);
if (!conv) {
return utils::Failure;
}
return conv.Get();
}
} // namespace tint::resolver