Const eval for fma

This CL adds const-eval for the `fma` builtin.

Bug: tint:1581
Change-Id: Ia4df818fec9d5d969b364b2c165400d787a9e275
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111584
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
dan sinclair
2022-11-24 21:54:00 +00:00
committed by Dawn LUCI CQ
parent a2a8895020
commit 724ad2a290
97 changed files with 2429 additions and 198 deletions

View File

@@ -473,8 +473,8 @@ fn faceForward<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
@const fn firstTrailingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
@const fn floor<T: fa_f32_f16>(@test_value(1.5) T) -> T
@const fn floor<N: num, T: fa_f32_f16>(@test_value(1.5) vec<N, T>) -> vec<N, T>
fn fma<T: f32_f16>(T, T, T) -> T
fn fma<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
@const fn fma<T: fa_f32_f16>(T, T, T) -> T
@const fn fma<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn fract<T: f32_f16>(T) -> T
fn fract<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn frexp<T: fa_f32_f16>(T) -> __frexp_result<T>

View File

@@ -2446,6 +2446,33 @@ ConstEval::Result ConstEval::floor(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::fma(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c1, const sem::Constant* c2,
const sem::Constant* c3) {
auto create = [&](auto e1, auto e2, auto e3) -> ImplResult {
auto err_msg = [&] {
AddNote("when calculating fma", source);
return utils::Failure;
};
auto mul = Mul(source, e1, e2);
if (!mul) {
return err_msg();
}
auto val = Add(source, mul.Get(), e3);
if (!val) {
return err_msg();
}
return CreateElement(builder, source, c1->Type(), val.Get());
};
return Dispatch_fa_f32_f16(create, c1, c2, c3);
};
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
ConstEval::Result ConstEval::frexp(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@@ -629,6 +629,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// fma builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result fma(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// frexp builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1072,6 +1072,31 @@ INSTANTIATE_TEST_SUITE_P( //
testing::ValuesIn(Concat(FloorCases<AFloat>(), //
FloorCases<f32>(),
FloorCases<f16>()))));
template <typename T>
std::vector<Case> FmaCases() {
auto error_msg = [](auto a, const char* op, auto b) {
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
12:34 note: when calculating fma)";
};
return {
C({T(0), T(0), T(0)}, T(0)),
C({T(1), T(2), T(3)}, T(5)),
C({Vec(T(1), T(2.5), -T(1)), Vec(T(2), T(2.5), T(1)), Vec(T(4), T(3.75), -T(2))},
Vec(T(6), T(10), -T(3))),
E({T::Highest(), T::Highest(), T(0)}, error_msg(T::Highest(), "*", T::Highest())),
E({T::Highest(), T(1), T::Highest()}, error_msg(T::Highest(), "+", T::Highest())),
};
}
INSTANTIATE_TEST_SUITE_P( //
Fma,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kFma),
testing::ValuesIn(Concat(FmaCases<AFloat>(), //
FmaCases<f32>(),
FmaCases<f16>()))));
template <typename T>
std::vector<Case> FrexpCases() {
using F = T; // fract type

View File

@@ -12510,24 +12510,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 3,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[462],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::fma,
},
{
/* [350] */
/* num parameters */ 3,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[465],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::fma,
},
{
/* [351] */
@@ -14245,8 +14245,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [38] */
/* fn fma<T : f32_f16>(T, T, T) -> T */
/* fn fma<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
/* fn fma<T : fa_f32_f16>(T, T, T) -> T */
/* fn fma<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[349],
},