tint: const eval of bitcast operator

Bug: tint:1581
Change-Id: Ida43b34118282eeb99ae099c91a6465eb3040ca6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/115080
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2022-12-20 18:34:06 +00:00
committed by Dawn LUCI CQ
parent ffa83ad1f7
commit 056618541f
35 changed files with 334 additions and 608 deletions

View File

@@ -1186,6 +1186,7 @@ if (tint_build_unittests) {
"resolver/compound_assignment_validation_test.cc",
"resolver/compound_statement_test.cc",
"resolver/const_eval_binary_op_test.cc",
"resolver/const_eval_bitcast_test.cc",
"resolver/const_eval_builtin_test.cc",
"resolver/const_eval_construction_test.cc",
"resolver/const_eval_conversion_test.cc",

View File

@@ -894,6 +894,7 @@ if(TINT_BUILD_TESTS)
resolver/compound_assignment_validation_test.cc
resolver/compound_statement_test.cc
resolver/const_eval_binary_op_test.cc
resolver/const_eval_bitcast_test.cc
resolver/const_eval_builtin_test.cc
resolver/const_eval_construction_test.cc
resolver/const_eval_conversion_test.cc

View File

@@ -68,6 +68,17 @@ auto Dispatch_iu32(F&& f, CONSTANTS&&... cs) {
[&](const type::U32*) { return f(cs->template ValueAs<u32>()...); });
}
/// Helper that calls `f` passing in the value of all `cs`.
/// Calls `f` with all constants cast to the type of the first `cs` argument.
template <typename F, typename... CONSTANTS>
auto Dispatch_fiu32(F&& f, CONSTANTS&&... cs) {
return Switch(
First(cs...)->Type(), //
[&](const type::F32*) { return f(cs->template ValueAs<f32>()...); },
[&](const type::I32*) { return f(cs->template ValueAs<i32>()...); },
[&](const type::U32*) { return f(cs->template ValueAs<u32>()...); });
}
/// Helper that calls `f` passing in the value of all `cs`.
/// Calls `f` with all constants cast to the type of the first `cs` argument.
template <typename F, typename... CONSTANTS>
@@ -1319,9 +1330,33 @@ ConstEval::Result ConstEval::Swizzle(const type::Type* ty,
return builder.create<constant::Composite>(ty, std::move(values));
}
ConstEval::Result ConstEval::Bitcast(const type::Type*, const sem::Expression*) {
// TODO(crbug.com/tint/1581): Implement @const intrinsics
return nullptr;
ConstEval::Result ConstEval::Bitcast(const type::Type* ty, const sem::Expression* expr) {
auto* value = expr->ConstantValue();
if (!value) {
return nullptr;
}
auto* el_ty = type::Type::DeepestElementOf(ty);
auto& source = expr->Declaration()->source;
auto transform = [&](const constant::Value* c0) {
auto create = [&](auto e) {
return Switch(
el_ty,
[&](const type::U32*) { //
auto r = utils::Bitcast<u32>(e);
return CreateScalar(builder, source, el_ty, r);
},
[&](const type::I32*) { //
auto r = utils::Bitcast<i32>(e);
return CreateScalar(builder, source, el_ty, r);
},
[&](const type::F32*) { //
auto r = utils::Bitcast<f32>(e);
return CreateScalar(builder, source, el_ty, r);
});
};
return Dispatch_fiu32(create, c0);
};
return TransformElements(builder, ty, transform, value);
}
ConstEval::Result ConstEval::OpComplement(const type::Type* ty,

View File

@@ -0,0 +1,192 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/resolver/const_eval_test.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
namespace {
struct Case {
Value input;
struct Success {
Value value;
};
struct Failure {
builder::CreatePtrs create_ptrs;
};
utils::Result<Success, Failure> expected;
};
static std::ostream& operator<<(std::ostream& o, const Case& c) {
o << "input: " << c.input;
if (c.expected) {
o << ", expected: " << c.expected.Get().value;
} else {
o << ", expected failed bitcast to " << c.expected.Failure().create_ptrs;
}
return o;
}
template <typename TO, typename FROM>
Case Success(FROM input, TO expected) {
return Case{input, Case::Success{expected}};
}
template <typename TO, typename FROM>
Case Failure(FROM input) {
return Case{input, Case::Failure{builder::CreatePtrsFor<TO>()}};
}
using ResolverConstEvalBitcastTest = ResolverTestWithParam<Case>;
TEST_P(ResolverConstEvalBitcastTest, Test) {
const auto& input = GetParam().input;
const auto& expected = GetParam().expected;
// Get the target type CreatePtrs
builder::CreatePtrs target_create_ptrs;
if (expected) {
target_create_ptrs = expected.Get().value.create_ptrs;
} else {
target_create_ptrs = expected.Failure().create_ptrs;
}
auto* target_ty = target_create_ptrs.ast(*this);
ASSERT_NE(target_ty, nullptr);
auto* input_val = input.Expr(*this);
const ast::Expression* expr = Bitcast(target_ty, input_val);
WrapInFunction(expr);
auto* target_sem_ty = target_create_ptrs.sem(*this);
if (expected) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
EXPECT_TYPE(sem->Type(), target_sem_ty);
ASSERT_NE(sem->ConstantValue(), nullptr);
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
auto expected_values = expected.Get().value.args;
auto got_values = ScalarsFrom(sem->ConstantValue());
EXPECT_EQ(expected_values, got_values);
} else {
ASSERT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), testing::HasSubstr("cannot be represented as"));
}
}
const u32 nan_as_u32 = utils::Bitcast<u32>(std::numeric_limits<float>::quiet_NaN());
const i32 nan_as_i32 = utils::Bitcast<i32>(std::numeric_limits<float>::quiet_NaN());
const u32 inf_as_u32 = utils::Bitcast<u32>(std::numeric_limits<float>::infinity());
const i32 inf_as_i32 = utils::Bitcast<i32>(std::numeric_limits<float>::infinity());
const u32 neg_inf_as_u32 = utils::Bitcast<u32>(-std::numeric_limits<float>::infinity());
const i32 neg_inf_as_i32 = utils::Bitcast<i32>(-std::numeric_limits<float>::infinity());
INSTANTIATE_TEST_SUITE_P(Bitcast,
ResolverConstEvalBitcastTest,
testing::ValuesIn({
// Bitcast to same (concrete) type, no change
Success(Val(0_u), Val(0_u)), //
Success(Val(0_i), Val(0_i)), //
Success(Val(0_f), Val(0_f)), //
Success(Val(123_u), Val(123_u)), //
Success(Val(123_i), Val(123_i)), //
Success(Val(123.456_f), Val(123.456_f)), //
Success(Val(u32::Highest()), Val(u32::Highest())), //
Success(Val(u32::Lowest()), Val(u32::Lowest())), //
Success(Val(i32::Highest()), Val(i32::Highest())), //
Success(Val(i32::Lowest()), Val(i32::Lowest())), //
Success(Val(f32::Highest()), Val(f32::Highest())), //
Success(Val(f32::Lowest()), Val(f32::Lowest())), //
// Bitcast to different type
Success(Val(0_u), Val(0_i)), //
Success(Val(0_u), Val(0_f)), //
Success(Val(0_i), Val(0_u)), //
Success(Val(0_i), Val(0_f)), //
Success(Val(0.0_f), Val(0_i)), //
Success(Val(0.0_f), Val(0_u)), //
Success(Val(1_u), Val(1_i)), //
Success(Val(1_u), Val(1.4013e-45_f)), //
Success(Val(1_i), Val(1_u)), //
Success(Val(1_i), Val(1.4013e-45_f)), //
Success(Val(1.0_f), Val(0x3F800000_u)), //
Success(Val(1.0_f), Val(0x3F800000_i)), //
Success(Val(123_u), Val(123_i)), //
Success(Val(123_u), Val(1.7236e-43_f)), //
Success(Val(123_i), Val(123_u)), //
Success(Val(123_i), Val(1.7236e-43_f)), //
Success(Val(123.0_f), Val(0x42F60000_u)), //
Success(Val(123.0_f), Val(0x42F60000_i)), //
// Bitcast from abstract materializes lhs first,
// so same results as above.
Success(Val(0_a), Val(0_i)), //
Success(Val(0_a), Val(0_f)), //
Success(Val(0_a), Val(0_u)), //
Success(Val(0_a), Val(0_f)), //
Success(Val(0_a), Val(0_i)), //
Success(Val(0_a), Val(0_u)), //
Success(Val(1_a), Val(1_i)), //
Success(Val(1_a), Val(1.4013e-45_f)), //
Success(Val(1_a), Val(1_u)), //
Success(Val(1_a), Val(1.4013e-45_f)), //
Success(Val(1.0_a), Val(0x3F800000_u)), //
Success(Val(1.0_a), Val(0x3F800000_i)), //
Success(Val(123_a), Val(123_i)), //
Success(Val(123_a), Val(1.7236e-43_f)), //
Success(Val(123_a), Val(123_u)), //
Success(Val(123_a), Val(1.7236e-43_f)), //
Success(Val(123.0_a), Val(0x42F60000_u)), //
Success(Val(123.0_a), Val(0x42F60000_i)), //
// u32 <-> i32 sign bit
Success(Val(0xFFFFFFFF_u), Val(-1_i)), //
Success(Val(-1_i), Val(0xFFFFFFFF_u)), //
Success(Val(0x80000000_u), Val(i32::Lowest())), //
Success(Val(i32::Lowest()), Val(0x80000000_u)), //
// Vector tests
Success(Vec(0_u, 1_u, 123_u), Vec(0_i, 1_i, 123_i)),
Success(Vec(0.0_f, 1.0_f, 123.0_f),
Vec(0_i, 0x3F800000_i, 0x42F60000_i)),
// Unrepresentable
Failure<f32>(Val(nan_as_u32)), //
Failure<f32>(Val(nan_as_i32)), //
Failure<f32>(Val(inf_as_u32)), //
Failure<f32>(Val(inf_as_i32)), //
Failure<f32>(Val(neg_inf_as_u32)), //
Failure<f32>(Val(neg_inf_as_i32)), //
Failure<builder::vec2<f32>>(Vec(nan_as_u32, 0_u)), //
Failure<builder::vec2<f32>>(Vec(nan_as_i32, 0_i)), //
Failure<builder::vec2<f32>>(Vec(inf_as_u32, 0_u)), //
Failure<builder::vec2<f32>>(Vec(inf_as_i32, 0_i)), //
Failure<builder::vec2<f32>>(Vec(neg_inf_as_u32, 0_u)), //
Failure<builder::vec2<f32>>(Vec(neg_inf_as_i32, 0_i)), //
Failure<builder::vec2<f32>>(Vec(0_u, nan_as_u32)), //
Failure<builder::vec2<f32>>(Vec(0_i, nan_as_i32)), //
Failure<builder::vec2<f32>>(Vec(0_u, inf_as_u32)), //
Failure<builder::vec2<f32>>(Vec(0_i, inf_as_i32)), //
Failure<builder::vec2<f32>>(Vec(0_u, neg_inf_as_u32)), //
Failure<builder::vec2<f32>>(Vec(0_i, neg_inf_as_i32)), //
}));
} // namespace
} // namespace tint::resolver

View File

@@ -789,11 +789,11 @@ enum class Method {
// let a = abstract_expr;
kLet,
// bitcast<f32>(abstract_expr)
kBitcastF32Arg,
// bitcast<i32>(abstract_expr)
kBitcastI32Arg,
// bitcast<vec3<f32>>(abstract_expr)
kBitcastVec3F32Arg,
// bitcast<vec3<i32>>(abstract_expr)
kBitcastVec3I32Arg,
// array<i32, abstract_expr>()
kArrayLength,
@@ -825,10 +825,10 @@ static std::ostream& operator<<(std::ostream& o, Method m) {
return o << "var";
case Method::kLet:
return o << "let";
case Method::kBitcastF32Arg:
return o << "bitcast-f32-arg";
case Method::kBitcastVec3F32Arg:
return o << "bitcast-vec3-f32-arg";
case Method::kBitcastI32Arg:
return o << "bitcast-i32-arg";
case Method::kBitcastVec3I32Arg:
return o << "bitcast-vec3-i32-arg";
case Method::kArrayLength:
return o << "array-length";
case Method::kSwitch:
@@ -903,12 +903,12 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) {
WrapInFunction(Decl(Let("a", abstract_expr())));
break;
}
case Method::kBitcastF32Arg: {
WrapInFunction(Bitcast<f32>(abstract_expr()));
case Method::kBitcastI32Arg: {
WrapInFunction(Bitcast<i32>(abstract_expr()));
break;
}
case Method::kBitcastVec3F32Arg: {
WrapInFunction(Bitcast(ty.vec3<f32>(), abstract_expr()));
case Method::kBitcastVec3I32Arg: {
WrapInFunction(Bitcast(ty.vec3<i32>(), abstract_expr()));
break;
}
case Method::kArrayLength: {
@@ -977,7 +977,7 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) {
constexpr Method kScalarMethods[] = {
Method::kLet,
Method::kVar,
Method::kBitcastF32Arg,
Method::kBitcastI32Arg,
Method::kTintMaterializeBuiltin,
};
@@ -985,7 +985,7 @@ constexpr Method kScalarMethods[] = {
constexpr Method kVectorMethods[] = {
Method::kLet,
Method::kVar,
Method::kBitcastVec3F32Arg,
Method::kBitcastVec3I32Arg,
Method::kRuntimeIndex,
Method::kTintMaterializeBuiltin,
};

View File

@@ -1957,24 +1957,27 @@ sem::Expression* Resolver::Bitcast(const ast::BitcastExpression* expr) {
if (!ty) {
return nullptr;
}
if (!validator_.Bitcast(expr, ty)) {
return nullptr;
}
//
const constant::Value* val = nullptr;
sem::EvaluationStage stage = sem::EvaluationStage::kRuntime;
// TODO(crbug.com/tint/1582): short circuit 'expr' once const eval of Bitcast is implemented.
if (auto r = const_eval_.Bitcast(ty, inner)) {
val = r.Get();
if (val) {
stage = sem::EvaluationStage::kConstant;
}
} else {
return nullptr;
}
auto stage = sem::EvaluationStage::kRuntime; // TODO(crbug.com/tint/1581)
auto* sem = builder_->create<sem::Expression>(expr, ty, stage, current_statement_,
std::move(val), inner->HasSideEffects());
sem->Behaviors() = inner->Behaviors();
if (!validator_.Bitcast(expr, ty)) {
return nullptr;
}
return sem;
}
@@ -2047,8 +2050,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
return nullptr;
}
auto stage = args_stage; // The evaluation stage of the call
const constant::Value* value = nullptr; // The constant value for the call
auto stage = args_stage; // The evaluation stage of the call
const constant::Value* value = nullptr; // The constant value for the call
if (stage == sem::EvaluationStage::kConstant) {
if (auto r = const_eval_.ArrayOrStructInit(ty, args)) {
value = r.Get();

View File

@@ -755,15 +755,18 @@ struct Value {
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
using EL_TY = typename builder::DataType<T>::ElementType;
return Value{
std::move(args), CreatePtrsFor<T>().expr, tint::IsAbstract<EL_TY>,
tint::IsIntegral<EL_TY>, tint::FriendlyName<EL_TY>(),
std::move(args), //
CreatePtrsFor<T>(), //
tint::IsAbstract<EL_TY>, //
tint::IsIntegral<EL_TY>, //
tint::FriendlyName<EL_TY>(),
};
}
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
const ast::Expression* Expr(ProgramBuilder& b) const { return (*create)(b, args); }
const ast::Expression* Expr(ProgramBuilder& b) const { return (*create_ptrs.expr)(b, args); }
/// Prints this value to the output stream
/// @param o the output stream
@@ -782,8 +785,8 @@ struct Value {
/// The arguments used to construct the value
utils::Vector<Scalar, 4> args;
/// Function used to construct an expression with the given value
builder::ast_expr_func_ptr create;
/// CreatePtrs for value's type used to create an expression with `args`
builder::CreatePtrs create_ptrs;
/// True if the element type is abstract
bool is_abstract = false;
/// True if the element type is an integer
@@ -809,9 +812,11 @@ Value Val(T v) {
}
/// Creates a Value of DataType<vec<N, T>> from N scalar `args`
template <typename... T>
Value Vec(T... args) {
using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
template <typename... Ts>
Value Vec(Ts... args) {
using FirstT = std::tuple_element_t<0, std::tuple<Ts...>>;
static_assert(std::conjunction_v<std::is_same<FirstT, Ts>...>,
"Vector args must all be the same type");
constexpr size_t N = sizeof...(args);
utils::Vector<Scalar, sizeof...(args)> v{args...};
return Value::Create<vec<N, FirstT>>(std::move(v));

View File

@@ -22,36 +22,39 @@ namespace {
using GlslGeneratorImplTest_Bitcast = TestHelper;
TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Float) {
auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr(1_i));
WrapInFunction(bitcast);
auto* a = Let("a", Expr(1_i));
auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("a"));
WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
EXPECT_EQ(out.str(), "intBitsToFloat(1)");
EXPECT_EQ(out.str(), "intBitsToFloat(a)");
}
TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Int) {
auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr(1_u));
WrapInFunction(bitcast);
auto* a = Let("a", Expr(1_u));
auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr("a"));
WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
EXPECT_EQ(out.str(), "int(1u)");
EXPECT_EQ(out.str(), "int(a)");
}
TEST_F(GlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Uint) {
auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr(1_i));
WrapInFunction(bitcast);
auto* a = Let("a", Expr(1_i));
auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr("a"));
WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
EXPECT_EQ(out.str(), "uint(1)");
EXPECT_EQ(out.str(), "uint(a)");
}
} // namespace

View File

@@ -22,36 +22,39 @@ namespace {
using HlslGeneratorImplTest_Bitcast = TestHelper;
TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Float) {
auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr(1_i));
WrapInFunction(bitcast);
auto* a = Let("a", Expr(1_i));
auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("a"));
WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
EXPECT_EQ(out.str(), "asfloat(1)");
EXPECT_EQ(out.str(), "asfloat(a)");
}
TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Int) {
auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr(1_u));
WrapInFunction(bitcast);
auto* a = Let("a", Expr(1_u));
auto* bitcast = create<ast::BitcastExpression>(ty.i32(), Expr("a"));
WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
EXPECT_EQ(out.str(), "asint(1u)");
EXPECT_EQ(out.str(), "asint(a)");
}
TEST_F(HlslGeneratorImplTest_Bitcast, EmitExpression_Bitcast_Uint) {
auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr(1_i));
WrapInFunction(bitcast);
auto* a = Let("a", Expr(1_i));
auto* bitcast = create<ast::BitcastExpression>(ty.u32(), Expr("a"));
WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
EXPECT_EQ(out.str(), "asuint(1)");
EXPECT_EQ(out.str(), "asuint(a)");
}
} // namespace

View File

@@ -22,14 +22,15 @@ namespace {
using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, EmitExpression_Bitcast) {
auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr(1_i));
WrapInFunction(bitcast);
auto* a = Let("a", Expr(1_i));
auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("a"));
WrapInFunction(a, bitcast);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, bitcast)) << gen.error();
EXPECT_EQ(out.str(), "as_type<float>(1)");
EXPECT_EQ(out.str(), "as_type<float>(a)");
}
} // namespace