mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-10 05:57:51 +00:00
Implemement const-eval for some pack routines.
This CL adds const-eval for pack2x16snorm, pack2x16unorm, pack4x8snorm and pack4x8unorm. Bug: tint:1581 Change-Id: I58d8f02da32a6a173ca54ee5110ca7be39e2c52f Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108466 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
3b2b5484e2
commit
2d706a0436
@@ -508,10 +508,10 @@ fn modf<T: f32_f16>(T) -> __modf_result<T>
|
||||
fn modf<N: num, T: f32_f16>(vec<N, T>) -> __modf_result_vec<N, T>
|
||||
fn normalize<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
|
||||
fn pack2x16float(vec2<f32>) -> u32
|
||||
fn pack2x16snorm(vec2<f32>) -> u32
|
||||
fn pack2x16unorm(vec2<f32>) -> u32
|
||||
fn pack4x8snorm(vec4<f32>) -> u32
|
||||
fn pack4x8unorm(vec4<f32>) -> u32
|
||||
@const fn pack2x16snorm(vec2<f32>) -> u32
|
||||
@const fn pack2x16unorm(vec2<f32>) -> u32
|
||||
@const fn pack4x8snorm(vec4<f32>) -> u32
|
||||
@const fn pack4x8unorm(vec4<f32>) -> u32
|
||||
fn pow<T: f32_f16>(T, T) -> T
|
||||
fn pow<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
@const fn quantizeToF16(f32) -> f32
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
#include "src/tint/sem/type_initializer.h"
|
||||
#include "src/tint/sem/u32.h"
|
||||
#include "src/tint/sem/vector.h"
|
||||
#include "src/tint/utils/bitcast.h"
|
||||
#include "src/tint/utils/compiler_macros.h"
|
||||
#include "src/tint/utils/map.h"
|
||||
#include "src/tint/utils/scoped_assignment.h"
|
||||
@@ -793,6 +794,20 @@ utils::Result<NumberT> ConstEval::Dot4(NumberT a1,
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> ConstEval::Clamp(NumberT e, NumberT low, NumberT high) {
|
||||
return NumberT{std::min(std::max(e, low), high)};
|
||||
}
|
||||
|
||||
auto ConstEval::ClampFunc(const sem::Type* elem_ty) {
|
||||
return [=](auto e, auto low, auto high) -> ImplResult {
|
||||
if (auto r = Clamp(e, low, high)) {
|
||||
return CreateElement(builder, elem_ty, r.Get());
|
||||
}
|
||||
return utils::Failure;
|
||||
};
|
||||
}
|
||||
|
||||
auto ConstEval::AddFunc(const sem::Type* elem_ty) {
|
||||
return [=](auto a1, auto a2) -> ImplResult {
|
||||
if (auto r = Add(a1, a2)) {
|
||||
@@ -1727,11 +1742,7 @@ ConstEval::Result ConstEval::clamp(const sem::Type* ty,
|
||||
const Source&) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1,
|
||||
const sem::Constant* c2) {
|
||||
auto create = [&](auto e, auto low, auto high) {
|
||||
return CreateElement(builder, c0->Type(),
|
||||
decltype(e)(std::min(std::max(e, low), high)));
|
||||
};
|
||||
return Dispatch_fia_fiu32_f16(create, c0, c1, c2);
|
||||
return Dispatch_fia_fiu32_f16(ClampFunc(c0->Type()), c0, c1, c2);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0], args[1], args[2]);
|
||||
}
|
||||
@@ -1979,6 +1990,78 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::pack2x16snorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
auto calc = [&](f32 val) -> u32 {
|
||||
auto clamped = Clamp(val, f32(-1.0f), f32(1.0f)).Get();
|
||||
return u32(utils::Bitcast<uint16_t>(
|
||||
static_cast<int16_t>(std::floor(0.5f + (32767.0f * clamped)))));
|
||||
};
|
||||
|
||||
auto* e = args[0];
|
||||
auto e0 = calc(e->Index(0)->As<f32>());
|
||||
auto e1 = calc(e->Index(1)->As<f32>());
|
||||
|
||||
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
|
||||
return CreateElement(builder, ty, ret);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::pack2x16unorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
auto calc = [&](f32 val) -> u32 {
|
||||
auto clamped = Clamp(val, f32(0.0f), f32(1.0f)).Get();
|
||||
return u32{std::floor(0.5f + (65535.0f * clamped))};
|
||||
};
|
||||
|
||||
auto* e = args[0];
|
||||
auto e0 = calc(e->Index(0)->As<f32>());
|
||||
auto e1 = calc(e->Index(1)->As<f32>());
|
||||
|
||||
u32 ret = u32((e0 & 0x0000'ffff) | (e1 << 16));
|
||||
return CreateElement(builder, ty, ret);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::pack4x8snorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
auto calc = [&](f32 val) -> u32 {
|
||||
auto clamped = Clamp(val, f32(-1.0f), f32(1.0f)).Get();
|
||||
return u32(
|
||||
utils::Bitcast<uint8_t>(static_cast<int8_t>(std::floor(0.5f + (127.0f * clamped)))));
|
||||
};
|
||||
|
||||
auto* e = args[0];
|
||||
auto e0 = calc(e->Index(0)->As<f32>());
|
||||
auto e1 = calc(e->Index(1)->As<f32>());
|
||||
auto e2 = calc(e->Index(2)->As<f32>());
|
||||
auto e3 = calc(e->Index(3)->As<f32>());
|
||||
|
||||
uint32_t mask = 0x0000'00ff;
|
||||
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
|
||||
return CreateElement(builder, ty, ret);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::pack4x8unorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
auto calc = [&](f32 val) -> u32 {
|
||||
auto clamped = Clamp(val, f32(0.0f), f32(1.0f)).Get();
|
||||
return u32{std::floor(0.5f + (255.0f * clamped))};
|
||||
};
|
||||
|
||||
auto* e = args[0];
|
||||
auto e0 = calc(e->Index(0)->As<f32>());
|
||||
auto e1 = calc(e->Index(1)->As<f32>());
|
||||
auto e2 = calc(e->Index(2)->As<f32>());
|
||||
auto e3 = calc(e->Index(3)->As<f32>());
|
||||
|
||||
uint32_t mask = 0x0000'00ff;
|
||||
u32 ret = u32((e0 & mask) | ((e1 & mask) << 8) | ((e2 & mask) << 16) | ((e3 & mask) << 24));
|
||||
return CreateElement(builder, ty, ret);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::reverseBits(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
|
||||
@@ -548,6 +548,42 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// pack2x16snorm builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result pack2x16snorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// pack2x16unorm builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result pack2x16unorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// pack4x8snorm builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result pack4x8snorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// pack4x8unorm builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result pack4x8unorm(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// reverseBits builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
@@ -677,6 +713,14 @@ class ConstEval {
|
||||
NumberT b3,
|
||||
NumberT b4);
|
||||
|
||||
/// Clamps e between low and high
|
||||
/// @param e the number to clamp
|
||||
/// @param low the lower bound
|
||||
/// @param high the upper bound
|
||||
/// @returns the result number on success, or logs an error and returns Failure
|
||||
template <typename NumberT>
|
||||
utils::Result<NumberT> Clamp(NumberT e, NumberT low, NumberT high);
|
||||
|
||||
/// Returns a callable that calls Add, and creates a Constant with its result of type `elem_ty`
|
||||
/// if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
@@ -707,6 +751,12 @@ class ConstEval {
|
||||
/// @returns the callable function
|
||||
auto Dot4Func(const sem::Type* elem_ty);
|
||||
|
||||
/// Returns a callable that calls Clamp, and creates a Constant with its result of type
|
||||
/// `elem_ty` if successful, or returns Failure otherwise.
|
||||
/// @param elem_ty the element type of the Constant to create on success
|
||||
/// @returns the callable function
|
||||
auto ClampFunc(const sem::Type* elem_ty);
|
||||
|
||||
ProgramBuilder& builder;
|
||||
const Source* current_source = nullptr;
|
||||
};
|
||||
|
||||
@@ -1093,6 +1093,73 @@ INSTANTIATE_TEST_SUITE_P(ExtractBits,
|
||||
std::make_tuple(1000, 1000), //
|
||||
std::make_tuple(u32::Highest(), u32::Highest())));
|
||||
|
||||
std::vector<Case> Pack4x8snormCases() {
|
||||
return {
|
||||
C({Vec(f32(0), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0000))),
|
||||
C({Vec(f32(0), f32(0), f32(0), f32(-1))}, Val(u32(0x8100'0000))),
|
||||
C({Vec(f32(0), f32(0), f32(0), f32(1))}, Val(u32(0x7f00'0000))),
|
||||
C({Vec(f32(0), f32(0), f32(-1), f32(0))}, Val(u32(0x0081'0000))),
|
||||
C({Vec(f32(0), f32(1), f32(0), f32(0))}, Val(u32(0x0000'7f00))),
|
||||
C({Vec(f32(-1), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0081))),
|
||||
C({Vec(f32(1), f32(-1), f32(1), f32(-1))}, Val(u32(0x817f'817f))),
|
||||
C({Vec(f32::Highest(), f32(-0.5), f32(0.5), f32::Lowest())}, Val(u32(0x8140'c17f))),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Pack4x8snorm,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kPack4X8Snorm),
|
||||
testing::ValuesIn(Pack4x8snormCases())));
|
||||
|
||||
std::vector<Case> Pack4x8unormCases() {
|
||||
return {
|
||||
C({Vec(f32(0), f32(0), f32(0), f32(0))}, Val(u32(0x0000'0000))),
|
||||
C({Vec(f32(0), f32(0), f32(0), f32(1))}, Val(u32(0xff00'0000))),
|
||||
C({Vec(f32(0), f32(0), f32(1), f32(0))}, Val(u32(0x00ff'0000))),
|
||||
C({Vec(f32(0), f32(1), f32(0), f32(0))}, Val(u32(0x0000'ff00))),
|
||||
C({Vec(f32(1), f32(0), f32(0), f32(0))}, Val(u32(0x0000'00ff))),
|
||||
C({Vec(f32(1), f32(0), f32(1), f32(0))}, Val(u32(0x00ff'00ff))),
|
||||
C({Vec(f32::Highest(), f32(0), f32(0.5), f32::Lowest())}, Val(u32(0x0080'00ff))),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Pack4x8unorm,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kPack4X8Unorm),
|
||||
testing::ValuesIn(Pack4x8unormCases())));
|
||||
|
||||
std::vector<Case> Pack2x16snormCases() {
|
||||
return {
|
||||
C({Vec(f32(0), f32(0))}, Val(u32(0x0000'0000))),
|
||||
C({Vec(f32(0), f32(-1))}, Val(u32(0x8001'0000))),
|
||||
C({Vec(f32(0), f32(1))}, Val(u32(0x7fff'0000))),
|
||||
C({Vec(f32(-1), f32(0))}, Val(u32(0x0000'8001))),
|
||||
C({Vec(f32(1), f32(0))}, Val(u32(0x0000'7fff))),
|
||||
C({Vec(f32(1), f32(-1))}, Val(u32(0x8001'7fff))),
|
||||
C({Vec(f32::Highest(), f32::Lowest())}, Val(u32(0x8001'7fff))),
|
||||
C({Vec(f32(-0.5), f32(0.5))}, Val(u32(0x4000'c001))),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Pack2x16snorm,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Snorm),
|
||||
testing::ValuesIn(Pack2x16snormCases())));
|
||||
|
||||
std::vector<Case> Pack2x16unormCases() {
|
||||
return {
|
||||
C({Vec(f32(0), f32(1))}, Val(u32(0xffff'0000))),
|
||||
C({Vec(f32(1), f32(0))}, Val(u32(0x0000'ffff))),
|
||||
C({Vec(f32(0.5), f32(0))}, Val(u32(0x0000'8000))),
|
||||
C({Vec(f32::Highest(), f32::Lowest())}, Val(u32(0x0000'ffff))),
|
||||
};
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Pack2x16unorm,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Unorm),
|
||||
testing::ValuesIn(Pack2x16unormCases())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ReverseBitsCases() {
|
||||
using B = BitValues<T>;
|
||||
|
||||
@@ -13946,7 +13946,7 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* parameters */ &kParameters[878],
|
||||
/* return matcher indices */ &kMatcherIndices[95],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::pack4x8unorm,
|
||||
},
|
||||
{
|
||||
/* [468] */
|
||||
@@ -13958,7 +13958,7 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* parameters */ &kParameters[877],
|
||||
/* return matcher indices */ &kMatcherIndices[95],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::pack4x8snorm,
|
||||
},
|
||||
{
|
||||
/* [469] */
|
||||
@@ -13970,7 +13970,7 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* parameters */ &kParameters[868],
|
||||
/* return matcher indices */ &kMatcherIndices[95],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::pack2x16unorm,
|
||||
},
|
||||
{
|
||||
/* [470] */
|
||||
@@ -13982,7 +13982,7 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* parameters */ &kParameters[867],
|
||||
/* return matcher indices */ &kMatcherIndices[95],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::pack2x16snorm,
|
||||
},
|
||||
{
|
||||
/* [471] */
|
||||
|
||||
Reference in New Issue
Block a user