[hlsl-writer] Use `mul` method where required.
This CL updates the binary operator emission to use the `mul()` method in the following cases: - vector * matrix - matrix * vector - matrix * matrix This is because the `*` operator works per-component in HLSL which does not do the expected multiply. Bug: tint:301 Change-Id: I0810522ac26fbbea323cf8a05a3ff6f2fb62117e Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/33362 Reviewed-by: dan sinclair <dsinclair@chromium.org> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: dan sinclair <dsinclair@chromium.org> Auto-Submit: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
31df1137d4
commit
d2f73226bc
|
@ -365,6 +365,27 @@ bool GeneratorImpl::EmitBinary(std::ostream& pre,
|
|||
return true;
|
||||
}
|
||||
|
||||
auto* lhs_type = expr->lhs()->result_type()->UnwrapAll();
|
||||
auto* rhs_type = expr->rhs()->result_type()->UnwrapAll();
|
||||
// Multiplying by a matrix requires the use of `mul` in order to get the
|
||||
// type of multiply we desire.
|
||||
if (expr->op() == ast::BinaryOp::kMultiply &&
|
||||
((lhs_type->IsVector() && rhs_type->IsMatrix()) ||
|
||||
(lhs_type->IsMatrix() && rhs_type->IsVector()) ||
|
||||
(lhs_type->IsMatrix() && rhs_type->IsMatrix()))) {
|
||||
out << "mul(";
|
||||
if (!EmitExpression(pre, out, expr->lhs())) {
|
||||
return false;
|
||||
}
|
||||
out << ", ";
|
||||
if (!EmitExpression(pre, out, expr->rhs())) {
|
||||
return false;
|
||||
}
|
||||
out << ")";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
out << "(";
|
||||
if (!EmitExpression(pre, out, expr->lhs())) {
|
||||
return false;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "src/ast/call_expression.h"
|
||||
#include "src/ast/call_statement.h"
|
||||
#include "src/ast/else_statement.h"
|
||||
#include "src/ast/float_literal.h"
|
||||
#include "src/ast/function.h"
|
||||
#include "src/ast/identifier_expression.h"
|
||||
#include "src/ast/if_statement.h"
|
||||
|
@ -28,8 +29,13 @@
|
|||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/sint_literal.h"
|
||||
#include "src/ast/type/bool_type.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/u32_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
#include "src/ast/type/void_type.h"
|
||||
#include "src/ast/type_constructor_expression.h"
|
||||
#include "src/ast/variable.h"
|
||||
#include "src/ast/variable_decl_statement.h"
|
||||
#include "src/writer/hlsl/test_helper.h"
|
||||
|
@ -51,14 +57,69 @@ inline std::ostream& operator<<(std::ostream& out, BinaryData data) {
|
|||
}
|
||||
|
||||
using HlslBinaryTest = TestParamHelper<BinaryData>;
|
||||
TEST_P(HlslBinaryTest, Emit) {
|
||||
TEST_P(HlslBinaryTest, Emit_f32) {
|
||||
ast::type::F32Type f32;
|
||||
|
||||
auto params = GetParam();
|
||||
|
||||
auto* left_var =
|
||||
create<ast::Variable>("left", ast::StorageClass::kFunction, &f32);
|
||||
auto* right_var =
|
||||
create<ast::Variable>("right", ast::StorageClass::kFunction, &f32);
|
||||
|
||||
auto* left = create<ast::IdentifierExpression>("left");
|
||||
auto* right = create<ast::IdentifierExpression>("right");
|
||||
|
||||
td.RegisterVariableForTesting(left_var);
|
||||
td.RegisterVariableForTesting(right_var);
|
||||
|
||||
ast::BinaryExpression expr(params.op, left, right);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(), params.result);
|
||||
}
|
||||
TEST_P(HlslBinaryTest, Emit_u32) {
|
||||
ast::type::U32Type u32;
|
||||
|
||||
auto params = GetParam();
|
||||
|
||||
auto* left_var =
|
||||
create<ast::Variable>("left", ast::StorageClass::kFunction, &u32);
|
||||
auto* right_var =
|
||||
create<ast::Variable>("right", ast::StorageClass::kFunction, &u32);
|
||||
|
||||
auto* left = create<ast::IdentifierExpression>("left");
|
||||
auto* right = create<ast::IdentifierExpression>("right");
|
||||
|
||||
td.RegisterVariableForTesting(left_var);
|
||||
td.RegisterVariableForTesting(right_var);
|
||||
|
||||
ast::BinaryExpression expr(params.op, left, right);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(), params.result);
|
||||
}
|
||||
TEST_P(HlslBinaryTest, Emit_i32) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto params = GetParam();
|
||||
|
||||
auto* left_var =
|
||||
create<ast::Variable>("left", ast::StorageClass::kFunction, &i32);
|
||||
auto* right_var =
|
||||
create<ast::Variable>("right", ast::StorageClass::kFunction, &i32);
|
||||
|
||||
auto* left = create<ast::IdentifierExpression>("left");
|
||||
auto* right = create<ast::IdentifierExpression>("right");
|
||||
|
||||
td.RegisterVariableForTesting(left_var);
|
||||
td.RegisterVariableForTesting(right_var);
|
||||
|
||||
ast::BinaryExpression expr(params.op, left, right);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(), params.result);
|
||||
}
|
||||
|
@ -83,6 +144,166 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{"(left / right)", ast::BinaryOp::kDivide},
|
||||
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto* lhs = create<ast::TypeConstructorExpression>(
|
||||
&vec3, ast::ExpressionList{
|
||||
create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)),
|
||||
create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)),
|
||||
create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)),
|
||||
});
|
||||
|
||||
auto* rhs = create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f));
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(),
|
||||
"(float3(1.00000000f, 1.00000000f, 1.00000000f) * "
|
||||
"1.00000000f)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto* lhs = create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f));
|
||||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(),
|
||||
"(1.00000000f * float3(1.00000000f, 1.00000000f, "
|
||||
"1.00000000f))");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto* lhs = create<ast::IdentifierExpression>("mat");
|
||||
auto* rhs = create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f));
|
||||
|
||||
td.RegisterVariableForTesting(var);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(), "(mat * 1.00000000f)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto* lhs = create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f));
|
||||
auto* rhs = create<ast::IdentifierExpression>("mat");
|
||||
|
||||
td.RegisterVariableForTesting(var);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(), "(1.00000000f * mat)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto* lhs = create<ast::IdentifierExpression>("mat");
|
||||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto* rhs = create<ast::TypeConstructorExpression>(&vec3, vals);
|
||||
|
||||
td.RegisterVariableForTesting(var);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(),
|
||||
"mul(mat, float3(1.00000000f, 1.00000000f, 1.00000000f))");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
|
||||
|
||||
ast::ExpressionList vals;
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
vals.push_back(create<ast::ScalarConstructorExpression>(
|
||||
create<ast::FloatLiteral>(&f32, 1.f)));
|
||||
auto* lhs = create<ast::TypeConstructorExpression>(&vec3, vals);
|
||||
|
||||
auto* rhs = create<ast::IdentifierExpression>("mat");
|
||||
|
||||
td.RegisterVariableForTesting(var);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(),
|
||||
"mul(float3(1.00000000f, 1.00000000f, 1.00000000f), mat)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::MatrixType mat3(&f32, 3, 3);
|
||||
|
||||
auto* var = create<ast::Variable>("mat", ast::StorageClass::kFunction, &mat3);
|
||||
auto* lhs = create<ast::IdentifierExpression>("mat");
|
||||
auto* rhs = create<ast::IdentifierExpression>("mat");
|
||||
|
||||
td.RegisterVariableForTesting(var);
|
||||
|
||||
ast::BinaryExpression expr(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
EXPECT_TRUE(gen.EmitExpression(pre, out, &expr)) << gen.error();
|
||||
EXPECT_EQ(result(), "mul(mat, mat)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
|
||||
auto* left = create<ast::IdentifierExpression>("left");
|
||||
auto* right = create<ast::IdentifierExpression>("right");
|
||||
|
|
Loading…
Reference in New Issue