tint: Add f16 support for parts of float built-in, part 1

This patch add f16 support for a major part of numeric built-in, and
implement corresponding unittests for resolver and backends. This patch
also enable f16 constant evaluation for unary minus operator, `atan2`
and `clamp`.

The following numeric built-ins are not supported yet:
* frexp
* modf

The end-to-end tests for f16 built-in are not added yet.

Bug: tint:1473, tint:1502
Change-Id: If807185617b21c510a1a9c371179a60800c4f875
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/96722
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
This commit is contained in:
Zhaoming Jiang 2022-07-29 11:41:51 +00:00 committed by Dawn LUCI CQ
parent 65dcdcbad0
commit 6fe1f515d4
59 changed files with 4136 additions and 2714 deletions

View File

@ -165,6 +165,7 @@ match scalar_no_f16: f32 | i32 | u32 | bool
match scalar_no_i32: f32 | f16 | u32 | bool match scalar_no_i32: f32 | f16 | u32 | bool
match scalar_no_u32: f32 | f16 | i32 | bool match scalar_no_u32: f32 | f16 | i32 | bool
match scalar_no_bool: f32 | f16 | i32 | u32 match scalar_no_bool: f32 | f16 | i32 | u32
match fia_fiu32_f16: fa | ia | f32 | i32 | u32 | f16
match fia_fi32_f16: fa | ia | f32 | i32 | f16 match fia_fi32_f16: fa | ia | f32 | i32 | f16
match fia_fiu32: fa | ia | f32 | i32 | u32 match fia_fiu32: fa | ia | f32 | i32 | u32
match fa_f32: fa | f32 match fa_f32: fa | f32
@ -386,48 +387,48 @@ match storage
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// https://gpuweb.github.io/gpuweb/wgsl/#builtin-functions // https://gpuweb.github.io/gpuweb/wgsl/#builtin-functions
fn abs<T: fiu32>(T) -> T fn abs<T: fiu32_f16>(T) -> T
fn abs<N: num, T: fiu32>(vec<N, T>) -> vec<N, T> fn abs<N: num, T: fiu32_f16>(vec<N, T>) -> vec<N, T>
fn acos(f32) -> f32 fn acos<T: f32_f16>(T) -> T
fn acos<N: num>(vec<N, f32>) -> vec<N, f32> fn acos<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn acosh(f32) -> f32 fn acosh<T: f32_f16>(T) -> T
fn acosh<N: num>(vec<N, f32>) -> vec<N, f32> fn acosh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn all(bool) -> bool fn all(bool) -> bool
fn all<N: num>(vec<N, bool>) -> bool fn all<N: num>(vec<N, bool>) -> bool
fn any(bool) -> bool fn any(bool) -> bool
fn any<N: num>(vec<N, bool>) -> bool fn any<N: num>(vec<N, bool>) -> bool
fn arrayLength<T, A: access>(ptr<storage, array<T>, A>) -> u32 fn arrayLength<T, A: access>(ptr<storage, array<T>, A>) -> u32
fn asin(f32) -> f32 fn asin<T: f32_f16>(T) -> T
fn asin<N: num>(vec<N, f32>) -> vec<N, f32> fn asin<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn asinh(f32) -> f32 fn asinh<T: f32_f16>(T) -> T
fn asinh<N: num>(vec<N, f32>) -> vec<N, f32> fn asinh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn atan(f32) -> f32 fn atan<T: f32_f16>(T) -> T
fn atan<N: num>(vec<N, f32>) -> vec<N, f32> fn atan<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn atan2<T: fa_f32>(T, T) -> T @const fn atan2<T: fa_f32_f16>(T, T) -> T
@const fn atan2<N: num, T: fa_f32>(vec<N, T>, vec<N, T>) -> vec<N, T> @const fn atan2<T: fa_f32_f16, N: num>(vec<N, T>, vec<N, T>) -> vec<N, T>
fn atanh(f32) -> f32 fn atanh<T: f32_f16>(T) -> T
fn atanh<N: num>(vec<N, f32>) -> vec<N, f32> fn atanh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn ceil(f32) -> f32 fn ceil<T: f32_f16>(T) -> T
fn ceil<N: num>(vec<N, f32>) -> vec<N, f32> fn ceil<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn clamp<T: fia_fiu32>(T, T, T) -> T @const fn clamp<T: fia_fiu32_f16>(T, T, T) -> T
@const fn clamp<N: num, T: fia_fiu32>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T> @const fn clamp<T: fia_fiu32_f16, N: num>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn cos(f32) -> f32 fn cos<T: f32_f16>(T) -> T
fn cos<N: num>(vec<N, f32>) -> vec<N, f32> fn cos<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn cosh(f32) -> f32 fn cosh<T: f32_f16>(T) -> T
fn cosh<N: num>(vec<N, f32>) -> vec<N, f32> fn cosh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn countLeadingZeros<T: iu32>(T) -> T fn countLeadingZeros<T: iu32>(T) -> T
fn countLeadingZeros<N: num, T: iu32>(vec<N, T>) -> vec<N, T> fn countLeadingZeros<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn countOneBits<T: iu32>(T) -> T fn countOneBits<T: iu32>(T) -> T
fn countOneBits<N: num, T: iu32>(vec<N, T>) -> vec<N, T> fn countOneBits<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn countTrailingZeros<T: iu32>(T) -> T fn countTrailingZeros<T: iu32>(T) -> T
fn countTrailingZeros<N: num, T: iu32>(vec<N, T>) -> vec<N, T> fn countTrailingZeros<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn cross(vec3<f32>, vec3<f32>) -> vec3<f32> fn cross<T: f32_f16>(vec3<T>, vec3<T>) -> vec3<T>
fn degrees(f32) -> f32 fn degrees<T: f32_f16>(T) -> T
fn degrees<N: num>(vec<N, f32>) -> vec<N, f32> fn degrees<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn determinant<N: num>(mat<N, N, f32>) -> f32 fn determinant<N: num, T: f32_f16>(mat<N, N, T>) -> T
fn distance(f32, f32) -> f32 fn distance<T: f32_f16>(T, T) -> T
fn distance<N: num>(vec<N, f32>, vec<N, f32>) -> f32 fn distance<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> T
fn dot<N: num, T: fiu32>(vec<N, T>, vec<N, T>) -> T fn dot<N: num, T: fiu32_f16>(vec<N, T>, vec<N, T>) -> T
fn dot4I8Packed(u32, u32) -> i32 fn dot4I8Packed(u32, u32) -> i32
fn dot4U8Packed(u32, u32) -> u32 fn dot4U8Packed(u32, u32) -> u32
@stage("fragment") fn dpdx(f32) -> f32 @stage("fragment") fn dpdx(f32) -> f32
@ -442,23 +443,23 @@ fn dot4U8Packed(u32, u32) -> u32
@stage("fragment") fn dpdyCoarse<N: num>(vec<N, f32>) -> vec<N, f32> @stage("fragment") fn dpdyCoarse<N: num>(vec<N, f32>) -> vec<N, f32>
@stage("fragment") fn dpdyFine(f32) -> f32 @stage("fragment") fn dpdyFine(f32) -> f32
@stage("fragment") fn dpdyFine<N: num>(vec<N, f32>) -> vec<N, f32> @stage("fragment") fn dpdyFine<N: num>(vec<N, f32>) -> vec<N, f32>
fn exp(f32) -> f32 fn exp<T: f32_f16>(T) -> T
fn exp<N: num>(vec<N, f32>) -> vec<N, f32> fn exp<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn exp2(f32) -> f32 fn exp2<T: f32_f16>(T) -> T
fn exp2<N: num>(vec<N, f32>) -> vec<N, f32> fn exp2<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn extractBits<T: iu32>(T, u32, u32) -> T fn extractBits<T: iu32>(T, u32, u32) -> T
fn extractBits<N: num, T: iu32>(vec<N, T>, u32, u32) -> vec<N, T> fn extractBits<N: num, T: iu32>(vec<N, T>, u32, u32) -> vec<N, T>
fn faceForward<N: num>(vec<N, f32>, vec<N, f32>, vec<N, f32>) -> vec<N, f32> fn faceForward<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn firstLeadingBit<T: iu32>(T) -> T fn firstLeadingBit<T: iu32>(T) -> T
fn firstLeadingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T> fn firstLeadingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn firstTrailingBit<T: iu32>(T) -> T fn firstTrailingBit<T: iu32>(T) -> T
fn firstTrailingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T> fn firstTrailingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn floor(f32) -> f32 fn floor<T: f32_f16>(T) -> T
fn floor<N: num>(vec<N, f32>) -> vec<N, f32> fn floor<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn fma(f32, f32, f32) -> f32 fn fma<T: f32_f16>(T, T, T) -> T
fn fma<N: num>(vec<N, f32>, vec<N, f32>, vec<N, f32>) -> vec<N, f32> fn fma<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn fract(f32) -> f32 fn fract<T: f32_f16>(T) -> T
fn fract<N: num>(vec<N, f32>) -> vec<N, f32> fn fract<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn frexp(f32) -> __frexp_result fn frexp(f32) -> __frexp_result
fn frexp<N: num>(vec<N, f32>) -> __frexp_result_vec<N> fn frexp<N: num>(vec<N, f32>) -> __frexp_result_vec<N>
@stage("fragment") fn fwidth(f32) -> f32 @stage("fragment") fn fwidth(f32) -> f32
@ -469,64 +470,64 @@ fn frexp<N: num>(vec<N, f32>) -> __frexp_result_vec<N>
@stage("fragment") fn fwidthFine<N: num>(vec<N, f32>) -> vec<N, f32> @stage("fragment") fn fwidthFine<N: num>(vec<N, f32>) -> vec<N, f32>
fn insertBits<T: iu32>(T, T, u32, u32) -> T fn insertBits<T: iu32>(T, T, u32, u32) -> T
fn insertBits<N: num, T: iu32>(vec<N, T>, vec<N, T>, u32, u32) -> vec<N, T> fn insertBits<N: num, T: iu32>(vec<N, T>, vec<N, T>, u32, u32) -> vec<N, T>
fn inverseSqrt(f32) -> f32 fn inverseSqrt<T: f32_f16>(T) -> T
fn inverseSqrt<N: num>(vec<N, f32>) -> vec<N, f32> fn inverseSqrt<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn ldexp(f32, i32) -> f32 fn ldexp<T: f32_f16>(T, i32) -> T
fn ldexp<N: num>(vec<N, f32>, vec<N, i32>) -> vec<N, f32> fn ldexp<N: num, T: f32_f16>(vec<N, T>, vec<N, i32>) -> vec<N, T>
fn length(f32) -> f32 fn length<T: f32_f16>(T) -> T
fn length<N: num>(vec<N, f32>) -> f32 fn length<N: num, T: f32_f16>(vec<N, T>) -> T
fn log(f32) -> f32 fn log<T: f32_f16>(T) -> T
fn log<N: num>(vec<N, f32>) -> vec<N, f32> fn log<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn log2(f32) -> f32 fn log2<T: f32_f16>(T) -> T
fn log2<N: num>(vec<N, f32>) -> vec<N, f32> fn log2<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn max<T: fiu32>(T, T) -> T fn max<T: fiu32_f16>(T, T) -> T
fn max<N: num, T: fiu32>(vec<N, T>, vec<N, T>) -> vec<N, T> fn max<N: num, T: fiu32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
fn min<T: fiu32>(T, T) -> T fn min<T: fiu32_f16>(T, T) -> T
fn min<N: num, T: fiu32>(vec<N, T>, vec<N, T>) -> vec<N, T> fn min<N: num, T: fiu32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
fn mix(f32, f32, f32) -> f32 fn mix<T: f32_f16>(T, T, T) -> T
fn mix<N: num>(vec<N, f32>, vec<N, f32>, vec<N, f32>) -> vec<N, f32> fn mix<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn mix<N: num>(vec<N, f32>, vec<N, f32>, f32) -> vec<N, f32> fn mix<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
fn modf(f32) -> __modf_result fn modf(f32) -> __modf_result
fn modf<N: num>(vec<N, f32>) -> __modf_result_vec<N> fn modf<N: num>(vec<N, f32>) -> __modf_result_vec<N>
fn normalize<N: num>(vec<N, f32>) -> vec<N, f32> fn normalize<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn pack2x16float(vec2<f32>) -> u32 fn pack2x16float(vec2<f32>) -> u32
fn pack2x16snorm(vec2<f32>) -> u32 fn pack2x16snorm(vec2<f32>) -> u32
fn pack2x16unorm(vec2<f32>) -> u32 fn pack2x16unorm(vec2<f32>) -> u32
fn pack4x8snorm(vec4<f32>) -> u32 fn pack4x8snorm(vec4<f32>) -> u32
fn pack4x8unorm(vec4<f32>) -> u32 fn pack4x8unorm(vec4<f32>) -> u32
fn pow(f32, f32) -> f32 fn pow<T: f32_f16>(T, T) -> T
fn pow<N: num>(vec<N, f32>, vec<N, f32>) -> vec<N, f32> fn pow<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
fn radians(f32) -> f32 fn radians<T: f32_f16>(T) -> T
fn radians<N: num>(vec<N, f32>) -> vec<N, f32> fn radians<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn reflect<N: num>(vec<N, f32>, vec<N, f32>) -> vec<N, f32> fn reflect<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
fn refract<N: num>(vec<N, f32>, vec<N, f32>, f32) -> vec<N, f32> fn refract<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, T) -> vec<N, T>
fn reverseBits<T: iu32>(T) -> T fn reverseBits<T: iu32>(T) -> T
fn reverseBits<N: num, T: iu32>(vec<N, T>) -> vec<N, T> fn reverseBits<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn round(f32) -> f32 fn round<T: f32_f16>(T) -> T
fn round<N: num>(vec<N, f32>) -> vec<N, f32> fn round<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn select<T: scalar_no_f16>(T, T, bool) -> T fn select<T: scalar>(T, T, bool) -> T
fn select<T: scalar_no_f16, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T> fn select<T: scalar, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T>
fn select<N: num, T: scalar_no_f16>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T> fn select<N: num, T: scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T>
fn sign(f32) -> f32 fn sign<T: f32_f16>(T) -> T
fn sign<N: num>(vec<N, f32>) -> vec<N, f32> fn sign<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn sin(f32) -> f32 fn sin<T: f32_f16>(T) -> T
fn sin<N: num>(vec<N, f32>) -> vec<N, f32> fn sin<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn sinh(f32) -> f32 fn sinh<T: f32_f16>(T) -> T
fn sinh<N: num>(vec<N, f32>) -> vec<N, f32> fn sinh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn smoothstep(f32, f32, f32) -> f32 fn smoothstep<T: f32_f16>(T, T, T) -> T
fn smoothstep<N: num>(vec<N, f32>, vec<N, f32>, vec<N, f32>) -> vec<N, f32> fn smoothstep<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn sqrt(f32) -> f32 fn sqrt<T: f32_f16>(T) -> T
fn sqrt<N: num>(vec<N, f32>) -> vec<N, f32> fn sqrt<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn step(f32, f32) -> f32 fn step<T: f32_f16>(T, T) -> T
fn step<N: num>(vec<N, f32>, vec<N, f32>) -> vec<N, f32> fn step<N: num, T: f32_f16>(vec<N, T>, vec<N, T>) -> vec<N, T>
@stage("compute") fn storageBarrier() @stage("compute") fn storageBarrier()
fn tan(f32) -> f32 fn tan<T: f32_f16>(T) -> T
fn tan<N: num>(vec<N, f32>) -> vec<N, f32> fn tan<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn tanh(f32) -> f32 fn tanh<T: f32_f16>(T) -> T
fn tanh<N: num>(vec<N, f32>) -> vec<N, f32> fn tanh<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn transpose<M: num, N: num>(mat<M, N, f32>) -> mat<N, M, f32> fn transpose<M: num, N: num, T: f32_f16>(mat<M, N, T>) -> mat<N, M, T>
fn trunc(f32) -> f32 fn trunc<T: f32_f16>(T) -> T
fn trunc<N: num>(vec<N, f32>) -> vec<N, f32> fn trunc<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn unpack2x16float(u32) -> vec2<f32> fn unpack2x16float(u32) -> vec2<f32>
fn unpack2x16snorm(u32) -> vec2<f32> fn unpack2x16snorm(u32) -> vec2<f32>
fn unpack2x16unorm(u32) -> vec2<f32> fn unpack2x16unorm(u32) -> vec2<f32>

View File

@ -129,9 +129,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_NoParams) {
R"(error: no matching call to select() R"(error: no matching call to select()
3 candidate functions: 3 candidate functions:
select(T, T, bool) -> T where: T is f32, i32, u32 or bool select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
)"); )");
} }
@ -145,9 +145,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_SelectorInt) {
R"(error: no matching call to select(i32, i32, i32) R"(error: no matching call to select(i32, i32, i32)
3 candidate functions: 3 candidate functions:
select(T, T, bool) -> T where: T is f32, i32, u32 or bool select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
)"); )");
} }
@ -162,9 +162,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_Matrix) {
R"(error: no matching call to select(mat2x2<f32>, mat2x2<f32>, bool) R"(error: no matching call to select(mat2x2<f32>, mat2x2<f32>, bool)
3 candidate functions: 3 candidate functions:
select(T, T, bool) -> T where: T is f32, i32, u32 or bool select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
)"); )");
} }
@ -178,9 +178,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_MismatchTypes) {
R"(error: no matching call to select(f32, vec2<f32>, bool) R"(error: no matching call to select(f32, vec2<f32>, bool)
3 candidate functions: 3 candidate functions:
select(T, T, bool) -> T where: T is f32, i32, u32 or bool select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
)"); )");
} }
@ -194,9 +194,9 @@ TEST_F(ResolverBuiltinTest, Select_Error_MismatchVectorSize) {
R"(error: no matching call to select(vec2<f32>, vec3<f32>, bool) R"(error: no matching call to select(vec2<f32>, vec3<f32>, bool)
3 candidate functions: 3 candidate functions:
select(T, T, bool) -> T where: T is f32, i32, u32 or bool select(T, T, bool) -> T where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, bool) -> vecN<T> where: T is f32, f16, i32, u32 or bool
select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, i32, u32 or bool select(vecN<T>, vecN<T>, vecN<bool>) -> vecN<T> where: T is f32, f16, i32, u32 or bool
)"); )");
} }
@ -458,6 +458,206 @@ TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, FourParams_Vector_f32) {
} }
} }
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, OneParam_Scalar_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, 1_h);
WrapInFunction(call);
if (param.args_number == 1u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(),
HasSubstr("error: no matching call to " + std::string(param.name) + "(f16)"));
}
}
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, OneParam_Vector_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, vec3<f16>(1_h, 1_h, 3_h));
WrapInFunction(call);
if (param.args_number == 1u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->is_float_vector());
EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
ASSERT_NE(TypeOf(call)->As<sem::Vector>()->type(), nullptr);
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
std::string(param.name) + "(vec3<f16>)"));
}
}
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, TwoParams_Scalar_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, 1_h, 1_h);
WrapInFunction(call);
if (param.args_number == 2u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
std::string(param.name) + "(f16, f16)"));
}
}
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, TwoParams_Vector_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, vec3<f16>(1_h, 1_h, 3_h), vec3<f16>(1_h, 1_h, 3_h));
WrapInFunction(call);
if (param.args_number == 2u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->is_float_vector());
EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
ASSERT_NE(TypeOf(call)->As<sem::Vector>()->type(), nullptr);
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
std::string(param.name) + "(vec3<f16>, vec3<f16>)"));
}
}
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, ThreeParams_Scalar_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, 1_h, 1_h, 1_h);
WrapInFunction(call);
if (param.args_number == 3u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
std::string(param.name) + "(f16, f16, f16)"));
}
}
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, ThreeParams_Vector_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, vec3<f16>(1_h, 1_h, 3_h), vec3<f16>(1_h, 1_h, 3_h),
vec3<f16>(1_h, 1_h, 3_h));
WrapInFunction(call);
if (param.args_number == 3u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->is_float_vector());
EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
ASSERT_NE(TypeOf(call)->As<sem::Vector>()->type(), nullptr);
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(),
HasSubstr("error: no matching call to " + std::string(param.name) +
"(vec3<f16>, vec3<f16>, vec3<f16>)"));
}
}
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, FourParams_Scalar_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, 1_h, 1_h, 1_h, 1_h);
WrapInFunction(call);
if (param.args_number == 4u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("error: no matching call to " +
std::string(param.name) + "(f16, f16, f16, f16)"));
}
}
TEST_P(ResolverBuiltinTest_FloatBuiltin_IdenticalType, FourParams_Vector_f16) {
auto param = GetParam();
Enable(ast::Extension::kF16);
auto* call = Call(param.name, vec3<f16>(1_h, 1_h, 3_h), vec3<f16>(1_h, 1_h, 3_h),
vec3<f16>(1_h, 1_h, 3_h), vec3<f16>(1_h, 1_h, 3_h));
WrapInFunction(call);
if (param.args_number == 4u) {
// Parameter count matched.
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->is_float_vector());
EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
ASSERT_NE(TypeOf(call)->As<sem::Vector>()->type(), nullptr);
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F16>());
} else {
// Invalid parameter count.
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(),
HasSubstr("error: no matching call to " + std::string(param.name) +
"(vec3<f16>, vec3<f16>, vec3<f16>, vec3<f16>)"));
}
}
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
ResolverTest, ResolverTest,
ResolverBuiltinTest_FloatBuiltin_IdenticalType, ResolverBuiltinTest_FloatBuiltin_IdenticalType,
@ -526,6 +726,20 @@ TEST_F(ResolverBuiltinFloatTest, Cross_f32) {
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinFloatTest, Cross_f16) {
Enable(ast::Extension::kF16);
auto* call = Call("cross", vec3<f16>(1_h, 2_h, 3_h), vec3<f16>(1_h, 2_h, 3_h));
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->is_float_vector());
EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F16>());
}
TEST_F(ResolverBuiltinFloatTest, Cross_Error_NoArgs) { TEST_F(ResolverBuiltinFloatTest, Cross_Error_NoArgs) {
auto* call = Call("cross"); auto* call = Call("cross");
WrapInFunction(call); WrapInFunction(call);
@ -535,7 +749,7 @@ TEST_F(ResolverBuiltinFloatTest, Cross_Error_NoArgs) {
EXPECT_EQ(r()->error(), R"(error: no matching call to cross() EXPECT_EQ(r()->error(), R"(error: no matching call to cross()
1 candidate function: 1 candidate function:
cross(vec3<f32>, vec3<f32>) -> vec3<f32> cross(vec3<T>, vec3<T>) -> vec3<T> where: T is f32 or f16
)"); )");
} }
@ -548,7 +762,7 @@ TEST_F(ResolverBuiltinFloatTest, Cross_Error_Scalar) {
EXPECT_EQ(r()->error(), R"(error: no matching call to cross(f32, f32) EXPECT_EQ(r()->error(), R"(error: no matching call to cross(f32, f32)
1 candidate function: 1 candidate function:
cross(vec3<f32>, vec3<f32>) -> vec3<f32> cross(vec3<T>, vec3<T>) -> vec3<T> where: T is f32 or f16
)"); )");
} }
@ -562,7 +776,7 @@ TEST_F(ResolverBuiltinFloatTest, Cross_Error_Vec3Int) {
R"(error: no matching call to cross(vec3<i32>, vec3<i32>) R"(error: no matching call to cross(vec3<i32>, vec3<i32>)
1 candidate function: 1 candidate function:
cross(vec3<f32>, vec3<f32>) -> vec3<f32> cross(vec3<T>, vec3<T>) -> vec3<T> where: T is f32 or f16
)"); )");
} }
@ -577,7 +791,7 @@ TEST_F(ResolverBuiltinFloatTest, Cross_Error_Vec4) {
R"(error: no matching call to cross(vec4<f32>, vec4<f32>) R"(error: no matching call to cross(vec4<f32>, vec4<f32>)
1 candidate function: 1 candidate function:
cross(vec3<f32>, vec3<f32>) -> vec3<f32> cross(vec3<T>, vec3<T>) -> vec3<T> where: T is f32 or f16
)"); )");
} }
@ -593,7 +807,7 @@ TEST_F(ResolverBuiltinFloatTest, Cross_Error_TooManyParams) {
R"(error: no matching call to cross(vec3<f32>, vec3<f32>, vec3<f32>) R"(error: no matching call to cross(vec3<f32>, vec3<f32>, vec3<f32>)
1 candidate function: 1 candidate function:
cross(vec3<f32>, vec3<f32>) -> vec3<f32> cross(vec3<T>, vec3<T>) -> vec3<T> where: T is f32 or f16
)"); )");
} }
@ -608,6 +822,18 @@ TEST_F(ResolverBuiltinFloatTest, Distance_Scalar_f32) {
EXPECT_TRUE(TypeOf(call)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinFloatTest, Distance_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* call = Call("distance", 1_h, 1_h);
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinFloatTest, Distance_Vector_f32) { TEST_F(ResolverBuiltinFloatTest, Distance_Vector_f32) {
auto* call = Call("distance", vec3<f32>(1_f, 1_f, 3_f), vec3<f32>(1_f, 1_f, 3_f)); auto* call = Call("distance", vec3<f32>(1_f, 1_f, 3_f), vec3<f32>(1_f, 1_f, 3_f));
WrapInFunction(call); WrapInFunction(call);
@ -618,31 +844,44 @@ TEST_F(ResolverBuiltinFloatTest, Distance_Vector_f32) {
EXPECT_TRUE(TypeOf(call)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinFloatTest, Distance_Vector_f16) {
Enable(ast::Extension::kF16);
auto* call = Call("distance", vec3<f16>(1_h, 1_h, 3_h), vec3<f16>(1_h, 1_h, 3_h));
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinFloatTest, Distance_TooManyParams) { TEST_F(ResolverBuiltinFloatTest, Distance_TooManyParams) {
auto* call = Call("distance", 1_f, 1_f, 3_f); auto* call = Call("distance", vec3<f32>(1_f, 1_f, 3_f), vec3<f32>(1_f, 1_f, 3_f),
vec3<f32>(1_f, 1_f, 3_f));
WrapInFunction(call); WrapInFunction(call);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(error: no matching call to distance(f32, f32, f32) EXPECT_EQ(r()->error(), R"(error: no matching call to distance(vec3<f32>, vec3<f32>, vec3<f32>)
2 candidate functions: 2 candidate functions:
distance(f32, f32) -> f32 distance(T, T) -> T where: T is f32 or f16
distance(vecN<f32>, vecN<f32>) -> f32 distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
)"); )");
} }
TEST_F(ResolverBuiltinFloatTest, Distance_TooFewParams) { TEST_F(ResolverBuiltinFloatTest, Distance_TooFewParams) {
auto* call = Call("distance", 1_f); auto* call = Call("distance", vec3<f32>(1_f, 1_f, 3_f));
WrapInFunction(call); WrapInFunction(call);
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(error: no matching call to distance(f32) EXPECT_EQ(r()->error(), R"(error: no matching call to distance(vec3<f32>)
2 candidate functions: 2 candidate functions:
distance(f32, f32) -> f32 distance(T, T) -> T where: T is f32 or f16
distance(vecN<f32>, vecN<f32>) -> f32 distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
)"); )");
} }
@ -655,8 +894,8 @@ TEST_F(ResolverBuiltinFloatTest, Distance_NoParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to distance() EXPECT_EQ(r()->error(), R"(error: no matching call to distance()
2 candidate functions: 2 candidate functions:
distance(f32, f32) -> f32 distance(T, T) -> T where: T is f32 or f16
distance(vecN<f32>, vecN<f32>) -> f32 distance(vecN<T>, vecN<T>) -> T where: T is f32 or f16
)"); )");
} }
@ -796,6 +1035,18 @@ TEST_F(ResolverBuiltinFloatTest, Length_Scalar_f32) {
EXPECT_TRUE(TypeOf(call)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinFloatTest, Length_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* call = Call("length", 1_h);
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinFloatTest, Length_FloatVector_f32) { TEST_F(ResolverBuiltinFloatTest, Length_FloatVector_f32) {
auto* call = Call("length", vec3<f32>(1_f, 1_f, 3_f)); auto* call = Call("length", vec3<f32>(1_f, 1_f, 3_f));
WrapInFunction(call); WrapInFunction(call);
@ -806,6 +1057,18 @@ TEST_F(ResolverBuiltinFloatTest, Length_FloatVector_f32) {
EXPECT_TRUE(TypeOf(call)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinFloatTest, Length_FloatVector_f16) {
Enable(ast::Extension::kF16);
auto* call = Call("length", vec3<f16>(1_h, 1_h, 3_h));
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinFloatTest, Length_NoParams) { TEST_F(ResolverBuiltinFloatTest, Length_NoParams) {
auto* call = Call("length"); auto* call = Call("length");
WrapInFunction(call); WrapInFunction(call);
@ -815,8 +1078,8 @@ TEST_F(ResolverBuiltinFloatTest, Length_NoParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to length() EXPECT_EQ(r()->error(), R"(error: no matching call to length()
2 candidate functions: 2 candidate functions:
length(f32) -> f32 length(T) -> T where: T is f32 or f16
length(vecN<f32>) -> f32 length(vecN<T>) -> T where: T is f32 or f16
)"); )");
} }
@ -829,8 +1092,8 @@ TEST_F(ResolverBuiltinFloatTest, Length_TooManyParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to length(f32, f32) EXPECT_EQ(r()->error(), R"(error: no matching call to length(f32, f32)
2 candidate functions: 2 candidate functions:
length(f32) -> f32 length(T) -> T where: T is f32 or f16
length(vecN<f32>) -> f32 length(vecN<T>) -> T where: T is f32 or f16
)"); )");
} }
@ -849,6 +1112,21 @@ TEST_F(ResolverBuiltinFloatTest, Mix_VectorScalar_f32) {
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinFloatTest, Mix_VectorScalar_f16) {
Enable(ast::Extension::kF16);
auto* call = Call("mix", vec3<f16>(1_h, 1_h, 1_h), vec3<f16>(1_h, 1_h, 1_h), 4_h);
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->is_float_vector());
EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
ASSERT_NE(TypeOf(call)->As<sem::Vector>()->type(), nullptr);
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F16>());
}
// modf: (f32) -> __modf_result, (vecN<f32>) -> __modf_result_vecN // modf: (f32) -> __modf_result, (vecN<f32>) -> __modf_result_vecN
TEST_F(ResolverBuiltinFloatTest, ModfScalar) { TEST_F(ResolverBuiltinFloatTest, ModfScalar) {
auto* call = Call("modf", 1_f); auto* call = Call("modf", 1_f);
@ -987,6 +1265,20 @@ TEST_F(ResolverBuiltinFloatTest, Normalize_Vector_f32) {
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinFloatTest, Normalize_Vector_f16) {
Enable(ast::Extension::kF16);
auto* call = Call("normalize", vec3<f16>(1_h, 1_h, 3_h));
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->is_float_vector());
EXPECT_EQ(TypeOf(call)->As<sem::Vector>()->Width(), 3u);
EXPECT_TRUE(TypeOf(call)->As<sem::Vector>()->type()->Is<sem::F16>());
}
TEST_F(ResolverBuiltinFloatTest, Normalize_Error_NoParams) { TEST_F(ResolverBuiltinFloatTest, Normalize_Error_NoParams) {
auto* call = Call("normalize"); auto* call = Call("normalize");
WrapInFunction(call); WrapInFunction(call);
@ -996,7 +1288,7 @@ TEST_F(ResolverBuiltinFloatTest, Normalize_Error_NoParams) {
EXPECT_EQ(r()->error(), R"(error: no matching call to normalize() EXPECT_EQ(r()->error(), R"(error: no matching call to normalize()
1 candidate function: 1 candidate function:
normalize(vecN<f32>) -> vecN<f32> normalize(vecN<T>) -> vecN<T> where: T is f32 or f16
)"); )");
} }
@ -1436,6 +1728,20 @@ TEST_F(ResolverBuiltinTest, Determinant_2x2_f32) {
EXPECT_TRUE(TypeOf(call)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinTest, Determinant_2x2_f16) {
Enable(ast::Extension::kF16);
GlobalVar("var", ty.mat2x2<f16>(), ast::StorageClass::kPrivate);
auto* call = Call("determinant", "var");
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinTest, Determinant_3x3_f32) { TEST_F(ResolverBuiltinTest, Determinant_3x3_f32) {
GlobalVar("var", ty.mat3x3<f32>(), ast::StorageClass::kPrivate); GlobalVar("var", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
@ -1448,6 +1754,20 @@ TEST_F(ResolverBuiltinTest, Determinant_3x3_f32) {
EXPECT_TRUE(TypeOf(call)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinTest, Determinant_3x3_f16) {
Enable(ast::Extension::kF16);
GlobalVar("var", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* call = Call("determinant", "var");
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinTest, Determinant_4x4_f32) { TEST_F(ResolverBuiltinTest, Determinant_4x4_f32) {
GlobalVar("var", ty.mat4x4<f32>(), ast::StorageClass::kPrivate); GlobalVar("var", ty.mat4x4<f32>(), ast::StorageClass::kPrivate);
@ -1460,6 +1780,20 @@ TEST_F(ResolverBuiltinTest, Determinant_4x4_f32) {
EXPECT_TRUE(TypeOf(call)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinTest, Determinant_4x4_f16) {
Enable(ast::Extension::kF16);
GlobalVar("var", ty.mat4x4<f16>(), ast::StorageClass::kPrivate);
auto* call = Call("determinant", "var");
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(call), nullptr);
EXPECT_TRUE(TypeOf(call)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinTest, Determinant_NotSquare) { TEST_F(ResolverBuiltinTest, Determinant_NotSquare) {
GlobalVar("var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate); GlobalVar("var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate);
@ -1471,7 +1805,7 @@ TEST_F(ResolverBuiltinTest, Determinant_NotSquare) {
EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(mat2x3<f32>) EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(mat2x3<f32>)
1 candidate function: 1 candidate function:
determinant(matNxN<f32>) -> f32 determinant(matNxN<T>) -> T where: T is f32 or f16
)"); )");
} }
@ -1486,7 +1820,7 @@ TEST_F(ResolverBuiltinTest, Determinant_NotMatrix) {
EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(f32) EXPECT_EQ(r()->error(), R"(error: no matching call to determinant(f32)
1 candidate function: 1 candidate function:
determinant(matNxN<f32>) -> f32 determinant(matNxN<T>) -> T where: T is f32 or f16
)"); )");
} }
@ -1507,6 +1841,20 @@ TEST_F(ResolverBuiltinTest, Dot_Vec2_f32) {
EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>()); EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
} }
TEST_F(ResolverBuiltinTest, Dot_Vec2_f16) {
Enable(ast::Extension::kF16);
GlobalVar("my_var", ty.vec2<f16>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "my_var", "my_var");
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
EXPECT_TRUE(TypeOf(expr)->Is<sem::F16>());
}
TEST_F(ResolverBuiltinTest, Dot_Vec3_i32) { TEST_F(ResolverBuiltinTest, Dot_Vec3_i32) {
GlobalVar("my_var", ty.vec3<i32>(), ast::StorageClass::kPrivate); GlobalVar("my_var", ty.vec3<i32>(), ast::StorageClass::kPrivate);
@ -1541,7 +1889,7 @@ TEST_F(ResolverBuiltinTest, Dot_Error_Scalar) {
R"(error: no matching call to dot(f32, f32) R"(error: no matching call to dot(f32, f32)
1 candidate function: 1 candidate function:
dot(vecN<T>, vecN<T>) -> T where: T is f32, i32 or u32 dot(vecN<T>, vecN<T>) -> T where: T is f32, i32, u32 or f16
)"); )");
} }

View File

@ -73,33 +73,32 @@ auto Dispatch_fia_fi32_f16(F&& f, CONSTANTS&&... cs) {
[&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); }, [&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); },
[&](const sem::F32*) { return f(cs->template As<f32>()...); }, [&](const sem::F32*) { return f(cs->template As<f32>()...); },
[&](const sem::I32*) { return f(cs->template As<i32>()...); }, [&](const sem::I32*) { return f(cs->template As<i32>()...); },
[&](const sem::F16*) { [&](const sem::F16*) { return f(cs->template As<f16>()...); });
// TODO(crbug.com/tint/1502): Support const eval for f16
return nullptr;
});
} }
/// Helper that calls `f` passing in the value of all `cs`. /// Helper that calls `f` passing in the value of all `cs`.
/// Assumes all `cs` are of the same type. /// Assumes all `cs` are of the same type.
template <typename F, typename... CONSTANTS> template <typename F, typename... CONSTANTS>
auto Dispatch_fia_fiu32(F&& f, CONSTANTS&&... cs) { auto Dispatch_fia_fiu32_f16(F&& f, CONSTANTS&&... cs) {
return Switch( return Switch(
First(cs...)->Type(), // First(cs...)->Type(), //
[&](const sem::AbstractInt*) { return f(cs->template As<AInt>()...); }, [&](const sem::AbstractInt*) { return f(cs->template As<AInt>()...); },
[&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); }, [&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); },
[&](const sem::F32*) { return f(cs->template As<f32>()...); }, [&](const sem::F32*) { return f(cs->template As<f32>()...); },
[&](const sem::I32*) { return f(cs->template As<i32>()...); }, [&](const sem::I32*) { return f(cs->template As<i32>()...); },
[&](const sem::U32*) { return f(cs->template As<u32>()...); }); [&](const sem::U32*) { return f(cs->template As<u32>()...); },
[&](const sem::F16*) { return f(cs->template As<f16>()...); });
} }
/// Helper that calls `f` passing in the value of all `cs`. /// Helper that calls `f` passing in the value of all `cs`.
/// Assumes all `cs` are of the same type. /// Assumes all `cs` are of the same type.
template <typename F, typename... CONSTANTS> template <typename F, typename... CONSTANTS>
auto Dispatch_fa_f32(F&& f, CONSTANTS&&... cs) { auto Dispatch_fa_f32_f16(F&& f, CONSTANTS&&... cs) {
return Switch( return Switch(
First(cs...)->Type(), // First(cs...)->Type(), //
[&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); }, [&](const sem::AbstractFloat*) { return f(cs->template As<AFloat>()...); },
[&](const sem::F32*) { return f(cs->template As<f32>()...); }); [&](const sem::F32*) { return f(cs->template As<f32>()...); },
[&](const sem::F16*) { return f(cs->template As<f16>()...); });
} }
/// ZeroTypeDispatch is a helper for calling the function `f`, passing a single zero-value argument /// ZeroTypeDispatch is a helper for calling the function `f`, passing a single zero-value argument
@ -730,7 +729,7 @@ const sem::Constant* ConstEval::atan2(const sem::Type*,
auto create = [&](auto i, auto j) { auto create = [&](auto i, auto j) {
return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value))); return CreateElement(builder, c0->Type(), decltype(i)(std::atan2(i.value, j.value)));
}; };
return Dispatch_fa_f32(create, c0, c1); return Dispatch_fa_f32_f16(create, c0, c1);
}; };
return TransformElements(builder, transform, args[0]->ConstantValue(), return TransformElements(builder, transform, args[0]->ConstantValue(),
args[1]->ConstantValue()); args[1]->ConstantValue());
@ -744,7 +743,7 @@ const sem::Constant* ConstEval::clamp(const sem::Type*,
return CreateElement(builder, c0->Type(), return CreateElement(builder, c0->Type(),
decltype(e)(std::min(std::max(e, low), high))); decltype(e)(std::min(std::max(e, low), high)));
}; };
return Dispatch_fia_fiu32(create, c0, c1, c2); return Dispatch_fia_fiu32_f16(create, c0, c1, c2);
}; };
return TransformElements(builder, transform, args[0]->ConstantValue(), args[1]->ConstantValue(), return TransformElements(builder, transform, args[0]->ConstantValue(), args[1]->ConstantValue(),
args[2]->ConstantValue()); args[2]->ConstantValue());

View File

@ -2994,7 +2994,8 @@ struct Values {
}; };
struct Case { struct Case {
std::variant<Values<AInt>, Values<AFloat>, Values<u32>, Values<i32>, Values<f32>> values; std::variant<Values<AInt>, Values<AFloat>, Values<u32>, Values<i32>, Values<f32>, Values<f16>>
values;
}; };
static std::ostream& operator<<(std::ostream& o, const Case& c) { static std::ostream& operator<<(std::ostream& o, const Case& c) {
@ -3101,6 +3102,15 @@ INSTANTIATE_TEST_SUITE_P(Negation,
C(-kHighest<f32>, kHighest<f32>), C(-kHighest<f32>, kHighest<f32>),
C(kLowest<f32>, Negate(kLowest<f32>)), C(kLowest<f32>, Negate(kLowest<f32>)),
C(Negate(kLowest<f32>), kLowest<f32>), C(Negate(kLowest<f32>), kLowest<f32>),
// f16
C(0.0_h, -0.0_h),
C(-0.0_h, 0.0_h),
C(1.0_h, -1.0_h),
C(-1.0_h, 1.0_h),
C(kHighest<f16>, -kHighest<f16>),
C(-kHighest<f16>, kHighest<f16>),
C(kLowest<f16>, Negate(kLowest<f16>)),
C(Negate(kLowest<f16>), kLowest<f16>),
}))); })));
// Make sure UBSan doesn't trip on C++'s undefined behaviour of negating the smallest negative // Make sure UBSan doesn't trip on C++'s undefined behaviour of negating the smallest negative
@ -3252,7 +3262,8 @@ INSTANTIATE_TEST_SUITE_P( //
ResolverConstEvalBuiltinTest, ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kAtan2), testing::Combine(testing::Values(sem::BuiltinType::kAtan2),
testing::ValuesIn(Concat(Atan2Cases<AFloat, true>(), // testing::ValuesIn(Concat(Atan2Cases<AFloat, true>(), //
Atan2Cases<f32, false>())))); Atan2Cases<f32, false>(),
Atan2Cases<f16, false>()))));
template <typename T> template <typename T>
std::vector<Case> ClampCases() { std::vector<Case> ClampCases() {
@ -3277,7 +3288,8 @@ INSTANTIATE_TEST_SUITE_P( //
ClampCases<i32>(), ClampCases<i32>(),
ClampCases<u32>(), ClampCases<u32>(),
ClampCases<AFloat>(), ClampCases<AFloat>(),
ClampCases<f32>())))); ClampCases<f32>(),
ClampCases<f16>()))));
} // namespace builtin } // namespace builtin

File diff suppressed because it is too large Load Diff

View File

@ -1302,10 +1302,12 @@ bool GeneratorImpl::EmitFrexpCall(std::ostream& out,
bool GeneratorImpl::EmitDegreesCall(std::ostream& out, bool GeneratorImpl::EmitDegreesCall(std::ostream& out,
const ast::CallExpression* expr, const ast::CallExpression* expr,
const sem::Builtin* builtin) { const sem::Builtin* builtin) {
auto* return_elem_type = sem::Type::DeepestElementOf(builtin->ReturnType());
const std::string suffix = Is<sem::F16>(return_elem_type) ? "hf" : "f";
return CallBuiltinHelper(out, expr, builtin, return CallBuiltinHelper(out, expr, builtin,
[&](TextBuffer* b, const std::vector<std::string>& params) { [&](TextBuffer* b, const std::vector<std::string>& params) {
line(b) << "return " << params[0] << " * " << std::setprecision(20) line(b) << "return " << params[0] << " * " << std::setprecision(20)
<< sem::kRadToDeg << ";"; << sem::kRadToDeg << suffix << ";";
return true; return true;
}); });
} }
@ -1313,10 +1315,12 @@ bool GeneratorImpl::EmitDegreesCall(std::ostream& out,
bool GeneratorImpl::EmitRadiansCall(std::ostream& out, bool GeneratorImpl::EmitRadiansCall(std::ostream& out,
const ast::CallExpression* expr, const ast::CallExpression* expr,
const sem::Builtin* builtin) { const sem::Builtin* builtin) {
auto* return_elem_type = sem::Type::DeepestElementOf(builtin->ReturnType());
const std::string suffix = Is<sem::F16>(return_elem_type) ? "hf" : "f";
return CallBuiltinHelper(out, expr, builtin, return CallBuiltinHelper(out, expr, builtin,
[&](TextBuffer* b, const std::vector<std::string>& params) { [&](TextBuffer* b, const std::vector<std::string>& params) {
line(b) << "return " << params[0] << " * " << std::setprecision(20) line(b) << "return " << params[0] << " * " << std::setprecision(20)
<< sem::kDegToRad << ";"; << sem::kDegToRad << suffix << ";";
return true; return true;
}); });
} }

View File

@ -29,36 +29,40 @@ using BuiltinType = sem::BuiltinType;
using GlslGeneratorImplTest_Builtin = TestHelper; using GlslGeneratorImplTest_Builtin = TestHelper;
enum class ParamType { enum class CallParamType {
kF32, kF32,
kU32, kU32,
kBool, kBool,
kF16,
}; };
struct BuiltinData { struct BuiltinData {
BuiltinType builtin; BuiltinType builtin;
ParamType type; CallParamType type;
const char* glsl_name; const char* glsl_name;
}; };
inline std::ostream& operator<<(std::ostream& out, BuiltinData data) { inline std::ostream& operator<<(std::ostream& out, BuiltinData data) {
out << data.glsl_name; out << data.glsl_name << "<";
switch (data.type) { switch (data.type) {
case ParamType::kF32: case CallParamType::kF32:
out << "f32"; out << "f32";
break; break;
case ParamType::kU32: case CallParamType::kU32:
out << "u32"; out << "u32";
break; break;
case ParamType::kBool: case CallParamType::kBool:
out << "bool"; out << "bool";
break; break;
case CallParamType::kF16:
out << "f16";
break;
} }
out << ">"; out << ">";
return out; return out;
} }
const ast::CallExpression* GenerateCall(BuiltinType builtin, const ast::CallExpression* GenerateCall(BuiltinType builtin,
ParamType type, CallParamType type,
ProgramBuilder* builder) { ProgramBuilder* builder) {
std::string name; std::string name;
std::ostringstream str(name); std::ostringstream str(name);
@ -96,29 +100,51 @@ const ast::CallExpression* GenerateCall(BuiltinType builtin,
case BuiltinType::kTanh: case BuiltinType::kTanh:
case BuiltinType::kTrunc: case BuiltinType::kTrunc:
case BuiltinType::kSign: case BuiltinType::kSign:
return builder->Call(str.str(), "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else {
return builder->Call(str.str(), "f2");
}
case BuiltinType::kLdexp: case BuiltinType::kLdexp:
return builder->Call(str.str(), "f2", "i2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "i2");
} else {
return builder->Call(str.str(), "f2", "i2");
}
case BuiltinType::kAtan2: case BuiltinType::kAtan2:
case BuiltinType::kDot: case BuiltinType::kDot:
case BuiltinType::kDistance: case BuiltinType::kDistance:
case BuiltinType::kPow: case BuiltinType::kPow:
case BuiltinType::kReflect: case BuiltinType::kReflect:
case BuiltinType::kStep: case BuiltinType::kStep:
return builder->Call(str.str(), "f2", "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2");
} else {
return builder->Call(str.str(), "f2", "f2");
}
case BuiltinType::kCross: case BuiltinType::kCross:
return builder->Call(str.str(), "f3", "f3"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h3", "h3");
} else {
return builder->Call(str.str(), "f3", "f3");
}
case BuiltinType::kFma: case BuiltinType::kFma:
case BuiltinType::kMix: case BuiltinType::kMix:
case BuiltinType::kFaceForward: case BuiltinType::kFaceForward:
case BuiltinType::kSmoothstep: case BuiltinType::kSmoothstep:
return builder->Call(str.str(), "f2", "f2", "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "h2");
} else {
return builder->Call(str.str(), "f2", "f2", "f2");
}
case BuiltinType::kAll: case BuiltinType::kAll:
case BuiltinType::kAny: case BuiltinType::kAny:
return builder->Call(str.str(), "b2"); return builder->Call(str.str(), "b2");
case BuiltinType::kAbs: case BuiltinType::kAbs:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2"); return builder->Call(str.str(), "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else { } else {
return builder->Call(str.str(), "u2"); return builder->Call(str.str(), "u2");
} }
@ -127,23 +153,39 @@ const ast::CallExpression* GenerateCall(BuiltinType builtin,
return builder->Call(str.str(), "u2"); return builder->Call(str.str(), "u2");
case BuiltinType::kMax: case BuiltinType::kMax:
case BuiltinType::kMin: case BuiltinType::kMin:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2", "f2"); return builder->Call(str.str(), "f2", "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2");
} else { } else {
return builder->Call(str.str(), "u2", "u2"); return builder->Call(str.str(), "u2", "u2");
} }
case BuiltinType::kClamp: case BuiltinType::kClamp:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2", "f2", "f2"); return builder->Call(str.str(), "f2", "f2", "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "h2");
} else { } else {
return builder->Call(str.str(), "u2", "u2", "u2"); return builder->Call(str.str(), "u2", "u2", "u2");
} }
case BuiltinType::kSelect: case BuiltinType::kSelect:
return builder->Call(str.str(), "f2", "f2", "b2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "b2");
} else {
return builder->Call(str.str(), "f2", "f2", "b2");
}
case BuiltinType::kDeterminant: case BuiltinType::kDeterminant:
return builder->Call(str.str(), "m2x2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "hm2x2");
} else {
return builder->Call(str.str(), "m2x2");
}
case BuiltinType::kTranspose: case BuiltinType::kTranspose:
return builder->Call(str.str(), "m3x2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "hm3x2");
} else {
return builder->Call(str.str(), "m3x2");
}
default: default:
break; break;
} }
@ -153,6 +195,15 @@ using GlslBuiltinTest = TestParamHelper<BuiltinData>;
TEST_P(GlslBuiltinTest, Emit) { TEST_P(GlslBuiltinTest, Emit) {
auto param = GetParam(); auto param = GetParam();
if (param.type == CallParamType::kF16) {
Enable(ast::Extension::kF16);
GlobalVar("h2", ty.vec2<f16>(), ast::StorageClass::kPrivate);
GlobalVar("h3", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("hm2x2", ty.mat2x2<f16>(), ast::StorageClass::kPrivate);
GlobalVar("hm3x2", ty.mat3x2<f16>(), ast::StorageClass::kPrivate);
}
GlobalVar("f2", ty.vec2<f32>(), ast::StorageClass::kPrivate); GlobalVar("f2", ty.vec2<f32>(), ast::StorageClass::kPrivate);
GlobalVar("f3", ty.vec3<f32>(), ast::StorageClass::kPrivate); GlobalVar("f3", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("u2", ty.vec2<u32>(), ast::StorageClass::kPrivate); GlobalVar("u2", ty.vec2<u32>(), ast::StorageClass::kPrivate);
@ -180,64 +231,110 @@ TEST_P(GlslBuiltinTest, Emit) {
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
GlslGeneratorImplTest_Builtin, GlslGeneratorImplTest_Builtin,
GlslBuiltinTest, GlslBuiltinTest,
testing::Values(BuiltinData{BuiltinType::kAbs, ParamType::kF32, "abs"}, testing::Values(/* Logical built-in */
BuiltinData{BuiltinType::kAbs, ParamType::kU32, "abs"}, BuiltinData{BuiltinType::kAll, CallParamType::kBool, "all"},
BuiltinData{BuiltinType::kAcos, ParamType::kF32, "acos"}, BuiltinData{BuiltinType::kAny, CallParamType::kBool, "any"},
BuiltinData{BuiltinType::kAll, ParamType::kBool, "all"}, /* Float built-in */
BuiltinData{BuiltinType::kAny, ParamType::kBool, "any"}, BuiltinData{BuiltinType::kAbs, CallParamType::kF32, "abs"},
BuiltinData{BuiltinType::kAsin, ParamType::kF32, "asin"}, BuiltinData{BuiltinType::kAbs, CallParamType::kF16, "abs"},
BuiltinData{BuiltinType::kAtan, ParamType::kF32, "atan"}, BuiltinData{BuiltinType::kAcos, CallParamType::kF32, "acos"},
BuiltinData{BuiltinType::kAtan2, ParamType::kF32, "atan"}, BuiltinData{BuiltinType::kAcos, CallParamType::kF16, "acos"},
BuiltinData{BuiltinType::kCeil, ParamType::kF32, "ceil"}, BuiltinData{BuiltinType::kAsin, CallParamType::kF32, "asin"},
BuiltinData{BuiltinType::kClamp, ParamType::kF32, "clamp"}, BuiltinData{BuiltinType::kAsin, CallParamType::kF16, "asin"},
BuiltinData{BuiltinType::kClamp, ParamType::kU32, "clamp"}, BuiltinData{BuiltinType::kAtan, CallParamType::kF32, "atan"},
BuiltinData{BuiltinType::kCos, ParamType::kF32, "cos"}, BuiltinData{BuiltinType::kAtan, CallParamType::kF16, "atan"},
BuiltinData{BuiltinType::kCosh, ParamType::kF32, "cosh"}, BuiltinData{BuiltinType::kAtan2, CallParamType::kF32, "atan"},
BuiltinData{BuiltinType::kCountOneBits, ParamType::kU32, "bitCount"}, BuiltinData{BuiltinType::kAtan2, CallParamType::kF16, "atan"},
BuiltinData{BuiltinType::kCross, ParamType::kF32, "cross"}, BuiltinData{BuiltinType::kCeil, CallParamType::kF32, "ceil"},
BuiltinData{BuiltinType::kDeterminant, ParamType::kF32, "determinant"}, BuiltinData{BuiltinType::kCeil, CallParamType::kF16, "ceil"},
BuiltinData{BuiltinType::kDistance, ParamType::kF32, "distance"}, BuiltinData{BuiltinType::kClamp, CallParamType::kF32, "clamp"},
BuiltinData{BuiltinType::kDot, ParamType::kF32, "dot"}, BuiltinData{BuiltinType::kClamp, CallParamType::kF16, "clamp"},
BuiltinData{BuiltinType::kDpdx, ParamType::kF32, "dFdx"}, BuiltinData{BuiltinType::kCos, CallParamType::kF32, "cos"},
BuiltinData{BuiltinType::kDpdxCoarse, ParamType::kF32, "dFdx"}, BuiltinData{BuiltinType::kCos, CallParamType::kF16, "cos"},
BuiltinData{BuiltinType::kDpdxFine, ParamType::kF32, "dFdx"}, BuiltinData{BuiltinType::kCosh, CallParamType::kF32, "cosh"},
BuiltinData{BuiltinType::kDpdy, ParamType::kF32, "dFdy"}, BuiltinData{BuiltinType::kCosh, CallParamType::kF16, "cosh"},
BuiltinData{BuiltinType::kDpdyCoarse, ParamType::kF32, "dFdy"}, BuiltinData{BuiltinType::kCross, CallParamType::kF32, "cross"},
BuiltinData{BuiltinType::kDpdyFine, ParamType::kF32, "dFdy"}, BuiltinData{BuiltinType::kCross, CallParamType::kF16, "cross"},
BuiltinData{BuiltinType::kExp, ParamType::kF32, "exp"}, BuiltinData{BuiltinType::kDistance, CallParamType::kF32, "distance"},
BuiltinData{BuiltinType::kExp2, ParamType::kF32, "exp2"}, BuiltinData{BuiltinType::kDistance, CallParamType::kF16, "distance"},
BuiltinData{BuiltinType::kFaceForward, ParamType::kF32, "faceforward"}, BuiltinData{BuiltinType::kExp, CallParamType::kF32, "exp"},
BuiltinData{BuiltinType::kFloor, ParamType::kF32, "floor"}, BuiltinData{BuiltinType::kExp, CallParamType::kF16, "exp"},
BuiltinData{BuiltinType::kFma, ParamType::kF32, "fma"}, BuiltinData{BuiltinType::kExp2, CallParamType::kF32, "exp2"},
BuiltinData{BuiltinType::kFract, ParamType::kF32, "fract"}, BuiltinData{BuiltinType::kExp2, CallParamType::kF16, "exp2"},
BuiltinData{BuiltinType::kFwidth, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFaceForward, CallParamType::kF32, "faceforward"},
BuiltinData{BuiltinType::kFwidthCoarse, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFaceForward, CallParamType::kF16, "faceforward"},
BuiltinData{BuiltinType::kFwidthFine, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFloor, CallParamType::kF32, "floor"},
BuiltinData{BuiltinType::kInverseSqrt, ParamType::kF32, "inversesqrt"}, BuiltinData{BuiltinType::kFloor, CallParamType::kF16, "floor"},
BuiltinData{BuiltinType::kLdexp, ParamType::kF32, "ldexp"}, BuiltinData{BuiltinType::kFma, CallParamType::kF32, "fma"},
BuiltinData{BuiltinType::kLength, ParamType::kF32, "length"}, BuiltinData{BuiltinType::kFma, CallParamType::kF16, "fma"},
BuiltinData{BuiltinType::kLog, ParamType::kF32, "log"}, BuiltinData{BuiltinType::kFract, CallParamType::kF32, "fract"},
BuiltinData{BuiltinType::kLog2, ParamType::kF32, "log2"}, BuiltinData{BuiltinType::kFract, CallParamType::kF16, "fract"},
BuiltinData{BuiltinType::kMax, ParamType::kF32, "max"}, BuiltinData{BuiltinType::kInverseSqrt, CallParamType::kF32, "inversesqrt"},
BuiltinData{BuiltinType::kMax, ParamType::kU32, "max"}, BuiltinData{BuiltinType::kInverseSqrt, CallParamType::kF16, "inversesqrt"},
BuiltinData{BuiltinType::kMin, ParamType::kF32, "min"}, BuiltinData{BuiltinType::kLdexp, CallParamType::kF32, "ldexp"},
BuiltinData{BuiltinType::kMin, ParamType::kU32, "min"}, BuiltinData{BuiltinType::kLdexp, CallParamType::kF16, "ldexp"},
BuiltinData{BuiltinType::kMix, ParamType::kF32, "mix"}, BuiltinData{BuiltinType::kLength, CallParamType::kF32, "length"},
BuiltinData{BuiltinType::kNormalize, ParamType::kF32, "normalize"}, BuiltinData{BuiltinType::kLength, CallParamType::kF16, "length"},
BuiltinData{BuiltinType::kPow, ParamType::kF32, "pow"}, BuiltinData{BuiltinType::kLog, CallParamType::kF32, "log"},
BuiltinData{BuiltinType::kReflect, ParamType::kF32, "reflect"}, BuiltinData{BuiltinType::kLog, CallParamType::kF16, "log"},
BuiltinData{BuiltinType::kReverseBits, ParamType::kU32, "bitfieldReverse"}, BuiltinData{BuiltinType::kLog2, CallParamType::kF32, "log2"},
BuiltinData{BuiltinType::kRound, ParamType::kU32, "round"}, BuiltinData{BuiltinType::kLog2, CallParamType::kF16, "log2"},
BuiltinData{BuiltinType::kSign, ParamType::kF32, "sign"}, BuiltinData{BuiltinType::kMax, CallParamType::kF32, "max"},
BuiltinData{BuiltinType::kSin, ParamType::kF32, "sin"}, BuiltinData{BuiltinType::kMax, CallParamType::kF16, "max"},
BuiltinData{BuiltinType::kSinh, ParamType::kF32, "sinh"}, BuiltinData{BuiltinType::kMin, CallParamType::kF32, "min"},
BuiltinData{BuiltinType::kSmoothstep, ParamType::kF32, "smoothstep"}, BuiltinData{BuiltinType::kMin, CallParamType::kF16, "min"},
BuiltinData{BuiltinType::kSqrt, ParamType::kF32, "sqrt"}, BuiltinData{BuiltinType::kMix, CallParamType::kF32, "mix"},
BuiltinData{BuiltinType::kStep, ParamType::kF32, "step"}, BuiltinData{BuiltinType::kMix, CallParamType::kF16, "mix"},
BuiltinData{BuiltinType::kTan, ParamType::kF32, "tan"}, BuiltinData{BuiltinType::kNormalize, CallParamType::kF32, "normalize"},
BuiltinData{BuiltinType::kTanh, ParamType::kF32, "tanh"}, BuiltinData{BuiltinType::kNormalize, CallParamType::kF16, "normalize"},
BuiltinData{BuiltinType::kTranspose, ParamType::kF32, "transpose"}, BuiltinData{BuiltinType::kPow, CallParamType::kF32, "pow"},
BuiltinData{BuiltinType::kTrunc, ParamType::kF32, "trunc"})); BuiltinData{BuiltinType::kPow, CallParamType::kF16, "pow"},
BuiltinData{BuiltinType::kReflect, CallParamType::kF32, "reflect"},
BuiltinData{BuiltinType::kReflect, CallParamType::kF16, "reflect"},
BuiltinData{BuiltinType::kSign, CallParamType::kF32, "sign"},
BuiltinData{BuiltinType::kSign, CallParamType::kF16, "sign"},
BuiltinData{BuiltinType::kSin, CallParamType::kF32, "sin"},
BuiltinData{BuiltinType::kSin, CallParamType::kF16, "sin"},
BuiltinData{BuiltinType::kSinh, CallParamType::kF32, "sinh"},
BuiltinData{BuiltinType::kSinh, CallParamType::kF16, "sinh"},
BuiltinData{BuiltinType::kSmoothstep, CallParamType::kF32, "smoothstep"},
BuiltinData{BuiltinType::kSmoothstep, CallParamType::kF16, "smoothstep"},
BuiltinData{BuiltinType::kSqrt, CallParamType::kF32, "sqrt"},
BuiltinData{BuiltinType::kSqrt, CallParamType::kF16, "sqrt"},
BuiltinData{BuiltinType::kStep, CallParamType::kF32, "step"},
BuiltinData{BuiltinType::kStep, CallParamType::kF16, "step"},
BuiltinData{BuiltinType::kTan, CallParamType::kF32, "tan"},
BuiltinData{BuiltinType::kTan, CallParamType::kF16, "tan"},
BuiltinData{BuiltinType::kTanh, CallParamType::kF32, "tanh"},
BuiltinData{BuiltinType::kTanh, CallParamType::kF16, "tanh"},
BuiltinData{BuiltinType::kTrunc, CallParamType::kF32, "trunc"},
BuiltinData{BuiltinType::kTrunc, CallParamType::kF16, "trunc"},
/* Integer built-in */
BuiltinData{BuiltinType::kAbs, CallParamType::kU32, "abs"},
BuiltinData{BuiltinType::kClamp, CallParamType::kU32, "clamp"},
BuiltinData{BuiltinType::kCountOneBits, CallParamType::kU32, "bitCount"},
BuiltinData{BuiltinType::kMax, CallParamType::kU32, "max"},
BuiltinData{BuiltinType::kMin, CallParamType::kU32, "min"},
BuiltinData{BuiltinType::kReverseBits, CallParamType::kU32, "bitfieldReverse"},
BuiltinData{BuiltinType::kRound, CallParamType::kU32, "round"},
/* Matrix built-in */
BuiltinData{BuiltinType::kDeterminant, CallParamType::kF32, "determinant"},
BuiltinData{BuiltinType::kDeterminant, CallParamType::kF16, "determinant"},
BuiltinData{BuiltinType::kTranspose, CallParamType::kF32, "transpose"},
BuiltinData{BuiltinType::kTranspose, CallParamType::kF16, "transpose"},
/* Vector built-in */
BuiltinData{BuiltinType::kDot, CallParamType::kF32, "dot"},
BuiltinData{BuiltinType::kDot, CallParamType::kF16, "dot"},
/* Derivate built-in */
BuiltinData{BuiltinType::kDpdx, CallParamType::kF32, "dFdx"},
BuiltinData{BuiltinType::kDpdxCoarse, CallParamType::kF32, "dFdx"},
BuiltinData{BuiltinType::kDpdxFine, CallParamType::kF32, "dFdx"},
BuiltinData{BuiltinType::kDpdy, CallParamType::kF32, "dFdy"},
BuiltinData{BuiltinType::kDpdyCoarse, CallParamType::kF32, "dFdy"},
BuiltinData{BuiltinType::kDpdyFine, CallParamType::kF32, "dFdy"},
BuiltinData{BuiltinType::kFwidth, CallParamType::kF32, "fwidth"},
BuiltinData{BuiltinType::kFwidthCoarse, CallParamType::kF32, "fwidth"},
BuiltinData{BuiltinType::kFwidthFine, CallParamType::kF32, "fwidth"}));
TEST_F(GlslGeneratorImplTest_Builtin, Builtin_Call) { TEST_F(GlslGeneratorImplTest_Builtin, Builtin_Call) {
auto* call = Call("dot", "param1", "param2"); auto* call = Call("dot", "param1", "param2");
@ -277,6 +374,41 @@ TEST_F(GlslGeneratorImplTest_Builtin, Select_Vector) {
EXPECT_EQ(out.str(), "mix(ivec2(1, 2), ivec2(3, 4), bvec2(true, false))"); EXPECT_EQ(out.str(), "mix(ivec2(1, 2), ivec2(3, 4), bvec2(true, false))");
} }
TEST_F(GlslGeneratorImplTest_Builtin, FMA_f32) {
auto* call = Call("fma", "a", "b", "c");
GlobalVar("a", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("c", ty.vec3<f32>(), ast::StorageClass::kPrivate);
WrapInFunction(CallStmt(call));
GeneratorImpl& gen = Build();
gen.increment_indent();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
EXPECT_EQ(out.str(), "((a) * (b) + (c))");
}
TEST_F(GlslGeneratorImplTest_Builtin, FMA_f16) {
Enable(ast::Extension::kF16);
GlobalVar("a", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("c", ty.vec3<f16>(), ast::StorageClass::kPrivate);
auto* call = Call("fma", "a", "b", "c");
WrapInFunction(CallStmt(call));
GeneratorImpl& gen = Build();
gen.increment_indent();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
EXPECT_EQ(out.str(), "((a) * (b) + (c))");
}
TEST_F(GlslGeneratorImplTest_Builtin, Modf_Scalar) { TEST_F(GlslGeneratorImplTest_Builtin, Modf_Scalar) {
auto* call = Call("modf", 1_f); auto* call = Call("modf", 1_f);
WrapInFunction(CallStmt(call)); WrapInFunction(CallStmt(call));
@ -402,7 +534,7 @@ void main() {
)")); )"));
} }
TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Scalar) { TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Scalar_f32) {
auto* val = Var("val", ty.f32()); auto* val = Var("val", ty.f32());
auto* call = Call("degrees", val); auto* call = Call("degrees", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -413,7 +545,7 @@ TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Scalar) {
EXPECT_EQ(gen.result(), R"(#version 310 es EXPECT_EQ(gen.result(), R"(#version 310 es
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -430,7 +562,7 @@ void main() {
)"); )");
} }
TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Vector) { TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Vector_f32) {
auto* val = Var("val", ty.vec3<f32>()); auto* val = Var("val", ty.vec3<f32>());
auto* call = Call("degrees", val); auto* call = Call("degrees", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -441,7 +573,7 @@ TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Vector) {
EXPECT_EQ(gen.result(), R"(#version 310 es EXPECT_EQ(gen.result(), R"(#version 310 es
vec3 tint_degrees(vec3 param_0) { vec3 tint_degrees(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -458,7 +590,69 @@ void main() {
)"); )");
} }
TEST_F(GlslGeneratorImplTest_Builtin, Radians_Scalar) { TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.f16());
auto* call = Call("degrees", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#version 310 es
#extension GL_AMD_gpu_shader_half_float : require
float16_t tint_degrees(float16_t param_0) {
return param_0 * 57.295779513082322865hf;
}
void test_function() {
float16_t val = 0.0hf;
float16_t tint_symbol = tint_degrees(val);
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
test_function();
return;
}
)");
}
TEST_F(GlslGeneratorImplTest_Builtin, Degrees_Vector_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.vec3<f16>());
auto* call = Call("degrees", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#version 310 es
#extension GL_AMD_gpu_shader_half_float : require
f16vec3 tint_degrees(f16vec3 param_0) {
return param_0 * 57.295779513082322865hf;
}
void test_function() {
f16vec3 val = f16vec3(0.0hf, 0.0hf, 0.0hf);
f16vec3 tint_symbol = tint_degrees(val);
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
test_function();
return;
}
)");
}
TEST_F(GlslGeneratorImplTest_Builtin, Radians_Scalar_f32) {
auto* val = Var("val", ty.f32()); auto* val = Var("val", ty.f32());
auto* call = Call("radians", val); auto* call = Call("radians", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -469,7 +663,7 @@ TEST_F(GlslGeneratorImplTest_Builtin, Radians_Scalar) {
EXPECT_EQ(gen.result(), R"(#version 310 es EXPECT_EQ(gen.result(), R"(#version 310 es
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -486,7 +680,7 @@ void main() {
)"); )");
} }
TEST_F(GlslGeneratorImplTest_Builtin, Radians_Vector) { TEST_F(GlslGeneratorImplTest_Builtin, Radians_Vector_f32) {
auto* val = Var("val", ty.vec3<f32>()); auto* val = Var("val", ty.vec3<f32>());
auto* call = Call("radians", val); auto* call = Call("radians", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -497,7 +691,7 @@ TEST_F(GlslGeneratorImplTest_Builtin, Radians_Vector) {
EXPECT_EQ(gen.result(), R"(#version 310 es EXPECT_EQ(gen.result(), R"(#version 310 es
vec3 tint_radians(vec3 param_0) { vec3 tint_radians(vec3 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -514,6 +708,68 @@ void main() {
)"); )");
} }
TEST_F(GlslGeneratorImplTest_Builtin, Radians_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.f16());
auto* call = Call("radians", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#version 310 es
#extension GL_AMD_gpu_shader_half_float : require
float16_t tint_radians(float16_t param_0) {
return param_0 * 0.017453292519943295474hf;
}
void test_function() {
float16_t val = 0.0hf;
float16_t tint_symbol = tint_radians(val);
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
test_function();
return;
}
)");
}
TEST_F(GlslGeneratorImplTest_Builtin, Radians_Vector_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.vec3<f16>());
auto* call = Call("radians", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#version 310 es
#extension GL_AMD_gpu_shader_half_float : require
f16vec3 tint_radians(f16vec3 param_0) {
return param_0 * 0.017453292519943295474hf;
}
void test_function() {
f16vec3 val = f16vec3(0.0hf, 0.0hf, 0.0hf);
f16vec3 tint_symbol = tint_radians(val);
}
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
test_function();
return;
}
)");
}
TEST_F(GlslGeneratorImplTest_Builtin, ExtractBits) { TEST_F(GlslGeneratorImplTest_Builtin, ExtractBits) {
auto* v = Var("v", ty.vec3<u32>()); auto* v = Var("v", ty.vec3<u32>());
auto* offset = Var("offset", ty.u32()); auto* offset = Var("offset", ty.u32());
@ -828,23 +1084,6 @@ void main() {
)"); )");
} }
TEST_F(GlslGeneratorImplTest_Builtin, FMA) {
auto* call = Call("fma", "a", "b", "c");
GlobalVar("a", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("c", ty.vec3<f32>(), ast::StorageClass::kPrivate);
WrapInFunction(CallStmt(call));
GeneratorImpl& gen = Build();
gen.increment_indent();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
EXPECT_EQ(out.str(), "((a) * (b) + (c))");
}
TEST_F(GlslGeneratorImplTest_Builtin, DotU32) { TEST_F(GlslGeneratorImplTest_Builtin, DotU32) {
GlobalVar("v", ty.vec3<u32>(), ast::StorageClass::kPrivate); GlobalVar("v", ty.vec3<u32>(), ast::StorageClass::kPrivate);
WrapInFunction(CallStmt(Call("dot", "v", "v"))); WrapInFunction(CallStmt(Call("dot", "v", "v")));

View File

@ -28,36 +28,40 @@ namespace {
using BuiltinType = sem::BuiltinType; using BuiltinType = sem::BuiltinType;
using HlslGeneratorImplTest_Builtin = TestHelper; using HlslGeneratorImplTest_Builtin = TestHelper;
enum class ParamType { enum class CallParamType {
kF32, kF32,
kU32, kU32,
kBool, kBool,
kF16,
}; };
struct BuiltinData { struct BuiltinData {
BuiltinType builtin; BuiltinType builtin;
ParamType type; CallParamType type;
const char* hlsl_name; const char* hlsl_name;
}; };
inline std::ostream& operator<<(std::ostream& out, BuiltinData data) { inline std::ostream& operator<<(std::ostream& out, BuiltinData data) {
out << data.hlsl_name; out << data.hlsl_name << "<";
switch (data.type) { switch (data.type) {
case ParamType::kF32: case CallParamType::kF32:
out << "f32"; out << "f32";
break; break;
case ParamType::kU32: case CallParamType::kU32:
out << "u32"; out << "u32";
break; break;
case ParamType::kBool: case CallParamType::kBool:
out << "bool"; out << "bool";
break; break;
case CallParamType::kF16:
out << "f16";
break;
} }
out << ">"; out << ">";
return out; return out;
} }
const ast::CallExpression* GenerateCall(BuiltinType builtin, const ast::CallExpression* GenerateCall(BuiltinType builtin,
ParamType type, CallParamType type,
ProgramBuilder* builder) { ProgramBuilder* builder) {
std::string name; std::string name;
std::ostringstream str(name); std::ostringstream str(name);
@ -95,29 +99,51 @@ const ast::CallExpression* GenerateCall(BuiltinType builtin,
case BuiltinType::kTanh: case BuiltinType::kTanh:
case BuiltinType::kTrunc: case BuiltinType::kTrunc:
case BuiltinType::kSign: case BuiltinType::kSign:
return builder->Call(str.str(), "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else {
return builder->Call(str.str(), "f2");
}
case BuiltinType::kLdexp: case BuiltinType::kLdexp:
return builder->Call(str.str(), "f2", "i2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "i2");
} else {
return builder->Call(str.str(), "f2", "i2");
}
case BuiltinType::kAtan2: case BuiltinType::kAtan2:
case BuiltinType::kDot: case BuiltinType::kDot:
case BuiltinType::kDistance: case BuiltinType::kDistance:
case BuiltinType::kPow: case BuiltinType::kPow:
case BuiltinType::kReflect: case BuiltinType::kReflect:
case BuiltinType::kStep: case BuiltinType::kStep:
return builder->Call(str.str(), "f2", "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2");
} else {
return builder->Call(str.str(), "f2", "f2");
}
case BuiltinType::kCross: case BuiltinType::kCross:
return builder->Call(str.str(), "f3", "f3"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h3", "h3");
} else {
return builder->Call(str.str(), "f3", "f3");
}
case BuiltinType::kFma: case BuiltinType::kFma:
case BuiltinType::kMix: case BuiltinType::kMix:
case BuiltinType::kFaceForward: case BuiltinType::kFaceForward:
case BuiltinType::kSmoothstep: case BuiltinType::kSmoothstep:
return builder->Call(str.str(), "f2", "f2", "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "h2");
} else {
return builder->Call(str.str(), "f2", "f2", "f2");
}
case BuiltinType::kAll: case BuiltinType::kAll:
case BuiltinType::kAny: case BuiltinType::kAny:
return builder->Call(str.str(), "b2"); return builder->Call(str.str(), "b2");
case BuiltinType::kAbs: case BuiltinType::kAbs:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2"); return builder->Call(str.str(), "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else { } else {
return builder->Call(str.str(), "u2"); return builder->Call(str.str(), "u2");
} }
@ -126,32 +152,58 @@ const ast::CallExpression* GenerateCall(BuiltinType builtin,
return builder->Call(str.str(), "u2"); return builder->Call(str.str(), "u2");
case BuiltinType::kMax: case BuiltinType::kMax:
case BuiltinType::kMin: case BuiltinType::kMin:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2", "f2"); return builder->Call(str.str(), "f2", "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2");
} else { } else {
return builder->Call(str.str(), "u2", "u2"); return builder->Call(str.str(), "u2", "u2");
} }
case BuiltinType::kClamp: case BuiltinType::kClamp:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2", "f2", "f2"); return builder->Call(str.str(), "f2", "f2", "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "h2");
} else { } else {
return builder->Call(str.str(), "u2", "u2", "u2"); return builder->Call(str.str(), "u2", "u2", "u2");
} }
case BuiltinType::kSelect: case BuiltinType::kSelect:
return builder->Call(str.str(), "f2", "f2", "b2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "b2");
} else {
return builder->Call(str.str(), "f2", "f2", "b2");
}
case BuiltinType::kDeterminant: case BuiltinType::kDeterminant:
return builder->Call(str.str(), "m2x2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "hm2x2");
} else {
return builder->Call(str.str(), "m2x2");
}
case BuiltinType::kTranspose: case BuiltinType::kTranspose:
return builder->Call(str.str(), "m3x2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "hm3x2");
} else {
return builder->Call(str.str(), "m3x2");
}
default: default:
break; break;
} }
return nullptr; return nullptr;
} }
using HlslBuiltinTest = TestParamHelper<BuiltinData>; using HlslBuiltinTest = TestParamHelper<BuiltinData>;
TEST_P(HlslBuiltinTest, Emit) { TEST_P(HlslBuiltinTest, Emit) {
auto param = GetParam(); auto param = GetParam();
if (param.type == CallParamType::kF16) {
Enable(ast::Extension::kF16);
GlobalVar("h2", ty.vec2<f16>(), ast::StorageClass::kPrivate);
GlobalVar("h3", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("hm2x2", ty.mat2x2<f16>(), ast::StorageClass::kPrivate);
GlobalVar("hm3x2", ty.mat3x2<f16>(), ast::StorageClass::kPrivate);
}
GlobalVar("f2", ty.vec2<f32>(), ast::StorageClass::kPrivate); GlobalVar("f2", ty.vec2<f32>(), ast::StorageClass::kPrivate);
GlobalVar("f3", ty.vec3<f32>(), ast::StorageClass::kPrivate); GlobalVar("f3", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("u2", ty.vec2<u32>(), ast::StorageClass::kPrivate); GlobalVar("u2", ty.vec2<u32>(), ast::StorageClass::kPrivate);
@ -179,64 +231,110 @@ TEST_P(HlslBuiltinTest, Emit) {
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
HlslGeneratorImplTest_Builtin, HlslGeneratorImplTest_Builtin,
HlslBuiltinTest, HlslBuiltinTest,
testing::Values(BuiltinData{BuiltinType::kAbs, ParamType::kF32, "abs"}, testing::Values(/* Logical built-in */
BuiltinData{BuiltinType::kAbs, ParamType::kU32, "abs"}, BuiltinData{BuiltinType::kAll, CallParamType::kBool, "all"},
BuiltinData{BuiltinType::kAcos, ParamType::kF32, "acos"}, BuiltinData{BuiltinType::kAny, CallParamType::kBool, "any"},
BuiltinData{BuiltinType::kAll, ParamType::kBool, "all"}, /* Float built-in */
BuiltinData{BuiltinType::kAny, ParamType::kBool, "any"}, BuiltinData{BuiltinType::kAbs, CallParamType::kF32, "abs"},
BuiltinData{BuiltinType::kAsin, ParamType::kF32, "asin"}, BuiltinData{BuiltinType::kAbs, CallParamType::kF16, "abs"},
BuiltinData{BuiltinType::kAtan, ParamType::kF32, "atan"}, BuiltinData{BuiltinType::kAcos, CallParamType::kF32, "acos"},
BuiltinData{BuiltinType::kAtan2, ParamType::kF32, "atan2"}, BuiltinData{BuiltinType::kAcos, CallParamType::kF16, "acos"},
BuiltinData{BuiltinType::kCeil, ParamType::kF32, "ceil"}, BuiltinData{BuiltinType::kAsin, CallParamType::kF32, "asin"},
BuiltinData{BuiltinType::kClamp, ParamType::kF32, "clamp"}, BuiltinData{BuiltinType::kAsin, CallParamType::kF16, "asin"},
BuiltinData{BuiltinType::kClamp, ParamType::kU32, "clamp"}, BuiltinData{BuiltinType::kAtan, CallParamType::kF32, "atan"},
BuiltinData{BuiltinType::kCos, ParamType::kF32, "cos"}, BuiltinData{BuiltinType::kAtan, CallParamType::kF16, "atan"},
BuiltinData{BuiltinType::kCosh, ParamType::kF32, "cosh"}, BuiltinData{BuiltinType::kAtan2, CallParamType::kF32, "atan2"},
BuiltinData{BuiltinType::kCountOneBits, ParamType::kU32, "countbits"}, BuiltinData{BuiltinType::kAtan2, CallParamType::kF16, "atan2"},
BuiltinData{BuiltinType::kCross, ParamType::kF32, "cross"}, BuiltinData{BuiltinType::kCeil, CallParamType::kF32, "ceil"},
BuiltinData{BuiltinType::kDeterminant, ParamType::kF32, "determinant"}, BuiltinData{BuiltinType::kCeil, CallParamType::kF16, "ceil"},
BuiltinData{BuiltinType::kDistance, ParamType::kF32, "distance"}, BuiltinData{BuiltinType::kClamp, CallParamType::kF32, "clamp"},
BuiltinData{BuiltinType::kDot, ParamType::kF32, "dot"}, BuiltinData{BuiltinType::kClamp, CallParamType::kF16, "clamp"},
BuiltinData{BuiltinType::kDpdx, ParamType::kF32, "ddx"}, BuiltinData{BuiltinType::kCos, CallParamType::kF32, "cos"},
BuiltinData{BuiltinType::kDpdxCoarse, ParamType::kF32, "ddx_coarse"}, BuiltinData{BuiltinType::kCos, CallParamType::kF16, "cos"},
BuiltinData{BuiltinType::kDpdxFine, ParamType::kF32, "ddx_fine"}, BuiltinData{BuiltinType::kCosh, CallParamType::kF32, "cosh"},
BuiltinData{BuiltinType::kDpdy, ParamType::kF32, "ddy"}, BuiltinData{BuiltinType::kCosh, CallParamType::kF16, "cosh"},
BuiltinData{BuiltinType::kDpdyCoarse, ParamType::kF32, "ddy_coarse"}, BuiltinData{BuiltinType::kCross, CallParamType::kF32, "cross"},
BuiltinData{BuiltinType::kDpdyFine, ParamType::kF32, "ddy_fine"}, BuiltinData{BuiltinType::kCross, CallParamType::kF16, "cross"},
BuiltinData{BuiltinType::kExp, ParamType::kF32, "exp"}, BuiltinData{BuiltinType::kDistance, CallParamType::kF32, "distance"},
BuiltinData{BuiltinType::kExp2, ParamType::kF32, "exp2"}, BuiltinData{BuiltinType::kDistance, CallParamType::kF16, "distance"},
BuiltinData{BuiltinType::kFaceForward, ParamType::kF32, "faceforward"}, BuiltinData{BuiltinType::kExp, CallParamType::kF32, "exp"},
BuiltinData{BuiltinType::kFloor, ParamType::kF32, "floor"}, BuiltinData{BuiltinType::kExp, CallParamType::kF16, "exp"},
BuiltinData{BuiltinType::kFma, ParamType::kF32, "mad"}, BuiltinData{BuiltinType::kExp2, CallParamType::kF32, "exp2"},
BuiltinData{BuiltinType::kFract, ParamType::kF32, "frac"}, BuiltinData{BuiltinType::kExp2, CallParamType::kF16, "exp2"},
BuiltinData{BuiltinType::kFwidth, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFaceForward, CallParamType::kF32, "faceforward"},
BuiltinData{BuiltinType::kFwidthCoarse, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFaceForward, CallParamType::kF16, "faceforward"},
BuiltinData{BuiltinType::kFwidthFine, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFloor, CallParamType::kF32, "floor"},
BuiltinData{BuiltinType::kInverseSqrt, ParamType::kF32, "rsqrt"}, BuiltinData{BuiltinType::kFloor, CallParamType::kF16, "floor"},
BuiltinData{BuiltinType::kLdexp, ParamType::kF32, "ldexp"}, BuiltinData{BuiltinType::kFma, CallParamType::kF32, "mad"},
BuiltinData{BuiltinType::kLength, ParamType::kF32, "length"}, BuiltinData{BuiltinType::kFma, CallParamType::kF16, "mad"},
BuiltinData{BuiltinType::kLog, ParamType::kF32, "log"}, BuiltinData{BuiltinType::kFract, CallParamType::kF32, "frac"},
BuiltinData{BuiltinType::kLog2, ParamType::kF32, "log2"}, BuiltinData{BuiltinType::kFract, CallParamType::kF16, "frac"},
BuiltinData{BuiltinType::kMax, ParamType::kF32, "max"}, BuiltinData{BuiltinType::kInverseSqrt, CallParamType::kF32, "rsqrt"},
BuiltinData{BuiltinType::kMax, ParamType::kU32, "max"}, BuiltinData{BuiltinType::kInverseSqrt, CallParamType::kF16, "rsqrt"},
BuiltinData{BuiltinType::kMin, ParamType::kF32, "min"}, BuiltinData{BuiltinType::kLdexp, CallParamType::kF32, "ldexp"},
BuiltinData{BuiltinType::kMin, ParamType::kU32, "min"}, BuiltinData{BuiltinType::kLdexp, CallParamType::kF16, "ldexp"},
BuiltinData{BuiltinType::kMix, ParamType::kF32, "lerp"}, BuiltinData{BuiltinType::kLength, CallParamType::kF32, "length"},
BuiltinData{BuiltinType::kNormalize, ParamType::kF32, "normalize"}, BuiltinData{BuiltinType::kLength, CallParamType::kF16, "length"},
BuiltinData{BuiltinType::kPow, ParamType::kF32, "pow"}, BuiltinData{BuiltinType::kLog, CallParamType::kF32, "log"},
BuiltinData{BuiltinType::kReflect, ParamType::kF32, "reflect"}, BuiltinData{BuiltinType::kLog, CallParamType::kF16, "log"},
BuiltinData{BuiltinType::kReverseBits, ParamType::kU32, "reversebits"}, BuiltinData{BuiltinType::kLog2, CallParamType::kF32, "log2"},
BuiltinData{BuiltinType::kRound, ParamType::kU32, "round"}, BuiltinData{BuiltinType::kLog2, CallParamType::kF16, "log2"},
BuiltinData{BuiltinType::kSign, ParamType::kF32, "sign"}, BuiltinData{BuiltinType::kMax, CallParamType::kF32, "max"},
BuiltinData{BuiltinType::kSin, ParamType::kF32, "sin"}, BuiltinData{BuiltinType::kMax, CallParamType::kF16, "max"},
BuiltinData{BuiltinType::kSinh, ParamType::kF32, "sinh"}, BuiltinData{BuiltinType::kMin, CallParamType::kF32, "min"},
BuiltinData{BuiltinType::kSmoothstep, ParamType::kF32, "smoothstep"}, BuiltinData{BuiltinType::kMin, CallParamType::kF16, "min"},
BuiltinData{BuiltinType::kSqrt, ParamType::kF32, "sqrt"}, BuiltinData{BuiltinType::kMix, CallParamType::kF32, "lerp"},
BuiltinData{BuiltinType::kStep, ParamType::kF32, "step"}, BuiltinData{BuiltinType::kMix, CallParamType::kF16, "lerp"},
BuiltinData{BuiltinType::kTan, ParamType::kF32, "tan"}, BuiltinData{BuiltinType::kNormalize, CallParamType::kF32, "normalize"},
BuiltinData{BuiltinType::kTanh, ParamType::kF32, "tanh"}, BuiltinData{BuiltinType::kNormalize, CallParamType::kF16, "normalize"},
BuiltinData{BuiltinType::kTranspose, ParamType::kF32, "transpose"}, BuiltinData{BuiltinType::kPow, CallParamType::kF32, "pow"},
BuiltinData{BuiltinType::kTrunc, ParamType::kF32, "trunc"})); BuiltinData{BuiltinType::kPow, CallParamType::kF16, "pow"},
BuiltinData{BuiltinType::kReflect, CallParamType::kF32, "reflect"},
BuiltinData{BuiltinType::kReflect, CallParamType::kF16, "reflect"},
BuiltinData{BuiltinType::kSign, CallParamType::kF32, "sign"},
BuiltinData{BuiltinType::kSign, CallParamType::kF16, "sign"},
BuiltinData{BuiltinType::kSin, CallParamType::kF32, "sin"},
BuiltinData{BuiltinType::kSin, CallParamType::kF16, "sin"},
BuiltinData{BuiltinType::kSinh, CallParamType::kF32, "sinh"},
BuiltinData{BuiltinType::kSinh, CallParamType::kF16, "sinh"},
BuiltinData{BuiltinType::kSmoothstep, CallParamType::kF32, "smoothstep"},
BuiltinData{BuiltinType::kSmoothstep, CallParamType::kF16, "smoothstep"},
BuiltinData{BuiltinType::kSqrt, CallParamType::kF32, "sqrt"},
BuiltinData{BuiltinType::kSqrt, CallParamType::kF16, "sqrt"},
BuiltinData{BuiltinType::kStep, CallParamType::kF32, "step"},
BuiltinData{BuiltinType::kStep, CallParamType::kF16, "step"},
BuiltinData{BuiltinType::kTan, CallParamType::kF32, "tan"},
BuiltinData{BuiltinType::kTan, CallParamType::kF16, "tan"},
BuiltinData{BuiltinType::kTanh, CallParamType::kF32, "tanh"},
BuiltinData{BuiltinType::kTanh, CallParamType::kF16, "tanh"},
BuiltinData{BuiltinType::kTrunc, CallParamType::kF32, "trunc"},
BuiltinData{BuiltinType::kTrunc, CallParamType::kF16, "trunc"},
/* Integer built-in */
BuiltinData{BuiltinType::kAbs, CallParamType::kU32, "abs"},
BuiltinData{BuiltinType::kClamp, CallParamType::kU32, "clamp"},
BuiltinData{BuiltinType::kCountOneBits, CallParamType::kU32, "countbits"},
BuiltinData{BuiltinType::kMax, CallParamType::kU32, "max"},
BuiltinData{BuiltinType::kMin, CallParamType::kU32, "min"},
BuiltinData{BuiltinType::kReverseBits, CallParamType::kU32, "reversebits"},
BuiltinData{BuiltinType::kRound, CallParamType::kU32, "round"},
/* Matrix built-in */
BuiltinData{BuiltinType::kDeterminant, CallParamType::kF32, "determinant"},
BuiltinData{BuiltinType::kDeterminant, CallParamType::kF16, "determinant"},
BuiltinData{BuiltinType::kTranspose, CallParamType::kF32, "transpose"},
BuiltinData{BuiltinType::kTranspose, CallParamType::kF16, "transpose"},
/* Vector built-in */
BuiltinData{BuiltinType::kDot, CallParamType::kF32, "dot"},
BuiltinData{BuiltinType::kDot, CallParamType::kF16, "dot"},
/* Derivate built-in */
BuiltinData{BuiltinType::kDpdx, CallParamType::kF32, "ddx"},
BuiltinData{BuiltinType::kDpdxCoarse, CallParamType::kF32, "ddx_coarse"},
BuiltinData{BuiltinType::kDpdxFine, CallParamType::kF32, "ddx_fine"},
BuiltinData{BuiltinType::kDpdy, CallParamType::kF32, "ddy"},
BuiltinData{BuiltinType::kDpdyCoarse, CallParamType::kF32, "ddy_coarse"},
BuiltinData{BuiltinType::kDpdyFine, CallParamType::kF32, "ddy_fine"},
BuiltinData{BuiltinType::kFwidth, CallParamType::kF32, "fwidth"},
BuiltinData{BuiltinType::kFwidthCoarse, CallParamType::kF32, "fwidth"},
BuiltinData{BuiltinType::kFwidthFine, CallParamType::kF32, "fwidth"}));
TEST_F(HlslGeneratorImplTest_Builtin, Builtin_Call) { TEST_F(HlslGeneratorImplTest_Builtin, Builtin_Call) {
auto* call = Call("dot", "param1", "param2"); auto* call = Call("dot", "param1", "param2");
@ -380,7 +478,7 @@ void test_function() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Builtin, Degrees_Scalar) { TEST_F(HlslGeneratorImplTest_Builtin, Degrees_Scalar_f32) {
auto* val = Var("val", ty.f32()); auto* val = Var("val", ty.f32());
auto* call = Call("degrees", val); auto* call = Call("degrees", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -401,7 +499,7 @@ void test_function() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Builtin, Degrees_Vector) { TEST_F(HlslGeneratorImplTest_Builtin, Degrees_Vector_f32) {
auto* val = Var("val", ty.vec3<f32>()); auto* val = Var("val", ty.vec3<f32>());
auto* call = Call("degrees", val); auto* call = Call("degrees", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -422,7 +520,53 @@ void test_function() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Builtin, Radians_Scalar) { TEST_F(HlslGeneratorImplTest_Builtin, Degrees_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.f16());
auto* call = Call("degrees", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(float16_t tint_degrees(float16_t param_0) {
return param_0 * 57.295779513082322865;
}
[numthreads(1, 1, 1)]
void test_function() {
float16_t val = float16_t(0.0h);
const float16_t tint_symbol = tint_degrees(val);
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Builtin, Degrees_Vector_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.vec3<f16>());
auto* call = Call("degrees", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(vector<float16_t, 3> tint_degrees(vector<float16_t, 3> param_0) {
return param_0 * 57.295779513082322865;
}
[numthreads(1, 1, 1)]
void test_function() {
vector<float16_t, 3> val = vector<float16_t, 3>(float16_t(0.0h), float16_t(0.0h), float16_t(0.0h));
const vector<float16_t, 3> tint_symbol = tint_degrees(val);
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Builtin, Radians_Scalar_f32) {
auto* val = Var("val", ty.f32()); auto* val = Var("val", ty.f32());
auto* call = Call("radians", val); auto* call = Call("radians", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -443,7 +587,7 @@ void test_function() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Builtin, Radians_Vector) { TEST_F(HlslGeneratorImplTest_Builtin, Radians_Vector_f32) {
auto* val = Var("val", ty.vec3<f32>()); auto* val = Var("val", ty.vec3<f32>());
auto* call = Call("radians", val); auto* call = Call("radians", val);
WrapInFunction(val, call); WrapInFunction(val, call);
@ -464,6 +608,52 @@ void test_function() {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Builtin, Radians_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.f16());
auto* call = Call("radians", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(float16_t tint_radians(float16_t param_0) {
return param_0 * 0.017453292519943295474;
}
[numthreads(1, 1, 1)]
void test_function() {
float16_t val = float16_t(0.0h);
const float16_t tint_symbol = tint_radians(val);
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Builtin, Radians_Vector_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.vec3<f16>());
auto* call = Call("radians", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(vector<float16_t, 3> tint_radians(vector<float16_t, 3> param_0) {
return param_0 * 0.017453292519943295474;
}
[numthreads(1, 1, 1)]
void test_function() {
vector<float16_t, 3> val = vector<float16_t, 3>(float16_t(0.0h), float16_t(0.0h), float16_t(0.0h));
const vector<float16_t, 3> tint_symbol = tint_radians(val);
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Builtin, Pack4x8Snorm) { TEST_F(HlslGeneratorImplTest_Builtin, Pack4x8Snorm) {
auto* call = Call("pack4x8snorm", "p1"); auto* call = Call("pack4x8snorm", "p1");
GlobalVar("p1", ty.vec4<f32>(), ast::StorageClass::kPrivate); GlobalVar("p1", ty.vec4<f32>(), ast::StorageClass::kPrivate);

View File

@ -25,36 +25,40 @@ using BuiltinType = sem::BuiltinType;
using MslGeneratorImplTest = TestHelper; using MslGeneratorImplTest = TestHelper;
enum class ParamType { enum class CallParamType {
kF32, kF32,
kU32, kU32,
kBool, kBool,
kF16,
}; };
struct BuiltinData { struct BuiltinData {
BuiltinType builtin; BuiltinType builtin;
ParamType type; CallParamType type;
const char* msl_name; const char* msl_name;
}; };
inline std::ostream& operator<<(std::ostream& out, BuiltinData data) { inline std::ostream& operator<<(std::ostream& out, BuiltinData data) {
out << data.msl_name << "<"; out << data.msl_name << "<";
switch (data.type) { switch (data.type) {
case ParamType::kF32: case CallParamType::kF32:
out << "f32"; out << "f32";
break; break;
case ParamType::kU32: case CallParamType::kU32:
out << "u32"; out << "u32";
break; break;
case ParamType::kBool: case CallParamType::kBool:
out << "bool"; out << "bool";
break; break;
case CallParamType::kF16:
out << "f16";
break;
} }
out << ">"; out << ">";
return out; return out;
} }
const ast::CallExpression* GenerateCall(BuiltinType builtin, const ast::CallExpression* GenerateCall(BuiltinType builtin,
ParamType type, CallParamType type,
ProgramBuilder* builder) { ProgramBuilder* builder) {
std::string name; std::string name;
std::ostringstream str(name); std::ostringstream str(name);
@ -92,31 +96,53 @@ const ast::CallExpression* GenerateCall(BuiltinType builtin,
case BuiltinType::kTanh: case BuiltinType::kTanh:
case BuiltinType::kTrunc: case BuiltinType::kTrunc:
case BuiltinType::kSign: case BuiltinType::kSign:
return builder->Call(str.str(), "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else {
return builder->Call(str.str(), "f2");
}
case BuiltinType::kLdexp: case BuiltinType::kLdexp:
return builder->Call(str.str(), "f2", "i2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "i2");
} else {
return builder->Call(str.str(), "f2", "i2");
}
case BuiltinType::kAtan2: case BuiltinType::kAtan2:
case BuiltinType::kDot: case BuiltinType::kDot:
case BuiltinType::kDistance: case BuiltinType::kDistance:
case BuiltinType::kPow: case BuiltinType::kPow:
case BuiltinType::kReflect: case BuiltinType::kReflect:
case BuiltinType::kStep: case BuiltinType::kStep:
return builder->Call(str.str(), "f2", "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2");
} else {
return builder->Call(str.str(), "f2", "f2");
}
case BuiltinType::kStorageBarrier: case BuiltinType::kStorageBarrier:
return builder->Call(str.str()); return builder->Call(str.str());
case BuiltinType::kCross: case BuiltinType::kCross:
return builder->Call(str.str(), "f3", "f3"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h3", "h3");
} else {
return builder->Call(str.str(), "f3", "f3");
}
case BuiltinType::kFma: case BuiltinType::kFma:
case BuiltinType::kMix: case BuiltinType::kMix:
case BuiltinType::kFaceForward: case BuiltinType::kFaceForward:
case BuiltinType::kSmoothstep: case BuiltinType::kSmoothstep:
return builder->Call(str.str(), "f2", "f2", "f2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "h2");
} else {
return builder->Call(str.str(), "f2", "f2", "f2");
}
case BuiltinType::kAll: case BuiltinType::kAll:
case BuiltinType::kAny: case BuiltinType::kAny:
return builder->Call(str.str(), "b2"); return builder->Call(str.str(), "b2");
case BuiltinType::kAbs: case BuiltinType::kAbs:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2"); return builder->Call(str.str(), "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2");
} else { } else {
return builder->Call(str.str(), "u2"); return builder->Call(str.str(), "u2");
} }
@ -131,21 +157,33 @@ const ast::CallExpression* GenerateCall(BuiltinType builtin,
return builder->Call(str.str(), "u2", "u2", "u1", "u1"); return builder->Call(str.str(), "u2", "u2", "u1", "u1");
case BuiltinType::kMax: case BuiltinType::kMax:
case BuiltinType::kMin: case BuiltinType::kMin:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2", "f2"); return builder->Call(str.str(), "f2", "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2");
} else { } else {
return builder->Call(str.str(), "u2", "u2"); return builder->Call(str.str(), "u2", "u2");
} }
case BuiltinType::kClamp: case BuiltinType::kClamp:
if (type == ParamType::kF32) { if (type == CallParamType::kF32) {
return builder->Call(str.str(), "f2", "f2", "f2"); return builder->Call(str.str(), "f2", "f2", "f2");
} else if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "h2");
} else { } else {
return builder->Call(str.str(), "u2", "u2", "u2"); return builder->Call(str.str(), "u2", "u2", "u2");
} }
case BuiltinType::kSelect: case BuiltinType::kSelect:
return builder->Call(str.str(), "f2", "f2", "b2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "h2", "h2", "b2");
} else {
return builder->Call(str.str(), "f2", "f2", "b2");
}
case BuiltinType::kDeterminant: case BuiltinType::kDeterminant:
return builder->Call(str.str(), "m2x2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "hm2x2");
} else {
return builder->Call(str.str(), "m2x2");
}
case BuiltinType::kPack2x16snorm: case BuiltinType::kPack2x16snorm:
case BuiltinType::kPack2x16unorm: case BuiltinType::kPack2x16unorm:
return builder->Call(str.str(), "f2"); return builder->Call(str.str(), "f2");
@ -160,7 +198,11 @@ const ast::CallExpression* GenerateCall(BuiltinType builtin,
case BuiltinType::kWorkgroupBarrier: case BuiltinType::kWorkgroupBarrier:
return builder->Call(str.str()); return builder->Call(str.str());
case BuiltinType::kTranspose: case BuiltinType::kTranspose:
return builder->Call(str.str(), "m3x2"); if (type == CallParamType::kF16) {
return builder->Call(str.str(), "hm3x2");
} else {
return builder->Call(str.str(), "m3x2");
}
default: default:
break; break;
} }
@ -171,6 +213,15 @@ using MslBuiltinTest = TestParamHelper<BuiltinData>;
TEST_P(MslBuiltinTest, Emit) { TEST_P(MslBuiltinTest, Emit) {
auto param = GetParam(); auto param = GetParam();
if (param.type == CallParamType::kF16) {
Enable(ast::Extension::kF16);
GlobalVar("h2", ty.vec2<f16>(), ast::StorageClass::kPrivate);
GlobalVar("h3", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("hm2x2", ty.mat2x2<f16>(), ast::StorageClass::kPrivate);
GlobalVar("hm3x2", ty.mat3x2<f16>(), ast::StorageClass::kPrivate);
}
GlobalVar("f2", ty.vec2<f32>(), ast::StorageClass::kPrivate); GlobalVar("f2", ty.vec2<f32>(), ast::StorageClass::kPrivate);
GlobalVar("f3", ty.vec3<f32>(), ast::StorageClass::kPrivate); GlobalVar("f3", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("f4", ty.vec4<f32>(), ast::StorageClass::kPrivate); GlobalVar("f4", ty.vec4<f32>(), ast::StorageClass::kPrivate);
@ -201,76 +252,122 @@ INSTANTIATE_TEST_SUITE_P(
MslGeneratorImplTest, MslGeneratorImplTest,
MslBuiltinTest, MslBuiltinTest,
testing::Values( testing::Values(
BuiltinData{BuiltinType::kAbs, ParamType::kF32, "fabs"}, /* Logical built-in */
BuiltinData{BuiltinType::kAbs, ParamType::kU32, "abs"}, BuiltinData{BuiltinType::kAll, CallParamType::kBool, "all"},
BuiltinData{BuiltinType::kAcos, ParamType::kF32, "acos"}, BuiltinData{BuiltinType::kAny, CallParamType::kBool, "any"},
BuiltinData{BuiltinType::kAll, ParamType::kBool, "all"}, BuiltinData{BuiltinType::kSelect, CallParamType::kF32, "select"},
BuiltinData{BuiltinType::kAny, ParamType::kBool, "any"}, /* Float built-in */
BuiltinData{BuiltinType::kAsin, ParamType::kF32, "asin"}, BuiltinData{BuiltinType::kAbs, CallParamType::kF32, "fabs"},
BuiltinData{BuiltinType::kAtan, ParamType::kF32, "atan"}, BuiltinData{BuiltinType::kAbs, CallParamType::kF16, "fabs"},
BuiltinData{BuiltinType::kAtan2, ParamType::kF32, "atan2"}, BuiltinData{BuiltinType::kAcos, CallParamType::kF32, "acos"},
BuiltinData{BuiltinType::kCeil, ParamType::kF32, "ceil"}, BuiltinData{BuiltinType::kAcos, CallParamType::kF16, "acos"},
BuiltinData{BuiltinType::kClamp, ParamType::kF32, "clamp"}, BuiltinData{BuiltinType::kAsin, CallParamType::kF32, "asin"},
BuiltinData{BuiltinType::kClamp, ParamType::kU32, "clamp"}, BuiltinData{BuiltinType::kAsin, CallParamType::kF16, "asin"},
BuiltinData{BuiltinType::kCos, ParamType::kF32, "cos"}, BuiltinData{BuiltinType::kAtan, CallParamType::kF32, "atan"},
BuiltinData{BuiltinType::kCosh, ParamType::kF32, "cosh"}, BuiltinData{BuiltinType::kAtan, CallParamType::kF16, "atan"},
BuiltinData{BuiltinType::kCountLeadingZeros, ParamType::kU32, "clz"}, BuiltinData{BuiltinType::kAtan2, CallParamType::kF32, "atan2"},
BuiltinData{BuiltinType::kCountOneBits, ParamType::kU32, "popcount"}, BuiltinData{BuiltinType::kAtan2, CallParamType::kF16, "atan2"},
BuiltinData{BuiltinType::kCountTrailingZeros, ParamType::kU32, "ctz"}, BuiltinData{BuiltinType::kCeil, CallParamType::kF32, "ceil"},
BuiltinData{BuiltinType::kCross, ParamType::kF32, "cross"}, BuiltinData{BuiltinType::kCeil, CallParamType::kF16, "ceil"},
BuiltinData{BuiltinType::kDeterminant, ParamType::kF32, "determinant"}, BuiltinData{BuiltinType::kClamp, CallParamType::kF32, "clamp"},
BuiltinData{BuiltinType::kDistance, ParamType::kF32, "distance"}, BuiltinData{BuiltinType::kClamp, CallParamType::kF16, "clamp"},
BuiltinData{BuiltinType::kDot, ParamType::kF32, "dot"}, BuiltinData{BuiltinType::kCos, CallParamType::kF32, "cos"},
BuiltinData{BuiltinType::kDpdx, ParamType::kF32, "dfdx"}, BuiltinData{BuiltinType::kCos, CallParamType::kF16, "cos"},
BuiltinData{BuiltinType::kDpdxCoarse, ParamType::kF32, "dfdx"}, BuiltinData{BuiltinType::kCosh, CallParamType::kF32, "cosh"},
BuiltinData{BuiltinType::kDpdxFine, ParamType::kF32, "dfdx"}, BuiltinData{BuiltinType::kCosh, CallParamType::kF16, "cosh"},
BuiltinData{BuiltinType::kDpdy, ParamType::kF32, "dfdy"}, BuiltinData{BuiltinType::kCross, CallParamType::kF32, "cross"},
BuiltinData{BuiltinType::kDpdyCoarse, ParamType::kF32, "dfdy"}, BuiltinData{BuiltinType::kCross, CallParamType::kF16, "cross"},
BuiltinData{BuiltinType::kDpdyFine, ParamType::kF32, "dfdy"}, BuiltinData{BuiltinType::kDistance, CallParamType::kF32, "distance"},
BuiltinData{BuiltinType::kExp, ParamType::kF32, "exp"}, BuiltinData{BuiltinType::kDistance, CallParamType::kF16, "distance"},
BuiltinData{BuiltinType::kExp2, ParamType::kF32, "exp2"}, BuiltinData{BuiltinType::kExp, CallParamType::kF32, "exp"},
BuiltinData{BuiltinType::kExtractBits, ParamType::kU32, "extract_bits"}, BuiltinData{BuiltinType::kExp, CallParamType::kF16, "exp"},
BuiltinData{BuiltinType::kFaceForward, ParamType::kF32, "faceforward"}, BuiltinData{BuiltinType::kExp2, CallParamType::kF32, "exp2"},
BuiltinData{BuiltinType::kFloor, ParamType::kF32, "floor"}, BuiltinData{BuiltinType::kExp2, CallParamType::kF16, "exp2"},
BuiltinData{BuiltinType::kFma, ParamType::kF32, "fma"}, BuiltinData{BuiltinType::kFaceForward, CallParamType::kF32, "faceforward"},
BuiltinData{BuiltinType::kFract, ParamType::kF32, "fract"}, BuiltinData{BuiltinType::kFaceForward, CallParamType::kF16, "faceforward"},
BuiltinData{BuiltinType::kFwidth, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFloor, CallParamType::kF32, "floor"},
BuiltinData{BuiltinType::kFwidthCoarse, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFloor, CallParamType::kF16, "floor"},
BuiltinData{BuiltinType::kFwidthFine, ParamType::kF32, "fwidth"}, BuiltinData{BuiltinType::kFma, CallParamType::kF32, "fma"},
BuiltinData{BuiltinType::kInsertBits, ParamType::kU32, "insert_bits"}, BuiltinData{BuiltinType::kFma, CallParamType::kF16, "fma"},
BuiltinData{BuiltinType::kInverseSqrt, ParamType::kF32, "rsqrt"}, BuiltinData{BuiltinType::kFract, CallParamType::kF32, "fract"},
BuiltinData{BuiltinType::kLdexp, ParamType::kF32, "ldexp"}, BuiltinData{BuiltinType::kFract, CallParamType::kF16, "fract"},
BuiltinData{BuiltinType::kLength, ParamType::kF32, "length"}, BuiltinData{BuiltinType::kInverseSqrt, CallParamType::kF32, "rsqrt"},
BuiltinData{BuiltinType::kLog, ParamType::kF32, "log"}, BuiltinData{BuiltinType::kInverseSqrt, CallParamType::kF16, "rsqrt"},
BuiltinData{BuiltinType::kLog2, ParamType::kF32, "log2"}, BuiltinData{BuiltinType::kLdexp, CallParamType::kF32, "ldexp"},
BuiltinData{BuiltinType::kMax, ParamType::kF32, "fmax"}, BuiltinData{BuiltinType::kLdexp, CallParamType::kF16, "ldexp"},
BuiltinData{BuiltinType::kMax, ParamType::kU32, "max"}, BuiltinData{BuiltinType::kLength, CallParamType::kF32, "length"},
BuiltinData{BuiltinType::kMin, ParamType::kF32, "fmin"}, BuiltinData{BuiltinType::kLength, CallParamType::kF16, "length"},
BuiltinData{BuiltinType::kMin, ParamType::kU32, "min"}, BuiltinData{BuiltinType::kLog, CallParamType::kF32, "log"},
BuiltinData{BuiltinType::kNormalize, ParamType::kF32, "normalize"}, BuiltinData{BuiltinType::kLog, CallParamType::kF16, "log"},
BuiltinData{BuiltinType::kPack4x8snorm, ParamType::kF32, "pack_float_to_snorm4x8"}, BuiltinData{BuiltinType::kLog2, CallParamType::kF32, "log2"},
BuiltinData{BuiltinType::kPack4x8unorm, ParamType::kF32, "pack_float_to_unorm4x8"}, BuiltinData{BuiltinType::kLog2, CallParamType::kF16, "log2"},
BuiltinData{BuiltinType::kPack2x16snorm, ParamType::kF32, "pack_float_to_snorm2x16"}, BuiltinData{BuiltinType::kMax, CallParamType::kF32, "fmax"},
BuiltinData{BuiltinType::kPack2x16unorm, ParamType::kF32, "pack_float_to_unorm2x16"}, BuiltinData{BuiltinType::kMax, CallParamType::kF16, "fmax"},
BuiltinData{BuiltinType::kPow, ParamType::kF32, "pow"}, BuiltinData{BuiltinType::kMin, CallParamType::kF32, "fmin"},
BuiltinData{BuiltinType::kReflect, ParamType::kF32, "reflect"}, BuiltinData{BuiltinType::kMin, CallParamType::kF16, "fmin"},
BuiltinData{BuiltinType::kReverseBits, ParamType::kU32, "reverse_bits"}, BuiltinData{BuiltinType::kNormalize, CallParamType::kF32, "normalize"},
BuiltinData{BuiltinType::kRound, ParamType::kU32, "rint"}, BuiltinData{BuiltinType::kNormalize, CallParamType::kF16, "normalize"},
BuiltinData{BuiltinType::kSelect, ParamType::kF32, "select"}, BuiltinData{BuiltinType::kPow, CallParamType::kF32, "pow"},
BuiltinData{BuiltinType::kSign, ParamType::kF32, "sign"}, BuiltinData{BuiltinType::kPow, CallParamType::kF16, "pow"},
BuiltinData{BuiltinType::kSin, ParamType::kF32, "sin"}, BuiltinData{BuiltinType::kReflect, CallParamType::kF32, "reflect"},
BuiltinData{BuiltinType::kSinh, ParamType::kF32, "sinh"}, BuiltinData{BuiltinType::kReflect, CallParamType::kF16, "reflect"},
BuiltinData{BuiltinType::kSmoothstep, ParamType::kF32, "smoothstep"}, BuiltinData{BuiltinType::kSign, CallParamType::kF32, "sign"},
BuiltinData{BuiltinType::kSqrt, ParamType::kF32, "sqrt"}, BuiltinData{BuiltinType::kSign, CallParamType::kF16, "sign"},
BuiltinData{BuiltinType::kStep, ParamType::kF32, "step"}, BuiltinData{BuiltinType::kSin, CallParamType::kF32, "sin"},
BuiltinData{BuiltinType::kTan, ParamType::kF32, "tan"}, BuiltinData{BuiltinType::kSin, CallParamType::kF16, "sin"},
BuiltinData{BuiltinType::kTanh, ParamType::kF32, "tanh"}, BuiltinData{BuiltinType::kSinh, CallParamType::kF32, "sinh"},
BuiltinData{BuiltinType::kTranspose, ParamType::kF32, "transpose"}, BuiltinData{BuiltinType::kSinh, CallParamType::kF16, "sinh"},
BuiltinData{BuiltinType::kTrunc, ParamType::kF32, "trunc"}, BuiltinData{BuiltinType::kSmoothstep, CallParamType::kF32, "smoothstep"},
BuiltinData{BuiltinType::kUnpack4x8snorm, ParamType::kU32, "unpack_snorm4x8_to_float"}, BuiltinData{BuiltinType::kSmoothstep, CallParamType::kF16, "smoothstep"},
BuiltinData{BuiltinType::kUnpack4x8unorm, ParamType::kU32, "unpack_unorm4x8_to_float"}, BuiltinData{BuiltinType::kSqrt, CallParamType::kF32, "sqrt"},
BuiltinData{BuiltinType::kUnpack2x16snorm, ParamType::kU32, "unpack_snorm2x16_to_float"}, BuiltinData{BuiltinType::kSqrt, CallParamType::kF16, "sqrt"},
BuiltinData{BuiltinType::kUnpack2x16unorm, ParamType::kU32, "unpack_unorm2x16_to_float"})); BuiltinData{BuiltinType::kStep, CallParamType::kF32, "step"},
BuiltinData{BuiltinType::kStep, CallParamType::kF16, "step"},
BuiltinData{BuiltinType::kTan, CallParamType::kF32, "tan"},
BuiltinData{BuiltinType::kTan, CallParamType::kF16, "tan"},
BuiltinData{BuiltinType::kTanh, CallParamType::kF32, "tanh"},
BuiltinData{BuiltinType::kTanh, CallParamType::kF16, "tanh"},
BuiltinData{BuiltinType::kTrunc, CallParamType::kF32, "trunc"},
BuiltinData{BuiltinType::kTrunc, CallParamType::kF16, "trunc"},
/* Integer built-in */
BuiltinData{BuiltinType::kAbs, CallParamType::kU32, "abs"},
BuiltinData{BuiltinType::kClamp, CallParamType::kU32, "clamp"},
BuiltinData{BuiltinType::kCountLeadingZeros, CallParamType::kU32, "clz"},
BuiltinData{BuiltinType::kCountOneBits, CallParamType::kU32, "popcount"},
BuiltinData{BuiltinType::kCountTrailingZeros, CallParamType::kU32, "ctz"},
BuiltinData{BuiltinType::kExtractBits, CallParamType::kU32, "extract_bits"},
BuiltinData{BuiltinType::kInsertBits, CallParamType::kU32, "insert_bits"},
BuiltinData{BuiltinType::kMax, CallParamType::kU32, "max"},
BuiltinData{BuiltinType::kMin, CallParamType::kU32, "min"},
BuiltinData{BuiltinType::kReverseBits, CallParamType::kU32, "reverse_bits"},
BuiltinData{BuiltinType::kRound, CallParamType::kU32, "rint"},
/* Matrix built-in */
BuiltinData{BuiltinType::kDeterminant, CallParamType::kF32, "determinant"},
BuiltinData{BuiltinType::kTranspose, CallParamType::kF32, "transpose"},
/* Vector built-in */
BuiltinData{BuiltinType::kDot, CallParamType::kF32, "dot"},
/* Derivate built-in */
BuiltinData{BuiltinType::kDpdx, CallParamType::kF32, "dfdx"},
BuiltinData{BuiltinType::kDpdxCoarse, CallParamType::kF32, "dfdx"},
BuiltinData{BuiltinType::kDpdxFine, CallParamType::kF32, "dfdx"},
BuiltinData{BuiltinType::kDpdy, CallParamType::kF32, "dfdy"},
BuiltinData{BuiltinType::kDpdyCoarse, CallParamType::kF32, "dfdy"},
BuiltinData{BuiltinType::kDpdyFine, CallParamType::kF32, "dfdy"},
BuiltinData{BuiltinType::kFwidth, CallParamType::kF32, "fwidth"},
BuiltinData{BuiltinType::kFwidthCoarse, CallParamType::kF32, "fwidth"},
BuiltinData{BuiltinType::kFwidthFine, CallParamType::kF32, "fwidth"},
/* Data packing builtin */
BuiltinData{BuiltinType::kPack4x8snorm, CallParamType::kF32, "pack_float_to_snorm4x8"},
BuiltinData{BuiltinType::kPack4x8unorm, CallParamType::kF32, "pack_float_to_unorm4x8"},
BuiltinData{BuiltinType::kPack2x16snorm, CallParamType::kF32, "pack_float_to_snorm2x16"},
BuiltinData{BuiltinType::kPack2x16unorm, CallParamType::kF32, "pack_float_to_unorm2x16"},
/* Data unpacking builtin */
BuiltinData{BuiltinType::kUnpack4x8snorm, CallParamType::kU32, "unpack_snorm4x8_to_float"},
BuiltinData{BuiltinType::kUnpack4x8unorm, CallParamType::kU32, "unpack_unorm4x8_to_float"},
BuiltinData{BuiltinType::kUnpack2x16snorm, CallParamType::kU32,
"unpack_snorm2x16_to_float"},
BuiltinData{BuiltinType::kUnpack2x16unorm, CallParamType::kU32,
"unpack_unorm2x16_to_float"}));
TEST_F(MslGeneratorImplTest, Builtin_Call) { TEST_F(MslGeneratorImplTest, Builtin_Call) {
GlobalVar("param1", ty.vec2<f32>(), ast::StorageClass::kPrivate); GlobalVar("param1", ty.vec2<f32>(), ast::StorageClass::kPrivate);
@ -308,12 +405,12 @@ TEST_F(MslGeneratorImplTest, WorkgroupBarrier) {
EXPECT_EQ(out.str(), "threadgroup_barrier(mem_flags::mem_threadgroup)"); EXPECT_EQ(out.str(), "threadgroup_barrier(mem_flags::mem_threadgroup)");
} }
TEST_F(MslGeneratorImplTest, Degrees_Scalar) { TEST_F(MslGeneratorImplTest, Degrees_Scalar_f32) {
auto* val = Var("val", ty.f32()); auto* val = Var("val", ty.f32());
auto* call = Call("degrees", val); auto* call = Call("degrees", val);
WrapInFunction(val, call); WrapInFunction(val, call);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error(); ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
@ -333,12 +430,12 @@ kernel void test_function() {
)"); )");
} }
TEST_F(MslGeneratorImplTest, Degrees_Vector) { TEST_F(MslGeneratorImplTest, Degrees_Vector_f32) {
auto* val = Var("val", ty.vec3<f32>()); auto* val = Var("val", ty.vec3<f32>());
auto* call = Call("degrees", val); auto* call = Call("degrees", val);
WrapInFunction(val, call); WrapInFunction(val, call);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error(); ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
@ -358,12 +455,66 @@ kernel void test_function() {
)"); )");
} }
TEST_F(MslGeneratorImplTest, Radians_Scalar) { TEST_F(MslGeneratorImplTest, Degrees_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.f16());
auto* call = Call("degrees", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
half tint_degrees(half param_0) {
return param_0 * 57.295779513082322865;
}
kernel void test_function() {
half val = 0.0h;
half const tint_symbol = tint_degrees(val);
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Degrees_Vector_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.vec3<f16>());
auto* call = Call("degrees", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
half3 tint_degrees(half3 param_0) {
return param_0 * 57.295779513082322865;
}
kernel void test_function() {
half3 val = 0.0h;
half3 const tint_symbol = tint_degrees(val);
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Radians_Scalar_f32) {
auto* val = Var("val", ty.f32()); auto* val = Var("val", ty.f32());
auto* call = Call("radians", val); auto* call = Call("radians", val);
WrapInFunction(val, call); WrapInFunction(val, call);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error(); ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
@ -383,12 +534,12 @@ kernel void test_function() {
)"); )");
} }
TEST_F(MslGeneratorImplTest, Radians_Vector) { TEST_F(MslGeneratorImplTest, Radians_Vector_f32) {
auto* val = Var("val", ty.vec3<f32>()); auto* val = Var("val", ty.vec3<f32>());
auto* call = Call("radians", val); auto* call = Call("radians", val);
WrapInFunction(val, call); WrapInFunction(val, call);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error(); ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
@ -408,6 +559,60 @@ kernel void test_function() {
)"); )");
} }
TEST_F(MslGeneratorImplTest, Radians_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.f16());
auto* call = Call("radians", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
half tint_radians(half param_0) {
return param_0 * 0.017453292519943295474;
}
kernel void test_function() {
half val = 0.0h;
half const tint_symbol = tint_radians(val);
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Radians_Vector_f16) {
Enable(ast::Extension::kF16);
auto* val = Var("val", ty.vec3<f16>());
auto* call = Call("radians", val);
WrapInFunction(val, call);
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
half3 tint_radians(half3 param_0) {
return param_0 * 0.017453292519943295474;
}
kernel void test_function() {
half3 val = 0.0h;
half3 const tint_symbol = tint_radians(val);
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Pack2x16Float) { TEST_F(MslGeneratorImplTest, Pack2x16Float) {
auto* call = Call("pack2x16float", "p1"); auto* call = Call("pack2x16float", "p1");
GlobalVar("p1", ty.vec2<f32>(), ast::StorageClass::kPrivate); GlobalVar("p1", ty.vec2<f32>(), ast::StorageClass::kPrivate);

View File

@ -131,6 +131,42 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_GLSLMethod_WithLoad_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalVar("ident", ty.f16(), ast::StorageClass::kPrivate);
auto* expr = Call("round", "ident");
auto* func = Func("a_func", {}, ty.void_(),
{
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect =
R"(%10 = OpExtInstImport "GLSL.std.450"
OpName %1 "ident"
OpName %7 "a_func"
%3 = OpTypeFloat 16
%2 = OpTypePointer Private %3
%4 = OpConstantNull %3
%1 = OpVariable %2 Private %4
%6 = OpTypeVoid
%5 = OpTypeFunction %6
%7 = OpFunction %6 None %5
%8 = OpLabel
%11 = OpLoad %3 %1
%9 = OpExtInst %3 %10 RoundEven %11
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
// Tests for Logical builtins // Tests for Logical builtins
namespace logical_builtin_tests { namespace logical_builtin_tests {
@ -485,6 +521,47 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_P(Builtin_Builder_SingleParam_Float_Test, Call_Scalar_f16) {
Enable(ast::Extension::kF16);
auto param = GetParam();
// Use a variable to prevent the function being evaluated as constant.
auto* scalar = Var("a", nullptr, Expr(1_h));
auto* expr = Call(param.name, scalar);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(scalar),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%11 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %7 "a"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = OpTypeFloat 16
%6 = OpConstant %5 0x1p+0
%8 = OpTypePointer Function %5
%9 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%7 = OpVariable %8 Function %9
OpStore %7 %6
%12 = OpLoad %5 %7
%10 = OpExtInst %5 %11 )" +
param.op +
R"( %12
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_P(Builtin_Builder_SingleParam_Float_Test, Call_Vector_f32) { TEST_P(Builtin_Builder_SingleParam_Float_Test, Call_Vector_f32) {
auto param = GetParam(); auto param = GetParam();
@ -527,6 +604,50 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_P(Builtin_Builder_SingleParam_Float_Test, Call_Vector_f16) {
Enable(ast::Extension::kF16);
auto param = GetParam();
// Use a variable to prevent the function being evaluated as constant.
auto* vec = Var("a", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call(param.name, vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "a"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%12 = OpExtInst %5 %13 )" +
param.op +
R"( %14
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
INSTANTIATE_TEST_SUITE_P(BuiltinBuilderTest, INSTANTIATE_TEST_SUITE_P(BuiltinBuilderTest,
Builtin_Builder_SingleParam_Float_Test, Builtin_Builder_SingleParam_Float_Test,
testing::Values(BuiltinData{"abs", "FAbs"}, testing::Values(BuiltinData{"abs", "FAbs"},
@ -589,6 +710,43 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Length_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* scalar = Var("a", nullptr, Expr(1_h));
auto* expr = Call("length", scalar);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(scalar),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%11 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %7 "a"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = OpTypeFloat 16
%6 = OpConstant %5 0x1p+0
%8 = OpTypePointer Function %5
%9 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%7 = OpVariable %8 Function %9
OpStore %7 %6
%12 = OpLoad %5 %7
%10 = OpExtInst %5 %11 Length %12
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_F(BuiltinBuilderTest, Call_Length_Vector_f32) { TEST_F(BuiltinBuilderTest, Call_Length_Vector_f32) {
auto* vec = Var("a", nullptr, vec2<f32>(1_f, 1_f)); auto* vec = Var("a", nullptr, vec2<f32>(1_f, 1_f));
auto* expr = Call("length", vec); auto* expr = Call("length", vec);
@ -626,6 +784,45 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Length_Vector_f16) {
Enable(ast::Extension::kF16);
auto* vec = Var("a", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call("length", vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "a"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%12 = OpExtInst %6 %13 Length %14
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_F(BuiltinBuilderTest, Call_Normalize_f32) { TEST_F(BuiltinBuilderTest, Call_Normalize_f32) {
auto* vec = Var("a", nullptr, vec2<f32>(1_f, 1_f)); auto* vec = Var("a", nullptr, vec2<f32>(1_f, 1_f));
auto* expr = Call("normalize", vec); auto* expr = Call("normalize", vec);
@ -663,6 +860,45 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Normalize_f16) {
Enable(ast::Extension::kF16);
auto* vec = Var("a", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call("normalize", vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "a"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%12 = OpExtInst %5 %13 Normalize %14
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
using Builtin_Builder_DualParam_Float_Test = BuiltinBuilderTestWithParam<BuiltinData>; using Builtin_Builder_DualParam_Float_Test = BuiltinBuilderTestWithParam<BuiltinData>;
TEST_P(Builtin_Builder_DualParam_Float_Test, Call_Scalar_f32) { TEST_P(Builtin_Builder_DualParam_Float_Test, Call_Scalar_f32) {
auto param = GetParam(); auto param = GetParam();
@ -703,6 +939,47 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_P(Builtin_Builder_DualParam_Float_Test, Call_Scalar_f16) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* scalar = Var("scalar", nullptr, Expr(1_h));
auto* expr = Call(param.name, scalar, scalar);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(scalar),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%11 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %7 "scalar"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = OpTypeFloat 16
%6 = OpConstant %5 0x1p+0
%8 = OpTypePointer Function %5
%9 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%7 = OpVariable %8 Function %9
OpStore %7 %6
%12 = OpLoad %5 %7
%13 = OpLoad %5 %7
%10 = OpExtInst %5 %11 )" +
param.op +
R"( %12 %13
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_P(Builtin_Builder_DualParam_Float_Test, Call_Vector_f32) { TEST_P(Builtin_Builder_DualParam_Float_Test, Call_Vector_f32) {
auto param = GetParam(); auto param = GetParam();
auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 1_f)); auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 1_f));
@ -744,6 +1021,49 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_P(Builtin_Builder_DualParam_Float_Test, Call_Vector_f16) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* vec = Var("vec", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call(param.name, vec, vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "vec"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%15 = OpLoad %5 %9
%12 = OpExtInst %5 %13 )" +
param.op +
R"( %14 %15
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
INSTANTIATE_TEST_SUITE_P(BuiltinBuilderTest, INSTANTIATE_TEST_SUITE_P(BuiltinBuilderTest,
Builtin_Builder_DualParam_Float_Test, Builtin_Builder_DualParam_Float_Test,
testing::Values(BuiltinData{"atan2", "Atan2"}, testing::Values(BuiltinData{"atan2", "Atan2"},
@ -790,6 +1110,46 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Reflect_Vector_f16) {
Enable(ast::Extension::kF16);
auto* vec = Var("vec", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call("reflect", vec, vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "vec"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%15 = OpLoad %5 %9
%12 = OpExtInst %5 %13 Reflect %14 %15
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_F(BuiltinBuilderTest, Call_Distance_Scalar_f32) { TEST_F(BuiltinBuilderTest, Call_Distance_Scalar_f32) {
auto* scalar = Var("scalar", nullptr, Expr(1_f)); auto* scalar = Var("scalar", nullptr, Expr(1_f));
auto* expr = Call("distance", scalar, scalar); auto* expr = Call("distance", scalar, scalar);
@ -826,6 +1186,44 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Distance_Scalar_f16) {
Enable(ast::Extension::kF16);
auto* scalar = Var("scalar", nullptr, Expr(1_h));
auto* expr = Call("distance", scalar, scalar);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(scalar),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%11 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %7 "scalar"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = OpTypeFloat 16
%6 = OpConstant %5 0x1p+0
%8 = OpTypePointer Function %5
%9 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%7 = OpVariable %8 Function %9
OpStore %7 %6
%12 = OpLoad %5 %7
%13 = OpLoad %5 %7
%10 = OpExtInst %5 %11 Distance %12 %13
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_F(BuiltinBuilderTest, Call_Distance_Vector_f32) { TEST_F(BuiltinBuilderTest, Call_Distance_Vector_f32) {
auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 1_f)); auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 1_f));
auto* expr = Call("distance", vec, vec); auto* expr = Call("distance", vec, vec);
@ -864,6 +1262,46 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Distance_Vector_f16) {
Enable(ast::Extension::kF16);
auto* vec = Var("vec", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call("distance", vec, vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "vec"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%15 = OpLoad %5 %9
%12 = OpExtInst %6 %13 Distance %14 %15
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_F(BuiltinBuilderTest, Call_Cross_f32) { TEST_F(BuiltinBuilderTest, Call_Cross_f32) {
auto* vec = Var("vec", nullptr, vec3<f32>(1_f, 1_f, 1_f)); auto* vec = Var("vec", nullptr, vec3<f32>(1_f, 1_f, 1_f));
auto* expr = Call("cross", vec, vec); auto* expr = Call("cross", vec, vec);
@ -902,6 +1340,46 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Cross_f16) {
Enable(ast::Extension::kF16);
auto* vec = Var("vec", nullptr, vec3<f16>(1_h, 1_h, 1_h));
auto* expr = Call("cross", vec, vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "vec"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 3
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%15 = OpLoad %5 %9
%12 = OpExtInst %5 %13 Cross %14 %15
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
using Builtin_Builder_ThreeParam_Float_Test = BuiltinBuilderTestWithParam<BuiltinData>; using Builtin_Builder_ThreeParam_Float_Test = BuiltinBuilderTestWithParam<BuiltinData>;
TEST_P(Builtin_Builder_ThreeParam_Float_Test, Call_Scalar_f32) { TEST_P(Builtin_Builder_ThreeParam_Float_Test, Call_Scalar_f32) {
auto param = GetParam(); auto param = GetParam();
@ -943,6 +1421,48 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_P(Builtin_Builder_ThreeParam_Float_Test, Call_Scalar_f16) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* scalar = Var("scalar", nullptr, Expr(1_h));
auto* expr = Call(param.name, scalar, scalar, scalar);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(scalar),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%11 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %7 "scalar"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = OpTypeFloat 16
%6 = OpConstant %5 0x1p+0
%8 = OpTypePointer Function %5
%9 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%7 = OpVariable %8 Function %9
OpStore %7 %6
%12 = OpLoad %5 %7
%13 = OpLoad %5 %7
%14 = OpLoad %5 %7
%10 = OpExtInst %5 %11 )" +
param.op +
R"( %12 %13 %14
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_P(Builtin_Builder_ThreeParam_Float_Test, Call_Vector_f32) { TEST_P(Builtin_Builder_ThreeParam_Float_Test, Call_Vector_f32) {
auto param = GetParam(); auto param = GetParam();
auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 1_f)); auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 1_f));
@ -985,6 +1505,50 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_P(Builtin_Builder_ThreeParam_Float_Test, Call_Vector_f16) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* vec = Var("vec", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call(param.name, vec, vec, vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "vec"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%15 = OpLoad %5 %9
%16 = OpLoad %5 %9
%12 = OpExtInst %5 %13 )" +
param.op +
R"( %14 %15 %16
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
INSTANTIATE_TEST_SUITE_P(BuiltinBuilderTest, INSTANTIATE_TEST_SUITE_P(BuiltinBuilderTest,
Builtin_Builder_ThreeParam_Float_Test, Builtin_Builder_ThreeParam_Float_Test,
testing::Values(BuiltinData{"clamp", "NClamp"}, testing::Values(BuiltinData{"clamp", "NClamp"},
@ -1032,6 +1596,47 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_FaceForward_Vector_f16) {
Enable(ast::Extension::kF16);
auto* vec = Var("vec", nullptr, vec2<f16>(1_h, 1_h));
auto* expr = Call("faceForward", vec, vec, vec);
auto* func = Func("a_func", {}, ty.void_(),
{
Decl(vec),
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%13 = OpExtInstImport "GLSL.std.450"
OpName %3 "a_func"
OpName %9 "vec"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = OpTypeFloat 16
%5 = OpTypeVector %6 2
%7 = OpConstant %6 0x1p+0
%8 = OpConstantComposite %5 %7 %7
%10 = OpTypePointer Function %5
%11 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVariable %10 Function %11
OpStore %9 %8
%14 = OpLoad %5 %9
%15 = OpLoad %5 %9
%16 = OpLoad %5 %9
%12 = OpExtInst %5 %13 FaceForward %14 %15 %16
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_F(BuiltinBuilderTest, Call_Modf) { TEST_F(BuiltinBuilderTest, Call_Modf) {
auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 2_f)); auto* vec = Var("vec", nullptr, vec2<f32>(1_f, 2_f));
auto* expr = Call("modf", vec); auto* expr = Call("modf", vec);
@ -2130,6 +2735,43 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Determinant_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalVar("var", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* expr = Call("determinant", "var");
auto* func = Func("a_func", {}, ty.void_(),
{
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(%12 = OpExtInstImport "GLSL.std.450"
OpName %1 "var"
OpName %9 "a_func"
%5 = OpTypeFloat 16
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Private %3
%6 = OpConstantNull %3
%1 = OpVariable %2 Private %6
%8 = OpTypeVoid
%7 = OpTypeFunction %8
%9 = OpFunction %8 None %7
%10 = OpLabel
%13 = OpLoad %3 %1
%11 = OpExtInst %5 %12 Determinant %13
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
TEST_F(BuiltinBuilderTest, Call_Transpose_f32) { TEST_F(BuiltinBuilderTest, Call_Transpose_f32) {
auto* var = GlobalVar("var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate); auto* var = GlobalVar("var", ty.mat2x3<f32>(), ast::StorageClass::kPrivate);
auto* expr = Call("transpose", "var"); auto* expr = Call("transpose", "var");
@ -2166,6 +2808,44 @@ OpFunctionEnd
EXPECT_EQ(got, expect); EXPECT_EQ(got, expect);
} }
TEST_F(BuiltinBuilderTest, Call_Transpose_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalVar("var", ty.mat2x3<f16>(), ast::StorageClass::kPrivate);
auto* expr = Call("transpose", "var");
auto* func = Func("a_func", {}, ty.void_(),
{
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
auto got = DumpBuilder(b);
auto expect = R"(OpName %1 "var"
OpName %9 "a_func"
%5 = OpTypeFloat 16
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 2
%2 = OpTypePointer Private %3
%6 = OpConstantNull %3
%1 = OpVariable %2 Private %6
%8 = OpTypeVoid
%7 = OpTypeFunction %8
%13 = OpTypeVector %5 2
%12 = OpTypeMatrix %13 3
%9 = OpFunction %8 None %7
%10 = OpLabel
%14 = OpLoad %3 %1
%11 = OpTranspose %12 %14
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(got, expect);
}
} // namespace matrix_builtin_tests } // namespace matrix_builtin_tests
// Tests for Numeric builtins with float and integer vector parameter, i.e. "dot" // Tests for Numeric builtins with float and integer vector parameter, i.e. "dot"
@ -2200,6 +2880,37 @@ OpReturn
)"); )");
} }
TEST_F(BuiltinBuilderTest, Call_Dot_F16) {
Enable(ast::Extension::kF16);
auto* var = GlobalVar("v", ty.vec3<f16>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "v", "v");
auto* func = Func("a_func", {}, ty.void_(),
{
Assign(Phony(), expr),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
ASSERT_TRUE(b.GenerateFunction(func)) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 16
%3 = OpTypeVector %4 3
%2 = OpTypePointer Private %3
%5 = OpConstantNull %3
%1 = OpVariable %2 Private %5
%7 = OpTypeVoid
%6 = OpTypeFunction %7
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%11 = OpLoad %3 %1
%12 = OpLoad %3 %1
%10 = OpDot %4 %11 %12
OpReturn
)");
}
TEST_F(BuiltinBuilderTest, Call_Dot_U32) { TEST_F(BuiltinBuilderTest, Call_Dot_U32) {
auto* var = GlobalVar("v", ty.vec3<u32>(), ast::StorageClass::kPrivate); auto* var = GlobalVar("v", ty.vec3<u32>(), ast::StorageClass::kPrivate);
auto* expr = Call("dot", "v", "v"); auto* expr = Call("dot", "v", "v");

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -13,16 +13,20 @@ See:
{{- range Sem.Builtins -}} {{- range Sem.Builtins -}}
{{- range .Overloads -}} {{- range .Overloads -}}
{{- range Permute . -}} {{- range Permute . -}}
{{- /* Generate a ./literal/<function>/<permuataion-hash>.wgsl file using {{- /* TODO(crbug.com/tint/1502): Remove the bodge below and emit "enable f16;" after
the Permutation macro defined below */ -}} getting ready for F16 end-to-end tests. */ -}}
{{- $file := printf "./literal/%v/%v.wgsl" .Intrinsic.Name .Hash -}} {{- if not (OverloadUsesF16 .Overload) -}}
{{- $content := Eval "Permutation" "Overload" . "Mode" "literal" -}} {{- /* Generate a ./literal/<function>/<permuataion-hash>.wgsl file using
{{- WriteFile $file $content -}} the Permutation macro defined below */ -}}
{{- /* Generate a ./var/<function>/<permuataion-hash>.wgsl file using {{- $file := printf "./literal/%v/%v.wgsl" .Intrinsic.Name .Hash -}}
the Permutation macro defined below */ -}} {{- $content := Eval "Permutation" "Overload" . "Mode" "literal" -}}
{{- $file := printf "./var/%v/%v.wgsl" .Intrinsic.Name .Hash -}} {{- WriteFile $file $content -}}
{{- $content := Eval "Permutation" "Overload" . "Mode" "var" -}} {{- /* Generate a ./var/<function>/<permuataion-hash>.wgsl file using
{{- WriteFile $file $content -}} the Permutation macro defined below */ -}}
{{- $file := printf "./var/%v/%v.wgsl" .Intrinsic.Name .Hash -}}
{{- $content := Eval "Permutation" "Overload" . "Mode" "var" -}}
{{- WriteFile $file $content -}}
{{- end }}
{{- end }} {{- end }}
{{- end }} {{- end }}
{{- end }} {{- end }}
@ -319,4 +323,4 @@ enable chromium_experimental_dp4a;
{{- end -}} {{- end -}}
> >
{{- end -}} {{- end -}}
{{- end -}} {{- end -}}

View File

@ -1,47 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
////////////////////////////////////////////////////////////////////////////////
// File generated by tools/src/cmd/gen
// using the template:
// test/tint/builtins/gen/gen.wgsl.tmpl
//
// Do not modify this file directly
////////////////////////////////////////////////////////////////////////////////
struct SB_RO {
arg_0: array<f16>,
};
@group(0) @binding(1) var<storage, read> sb_ro : SB_RO;
// fn arrayLength(ptr<storage, array<f16>, read>) -> u32
fn arrayLength_8421b9() {
var res: u32 = arrayLength(&sb_ro.arg_0);
}
@vertex
fn vertex_main() -> @builtin(position) vec4<f32> {
arrayLength_8421b9();
return vec4<f32>();
}
@fragment
fn fragment_main() {
arrayLength_8421b9();
}
@compute @workgroup_size(1)
fn compute_main() {
arrayLength_8421b9();
}

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,47 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
////////////////////////////////////////////////////////////////////////////////
// File generated by tools/src/cmd/gen
// using the template:
// test/tint/builtins/gen/gen.wgsl.tmpl
//
// Do not modify this file directly
////////////////////////////////////////////////////////////////////////////////
struct SB_RW {
arg_0: array<f16>,
};
@group(0) @binding(0) var<storage, read_write> sb_rw : SB_RW;
// fn arrayLength(ptr<storage, array<f16>, read_write>) -> u32
fn arrayLength_cbd6b5() {
var res: u32 = arrayLength(&sb_rw.arg_0);
}
@vertex
fn vertex_main() -> @builtin(position) vec4<f32> {
arrayLength_cbd6b5();
return vec4<f32>();
}
@fragment
fn fragment_main() {
arrayLength_cbd6b5();
}
@compute @workgroup_size(1)
fn compute_main() {
arrayLength_cbd6b5();
}

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/literal/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec4 tint_degrees(vec4 param_0) { vec4 tint_degrees(vec4 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
vec4 tint_degrees(vec4 param_0) { vec4 tint_degrees(vec4 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
vec4 tint_degrees(vec4 param_0) { vec4 tint_degrees(vec4 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec2 tint_degrees(vec2 param_0) { vec2 tint_degrees(vec2 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
vec2 tint_degrees(vec2 param_0) { vec2 tint_degrees(vec2 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
vec2 tint_degrees(vec2 param_0) { vec2 tint_degrees(vec2 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec3 tint_degrees(vec3 param_0) { vec3 tint_degrees(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
vec3 tint_degrees(vec3 param_0) { vec3 tint_degrees(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
vec3 tint_degrees(vec3 param_0) { vec3 tint_degrees(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec4 tint_radians(vec4 param_0) { vec4 tint_radians(vec4 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
vec4 tint_radians(vec4 param_0) { vec4 tint_radians(vec4 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
vec4 tint_radians(vec4 param_0) { vec4 tint_radians(vec4 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec2 tint_radians(vec2 param_0) { vec2 tint_radians(vec2 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
vec2 tint_radians(vec2 param_0) { vec2 tint_radians(vec2 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
vec2 tint_radians(vec2 param_0) { vec2 tint_radians(vec2 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec3 tint_radians(vec3 param_0) { vec3 tint_radians(vec3 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -26,7 +26,7 @@ void main() {
precision mediump float; precision mediump float;
vec3 tint_radians(vec3 param_0) { vec3 tint_radians(vec3 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -45,7 +45,7 @@ void main() {
#version 310 es #version 310 es
vec3 tint_radians(vec3 param_0) { vec3 tint_radians(vec3 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,47 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
////////////////////////////////////////////////////////////////////////////////
// File generated by tools/src/cmd/gen
// using the template:
// test/tint/builtins/gen/gen.wgsl.tmpl
//
// Do not modify this file directly
////////////////////////////////////////////////////////////////////////////////
struct SB_RO {
arg_0: array<f16>,
};
@group(0) @binding(1) var<storage, read> sb_ro : SB_RO;
// fn arrayLength(ptr<storage, array<f16>, read>) -> u32
fn arrayLength_8421b9() {
var res: u32 = arrayLength(&sb_ro.arg_0);
}
@vertex
fn vertex_main() -> @builtin(position) vec4<f32> {
arrayLength_8421b9();
return vec4<f32>();
}
@fragment
fn fragment_main() {
arrayLength_8421b9();
}
@compute @workgroup_size(1)
fn compute_main() {
arrayLength_8421b9();
}

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/8421b9.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,47 +0,0 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
////////////////////////////////////////////////////////////////////////////////
// File generated by tools/src/cmd/gen
// using the template:
// test/tint/builtins/gen/gen.wgsl.tmpl
//
// Do not modify this file directly
////////////////////////////////////////////////////////////////////////////////
struct SB_RW {
arg_0: array<f16>,
};
@group(0) @binding(0) var<storage, read_write> sb_rw : SB_RW;
// fn arrayLength(ptr<storage, array<f16>, read_write>) -> u32
fn arrayLength_cbd6b5() {
var res: u32 = arrayLength(&sb_rw.arg_0);
}
@vertex
fn vertex_main() -> @builtin(position) vec4<f32> {
arrayLength_cbd6b5();
return vec4<f32>();
}
@fragment
fn fragment_main() {
arrayLength_cbd6b5();
}
@compute @workgroup_size(1)
fn compute_main() {
arrayLength_cbd6b5();
}

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,6 +0,0 @@
SKIP: FAILED
builtins/gen/var/arrayLength/cbd6b5.wgsl:26:16 error: f16 used without 'f16' extension enabled
arg_0: array<f16>,
^^^

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec4 tint_degrees(vec4 param_0) { vec4 tint_degrees(vec4 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
vec4 tint_degrees(vec4 param_0) { vec4 tint_degrees(vec4 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
vec4 tint_degrees(vec4 param_0) { vec4 tint_degrees(vec4 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec2 tint_degrees(vec2 param_0) { vec2 tint_degrees(vec2 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
vec2 tint_degrees(vec2 param_0) { vec2 tint_degrees(vec2 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
vec2 tint_degrees(vec2 param_0) { vec2 tint_degrees(vec2 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec3 tint_degrees(vec3 param_0) { vec3 tint_degrees(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
vec3 tint_degrees(vec3 param_0) { vec3 tint_degrees(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
vec3 tint_degrees(vec3 param_0) { vec3 tint_degrees(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
float tint_degrees(float param_0) { float tint_degrees(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec4 tint_radians(vec4 param_0) { vec4 tint_radians(vec4 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
vec4 tint_radians(vec4 param_0) { vec4 tint_radians(vec4 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
vec4 tint_radians(vec4 param_0) { vec4 tint_radians(vec4 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec2 tint_radians(vec2 param_0) { vec2 tint_radians(vec2 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
vec2 tint_radians(vec2 param_0) { vec2 tint_radians(vec2 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
vec2 tint_radians(vec2 param_0) { vec2 tint_radians(vec2 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
vec3 tint_radians(vec3 param_0) { vec3 tint_radians(vec3 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -27,7 +27,7 @@ void main() {
precision mediump float; precision mediump float;
vec3 tint_radians(vec3 param_0) { vec3 tint_radians(vec3 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }
@ -47,7 +47,7 @@ void main() {
#version 310 es #version 310 es
vec3 tint_radians(vec3 param_0) { vec3 tint_radians(vec3 param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,7 +1,7 @@
#version 310 es #version 310 es
float tint_radians(float param_0) { float tint_radians(float param_0) {
return param_0 * 0.017453292519943295474; return param_0 * 0.017453292519943295474f;
} }

View File

@ -1,19 +1,19 @@
#version 310 es #version 310 es
vec4 tint_degrees(vec4 param_0) { vec4 tint_degrees(vec4 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
vec3 tint_degrees_1(vec3 param_0) { vec3 tint_degrees_1(vec3 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
vec2 tint_degrees_2(vec2 param_0) { vec2 tint_degrees_2(vec2 param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }
float tint_degrees_3(float param_0) { float tint_degrees_3(float param_0) {
return param_0 * 57.295779513082322865; return param_0 * 57.295779513082322865f;
} }

View File

@ -448,6 +448,8 @@ func SplitDisplayName(displayName string) []string {
// If the type is not a composite type, then the fully qualified name is returned // If the type is not a composite type, then the fully qualified name is returned
func ElementType(fqn sem.FullyQualifiedName) sem.FullyQualifiedName { func ElementType(fqn sem.FullyQualifiedName) sem.FullyQualifiedName {
switch fqn.Target.GetName() { switch fqn.Target.GetName() {
case "vec2", "vec3", "vec4":
return fqn.TemplateArguments[0].(sem.FullyQualifiedName)
case "vec": case "vec":
return fqn.TemplateArguments[1].(sem.FullyQualifiedName) return fqn.TemplateArguments[1].(sem.FullyQualifiedName)
case "mat": case "mat":
@ -462,12 +464,16 @@ func ElementType(fqn sem.FullyQualifiedName) sem.FullyQualifiedName {
// fully qualified name. // fully qualified name.
func DeepestElementType(fqn sem.FullyQualifiedName) sem.FullyQualifiedName { func DeepestElementType(fqn sem.FullyQualifiedName) sem.FullyQualifiedName {
switch fqn.Target.GetName() { switch fqn.Target.GetName() {
case "vec2", "vec3", "vec4":
return fqn.TemplateArguments[0].(sem.FullyQualifiedName)
case "vec": case "vec":
return fqn.TemplateArguments[1].(sem.FullyQualifiedName) return fqn.TemplateArguments[1].(sem.FullyQualifiedName)
case "mat": case "mat":
return DeepestElementType(fqn.TemplateArguments[2].(sem.FullyQualifiedName)) return DeepestElementType(fqn.TemplateArguments[2].(sem.FullyQualifiedName))
case "array": case "array":
return DeepestElementType(fqn.TemplateArguments[0].(sem.FullyQualifiedName)) return DeepestElementType(fqn.TemplateArguments[0].(sem.FullyQualifiedName))
case "ptr":
return DeepestElementType(fqn.TemplateArguments[1].(sem.FullyQualifiedName))
} }
return fqn return fqn
} }
@ -489,7 +495,7 @@ func IsDeclarable(fqn sem.FullyQualifiedName) bool {
} }
// OverloadUsesF16 returns true if the overload uses the f16 type anywhere in the signature. // OverloadUsesF16 returns true if the overload uses the f16 type anywhere in the signature.
func OverloadUsesF16(overload *sem.Overload, typename string) bool { func OverloadUsesF16(overload *sem.Overload) bool {
for _, param := range overload.Parameters { for _, param := range overload.Parameters {
if DeepestElementType(param.Type).Target.GetName() == "f16" { if DeepestElementType(param.Type).Target.GetName() == "f16" {
return true return true