tint/resolver: Ensure materialized values are representable

by the materialized type.

Bug: tint:1504
Change-Id: I3534ce62308ba2ff32c52a2f5bc8480d102153a1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91422
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-25 21:16:55 +00:00 committed by Dawn LUCI CQ
parent a8d5228049
commit e34e059804
4 changed files with 360 additions and 149 deletions

View File

@ -134,35 +134,65 @@ static std::ostream& operator<<(std::ostream& o, Method m) {
struct Data { struct Data {
std::string target_type_name; std::string target_type_name;
std::string target_element_type_name;
builder::ast_type_func_ptr target_ast_ty; builder::ast_type_func_ptr target_ast_ty;
builder::sem_type_func_ptr target_sem_ty; builder::sem_type_func_ptr target_sem_ty;
builder::ast_expr_func_ptr target_expr; builder::ast_expr_func_ptr target_expr;
std::string literal_type_name; std::string source_type_name;
builder::ast_expr_func_ptr literal_value; builder::ast_expr_func_ptr source_builder;
std::variant<AInt, AFloat> materialized_value; std::variant<AInt, AFloat> materialized_value;
double literal_value;
}; };
template <typename TARGET_TYPE, typename LITERAL_TYPE, typename MATERIALIZED_TYPE = AInt> template <typename TARGET_TYPE, typename SOURCE_TYPE, typename MATERIALIZED_TYPE>
Data Types(MATERIALIZED_TYPE materialized_value = 0_a) { Data Types(MATERIALIZED_TYPE materialized_value, double literal_value) {
using TargetDataType = builder::DataType<TARGET_TYPE>;
using SourceDataType = builder::DataType<SOURCE_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return { return {
builder::DataType<TARGET_TYPE>::Name(), // TargetDataType::Name(), // target_type_name
builder::DataType<TARGET_TYPE>::AST, // TargetElementDataType::Name(), // target_element_type_name
builder::DataType<TARGET_TYPE>::Sem, // TargetDataType::AST, // target_ast_ty
builder::DataType<TARGET_TYPE>::Expr, // TargetDataType::Sem, // target_sem_ty
builder::DataType<LITERAL_TYPE>::Name(), // TargetDataType::Expr, // target_expr
builder::DataType<LITERAL_TYPE>::Expr, // SourceDataType::Name(), // literal_type_name
SourceDataType::Expr, // literal_builder
materialized_value, materialized_value,
literal_value,
};
}
template <typename TARGET_TYPE, typename SOURCE_TYPE>
Data Types() {
using TargetDataType = builder::DataType<TARGET_TYPE>;
using SourceDataType = builder::DataType<SOURCE_TYPE>;
using TargetElementDataType = builder::DataType<typename TargetDataType::ElementType>;
return {
TargetDataType::Name(), // target_type_name
TargetElementDataType::Name(), // target_element_type_name
TargetDataType::AST, // target_ast_ty
TargetDataType::Sem, // target_sem_ty
TargetDataType::Expr, // target_expr
SourceDataType::Name(), // literal_type_name
SourceDataType::Expr, // literal_builder
0_a,
0.0,
}; };
} }
static std::ostream& operator<<(std::ostream& o, const Data& c) { static std::ostream& operator<<(std::ostream& o, const Data& c) {
return o << "[" << c.target_type_name << " <- " << c.literal_type_name << "]"; auto print_value = [&](auto&& v) { o << v; };
o << "[" << c.target_type_name << " <- " << c.source_type_name << "] [";
std::visit(print_value, c.materialized_value);
o << " <- " << c.literal_value << "]";
return o;
} }
enum class Expectation { enum class Expectation {
kMaterialize, kMaterialize,
kNoMaterialize, kNoMaterialize,
kInvalidCast, kInvalidCast,
kValueCannotBeRepresented,
}; };
static std::ostream& operator<<(std::ostream& o, Expectation m) { static std::ostream& operator<<(std::ostream& o, Expectation m) {
@ -173,6 +203,8 @@ static std::ostream& operator<<(std::ostream& o, Expectation m) {
return o << "no-materialize"; return o << "no-materialize";
case Expectation::kInvalidCast: case Expectation::kInvalidCast:
return o << "invalid-cast"; return o << "invalid-cast";
case Expectation::kValueCannotBeRepresented:
return o << "value too low or high";
} }
return o << "<unknown>"; return o << "<unknown>";
} }
@ -191,7 +223,7 @@ TEST_P(MaterializeAbstractNumeric, Test) {
auto target_ty = [&] { return data.target_ast_ty(*this); }; auto target_ty = [&] { return data.target_ast_ty(*this); };
auto target_expr = [&] { return data.target_expr(*this, 42); }; auto target_expr = [&] { return data.target_expr(*this, 42); };
auto* literal = data.literal_value(*this, 1); auto* literal = data.source_builder(*this, data.literal_value);
switch (method) { switch (method) {
case Method::kVar: case Method::kVar:
WrapInFunction(Decl(Var("a", target_ty(), literal))); WrapInFunction(Decl(Var("a", target_ty(), literal)));
@ -283,110 +315,191 @@ TEST_P(MaterializeAbstractNumeric, Test) {
switch (method) { switch (method) {
case Method::kBuiltinArg: case Method::kBuiltinArg:
expect = "error: no matching call to min(" + data.target_type_name + ", " + expect = "error: no matching call to min(" + data.target_type_name + ", " +
data.literal_type_name + ")"; data.source_type_name + ")";
break; break;
case Method::kBinaryOp: case Method::kBinaryOp:
expect = "error: no matching overload for operator + (" + expect = "error: no matching overload for operator + (" +
data.target_type_name + ", " + data.literal_type_name + ")"; data.target_type_name + ", " + data.source_type_name + ")";
break; break;
default: default:
expect = "error: cannot convert value of type '" + data.literal_type_name + expect = "error: cannot convert value of type '" + data.source_type_name +
"' to type '" + data.target_type_name + "'"; "' to type '" + data.target_type_name + "'";
break; break;
} }
EXPECT_THAT(r()->error(), testing::StartsWith(expect)); EXPECT_THAT(r()->error(), testing::StartsWith(expect));
break; break;
} }
case Expectation::kValueCannotBeRepresented:
ASSERT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), testing::HasSubstr("cannot be represented as '" +
data.target_element_type_name + "'"));
break;
} }
} }
// TODO(crbug.com/tint/1504): Test for abstract-numeric values not fitting in materialized types. /// Methods that support scalar materialization
constexpr Method kScalarMethods[] = {Method::kLet, //
Method::kVar, //
Method::kFnArg, //
Method::kBuiltinArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp};
INSTANTIATE_TEST_SUITE_P(MaterializeScalar, /// Methods that support vector materialization
MaterializeAbstractNumeric, // constexpr Method kVectorMethods[] = {Method::kLet, //
testing::Combine(testing::Values(Expectation::kMaterialize), // Method::kVar, //
testing::Values(Method::kLet, // Method::kFnArg, //
Method::kVar, // Method::kBuiltinArg, //
Method::kFnArg, // Method::kReturn, //
Method::kBuiltinArg, // Method::kArray, //
Method::kReturn, // Method::kStruct, //
Method::kArray, // Method::kBinaryOp};
Method::kStruct, //
Method::kBinaryOp), //
testing::Values(Types<i32, AInt>(1_a), //
Types<u32, AInt>(1_a), //
Types<f32, AFloat>(1.0_a) //
/* Types<f16, AFloat>(1.0_a), */ //
/* Types<f16, AFloat>(1.0_a), */)));
INSTANTIATE_TEST_SUITE_P(MaterializeVector, /// Methods that support matrix materialization
MaterializeAbstractNumeric, // constexpr Method kMatrixMethods[] = {Method::kLet, //
testing::Combine(testing::Values(Expectation::kMaterialize), // Method::kVar, //
testing::Values(Method::kLet, // Method::kFnArg, //
Method::kVar, // Method::kReturn, //
Method::kFnArg, // Method::kArray, //
Method::kBuiltinArg, // Method::kStruct, //
Method::kReturn, // Method::kBinaryOp};
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp), //
testing::Values(Types<i32V, AIntV>(1_a), //
Types<u32V, AIntV>(1_a), //
Types<f32V, AFloatV>(1.0_a) //
/* Types<f16V, AFloatV>(1.0_a), */ //
/* Types<f16V, AFloatV>(1.0_a), */)));
INSTANTIATE_TEST_SUITE_P(MaterializeMatrix, /// Methods that support materialization for switch cases
MaterializeAbstractNumeric, // constexpr Method kSwitchMethods[] = {Method::kSwitchCond, //
testing::Combine(testing::Values(Expectation::kMaterialize), // Method::kSwitchCase, //
testing::Values(Method::kLet, // Method::kSwitchCondWithAbstractCase, //
Method::kVar, // Method::kSwitchCaseWithAbstractCase};
Method::kFnArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp), //
testing::Values(Types<f32M, AFloatM>(1.0_a) //
/* Types<f16V, AFloatM>(1.0_a), */ //
)));
INSTANTIATE_TEST_SUITE_P(MaterializeSwitch, constexpr double kMaxF32 = static_cast<double>(f32::kHighest);
MaterializeAbstractNumeric, // constexpr double kPiF64 = 3.141592653589793;
testing::Combine(testing::Values(Expectation::kMaterialize), // constexpr double kPiF32 = 3.1415927410125732; // kPiF64 quantized to f32
testing::Values(Method::kSwitchCond, //
Method::kSwitchCase, // // (2^-127)×(1+(0xfffffffffffff÷0x10000000000000))
Method::kSwitchCondWithAbstractCase, // constexpr double kTooSmallF32 = 1.1754943508222874e-38;
Method::kSwitchCaseWithAbstractCase), //
testing::Values(Types<i32, AInt>(1_a), // INSTANTIATE_TEST_SUITE_P(
Types<u32, AInt>(1_a)))); MaterializeScalar,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::ValuesIn(kScalarMethods), //
testing::Values(Types<i32, AInt>(0_a, 0.0), //
Types<i32, AInt>(2147483647_a, 2147483647.0), //
Types<i32, AInt>(-2147483648_a, -2147483648.0), //
Types<u32, AInt>(0_a, 0.0), //
Types<u32, AInt>(4294967295_a, 4294967295.0), //
Types<f32, AFloat>(0.0_a, 0.0), //
Types<f32, AFloat>(AFloat(kMaxF32), kMaxF32), //
Types<f32, AFloat>(AFloat(-kMaxF32), -kMaxF32), //
Types<f32, AFloat>(AFloat(kPiF32), kPiF64), //
Types<f32, AFloat>(0.0_a, kTooSmallF32), //
Types<f32, AFloat>(-0.0_a, -kTooSmallF32) //
/* Types<f16, AFloat>(1.0_a), */ //
/* Types<f16, AFloat>(1.0_a), */)));
INSTANTIATE_TEST_SUITE_P(
MaterializeVector,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::ValuesIn(kVectorMethods), //
testing::Values(Types<i32V, AIntV>(0_a, 0.0), //
Types<i32V, AIntV>(2147483647_a, 2147483647.0), //
Types<i32V, AIntV>(-2147483648_a, -2147483648.0), //
Types<u32V, AIntV>(0_a, 0.0), //
Types<u32V, AIntV>(4294967295_a, 4294967295.0), //
Types<f32V, AFloatV>(0.0_a, 0.0), //
Types<f32V, AFloatV>(AFloat(kMaxF32), kMaxF32), //
Types<f32V, AFloatV>(AFloat(-kMaxF32), -kMaxF32), //
Types<f32V, AFloatV>(AFloat(kPiF32), kPiF64), //
Types<f32V, AFloatV>(0.0_a, kTooSmallF32), //
Types<f32V, AFloatV>(-0.0_a, -kTooSmallF32) //
/* Types<f16V, AFloatV>(1.0_a), */ //
/* Types<f16V, AFloatV>(1.0_a), */)));
INSTANTIATE_TEST_SUITE_P(
MaterializeMatrix,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::ValuesIn(kMatrixMethods), //
testing::Values(Types<f32M, AFloatM>(0.0_a, 0.0), //
Types<f32M, AFloatM>(AFloat(kMaxF32), kMaxF32), //
Types<f32M, AFloatM>(AFloat(-kMaxF32), -kMaxF32), //
Types<f32M, AFloatM>(AFloat(kPiF32), kPiF64), //
Types<f32M, AFloatM>(0.0_a, kTooSmallF32), //
Types<f32M, AFloatM>(-0.0_a, -kTooSmallF32) //
/* Types<f16V, AFloatM>(1.0_a), */ //
)));
INSTANTIATE_TEST_SUITE_P(
MaterializeSwitch,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kMaterialize), //
testing::ValuesIn(kSwitchMethods), //
testing::Values(Types<i32, AInt>(0_a, 0.0), //
Types<i32, AInt>(2147483647_a, 2147483647.0), //
Types<i32, AInt>(-2147483648_a, -2147483648.0), //
Types<u32, AInt>(0_a, 0.0), //
Types<u32, AInt>(4294967295_a, 4294967295.0))));
// TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops. // TODO(crbug.com/tint/1504): Enable once we have abstract overloads of builtins / binary ops.
INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize, INSTANTIATE_TEST_SUITE_P(DISABLED_NoMaterialize,
MaterializeAbstractNumeric, // MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kNoMaterialize), // testing::Combine(testing::Values(Expectation::kNoMaterialize), //
testing::Values(Method::kBuiltinArg, // testing::Values(Method::kBuiltinArg, //
Method::kBinaryOp), // Method::kBinaryOp), //
testing::Values(Types<AInt, AInt>(1_a), // testing::Values(Types<AInt, AInt>(), //
Types<AFloat, AFloat>(1.0_a), // Types<AFloat, AFloat>(), //
Types<AIntV, AIntV>(1_a), // Types<AIntV, AIntV>(), //
Types<AFloatV, AFloatV>(1.0_a), // Types<AFloatV, AFloatV>(), //
Types<AFloatM, AFloatM>(1.0_a)))); Types<AFloatM, AFloatM>())));
INSTANTIATE_TEST_SUITE_P(InvalidCast, INSTANTIATE_TEST_SUITE_P(InvalidCast,
MaterializeAbstractNumeric, // MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kInvalidCast), // testing::Combine(testing::Values(Expectation::kInvalidCast), //
testing::Values(Method::kLet, // testing::ValuesIn(kScalarMethods), //
Method::kVar, //
Method::kFnArg, //
Method::kBuiltinArg, //
Method::kReturn, //
Method::kArray, //
Method::kStruct, //
Method::kBinaryOp), //
testing::Values(Types<i32, AFloat>(), // testing::Values(Types<i32, AFloat>(), //
Types<u32, AFloat>(), // Types<u32, AFloat>(), //
Types<i32V, AFloatV>(), // Types<i32V, AFloatV>(), //
Types<u32V, AFloatV>()))); Types<u32V, AFloatV>())));
INSTANTIATE_TEST_SUITE_P(
ScalarValueCannotBeRepresented,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
testing::ValuesIn(kScalarMethods), //
testing::Values(Types<i32, AInt>(0_a, 2147483648.0), //
Types<i32, AInt>(0_a, -2147483649.0), //
Types<u32, AInt>(0_a, 4294967296), //
Types<u32, AInt>(0_a, -1.0), //
Types<f32, AFloat>(0.0_a, 3.5e+38), //
Types<f32, AFloat>(0.0_a, -3.5e+38) //
/* Types<f16, AFloat>(), */ //
/* Types<f16, AFloat>(), */)));
INSTANTIATE_TEST_SUITE_P(
VectorValueCannotBeRepresented,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
testing::ValuesIn(kVectorMethods), //
testing::Values(Types<i32V, AIntV>(0_a, 2147483648.0), //
Types<i32V, AIntV>(0_a, -2147483649.0), //
Types<u32V, AIntV>(0_a, 4294967296), //
Types<u32V, AIntV>(0_a, -1.0), //
Types<f32V, AFloatV>(0.0_a, 3.5e+38), //
Types<f32V, AFloatV>(0.0_a, -3.5e+38) //
/* Types<f16V, AFloatV>(), */ //
/* Types<f16V, AFloatV>(), */)));
INSTANTIATE_TEST_SUITE_P(
MatrixValueCannotBeRepresented,
MaterializeAbstractNumeric, //
testing::Combine(testing::Values(Expectation::kValueCannotBeRepresented), //
testing::ValuesIn(kMatrixMethods), //
testing::Values(Types<f32M, AFloatM>(0.0_a, 3.5e+38), //
Types<f32M, AFloatM>(0.0_a, -3.5e+38) //
/* Types<f16M, AFloatM>(), */ //
/* Types<f16M, AFloatM>(), */)));
} // namespace MaterializeTests } // namespace MaterializeTests
} // namespace } // namespace

View File

@ -1110,19 +1110,27 @@ const sem::Expression* Resolver::Materialize(const sem::Expression* expr,
// Helper for actually creating the the materialize node, performing the constant cast, updating // Helper for actually creating the the materialize node, performing the constant cast, updating
// the ast -> sem binding, and performing validation. // the ast -> sem binding, and performing validation.
auto materialize = [&](const sem::Type* target_ty) -> sem::Materialize* { auto materialize = [&](const sem::Type* target_ty) -> sem::Materialize* {
auto expr_val = EvaluateConstantValue(expr->Declaration(), expr->Type()); auto* decl = expr->Declaration();
if (!expr_val.IsValid()) { auto expr_val = EvaluateConstantValue(decl, expr->Type());
if (!expr_val) {
return nullptr;
}
if (!expr_val->IsValid()) {
TINT_ICE(Resolver, builder_->Diagnostics()) TINT_ICE(Resolver, builder_->Diagnostics())
<< expr->Declaration()->source << decl->source
<< " EvaluateConstantValue() returned invalid value for materialized " << " EvaluateConstantValue() returned invalid value for materialized "
"value of type: " "value of type: "
<< (expr->Type() ? expr->Type()->FriendlyName(builder_->Symbols()) : "<null>"); << (expr->Type() ? expr->Type()->FriendlyName(builder_->Symbols()) : "<null>");
return nullptr; return nullptr;
} }
auto materialized_val = ConvertValue(expr_val, target_ty); auto materialized_val = ConvertValue(expr_val.Get(), target_ty, decl->source);
auto* m = builder_->create<sem::Materialize>(expr, current_statement_, materialized_val); if (!materialized_val) {
return nullptr;
}
auto* m =
builder_->create<sem::Materialize>(expr, current_statement_, materialized_val.Get());
m->Behaviors() = expr->Behaviors(); m->Behaviors() = expr->Behaviors();
builder_->Sem().Replace(expr->Declaration(), m); builder_->Sem().Replace(decl, m);
return validator_.Materialize(m) ? m : nullptr; return validator_.Materialize(m) ? m : nullptr;
}; };
@ -1215,8 +1223,11 @@ sem::Expression* Resolver::IndexAccessor(const ast::IndexAccessorExpression* exp
} }
auto val = EvaluateConstantValue(expr, ty); auto val = EvaluateConstantValue(expr, ty);
if (!val) {
return nullptr;
}
bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects(); bool has_side_effects = idx->HasSideEffects() || obj->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val, auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val.Get(),
has_side_effects, obj->SourceVariable()); has_side_effects, obj->SourceVariable());
sem->Behaviors() = idx->Behaviors() + obj->Behaviors(); sem->Behaviors() = idx->Behaviors() + obj->Behaviors();
return sem; return sem;
@ -1230,7 +1241,10 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
} }
auto val = EvaluateConstantValue(expr, ty); auto val = EvaluateConstantValue(expr, ty);
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val, if (!val) {
return nullptr;
}
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_, val.Get(),
inner->HasSideEffects()); inner->HasSideEffects());
sem->Behaviors() = inner->Behaviors(); sem->Behaviors() = inner->Behaviors();
@ -1277,9 +1291,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) { if (!MaterializeArguments(args, call_target)) {
return nullptr; return nullptr;
} }
auto value = EvaluateConstantValue(expr, call_target->ReturnType()); auto val = EvaluateConstantValue(expr, call_target->ReturnType());
if (!val) {
return nullptr;
}
return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_, return builder_->create<sem::Call>(expr, call_target, std::move(args), current_statement_,
value, has_side_effects); val.Get(), has_side_effects);
}; };
// ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion // ct_ctor_or_conv is a helper for building either a sem::TypeConstructor or sem::TypeConversion
@ -1315,9 +1332,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) { if (!MaterializeArguments(args, call_target)) {
return nullptr; return nullptr;
} }
auto value = EvaluateConstantValue(expr, call_target->ReturnType()); auto val = EvaluateConstantValue(expr, call_target->ReturnType());
if (!val) {
return nullptr;
}
return builder_->create<sem::Call>(expr, call_target, std::move(args), return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, value, has_side_effects); current_statement_, val.Get(), has_side_effects);
}, },
[&](const sem::Struct* str) -> sem::Call* { [&](const sem::Struct* str) -> sem::Call* {
auto* call_target = utils::GetOrCreate( auto* call_target = utils::GetOrCreate(
@ -1337,9 +1357,12 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
if (!MaterializeArguments(args, call_target)) { if (!MaterializeArguments(args, call_target)) {
return nullptr; return nullptr;
} }
auto value = EvaluateConstantValue(expr, call_target->ReturnType()); auto val = EvaluateConstantValue(expr, call_target->ReturnType());
if (!val) {
return nullptr;
}
return builder_->create<sem::Call>(expr, call_target, std::move(args), return builder_->create<sem::Call>(expr, call_target, std::move(args),
current_statement_, value, has_side_effects); current_statement_, val.Get(), has_side_effects);
}, },
[&](Default) { [&](Default) {
AddError("type is not constructible", expr->source); AddError("type is not constructible", expr->source);
@ -1616,7 +1639,10 @@ sem::Expression* Resolver::Literal(const ast::LiteralExpression* literal) {
} }
auto val = EvaluateConstantValue(literal, ty); auto val = EvaluateConstantValue(literal, ty);
return builder_->create<sem::Expression>(literal, ty, current_statement_, val, if (!val) {
return nullptr;
}
return builder_->create<sem::Expression>(literal, ty, current_statement_, val.Get(),
/* has_side_effects */ false); /* has_side_effects */ false);
} }
@ -1828,8 +1854,11 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
} }
auto val = EvaluateConstantValue(expr, op.result); auto val = EvaluateConstantValue(expr, op.result);
if (!val) {
return nullptr;
}
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val, auto* sem = builder_->create<sem::Expression>(expr, op.result, current_statement_, val.Get(),
has_side_effects); has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
@ -1902,7 +1931,10 @@ sem::Expression* Resolver::UnaryOp(const ast::UnaryOpExpression* unary) {
} }
auto val = EvaluateConstantValue(unary, ty); auto val = EvaluateConstantValue(unary, ty);
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val, if (!val) {
return nullptr;
}
auto* sem = builder_->create<sem::Expression>(unary, ty, current_statement_, val.Get(),
expr->HasSideEffects(), source_var); expr->HasSideEffects(), source_var);
sem->Behaviors() = expr->Behaviors(); sem->Behaviors() = expr->Behaviors();
return sem; return sem;

View File

@ -34,6 +34,7 @@
#include "src/tint/sem/constant.h" #include "src/tint/sem/constant.h"
#include "src/tint/sem/function.h" #include "src/tint/sem/function.h"
#include "src/tint/sem/struct.h" #include "src/tint/sem/struct.h"
#include "src/tint/utils/result.h"
#include "src/tint/utils/unique_vector.h" #include "src/tint/utils/unique_vector.h"
// Forward declarations // Forward declarations
@ -354,15 +355,19 @@ class Resolver {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// Constant value evaluation methods /// Constant value evaluation methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// The result type of a ConstantEvaluation method. Holds the constant value and a boolean,
/// which is true on success, false on an error.
using ConstantResult = utils::Result<sem::Constant>;
/// Convert the `value` to `target_type` /// Convert the `value` to `target_type`
/// @return the converted value /// @return the converted value
sem::Constant ConvertValue(const sem::Constant& value, const sem::Type* target_type); ConstantResult ConvertValue(const sem::Constant& value,
const sem::Type* target_type,
sem::Constant EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type); const Source& source);
sem::Constant EvaluateConstantValue(const ast::LiteralExpression* literal, ConstantResult EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type);
const sem::Type* type); ConstantResult EvaluateConstantValue(const ast::LiteralExpression* literal,
sem::Constant EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type); const sem::Type* type);
ConstantResult EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* type);
/// @returns true if the symbol is the name of a builtin function. /// @returns true if the symbol is the name of a builtin function.
bool IsBuiltin(Symbol) const; bool IsBuiltin(Symbol) const;

View File

@ -14,7 +14,9 @@
#include "src/tint/resolver/resolver.h" #include "src/tint/resolver/resolver.h"
#include <optional> #include <cmath>
// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
#include <optional> // NOLINT(build/include_order))
#include "src/tint/sem/abstract_float.h" #include "src/tint/sem/abstract_float.h"
#include "src/tint/sem/abstract_int.h" #include "src/tint/sem/abstract_int.h"
@ -30,46 +32,53 @@ namespace tint::resolver {
namespace { namespace {
/// Converts all the element values of `in` to the type `T`. /// Converts and returns all the element values of `in` to the type `T`, using the converter
/// function `CONVERTER`.
/// @param elements_in the vector of elements to be converted /// @param elements_in the vector of elements to be converted
/// @param converter a function-like with the signature `void(TO&, FROM)`
/// @returns the elements converted to type T. /// @returns the elements converted to type T.
template <typename T, typename ELEMENTS_IN> template <typename T, typename ELEMENTS_IN, typename CONVERTER>
sem::Constant::Elements Convert(const ELEMENTS_IN& elements_in) { sem::Constant::Elements Transform(const ELEMENTS_IN& elements_in, CONVERTER&& converter) {
TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE(); TINT_BEGIN_DISABLE_WARNING_UNREACHABLE_CODE();
using E = UnwrapNumber<T>;
return utils::Transform(elements_in, [&](auto value_in) { return utils::Transform(elements_in, [&](auto value_in) {
if constexpr (std::is_same_v<E, bool>) { if constexpr (std::is_same_v<UnwrapNumber<T>, bool>) {
return AInt(value_in != 0); return AInt(value_in != 0);
}
E converted = static_cast<E>(value_in);
if constexpr (IsFloatingPoint<E>) {
return AFloat(converted);
} else { } else {
return AInt(converted); T converted{};
converter(converted, value_in);
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) {
return AFloat(converted);
} else {
return AInt(converted);
}
} }
}); });
TINT_END_DISABLE_WARNING_UNREACHABLE_CODE(); TINT_END_DISABLE_WARNING_UNREACHABLE_CODE();
} }
/// Converts and returns all the element values of `in` to the semantic type `el_ty`. /// Converts and returns all the element values of `in` to the semantic type `el_ty`, using the
/// converter function `CONVERTER`.
/// @param in the constant to convert /// @param in the constant to convert
/// @param el_ty the target element type /// @param el_ty the target element type
/// @returns the elements converted to `type` /// @param converter a function-like with the signature `void(TO&, FROM)`
sem::Constant::Elements Convert(const sem::Constant::Elements& in, const sem::Type* el_ty) { /// @returns the elements converted to `el_ty`
template <typename CONVERTER>
sem::Constant::Elements Transform(const sem::Constant::Elements& in,
const sem::Type* el_ty,
CONVERTER&& converter) {
return std::visit( return std::visit(
[&](auto&& v) { [&](auto&& v) {
return Switch( return Switch(
el_ty, // el_ty, //
[&](const sem::AbstractInt*) { return Convert<AInt>(v); }, [&](const sem::AbstractInt*) { return Transform<AInt>(v, converter); },
[&](const sem::AbstractFloat*) { return Convert<AFloat>(v); }, [&](const sem::AbstractFloat*) { return Transform<AFloat>(v, converter); },
[&](const sem::I32*) { return Convert<i32>(v); }, [&](const sem::I32*) { return Transform<i32>(v, converter); },
[&](const sem::U32*) { return Convert<u32>(v); }, [&](const sem::U32*) { return Transform<u32>(v, converter); },
[&](const sem::F32*) { return Convert<f32>(v); }, [&](const sem::F32*) { return Transform<f32>(v, converter); },
[&](const sem::F16*) { return Convert<f16>(v); }, [&](const sem::F16*) { return Transform<f16>(v, converter); },
[&](const sem::Bool*) { return Convert<bool>(v); }, [&](const sem::Bool*) { return Transform<bool>(v, converter); },
[&](Default) -> sem::Constant::Elements { [&](Default) -> sem::Constant::Elements {
diag::List diags; diag::List diags;
TINT_UNREACHABLE(Semantic, diags) TINT_UNREACHABLE(Semantic, diags)
@ -80,44 +89,91 @@ sem::Constant::Elements Convert(const sem::Constant::Elements& in, const sem::Ty
in); in);
} }
/// Converts and returns all the elements in `in` to the type `el_ty`, by performing a `static_cast`
/// on each element value. No checks will be performed that the value fits in the target type.
/// @param in the input elements
/// @param el_ty the target element type
/// @returns the elements converted to `el_ty`
sem::Constant::Elements ConvertElements(const sem::Constant::Elements& in, const sem::Type* el_ty) {
return Transform(in, el_ty, [](auto& el_out, auto el_in) {
el_out = std::decay_t<decltype(el_out)>(el_in);
});
}
/// Converts and returns all the elements in `in` to the type `el_ty`, by performing a
/// `CheckedConvert` on each element value. A single error diagnostic will be raised if an element
/// value cannot be represented by the target type.
/// @param in the input elements
/// @param el_ty the target element type
/// @returns the elements converted to `el_ty`, or a Failure if some elements could not be
/// represented by the target type.
utils::Result<sem::Constant::Elements> MaterializeElements(const sem::Constant::Elements& in,
const sem::Type* el_ty,
ProgramBuilder& builder,
Source source) {
std::optional<std::string> failure;
auto out = Transform(in, el_ty, [&](auto& el_out, auto el_in) {
using OUT = std::decay_t<decltype(el_out)>;
if (auto conv = CheckedConvert<OUT>(el_in)) {
el_out = conv.Get();
} else if (conv.Failure() == ConversionFailure::kTooSmall) {
el_out = OUT(el_in < 0 ? -0.0 : 0.0);
} else if (!failure.has_value()) {
std::stringstream ss;
ss << "value " << el_in << " cannot be represented as ";
ss << "'" << builder.FriendlyName(el_ty) << "'";
failure = ss.str();
}
});
if (failure.has_value()) {
builder.Diagnostics().add_error(diag::System::Resolver, std::move(failure.value()), source);
return utils::Failure;
}
return out;
}
} // namespace } // namespace
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) { utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::Expression* expr,
const sem::Type* type) {
if (auto* e = expr->As<ast::LiteralExpression>()) { if (auto* e = expr->As<ast::LiteralExpression>()) {
return EvaluateConstantValue(e, type); return EvaluateConstantValue(e, type);
} }
if (auto* e = expr->As<ast::CallExpression>()) { if (auto* e = expr->As<ast::CallExpression>()) {
return EvaluateConstantValue(e, type); return EvaluateConstantValue(e, type);
} }
return {}; return sem::Constant{};
} }
sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal, utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
const sem::Type* type) { const sem::Type* type) {
return Switch( return Switch(
literal, literal,
[&](const ast::BoolLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value ? 1 : 0)}};
},
[&](const ast::IntLiteralExpression* lit) { [&](const ast::IntLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value)}}; return sem::Constant{type, {AInt(lit->value)}};
}, },
[&](const ast::FloatLiteralExpression* lit) { [&](const ast::FloatLiteralExpression* lit) {
return sem::Constant{type, {AFloat(lit->value)}}; return sem::Constant{type, {AFloat(lit->value)}};
},
[&](const ast::BoolLiteralExpression* lit) {
return sem::Constant{type, {AInt(lit->value ? 1 : 0)}};
}); });
} }
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, utils::Result<sem::Constant> Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* ty) { const sem::Type* ty) {
uint32_t result_size = 0; uint32_t result_size = 0;
auto* el_ty = sem::Type::ElementOf(ty, &result_size); auto* el_ty = sem::Type::ElementOf(ty, &result_size);
if (!el_ty) { if (!el_ty) {
return {}; return sem::Constant{};
} }
// ElementOf() will also return the element type of array, which we do not support. // ElementOf() will also return the element type of array, which we do not support.
if (ty->Is<sem::Array>()) { if (ty->Is<sem::Array>()) {
return {}; return sem::Constant{};
} }
// For zero value init, return 0s // For zero value init, return 0s
@ -142,15 +198,15 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
for (auto* expr : call->args) { for (auto* expr : call->args) {
auto* arg = builder_->Sem().Get(expr); auto* arg = builder_->Sem().Get(expr);
if (!arg) { if (!arg) {
return {}; return sem::Constant{};
} }
auto value = arg->ConstantValue(); auto value = arg->ConstantValue();
if (!value) { if (!value) {
return {}; return sem::Constant{};
} }
// Convert the elements to the desired type. // Convert the elements to the desired type.
auto converted = Convert(value.GetElements(), el_ty); auto converted = ConvertElements(value.GetElements(), el_ty);
if (elements.has_value()) { if (elements.has_value()) {
// Append the converted vector to elements // Append the converted vector to elements
@ -180,20 +236,25 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
return sem::Constant(ty, std::move(elements.value())); return sem::Constant(ty, std::move(elements.value()));
} }
sem::Constant Resolver::ConvertValue(const sem::Constant& value, const sem::Type* ty) { utils::Result<sem::Constant> Resolver::ConvertValue(const sem::Constant& value,
const sem::Type* ty,
const Source& source) {
if (value.Type() == ty) { if (value.Type() == ty) {
return value; return value;
} }
auto* el_ty = sem::Type::ElementOf(ty); auto* el_ty = sem::Type::ElementOf(ty);
if (el_ty == nullptr) { if (el_ty == nullptr) {
return {}; return sem::Constant{};
} }
if (value.ElementType() == el_ty) { if (value.ElementType() == el_ty) {
return sem::Constant(ty, value.GetElements()); return sem::Constant(ty, value.GetElements());
} }
return sem::Constant(ty, Convert(value.GetElements(), el_ty)); if (auto res = MaterializeElements(value.GetElements(), el_ty, *builder_, source)) {
return sem::Constant(ty, std::move(res.Get()));
}
return utils::Failure;
} }
} // namespace tint::resolver } // namespace tint::resolver