Fix operator% for f32 and vecN<f32>

https://github.com/gpuweb/gpuweb/pull/1945 changes the SPIR-V mapping of this operator so that it now maps to OpFRem instead of OpFMod. Polyfill OpFMod with `x - y * floor(x / y)`

Also map the MSL output of this operator to use `fmod()`.

Behavior of this operator is now consistent across all backends.

Fixed: tint:945
Fixed: tint:977
Fixed: tint:1010
Change-Id: Iefa009b905989c55ace24e073ab0e261c7cf69b0
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58393
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton
2021-07-21 14:11:01 +00:00
committed by Tint LUCI CQ
parent 1ec484410a
commit 81d4ed0d9c
17 changed files with 453 additions and 98 deletions

View File

@@ -212,7 +212,7 @@ ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
return ast::BinaryOp::kDivide;
case SpvOpUMod:
case SpvOpSMod:
case SpvOpFMod:
case SpvOpFRem:
return ast::BinaryOp::kModulo;
case SpvOpLogicalEqual:
case SpvOpIEqual:
@@ -398,8 +398,9 @@ std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
return "unpack2x16float";
default:
// TODO(dneto) - The following are not implemented.
// They are grouped semantically, as in GLSL.std.450.h.
// TODO(dneto) - The following are not implemented.
// They are grouped semantically, as in GLSL.std.450.h.
case GLSLstd450SSign:
case GLSLstd450Radians:
@@ -3854,6 +3855,10 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
return MakeIntrinsicCall(inst);
}
if (opcode == SpvOpFMod) {
return MakeFMod(inst);
}
if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
return MakeAccessChain(inst);
}
@@ -4074,6 +4079,21 @@ ast::IdentifierExpression* FunctionEmitter::PrefixSwizzle(uint32_t n) {
return nullptr;
}
TypedExpression FunctionEmitter::MakeFMod(
const spvtools::opt::Instruction& inst) {
auto x = MakeOperand(inst, 0);
auto y = MakeOperand(inst, 1);
if (!x || !y) {
return {};
}
// Emulated with: x - y * floor(x / y)
auto* div = builder_.Div(x.expr, y.expr);
auto* floor = builder_.Call("floor", div);
auto* y_floor = builder_.Mul(y.expr, floor);
auto* res = builder_.Sub(x.expr, y_floor);
return {x.type, res};
}
TypedExpression FunctionEmitter::MakeAccessChain(
const spvtools::opt::Instruction& inst) {
if (inst.NumInOperands() < 1) {

View File

@@ -966,6 +966,11 @@ class FunctionEmitter {
/// @results a copy of the expression, with possibly updated type
TypedExpression InferFunctionStorageClass(TypedExpression expr);
/// Returns an expression for a SPIR-V OpFMod instruction.
/// @param inst the SPIR-V instruction
/// @returns an expression
TypedExpression MakeFMod(const spvtools::opt::Instruction& inst);
/// Returns an expression for a SPIR-V OpAccessChain or OpInBoundsAccessChain
/// instruction.
/// @param inst the SPIR-V instruction

View File

@@ -1239,18 +1239,120 @@ TEST_F(SpvBinaryArithTestBasic, SMod_Vector_UnsignedResult) {
}
INSTANTIATE_TEST_SUITE_P(
SpvParserTest_FMod,
SpvParserTest_FRem,
SpvBinaryArithTest,
::testing::Values(
// Scalar float
BinaryData{"float", "float_50", "OpFMod", "float_60", "__f32",
BinaryData{"float", "float_50", "OpFRem", "float_60", "__f32",
"ScalarConstructor[not set]{50.000000}", "modulo",
"ScalarConstructor[not set]{60.000000}"},
// Vector float
BinaryData{"v2float", "v2float_50_60", "OpFMod", "v2float_60_50",
BinaryData{"v2float", "v2float_50_60", "OpFRem", "v2float_60_50",
"__vec_2__f32", AstFor("v2float_50_60"), "modulo",
AstFor("v2float_60_50")}));
TEST_F(SpvBinaryArithTestBasic, FMod_Scalar) {
const auto assembly = Preamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpFMod %float %float_50 %float_60
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< p->error() << "\n"
<< assembly;
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(p->builder(), fe.ast_body()), HasSubstr(R"(
VariableConst{
x_1
none
undefined
__f32
{
Binary[not set]{
ScalarConstructor[not set]{50.000000}
subtract
Binary[not set]{
ScalarConstructor[not set]{60.000000}
multiply
Call[not set]{
Identifier[not set]{floor}
(
Binary[not set]{
ScalarConstructor[not set]{50.000000}
divide
ScalarConstructor[not set]{60.000000}
}
)
}
}
}
}
})"));
}
TEST_F(SpvBinaryArithTestBasic, FMod_Vector) {
const auto assembly = Preamble() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpFMod %v2float %v2float_50_60 %v2float_60_50
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions())
<< p->error() << "\n"
<< assembly;
auto fe = p->function_emitter(100);
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(p->builder(), fe.ast_body()), HasSubstr(R"(
VariableConst{
x_1
none
undefined
__vec_2__f32
{
Binary[not set]{
TypeConstructor[not set]{
__vec_2__f32
ScalarConstructor[not set]{50.000000}
ScalarConstructor[not set]{60.000000}
}
subtract
Binary[not set]{
TypeConstructor[not set]{
__vec_2__f32
ScalarConstructor[not set]{60.000000}
ScalarConstructor[not set]{50.000000}
}
multiply
Call[not set]{
Identifier[not set]{floor}
(
Binary[not set]{
TypeConstructor[not set]{
__vec_2__f32
ScalarConstructor[not set]{50.000000}
ScalarConstructor[not set]{60.000000}
}
divide
TypeConstructor[not set]{
__vec_2__f32
ScalarConstructor[not set]{60.000000}
ScalarConstructor[not set]{50.000000}
}
}
)
}
}
}
}
})"));
}
TEST_F(SpvBinaryArithTestBasic, VectorTimesScalar) {
const auto assembly = Preamble() + R"(
%100 = OpFunction %void None %voidfn

View File

@@ -225,7 +225,21 @@ bool GeneratorImpl::EmitAssign(ast::AssignmentStatement* stmt) {
}
bool GeneratorImpl::EmitBinary(std::ostream& out, ast::BinaryExpression* expr) {
out << "(";
if (expr->op() == ast::BinaryOp::kModulo &&
TypeOf(expr)->UnwrapRef()->is_float_scalar_or_vector()) {
out << "fmod";
ScopedParen sp(out);
if (!EmitExpression(out, expr->lhs())) {
return false;
}
out << ", ";
if (!EmitExpression(out, expr->rhs())) {
return false;
}
return true;
}
ScopedParen sp(out);
if (!EmitExpression(out, expr->lhs())) {
return false;
@@ -303,7 +317,6 @@ bool GeneratorImpl::EmitBinary(std::ostream& out, ast::BinaryExpression* expr) {
return false;
}
out << ")";
return true;
}

View File

@@ -42,7 +42,7 @@ TEST_P(MslBinaryTest, Emit) {
auto* right = Var("right", type());
auto* expr =
create<ast::BinaryExpression>(params.op, Expr("left"), Expr("right"));
create<ast::BinaryExpression>(params.op, Expr(left), Expr(right));
WrapInFunction(left, right, expr);
GeneratorImpl& gen = Build();
@@ -74,6 +74,34 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{"(left / right)", ast::BinaryOp::kDivide},
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
TEST_F(MslBinaryTest, ModF32) {
auto* left = Var("left", ty.f32());
auto* right = Var("right", ty.f32());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr(left),
Expr(right));
WrapInFunction(left, right, expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "fmod(left, right)");
}
TEST_F(MslBinaryTest, ModVec3F32) {
auto* left = Var("left", ty.vec3<f32>());
auto* right = Var("right", ty.vec3<f32>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr(left),
Expr(right));
WrapInFunction(left, right, expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "fmod(left, right)");
}
} // namespace
} // namespace msl
} // namespace writer

View File

@@ -2097,7 +2097,7 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
}
} else if (expr->IsModulo()) {
if (lhs_is_float_or_vec) {
op = spv::Op::OpFMod;
op = spv::Op::OpFRem;
} else if (lhs_is_unsigned) {
op = spv::Op::OpUMod;
} else {

View File

@@ -246,7 +246,7 @@ INSTANTIATE_TEST_SUITE_P(
BinaryArithFloatTest,
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpFMod"},
BinaryData{ast::BinaryOp::kModulo, "OpFRem"},
BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));