tint/spirv-reader: cast offset and count args to u32 for insertBits/extractBits

Bug: tint:1874
Change-Id: Ieadbfcb7fc61a0404dd988df42e0cfe0c8693b02
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124320
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano 2023-03-15 21:10:00 +00:00 committed by Dawn LUCI CQ
parent 71e0f5fe46
commit 81d11b3cf1
3 changed files with 218 additions and 12 deletions

View File

@ -3819,7 +3819,14 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
const auto builtin = GetBuiltin(op);
if (builtin != builtin::Function::kNone) {
return MakeBuiltinCall(inst);
switch (builtin) {
case builtin::Function::kExtractBits:
return MakeExtractBitsCall(inst);
case builtin::Function::kInsertBits:
return MakeInsertBitsCall(inst);
default:
return MakeBuiltinCall(inst);
}
}
if (op == spv::Op::OpFMod) {
@ -5274,6 +5281,42 @@ TypedExpression FunctionEmitter::MakeBuiltinCall(const spvtools::opt::Instructio
return parser_impl_.RectifyForcedResultType(call, inst, first_operand_type);
}
TypedExpression FunctionEmitter::MakeExtractBitsCall(const spvtools::opt::Instruction& inst) {
const auto builtin = GetBuiltin(opcode(inst));
auto* name = builtin::str(builtin);
auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
auto e = MakeOperand(inst, 0);
auto offset = ToU32(MakeOperand(inst, 1));
auto count = ToU32(MakeOperand(inst, 2));
auto* call_expr = builder_.Call(ident, ExpressionList{e.expr, offset.expr, count.expr});
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type) {
Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
return {};
}
TypedExpression call{result_type, call_expr};
return parser_impl_.RectifyForcedResultType(call, inst, e.type);
}
TypedExpression FunctionEmitter::MakeInsertBitsCall(const spvtools::opt::Instruction& inst) {
const auto builtin = GetBuiltin(opcode(inst));
auto* name = builtin::str(builtin);
auto* ident = create<ast::Identifier>(Source{}, builder_.Symbols().Register(name));
auto e = MakeOperand(inst, 0);
auto newbits = MakeOperand(inst, 1);
auto offset = ToU32(MakeOperand(inst, 2));
auto count = ToU32(MakeOperand(inst, 3));
auto* call_expr =
builder_.Call(ident, ExpressionList{e.expr, newbits.expr, offset.expr, count.expr});
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type) {
Fail() << "internal error: no mapped type result of call: " << inst.PrettyPrint();
return {};
}
TypedExpression call{result_type, call_expr};
return parser_impl_.RectifyForcedResultType(call, inst, e.type);
}
TypedExpression FunctionEmitter::MakeSimpleSelect(const spvtools::opt::Instruction& inst) {
auto condition = MakeOperand(inst, 0);
auto true_value = MakeOperand(inst, 1);
@ -6053,6 +6096,13 @@ TypedExpression FunctionEmitter::ToI32(TypedExpression value) {
return {ty_.I32(), builder_.Call(builder_.ty.i32(), utils::Vector{value.expr})};
}
TypedExpression FunctionEmitter::ToU32(TypedExpression value) {
if (!value || value.type->Is<U32>()) {
return value;
}
return {ty_.U32(), builder_.Call(builder_.ty.u32(), utils::Vector{value.expr})};
}
TypedExpression FunctionEmitter::ToSignedIfUnsigned(TypedExpression value) {
if (!value || !value.type->IsUnsignedScalarOrVector()) {
return value;

View File

@ -945,6 +945,12 @@ class FunctionEmitter {
/// @returns the value as an i32 value.
TypedExpression ToI32(TypedExpression value);
/// Returns the given value as an u32. If it's already an u32 then simply returns @p value.
/// Otherwise, wrap the value in a TypeInitializer expression.
/// @param value the value to pass through or convert
/// @returns the value as an u32 value.
TypedExpression ToU32(TypedExpression value);
/// Returns the given value as a signed integer type of the same shape if the value is unsigned
/// scalar or vector, by wrapping the value with a TypeInitializer expression. Returns the
/// value itself if the value was already signed.
@ -1035,6 +1041,18 @@ class FunctionEmitter {
/// @returns an expression
TypedExpression MakeBuiltinCall(const spvtools::opt::Instruction& inst);
/// Returns an expression for a SPIR-V instruction that maps to the extractBits WGSL
/// builtin function call, with special handling to cast offset and count to u32, if needed.
/// @param inst the SPIR-V instruction
/// @returns an expression
TypedExpression MakeExtractBitsCall(const spvtools::opt::Instruction& inst);
/// Returns an expression for a SPIR-V instruction that maps to the insertBits WGSL
/// builtin function call, with special handling to cast offset and count to u32, if needed.
/// @param inst the SPIR-V instruction
/// @returns an expression
TypedExpression MakeInsertBitsCall(const spvtools::opt::Instruction& inst);
/// Returns an expression for a SPIR-V OpArrayLength instruction.
/// @param inst the SPIR-V instruction
/// @returns an expression

View File

@ -33,6 +33,8 @@ std::string CommonTypes() {
%uint_10 = OpConstant %uint 10
%uint_20 = OpConstant %uint 20
%int_10 = OpConstant %int 10
%int_20 = OpConstant %int 20
%int_30 = OpConstant %int 30
%int_40 = OpConstant %int 40
%float_50 = OpConstant %float 50
@ -832,7 +834,7 @@ TEST_F(SpvUnaryBitTest, BitReverse_IntVector_IntVector) {
TEST_F(SpvUnaryBitTest, InsertBits_Int) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldInsert %v2int %int_30 %int_40 %uint_10 %uint_20
%1 = OpBitFieldInsert %int %int_30 %int_40 %uint_10 %uint_20
OpReturn
OpFunctionEnd
)";
@ -842,7 +844,23 @@ TEST_F(SpvUnaryBitTest, InsertBits_Int) {
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : vec2<i32> = insertBits(30i, 40i, 10u, 20u);")) << body;
EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, 10u, 20u);")) << body;
}
TEST_F(SpvUnaryBitTest, InsertBits_Int_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldInsert %int %int_30 %int_40 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : i32 = insertBits(30i, 40i, u32(10i), u32(20i));"))
<< body;
}
TEST_F(SpvUnaryBitTest, InsertBits_IntVector) {
@ -864,9 +882,9 @@ TEST_F(SpvUnaryBitTest, InsertBits_IntVector) {
<< body;
}
TEST_F(SpvUnaryBitTest, InsertBits_Uint) {
TEST_F(SpvUnaryBitTest, InsertBits_IntVector_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldInsert %v2uint %uint_20 %uint_10 %uint_10 %uint_20
%1 = OpBitFieldInsert %v2int %v2int_30_40 %v2int_40_30 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
@ -876,7 +894,42 @@ TEST_F(SpvUnaryBitTest, InsertBits_Uint) {
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : vec2<u32> = insertBits(20u, 10u, 10u, 20u);")) << body;
EXPECT_THAT(
body,
HasSubstr(
R"(let x_1 : vec2<i32> = insertBits(vec2<i32>(30i, 40i), vec2<i32>(40i, 30i), u32(10i), u32(20i));)"))
<< body;
}
TEST_F(SpvUnaryBitTest, InsertBits_Uint) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldInsert %uint %uint_20 %uint_10 %uint_10 %uint_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, 10u, 20u);")) << body;
}
TEST_F(SpvUnaryBitTest, InsertBits_Uint_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldInsert %uint %uint_20 %uint_10 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : u32 = insertBits(20u, 10u, u32(10i), u32(20i));"))
<< body;
}
TEST_F(SpvUnaryBitTest, InsertBits_UintVector) {
@ -898,9 +951,9 @@ TEST_F(SpvUnaryBitTest, InsertBits_UintVector) {
<< body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_Int) {
TEST_F(SpvUnaryBitTest, InsertBits_UintVector_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldSExtract %v2int %int_30 %uint_10 %uint_20
%1 = OpBitFieldInsert %v2uint %v2uint_10_20 %v2uint_20_10 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
@ -910,7 +963,41 @@ TEST_F(SpvUnaryBitTest, ExtractBits_Int) {
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : vec2<i32> = extractBits(30i, 10u, 20u);")) << body;
EXPECT_THAT(
body,
HasSubstr(
R"(let x_1 : vec2<u32> = insertBits(vec2<u32>(10u, 20u), vec2<u32>(20u, 10u), u32(10i), u32(20i));)"))
<< body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_Int) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldSExtract %int %int_30 %uint_10 %uint_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, 10u, 20u);")) << body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_Int_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldSExtract %int %int_30 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : i32 = extractBits(30i, u32(10i), u32(20i));")) << body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_IntVector) {
@ -930,9 +1017,9 @@ TEST_F(SpvUnaryBitTest, ExtractBits_IntVector) {
<< body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_Uint) {
TEST_F(SpvUnaryBitTest, ExtractBits_IntVector_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldUExtract %v2uint %uint_20 %uint_10 %uint_20
%1 = OpBitFieldSExtract %v2int %v2int_30_40 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
@ -942,7 +1029,40 @@ TEST_F(SpvUnaryBitTest, ExtractBits_Uint) {
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : vec2<u32> = extractBits(20u, 10u, 20u);")) << body;
EXPECT_THAT(
body,
HasSubstr("let x_1 : vec2<i32> = extractBits(vec2<i32>(30i, 40i), u32(10i), u32(20i));"))
<< body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_Uint) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldUExtract %uint %uint_20 %uint_10 %uint_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, 10u, 20u);")) << body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_Uint_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldUExtract %uint %uint_20 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(body, HasSubstr("let x_1 : u32 = extractBits(20u, u32(10i), u32(20i));")) << body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_UintVector) {
@ -962,5 +1082,23 @@ TEST_F(SpvUnaryBitTest, ExtractBits_UintVector) {
<< body;
}
TEST_F(SpvUnaryBitTest, ExtractBits_UintVector_SignedOffsetAndCount) {
const auto assembly = BitTestPreamble() + R"(
%1 = OpBitFieldUExtract %v2uint %v2uint_10_20 %int_10 %int_20
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto ast_body = fe.ast_body();
auto body = test::ToString(p->program(), ast_body);
EXPECT_THAT(
body,
HasSubstr("let x_1 : vec2<u32> = extractBits(vec2<u32>(10u, 20u), u32(10i), u32(20i));"))
<< body;
}
} // namespace
} // namespace tint::reader::spirv