diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc index 7ec46e73a3..ff0d721b23 100644 --- a/src/tint/reader/spirv/function.cc +++ b/src/tint/reader/spirv/function.cc @@ -3819,7 +3819,14 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( const auto builtin = GetBuiltin(op); if (builtin != builtin::Function::kNone) { - return MakeBuiltinCall(inst); + switch (builtin) { + case builtin::Function::kExtractBits: + return MakeExtractBitsCall(inst); + case builtin::Function::kInsertBits: + return MakeInsertBitsCall(inst); + default: + return MakeBuiltinCall(inst); + } } if (op == spv::Op::OpFMod) { @@ -5274,6 +5281,42 @@ TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instructio return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type); } +TypedExpression FunctionEmitter::MakeExtractBitsCall(const spvtools::opt::Instruction& inst) { + const auto builtin = GetBuiltin(opcode(inst)); + auto* name = builtin::str(builtin); + auto* ident = create(Source{}, builder_.Symbols().Register(name)); + auto e = MakeOperand(inst, 0); + auto offset = ToU32(MakeOperand(inst, 1)); + auto count = ToU32(MakeOperand(inst, 2)); + auto* call_expr = builder_.Call(ident, ExpressionList{e.expr, offset.expr, count.expr}); + auto* result_type = parser_impl_.ConvertType(inst.type_id()); + if (!result_type) { + Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint(); + return {}; + } + TypedExpression call{result_type, call_expr}; + return parser_impl_.RectifyForcedResultType(call, inst, e.type); +} + +TypedExpression FunctionEmitter::MakeInsertBitsCall(const spvtools::opt::Instruction& inst) { + const auto builtin = GetBuiltin(opcode(inst)); + auto* name = builtin::str(builtin); + auto* ident = create(Source{}, builder_.Symbols().Register(name)); + auto e = MakeOperand(inst, 0); + auto newbits = MakeOperand(inst, 1); + auto offset = ToU32(MakeOperand(inst, 2)); + auto count = ToU32(MakeOperand(inst, 3)); + auto* call_expr = + builder_.Call(ident, ExpressionList{e.expr, newbits.expr, offset.expr, count.expr}); + auto* result_type = parser_impl_.ConvertType(inst.type_id()); + if (!result_type) { + Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint(); + return {}; + } + TypedExpression call{result_type, call_expr}; + return parser_impl_.RectifyForcedResultType(call, inst, e.type); +} + TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instruction& inst) { auto condition = MakeOperand(inst, 0); auto true_value = MakeOperand(inst, 1); @@ -6053,6 +6096,13 @@ TypedExpression FunctionEmitter::ToI32(TypedExpression value) { return {ty_.I32(), builder_.Call(builder_.ty.i32(), utils::Vector{value.expr})}; } +TypedExpression FunctionEmitter::ToU32(TypedExpression value) { + if (!value || value.type->Is()) { + return value; + } + return {ty_.U32(), builder_.Call(builder_.ty.u32(), utils::Vector{value.expr})}; +} + TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) { if (!value || !value.type->IsUnsignedScalarOrVector()) { return value; diff --git a/src/tint/reader/spirv/function.h b/src/tint/reader/spirv/function.h index 718e8e3ea6..11fc92ce7c 100644 --- a/src/tint/reader/spirv/function.h +++ b/src/tint/reader/spirv/function.h @@ -945,6 +945,12 @@ class FunctionEmitter { /// @returns the value as an i32 value. TypedExpression ToI32(TypedExpression value); + /// Returns the given value as an u32. If it's already an u32 then simply returns @p value. + /// Otherwise, wrap the value in a TypeInitializer expression. + /// @param value the value to pass through or convert + /// @returns the value as an u32 value. + TypedExpression ToU32(TypedExpression value); + /// Returns the given value as a signed integer type of the same shape if the value is unsigned /// scalar or vector, by wrapping the value with a TypeInitializer expression. Returns the /// value itself if the value was already signed. @@ -1035,6 +1041,18 @@ class FunctionEmitter { /// @returns an expression TypedExpression MakeBuiltinCall(const spvtools::opt::Instruction& inst); + /// Returns an expression for a SPIR-V instruction that maps to the extractBits WGSL + /// builtin function call, with special handling to cast offset and count to u32, if needed. + /// @param inst the SPIR-V instruction + /// @returns an expression + TypedExpression MakeExtractBitsCall(const spvtools::opt::Instruction& inst); + + /// Returns an expression for a SPIR-V instruction that maps to the insertBits WGSL + /// builtin function call, with special handling to cast offset and count to u32, if needed. + /// @param inst the SPIR-V instruction + /// @returns an expression + TypedExpression MakeInsertBitsCall(const spvtools::opt::Instruction& inst); + /// Returns an expression for a SPIR-V OpArrayLength instruction. /// @param inst the SPIR-V instruction /// @returns an expression diff --git a/src/tint/reader/spirv/function_bit_test.cc b/src/tint/reader/spirv/function_bit_test.cc index 40f916245e..2a12f010a6 100644 --- a/src/tint/reader/spirv/function_bit_test.cc +++ b/src/tint/reader/spirv/function_bit_test.cc @@ -33,6 +33,8 @@ std::string CommonTypes() { %uint_10 = OpConstant %uint 10 %uint_20 = OpConstant %uint 20 + %int_10 = OpConstant %int 10 + %int_20 = OpConstant %int 20 %int_30 = OpConstant %int 30 %int_40 = OpConstant %int 40 %float_50 = OpConstant %float 50 @@ -832,7 +834,7 @@ TEST_F(SpvUnaryBitTest, BitReverse_IntVector_IntVector) { TEST_F(SpvUnaryBitTest, InsertBits_Int) { const auto assembly = BitTestPreamble() + R"( - %1 = OpBitFieldInsert %v2int %int_30 %int_40 %uint_10 %uint_20 + %1 = OpBitFieldInsert %int %int_30 %int_40 %uint_10 %uint_20 OpReturn OpFunctionEnd )"; @@ -842,7 +844,23 @@ TEST_F(SpvUnaryBitTest, InsertBits_Int) { EXPECT_TRUE(fe.EmitBody()) << p->error(); auto ast_body = fe.ast_body(); auto body = test::ToString(p->program(), ast_body); - EXPECT_THAT(body, HasSubstr("let x_1 : vec2 = insertBits(30i, 40i, 10u, 20u);")) << body; + EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, 10u, 20u);")) << body; +} + +TEST_F(SpvUnaryBitTest, InsertBits_Int_SignedOffsetAndCount) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldInsert %int %int_30 %int_40 %int_10 %int_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, u32(10i), u32(20i));")) + << body; } TEST_F(SpvUnaryBitTest, InsertBits_IntVector) { @@ -864,9 +882,9 @@ TEST_F(SpvUnaryBitTest, InsertBits_IntVector) { << body; } -TEST_F(SpvUnaryBitTest, InsertBits_Uint) { +TEST_F(SpvUnaryBitTest, InsertBits_IntVector_SignedOffsetAndCount) { const auto assembly = BitTestPreamble() + R"( - %1 = OpBitFieldInsert %v2uint %uint_20 %uint_10 %uint_10 %uint_20 + %1 = OpBitFieldInsert %v2int %v2int_30_40 %v2int_40_30 %int_10 %int_20 OpReturn OpFunctionEnd )"; @@ -876,7 +894,42 @@ TEST_F(SpvUnaryBitTest, InsertBits_Uint) { EXPECT_TRUE(fe.EmitBody()) << p->error(); auto ast_body = fe.ast_body(); auto body = test::ToString(p->program(), ast_body); - EXPECT_THAT(body, HasSubstr("let x_1 : vec2 = insertBits(20u, 10u, 10u, 20u);")) << body; + EXPECT_THAT( + body, + HasSubstr( + R"(let x_1 : vec2 = insertBits(vec2(30i, 40i), vec2(40i, 30i), u32(10i), u32(20i));)")) + << body; +} + +TEST_F(SpvUnaryBitTest, InsertBits_Uint) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldInsert %uint %uint_20 %uint_10 %uint_10 %uint_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, 10u, 20u);")) << body; +} + +TEST_F(SpvUnaryBitTest, InsertBits_Uint_SignedOffsetAndCount) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldInsert %uint %uint_20 %uint_10 %int_10 %int_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, u32(10i), u32(20i));")) + << body; } TEST_F(SpvUnaryBitTest, InsertBits_UintVector) { @@ -898,9 +951,9 @@ TEST_F(SpvUnaryBitTest, InsertBits_UintVector) { << body; } -TEST_F(SpvUnaryBitTest, ExtractBits_Int) { +TEST_F(SpvUnaryBitTest, InsertBits_UintVector_SignedOffsetAndCount) { const auto assembly = BitTestPreamble() + R"( - %1 = OpBitFieldSExtract %v2int %int_30 %uint_10 %uint_20 + %1 = OpBitFieldInsert %v2uint %v2uint_10_20 %v2uint_20_10 %int_10 %int_20 OpReturn OpFunctionEnd )"; @@ -910,7 +963,41 @@ TEST_F(SpvUnaryBitTest, ExtractBits_Int) { EXPECT_TRUE(fe.EmitBody()) << p->error(); auto ast_body = fe.ast_body(); auto body = test::ToString(p->program(), ast_body); - EXPECT_THAT(body, HasSubstr("let x_1 : vec2 = extractBits(30i, 10u, 20u);")) << body; + EXPECT_THAT( + body, + HasSubstr( + R"(let x_1 : vec2 = insertBits(vec2(10u, 20u), vec2(20u, 10u), u32(10i), u32(20i));)")) + << body; +} + +TEST_F(SpvUnaryBitTest, ExtractBits_Int) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldSExtract %int %int_30 %uint_10 %uint_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, 10u, 20u);")) << body; +} + +TEST_F(SpvUnaryBitTest, ExtractBits_Int_SignedOffsetAndCount) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldSExtract %int %int_30 %int_10 %int_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, u32(10i), u32(20i));")) << body; } TEST_F(SpvUnaryBitTest, ExtractBits_IntVector) { @@ -930,9 +1017,9 @@ TEST_F(SpvUnaryBitTest, ExtractBits_IntVector) { << body; } -TEST_F(SpvUnaryBitTest, ExtractBits_Uint) { +TEST_F(SpvUnaryBitTest, ExtractBits_IntVector_SignedOffsetAndCount) { const auto assembly = BitTestPreamble() + R"( - %1 = OpBitFieldUExtract %v2uint %uint_20 %uint_10 %uint_20 + %1 = OpBitFieldSExtract %v2int %v2int_30_40 %int_10 %int_20 OpReturn OpFunctionEnd )"; @@ -942,7 +1029,40 @@ TEST_F(SpvUnaryBitTest, ExtractBits_Uint) { EXPECT_TRUE(fe.EmitBody()) << p->error(); auto ast_body = fe.ast_body(); auto body = test::ToString(p->program(), ast_body); - EXPECT_THAT(body, HasSubstr("let x_1 : vec2 = extractBits(20u, 10u, 20u);")) << body; + EXPECT_THAT( + body, + HasSubstr("let x_1 : vec2 = extractBits(vec2(30i, 40i), u32(10i), u32(20i));")) + << body; +} + +TEST_F(SpvUnaryBitTest, ExtractBits_Uint) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldUExtract %uint %uint_20 %uint_10 %uint_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, 10u, 20u);")) << body; +} + +TEST_F(SpvUnaryBitTest, ExtractBits_Uint_SignedOffsetAndCount) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldUExtract %uint %uint_20 %int_10 %int_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, u32(10i), u32(20i));")) << body; } TEST_F(SpvUnaryBitTest, ExtractBits_UintVector) { @@ -962,5 +1082,23 @@ TEST_F(SpvUnaryBitTest, ExtractBits_UintVector) { << body; } +TEST_F(SpvUnaryBitTest, ExtractBits_UintVector_SignedOffsetAndCount) { + const auto assembly = BitTestPreamble() + R"( + %1 = OpBitFieldUExtract %v2uint %v2uint_10_20 %int_10 %int_20 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + auto body = test::ToString(p->program(), ast_body); + EXPECT_THAT( + body, + HasSubstr("let x_1 : vec2 = extractBits(vec2(10u, 20u), u32(10i), u32(20i));")) + << body; +} + } // namespace } // namespace tint::reader::spirv