Add const-eval for pack and unpack of 2x16float.

This CL adds const-eval for pack and unpack of 2x16 float values.

Bug: tint:1581
Change-Id: I59a1925148124e628c3771ca96d309fad045f27d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109280
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-11-09 20:04:37 +00:00 committed by Dawn LUCI CQ
parent 00d0fd5e84
commit 5ac2a365d9
15 changed files with 4179 additions and 4105 deletions

View File

@ -507,7 +507,7 @@ fn mix<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
fn modf<T: f32_f16>(T) -> __modf_result<T> fn modf<T: f32_f16>(T) -> __modf_result<T>
fn modf<N: num, T: f32_f16>(vec<N, T>) -> __modf_result_vec<N, T> fn modf<N: num, T: f32_f16>(vec<N, T>) -> __modf_result_vec<N, T>
fn normalize<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T> fn normalize<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn pack2x16float(vec2<f32>) -> u32 @const fn pack2x16float(vec2<f32>) -> u32
@const fn pack2x16snorm(vec2<f32>) -> u32 @const fn pack2x16snorm(vec2<f32>) -> u32
@const fn pack2x16unorm(vec2<f32>) -> u32 @const fn pack2x16unorm(vec2<f32>) -> u32
@const fn pack4x8snorm(vec4<f32>) -> u32 @const fn pack4x8snorm(vec4<f32>) -> u32
@ -549,7 +549,7 @@ fn tanh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn transpose<M: num, N: num, T: f32_f16>(mat<M, N, T>) -> mat<N, M, T> fn transpose<M: num, N: num, T: f32_f16>(mat<M, N, T>) -> mat<N, M, T>
fn trunc<T: f32_f16>(T) -> T fn trunc<T: f32_f16>(T) -> T
fn trunc<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T> fn trunc<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn unpack2x16float(u32) -> vec2<f32> @const fn unpack2x16float(u32) -> vec2<f32>
@const fn unpack2x16snorm(u32) -> vec2<f32> @const fn unpack2x16snorm(u32) -> vec2<f32>
@const fn unpack2x16unorm(u32) -> vec2<f32> @const fn unpack2x16unorm(u32) -> vec2<f32>
@const fn unpack4x8snorm(u32) -> vec4<f32> @const fn unpack4x8snorm(u32) -> vec4<f32>

View File

@ -2076,6 +2076,34 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
} }
ConstEval::Result ConstEval::pack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto convert = [&](f32 val) -> utils::Result<uint32_t> {
auto conv = CheckedConvert<f16>(val);
if (!conv) {
AddError(OverflowErrorMessage(val, "f16"), source);
return utils::Failure;
}
uint16_t v = conv.Get().BitsRepresentation();
return utils::Result<uint32_t>{v};
};
auto* e = args[0];
auto e0 = convert(e->Index(0)->As<f32>());
if (!e0) {
return utils::Failure;
}
auto e1 = convert(e->Index(1)->As<f32>());
if (!e1) {
return utils::Failure;
}
u32 ret = u32((e0.Get() & 0x0000'ffff) | (e1.Get() << 16));
return CreateElement(builder, ty, ret);
}
ConstEval::Result ConstEval::pack2x16snorm(const sem::Type* ty, ConstEval::Result ConstEval::pack2x16snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {
@ -2254,6 +2282,26 @@ ConstEval::Result ConstEval::step(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]); return TransformElements(builder, ty, transform, args[0], args[1]);
} }
ConstEval::Result ConstEval::unpack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto* inner_ty = sem::Type::DeepestElementOf(ty);
auto e = args[0]->As<u32>().value;
utils::Vector<const sem::Constant*, 2> els;
els.Reserve(2);
for (size_t i = 0; i < 2; ++i) {
auto in = f16::FromBits(uint16_t((e >> (16 * i)) & 0x0000'ffff));
auto val = CheckedConvert<f32>(in);
if (!val) {
AddError(OverflowErrorMessage(in, "f32"), source);
return utils::Failure;
}
els.Push(CreateElement(builder, inner_ty, val.Get()));
}
return CreateComposite(builder, ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16snorm(const sem::Type* ty, ConstEval::Result ConstEval::unpack2x16snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {

View File

@ -557,6 +557,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// pack2x16float builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result pack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack2x16snorm builtin /// pack2x16snorm builtin
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments
@ -647,6 +656,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// unpack2x16float builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result unpack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// unpack2x16snorm builtin /// unpack2x16snorm builtin
/// @param ty the expression type /// @param ty the expression type
/// @param args the input arguments /// @param args the input arguments

View File

@ -1265,6 +1265,25 @@ INSTANTIATE_TEST_SUITE_P( //
testing::Combine(testing::Values(sem::BuiltinType::kPack4X8Unorm), testing::Combine(testing::Values(sem::BuiltinType::kPack4X8Unorm),
testing::ValuesIn(Pack4x8unormCases()))); testing::ValuesIn(Pack4x8unormCases())));
std::vector<Case> Pack2x16floatCases() {
return {
C({Vec(f32(f16::Lowest()), f32(f16::Highest()))}, Val(u32(0x7bff'fbff))),
C({Vec(f32(1), f32(-1))}, Val(u32(0xbc00'3c00))),
C({Vec(f32(0), f32(0))}, Val(u32(0x0000'0000))),
C({Vec(f32(10), f32(-10.5))}, Val(u32(0xc940'4900))),
E({Vec(f32(0), f32::Highest())},
"12:34 error: value 3.4028234663852885981e+38 cannot be represented as 'f16'"),
E({Vec(f32::Lowest(), f32(0))},
"12:34 error: value -3.4028234663852885981e+38 cannot be represented as 'f16'"),
};
}
INSTANTIATE_TEST_SUITE_P( //
Pack2x16float,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Float),
testing::ValuesIn(Pack2x16floatCases())));
std::vector<Case> Pack2x16snormCases() { std::vector<Case> Pack2x16snormCases() {
return { return {
C({Vec(f32(0), f32(0))}, Val(u32(0x0000'0000))), C({Vec(f32(0), f32(0))}, Val(u32(0x0000'0000))),
@ -1508,6 +1527,20 @@ INSTANTIATE_TEST_SUITE_P( //
testing::Combine(testing::Values(sem::BuiltinType::kUnpack4X8Unorm), testing::Combine(testing::Values(sem::BuiltinType::kUnpack4X8Unorm),
testing::ValuesIn(Unpack4x8unormCases()))); testing::ValuesIn(Unpack4x8unormCases())));
std::vector<Case> Unpack2x16floatCases() {
return {
C({Val(u32(0x7bff'fbff))}, Vec(f32(f16::Lowest()), f32(f16::Highest()))),
C({Val(u32(0xbc00'3c00))}, Vec(f32(1), f32(-1))),
C({Val(u32(0x0000'0000))}, Vec(f32(0), f32(0))),
C({Val(u32(0xc940'4900))}, Vec(f32(10), f32(-10.5))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Unpack2x16float,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kUnpack2X16Float),
testing::ValuesIn(Unpack2x16floatCases())));
std::vector<Case> Unpack2x16snormCases() { std::vector<Case> Unpack2x16snormCases() {
return { return {
C({Val(u32(0x0000'0000))}, Vec(f32(0), f32(0))), C({Val(u32(0x0000'0000))}, Vec(f32(0), f32(0))),

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,5 @@
uint tint_pack2x16float(float2 param_0) {
uint2 i = f32tof16(param_0);
return i.x | (i.y << 16);
}
void pack2x16float_0e97b3() { void pack2x16float_0e97b3() {
uint res = tint_pack2x16float((1.0f).xx); uint res = 1006648320u;
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,10 +1,5 @@
uint tint_pack2x16float(float2 param_0) {
uint2 i = f32tof16(param_0);
return i.x | (i.y << 16);
}
void pack2x16float_0e97b3() { void pack2x16float_0e97b3() {
uint res = tint_pack2x16float((1.0f).xx); uint res = 1006648320u;
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
void pack2x16float_0e97b3() { void pack2x16float_0e97b3() {
uint res = packHalf2x16(vec2(1.0f)); uint res = 1006648320u;
} }
vec4 vertex_main() { vec4 vertex_main() {
@ -21,7 +21,7 @@ void main() {
precision mediump float; precision mediump float;
void pack2x16float_0e97b3() { void pack2x16float_0e97b3() {
uint res = packHalf2x16(vec2(1.0f)); uint res = 1006648320u;
} }
void fragment_main() { void fragment_main() {
@ -35,7 +35,7 @@ void main() {
#version 310 es #version 310 es
void pack2x16float_0e97b3() { void pack2x16float_0e97b3() {
uint res = packHalf2x16(vec2(1.0f)); uint res = 1006648320u;
} }
void compute_main() { void compute_main() {

View File

@ -2,7 +2,7 @@
using namespace metal; using namespace metal;
void pack2x16float_0e97b3() { void pack2x16float_0e97b3() {
uint res = as_type<uint>(half2(float2(1.0f))); uint res = 1006648320u;
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,10 +1,9 @@
; SPIR-V ; SPIR-V
; Version: 1.3 ; Version: 1.3
; Generator: Google Tint Compiler; 0 ; Generator: Google Tint Compiler; 0
; Bound: 35 ; Bound: 32
; Schema: 0 ; Schema: 0
OpCapability Shader OpCapability Shader
%15 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450 OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %vertex_main "vertex_main" %value %vertex_point_size OpEntryPoint Vertex %vertex_main "vertex_main" %value %vertex_point_size
OpEntryPoint Fragment %fragment_main "fragment_main" OpEntryPoint Fragment %fragment_main "fragment_main"
@ -32,38 +31,36 @@
%void = OpTypeVoid %void = OpTypeVoid
%9 = OpTypeFunction %void %9 = OpTypeFunction %void
%uint = OpTypeInt 32 0 %uint = OpTypeInt 32 0
%v2float = OpTypeVector %float 2 %uint_1006648320 = OpConstant %uint 1006648320
%float_1 = OpConstant %float 1
%18 = OpConstantComposite %v2float %float_1 %float_1
%_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_uint = OpTypePointer Function %uint
%21 = OpConstantNull %uint %17 = OpConstantNull %uint
%22 = OpTypeFunction %v4float %18 = OpTypeFunction %v4float
%float_1 = OpConstant %float 1
%pack2x16float_0e97b3 = OpFunction %void None %9 %pack2x16float_0e97b3 = OpFunction %void None %9
%12 = OpLabel %12 = OpLabel
%res = OpVariable %_ptr_Function_uint Function %21 %res = OpVariable %_ptr_Function_uint Function %17
%13 = OpExtInst %uint %15 PackHalf2x16 %18 OpStore %res %uint_1006648320
OpStore %res %13
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%vertex_main_inner = OpFunction %v4float None %22 %vertex_main_inner = OpFunction %v4float None %18
%24 = OpLabel %20 = OpLabel
%25 = OpFunctionCall %void %pack2x16float_0e97b3 %21 = OpFunctionCall %void %pack2x16float_0e97b3
OpReturnValue %5 OpReturnValue %5
OpFunctionEnd OpFunctionEnd
%vertex_main = OpFunction %void None %9 %vertex_main = OpFunction %void None %9
%27 = OpLabel %23 = OpLabel
%28 = OpFunctionCall %v4float %vertex_main_inner %24 = OpFunctionCall %v4float %vertex_main_inner
OpStore %value %28 OpStore %value %24
OpStore %vertex_point_size %float_1 OpStore %vertex_point_size %float_1
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%fragment_main = OpFunction %void None %9 %fragment_main = OpFunction %void None %9
%27 = OpLabel
%28 = OpFunctionCall %void %pack2x16float_0e97b3
OpReturn
OpFunctionEnd
%compute_main = OpFunction %void None %9
%30 = OpLabel %30 = OpLabel
%31 = OpFunctionCall %void %pack2x16float_0e97b3 %31 = OpFunctionCall %void %pack2x16float_0e97b3
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%compute_main = OpFunction %void None %9
%33 = OpLabel
%34 = OpFunctionCall %void %pack2x16float_0e97b3
OpReturn
OpFunctionEnd

View File

@ -1,10 +1,5 @@
float2 tint_unpack2x16float(uint param_0) {
uint i = param_0;
return f16tof32(uint2(i & 0xffff, i >> 16));
}
void unpack2x16float_32a5cf() { void unpack2x16float_32a5cf() {
float2 res = tint_unpack2x16float(1u); float2 res = float2(5.96046448e-08f, 0.0f);
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,10 +1,5 @@
float2 tint_unpack2x16float(uint param_0) {
uint i = param_0;
return f16tof32(uint2(i & 0xffff, i >> 16));
}
void unpack2x16float_32a5cf() { void unpack2x16float_32a5cf() {
float2 res = tint_unpack2x16float(1u); float2 res = float2(5.96046448e-08f, 0.0f);
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
void unpack2x16float_32a5cf() { void unpack2x16float_32a5cf() {
vec2 res = unpackHalf2x16(1u); vec2 res = vec2(5.96046448e-08f, 0.0f);
} }
vec4 vertex_main() { vec4 vertex_main() {
@ -21,7 +21,7 @@ void main() {
precision mediump float; precision mediump float;
void unpack2x16float_32a5cf() { void unpack2x16float_32a5cf() {
vec2 res = unpackHalf2x16(1u); vec2 res = vec2(5.96046448e-08f, 0.0f);
} }
void fragment_main() { void fragment_main() {
@ -35,7 +35,7 @@ void main() {
#version 310 es #version 310 es
void unpack2x16float_32a5cf() { void unpack2x16float_32a5cf() {
vec2 res = unpackHalf2x16(1u); vec2 res = vec2(5.96046448e-08f, 0.0f);
} }
void compute_main() { void compute_main() {

View File

@ -2,7 +2,7 @@
using namespace metal; using namespace metal;
void unpack2x16float_32a5cf() { void unpack2x16float_32a5cf() {
float2 res = float2(as_type<half2>(1u)); float2 res = float2(5.96046448e-08f, 0.0f);
} }
struct tint_symbol { struct tint_symbol {

View File

@ -1,10 +1,9 @@
; SPIR-V ; SPIR-V
; Version: 1.3 ; Version: 1.3
; Generator: Google Tint Compiler; 0 ; Generator: Google Tint Compiler; 0
; Bound: 35 ; Bound: 33
; Schema: 0 ; Schema: 0
OpCapability Shader OpCapability Shader
%15 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450 OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %vertex_main "vertex_main" %value %vertex_point_size OpEntryPoint Vertex %vertex_main "vertex_main" %value %vertex_point_size
OpEntryPoint Fragment %fragment_main "fragment_main" OpEntryPoint Fragment %fragment_main "fragment_main"
@ -32,38 +31,37 @@
%void = OpTypeVoid %void = OpTypeVoid
%9 = OpTypeFunction %void %9 = OpTypeFunction %void
%v2float = OpTypeVector %float 2 %v2float = OpTypeVector %float 2
%uint = OpTypeInt 32 0 %float_5_96046448en08 = OpConstant %float 5.96046448e-08
%uint_1 = OpConstant %uint 1 %15 = OpConstantComposite %v2float %float_5_96046448en08 %8
%_ptr_Function_v2float = OpTypePointer Function %v2float %_ptr_Function_v2float = OpTypePointer Function %v2float
%20 = OpConstantNull %v2float %18 = OpConstantNull %v2float
%21 = OpTypeFunction %v4float %19 = OpTypeFunction %v4float
%float_1 = OpConstant %float 1 %float_1 = OpConstant %float 1
%unpack2x16float_32a5cf = OpFunction %void None %9 %unpack2x16float_32a5cf = OpFunction %void None %9
%12 = OpLabel %12 = OpLabel
%res = OpVariable %_ptr_Function_v2float Function %20 %res = OpVariable %_ptr_Function_v2float Function %18
%13 = OpExtInst %v2float %15 UnpackHalf2x16 %uint_1 OpStore %res %15
OpStore %res %13
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%vertex_main_inner = OpFunction %v4float None %21 %vertex_main_inner = OpFunction %v4float None %19
%23 = OpLabel %21 = OpLabel
%24 = OpFunctionCall %void %unpack2x16float_32a5cf %22 = OpFunctionCall %void %unpack2x16float_32a5cf
OpReturnValue %5 OpReturnValue %5
OpFunctionEnd OpFunctionEnd
%vertex_main = OpFunction %void None %9 %vertex_main = OpFunction %void None %9
%26 = OpLabel %24 = OpLabel
%27 = OpFunctionCall %v4float %vertex_main_inner %25 = OpFunctionCall %v4float %vertex_main_inner
OpStore %value %27 OpStore %value %25
OpStore %vertex_point_size %float_1 OpStore %vertex_point_size %float_1
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%fragment_main = OpFunction %void None %9 %fragment_main = OpFunction %void None %9
%30 = OpLabel %28 = OpLabel
%31 = OpFunctionCall %void %unpack2x16float_32a5cf %29 = OpFunctionCall %void %unpack2x16float_32a5cf
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%compute_main = OpFunction %void None %9 %compute_main = OpFunction %void None %9
%33 = OpLabel %31 = OpLabel
%34 = OpFunctionCall %void %unpack2x16float_32a5cf %32 = OpFunctionCall %void %unpack2x16float_32a5cf
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd