[hlsl-writer] Use `mul` method where required.

This CL updates the binary operator emission to use the `mul()` method
in the following cases:
 - vector * matrix
 - matrix * vector
 - matrix * matrix

This is because the `*` operator works per-component in HLSL which does
not do the expected multiply.

Bug: tint:301
Change-Id: I0810522ac26fbbea323cf8a05a3ff6f2fb62117e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33362
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2020-11-19 18:38:11 +00:00 committed by Commit Bot service account
parent 31df1137d4
commit d2f73226bc
2 changed files with 243 additions and 1 deletions

View File

@ -365,6 +365,27 @@ bool GeneratorImpl::EmitBinary(std::ostream& pre,
return true;
}
auto* lhs_type = expr->lhs()->result_type()->UnwrapAll();
auto* rhs_type = expr->rhs()->result_type()->UnwrapAll();
// Multiplying by a matrix requires the use of `mul` in order to get the
// type of multiply we desire.
if (expr->op() == ast::BinaryOp::kMultiply &&
((lhs_type->IsVector() && rhs_type->IsMatrix()) ||
(lhs_type->IsMatrix() && rhs_type->IsVector()) ||
(lhs_type->IsMatrix() && rhs_type->IsMatrix()))) {
out << "mul(";
if (!EmitExpression(pre, out, expr->lhs())) {
return false;
}
out << ", ";
if (!EmitExpression(pre, out, expr->rhs())) {
return false;
}
out << ")";
return true;
}
out << "(";
if (!EmitExpression(pre, out, expr->lhs())) {
return false;

View File

@ -20,6 +20,7 @@
#include "src/ast/call_expression.h"
#include "src/ast/call_statement.h"
#include "src/ast/else_statement.h"
#include "src/ast/float_literal.h"
#include "src/ast/function.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
@ -28,8 +29,13 @@
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/sint_literal.h"
#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/u32_type.h"
#include "src/ast/type/vector_type.h"
#include "src/ast/type/void_type.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/variable.h"
#include "src/ast/variable_decl_statement.h"
#include "src/writer/hlsl/test_helper.h"
@ -51,14 +57,69 @@ inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
}
using HlslBinaryTest = TestParamHelper<BinaryData>;
TEST_P(HlslBinaryTest, Emit) {
TEST_P(HlslBinaryTest, Emit_f32) {
ast::type::F32Type f32;
auto params = GetParam();
auto* left_var =
create<ast::Variable>("left", ast::StorageClass::kFunction, &f32);
auto* right_var =
create<ast::Variable>("right", ast::StorageClass::kFunction, &f32);
auto* left = create<ast::IdentifierExpression>("left");
auto* right = create<ast::IdentifierExpression>("right");
td.RegisterVariableForTesting(left_var);
td.RegisterVariableForTesting(right_var);
ast::BinaryExpression expr(params.op, left, right);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(), params.result);
}
TEST_P(HlslBinaryTest, Emit_u32) {
ast::type::U32Type u32;
auto params = GetParam();
auto* left_var =
create<ast::Variable>("left", ast::StorageClass::kFunction, &u32);
auto* right_var =
create<ast::Variable>("right", ast::StorageClass::kFunction, &u32);
auto* left = create<ast::IdentifierExpression>("left");
auto* right = create<ast::IdentifierExpression>("right");
td.RegisterVariableForTesting(left_var);
td.RegisterVariableForTesting(right_var);
ast::BinaryExpression expr(params.op, left, right);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(), params.result);
}
TEST_P(HlslBinaryTest, Emit_i32) {
ast::type::I32Type i32;
auto params = GetParam();
auto* left_var =
create<ast::Variable>("left", ast::StorageClass::kFunction, &i32);
auto* right_var =
create<ast::Variable>("right", ast::StorageClass::kFunction, &i32);
auto* left = create<ast::IdentifierExpression>("left");
auto* right = create<ast::IdentifierExpression>("right");
td.RegisterVariableForTesting(left_var);
td.RegisterVariableForTesting(right_var);
ast::BinaryExpression expr(params.op, left, right);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(), params.result);
}
@ -83,6 +144,166 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{"(left / right)", ast::BinaryOp::kDivide},
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto* lhs = create<ast::TypeConstructorExpression>(
&vec3, ast::ExpressionList{
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)),
});
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(),
"(float3(1.00000000f, 1.00000000f, 1.00000000f) * "
"1.00000000f)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
ast::ExpressionList vals;
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(),
"(1.00000000f * float3(1.00000000f, 1.00000000f, "
"1.00000000f))");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
ast::type::F32Type f32;
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::IdentifierExpression>("mat");
auto* rhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(), "(mat * 1.00000000f)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
ast::type::F32Type f32;
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f));
auto* rhs = create<ast::IdentifierExpression>("mat");
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(), "(1.00000000f * mat)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::IdentifierExpression>("mat");
ast::ExpressionList vals;
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(),
"mul(mat, float3(1.00000000f, 1.00000000f, 1.00000000f))");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
ast::ExpressionList vals;
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
vals.push_back(create<ast::ScalarConstructorExpression>(
create<ast::FloatLiteral>(&f32, 1.f)));
auto* lhs = create<ast::TypeConstructorExpression>(&vec3, vals);
auto* rhs = create<ast::IdentifierExpression>("mat");
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(),
"mul(float3(1.00000000f, 1.00000000f, 1.00000000f), mat)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3(&f32, 3, 3);
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
auto* lhs = create<ast::IdentifierExpression>("mat");
auto* rhs = create<ast::IdentifierExpression>("mat");
td.RegisterVariableForTesting(var);
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
EXPECT_EQ(result(), "mul(mat, mat)");
}
TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
auto* left = create<ast::IdentifierExpression>("left");
auto* right = create<ast::IdentifierExpression>("right");