From d7868e34c2cd5e9bd63b2e165fe2236ee5794c61 Mon Sep 17 00:00:00 2001 From: David Neto Date: Thu, 18 Jun 2020 17:57:23 +0000 Subject: [PATCH] [spirv-reader] Add ConvertFToU, ConvertFToS Bug: tint:3 Change-Id: I9f3188e0aac64e98da785c4df2e8b2aa42b71cf8 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23402 Reviewed-by: dan sinclair --- src/reader/spirv/function.cc | 77 +++- src/reader/spirv/function_conversion_test.cc | 429 ++++++++++++++++++- src/reader/spirv/parser_impl.cc | 48 +++ src/reader/spirv/parser_impl.h | 28 ++ 4 files changed, 556 insertions(+), 26 deletions(-) diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 05c3b3017c..260fd69d27 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -50,6 +50,7 @@ #include "src/ast/storage_class.h" #include "src/ast/switch_statement.h" #include "src/ast/type/bool_type.h" +#include "src/ast/type/type.h" #include "src/ast/type/u32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" @@ -2473,12 +2474,9 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( auto arg1 = MakeOperand(inst, 1); auto binary_expr = std::make_unique( binary_op, std::move(arg0.expr), std::move(arg1.expr)); - auto* forced_result_ty = parser_impl_.ForcedResultType(opcode, arg0.type); - if (forced_result_ty && forced_result_ty != ast_type) { - return {ast_type, std::make_unique( - ast_type, std::move(binary_expr))}; - } - return {ast_type, std::move(binary_expr)}; + TypedExpression result(ast_type, std::move(binary_expr)); + return parser_impl_.RectifyForcedResultType(std::move(result), opcode, + arg0.type); } auto unary_op = ast::UnaryOp::kNegation; @@ -2486,12 +2484,9 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( auto arg0 = MakeOperand(inst, 0); auto unary_expr = std::make_unique( unary_op, std::move(arg0.expr)); - auto* forced_result_ty = parser_impl_.ForcedResultType(opcode, arg0.type); - if (forced_result_ty && forced_result_ty != ast_type) { - return {ast_type, std::make_unique( - ast_type, std::move(unary_expr))}; - } - return {ast_type, std::move(unary_expr)}; + TypedExpression result(ast_type, std::move(unary_expr)); + return parser_impl_.RectifyForcedResultType(std::move(result), opcode, + arg0.type); } if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) { @@ -2539,7 +2534,8 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( if (opcode == SpvOpVectorShuffle) { return MakeVectorShuffle(inst); } - if (opcode == SpvOpConvertSToF || opcode == SpvOpConvertUToF) { + if (opcode == SpvOpConvertSToF || opcode == SpvOpConvertUToF || + opcode == SpvOpConvertFToS || opcode == SpvOpConvertFToU) { return MakeNumericConversion(inst); } @@ -2547,13 +2543,9 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( // glsl.std.450 readonly function // Instructions: - // OpCopyObject // OpUndef - // OpBitcast - // OpSatConvertSToU - // OpSatConvertUToS - // OpConvertFToS - // OpConvertFToU + // OpSatConvertSToU // Only in Kernel (OpenCL), not in WebGPU + // OpSatConvertUToS // Only in Kernel (OpenCL), not in WebGPU // OpUConvert // Only needed when multiple widths supported // OpSConvert // Only needed when multiple widths supported // OpFConvert // Only needed when multiple widths supported @@ -2902,11 +2894,52 @@ void FunctionEmitter::RegisterValuesNeedingNamedDefinition() { TypedExpression FunctionEmitter::MakeNumericConversion( const spvtools::opt::Instruction& inst) { - auto* result_type = parser_impl_.ConvertType(inst.type_id()); + const auto opcode = inst.opcode(); + auto* requested_type = parser_impl_.ConvertType(inst.type_id()); auto arg_expr = MakeOperand(inst, 0); + if (!arg_expr.expr || !arg_expr.type) { + return {}; + } - return {result_type, std::make_unique( - result_type, std::move(arg_expr.expr))}; + ast::type::Type* expr_type = nullptr; + if ((opcode == SpvOpConvertSToF) || (opcode == SpvOpConvertUToF)) { + if (arg_expr.type->is_integer_scalar_or_vector()) { + expr_type = requested_type; + } else { + Fail() << "operand for conversion to floating point must be integral " + "scalar or vector, but got: " + << arg_expr.type->type_name(); + } + } else if (inst.opcode() == SpvOpConvertFToU) { + if (arg_expr.type->is_float_scalar_or_vector()) { + expr_type = parser_impl_.GetUnsignedIntMatchingShape(arg_expr.type); + } else { + Fail() << "operand for conversion to unsigned integer must be floating " + "point scalar or vector, but got: " + << arg_expr.type->type_name(); + } + } else if (inst.opcode() == SpvOpConvertFToS) { + if (arg_expr.type->is_float_scalar_or_vector()) { + expr_type = parser_impl_.GetSignedIntMatchingShape(arg_expr.type); + } else { + Fail() << "operand for conversion to signed integer must be floating " + "point scalar or vector, but got: " + << arg_expr.type->type_name(); + } + } + if (expr_type == nullptr) { + // The diagnostic has already been emitted. + return {}; + } + + TypedExpression result(expr_type, std::make_unique( + expr_type, std::move(arg_expr.expr))); + + if (requested_type == expr_type) { + return result; + } + return {requested_type, std::make_unique( + requested_type, std::move(result.expr))}; } } // namespace spirv diff --git a/src/reader/spirv/function_conversion_test.cc b/src/reader/spirv/function_conversion_test.cc index 92d3c6cb1f..1500eb4183 100644 --- a/src/reader/spirv/function_conversion_test.cc +++ b/src/reader/spirv/function_conversion_test.cc @@ -26,6 +26,7 @@ namespace reader { namespace spirv { namespace { +using ::testing::Eq; using ::testing::HasSubstr; std::string CommonTypes() { @@ -33,10 +34,16 @@ std::string CommonTypes() { %void = OpTypeVoid %voidfn = OpTypeFunction %void + %bool = OpTypeBool %uint = OpTypeInt 32 0 %int = OpTypeInt 32 1 %float = OpTypeFloat 32 + %true = OpConstantTrue %bool + %false = OpConstantFalse %bool + %v2bool = OpTypeVector %bool 2 + %v2bool_t_f = OpConstantComposite %v2bool %true %false + %uint_10 = OpConstant %uint 10 %uint_20 = OpConstant %uint 20 %int_30 = OpConstant %int 30 @@ -119,6 +126,105 @@ TEST_F(SpvUnaryConversionTest, Bitcast_Vector) { << ToString(fe.ast_body()); } +TEST_F(SpvUnaryConversionTest, ConvertSToF_BadArg) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertSToF %float %void + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + HasSubstr("unhandled expression for ID 2\n%2 = OpTypeVoid")); +} + +TEST_F(SpvUnaryConversionTest, ConvertUToF_BadArg) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertUToF %float %void + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + HasSubstr("unhandled expression for ID 2\n%2 = OpTypeVoid")); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToS_BadArg) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertFToS %float %void + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + HasSubstr("unhandled expression for ID 2\n%2 = OpTypeVoid")); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToU_BadArg) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertFToU %float %void + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + HasSubstr("unhandled expression for ID 2\n%2 = OpTypeVoid")); +} + +TEST_F(SpvUnaryConversionTest, ConvertSToF_Scalar_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertSToF %float %false + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + HasSubstr("operand for conversion to floating point must be " + "integral scalar or vector, but got: __bool")); +} + +TEST_F(SpvUnaryConversionTest, ConvertSToF_Vector_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertSToF %v2float %v2bool_t_f + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT( + p->error(), + HasSubstr("operand for conversion to floating point must be integral " + "scalar or vector, but got: __vec_2__bool")); +} + TEST_F(SpvUnaryConversionTest, ConvertSToF_Scalar_FromSigned) { const auto assembly = CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -227,6 +333,39 @@ TEST_F(SpvUnaryConversionTest, ConvertSToF_Vector_FromUnsigned) { << ToString(fe.ast_body()); } +TEST_F(SpvUnaryConversionTest, ConvertUToF_Scalar_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertUToF %float %false + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), Eq("operand for conversion to floating point must be " + "integral scalar or vector, but got: __bool")); +} + +TEST_F(SpvUnaryConversionTest, ConvertUToF_Vector_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertUToF %v2float %v2bool_t_f + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + Eq("operand for conversion to floating point must be integral " + "scalar or vector, but got: __vec_2__bool")); +} + TEST_F(SpvUnaryConversionTest, ConvertUToF_Scalar_FromSigned) { const auto assembly = CommonTypes() + R"( %100 = OpFunction %void None %voidfn @@ -335,14 +474,296 @@ TEST_F(SpvUnaryConversionTest, ConvertUToF_Vector_FromUnsigned) { << ToString(fe.ast_body()); } -// TODO(dneto): OpConvertFToU -// TODO(dneto): OpConvertFToS +TEST_F(SpvUnaryConversionTest, ConvertFToS_Scalar_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertFToS %int %uint_10 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + Eq("operand for conversion to signed integer must be floating " + "point scalar or vector, but got: __u32")); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToS_Vector_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertFToS %v2float %v2bool_t_f + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + Eq("operand for conversion to signed integer must be floating " + "point scalar or vector, but got: __vec_2__bool")); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToS_Scalar_ToSigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %float %float_50 + %1 = OpConvertFToS %int %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __i32 + { + Cast<__i32>( + Identifier{x_30} + ) + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToS_Scalar_ToUnsigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %float %float_50 + %1 = OpConvertFToS %uint %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __u32 + { + As<__u32>{ + Cast<__i32>( + Identifier{x_30} + ) + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToS_Vector_ToSigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %v2float %v2float_50_60 + %1 = OpConvertFToS %v2int %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __vec_2__i32 + { + Cast<__vec_2__i32>( + Identifier{x_30} + ) + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToS_Vector_ToUnsigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %v2float %v2float_50_60 + %1 = OpConvertFToS %v2uint %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __vec_2__u32 + { + As<__vec_2__u32>{ + Cast<__vec_2__i32>( + Identifier{x_30} + ) + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToU_Scalar_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertFToU %int %uint_10 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + Eq("operand for conversion to unsigned integer must be floating " + "point scalar or vector, but got: __u32")); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToU_Vector_BadArgType) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpConvertFToU %v2float %v2bool_t_f + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), + Eq("operand for conversion to unsigned integer must be floating " + "point scalar or vector, but got: __vec_2__bool")); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToU_Scalar_ToSigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %float %float_50 + %1 = OpConvertFToU %int %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __i32 + { + As<__i32>{ + Cast<__u32>( + Identifier{x_30} + ) + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToU_Scalar_ToUnsigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %float %float_50 + %1 = OpConvertFToU %uint %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __u32 + { + Cast<__u32>( + Identifier{x_30} + ) + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToU_Vector_ToSigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %v2float %v2float_50_60 + %1 = OpConvertFToU %v2int %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __vec_2__i32 + { + As<__vec_2__i32>{ + Cast<__vec_2__u32>( + Identifier{x_30} + ) + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvUnaryConversionTest, ConvertFToU_Vector_ToUnsigned) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %v2float %v2float_50_60 + %1 = OpConvertFToU %v2uint %30 + OpReturn + OpFunctionEnd + )"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_1 + none + __vec_2__u32 + { + Cast<__vec_2__u32>( + Identifier{x_30} + ) + } + })")) + << ToString(fe.ast_body()); +} + // TODO(dneto): OpSConvert // only if multiple widths // TODO(dneto): OpUConvert // only if multiple widths // TODO(dneto): OpFConvert // only if multiple widths // TODO(dneto): OpQuantizeToF16 // only if f16 supported -// TODO(dneto): OpConvertSToU -// TODO(dneto): OpConvertUToS +// TODO(dneto): OpSatConvertSToU // Kernel (OpenCL), not in WebGPU +// TODO(dneto): OpSatConvertUToS // Kernel (OpenCL), not in WebGPU } // namespace } // namespace spirv diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 54dcff5dfb..f5b896fa54 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -1087,6 +1087,54 @@ ast::type::Type* ParserImpl::ForcedResultType( return nullptr; } +ast::type::Type* ParserImpl::GetSignedIntMatchingShape(ast::type::Type* other) { + if (other == nullptr) { + Fail() << "no type provided"; + } + auto* i32 = ctx_.type_mgr().Get(std::make_unique()); + if (other->IsF32() || other->IsU32() || other->IsI32()) { + return i32; + } + auto* vec_ty = other->AsVector(); + if (vec_ty) { + return ctx_.type_mgr().Get( + std::make_unique(i32, vec_ty->size())); + } + Fail() << "required numeric scalar or vector, but got " << other->type_name(); + return nullptr; +} + +ast::type::Type* ParserImpl::GetUnsignedIntMatchingShape( + ast::type::Type* other) { + if (other == nullptr) { + Fail() << "no type provided"; + return nullptr; + } + auto* u32 = ctx_.type_mgr().Get(std::make_unique()); + if (other->IsF32() || other->IsU32() || other->IsI32()) { + return u32; + } + auto* vec_ty = other->AsVector(); + if (vec_ty) { + return ctx_.type_mgr().Get( + std::make_unique(u32, vec_ty->size())); + } + Fail() << "required numeric scalar or vector, but got " << other->type_name(); + return nullptr; +} + +TypedExpression ParserImpl::RectifyForcedResultType( + TypedExpression expr, + SpvOp op, + ast::type::Type* first_operand_type) { + auto* forced_result_ty = ForcedResultType(op, first_operand_type); + if ((forced_result_ty == nullptr) || (forced_result_ty == expr.type)) { + return expr; + } + return {expr.type, + std::make_unique(expr.type, std::move(expr.expr))}; +} + bool ParserImpl::EmitFunctions() { if (!success_) { return false; diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index b269455f3b..8c66bcb4ab 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -280,6 +280,34 @@ class ParserImpl : Reader { ast::type::Type* ForcedResultType(SpvOp op, ast::type::Type* first_operand_type); + /// Returns a signed integer scalar or vector type matching the shape (scalar, + /// vector, and component bit width) of another type, which itself is a + /// numeric scalar or vector. Returns null if the other type does not meet the + /// requirement. + /// @param other the type whose shape must be matched + /// @returns the signed scalar or vector type + ast::type::Type* GetSignedIntMatchingShape(ast::type::Type* other); + + /// Returns a signed integer scalar or vector type matching the shape (scalar, + /// vector, and component bit width) of another type, which itself is a + /// numeric scalar or vector. Returns null if the other type does not meet the + /// requirement. + /// @param other the type whose shape must be matched + /// @returns the unsigned scalar or vector type + ast::type::Type* GetUnsignedIntMatchingShape(ast::type::Type* other); + + /// Wraps the given expression in an as-cast to the given expression's type, + /// when the underlying operation produces a forced result type different + /// from the expression's result type. Otherwise, returns the given expression + /// unchanged. + /// @param expr the expression to pass through or to wrap + /// @param op the SPIR-V opcode + /// @param first_operand_type the AST type for the first operand. + /// @returns the forced AST result type, or nullptr if no forcing is required. + TypedExpression RectifyForcedResultType(TypedExpression expr, + SpvOp op, + ast::type::Type* first_operand_type); + /// @returns the registered boolean type. ast::type::Type* BoolType() const { return bool_type_; }