Add const-eval for pack and unpack of 2x16float.

This CL adds const-eval for pack and unpack of 2x16 float values.

Bug: tint:1581
Change-Id: I59a1925148124e628c3771ca96d309fad045f27d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109280
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-09 20:04:37 +00:00
committed by Dawn LUCI CQ
parent 00d0fd5e84
commit 5ac2a365d9
15 changed files with 4179 additions and 4105 deletions

View File

@@ -2076,6 +2076,34 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::pack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto convert = [&](f32 val) -> utils::Result<uint32_t> {
auto conv = CheckedConvert<f16>(val);
if (!conv) {
AddError(OverflowErrorMessage(val, "f16"), source);
return utils::Failure;
}
uint16_t v = conv.Get().BitsRepresentation();
return utils::Result<uint32_t>{v};
};
auto* e = args[0];
auto e0 = convert(e->Index(0)->As<f32>());
if (!e0) {
return utils::Failure;
}
auto e1 = convert(e->Index(1)->As<f32>());
if (!e1) {
return utils::Failure;
}
u32 ret = u32((e0.Get() & 0x0000'ffff) | (e1.Get() << 16));
return CreateElement(builder, ty, ret);
}
ConstEval::Result ConstEval::pack2x16snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
@@ -2254,6 +2282,26 @@ ConstEval::Result ConstEval::step(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::unpack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto* inner_ty = sem::Type::DeepestElementOf(ty);
auto e = args[0]->As<u32>().value;
utils::Vector<const sem::Constant*, 2> els;
els.Reserve(2);
for (size_t i = 0; i < 2; ++i) {
auto in = f16::FromBits(uint16_t((e >> (16 * i)) & 0x0000'ffff));
auto val = CheckedConvert<f32>(in);
if (!val) {
AddError(OverflowErrorMessage(in, "f32"), source);
return utils::Failure;
}
els.Push(CreateElement(builder, inner_ty, val.Get()));
}
return CreateComposite(builder, ty, std::move(els));
}
ConstEval::Result ConstEval::unpack2x16snorm(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -557,6 +557,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack2x16float builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result pack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// pack2x16snorm builtin
/// @param ty the expression type
/// @param args the input arguments
@@ -647,6 +656,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// unpack2x16float builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result unpack2x16float(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// unpack2x16snorm builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1265,6 +1265,25 @@ INSTANTIATE_TEST_SUITE_P( //
testing::Combine(testing::Values(sem::BuiltinType::kPack4X8Unorm),
testing::ValuesIn(Pack4x8unormCases())));
std::vector<Case> Pack2x16floatCases() {
return {
C({Vec(f32(f16::Lowest()), f32(f16::Highest()))}, Val(u32(0x7bff'fbff))),
C({Vec(f32(1), f32(-1))}, Val(u32(0xbc00'3c00))),
C({Vec(f32(0), f32(0))}, Val(u32(0x0000'0000))),
C({Vec(f32(10), f32(-10.5))}, Val(u32(0xc940'4900))),
E({Vec(f32(0), f32::Highest())},
"12:34 error: value 3.4028234663852885981e+38 cannot be represented as 'f16'"),
E({Vec(f32::Lowest(), f32(0))},
"12:34 error: value -3.4028234663852885981e+38 cannot be represented as 'f16'"),
};
}
INSTANTIATE_TEST_SUITE_P( //
Pack2x16float,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kPack2X16Float),
testing::ValuesIn(Pack2x16floatCases())));
std::vector<Case> Pack2x16snormCases() {
return {
C({Vec(f32(0), f32(0))}, Val(u32(0x0000'0000))),
@@ -1508,6 +1527,20 @@ INSTANTIATE_TEST_SUITE_P( //
testing::Combine(testing::Values(sem::BuiltinType::kUnpack4X8Unorm),
testing::ValuesIn(Unpack4x8unormCases())));
std::vector<Case> Unpack2x16floatCases() {
return {
C({Val(u32(0x7bff'fbff))}, Vec(f32(f16::Lowest()), f32(f16::Highest()))),
C({Val(u32(0xbc00'3c00))}, Vec(f32(1), f32(-1))),
C({Val(u32(0x0000'0000))}, Vec(f32(0), f32(0))),
C({Val(u32(0xc940'4900))}, Vec(f32(10), f32(-10.5))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Unpack2x16float,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kUnpack2X16Float),
testing::ValuesIn(Unpack2x16floatCases())));
std::vector<Case> Unpack2x16snormCases() {
return {
C({Val(u32(0x0000'0000))}, Vec(f32(0), f32(0))),

File diff suppressed because it is too large Load Diff