Implemement const-eval for some pack routines.

This CL adds const-eval for pack2x16snorm, pack2x16unorm,
pack4x8snorm and pack4x8unorm.

Bug: tint:1581
Change-Id: I58d8f02da32a6a173ca54ee5110ca7be39e2c52f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108466
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-07 14:32:16 +00:00
committed by Dawn LUCI CQ
parent 3b2b5484e2
commit 2d706a0436
26 changed files with 307 additions and 205 deletions

View File

@@ -508,10 +508,10 @@ fn modf<T: f32_f16>(T) -> __modf_result<T>
fn modf<N: num, T: f32_f16>(vec<N, T>) -> __modf_result_vec<N, T>
fn normalize<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn pack2x16float(vec2<f32>) -> u32
fn pack2x16snorm(vec2<f32>) -> u32
fn pack2x16unorm(vec2<f32>) -> u32
fn pack4x8snorm(vec4<f32>) -> u32
fn pack4x8unorm(vec4<f32>) -> u32
@const fn pack2x16snorm(vec2<f32>) -> u32
@const fn pack2x16unorm(vec2<f32>) -> u32
@const fn pack4x8snorm(vec4<f32>) -> u32
@const fn pack4x8unorm(vec4<f32>) -> u32
fn pow<T: f32_f16>(T, T) -> T
fn pow<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
@const fn quantizeToF16(f32) -> f32

View File

@@ -37,6 +37,7 @@
#include "src/tint/sem/type_initializer.h"
#include "src/tint/sem/u32.h"
#include "src/tint/sem/vector.h"
#include "src/tint/utils/bitcast.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/scoped_assignment.h"
@@ -793,6 +794,20 @@ utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
return r;
}
template <typename NumberT>
utils::Result<NumberT> ConstEval::Clamp(NumberT e, NumberT low, NumberT high) {
return NumberT{std::min(std::max(e, low), high)};
}
auto ConstEval::ClampFunc(const sem::Type* elem_ty) {
return [=](auto e, auto low, auto high) -> ImplResult {
if (auto r = Clamp(e, low, high)) {
return CreateElement(builder, elem_ty, r.Get());
}
return utils::Failure;
};
}
auto ConstEval::AddFunc(const sem::Type* elem_ty) {
return [=](auto a1, auto a2) -> ImplResult {
if (auto r = Add(a1, a2)) {
@@ -1727,11 +1742,7 @@ ConstEval::Result ConstEval::clamp(const sem::Type* ty,
const Source&) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
const sem::Constant* c2) {
auto create = [&](auto e, auto low, auto high) {
return CreateElement(builder, c0->Type(),
decltype(e)(std::min(std::max(e, low), high)));
};
return Dispatch_fia_fiu32_f16(create, c0, c1, c2);
return Dispatch_fia_fiu32_f16(ClampFunc(c0->Type()), c0, c1, c2);
};
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
}
@@ -1979,6 +1990,78 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::pack2x16snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(val, f32(-1.0f), f32(1.0f)).Get();
return u32(utils::Bitcast<uint16_t>(
static_cast<int16_t>(std::floor(0.5f + (32767.0f * clamped)))));
};
auto* e = args[0];
auto e0 = calc(e->Index(0)->As<f32>());
auto e1 = calc(e->Index(1)->As<f32>());
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
return CreateElement(builder, ty, ret);
}
ConstEval::Result ConstEval::pack2x16unorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(val, f32(0.0f), f32(1.0f)).Get();
return u32{std::floor(0.5f + (65535.0f * clamped))};
};
auto* e = args[0];
auto e0 = calc(e->Index(0)->As<f32>());
auto e1 = calc(e->Index(1)->As<f32>());
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
return CreateElement(builder, ty, ret);
}
ConstEval::Result ConstEval::pack4x8snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(val, f32(-1.0f), f32(1.0f)).Get();
return u32(
utils::Bitcast<uint8_t>(static_cast<int8_t>(std::floor(0.5f + (127.0f * clamped)))));
};
auto* e = args[0];
auto e0 = calc(e->Index(0)->As<f32>());
auto e1 = calc(e->Index(1)->As<f32>());
auto e2 = calc(e->Index(2)->As<f32>());
auto e3 = calc(e->Index(3)->As<f32>());
uint32_t mask = 0x0000'00ff;
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
return CreateElement(builder, ty, ret);
}
ConstEval::Result ConstEval::pack4x8unorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto calc = [&](f32 val) -> u32 {
auto clamped = Clamp(val, f32(0.0f), f32(1.0f)).Get();
return u32{std::floor(0.5f + (255.0f * clamped))};
};
auto* e = args[0];
auto e0 = calc(e->Index(0)->As<f32>());
auto e1 = calc(e->Index(1)->As<f32>());
auto e2 = calc(e->Index(2)->As<f32>());
auto e3 = calc(e->Index(3)->As<f32>());
uint32_t mask = 0x0000'00ff;
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
return CreateElement(builder, ty, ret);
}
ConstEval::Result ConstEval::reverseBits(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -548,6 +548,42 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack2x16snorm builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result pack2x16snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack2x16unorm builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result pack2x16unorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack4x8snorm builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result pack4x8snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack4x8unorm builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result pack4x8unorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// reverseBits builtin
/// @param ty the expression type
/// @param args the input arguments
@@ -677,6 +713,14 @@ class ConstEval {
NumberT b3,
NumberT b4);
/// Clamps e between low and high
/// @param e the number to clamp
/// @param low the lower bound
/// @param high the upper bound
/// @returns the result number on success, or logs an error and returns Failure
template <typename NumberT>
utils::Result<NumberT> Clamp(NumberT e, NumberT low, NumberT high);
/// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty`
/// if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
@@ -707,6 +751,12 @@ class ConstEval {
/// @returns the callable function
auto Dot4Func(const sem::Type* elem_ty);
/// Returns a callable that calls Clamp, and creates a Constant with its result of type
/// `elem_ty` if successful, or returns Failure otherwise.
/// @param elem_ty the element type of the Constant to create on success
/// @returns the callable function
auto ClampFunc(const sem::Type* elem_ty);
ProgramBuilder& builder;
const Source* current_source = nullptr;
};

View File

@@ -1093,6 +1093,73 @@ INSTANTIATE_TEST_SUITE_P(ExtractBits,
std::make_tuple(1000, 1000), //
std::make_tuple(u32::Highest(), u32::Highest())));
std::vector<Case> Pack4x8snormCases() {
return {
C({Vec(f32(0), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0000))),
C({Vec(f32(0), f32(0), f32(0), f32(-1))}, Val(u32(0x8100'0000))),
C({Vec(f32(0), f32(0), f32(0), f32(1))}, Val(u32(0x7f00'0000))),
C({Vec(f32(0), f32(0), f32(-1), f32(0))}, Val(u32(0x0081'0000))),
C({Vec(f32(0), f32(1), f32(0), f32(0))}, Val(u32(0x0000'7f00))),
C({Vec(f32(-1), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0081))),
C({Vec(f32(1), f32(-1), f32(1), f32(-1))}, Val(u32(0x817f'817f))),
C({Vec(f32::Highest(), f32(-0.5), f32(0.5), f32::Lowest())}, Val(u32(0x8140'c17f))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Pack4x8snorm,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kPack4X8Snorm),
testing::ValuesIn(Pack4x8snormCases())));
std::vector<Case> Pack4x8unormCases() {
return {
C({Vec(f32(0), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0000))),
C({Vec(f32(0), f32(0), f32(0), f32(1))}, Val(u32(0xff00'0000))),
C({Vec(f32(0), f32(0), f32(1), f32(0))}, Val(u32(0x00ff'0000))),
C({Vec(f32(0), f32(1), f32(0), f32(0))}, Val(u32(0x0000'ff00))),
C({Vec(f32(1), f32(0), f32(0), f32(0))}, Val(u32(0x0000'00ff))),
C({Vec(f32(1), f32(0), f32(1), f32(0))}, Val(u32(0x00ff'00ff))),
C({Vec(f32::Highest(), f32(0), f32(0.5), f32::Lowest())}, Val(u32(0x0080'00ff))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Pack4x8unorm,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kPack4X8Unorm),
testing::ValuesIn(Pack4x8unormCases())));
std::vector<Case> Pack2x16snormCases() {
return {
C({Vec(f32(0), f32(0))}, Val(u32(0x0000'0000))),
C({Vec(f32(0), f32(-1))}, Val(u32(0x8001'0000))),
C({Vec(f32(0), f32(1))}, Val(u32(0x7fff'0000))),
C({Vec(f32(-1), f32(0))}, Val(u32(0x0000'8001))),
C({Vec(f32(1), f32(0))}, Val(u32(0x0000'7fff))),
C({Vec(f32(1), f32(-1))}, Val(u32(0x8001'7fff))),
C({Vec(f32::Highest(), f32::Lowest())}, Val(u32(0x8001'7fff))),
C({Vec(f32(-0.5), f32(0.5))}, Val(u32(0x4000'c001))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Pack2x16snorm,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Snorm),
testing::ValuesIn(Pack2x16snormCases())));
std::vector<Case> Pack2x16unormCases() {
return {
C({Vec(f32(0), f32(1))}, Val(u32(0xffff'0000))),
C({Vec(f32(1), f32(0))}, Val(u32(0x0000'ffff))),
C({Vec(f32(0.5), f32(0))}, Val(u32(0x0000'8000))),
C({Vec(f32::Highest(), f32::Lowest())}, Val(u32(0x0000'ffff))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Pack2x16unorm,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Unorm),
testing::ValuesIn(Pack2x16unormCases())));
template <typename T>
std::vector<Case> ReverseBitsCases() {
using B = BitValues<T>;

View File

@@ -13946,7 +13946,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[878],
/* return matcher indices */ &kMatcherIndices[95],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::pack4x8unorm,
},
{
/* [468] */
@@ -13958,7 +13958,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[877],
/* return matcher indices */ &kMatcherIndices[95],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::pack4x8snorm,
},
{
/* [469] */
@@ -13970,7 +13970,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[868],
/* return matcher indices */ &kMatcherIndices[95],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::pack2x16unorm,
},
{
/* [470] */
@@ -13982,7 +13982,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[867],
/* return matcher indices */ &kMatcherIndices[95],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::pack2x16snorm,
},
{
/* [471] */