diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index ef5deafffc..d35f0d919f 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -325,65 +325,99 @@ ast::BinaryOp NegatedFloatCompare(SpvOp opcode) { // @returns the WGSL standard function name, or an empty string. std::string GetGlslStd450FuncName(uint32_t ext_opcode) { switch (ext_opcode) { + case GLSLstd450FAbs: + return "abs"; + case GLSLstd450Acos: + return "acos"; + case GLSLstd450Asin: + return "asin"; + case GLSLstd450Atan: + return "atan"; case GLSLstd450Atan2: return "atan2"; - case GLSLstd450Cos: - return "cos"; - case GLSLstd450Sin: - return "sin"; - case GLSLstd450Distance: - return "distance"; - case GLSLstd450Normalize: - return "normalize"; + case GLSLstd450Ceil: + return "ceil"; case GLSLstd450UClamp: case GLSLstd450SClamp: case GLSLstd450NClamp: case GLSLstd450FClamp: // FClamp is less prescriptive about NaN operands return "clamp"; + case GLSLstd450Cos: + return "cos"; + case GLSLstd450Cosh: + return "cosh"; + case GLSLstd450Cross: + return "cross"; + case GLSLstd450Distance: + return "distance"; + case GLSLstd450Exp: + return "exp"; + case GLSLstd450Exp2: + return "exp2"; + case GLSLstd450FaceForward: + return "faceForward"; + case GLSLstd450Floor: + return "floor"; + case GLSLstd450Fma: + return "fma"; + case GLSLstd450Fract: + return "fract"; + case GLSLstd450InverseSqrt: + return "inverseSqrt"; case GLSLstd450Length: return "length"; - case GLSLstd450NMin: - case GLSLstd450FMin: // FMin is less prescriptive about NaN operands - return "min"; + case GLSLstd450Log: + return "log"; + case GLSLstd450Log2: + return "log2"; case GLSLstd450NMax: case GLSLstd450FMax: // FMax is less prescriptive about NaN operands return "max"; + case GLSLstd450NMin: + case GLSLstd450FMin: // FMin is less prescriptive about NaN operands + return "min"; + case GLSLstd450FMix: + return "mix"; + case GLSLstd450Normalize: + return "normalize"; + case GLSLstd450Pow: + return "pow"; + case GLSLstd450FSign: + return "sign"; + case GLSLstd450Reflect: + return "reflect"; + case GLSLstd450Round: + return "round"; + case GLSLstd450Sin: + return "sin"; + case GLSLstd450Sinh: + return "sinh"; + case GLSLstd450SmoothStep: + return "smoothStep"; + case GLSLstd450Sqrt: + return "sqrt"; + case GLSLstd450Step: + return "step"; + case GLSLstd450Tan: + return "tan"; + case GLSLstd450Tanh: + return "tanh"; + case GLSLstd450Trunc: + return "trunc"; default: // TODO(dneto). The following are not implemented. // They are grouped semantically, as in GLSL.std.450.h. - case GLSLstd450Round: case GLSLstd450RoundEven: - case GLSLstd450Trunc: - case GLSLstd450FAbs: case GLSLstd450SAbs: - case GLSLstd450FSign: case GLSLstd450SSign: - case GLSLstd450Floor: - case GLSLstd450Ceil: - case GLSLstd450Fract: case GLSLstd450Radians: case GLSLstd450Degrees: - case GLSLstd450Tan: - case GLSLstd450Asin: - case GLSLstd450Acos: - case GLSLstd450Atan: - case GLSLstd450Sinh: - case GLSLstd450Cosh: - case GLSLstd450Tanh: case GLSLstd450Asinh: case GLSLstd450Acosh: case GLSLstd450Atanh: - case GLSLstd450Pow: - case GLSLstd450Exp: - case GLSLstd450Log: - case GLSLstd450Exp2: - case GLSLstd450Log2: - case GLSLstd450Sqrt: - case GLSLstd450InverseSqrt: - case GLSLstd450Determinant: case GLSLstd450MatrixInverse: @@ -393,12 +427,8 @@ std::string GetGlslStd450FuncName(uint32_t ext_opcode) { case GLSLstd450SMin: case GLSLstd450UMax: case GLSLstd450SMax: - case GLSLstd450FMix: case GLSLstd450IMix: - case GLSLstd450Step: - case GLSLstd450SmoothStep: - case GLSLstd450Fma: case GLSLstd450Frexp: case GLSLstd450FrexpStruct: case GLSLstd450Ldexp: @@ -416,9 +446,6 @@ std::string GetGlslStd450FuncName(uint32_t ext_opcode) { case GLSLstd450UnpackUnorm4x8: case GLSLstd450UnpackDouble2x32: - case GLSLstd450Cross: - case GLSLstd450FaceForward: - case GLSLstd450Reflect: case GLSLstd450Refract: case GLSLstd450FindILsb: diff --git a/src/reader/spirv/function_glsl_std_450_test.cc b/src/reader/spirv/function_glsl_std_450_test.cc index ffa1e54d0d..add0dac9a7 100644 --- a/src/reader/spirv/function_glsl_std_450_test.cc +++ b/src/reader/spirv/function_glsl_std_450_test.cc @@ -55,6 +55,8 @@ std::string Preamble() { OpName %v2f1 "v2f1" OpName %v2f2 "v2f2" OpName %v2f3 "v2f3" + OpName %v3f1 "v3f1" + OpName %v3f2 "v3f2" %void = OpTypeVoid %voidfn = OpTypeFunction %void @@ -76,6 +78,7 @@ std::string Preamble() { %v2uint = OpTypeVector %uint 2 %v2int = OpTypeVector %int 2 %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 %v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20 %v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10 @@ -87,6 +90,9 @@ std::string Preamble() { %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50 %v2float_70_70 = OpConstantComposite %v2float %float_70 %float_70 + %v3float_50_60_70 = OpConstantComposite %v3float %float_50 %float_60 %float_70 + %v3float_60_70_50 = OpConstantComposite %v3float %float_60 %float_70 %float_50 + %100 = OpFunction %void None %voidfn %entry = OpLabel @@ -114,6 +120,9 @@ std::string Preamble() { %v2f2 = OpCopyObject %v2float %v2float_60_50 %v2f3 = OpCopyObject %v2float %v2float_70_70 + %v3f1 = OpCopyObject %v3float %v3float_50_60_70 + %v3f2 = OpCopyObject %v3float %v3float_60_70_50 + )"; } @@ -129,6 +138,7 @@ inline std::ostream& operator<<(std::ostream& out, GlslStd450Case c) { // Nomenclature: // Float = scalar float // Floating = scalar float or vector-of-float +// Float3 = 3-element vector of float // Int = scalar signed int // Inting = scalar int or vector-of-int // Uint = scalar unsigned int @@ -144,6 +154,8 @@ using SpvParserTest_GlslStd450_Floating_FloatingFloating = SpvParserTestBase<::testing::TestWithParam>; using SpvParserTest_GlslStd450_Floating_FloatingFloatingFloating = SpvParserTestBase<::testing::TestWithParam>; +using SpvParserTest_GlslStd450_Float3_Float3Float3 = + SpvParserTestBase<::testing::TestWithParam>; using SpvParserTest_GlslStd450_Inting_IntingIntingInting = SpvParserTestBase<::testing::TestWithParam>; @@ -276,7 +288,7 @@ TEST_P(SpvParserTest_GlslStd450_Floating_Floating, Scalar) { OpFunctionEnd )"; auto p = parser(test::Assemble(assembly)); - ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.EmitBody()) << p->error(); EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( @@ -334,7 +346,7 @@ TEST_P(SpvParserTest_GlslStd450_Floating_FloatingFloating, Scalar) { OpFunctionEnd )"; auto p = parser(test::Assemble(assembly)); - ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.EmitBody()) << p->error(); EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( @@ -364,7 +376,7 @@ TEST_P(SpvParserTest_GlslStd450_Floating_FloatingFloating, Vector) { OpFunctionEnd )"; auto p = parser(test::Assemble(assembly)); - ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.EmitBody()) << p->error(); EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( @@ -449,6 +461,37 @@ TEST_P(SpvParserTest_GlslStd450_Floating_FloatingFloatingFloating, Vector) { << ToString(fe.ast_body()); } +TEST_P(SpvParserTest_GlslStd450_Float3_Float3Float3, Samples) { + const auto assembly = Preamble() + R"( + %1 = OpExtInst %v3float %glsl )" + + GetParam().opcode + + R"( %v3f1 %v3f2 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_3__f32 + { + Call[not set]{ + Identifier[not set]{)" + GetParam().wgsl_func + + R"(} + ( + Identifier[not set]{v3f1} + Identifier[not set]{v3f2} + ) + } + } + })")) + << ToString(fe.ast_body()); +} + INSTANTIATE_TEST_SUITE_P(Samples, SpvParserTest_GlslStd450_Float_Floating, ::testing::Values(GlslStd450Case{"Length", "length"})); @@ -460,10 +503,31 @@ INSTANTIATE_TEST_SUITE_P(Samples, INSTANTIATE_TEST_SUITE_P(Samples, SpvParserTest_GlslStd450_Floating_Floating, - ::testing::Values(GlslStd450Case{"Sin", "sin"}, - GlslStd450Case{"Cos", "cos"}, - GlslStd450Case{"Normalize", - "normalize"})); + ::testing::ValuesIn(std::vector{ + {"Acos", "acos"}, + {"Asin", "asin"}, + {"Atan", "atan"}, + {"Ceil", "ceil"}, + {"Cos", "cos"}, + {"Cosh", "cosh"}, + {"Exp", "exp"}, + {"Exp2", "exp2"}, + {"FAbs", "abs"}, + {"FSign", "sign"}, + {"Floor", "floor"}, + {"Fract", "fract"}, + {"InverseSqrt", "inverseSqrt"}, + {"Log", "log"}, + {"Log2", "log2"}, + {"Normalize", "normalize"}, + {"Round", "round"}, + {"Sin", "sin"}, + {"Sinh", "sinh"}, + {"Sqrt", "sqrt"}, + {"Tan", "tan"}, + {"Tanh", "tanh"}, + {"Trunc", "trunc"}, + })); INSTANTIATE_TEST_SUITE_P(Samples, SpvParserTest_GlslStd450_Floating_FloatingFloating, @@ -472,16 +536,26 @@ INSTANTIATE_TEST_SUITE_P(Samples, {"NMax", "max"}, {"NMin", "min"}, {"FMax", "max"}, // WGSL max promises more for NaN - {"FMin", "min"} // WGSL min promises more for NaN + {"FMin", "min"}, // WGSL min promises more for NaN + {"Pow", "pow"}, + {"Reflect", "reflect"}, + {"Step", "step"}, })); +INSTANTIATE_TEST_SUITE_P(Samples, + SpvParserTest_GlslStd450_Float3_Float3Float3, + ::testing::Values(GlslStd450Case{"Cross", "cross"})); + INSTANTIATE_TEST_SUITE_P( Samples, SpvParserTest_GlslStd450_Floating_FloatingFloatingFloating, ::testing::ValuesIn(std::vector{ {"NClamp", "clamp"}, - {"FClamp", "clamp"} // WGSL FClamp promises more for NaN - })); + {"FClamp", "clamp"}, // WGSL FClamp promises more for NaN + {"FaceForward", "faceForward"}, + {"Fma", "fma"}, + {"FMix", "mix"}, + {"SmoothStep", "smoothStep"}})); TEST_P(SpvParserTest_GlslStd450_Inting_IntingIntingInting, Scalar) { const auto assembly = Preamble() + R"(