From ccc67252ffa40dd723f0c114ac8906b3ad13b063 Mon Sep 17 00:00:00 2001 From: David Neto Date: Thu, 10 Dec 2020 18:51:51 +0000 Subject: [PATCH] spirv-reader: support OpBitCount, OpBitReverse Bug: tint:3 Change-Id: I81580136621ab51a9852e1d692ddad2457b9aab9 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35340 Auto-Submit: David Neto Reviewed-by: dan sinclair Commit-Queue: David Neto --- src/reader/spirv/function.cc | 14 +- src/reader/spirv/function_bit_test.cc | 492 +++++++++++++++++++++++++- src/reader/spirv/parser_impl.cc | 17 +- 3 files changed, 509 insertions(+), 14 deletions(-) diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 484fbba1a8..50bb0ab179 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -465,6 +465,10 @@ std::string GetGlslStd450FuncName(uint32_t ext_opcode) { // given instruction, or ast::Intrinsic::kNone ast::Intrinsic GetIntrinsic(SpvOp opcode) { switch (opcode) { + case SpvOpBitCount: + return ast::Intrinsic::kCountOneBits; + case SpvOpBitReverse: + return ast::Intrinsic::kReverseBits; case SpvOpDot: return ast::Intrinsic::kDot; case SpvOpOuterProduct: @@ -3726,8 +3730,13 @@ TypedExpression FunctionEmitter::MakeIntrinsicCall( ident->set_intrinsic(intrinsic); ast::ExpressionList params; + ast::type::Type* first_operand_type = nullptr; for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) { - params.emplace_back(MakeOperand(inst, iarg).expr); + TypedExpression operand = MakeOperand(inst, iarg); + if (first_operand_type == nullptr) { + first_operand_type = operand.type; + } + params.emplace_back(operand.expr); } auto* call_expr = create(ident, std::move(params)); auto* result_type = parser_impl_.ConvertType(inst.type_id()); @@ -3736,7 +3745,8 @@ TypedExpression FunctionEmitter::MakeIntrinsicCall( << inst.PrettyPrint(); return {}; } - return {result_type, call_expr}; + TypedExpression call{result_type, call_expr}; + return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type); } TypedExpression FunctionEmitter::MakeSimpleSelect( diff --git a/src/reader/spirv/function_bit_test.cc b/src/reader/spirv/function_bit_test.cc index 3a03837368..3ed4f4e067 100644 --- a/src/reader/spirv/function_bit_test.cc +++ b/src/reader/spirv/function_bit_test.cc @@ -627,11 +627,499 @@ TEST_F(SpvUnaryBitTest, Not_UnsignedVec_UnsignedVec) { << ToString(fe.ast_body()); } +std::string BitTestPreamble() { + return R"( + OpCapability Shader + %glsl = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %100 "main" + OpExecutionMode %100 LocalSize 1 1 1 + + OpName %u1 "u1" + OpName %i1 "i1" + OpName %v2u1 "v2u1" + OpName %v2i1 "v2i1" + +)" + CommonTypes() + + R"( + + %100 = OpFunction %void None %voidfn + %entry = OpLabel + + %u1 = OpCopyObject %uint %uint_10 + %i1 = OpCopyObject %int %int_30 + %v2u1 = OpCopyObject %v2uint %v2uint_10_20 + %v2i1 = OpCopyObject %v2int %v2int_30_40 +)"; +} + +TEST_F(SpvUnaryBitTest, BitCount_Uint_Uint) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %uint %u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __u32 + { + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{u1} + ) + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitCount_Uint_Int) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %uint %i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __u32 + { + Bitcast[not set]<__u32>{ + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{i1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitCount_Int_Uint) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %int %u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __i32 + { + Bitcast[not set]<__i32>{ + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{u1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitCount_Int_Int) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %int %i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __i32 + { + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{i1} + ) + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitCount_UintVector_UintVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %v2uint %v2u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__u32 + { + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{v2u1} + ) + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitCount_UintVector_IntVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %v2uint %v2i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__u32 + { + Bitcast[not set]<__vec_2__u32>{ + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{v2i1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitCount_IntVector_UintVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %v2int %v2u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__i32 + { + Bitcast[not set]<__vec_2__i32>{ + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{v2u1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitCount_IntVector_IntVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitCount %v2int %v2i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__i32 + { + Call[not set]{ + Identifier[not set]{countOneBits} + ( + Identifier[not set]{v2i1} + ) + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_Uint_Uint) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %uint %u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __u32 + { + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{u1} + ) + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_Uint_Int) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %uint %i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __u32 + { + Bitcast[not set]<__u32>{ + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{i1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_Int_Uint) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %int %u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __i32 + { + Bitcast[not set]<__i32>{ + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{u1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_Int_Int) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %int %i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __i32 + { + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{i1} + ) + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_UintVector_UintVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %v2uint %v2u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__u32 + { + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{v2u1} + ) + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_UintVector_IntVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %v2uint %v2i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__u32 + { + Bitcast[not set]<__vec_2__u32>{ + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{v2i1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_IntVector_UintVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %v2int %v2u1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__i32 + { + Bitcast[not set]<__vec_2__i32>{ + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{v2u1} + ) + } + } + } + })")) + << body; +} + +TEST_F(SpvUnaryBitTest, BitReverse_IntVector_IntVector) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitReverse %v2int %v2i1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto body = ToString(fe.ast_body()); + EXPECT_THAT(body, HasSubstr(R"( + VariableConst{ + x_1 + none + __vec_2__i32 + { + Call[not set]{ + Identifier[not set]{reverseBits} + ( + Identifier[not set]{v2i1} + ) + } + } + })")) + << body; +} + // TODO(dneto): OpBitFieldInsert // TODO(dneto): OpBitFieldSExtract // TODO(dneto): OpBitFieldUExtract -// TODO(dneto): OpBitReverse -// TODO(dneto): OpBitCount } // namespace } // namespace spirv diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 71e93196c9..6c2bd66b90 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -209,10 +209,14 @@ bool AssumesUnsignedOperands(GLSLstd450 extended_opcode) { return false; } -// Returns true if the operation is binary, and the WGSL operation requires +// Returns true if the corresponding WGSL operation requires // the signedness of the result to match the signedness of the first operand. -bool AssumesResultSignednessMatchesBinaryFirstOperand(SpvOp opcode) { +bool AssumesResultSignednessMatchesFirstOperand(SpvOp opcode) { switch (opcode) { + case SpvOpNot: + case SpvOpSNegate: + case SpvOpBitCount: + case SpvOpBitReverse: case SpvOpSDiv: case SpvOpSMod: case SpvOpSRem: @@ -1501,14 +1505,7 @@ ast::type::Type* ParserImpl::ForcedResultType( const spvtools::opt::Instruction& inst, ast::type::Type* first_operand_type) { const auto opcode = inst.opcode(); - if ((opcode == SpvOpSNegate) || (opcode == SpvOpNot)) { - // The unary operation cases that force the result type to match the - // first operand type. - return first_operand_type; - } - if (AssumesResultSignednessMatchesBinaryFirstOperand(opcode)) { - // The binary operation cases that force the result type to match - // the first operand type. + if (AssumesResultSignednessMatchesFirstOperand(opcode)) { return first_operand_type; } if (IsGlslExtendedInstruction(inst)) {