tint/resolver: Consistently use utils::Result in ConstEval

Instead of using builder.Diagnostics().contains_errors()

Produces cleaner code and reduces scope of error handling.

Bug: tint:1661
Change-Id: I35af5ad1c6553f2cf74d1ce92dc14984f93b9db4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/102161
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2022-09-20 09:26:21 +00:00 committed by Dawn LUCI CQ
parent fc6167b9a8
commit e68d4506c0
4 changed files with 337 additions and 367 deletions

View File

@ -182,24 +182,27 @@ std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) {
return ss.str(); return ss.str();
} }
/// Constant inherits from sem::Constant to add an private implementation method for conversion. /// ImplConstant inherits from sem::Constant to add an private implementation method for conversion.
struct Constant : public sem::Constant { struct ImplConstant : public sem::Constant {
/// Convert attempts to convert the constant value to the given type. On error, Convert() /// Convert attempts to convert the constant value to the given type. On error, Convert()
/// creates a new diagnostic message and returns a Failure. /// creates a new diagnostic message and returns a Failure.
virtual utils::Result<const Constant*> Convert(ProgramBuilder& builder, virtual utils::Result<const ImplConstant*> Convert(ProgramBuilder& builder,
const sem::Type* target_ty, const sem::Type* target_ty,
const Source& source) const = 0; const Source& source) const = 0;
}; };
/// A result templated with a ImplConstant.
using ImplResult = utils::Result<const ImplConstant*>;
// Forward declaration // Forward declaration
const Constant* CreateComposite(ProgramBuilder& builder, const ImplConstant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type, const sem::Type* type,
utils::VectorRef<const sem::Constant*> elements); utils::VectorRef<const sem::Constant*> elements);
/// Element holds a single scalar or abstract-numeric value. /// Element holds a single scalar or abstract-numeric value.
/// Element implements the Constant interface. /// Element implements the Constant interface.
template <typename T> template <typename T>
struct Element : Constant { struct Element : ImplConstant {
static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>, static_assert(!std::is_same_v<UnwrapNumber<T>, T> || std::is_same_v<T, bool>,
"T must be a Number or bool"); "T must be a Number or bool");
@ -219,16 +222,15 @@ struct Element : Constant {
bool AllEqual() const override { return true; } bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, ValueOf(value)); } size_t Hash() const override { return utils::Hash(type, ValueOf(value)); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder, ImplResult Convert(ProgramBuilder& builder,
const sem::Type* target_ty, const sem::Type* target_ty,
const Source& source) const override { const Source& source) const override {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE); TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
if (target_ty == type) { if (target_ty == type) {
// If the types are identical, then no conversion is needed. // If the types are identical, then no conversion is needed.
return this; return this;
} }
bool failed = false; return ZeroTypeDispatch(target_ty, [&](auto zero_to) -> ImplResult {
auto* res = ZeroTypeDispatch(target_ty, [&](auto zero_to) -> const Constant* {
// `T` is the source type, `value` is the source value. // `T` is the source type, `value` is the source value.
// `TO` is the target type. // `TO` is the target type.
using TO = std::decay_t<decltype(zero_to)>; using TO = std::decay_t<decltype(zero_to)>;
@ -248,7 +250,7 @@ struct Element : Constant {
ss << "value " << value << " cannot be represented as "; ss << "value " << value << " cannot be represented as ";
ss << "'" << builder.FriendlyName(target_ty) << "'"; ss << "'" << builder.FriendlyName(target_ty) << "'";
builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source); builder.Diagnostics().add_error(tint::diag::System::Resolver, ss.str(), source);
failed = true; return utils::Failure;
} else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) { } else if constexpr (IsFloatingPoint<UnwrapNumber<TO>>) {
// [x -> floating-point] - number not exactly representable // [x -> floating-point] - number not exactly representable
// https://www.w3.org/TR/WGSL/#floating-point-conversion // https://www.w3.org/TR/WGSL/#floating-point-conversion
@ -270,11 +272,6 @@ struct Element : Constant {
} }
return nullptr; // Expression is not constant. return nullptr; // Expression is not constant.
}); });
if (failed) {
// A diagnostic error has been raised, and resolving should abort.
return utils::Failure;
}
return res;
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE); TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
} }
@ -286,7 +283,7 @@ struct Element : Constant {
/// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is /// Splat is used for zero-initializers, 'splat' constructors, or constructors where each element is
/// identical. Splat may be of a vector, matrix or array type. /// identical. Splat may be of a vector, matrix or array type.
/// Splat implements the Constant interface. /// Splat implements the Constant interface.
struct Splat : Constant { struct Splat : ImplConstant {
Splat(const sem::Type* t, const sem::Constant* e, size_t n) : type(t), el(e), count(n) {} Splat(const sem::Type* t, const sem::Constant* e, size_t n) : type(t), el(e), count(n) {}
~Splat() override = default; ~Splat() override = default;
const sem::Type* Type() const override { return type; } const sem::Type* Type() const override { return type; }
@ -297,13 +294,13 @@ struct Splat : Constant {
bool AllEqual() const override { return true; } bool AllEqual() const override { return true; }
size_t Hash() const override { return utils::Hash(type, el->Hash(), count); } size_t Hash() const override { return utils::Hash(type, el->Hash(), count); }
utils::Result<const Constant*> Convert(ProgramBuilder& builder, ImplResult Convert(ProgramBuilder& builder,
const sem::Type* target_ty, const sem::Type* target_ty,
const Source& source) const override { const Source& source) const override {
// Convert the single splatted element type. // Convert the single splatted element type.
// Note: This file is the only place where `sem::Constant`s are created, so this static_cast // Note: This file is the only place where `sem::Constant`s are created, so this static_cast
// is safe. // is safe.
auto conv_el = static_cast<const Constant*>(el)->Convert( auto conv_el = static_cast<const ImplConstant*>(el)->Convert(
builder, sem::Type::ElementOf(target_ty), source); builder, sem::Type::ElementOf(target_ty), source);
if (!conv_el) { if (!conv_el) {
return utils::Failure; return utils::Failure;
@ -324,7 +321,7 @@ struct Splat : Constant {
/// If each element is the same type and value, then a Splat would be a more efficient constant /// If each element is the same type and value, then a Splat would be a more efficient constant
/// implementation. Use CreateComposite() to create the appropriate Constant type. /// implementation. Use CreateComposite() to create the appropriate Constant type.
/// Composite implements the Constant interface. /// Composite implements the Constant interface.
struct Composite : Constant { struct Composite : ImplConstant {
Composite(const sem::Type* t, Composite(const sem::Type* t,
utils::VectorRef<const sem::Constant*> els, utils::VectorRef<const sem::Constant*> els,
bool all_0, bool all_0,
@ -341,9 +338,9 @@ struct Composite : Constant {
bool AllEqual() const override { return false; /* otherwise this should be a Splat */ } bool AllEqual() const override { return false; /* otherwise this should be a Splat */ }
size_t Hash() const override { return hash; } size_t Hash() const override { return hash; }
utils::Result<const Constant*> Convert(ProgramBuilder& builder, ImplResult Convert(ProgramBuilder& builder,
const sem::Type* target_ty, const sem::Type* target_ty,
const Source& source) const override { const Source& source) const override {
// Convert each of the composite element types. // Convert each of the composite element types.
auto* el_ty = sem::Type::ElementOf(target_ty); auto* el_ty = sem::Type::ElementOf(target_ty);
utils::Vector<const sem::Constant*, 4> conv_els; utils::Vector<const sem::Constant*, 4> conv_els;
@ -351,7 +348,7 @@ struct Composite : Constant {
for (auto* el : elements) { for (auto* el : elements) {
// Note: This file is the only place where `sem::Constant`s are created, so this // Note: This file is the only place where `sem::Constant`s are created, so this
// static_cast is safe. // static_cast is safe.
auto conv_el = static_cast<const Constant*>(el)->Convert(builder, el_ty, source); auto conv_el = static_cast<const ImplConstant*>(el)->Convert(builder, el_ty, source);
if (!conv_el) { if (!conv_el) {
return utils::Failure; return utils::Failure;
} }
@ -380,30 +377,30 @@ struct Composite : Constant {
/// CreateElement constructs and returns an Element<T>. /// CreateElement constructs and returns an Element<T>.
template <typename T> template <typename T>
const Constant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) { const ImplConstant* CreateElement(ProgramBuilder& builder, const sem::Type* t, T v) {
return builder.create<Element<T>>(t, v); return builder.create<Element<T>>(t, v);
} }
/// ZeroValue returns a Constant for the zero-value of the type `type`. /// ZeroValue returns a Constant for the zero-value of the type `type`.
const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) { const ImplConstant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
return Switch( return Switch(
type, // type, //
[&](const sem::Vector* v) -> const Constant* { [&](const sem::Vector* v) -> const ImplConstant* {
auto* zero_el = ZeroValue(builder, v->type()); auto* zero_el = ZeroValue(builder, v->type());
return builder.create<Splat>(type, zero_el, v->Width()); return builder.create<Splat>(type, zero_el, v->Width());
}, },
[&](const sem::Matrix* m) -> const Constant* { [&](const sem::Matrix* m) -> const ImplConstant* {
auto* zero_el = ZeroValue(builder, m->ColumnType()); auto* zero_el = ZeroValue(builder, m->ColumnType());
return builder.create<Splat>(type, zero_el, m->columns()); return builder.create<Splat>(type, zero_el, m->columns());
}, },
[&](const sem::Array* a) -> const Constant* { [&](const sem::Array* a) -> const ImplConstant* {
if (auto* zero_el = ZeroValue(builder, a->ElemType())) { if (auto* zero_el = ZeroValue(builder, a->ElemType())) {
return builder.create<Splat>(type, zero_el, a->Count()); return builder.create<Splat>(type, zero_el, a->Count());
} }
return nullptr; return nullptr;
}, },
[&](const sem::Struct* s) -> const Constant* { [&](const sem::Struct* s) -> const ImplConstant* {
std::unordered_map<const sem::Type*, const Constant*> zero_by_type; std::unordered_map<const sem::Type*, const ImplConstant*> zero_by_type;
utils::Vector<const sem::Constant*, 4> zeros; utils::Vector<const sem::Constant*, 4> zeros;
zeros.Reserve(s->Members().size()); zeros.Reserve(s->Members().size());
for (auto* member : s->Members()) { for (auto* member : s->Members()) {
@ -420,8 +417,8 @@ const Constant* ZeroValue(ProgramBuilder& builder, const sem::Type* type) {
} }
return CreateComposite(builder, s, std::move(zeros)); return CreateComposite(builder, s, std::move(zeros));
}, },
[&](Default) -> const Constant* { [&](Default) -> const ImplConstant* {
return ZeroTypeDispatch(type, [&](auto zero) -> const Constant* { return ZeroTypeDispatch(type, [&](auto zero) -> const ImplConstant* {
return CreateElement(builder, type, zero); return CreateElement(builder, type, zero);
}); });
}); });
@ -467,9 +464,9 @@ bool Equal(const sem::Constant* a, const sem::Constant* b) {
/// CreateComposite is used to construct a constant of a vector, matrix or array type. /// CreateComposite is used to construct a constant of a vector, matrix or array type.
/// CreateComposite examines the element values and will return either a Composite or a Splat, /// CreateComposite examines the element values and will return either a Composite or a Splat,
/// depending on the element types and values. /// depending on the element types and values.
const Constant* CreateComposite(ProgramBuilder& builder, const ImplConstant* CreateComposite(ProgramBuilder& builder,
const sem::Type* type, const sem::Type* type,
utils::VectorRef<const sem::Constant*> elements) { utils::VectorRef<const sem::Constant*> elements) {
if (elements.IsEmpty()) { if (elements.IsEmpty()) {
return nullptr; return nullptr;
} }
@ -504,10 +501,10 @@ const Constant* CreateComposite(ProgramBuilder& builder,
/// transformation function 'f' on each of the most deeply nested elements of 'cs'. Assumes that all /// transformation function 'f' on each of the most deeply nested elements of 'cs'. Assumes that all
/// input constants `cs` are of the same type. /// input constants `cs` are of the same type.
template <typename F, typename... CONSTANTS> template <typename F, typename... CONSTANTS>
const Constant* TransformElements(ProgramBuilder& builder, ImplResult TransformElements(ProgramBuilder& builder,
const sem::Type* composite_ty, const sem::Type* composite_ty,
F&& f, F&& f,
CONSTANTS&&... cs) { CONSTANTS&&... cs) {
uint32_t n = 0; uint32_t n = 0;
auto* ty = First(cs...)->Type(); auto* ty = First(cs...)->Type();
auto* el_ty = sem::Type::ElementOf(ty, &n); auto* el_ty = sem::Type::ElementOf(ty, &n);
@ -517,8 +514,13 @@ const Constant* TransformElements(ProgramBuilder& builder,
utils::Vector<const sem::Constant*, 8> els; utils::Vector<const sem::Constant*, 8> els;
els.Reserve(n); els.Reserve(n);
for (uint32_t i = 0; i < n; i++) { for (uint32_t i = 0; i < n; i++) {
els.Push(TransformElements(builder, sem::Type::ElementOf(composite_ty), std::forward<F>(f), if (auto el = TransformElements(builder, sem::Type::ElementOf(composite_ty),
cs->Index(i)...)); std::forward<F>(f), cs->Index(i)...)) {
els.Push(el.Get());
} else {
return el.Failure();
}
} }
return CreateComposite(builder, composite_ty, std::move(els)); return CreateComposite(builder, composite_ty, std::move(els));
} }
@ -528,11 +530,11 @@ const Constant* TransformElements(ProgramBuilder& builder,
/// Unlike TransformElements, this function handles the constants being of different types, e.g. /// Unlike TransformElements, this function handles the constants being of different types, e.g.
/// vector-scalar, scalar-vector. /// vector-scalar, scalar-vector.
template <typename F> template <typename F>
const Constant* TransformBinaryElements(ProgramBuilder& builder, ImplResult TransformBinaryElements(ProgramBuilder& builder,
const sem::Type* composite_ty, const sem::Type* composite_ty,
F&& f, F&& f,
const sem::Constant* c0, const sem::Constant* c0,
const sem::Constant* c1) { const sem::Constant* c1) {
uint32_t n0 = 0, n1 = 0; uint32_t n0 = 0, n1 = 0;
sem::Type::ElementOf(c0->Type(), &n0); sem::Type::ElementOf(c0->Type(), &n0);
sem::Type::ElementOf(c1->Type(), &n1); sem::Type::ElementOf(c1->Type(), &n1);
@ -551,9 +553,13 @@ const Constant* TransformBinaryElements(ProgramBuilder& builder,
} }
return c->Index(i); return c->Index(i);
}; };
els.Push(TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty), if (auto el = TransformBinaryElements(builder, sem::Type::ElementOf(composite_ty),
std::forward<F>(f), nested_or_self(c0, n0), std::forward<F>(f), nested_or_self(c0, n0),
nested_or_self(c1, n1))); nested_or_self(c1, n1))) {
els.Push(el.Get());
} else {
return el.Failure();
}
} }
return CreateComposite(builder, composite_ty, std::move(els)); return CreateComposite(builder, composite_ty, std::move(els));
} }
@ -703,7 +709,7 @@ utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
} }
auto ConstEval::AddFunc(const sem::Type* elem_ty) { auto ConstEval::AddFunc(const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> utils::Result<const Constant*> { return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Add(a1, a2)) { if (auto r = Add(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get()); return CreateElement(builder, elem_ty, r.Get());
} }
@ -712,7 +718,7 @@ auto ConstEval::AddFunc(const sem::Type* elem_ty) {
} }
auto ConstEval::MulFunc(const sem::Type* elem_ty) { auto ConstEval::MulFunc(const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> utils::Result<const Constant*> { return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Mul(a1, a2)) { if (auto r = Mul(a1, a2)) {
return CreateElement(builder, elem_ty, r.Get()); return CreateElement(builder, elem_ty, r.Get());
} }
@ -721,7 +727,7 @@ auto ConstEval::MulFunc(const sem::Type* elem_ty) {
} }
auto ConstEval::Dot2Func(const sem::Type* elem_ty) { auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto b1, auto b2) -> utils::Result<const Constant*> { return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult {
if (auto r = Dot2(a1, a2, b1, b2)) { if (auto r = Dot2(a1, a2, b1, b2)) {
return CreateElement(builder, elem_ty, r.Get()); return CreateElement(builder, elem_ty, r.Get());
} }
@ -730,8 +736,7 @@ auto ConstEval::Dot2Func(const sem::Type* elem_ty) {
} }
auto ConstEval::Dot3Func(const sem::Type* elem_ty) { auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto b1, auto b2, return [=](auto a1, auto a2, auto a3, auto b1, auto b2, auto b3) -> ImplResult {
auto b3) -> utils::Result<const Constant*> {
if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) { if (auto r = Dot3(a1, a2, a3, b1, b2, b3)) {
return CreateElement(builder, elem_ty, r.Get()); return CreateElement(builder, elem_ty, r.Get());
} }
@ -740,23 +745,22 @@ auto ConstEval::Dot3Func(const sem::Type* elem_ty) {
} }
auto ConstEval::Dot4Func(const sem::Type* elem_ty) { auto ConstEval::Dot4Func(const sem::Type* elem_ty) {
return [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, return
auto b4) -> utils::Result<const Constant*> { [=](auto a1, auto a2, auto a3, auto a4, auto b1, auto b2, auto b3, auto b4) -> ImplResult {
if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) { if (auto r = Dot4(a1, a2, a3, a4, b1, b2, b3, b4)) {
return CreateElement(builder, elem_ty, r.Get()); return CreateElement(builder, elem_ty, r.Get());
} }
return utils::Failure; return utils::Failure;
}; };
} }
ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty, ConstEval::Result ConstEval::Literal(const sem::Type* ty, const ast::LiteralExpression* literal) {
const ast::LiteralExpression* literal) {
return Switch( return Switch(
literal, literal,
[&](const ast::BoolLiteralExpression* lit) { [&](const ast::BoolLiteralExpression* lit) {
return CreateElement(builder, ty, lit->value); return CreateElement(builder, ty, lit->value);
}, },
[&](const ast::IntLiteralExpression* lit) -> const Constant* { [&](const ast::IntLiteralExpression* lit) -> ImplResult {
switch (lit->suffix) { switch (lit->suffix) {
case ast::IntLiteralExpression::Suffix::kNone: case ast::IntLiteralExpression::Suffix::kNone:
return CreateElement(builder, ty, AInt(lit->value)); return CreateElement(builder, ty, AInt(lit->value));
@ -767,7 +771,7 @@ ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
} }
return nullptr; return nullptr;
}, },
[&](const ast::FloatLiteralExpression* lit) -> const Constant* { [&](const ast::FloatLiteralExpression* lit) -> ImplResult {
switch (lit->suffix) { switch (lit->suffix) {
case ast::FloatLiteralExpression::Suffix::kNone: case ast::FloatLiteralExpression::Suffix::kNone:
return CreateElement(builder, ty, AFloat(lit->value)); return CreateElement(builder, ty, AFloat(lit->value));
@ -780,9 +784,8 @@ ConstEval::ConstantResult ConstEval::Literal(const sem::Type* ty,
}); });
} }
ConstEval::ConstantResult ConstEval::ArrayOrStructCtor( ConstEval::Result ConstEval::ArrayOrStructCtor(const sem::Type* ty,
const sem::Type* ty, utils::VectorRef<const sem::Expression*> args) {
utils::VectorRef<const sem::Expression*> args) {
if (args.IsEmpty()) { if (args.IsEmpty()) {
return ZeroValue(builder, ty); return ZeroValue(builder, ty);
} }
@ -801,9 +804,9 @@ ConstEval::ConstantResult ConstEval::ArrayOrStructCtor(
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty, ConstEval::Result ConstEval::Conv(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
uint32_t el_count = 0; uint32_t el_count = 0;
auto* el_ty = sem::Type::ElementOf(ty, &el_count); auto* el_ty = sem::Type::ElementOf(ty, &el_count);
if (!el_ty) { if (!el_ty) {
@ -821,36 +824,36 @@ ConstEval::ConstantResult ConstEval::Conv(const sem::Type* ty,
return nullptr; return nullptr;
} }
ConstEval::ConstantResult ConstEval::Zero(const sem::Type* ty, ConstEval::Result ConstEval::Zero(const sem::Type* ty,
utils::VectorRef<const sem::Constant*>, utils::VectorRef<const sem::Constant*>,
const Source&) { const Source&) {
return ZeroValue(builder, ty); return ZeroValue(builder, ty);
} }
ConstEval::ConstantResult ConstEval::Identity(const sem::Type*, ConstEval::Result ConstEval::Identity(const sem::Type*,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
return args[0]; return args[0];
} }
ConstEval::ConstantResult ConstEval::VecSplat(const sem::Type* ty, ConstEval::Result ConstEval::VecSplat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
if (auto* arg = args[0]) { if (auto* arg = args[0]) {
return builder.create<Splat>(ty, arg, static_cast<const sem::Vector*>(ty)->Width()); return builder.create<Splat>(ty, arg, static_cast<const sem::Vector*>(ty)->Width());
} }
return nullptr; return nullptr;
} }
ConstEval::ConstantResult ConstEval::VecCtorS(const sem::Type* ty, ConstEval::Result ConstEval::VecCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
return CreateComposite(builder, ty, args); return CreateComposite(builder, ty, args);
} }
ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty, ConstEval::Result ConstEval::VecCtorM(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
utils::Vector<const sem::Constant*, 4> els; utils::Vector<const sem::Constant*, 4> els;
for (auto* arg : args) { for (auto* arg : args) {
auto* val = arg; auto* val = arg;
@ -874,9 +877,9 @@ ConstEval::ConstantResult ConstEval::VecCtorM(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty, ConstEval::Result ConstEval::MatCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto* m = static_cast<const sem::Matrix*>(ty); auto* m = static_cast<const sem::Matrix*>(ty);
utils::Vector<const sem::Constant*, 4> els; utils::Vector<const sem::Constant*, 4> els;
@ -891,14 +894,14 @@ ConstEval::ConstantResult ConstEval::MatCtorS(const sem::Type* ty,
return CreateComposite(builder, ty, std::move(els)); return CreateComposite(builder, ty, std::move(els));
} }
ConstEval::ConstantResult ConstEval::MatCtorV(const sem::Type* ty, ConstEval::Result ConstEval::MatCtorV(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
return CreateComposite(builder, ty, args); return CreateComposite(builder, ty, args);
} }
ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr, ConstEval::Result ConstEval::Index(const sem::Expression* obj_expr,
const sem::Expression* idx_expr) { const sem::Expression* idx_expr) {
auto idx_val = idx_expr->ConstantValue(); auto idx_val = idx_expr->ConstantValue();
if (!idx_val) { if (!idx_val) {
return nullptr; return nullptr;
@ -926,8 +929,8 @@ ConstEval::ConstantResult ConstEval::Index(const sem::Expression* obj_expr,
return obj_val->Index(static_cast<size_t>(idx)); return obj_val->Index(static_cast<size_t>(idx));
} }
ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_expr, ConstEval::Result ConstEval::MemberAccess(const sem::Expression* obj_expr,
const sem::StructMember* member) { const sem::StructMember* member) {
auto obj_val = obj_expr->ConstantValue(); auto obj_val = obj_expr->ConstantValue();
if (!obj_val) { if (!obj_val) {
return nullptr; return nullptr;
@ -935,30 +938,29 @@ ConstEval::ConstantResult ConstEval::MemberAccess(const sem::Expression* obj_exp
return obj_val->Index(static_cast<size_t>(member->Index())); return obj_val->Index(static_cast<size_t>(member->Index()));
} }
ConstEval::ConstantResult ConstEval::Swizzle(const sem::Type* ty, ConstEval::Result ConstEval::Swizzle(const sem::Type* ty,
const sem::Expression* vec_expr, const sem::Expression* vec_expr,
utils::VectorRef<uint32_t> indices) { utils::VectorRef<uint32_t> indices) {
auto* vec_val = vec_expr->ConstantValue(); auto* vec_val = vec_expr->ConstantValue();
if (!vec_val) { if (!vec_val) {
return nullptr; return nullptr;
} }
if (indices.Length() == 1) { if (indices.Length() == 1) {
return vec_val->Index(static_cast<size_t>(indices[0])); return vec_val->Index(static_cast<size_t>(indices[0]));
} else {
auto values = utils::Transform<4>(
indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
return CreateComposite(builder, ty, std::move(values));
} }
auto values = utils::Transform<4>(
indices, [&](uint32_t i) { return vec_val->Index(static_cast<size_t>(i)); });
return CreateComposite(builder, ty, std::move(values));
} }
ConstEval::ConstantResult ConstEval::Bitcast(const sem::Type*, const sem::Expression*) { ConstEval::Result ConstEval::Bitcast(const sem::Type*, const sem::Expression*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics // TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr; return nullptr;
} }
ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type* ty, ConstEval::Result ConstEval::OpComplement(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c) { auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) { auto create = [&](auto i) {
return CreateElement(builder, c->Type(), decltype(i)(~i.value)); return CreateElement(builder, c->Type(), decltype(i)(~i.value));
@ -968,9 +970,9 @@ ConstEval::ConstantResult ConstEval::OpComplement(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]); return TransformElements(builder, ty, transform, args[0]);
} }
ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type* ty, ConstEval::Result ConstEval::OpUnaryMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c) { auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) { auto create = [&](auto i) {
// For signed integrals, avoid C++ UB by not negating the // For signed integrals, avoid C++ UB by not negating the
@ -993,9 +995,9 @@ ConstEval::ConstantResult ConstEval::OpUnaryMinus(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]); return TransformElements(builder, ty, transform, args[0]);
} }
ConstEval::ConstantResult ConstEval::OpNot(const sem::Type* ty, ConstEval::Result ConstEval::OpNot(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c) { auto transform = [&](const sem::Constant* c) {
auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(!i)); }; auto create = [&](auto i) { return CreateElement(builder, c->Type(), decltype(i)(!i)); };
return Dispatch_bool(create, c); return Dispatch_bool(create, c);
@ -1003,29 +1005,22 @@ ConstEval::ConstantResult ConstEval::OpNot(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]); return TransformElements(builder, ty, transform, args[0]);
} }
ConstEval::ConstantResult ConstEval::OpPlus(const sem::Type* ty, ConstEval::Result ConstEval::OpPlus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source); TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
if (auto r = Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1)) { return Dispatch_fia_fiu32_f16(AddFunc(c0->Type()), c0, c1);
return r.Get();
}
return nullptr;
}; };
auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty, ConstEval::Result ConstEval::OpMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
NumberT result; NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) { if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
@ -1034,7 +1029,7 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
result = r->value; result = r->value;
} else { } else {
AddError(OverflowErrorMessage(i, "-", j), source); AddError(OverflowErrorMessage(i, "-", j), source);
return nullptr; return utils::Failure;
} }
} else { } else {
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
@ -1054,41 +1049,30 @@ ConstEval::ConstantResult ConstEval::OpMinus(const sem::Type* ty,
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpMultiply(const sem::Type* ty, ConstEval::Result ConstEval::OpMultiply(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source); TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) -> const Constant* { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
if (auto r = Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1)) { return Dispatch_fia_fiu32_f16(MulFunc(c0->Type()), c0, c1);
return r.Get();
}
return nullptr;
}; };
auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty, ConstEval::Result ConstEval::OpMultiplyMatVec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source); TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat_ty = args[0]->Type()->As<sem::Matrix>(); auto* mat_ty = args[0]->Type()->As<sem::Matrix>();
auto* vec_ty = args[1]->Type()->As<sem::Vector>(); auto* vec_ty = args[1]->Type()->As<sem::Vector>();
auto* elem_ty = vec_ty->type(); auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) { auto dot = [&](const sem::Constant* m, size_t row, const sem::Constant* v) {
utils::Result<const Constant*> result; ImplResult result;
switch (mat_ty->columns()) { switch (mat_ty->columns()) {
case 2: case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
@ -1130,16 +1114,16 @@ ConstEval::ConstantResult ConstEval::OpMultiplyMatVec(const sem::Type* ty,
} }
return CreateComposite(builder, ty, result); return CreateComposite(builder, ty, result);
} }
ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty, ConstEval::Result ConstEval::OpMultiplyVecMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source); TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* vec_ty = args[0]->Type()->As<sem::Vector>(); auto* vec_ty = args[0]->Type()->As<sem::Vector>();
auto* mat_ty = args[1]->Type()->As<sem::Matrix>(); auto* mat_ty = args[1]->Type()->As<sem::Matrix>();
auto* elem_ty = vec_ty->type(); auto* elem_ty = vec_ty->type();
auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) { auto dot = [&](const sem::Constant* v, const sem::Constant* m, size_t col) {
utils::Result<const Constant*> result; ImplResult result;
switch (mat_ty->rows()) { switch (mat_ty->rows()) {
case 2: case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
@ -1182,9 +1166,9 @@ ConstEval::ConstantResult ConstEval::OpMultiplyVecMat(const sem::Type* ty,
return CreateComposite(builder, ty, result); return CreateComposite(builder, ty, result);
} }
ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty, ConstEval::Result ConstEval::OpMultiplyMatMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
TINT_SCOPED_ASSIGNMENT(current_source, &source); TINT_SCOPED_ASSIGNMENT(current_source, &source);
auto* mat1 = args[0]; auto* mat1 = args[0];
auto* mat2 = args[1]; auto* mat2 = args[1];
@ -1196,7 +1180,7 @@ ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); }; auto m1e = [&](size_t r, size_t c) { return m1->Index(c)->Index(r); };
auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); }; auto m2e = [&](size_t r, size_t c) { return m2->Index(c)->Index(r); };
utils::Result<const Constant*> result; ImplResult result;
switch (mat1_ty->columns()) { switch (mat1_ty->columns()) {
case 2: case 2:
result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), // result = Dispatch_fa_f32_f16(Dot2Func(elem_ty), //
@ -1247,11 +1231,11 @@ ConstEval::ConstantResult ConstEval::OpMultiplyMatMat(const sem::Type* ty,
return CreateComposite(builder, ty, result_mat); return CreateComposite(builder, ty, result_mat);
} }
ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty, ConstEval::Result ConstEval::OpDivide(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source) { const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
NumberT result; NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) { if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) {
@ -1260,7 +1244,7 @@ ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty,
result = r->value; result = r->value;
} else { } else {
AddError(OverflowErrorMessage(i, "/", j), source); AddError(OverflowErrorMessage(i, "/", j), source);
return nullptr; return utils::Failure;
} }
} else { } else {
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
@ -1288,120 +1272,92 @@ ConstEval::ConstantResult ConstEval::OpDivide(const sem::Type* ty,
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
auto r = TransformBinaryElements(builder, ty, transform, args[0], args[1]); return TransformBinaryElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpEqual(const sem::Type* ty, ConstEval::Result ConstEval::OpEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i == j); return CreateElement(builder, sem::Type::DeepestElementOf(ty), i == j);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpNotEqual(const sem::Type* ty, ConstEval::Result ConstEval::OpNotEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i != j); return CreateElement(builder, sem::Type::DeepestElementOf(ty), i != j);
}; };
return Dispatch_fia_fiu32_f16_bool(create, c0, c1); return Dispatch_fia_fiu32_f16_bool(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpLessThan(const sem::Type* ty, ConstEval::Result ConstEval::OpLessThan(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i < j); return CreateElement(builder, sem::Type::DeepestElementOf(ty), i < j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpGreaterThan(const sem::Type* ty, ConstEval::Result ConstEval::OpGreaterThan(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i > j); return CreateElement(builder, sem::Type::DeepestElementOf(ty), i > j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpLessThanEqual(const sem::Type* ty, ConstEval::Result ConstEval::OpLessThanEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i <= j); return CreateElement(builder, sem::Type::DeepestElementOf(ty), i <= j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpGreaterThanEqual(const sem::Type* ty, ConstEval::Result ConstEval::OpGreaterThanEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), i >= j); return CreateElement(builder, sem::Type::DeepestElementOf(ty), i >= j);
}; };
return Dispatch_fia_fiu32_f16(create, c0, c1); return Dispatch_fia_fiu32_f16(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpAnd(const sem::Type* ty, ConstEval::Result ConstEval::OpAnd(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
using T = decltype(i); using T = decltype(i);
T result; T result;
if constexpr (std::is_same_v<T, bool>) { if constexpr (std::is_same_v<T, bool>) {
@ -1414,18 +1370,14 @@ ConstEval::ConstantResult ConstEval::OpAnd(const sem::Type* ty,
return Dispatch_ia_iu32_bool(create, c0, c1); return Dispatch_ia_iu32_bool(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpOr(const sem::Type* ty, ConstEval::Result ConstEval::OpOr(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> ImplResult {
using T = decltype(i); using T = decltype(i);
T result; T result;
if constexpr (std::is_same_v<T, bool>) { if constexpr (std::is_same_v<T, bool>) {
@ -1438,18 +1390,14 @@ ConstEval::ConstantResult ConstEval::OpOr(const sem::Type* ty,
return Dispatch_ia_iu32_bool(create, c0, c1); return Dispatch_ia_iu32_bool(create, c0, c1);
}; };
auto r = TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
} }
ConstEval::ConstantResult ConstEval::OpXor(const sem::Type* ty, ConstEval::Result ConstEval::OpXor(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) -> const Constant* { auto create = [&](auto i, auto j) -> const ImplConstant* {
return CreateElement(builder, sem::Type::DeepestElementOf(ty), decltype(i){i ^ j}); return CreateElement(builder, sem::Type::DeepestElementOf(ty), decltype(i){i ^ j});
}; };
return Dispatch_ia_iu32(create, c0, c1); return Dispatch_ia_iu32(create, c0, c1);
@ -1462,9 +1410,9 @@ ConstEval::ConstantResult ConstEval::OpXor(const sem::Type* ty,
return r; return r;
} }
ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty, ConstEval::Result ConstEval::atan2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto i, auto j) { auto create = [&](auto i, auto j) {
return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value)));
@ -1474,9 +1422,9 @@ ConstEval::ConstantResult ConstEval::atan2(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
} }
ConstEval::ConstantResult ConstEval::clamp(const sem::Type* ty, ConstEval::Result ConstEval::clamp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1, auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
const sem::Constant* c2) { const sem::Constant* c2) {
auto create = [&](auto e, auto low, auto high) { auto create = [&](auto e, auto low, auto high) {
@ -1488,17 +1436,13 @@ ConstEval::ConstantResult ConstEval::clamp(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1], args[2]); return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
} }
utils::Result<const sem::Constant*> ConstEval::Convert(const sem::Type* target_ty, ConstEval::Result ConstEval::Convert(const sem::Type* target_ty,
const sem::Constant* value, const sem::Constant* value,
const Source& source) { const Source& source) {
if (value->Type() == target_ty) { if (value->Type() == target_ty) {
return value; return value;
} }
auto conv = static_cast<const Constant*>(value)->Convert(builder, target_ty, source); return static_cast<const ImplConstant*>(value)->Convert(builder, target_ty, source);
if (!conv) {
return utils::Failure;
}
return conv.Get();
} }
void ConstEval::AddError(const std::string& msg, const Source& source) const { void ConstEval::AddError(const std::string& msg, const Source& source) const {

View File

@ -53,12 +53,12 @@ class ConstEval {
/// * `utils::Failure`. Returned when there was a resolver error. In this situation the method /// * `utils::Failure`. Returned when there was a resolver error. In this situation the method
/// will have already reported a diagnostic error message, and the caller should abort /// will have already reported a diagnostic error message, and the caller should abort
/// resolving. /// resolving.
using ConstantResult = utils::Result<const sem::Constant*>; using Result = utils::Result<const sem::Constant*>;
/// Typedef for a constant evaluation function /// Typedef for a constant evaluation function
using Function = ConstantResult (ConstEval::*)(const sem::Type* result_ty, using Function = Result (ConstEval::*)(const sem::Type* result_ty,
utils::VectorRef<const sem::Constant*>, utils::VectorRef<const sem::Constant*>,
const Source&); const Source&);
/// Constructor /// Constructor
/// @param b the program builder /// @param b the program builder
@ -71,44 +71,43 @@ class ConstEval {
/// @param ty the target type - must be an array or constructor /// @param ty the target type - must be an array or constructor
/// @param args the input arguments /// @param args the input arguments
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult ArrayOrStructCtor(const sem::Type* ty, Result ArrayOrStructCtor(const sem::Type* ty, utils::VectorRef<const sem::Expression*> args);
utils::VectorRef<const sem::Expression*> args);
/// @param ty the target type /// @param ty the target type
/// @param expr the input expression /// @param expr the input expression
/// @return the bit-cast of the given expression to the given type, or null if the value cannot /// @return the bit-cast of the given expression to the given type, or null if the value cannot
/// be calculated /// be calculated
ConstantResult Bitcast(const sem::Type* ty, const sem::Expression* expr); Result Bitcast(const sem::Type* ty, const sem::Expression* expr);
/// @param obj the object being indexed /// @param obj the object being indexed
/// @param idx the index expression /// @param idx the index expression
/// @return the result of the index, or null if the value cannot be calculated /// @return the result of the index, or null if the value cannot be calculated
ConstantResult Index(const sem::Expression* obj, const sem::Expression* idx); Result Index(const sem::Expression* obj, const sem::Expression* idx);
/// @param ty the result type /// @param ty the result type
/// @param lit the literal AST node /// @param lit the literal AST node
/// @return the constant value of the literal /// @return the constant value of the literal
ConstantResult Literal(const sem::Type* ty, const ast::LiteralExpression* lit); Result Literal(const sem::Type* ty, const ast::LiteralExpression* lit);
/// @param obj the object being accessed /// @param obj the object being accessed
/// @param member the member /// @param member the member
/// @return the result of the member access, or null if the value cannot be calculated /// @return the result of the member access, or null if the value cannot be calculated
ConstantResult MemberAccess(const sem::Expression* obj, const sem::StructMember* member); Result MemberAccess(const sem::Expression* obj, const sem::StructMember* member);
/// @param ty the result type /// @param ty the result type
/// @param vector the vector being swizzled /// @param vector the vector being swizzled
/// @param indices the swizzle indices /// @param indices the swizzle indices
/// @return the result of the swizzle, or null if the value cannot be calculated /// @return the result of the swizzle, or null if the value cannot be calculated
ConstantResult Swizzle(const sem::Type* ty, Result Swizzle(const sem::Type* ty,
const sem::Expression* vector, const sem::Expression* vector,
utils::VectorRef<uint32_t> indices); utils::VectorRef<uint32_t> indices);
/// Convert the `value` to `target_type` /// Convert the `value` to `target_type`
/// @param ty the result type /// @param ty the result type
/// @param value the value being converted /// @param value the value being converted
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the converted value, or null if the value cannot be calculated /// @return the converted value, or null if the value cannot be calculated
ConstantResult Convert(const sem::Type* ty, const sem::Constant* value, const Source& source); Result Convert(const sem::Type* ty, const sem::Constant* value, const Source& source);
//////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////
// Constant value evaluation methods, to be indirectly called via the intrinsic table // Constant value evaluation methods, to be indirectly called via the intrinsic table
@ -119,72 +118,72 @@ class ConstEval {
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the converted value, or null if the value cannot be calculated /// @return the converted value, or null if the value cannot be calculated
ConstantResult Conv(const sem::Type* ty, Result Conv(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Zero value type constructor /// Zero value type constructor
/// @param ty the result type /// @param ty the result type
/// @param args the input arguments (no arguments provided) /// @param args the input arguments (no arguments provided)
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult Zero(const sem::Type* ty, Result Zero(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Identity value type constructor /// Identity value type constructor
/// @param ty the result type /// @param ty the result type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult Identity(const sem::Type* ty, Result Identity(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Vector splat constructor /// Vector splat constructor
/// @param ty the vector type /// @param ty the vector type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult VecSplat(const sem::Type* ty, Result VecSplat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Vector constructor using scalars /// Vector constructor using scalars
/// @param ty the vector type /// @param ty the vector type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult VecCtorS(const sem::Type* ty, Result VecCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Vector constructor using a mix of scalars and smaller vectors /// Vector constructor using a mix of scalars and smaller vectors
/// @param ty the vector type /// @param ty the vector type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult VecCtorM(const sem::Type* ty, Result VecCtorM(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Matrix constructor using scalar values /// Matrix constructor using scalar values
/// @param ty the matrix type /// @param ty the matrix type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult MatCtorS(const sem::Type* ty, Result MatCtorS(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Matrix constructor using column vectors /// Matrix constructor using column vectors
/// @param ty the matrix type /// @param ty the matrix type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the constructed value, or null if the value cannot be calculated /// @return the constructed value, or null if the value cannot be calculated
ConstantResult MatCtorV(const sem::Type* ty, Result MatCtorV(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Unary Operators // Unary Operators
@ -195,27 +194,27 @@ class ConstEval {
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpComplement(const sem::Type* ty, Result OpComplement(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Unary minus operator '-' /// Unary minus operator '-'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpUnaryMinus(const sem::Type* ty, Result OpUnaryMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Unary not operator '!' /// Unary not operator '!'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpNot(const sem::Type* ty, Result OpNot(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Binary Operators // Binary Operators
@ -226,142 +225,142 @@ class ConstEval {
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpPlus(const sem::Type* ty, Result OpPlus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Minus operator '-' /// Minus operator '-'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpMinus(const sem::Type* ty, Result OpMinus(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Multiply operator '*' for the same type on the LHS and RHS /// Multiply operator '*' for the same type on the LHS and RHS
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiply(const sem::Type* ty, Result OpMultiply(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Multiply operator '*' for matCxR<T> * vecC<T> /// Multiply operator '*' for matCxR<T> * vecC<T>
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyMatVec(const sem::Type* ty, Result OpMultiplyMatVec(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Multiply operator '*' for vecR<T> * matCxR<T> /// Multiply operator '*' for vecR<T> * matCxR<T>
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyVecMat(const sem::Type* ty, Result OpMultiplyVecMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Multiply operator '*' for matKxR<T> * matCxK<T> /// Multiply operator '*' for matKxR<T> * matCxK<T>
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpMultiplyMatMat(const sem::Type* ty, Result OpMultiplyMatMat(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Divide operator '/' /// Divide operator '/'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpDivide(const sem::Type* ty, Result OpDivide(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Equality operator '==' /// Equality operator '=='
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpEqual(const sem::Type* ty, Result OpEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Inequality operator '!=' /// Inequality operator '!='
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpNotEqual(const sem::Type* ty, Result OpNotEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Less than operator '<' /// Less than operator '<'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpLessThan(const sem::Type* ty, Result OpLessThan(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Greater than operator '>' /// Greater than operator '>'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpGreaterThan(const sem::Type* ty, Result OpGreaterThan(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Less than or equal operator '<=' /// Less than or equal operator '<='
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpLessThanEqual(const sem::Type* ty, Result OpLessThanEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Greater than or equal operator '>=' /// Greater than or equal operator '>='
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpGreaterThanEqual(const sem::Type* ty, Result OpGreaterThanEqual(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Bitwise and operator '&' /// Bitwise and operator '&'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpAnd(const sem::Type* ty, Result OpAnd(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Bitwise or operator '|' /// Bitwise or operator '|'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpOr(const sem::Type* ty, Result OpOr(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Bitwise xor operator '^' /// Bitwise xor operator '^'
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult OpXor(const sem::Type* ty, Result OpXor(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
@ -374,18 +373,18 @@ class ConstEval {
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult atan2(const sem::Type* ty, Result atan2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// clamp builtin /// clamp builtin
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
ConstantResult clamp(const sem::Type* ty, Result clamp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
private: private:
/// Adds the given error message to the diagnostics /// Adds the given error message to the diagnostics

View File

@ -17,6 +17,7 @@
#include <ostream> #include <ostream>
#include <variant> #include <variant>
#include "src/tint/debug.h" #include "src/tint/debug.h"
namespace tint::utils { namespace tint::utils {
@ -50,6 +51,20 @@ struct [[nodiscard]] Result {
Result(const FAILURE_TYPE& failure) // NOLINT(runtime/explicit): Result(const FAILURE_TYPE& failure) // NOLINT(runtime/explicit):
: value{failure} {} : value{failure} {}
/// Copy constructor with success / failure casting
/// @param other the Result to copy
template <typename S,
typename F,
typename = std::void_t<decltype(SUCCESS_TYPE{std::declval<S>()}),
decltype(FAILURE_TYPE{std::declval<F>()})>>
Result(const Result<S, F>& other) { // NOLINT(runtime/explicit):
if (other) {
value = SUCCESS_TYPE{other.Get()};
} else {
value = FAILURE_TYPE{other.Failure()};
}
}
/// @returns true if the result was a success /// @returns true if the result was a success
operator bool() const { operator bool() const {
Validate(); Validate();

View File

@ -51,5 +51,17 @@ TEST(ResultTest, CustomFailure) {
EXPECT_EQ(r.Failure(), "oh noes!"); EXPECT_EQ(r.Failure(), "oh noes!");
} }
TEST(ResultTest, ValueCast) {
struct X {};
struct Y : X {};
Y* y = nullptr;
auto r_y = Result<Y*>{y};
auto r_x = Result<X*>{r_y};
(void)r_x;
(void)r_y;
}
} // namespace } // namespace
} // namespace tint::utils } // namespace tint::utils