mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-09 21:47:47 +00:00
Const eval for fma
This CL adds const-eval for the `fma` builtin. Bug: tint:1581 Change-Id: Ia4df818fec9d5d969b364b2c165400d787a9e275 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111584 Commit-Queue: Dan Sinclair <dsinclair@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
a2a8895020
commit
724ad2a290
@@ -473,8 +473,8 @@ fn faceForward<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const fn firstTrailingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
|
||||
@const fn floor<T: fa_f32_f16>(@test_value(1.5) T) -> T
|
||||
@const fn floor<N: num, T: fa_f32_f16>(@test_value(1.5) vec<N, T>) -> vec<N, T>
|
||||
fn fma<T: f32_f16>(T, T, T) -> T
|
||||
fn fma<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const fn fma<T: fa_f32_f16>(T, T, T) -> T
|
||||
@const fn fma<N: num, T: fa_f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
fn fract<T: f32_f16>(T) -> T
|
||||
fn fract<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
@const fn frexp<T: fa_f32_f16>(T) -> __frexp_result<T>
|
||||
|
||||
@@ -2446,6 +2446,33 @@ ConstEval::Result ConstEval::floor(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::fma(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c1, const sem::Constant* c2,
|
||||
const sem::Constant* c3) {
|
||||
auto create = [&](auto e1, auto e2, auto e3) -> ImplResult {
|
||||
auto err_msg = [&] {
|
||||
AddNote("when calculating fma", source);
|
||||
return utils::Failure;
|
||||
};
|
||||
|
||||
auto mul = Mul(source, e1, e2);
|
||||
if (!mul) {
|
||||
return err_msg();
|
||||
}
|
||||
|
||||
auto val = Add(source, mul.Get(), e3);
|
||||
if (!val) {
|
||||
return err_msg();
|
||||
}
|
||||
return CreateElement(builder, source, c1->Type(), val.Get());
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c1, c2, c3);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::frexp(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
||||
@@ -629,6 +629,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// fma builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result fma(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// frexp builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -1072,6 +1072,31 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
testing::ValuesIn(Concat(FloorCases<AFloat>(), //
|
||||
FloorCases<f32>(),
|
||||
FloorCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> FmaCases() {
|
||||
auto error_msg = [](auto a, const char* op, auto b) {
|
||||
return "12:34 error: " + OverflowErrorMessage(a, op, b) + R"(
|
||||
12:34 note: when calculating fma)";
|
||||
};
|
||||
return {
|
||||
C({T(0), T(0), T(0)}, T(0)),
|
||||
C({T(1), T(2), T(3)}, T(5)),
|
||||
C({Vec(T(1), T(2.5), -T(1)), Vec(T(2), T(2.5), T(1)), Vec(T(4), T(3.75), -T(2))},
|
||||
Vec(T(6), T(10), -T(3))),
|
||||
|
||||
E({T::Highest(), T::Highest(), T(0)}, error_msg(T::Highest(), "*", T::Highest())),
|
||||
E({T::Highest(), T(1), T::Highest()}, error_msg(T::Highest(), "+", T::Highest())),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Fma,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kFma),
|
||||
testing::ValuesIn(Concat(FmaCases<AFloat>(), //
|
||||
FmaCases<f32>(),
|
||||
FmaCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> FrexpCases() {
|
||||
using F = T; // fract type
|
||||
|
||||
@@ -12510,24 +12510,24 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[462],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::fma,
|
||||
},
|
||||
{
|
||||
/* [350] */
|
||||
/* num parameters */ 3,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[465],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::fma,
|
||||
},
|
||||
{
|
||||
/* [351] */
|
||||
@@ -14245,8 +14245,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [38] */
|
||||
/* fn fma<T : f32_f16>(T, T, T) -> T */
|
||||
/* fn fma<N : num, T : f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* fn fma<T : fa_f32_f16>(T, T, T) -> T */
|
||||
/* fn fma<N : num, T : fa_f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[349],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user