mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-05-14 11:21:40 +00:00
hlsl/writer: Transpose matrices
HLSL's matrices are declared as <type>NxM, where N is the number of rows and M is the number of columns. Despite HLSL's matrices being column-major by default, the index operator and constructors actually operate on row-vectors, where as WGSL operates on column vectors. To simplify everything we use the transpose of the matrices. This is the same approach taken by SPIRV-Cross. Change-Id: I98860e11ff1a68132736980f694b2f68b633ef83 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46873 Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
parent
8bd7cc7e88
commit
ab26a8fd34
@ -140,7 +140,8 @@ ast::Statement* ProgramBuilder::WrapInStatement(ast::Statement* stmt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ast::Function* ProgramBuilder::WrapInFunction(ast::StatementList stmts) {
|
ast::Function* ProgramBuilder::WrapInFunction(ast::StatementList stmts) {
|
||||||
return Func("test_function", {}, ty.void_(), std::move(stmts), {});
|
return Func("test_function", {}, ty.void_(), std::move(stmts),
|
||||||
|
{create<ast::StageDecoration>(ast::PipelineStage::kCompute)});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
@ -370,12 +370,13 @@ bool GeneratorImpl::EmitBinary(std::ostream& pre,
|
|||||||
((lhs_type->Is<type::Vector>() && rhs_type->Is<type::Matrix>()) ||
|
((lhs_type->Is<type::Vector>() && rhs_type->Is<type::Matrix>()) ||
|
||||||
(lhs_type->Is<type::Matrix>() && rhs_type->Is<type::Vector>()) ||
|
(lhs_type->Is<type::Matrix>() && rhs_type->Is<type::Vector>()) ||
|
||||||
(lhs_type->Is<type::Matrix>() && rhs_type->Is<type::Matrix>()))) {
|
(lhs_type->Is<type::Matrix>() && rhs_type->Is<type::Matrix>()))) {
|
||||||
|
// Matrices are transposed, so swap LHS and RHS.
|
||||||
out << "mul(";
|
out << "mul(";
|
||||||
if (!EmitExpression(pre, out, expr->lhs())) {
|
if (!EmitExpression(pre, out, expr->rhs())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
out << ", ";
|
out << ", ";
|
||||||
if (!EmitExpression(pre, out, expr->rhs())) {
|
if (!EmitExpression(pre, out, expr->lhs())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
out << ")";
|
out << ")";
|
||||||
@ -2529,7 +2530,14 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||||||
if (!EmitType(out, mat->type(), "")) {
|
if (!EmitType(out, mat->type(), "")) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
out << mat->rows() << "x" << mat->columns();
|
// Note: HLSL's matrices are declared as <type>NxM, where N is the number of
|
||||||
|
// rows and M is the number of columns. Despite HLSL's matrices being
|
||||||
|
// column-major by default, the index operator and constructors actually
|
||||||
|
// operate on row-vectors, where as WGSL operates on column vectors.
|
||||||
|
// To simplify everything we use the transpose of the matrices.
|
||||||
|
// See:
|
||||||
|
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-per-component-math#matrix-ordering
|
||||||
|
out << mat->columns() << "x" << mat->rows();
|
||||||
} else if (type->Is<type::Pointer>()) {
|
} else if (type->Is<type::Pointer>()) {
|
||||||
// TODO(dsinclair): What do we do with pointers in HLSL?
|
// TODO(dsinclair): What do we do with pointers in HLSL?
|
||||||
// https://bugs.chromium.org/p/tint/issues/detail?id=183
|
// https://bugs.chromium.org/p/tint/issues/detail?id=183
|
||||||
|
@ -198,7 +198,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
|
|||||||
GeneratorImpl& gen = Build();
|
GeneratorImpl& gen = Build();
|
||||||
|
|
||||||
EXPECT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
|
EXPECT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
|
||||||
EXPECT_EQ(result(), "mul(mat, float3(1.0f, 1.0f, 1.0f))");
|
EXPECT_EQ(result(), "mul(float3(1.0f, 1.0f, 1.0f), mat)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
||||||
@ -213,7 +213,7 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
|||||||
GeneratorImpl& gen = Build();
|
GeneratorImpl& gen = Build();
|
||||||
|
|
||||||
EXPECT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
|
EXPECT_TRUE(gen.EmitExpression(pre, out, expr)) << gen.error();
|
||||||
EXPECT_EQ(result(), "mul(float3(1.0f, 1.0f, 1.0f), mat)");
|
EXPECT_EQ(result(), "mul(mat, float3(1.0f, 1.0f, 1.0f))");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
|
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "gmock/gmock.h"
|
||||||
#include "src/writer/hlsl/test_helper.h"
|
#include "src/writer/hlsl/test_helper.h"
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
@ -19,6 +20,8 @@ namespace writer {
|
|||||||
namespace hlsl {
|
namespace hlsl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
using HlslGeneratorImplTest_Constructor = TestHelper;
|
using HlslGeneratorImplTest_Constructor = TestHelper;
|
||||||
|
|
||||||
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Bool) {
|
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Bool) {
|
||||||
@ -113,19 +116,19 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_Empty) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat) {
|
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat) {
|
||||||
// WGSL matrix is mat2x3 (it flips for AST, sigh). With a type constructor
|
WrapInFunction(
|
||||||
// of <vec3, vec3>
|
mat2x3<f32>(vec3<f32>(1.f, 2.f, 3.f), vec3<f32>(3.f, 4.f, 5.f)));
|
||||||
|
|
||||||
auto* expr = mat2x3<f32>(vec3<f32>(1.f, 2.f, 3.f), vec3<f32>(3.f, 4.f, 5.f));
|
|
||||||
|
|
||||||
GeneratorImpl& gen = Build();
|
GeneratorImpl& gen = Build();
|
||||||
|
|
||||||
ASSERT_TRUE(gen.EmitConstructor(pre, out, expr)) << gen.error();
|
ASSERT_TRUE(gen.Generate(out)) << gen.error();
|
||||||
|
|
||||||
// A matrix of type T with n columns and m rows can also be constructed from
|
EXPECT_THAT(
|
||||||
// n vectors of type T with m components.
|
result(),
|
||||||
EXPECT_EQ(result(),
|
HasSubstr(
|
||||||
"float3x2(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))");
|
"float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))"));
|
||||||
|
|
||||||
|
Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Array) {
|
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Array) {
|
||||||
|
@ -104,7 +104,7 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor,
|
|||||||
// mat2x3<f32> b;
|
// mat2x3<f32> b;
|
||||||
// data.a = b;
|
// data.a = b;
|
||||||
//
|
//
|
||||||
// -> float3x2 _tint_tmp = b;
|
// -> float2x3 _tint_tmp = b;
|
||||||
// data.Store3(4 + 0, asuint(_tint_tmp[0]));
|
// data.Store3(4 + 0, asuint(_tint_tmp[0]));
|
||||||
// data.Store3(4 + 16, asuint(_tint_tmp[1]));
|
// data.Store3(4 + 16, asuint(_tint_tmp[1]));
|
||||||
|
|
||||||
@ -126,7 +126,7 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor,
|
|||||||
gen.register_global(b_var);
|
gen.register_global(b_var);
|
||||||
|
|
||||||
ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error();
|
ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error();
|
||||||
EXPECT_EQ(result(), R"(float3x2 _tint_tmp = b;
|
EXPECT_EQ(result(), R"(float2x3 _tint_tmp = b;
|
||||||
data.Store3(16 + 0, asuint(_tint_tmp[0]));
|
data.Store3(16 + 0, asuint(_tint_tmp[0]));
|
||||||
data.Store3(16 + 16, asuint(_tint_tmp[1]));
|
data.Store3(16 + 16, asuint(_tint_tmp[1]));
|
||||||
)");
|
)");
|
||||||
@ -141,7 +141,7 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor,
|
|||||||
// var<storage> data : Data;
|
// var<storage> data : Data;
|
||||||
// data.a = mat2x3<f32>();
|
// data.a = mat2x3<f32>();
|
||||||
//
|
//
|
||||||
// -> float3x2 _tint_tmp = float3x2(0.0f, 0.0f, 0.0f,
|
// -> float2x3 _tint_tmp = float2x3(0.0f, 0.0f, 0.0f,
|
||||||
// 0.0f, 0.0f, 0.0f);
|
// 0.0f, 0.0f, 0.0f);
|
||||||
// data.Store3(16 + 0, asuint(_tint_tmp[0]);
|
// data.Store3(16 + 0, asuint(_tint_tmp[0]);
|
||||||
// data.Store3(16 + 16, asuint(_tint_tmp[1]));
|
// data.Store3(16 + 16, asuint(_tint_tmp[1]));
|
||||||
@ -164,7 +164,7 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor,
|
|||||||
ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error();
|
ASSERT_TRUE(gen.EmitStatement(out, assign)) << gen.error();
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
result(),
|
result(),
|
||||||
R"(float3x2 _tint_tmp = float3x2(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
R"(float2x3 _tint_tmp = float2x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
data.Store3(16 + 0, asuint(_tint_tmp[0]));
|
data.Store3(16 + 0, asuint(_tint_tmp[0]));
|
||||||
data.Store3(16 + 16, asuint(_tint_tmp[1]));
|
data.Store3(16 + 16, asuint(_tint_tmp[1]));
|
||||||
)");
|
)");
|
||||||
|
@ -158,7 +158,7 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Matrix) {
|
|||||||
GeneratorImpl& gen = Build();
|
GeneratorImpl& gen = Build();
|
||||||
|
|
||||||
ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error();
|
ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error();
|
||||||
EXPECT_EQ(result(), "float3x2");
|
EXPECT_EQ(result(), "float2x3");
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(dsinclair): How to annotate as workgroup?
|
// TODO(dsinclair): How to annotate as workgroup?
|
||||||
|
@ -131,7 +131,7 @@ TEST_F(HlslGeneratorImplTest_VariableDecl,
|
|||||||
|
|
||||||
ASSERT_TRUE(gen.EmitStatement(out, stmt)) << gen.error();
|
ASSERT_TRUE(gen.EmitStatement(out, stmt)) << gen.error();
|
||||||
EXPECT_EQ(result(),
|
EXPECT_EQ(result(),
|
||||||
R"(float3x2 a = float3x2(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
R"(float2x3 a = float2x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,7 +35,8 @@ TEST_F(WgslGeneratorImplTest, Emit_GlobalDeclAfterFunction) {
|
|||||||
gen.increment_indent();
|
gen.increment_indent();
|
||||||
|
|
||||||
ASSERT_TRUE(gen.Generate(nullptr)) << gen.error();
|
ASSERT_TRUE(gen.Generate(nullptr)) << gen.error();
|
||||||
EXPECT_EQ(gen.result(), R"( fn test_function() -> void {
|
EXPECT_EQ(gen.result(), R"( [[stage(compute)]]
|
||||||
|
fn test_function() -> void {
|
||||||
var a : f32;
|
var a : f32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user