tint: Implement const-eval of frexp()

Also add abstract overloads.

Bug: tint:1581
Fixed: tint:1768
Change-Id: Icda465e0cfe960b77823c2135f0cfe8f82ed394f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111441
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2022-11-23 18:21:38 +00:00
committed by Dawn LUCI CQ
parent d743778ed5
commit 69c2c34326
218 changed files with 6625 additions and 719 deletions

View File

@@ -1079,54 +1079,8 @@ TEST_F(ResolverBuiltinFloatTest, Frexp_Error_FirstParamInt) {
R"(error: no matching call to frexp(i32, ptr<workgroup, i32, read_write>)
2 candidate functions:
frexp(T) -> __frexp_result_T where: T is f32 or f16
frexp(vecN<T>) -> __frexp_result_vecN_T where: T is f32 or f16
)");
}
TEST_F(ResolverBuiltinFloatTest, Frexp_Error_SecondParamFloatPtr) {
GlobalVar("v", ty.f32(), ast::AddressSpace::kWorkgroup);
auto* call = Call("frexp", 1_f, AddressOf("v"));
WrapInFunction(call);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(error: no matching call to frexp(f32, ptr<workgroup, f32, read_write>)
2 candidate functions:
frexp(T) -> __frexp_result_T where: T is f32 or f16
frexp(vecN<T>) -> __frexp_result_vecN_T where: T is f32 or f16
)");
}
TEST_F(ResolverBuiltinFloatTest, Frexp_Error_SecondParamNotAPointer) {
auto* call = Call("frexp", 1_f, 1_i);
WrapInFunction(call);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(error: no matching call to frexp(f32, i32)
2 candidate functions:
frexp(T) -> __frexp_result_T where: T is f32 or f16
frexp(vecN<T>) -> __frexp_result_vecN_T where: T is f32 or f16
)");
}
TEST_F(ResolverBuiltinFloatTest, Frexp_Error_VectorSizesDontMatch) {
GlobalVar("v", ty.vec4<i32>(), ast::AddressSpace::kWorkgroup);
auto* call = Call("frexp", vec2<f32>(1_f, 2_f), AddressOf("v"));
WrapInFunction(call);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(error: no matching call to frexp(vec2<f32>, ptr<workgroup, vec4<i32>, read_write>)
2 candidate functions:
frexp(T) -> __frexp_result_T where: T is f32 or f16
frexp(vecN<T>) -> __frexp_result_vecN_T where: T is f32 or f16
frexp(T) -> __frexp_result_T where: T is abstract-float, f32 or f16
frexp(vecN<T>) -> __frexp_result_vecN_T where: T is abstract-float, f32 or f16
)");
}

View File

@@ -2219,6 +2219,79 @@ ConstEval::Result ConstEval::floor(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::frexp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto* arg = args[0];
struct FractExp {
ImplResult fract;
ImplResult exp;
};
auto scalar = [&](const sem::Constant* s) {
int exp = 0;
double fract = std::frexp(s->As<AFloat>(), &exp);
return Switch(
s->Type(),
[&](const sem::F32*) {
return FractExp{
CreateElement(builder, source, builder.create<sem::F32>(), f32(fract)),
CreateElement(builder, source, builder.create<sem::I32>(), i32(exp)),
};
},
[&](const sem::F16*) {
return FractExp{
CreateElement(builder, source, builder.create<sem::F16>(), f16(fract)),
CreateElement(builder, source, builder.create<sem::I32>(), i32(exp)),
};
},
[&](const sem::AbstractFloat*) {
return FractExp{
CreateElement(builder, source, builder.create<sem::AbstractFloat>(),
AFloat(fract)),
CreateElement(builder, source, builder.create<sem::AbstractInt>(), AInt(exp)),
};
},
[&](Default) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "unhandled element type for frexp() const-eval: "
<< builder.FriendlyName(s->Type());
return FractExp{utils::Failure, utils::Failure};
});
};
if (auto* vec = arg->Type()->As<sem::Vector>()) {
utils::Vector<const sem::Constant*, 4> fract_els;
utils::Vector<const sem::Constant*, 4> exp_els;
for (uint32_t i = 0; i < vec->Width(); i++) {
auto fe = scalar(arg->Index(i));
if (!fe.fract || !fe.exp) {
return utils::Failure;
}
fract_els.Push(fe.fract.Get());
exp_els.Push(fe.exp.Get());
}
auto fract_ty = builder.create<sem::Vector>(fract_els[0]->Type(), vec->Width());
auto exp_ty = builder.create<sem::Vector>(exp_els[0]->Type(), vec->Width());
return CreateComposite(builder, ty,
utils::Vector<const sem::Constant*, 2>{
CreateComposite(builder, fract_ty, std::move(fract_els)),
CreateComposite(builder, exp_ty, std::move(exp_els)),
});
} else {
auto fe = scalar(arg);
if (!fe.fract || !fe.exp) {
return utils::Failure;
}
return CreateComposite(builder, ty,
utils::Vector<const sem::Constant*, 2>{
fe.fract.Get(),
fe.exp.Get(),
});
}
}
ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -610,6 +610,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// frexp builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result frexp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// insertBits builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -968,6 +968,62 @@ INSTANTIATE_TEST_SUITE_P( //
testing::ValuesIn(Concat(FloorCases<AFloat>(), //
FloorCases<f32>(),
FloorCases<f16>()))));
template <typename T>
std::vector<Case> FrexpCases() {
using F = T; // fract type
using E = std::conditional_t<std::is_same_v<T, AFloat>, AInt, i32>; // exp type
auto cases = std::vector<Case>{
// Scalar tests
// in fract exp
C({T(-3.5)}, {F(-0.875), E(2)}), //
C({T(-3.0)}, {F(-0.750), E(2)}), //
C({T(-2.5)}, {F(-0.625), E(2)}), //
C({T(-2.0)}, {F(-0.500), E(2)}), //
C({T(-1.5)}, {F(-0.750), E(1)}), //
C({T(-1.0)}, {F(-0.500), E(1)}), //
C({T(+0.0)}, {F(+0.000), E(0)}), //
C({T(+1.0)}, {F(+0.500), E(1)}), //
C({T(+1.5)}, {F(+0.750), E(1)}), //
C({T(+2.0)}, {F(+0.500), E(2)}), //
C({T(+2.5)}, {F(+0.625), E(2)}), //
C({T(+3.0)}, {F(+0.750), E(2)}), //
C({T(+3.5)}, {F(+0.875), E(2)}), //
// Vector tests
// in fract exp
C({Vec(T(-2.5), T(+1.0))}, {Vec(F(-0.625), F(+0.500)), Vec(E(2), E(1))}),
C({Vec(T(+3.5), T(-2.5))}, {Vec(F(+0.875), F(-0.625)), Vec(E(2), E(2))}),
};
ConcatIntoIf<std::is_same_v<T, f16>>(cases, std::vector<Case>{
C({T::Highest()}, {F(0x0.ffep0), E(16)}), //
C({T::Lowest()}, {F(-0x0.ffep0), E(16)}), //
C({T::Smallest()}, {F(0.5), E(-13)}), //
});
ConcatIntoIf<std::is_same_v<T, f32>>(cases,
std::vector<Case>{
C({T::Highest()}, {F(0x0.ffffffp0), E(128)}), //
C({T::Lowest()}, {F(-0x0.ffffffp0), E(128)}), //
C({T::Smallest()}, {F(0.5), E(-125)}), //
});
ConcatIntoIf<std::is_same_v<T, AFloat>>(
cases, std::vector<Case>{
C({T::Highest()}, {F(0x0.fffffffffffff8p0), E(1024)}), //
C({T::Lowest()}, {F(-0x0.fffffffffffff8p0), E(1024)}), //
C({T::Smallest()}, {F(0.5), E(-1021)}), //
});
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Frexp,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kFrexp),
testing::ValuesIn(Concat(FrexpCases<AFloat>(), //
FrexpCases<f32>(), //
FrexpCases<f16>()))));
template <typename T>
std::vector<Case> InsertBitsCases() {

View File

@@ -855,6 +855,7 @@ const sem::Struct* build_modf_result(MatchState& state, const sem::Type* el) {
return nullptr;
});
}
const sem::Struct* build_modf_result_vec(MatchState& state, Number& n, const sem::Type* el) {
auto prefix = "__modf_result_vec" + std::to_string(n.Value());
auto build_f32 = [&] {
@@ -883,27 +884,70 @@ const sem::Struct* build_modf_result_vec(MatchState& state, Number& n, const sem
return nullptr;
});
}
const sem::Struct* build_frexp_result(MatchState& state, const sem::Type* el) {
std::string display_name;
if (el->Is<sem::F16>()) {
display_name = "__frexp_result_f16";
} else {
display_name = "__frexp_result";
}
auto* i32 = state.builder.create<sem::I32>();
return build_struct(state.builder, display_name, {{"fract", el}, {"exp", i32}});
auto build_f32 = [&] {
auto* f = state.builder.create<sem::F32>();
auto* i = state.builder.create<sem::I32>();
return build_struct(state.builder, "__frexp_result", {{"fract", f}, {"exp", i}});
};
auto build_f16 = [&] {
auto* f = state.builder.create<sem::F16>();
auto* i = state.builder.create<sem::I32>();
return build_struct(state.builder, "__frexp_result_f16", {{"fract", f}, {"exp", i}});
};
return Switch(
el, //
[&](const sem::F32*) { return build_f32(); }, //
[&](const sem::F16*) { return build_f16(); }, //
[&](const sem::AbstractFloat*) {
auto* i = state.builder.create<sem::AbstractInt>();
auto* abstract =
build_struct(state.builder, "__frexp_result_abstract", {{"fract", el}, {"exp", i}});
abstract->SetConcreteTypes(utils::Vector{build_f32(), build_f16()});
return abstract;
},
[&](Default) {
TINT_ICE(Resolver, state.builder.Diagnostics())
<< "unhandled frexp type: " << state.builder.FriendlyName(el);
return nullptr;
});
}
const sem::Struct* build_frexp_result_vec(MatchState& state, Number& n, const sem::Type* el) {
std::string display_name;
if (el->Is<sem::F16>()) {
display_name = "__frexp_result_vec" + std::to_string(n.Value()) + "_f16";
} else {
display_name = "__frexp_result_vec" + std::to_string(n.Value());
}
auto* vec = state.builder.create<sem::Vector>(el, n.Value());
auto* vec_i32 = state.builder.create<sem::Vector>(state.builder.create<sem::I32>(), n.Value());
return build_struct(state.builder, display_name, {{"fract", vec}, {"exp", vec_i32}});
auto prefix = "__frexp_result_vec" + std::to_string(n.Value());
auto build_f32 = [&] {
auto* f = state.builder.create<sem::Vector>(state.builder.create<sem::F32>(), n.Value());
auto* e = state.builder.create<sem::Vector>(state.builder.create<sem::I32>(), n.Value());
return build_struct(state.builder, prefix, {{"fract", f}, {"exp", e}});
};
auto build_f16 = [&] {
auto* f = state.builder.create<sem::Vector>(state.builder.create<sem::F16>(), n.Value());
auto* e = state.builder.create<sem::Vector>(state.builder.create<sem::I32>(), n.Value());
return build_struct(state.builder, prefix + "_f16", {{"fract", f}, {"exp", e}});
};
return Switch(
el, //
[&](const sem::F32*) { return build_f32(); }, //
[&](const sem::F16*) { return build_f16(); }, //
[&](const sem::AbstractFloat*) {
auto* f = state.builder.create<sem::Vector>(el, n.Value());
auto* e = state.builder.create<sem::Vector>(state.builder.create<sem::AbstractInt>(),
n.Value());
auto* abstract =
build_struct(state.builder, prefix + "_abstract", {{"fract", f}, {"exp", e}});
abstract->SetConcreteTypes(utils::Vector{build_f32(), build_f16()});
return abstract;
},
[&](Default) {
TINT_ICE(Resolver, state.builder.Diagnostics())
<< "unhandled frexp type: " << state.builder.FriendlyName(el);
return nullptr;
});
}
const sem::Struct* build_atomic_compare_exchange_result(MatchState& state, const sem::Type* ty) {
return build_struct(
state.builder,

View File

@@ -12558,24 +12558,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[862],
/* return matcher indices */ &kMatcherIndices[104],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::frexp,
},
{
/* [354] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[863],
/* return matcher indices */ &kMatcherIndices[39],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::frexp,
},
{
/* [355] */
@@ -14259,8 +14259,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [40] */
/* fn frexp<T : f32_f16>(T) -> __frexp_result<T> */
/* fn frexp<N : num, T : f32_f16>(vec<N, T>) -> __frexp_result_vec<N, T> */
/* fn frexp<T : fa_f32_f16>(T) -> __frexp_result<T> */
/* fn frexp<N : num, T : fa_f32_f16>(vec<N, T>) -> __frexp_result_vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[353],
},

View File

@@ -1320,6 +1320,94 @@ TEST_F(MaterializeAbstractStructure, Modf_Vector_ExplicitType) {
abstract_str->Members()[0]->Type()->As<sem::Vector>()->type()->Is<sem::AbstractFloat>());
}
TEST_F(MaterializeAbstractStructure, Frexp_Scalar_DefaultType) {
// var v = frexp(1);
auto* call = Call("frexp", 1_a);
WrapInFunction(Decl(Var("v", call)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
ASSERT_TRUE(sem->Is<sem::Materialize>());
auto* materialize = sem->As<sem::Materialize>();
ASSERT_TRUE(materialize->Type()->Is<sem::Struct>());
auto* concrete_str = materialize->Type()->As<sem::Struct>();
ASSERT_TRUE(concrete_str->Members()[0]->Type()->Is<sem::F32>());
ASSERT_TRUE(concrete_str->Members()[1]->Type()->Is<sem::I32>());
ASSERT_TRUE(materialize->Expr()->Type()->Is<sem::Struct>());
auto* abstract_str = materialize->Expr()->Type()->As<sem::Struct>();
ASSERT_TRUE(abstract_str->Members()[0]->Type()->Is<sem::AbstractFloat>());
ASSERT_TRUE(abstract_str->Members()[1]->Type()->Is<sem::AbstractInt>());
}
TEST_F(MaterializeAbstractStructure, Frexp_Vector_DefaultType) {
// var v = frexp(vec2(1));
auto* call = Call("frexp", Construct(ty.vec2(nullptr), 1_a));
WrapInFunction(Decl(Var("v", call)));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
ASSERT_TRUE(sem->Is<sem::Materialize>());
auto* materialize = sem->As<sem::Materialize>();
ASSERT_TRUE(materialize->Type()->Is<sem::Struct>());
auto* concrete_str = materialize->Type()->As<sem::Struct>();
ASSERT_TRUE(concrete_str->Members()[0]->Type()->Is<sem::Vector>());
ASSERT_TRUE(concrete_str->Members()[1]->Type()->Is<sem::Vector>());
ASSERT_TRUE(concrete_str->Members()[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
ASSERT_TRUE(concrete_str->Members()[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
ASSERT_TRUE(materialize->Expr()->Type()->Is<sem::Struct>());
auto* abstract_str = materialize->Expr()->Type()->As<sem::Struct>();
ASSERT_TRUE(abstract_str->Members()[0]->Type()->Is<sem::Vector>());
ASSERT_TRUE(
abstract_str->Members()[0]->Type()->As<sem::Vector>()->type()->Is<sem::AbstractFloat>());
ASSERT_TRUE(
abstract_str->Members()[1]->Type()->As<sem::Vector>()->type()->Is<sem::AbstractInt>());
}
TEST_F(MaterializeAbstractStructure, Frexp_Scalar_ExplicitType) {
// var v = frexp(1_h); // v is __frexp_result_f16
// v = frexp(1); // __frexp_result_f16 <- __frexp_result_abstract
Enable(ast::Extension::kF16);
auto* call = Call("frexp", 1_a);
WrapInFunction(Decl(Var("v", Call("frexp", 1_h))), //
Assign("v", call));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
ASSERT_TRUE(sem->Is<sem::Materialize>());
auto* materialize = sem->As<sem::Materialize>();
ASSERT_TRUE(materialize->Type()->Is<sem::Struct>());
auto* concrete_str = materialize->Type()->As<sem::Struct>();
ASSERT_TRUE(concrete_str->Members()[0]->Type()->Is<sem::F16>());
ASSERT_TRUE(concrete_str->Members()[1]->Type()->Is<sem::I32>());
ASSERT_TRUE(materialize->Expr()->Type()->Is<sem::Struct>());
auto* abstract_str = materialize->Expr()->Type()->As<sem::Struct>();
ASSERT_TRUE(abstract_str->Members()[0]->Type()->Is<sem::AbstractFloat>());
ASSERT_TRUE(abstract_str->Members()[1]->Type()->Is<sem::AbstractInt>());
}
TEST_F(MaterializeAbstractStructure, Frexp_Vector_ExplicitType) {
// var v = frexp(vec2(1_h)); // v is __frexp_result_vec2_f16
// v = frexp(vec2(1)); // __frexp_result_vec2_f16 <- __frexp_result_vec2_abstract
Enable(ast::Extension::kF16);
auto* call = Call("frexp", Construct(ty.vec2(nullptr), 1_a));
WrapInFunction(Decl(Var("v", Call("frexp", Construct(ty.vec2(nullptr), 1_h)))),
Assign("v", call));
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(call);
ASSERT_TRUE(sem->Is<sem::Materialize>());
auto* materialize = sem->As<sem::Materialize>();
ASSERT_TRUE(materialize->Type()->Is<sem::Struct>());
auto* concrete_str = materialize->Type()->As<sem::Struct>();
ASSERT_TRUE(concrete_str->Members()[0]->Type()->Is<sem::Vector>());
ASSERT_TRUE(concrete_str->Members()[1]->Type()->Is<sem::Vector>());
ASSERT_TRUE(concrete_str->Members()[0]->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
ASSERT_TRUE(concrete_str->Members()[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
ASSERT_TRUE(materialize->Expr()->Type()->Is<sem::Struct>());
auto* abstract_str = materialize->Expr()->Type()->As<sem::Struct>();
ASSERT_TRUE(abstract_str->Members()[0]->Type()->Is<sem::Vector>());
ASSERT_TRUE(
abstract_str->Members()[0]->Type()->As<sem::Vector>()->type()->Is<sem::AbstractFloat>());
ASSERT_TRUE(
abstract_str->Members()[1]->Type()->As<sem::Vector>()->type()->Is<sem::AbstractInt>());
}
} // namespace materialize_abstract_structure
} // namespace