tint/resolver: Ensure materialized values are representable

by the materialized type.

Bug: tint:1504
Change-Id: I3534ce62308ba2ff32c52a2f5bc8480d102153a1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91422
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-25 21:16:55 +00:00 committed by Dawn LUCI CQ
parent a8d5228049
commit e34e059804
4 changed files with 360 additions and 149 deletions

View File

@ -134,35 +134,65 @@ static std::ostream& operator<<(std::ostream& o, Method m) {
struct Data {
std::string target_type_name;
std::string target_element_type_name;
builder::ast_type_func_ptr target_ast_ty;
builder::sem_type_func_ptr target_sem_ty;
builder::ast_expr_func_ptr target_expr;
std::string literal_type_name;
builder::ast_expr_func_ptr literal_value;
std::string source_type_name;
builder::ast_expr_func_ptr source_builder;
std::variant<AInt, AFloat> materialized_value;
double literal_value;
};
template <typename TARGET_TYPE, typename LITERAL_TYPE, typename MATERIALIZED_TYPE = AInt>
Data Types(MATERIALIZED_TYPE materialized_value = 0_a) {
template <typename TARGET_TYPE, typename SOURCE_TYPE, typename MATERIALIZED_TYPE>
Data Types(MATERIALIZED_TYPE materialized_value, double literal_value) {
using TargetDataType = builder::DataType<TARGET_TYPE>;
using SourceDataType = builder::DataType<SOURCE_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
builder::DataType<TARGET_TYPE>::Name(), //
builder::DataType<TARGET_TYPE>::AST, //
builder::DataType<TARGET_TYPE>::Sem, //
builder::DataType<TARGET_TYPE>::Expr, //
builder::DataType<LITERAL_TYPE>::Name(), //
builder::DataType<LITERAL_TYPE>::Expr, //
TargetDataType::Name(), // target_type_name
TargetElementDataType::Name(), // target_element_type_name
TargetDataType::AST, // target_ast_ty
TargetDataType::Sem, // target_sem_ty
TargetDataType::Expr, // target_expr
SourceDataType::Name(), // literal_type_name
SourceDataType::Expr, // literal_builder
materialized_value,
literal_value,
};
}
template <typename TARGET_TYPE, typename SOURCE_TYPE>
Data Types() {
using TargetDataType = builder::DataType<TARGET_TYPE>;
using SourceDataType = builder::DataType<SOURCE_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
TargetDataType::Name(), // target_type_name
TargetElementDataType::Name(), // target_element_type_name
TargetDataType::AST, // target_ast_ty
TargetDataType::Sem, // target_sem_ty
TargetDataType::Expr, // target_expr
SourceDataType::Name(), // literal_type_name
SourceDataType::Expr, // literal_builder
0_a,
0.0,
};
}
static std::ostream& operator<<(std::ostream& o, const Data& c) {
return o << "[" << c.target_type_name << " <- " << c.literal_type_name << "]";
auto print_value = [&](auto&& v) { o << v; };
o << "[" << c.target_type_name << " <- " << c.source_type_name << "] [";
std::visit(print_value, c.materialized_value);
o << " <- " << c.literal_value << "]";
return o;
}
enum class Expectation {
kMaterialize,
kNoMaterialize,
kInvalidCast,
kValueCannotBeRepresented,
};
static std::ostream& operator<<(std::ostream& o, Expectation m) {
@ -173,6 +203,8 @@ static std::ostream& operator<<(std::ostream& o, Expectation m) {
return o << "no-materialize";
case Expectation::kInvalidCast:
return o << "invalid-cast";
case Expectation::kValueCannotBeRepresented:
return o << "value too low or high";
}
return o << "<unknown>";
}
@ -191,7 +223,7 @@ TEST_P(MaterializeAbstractNumeric, Test) {
auto target_ty = [&] { return data.target_ast_ty(*this); };
auto target_expr = [&] { return data.target_expr(*this, 42); };
auto* literal = data.literal_value(*this, 1);
auto* literal = data.source_builder(*this, data.literal_value);
switch (method) {
case Method::kVar:
WrapInFunction(Decl(Var("a", target_ty(), literal)));
@ -283,82 +315,132 @@ TEST_P(MaterializeAbstractNumeric, Test) {
switch (method) {
case Method::kBuiltinArg:
expect = "error: no matching call to min(" + data.target_type_name + ", " +
data.literal_type_name + ")";
data.source_type_name + ")";
break;
case Method::kBinaryOp:
expect = "error: no matching overload for operator + (" +
data.target_type_name + ", " + data.literal_type_name + ")";
data.target_type_name + ", " + data.source_type_name + ")";
break;
default:
expect = "error: cannot convert value of type '" + data.literal_type_name +
expect = "error: cannot convert value of type '" + data.source_type_name +
"' to type '" + data.target_type_name + "'";
break;
}
EXPECT_THAT(r()->error(), testing::StartsWith(expect));
break;
}
case Expectation::kValueCannotBeRepresented:
ASSERT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), testing::HasSubstr("cannot be represented as '" +
data.target_element_type_name + "'"));
break;
}
}
// TODO(crbug.com/tint/1504): Test for abstract-numeric values not fitting in materialized types.
INSTANTIATE_TEST_SUITE_P(MaterializeScalar,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::Values(Method::kLet, //
/// Methods that support scalar materialization
constexpr Method kScalarMethods[] = {Method::kLet, //
Method::kVar, //
Method::kFnArg, //
Method::kBuiltinArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp), //
testing::Values(Types<i32, AInt>(1_a), //
Types<u32, AInt>(1_a), //
Types<f32, AFloat>(1.0_a) //
Method::kBinaryOp};
/// Methods that support vector materialization
constexpr Method kVectorMethods[] = {Method::kLet, //
Method::kVar, //
Method::kFnArg, //
Method::kBuiltinArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp};
/// Methods that support matrix materialization
constexpr Method kMatrixMethods[] = {Method::kLet, //
Method::kVar, //
Method::kFnArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp};
/// Methods that support materialization for switch cases
constexpr Method kSwitchMethods[] = {Method::kSwitchCond, //
Method::kSwitchCase, //
Method::kSwitchCondWithAbstractCase, //
Method::kSwitchCaseWithAbstractCase};
constexpr double kMaxF32 = static_cast<double>(f32::kHighest);
constexpr double kPiF64 = 3.141592653589793;
constexpr double kPiF32 = 3.1415927410125732; // kPiF64 quantized to f32
// (2^-127)×(1+(0xfffffffffffff÷0x10000000000000))
constexpr double kTooSmallF32 = 1.1754943508222874e-38;
INSTANTIATE_TEST_SUITE_P(
MaterializeScalar,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::ValuesIn(kScalarMethods), //
testing::Values(Types<i32, AInt>(0_a, 0.0), //
Types<i32, AInt>(2147483647_a, 2147483647.0), //
Types<i32, AInt>(-2147483648_a, -2147483648.0), //
Types<u32, AInt>(0_a, 0.0), //
Types<u32, AInt>(4294967295_a, 4294967295.0), //
Types<f32, AFloat>(0.0_a, 0.0), //
Types<f32, AFloat>(AFloat(kMaxF32), kMaxF32), //
Types<f32, AFloat>(AFloat(-kMaxF32), -kMaxF32), //
Types<f32, AFloat>(AFloat(kPiF32), kPiF64), //
Types<f32, AFloat>(0.0_a, kTooSmallF32), //
Types<f32, AFloat>(-0.0_a, -kTooSmallF32) //
/* Types<f16, AFloat>(1.0_a), */ //
/* Types<f16, AFloat>(1.0_a), */)));
INSTANTIATE_TEST_SUITE_P(MaterializeVector,
INSTANTIATE_TEST_SUITE_P(
MaterializeVector,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::Values(Method::kLet, //
Method::kVar, //
Method::kFnArg, //
Method::kBuiltinArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp), //
testing::Values(Types<i32V, AIntV>(1_a), //
Types<u32V, AIntV>(1_a), //
Types<f32V, AFloatV>(1.0_a) //
testing::ValuesIn(kVectorMethods), //
testing::Values(Types<i32V, AIntV>(0_a, 0.0), //
Types<i32V, AIntV>(2147483647_a, 2147483647.0), //
Types<i32V, AIntV>(-2147483648_a, -2147483648.0), //
Types<u32V, AIntV>(0_a, 0.0), //
Types<u32V, AIntV>(4294967295_a, 4294967295.0), //
Types<f32V, AFloatV>(0.0_a, 0.0), //
Types<f32V, AFloatV>(AFloat(kMaxF32), kMaxF32), //
Types<f32V, AFloatV>(AFloat(-kMaxF32), -kMaxF32), //
Types<f32V, AFloatV>(AFloat(kPiF32), kPiF64), //
Types<f32V, AFloatV>(0.0_a, kTooSmallF32), //
Types<f32V, AFloatV>(-0.0_a, -kTooSmallF32) //
/* Types<f16V, AFloatV>(1.0_a), */ //
/* Types<f16V, AFloatV>(1.0_a), */)));
INSTANTIATE_TEST_SUITE_P(MaterializeMatrix,
INSTANTIATE_TEST_SUITE_P(
MaterializeMatrix,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::Values(Method::kLet, //
Method::kVar, //
Method::kFnArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp), //
testing::Values(Types<f32M, AFloatM>(1.0_a) //
testing::ValuesIn(kMatrixMethods), //
testing::Values(Types<f32M, AFloatM>(0.0_a, 0.0), //
Types<f32M, AFloatM>(AFloat(kMaxF32), kMaxF32), //
Types<f32M, AFloatM>(AFloat(-kMaxF32), -kMaxF32), //
Types<f32M, AFloatM>(AFloat(kPiF32), kPiF64), //
Types<f32M, AFloatM>(0.0_a, kTooSmallF32), //
Types<f32M, AFloatM>(-0.0_a, -kTooSmallF32) //
/* Types<f16V, AFloatM>(1.0_a), */ //
)));
INSTANTIATE_TEST_SUITE_P(MaterializeSwitch,
INSTANTIATE_TEST_SUITE_P(
MaterializeSwitch,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::Values(Method::kSwitchCond, //
Method::kSwitchCase, //
Method::kSwitchCondWithAbstractCase, //
Method::kSwitchCaseWithAbstractCase), //
testing::Values(Types<i32, AInt>(1_a), //
Types<u32, AInt>(1_a))));
testing::ValuesIn(kSwitchMethods), //
testing::Values(Types<i32, AInt>(0_a, 0.0), //
Types<i32, AInt>(2147483647_a, 2147483647.0), //
Types<i32, AInt>(-2147483648_a, -2147483648.0), //
Types<u32, AInt>(0_a, 0.0), //
Types<u32, AInt>(4294967295_a, 4294967295.0))));
// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops.
INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
@ -366,27 +448,58 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
testing::Combine(testing::Values(Expectation::kNoMaterialize), //
testing::Values(Method::kBuiltinArg, //
Method::kBinaryOp), //
testing::Values(Types<AInt, AInt>(1_a), //
Types<AFloat, AFloat>(1.0_a), //
Types<AIntV, AIntV>(1_a), //
Types<AFloatV, AFloatV>(1.0_a), //
Types<AFloatM, AFloatM>(1.0_a))));
testing::Values(Types<AInt, AInt>(), //
Types<AFloat, AFloat>(), //
Types<AIntV, AIntV>(), //
Types<AFloatV, AFloatV>(), //
Types<AFloatM, AFloatM>())));
INSTANTIATE_TEST_SUITE_P(InvalidCast,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kInvalidCast), //
testing::Values(Method::kLet, //
Method::kVar, //
Method::kFnArg, //
Method::kBuiltinArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp), //
testing::ValuesIn(kScalarMethods), //
testing::Values(Types<i32, AFloat>(), //
Types<u32, AFloat>(), //
Types<i32V, AFloatV>(), //
Types<u32V, AFloatV>())));
INSTANTIATE_TEST_SUITE_P(
ScalarValueCannotBeRepresented,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
testing::ValuesIn(kScalarMethods), //
testing::Values(Types<i32, AInt>(0_a, 2147483648.0), //
Types<i32, AInt>(0_a, -2147483649.0), //
Types<u32, AInt>(0_a, 4294967296), //
Types<u32, AInt>(0_a, -1.0), //
Types<f32, AFloat>(0.0_a, 3.5e+38), //
Types<f32, AFloat>(0.0_a, -3.5e+38) //
/* Types<f16, AFloat>(), */ //
/* Types<f16, AFloat>(), */)));
INSTANTIATE_TEST_SUITE_P(
VectorValueCannotBeRepresented,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
testing::ValuesIn(kVectorMethods), //
testing::Values(Types<i32V, AIntV>(0_a, 2147483648.0), //
Types<i32V, AIntV>(0_a, -2147483649.0), //
Types<u32V, AIntV>(0_a, 4294967296), //
Types<u32V, AIntV>(0_a, -1.0), //
Types<f32V, AFloatV>(0.0_a, 3.5e+38), //
Types<f32V, AFloatV>(0.0_a, -3.5e+38) //
/* Types<f16V, AFloatV>(), */ //
/* Types<f16V, AFloatV>(), */)));
INSTANTIATE_TEST_SUITE_P(
MatrixValueCannotBeRepresented,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
testing::ValuesIn(kMatrixMethods), //
testing::Values(Types<f32M, AFloatM>(0.0_a, 3.5e+38), //
Types<f32M, AFloatM>(0.0_a, -3.5e+38) //
/* Types<f16M, AFloatM>(), */ //
/* Types<f16M, AFloatM>(), */)));
} // namespace MaterializeTests
} // namespace

View File

@ -1110,19 +1110,27 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
// Helper for actually creating the the materialize node, performing the constant cast, updating
// the ast -> sem binding, and performing validation.
auto materialize = [&](const sem::Type* target_ty) -> sem::Materialize* {
auto expr_val = EvaluateConstantValue(expr->Declaration(), expr->Type());
if (!expr_val.IsValid()) {
auto* decl = expr->Declaration();
auto expr_val = EvaluateConstantValue(decl, expr->Type());
if (!expr_val) {
return nullptr;
}
if (!expr_val->IsValid()) {
TINT_ICE(Resolver, builder_->Diagnostics())
<< expr->Declaration()->source
<< decl->source
<< " EvaluateConstantValue() returned invalid value for materialized "
"value of type: "
<< (expr->Type() ? expr->Type()->FriendlyName(builder_->Symbols()) : "<null>");
return nullptr;
}
auto materialized_val = ConvertValue(expr_val, target_ty);
auto* m = builder_->create<sem::Materialize>(expr, current_statement_, materialized_val);
auto materialized_val = ConvertValue(expr_val.Get(), target_ty, decl->source);
if (!materialized_val) {
return nullptr;
}
auto* m =
builder_->create<sem::Materialize>(expr, current_statement_, materialized_val.Get());
m->Behaviors() = expr->Behaviors();
builder_->Sem().Replace(expr->Declaration(), m);
builder_->Sem().Replace(decl, m);
return validator_.Materialize(m) ? m : nullptr;
};
@ -1215,8 +1223,11 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
}
auto val = EvaluateConstantValue(expr, ty);
if (!val) {
return nullptr;
}
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val,
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val.Get(),
has_side_effects, obj->SourceVariable());
sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
return sem;
@ -1230,7 +1241,10 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
}
auto val = EvaluateConstantValue(expr, ty);
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val,
if (!val) {
return nullptr;
}
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val.Get(),
inner->HasSideEffects());
sem->Behaviors() = inner->Behaviors();
@ -1277,9 +1291,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto value = EvaluateConstantValue(expr, call_target->ReturnType());
auto val = EvaluateConstantValue(expr, call_target->ReturnType());
if (!val) {
return nullptr;
}
return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_,
value, has_side_effects);
val.Get(), has_side_effects);
};
// ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion
@ -1315,9 +1332,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto value = EvaluateConstantValue(expr, call_target->ReturnType());
auto val = EvaluateConstantValue(expr, call_target->ReturnType());
if (!val) {
return nullptr;
}
return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, value, has_side_effects);
current_statement_, val.Get(), has_side_effects);
},
[&](const sem::Struct* str) -> sem::Call* {
auto* call_target = utils::GetOrCreate(
@ -1337,9 +1357,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) {
return nullptr;
}
auto value = EvaluateConstantValue(expr, call_target->ReturnType());
auto val = EvaluateConstantValue(expr, call_target->ReturnType());
if (!val) {
return nullptr;
}
return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, value, has_side_effects);
current_statement_, val.Get(), has_side_effects);
},
[&](Default) {
AddError("type is not constructible", expr->source);
@ -1616,7 +1639,10 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
}
auto val = EvaluateConstantValue(literal, ty);
return builder_->create<sem::Expression>(literal, ty, current_statement_, val,
if (!val) {
return nullptr;
}
return builder_->create<sem::Expression>(literal, ty, current_statement_, val.Get(),
/* has_side_effects */ false);
}
@ -1828,8 +1854,11 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
}
auto val = EvaluateConstantValue(expr, op.result);
if (!val) {
return nullptr;
}
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val,
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val.Get(),
has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
@ -1902,7 +1931,10 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
}
auto val = EvaluateConstantValue(unary, ty);
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val,
if (!val) {
return nullptr;
}
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val.Get(),
expr->HasSideEffects(), source_var);
sem->Behaviors() = expr->Behaviors();
return sem;

View File

@ -34,6 +34,7 @@
#include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h"
#include "src/tint/utils/result.h"
#include "src/tint/utils/unique_vector.h"
// Forward declarations
@ -354,15 +355,19 @@ class Resolver {
//////////////////////////////////////////////////////////////////////////////
/// Constant value evaluation methods
//////////////////////////////////////////////////////////////////////////////
/// The result type of a ConstantEvaluation method. Holds the constant value and a boolean,
/// which is true on success, false on an error.
using ConstantResult = utils::Result<sem::Constant>;
/// Convert the `value` to `target_type`
/// @return the converted value
sem::Constant ConvertValue(const sem::Constant& value, const sem::Type* target_type);
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal,
ConstantResult ConvertValue(const sem::Constant& value,
const sem::Type* target_type,
const Source& source);
ConstantResult EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
ConstantResult EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type);
sem::Constant EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type);
ConstantResult EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type);
/// @returns true if the symbol is the name of a builtin function.
bool IsBuiltin(Symbol) const;

View File

@ -14,7 +14,9 @@
#include "src/tint/resolver/resolver.h"
#include <optional>
#include <cmath>
// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
#include <optional> // NOLINT(build/include_order))
#include "src/tint/sem/abstract_float.h"
#include "src/tint/sem/abstract_int.h"
@ -30,46 +32,53 @@ namespace tint::resolver {
namespace {
/// Converts all the element values of `in` to the type `T`.
/// Converts and returns all the element values of `in` to the type `T`, using the converter
/// function `CONVERTER`.
/// @param elements_in the vector of elements to be converted
/// @param converter a function-like with the signature `void(TO&, FROM)`
/// @returns the elements converted to type T.
template <typename T, typename ELEMENTS_IN>
sem::Constant::Elements Convert(const ELEMENTS_IN& elements_in) {
template <typename T, typename ELEMENTS_IN, typename CONVERTER>
sem::Constant::Elements Transform(const ELEMENTS_IN& elements_in, CONVERTER&& converter) {
TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE();
using E = UnwrapNumber<T>;
return utils::Transform(elements_in, [&](auto value_in) {
if constexpr (std::is_same_v<E, bool>) {
if constexpr (std::is_same_v<UnwrapNumber<T>, bool>) {
return AInt(value_in != 0);
}
E converted = static_cast<E>(value_in);
if constexpr (IsFloatingPoint<E>) {
} else {
T converted{};
converter(converted, value_in);
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return AFloat(converted);
} else {
return AInt(converted);
}
}
});
TINT_END_DISABLE_WARNING_UNREACHABLE_CODE();
}
/// Converts and returns all the element values of `in` to the semantic type `el_ty`.
/// Converts and returns all the element values of `in` to the semantic type `el_ty`, using the
/// converter function `CONVERTER`.
/// @param in the constant to convert
/// @param el_ty the target element type
/// @returns the elements converted to `type`
sem::Constant::Elements Convert(const sem::Constant::Elements& in, const sem::Type* el_ty) {
/// @param converter a function-like with the signature `void(TO&, FROM)`
/// @returns the elements converted to `el_ty`
template <typename CONVERTER>
sem::Constant::Elements Transform(const sem::Constant::Elements& in,
const sem::Type* el_ty,
CONVERTER&& converter) {
return std::visit(
[&](auto&& v) {
return Switch(
el_ty, //
[&](const sem::AbstractInt*) { return Convert<AInt>(v); },
[&](const sem::AbstractFloat*) { return Convert<AFloat>(v); },
[&](const sem::I32*) { return Convert<i32>(v); },
[&](const sem::U32*) { return Convert<u32>(v); },
[&](const sem::F32*) { return Convert<f32>(v); },
[&](const sem::F16*) { return Convert<f16>(v); },
[&](const sem::Bool*) { return Convert<bool>(v); },
[&](const sem::AbstractInt*) { return Transform<AInt>(v, converter); },
[&](const sem::AbstractFloat*) { return Transform<AFloat>(v, converter); },
[&](const sem::I32*) { return Transform<i32>(v, converter); },
[&](const sem::U32*) { return Transform<u32>(v, converter); },
[&](const sem::F32*) { return Transform<f32>(v, converter); },
[&](const sem::F16*) { return Transform<f16>(v, converter); },
[&](const sem::Bool*) { return Transform<bool>(v, converter); },
[&](Default) -> sem::Constant::Elements {
diag::List diags;
TINT_UNREACHABLE(Semantic, diags)
@ -80,44 +89,91 @@ sem::Constant::Elements Convert(const sem::Constant::Elements& in, const sem::Ty
in);
}
/// Converts and returns all the elements in `in` to the type `el_ty`, by performing a `static_cast`
/// on each element value. No checks will be performed that the value fits in the target type.
/// @param in the input elements
/// @param el_ty the target element type
/// @returns the elements converted to `el_ty`
sem::Constant::Elements ConvertElements(const sem::Constant::Elements& in, const sem::Type* el_ty) {
return Transform(in, el_ty, [](auto& el_out, auto el_in) {
el_out = std::decay_t<decltype(el_out)>(el_in);
});
}
/// Converts and returns all the elements in `in` to the type `el_ty`, by performing a
/// `CheckedConvert` on each element value. A single error diagnostic will be raised if an element
/// value cannot be represented by the target type.
/// @param in the input elements
/// @param el_ty the target element type
/// @returns the elements converted to `el_ty`, or a Failure if some elements could not be
/// represented by the target type.
utils::Result<sem::Constant::Elements> MaterializeElements(const sem::Constant::Elements& in,
const sem::Type* el_ty,
ProgramBuilder& builder,
Source source) {
std::optional<std::string> failure;
auto out = Transform(in, el_ty, [&](auto& el_out, auto el_in) {
using OUT = std::decay_t<decltype(el_out)>;
if (auto conv = CheckedConvert<OUT>(el_in)) {
el_out = conv.Get();
} else if (conv.Failure() == ConversionFailure::kTooSmall) {
el_out = OUT(el_in < 0 ? -0.0 : 0.0);
} else if (!failure.has_value()) {
std::stringstream ss;
ss << "value " << el_in << " cannot be represented as ";
ss << "'" << builder.FriendlyName(el_ty) << "'";
failure = ss.str();
}
});
if (failure.has_value()) {
builder.Diagnostics().add_error(diag::System::Resolver, std::move(failure.value()), source);
return utils::Failure;
}
return out;
}
} // namespace
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::Expression* expr,
const sem::Type* type) {
if (auto* e = expr->As<ast::LiteralExpression>()) {
return EvaluateConstantValue(e, type);
}
if (auto* e = expr->As<ast::CallExpression>()) {
return EvaluateConstantValue(e, type);
}
return {};
return sem::Constant{};
}
sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type) {
return Switch(
literal,
[&](const ast::BoolLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value ? 1 : 0)}};
},
[&](const ast::IntLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value)}};
},
[&](const ast::FloatLiteralExpression* lit) {
return sem::Constant{type, {AFloat(lit->value)}};
},
[&](const ast::BoolLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value ? 1 : 0)}};
});
}
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* ty) {
uint32_t result_size = 0;
auto* el_ty = sem::Type::ElementOf(ty, &result_size);
if (!el_ty) {
return {};
return sem::Constant{};
}
// ElementOf() will also return the element type of array, which we do not support.
if (ty->Is<sem::Array>()) {
return {};
return sem::Constant{};
}
// For zero value init, return 0s
@ -142,15 +198,15 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
for (auto* expr : call->args) {
auto* arg = builder_->Sem().Get(expr);
if (!arg) {
return {};
return sem::Constant{};
}
auto value = arg->ConstantValue();
if (!value) {
return {};
return sem::Constant{};
}
// Convert the elements to the desired type.
auto converted = Convert(value.GetElements(), el_ty);
auto converted = ConvertElements(value.GetElements(), el_ty);
if (elements.has_value()) {
// Append the converted vector to elements
@ -180,20 +236,25 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
return sem::Constant(ty, std::move(elements.value()));
}
sem::Constant Resolver::ConvertValue(const sem::Constant& value, const sem::Type* ty) {
utils::Result<sem::Constant> Resolver::ConvertValue(const sem::Constant& value,
const sem::Type* ty,
const Source& source) {
if (value.Type() == ty) {
return value;
}
auto* el_ty = sem::Type::ElementOf(ty);
if (el_ty == nullptr) {
return {};
return sem::Constant{};
}
if (value.ElementType() == el_ty) {
return sem::Constant(ty, value.GetElements());
}
return sem::Constant(ty, Convert(value.GetElements(), el_ty));
if (auto res = MaterializeElements(value.GetElements(), el_ty, *builder_, source)) {
return sem::Constant(ty, std::move(res.Get()));
}
return utils::Failure;
}
} // namespace tint::resolver