From 4f10a256d5328cafe45962ddf9759d540ec74a1d Mon Sep 17 00:00:00 2001 From: David Neto Date: Mon, 20 Apr 2020 21:06:43 +0000 Subject: [PATCH] [spirv-reader] Fix OpSDiv operand and result signedness (I expect that) the WGSL signed division operator expects both operands to be signed and the result will also be signed. When the operands of a SPIR-V OpSDiv is unsigned, then wrap the operand in an as-cast to the corresponding signed type. When the result type of a SPIR-V OpSDiv instruction is unsigned, we have to wrap the generated WGSL operator with an as-cast to that unsigned type. This first CL addresses OpSDiv. We'll address other operations in future CLs. Bug: tint:3 Change-Id: If3849ceb44b21db87c1efd2c6a2cd63c6d648c88 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19800 Reviewed-by: dan sinclair --- src/reader/spirv/function.cc | 25 +++- src/reader/spirv/function_arithmetic_test.cc | 114 +++++++++++++++++ src/reader/spirv/parser_impl.cc | 125 ++++++++++++++++++- src/reader/spirv/parser_impl.h | 28 +++++ 4 files changed, 280 insertions(+), 12 deletions(-) diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 0a17fd5b82..3d42885e7e 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -20,6 +20,7 @@ #include "source/opt/function.h" #include "source/opt/instruction.h" #include "source/opt/module.h" +#include "src/ast/as_expression.h" #include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" #include "src/ast/identifier_expression.h" @@ -83,6 +84,7 @@ ast::BinaryOp ConvertBinaryOp(SpvOp opcode) { } return ast::BinaryOp::kNone; } + } // namespace FunctionEmitter::FunctionEmitter(ParserImpl* pi, @@ -358,19 +360,30 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( // TODO(dneto): Fill in the following cases. auto operand = [this, &inst](uint32_t operand_index) { - return this->MakeExpression(inst.GetSingleWordInOperand(operand_index)); + auto expr = + this->MakeExpression(inst.GetSingleWordInOperand(operand_index)); + return parser_impl_.RectifyOperandSignedness(inst.opcode(), + std::move(expr)); }; - auto* ast_type = + ast::type::Type* ast_type = inst.type_id() != 0 ? parser_impl_.ConvertType(inst.type_id()) : nullptr; auto binary_op = ConvertBinaryOp(inst.opcode()); if (binary_op != ast::BinaryOp::kNone) { - return {ast_type, std::make_unique( - binary_op, std::move(operand(0).expr), - std::move(operand(1).expr))}; + auto arg0 = operand(0); + auto arg1 = operand(1); + auto binary_expr = std::make_unique( + binary_op, std::move(arg0.expr), std::move(arg1.expr)); + auto* forced_result_ty = + parser_impl_.ForcedResultType(inst.opcode(), arg0.type); + if (forced_result_ty && forced_result_ty != ast_type) { + return {ast_type, std::make_unique( + ast_type, std::move(binary_expr))}; + } + return {ast_type, std::move(binary_expr)}; } - // binary operator + // unary operator // builtin readonly function // glsl.std.450 readonly function diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc index a0ccf6e2b9..31ac7f0250 100644 --- a/src/reader/spirv/function_arithmetic_test.cc +++ b/src/reader/spirv/function_arithmetic_test.cc @@ -91,6 +91,15 @@ std::string AstFor(std::string assembly) { ScalarConstructor{30} })"; } + if (assembly == "cast_int_v2uint_10_20") { + return R"(As<__vec_2__i32>{ + TypeConstructor{ + __vec_2__u32 + ScalarConstructor{10} + ScalarConstructor{20} + } + })"; + } if (assembly == "v2float_50_60") { return R"(TypeConstructor{ __vec_2__f32 @@ -126,6 +135,7 @@ inline std::ostream& operator<<(std::ostream& out, BinaryData data) { } using SpvBinaryTest = SpvParserTestBase<::testing::TestWithParam>; +using SpvBinaryTestBasic = SpvParserTestBase<::testing::Test>; TEST_P(SpvBinaryTest, EmitExpression) { const auto assembly = CommonTypes() + R"( @@ -324,6 +334,110 @@ INSTANTIATE_TEST_SUITE_P( "__vec_2__i32", AstFor("v2int_30_40"), "divide", AstFor("v2int_40_30")})); +INSTANTIATE_TEST_SUITE_P( + SpvParserTest_SDiv_MixedSignednessOperands, + SpvBinaryTest, + ::testing::Values( + // Mixed, returning int, second arg uint + BinaryData{"int", "int_30", "OpSDiv", "uint_10", "__i32", + "ScalarConstructor{30}", "divide", + R"(As<__i32>{ + ScalarConstructor{10} + })"}, + // Mixed, returning int, first arg uint + BinaryData{"int", "uint_10", "OpSDiv", "int_30", "__i32", + R"(As<__i32>{ + ScalarConstructor{10} + })", + "divide", "ScalarConstructor{30}"}, + // Mixed, returning v2int, first arg v2uint + BinaryData{"v2int", "v2uint_10_20", "OpSDiv", "v2int_30_40", + "__vec_2__i32", AstFor("cast_int_v2uint_10_20"), "divide", + AstFor("v2int_30_40")}, + // Mixed, returning v2int, second arg v2uint + BinaryData{"v2int", "v2int_30_40", "OpSDiv", "v2uint_10_20", + "__vec_2__i32", AstFor("v2int_30_40"), "divide", + AstFor("cast_int_v2uint_10_20")})); + +TEST_F(SpvBinaryTestBasic, SDiv_Scalar_UnsignedResult) { + // The WGSL signed division operator expects both operands to be signed + // and the result is signed as well. + // In this test SPIR-V demands an unsigned result, so we have to + // wrap the result with an as-cast. + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpSDiv %uint %int_30 %int_40 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) + << p->error() << "\n" + << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_1 + none + __u32 + { + As<__u32>{ + Binary{ + ScalarConstructor{30} + divide + ScalarConstructor{40} + } + } + } + })")); +} + +TEST_F(SpvBinaryTestBasic, SDiv_Vector_UnsignedResult) { + // The WGSL signed division operator expects both operands to be signed + // and the result is signed as well. + // In this test SPIR-V demands an unsigned result, so we have to + // wrap the result with an as-cast. + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpSDiv %v2uint %v2int_30_40 %v2int_40_30 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) + << p->error() << "\n" + << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"( + Variable{ + x_1 + none + __vec_2__u32 + { + As<__vec_2__u32>{ + Binary{ + TypeConstructor{ + __vec_2__i32 + ScalarConstructor{30} + ScalarConstructor{40} + } + divide + TypeConstructor{ + __vec_2__i32 + ScalarConstructor{40} + ScalarConstructor{30} + } + } + } + } + })")) + << ToString(fe.ast_body()); +} + INSTANTIATE_TEST_SUITE_P( SpvParserTest_FDiv, SpvBinaryTest, diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 795dc7d73b..3139cb9fc8 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -33,6 +33,8 @@ #include "source/opt/type_manager.h" #include "source/opt/types.h" #include "spirv-tools/libspirv.hpp" +#include "src/ast/as_expression.h" +#include "src/ast/binary_expression.h" #include "src/ast/bool_literal.h" #include "src/ast/builtin_decoration.h" #include "src/ast/decorated_variable.h" @@ -119,6 +121,52 @@ class FunctionTraverser { std::vector ordered_; }; +// Returns true if the opcode operates as if its operands are signed integral. +bool AssumesSignedOperands(SpvOp opcode) { + switch (opcode) { + case SpvOpSDiv: + case SpvOpSRem: + case SpvOpSMod: + case SpvOpSLessThan: + case SpvOpSLessThanEqual: + case SpvOpSGreaterThan: + case SpvOpSGreaterThanEqual: + return true; + default: + break; + } + return false; +} + +// Returns true if the opcode operates as if its operands are unsigned integral. +bool AssumesUnsignedOperands(SpvOp opcode) { + switch (opcode) { + case SpvOpUDiv: + case SpvOpUMod: + case SpvOpULessThan: + case SpvOpULessThanEqual: + case SpvOpUGreaterThan: + case SpvOpUGreaterThanEqual: + return true; + default: + break; + } + return false; +} + +// Returns true if the operation is binary, and the WGSL operation requires +// the signedness of the result to match the signedness of the first operand. +bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) { + switch (opcode) { + // TODO(dneto): More arithmetic operations. + case SpvOpSDiv: + return true; + default: + break; + } + return false; +} + } // namespace ParserImpl::ParserImpl(Context* ctx, const std::vector& spv_binary) @@ -458,11 +506,13 @@ bool ParserImpl::EmitEntryPoints() { ast::type::Type* ParserImpl::ConvertType( const spvtools::opt::analysis::Integer* int_ty) { if (int_ty->width() == 32) { - if (int_ty->IsSigned()) { - return ctx_.type_mgr().Get(std::make_unique()); - } else { - return ctx_.type_mgr().Get(std::make_unique()); - } + auto signed_ty = + ctx_.type_mgr().Get(std::make_unique()); + auto unsigned_ty = + ctx_.type_mgr().Get(std::make_unique()); + signed_type_for_[unsigned_ty] = signed_ty; + unsigned_type_for_[signed_ty] = unsigned_ty; + return int_ty->IsSigned() ? signed_ty : unsigned_ty; } Fail() << "unhandled integer width: " << int_ty->width(); return nullptr; @@ -484,8 +534,23 @@ ast::type::Type* ParserImpl::ConvertType( if (ast_elem_ty == nullptr) { return nullptr; } - return ctx_.type_mgr().Get( + auto* this_ty = ctx_.type_mgr().Get( std::make_unique(ast_elem_ty, num_elem)); + // Generate the opposite-signedness vector type, if this type is integral. + if (unsigned_type_for_.count(ast_elem_ty)) { + auto* other_ty = + ctx_.type_mgr().Get(std::make_unique( + unsigned_type_for_[ast_elem_ty], num_elem)); + signed_type_for_[other_ty] = this_ty; + unsigned_type_for_[this_ty] = other_ty; + } else if (signed_type_for_.count(ast_elem_ty)) { + auto* other_ty = + ctx_.type_mgr().Get(std::make_unique( + signed_type_for_[ast_elem_ty], num_elem)); + unsigned_type_for_[other_ty] = this_ty; + signed_type_for_[this_ty] = other_ty; + } + return this_ty; } ast::type::Type* ParserImpl::ConvertType( @@ -782,6 +847,7 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { Fail() << "ID " << id << " is not a constant"; return {}; } + // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // So canonicalization should map that way too. // Currently "null" is missing from the WGSL parser. @@ -839,6 +905,53 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { return {}; } +TypedExpression ParserImpl::RectifyOperandSignedness(SpvOp op, + TypedExpression&& expr) { + const bool requires_signed = AssumesSignedOperands(op); + const bool requires_unsigned = AssumesUnsignedOperands(op); + if (!requires_signed && !requires_unsigned) { + // No conversion is required, assuming our tables are complete. + return std::move(expr); + } + if (!expr.expr) { + Fail() << "internal error: RectifyOperandSignedness given a null expr\n"; + return {}; + } + auto* type = expr.type; + if (!type) { + Fail() << "internal error: unmapped type for: " << expr.expr->str() << "\n"; + return {}; + } + if (requires_unsigned) { + auto* unsigned_ty = unsigned_type_for_[type]; + if (unsigned_ty != nullptr) { + // Conversion is required. + return {unsigned_ty, std::make_unique( + unsigned_ty, std::move(expr.expr))}; + } + } else if (requires_signed) { + auto* signed_ty = signed_type_for_[type]; + if (signed_ty != nullptr) { + // Conversion is required. + return {signed_ty, std::make_unique( + signed_ty, std::move(expr.expr))}; + } + } + // We should not reach here. + return std::move(expr); +} + +ast::type::Type* ParserImpl::ForcedResultType( + SpvOp op, + ast::type::Type* first_operand_type) { + const bool binary_match_first_operand = + AssumesResultSignednessMatchesBinaryFirstOperand(op); + if (binary_match_first_operand) { + return first_operand_type; + } + return nullptr; +} + bool ParserImpl::EmitFunctions() { if (!success_) { return false; diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index 91c22e01eb..2412398a35 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -30,6 +30,7 @@ #include "source/opt/type_manager.h" #include "source/opt/types.h" #include "spirv-tools/libspirv.hpp" +#include "src/ast/expression.h" #include "src/ast/import.h" #include "src/ast/module.h" #include "src/ast/struct_member_decoration.h" @@ -245,6 +246,28 @@ class ParserImpl : Reader { /// @returns a new Literal node TypedExpression MakeConstantExpression(uint32_t id); + /// Converts a given expression to the signedness demanded for an operand + /// of the given SPIR-V opcode, if required. If the operation assumes + /// signed integer operands, and |expr| is unsigned, then return an + /// as-cast expression converting it to signed. Otherwise, return + /// |expr| itself. Similarly, convert as required from unsigned + /// to signed. Assumes all SPIR-V types have been mapped to AST types. + /// @param op the SPIR-V opcode + /// @param expr an expression + /// @returns expr, or a cast of expr + TypedExpression RectifyOperandSignedness(SpvOp op, TypedExpression&& expr); + + /// Returns the "forced" result type for the given SPIR-V opcode. + /// If the WGSL result type for an operation has a more strict rule than + /// requried by SPIR-V, then we say the result type is "forced". This occurs + /// for signed integer division (OpSDiv), for example, where the result type + /// in WGSL must match the operand types. + /// @param op the SPIR-V opcode + /// @param first_operand_type the AST type for the first operand. + /// @returns the forced AST result type, or nullptr if no forcing is required. + ast::type::Type* ForcedResultType(SpvOp op, + ast::type::Type* first_operand_type); + private: /// Converts a specific SPIR-V type to a Tint type. Integer case ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); @@ -303,6 +326,11 @@ class ParserImpl : Reader { // Maps a SPIR-V type ID to a Tint type. std::unordered_map id_to_type_; + + // Maps an unsigned type corresponding to the given signed type. + std::unordered_map signed_type_for_; + // Maps an signed type corresponding to the given unsigned type. + std::unordered_map unsigned_type_for_; }; } // namespace spirv