diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index bf66a22868..ca3a445454 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -58,6 +58,7 @@ const char kInStructNameSuffix[] = "in"; const char kOutStructNameSuffix[] = "out"; const char kTintStructInVarPrefix[] = "tint_in"; const char kTintStructOutVarPrefix[] = "tint_out"; +const char kTempNamePrefix[] = "_tint_tmp"; bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { if (stmts->empty()) { @@ -1524,7 +1525,6 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( if (expr->IsMemberAccessor()) { auto* mem = expr->AsMemberAccessor(); auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias(); - if (res_type->IsStruct()) { auto* str_type = res_type->AsStruct()->impl(); auto* str_member = str_type->get_member(mem->member()->name()); @@ -1534,6 +1534,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( return ""; } out << str_member->offset(); + } else if (res_type->IsVector()) { // This must be a single element swizzle if we've got a vector at this // point. @@ -1561,7 +1562,6 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias(); out << "("; - // TODO(dsinclair): Handle matrix case if (ary_type->IsArray()) { out << ary_type->AsArray()->array_stride(); } else if (ary_type->IsVector()) { @@ -1569,6 +1569,13 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( // or u32 which are all 4 bytes. When we get f16 or other types we'll // have to ask the type for the byte size. out << "4"; + } else if (ary_type->IsMatrix()) { + auto* mat = ary_type->AsMatrix(); + if (mat->columns() == 2) { + out << "8"; + } else { + out << "16"; + } } else { error_ = "Invalid array type in storage buffer access"; return ""; @@ -1600,14 +1607,18 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& out, ast::Expression* expr, ast::Expression* rhs) { auto* result_type = expr->result_type()->UnwrapAliasPtrAlias(); - std::string access_method = rhs != nullptr ? "Store" : "Load"; + bool is_store = rhs != nullptr; + + std::string access_method = is_store ? "Store" : "Load"; if (result_type->IsVector()) { access_method += std::to_string(result_type->AsVector()->size()); + } else if (result_type->IsMatrix()) { + access_method += std::to_string(result_type->AsMatrix()->rows()); } // If we aren't storing then we need to put in the outer cast. - if (rhs == nullptr) { - if (result_type->is_float_scalar_or_vector()) { + if (!is_store) { + if (result_type->is_float_scalar_or_vector() || result_type->IsMatrix()) { out << "asfloat("; } else if (result_type->is_signed_scalar_or_vector()) { out << "asint("; @@ -1621,15 +1632,63 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& out, error_ = "error emitting storage buffer access"; return false; } - out << buffer_name << "." << access_method << "("; auto idx = generate_storage_buffer_index_expression(expr); if (idx.empty()) { return false; } - out << idx; - if (rhs != nullptr) { + if (result_type->IsMatrix()) { + auto* mat = result_type->AsMatrix(); + + // TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed + // if we get matrixes of f16 or f64. + uint32_t stride = mat->rows() == 2 ? 8 : 16; + + if (is_store) { + if (!EmitType(out, mat, "")) { + return false; + } + + auto name = generate_name(kTempNamePrefix); + out << " " << name << " = "; + if (!EmitExpression(out, rhs)) { + return false; + } + out << ";" << std::endl; + + for (uint32_t i = 0; i < mat->columns(); i++) { + if (i > 0) { + out << ";" << std::endl; + } + + make_indent(out); + out << buffer_name << "." << access_method << "(" << idx << " + " + << (i * stride) << ", asuint(" << name << "[" << i << "]))"; + } + + return true; + } + + out << "matrixrows() << ", " << mat->columns() << ">("; + + for (uint32_t i = 0; i < mat->columns(); i++) { + if (i != 0) { + out << ", "; + } + + out << buffer_name << "." << access_method << "(" << idx << " + " + << (i * stride) << ")"; + } + + // Close the matrix type and outer cast + out << "))"; + + return true; + } + + out << buffer_name << "." << access_method << "(" << idx; + if (is_store) { out << ", asuint("; if (!EmitExpression(out, rhs)) { return false; @@ -1640,7 +1699,7 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& out, out << ")"; // Close the outer cast. - if (rhs == nullptr) { + if (!is_store) { out << ")"; } diff --git a/src/writer/hlsl/generator_impl_member_accessor_test.cc b/src/writer/hlsl/generator_impl_member_accessor_test.cc index 34be49008a..8eec27ad8a 100644 --- a/src/writer/hlsl/generator_impl_member_accessor_test.cc +++ b/src/writer/hlsl/generator_impl_member_accessor_test.cc @@ -30,6 +30,7 @@ #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/matrix_type.h" #include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" @@ -174,6 +175,344 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor, ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error(); EXPECT_EQ(result(), "asint(data.Load(0))"); } +TEST_F(HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix) { + // struct Data { + // [[offset 0]] z : f32; + // [[offset 4]] a : mat2x3; + // }; + // var data : Data; + // mat2x3 b; + // data.a = b; + // + // -> matrix _tint_tmp = b; + // data.Store3(4 + 0, asuint(_tint_tmp[0])); + // data.Store3(4 + 16, asuint(_tint_tmp[1])); + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::MatrixType mat(&f32, 3, 2); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("z", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("a", &mat, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto b_var = + std::make_unique("b", ast::StorageClass::kPrivate, &mat); + + auto coord_var = std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s); + + auto lhs = std::make_unique( + std::make_unique("data"), + std::make_unique("a")); + auto rhs = std::make_unique("b"); + + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + td().RegisterVariableForTesting(coord_var.get()); + td().RegisterVariableForTesting(b_var.get()); + gen().register_global(coord_var.get()); + gen().register_global(b_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + mod()->AddGlobalVariable(std::move(b_var)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(td().DetermineResultType(&assign)); + + ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error(); + EXPECT_EQ(result(), R"(matrix _tint_tmp = b; +data.Store3(4 + 0, asuint(_tint_tmp[0])); +data.Store3(4 + 16, asuint(_tint_tmp[1])); +)"); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix_Empty) { + // struct Data { + // [[offset 0]] z : f32; + // [[offset 4]] a : mat2x3; + // }; + // var data : Data; + // data.a = mat2x3(); + // + // -> matrix _tint_tmp = matrix(0.0f, 0.0f, 0.0f, + // 0.0f, 0.0f, 0.0f); + // data.Store3(4 + 0, asuint(_tint_tmp[0]); + // data.Store3(4 + 16, asuint(_tint_tmp[1])); + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::MatrixType mat(&f32, 3, 2); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("z", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("a", &mat, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + auto lhs = std::make_unique( + std::make_unique("data"), + std::make_unique("a")); + auto rhs = std::make_unique( + &mat, ast::ExpressionList{}); + + ast::AssignmentStatement assign(std::move(lhs), std::move(rhs)); + + td().RegisterVariableForTesting(coord_var.get()); + gen().register_global(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(td().DetermineResultType(&assign)); + + ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error(); + EXPECT_EQ( + result(), + R"(matrix _tint_tmp = matrix(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); +data.Store3(4 + 0, asuint(_tint_tmp[0])); +data.Store3(4 + 16, asuint(_tint_tmp[1])); +)"); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix) { + // struct Data { + // [[offset 0]] z : f32; + // [[offset 4]] a : mat3x2; + // }; + // var data : Data; + // data.a; + // + // -> asfloat(matrix(data.Load2(4 + 0), data.Load2(4 + 8), + // data.Load2(4 + 16))); + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::MatrixType mat(&f32, 2, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("z", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("a", &mat, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::MemberAccessorExpression expr( + std::make_unique("data"), + std::make_unique("a")); + + td().RegisterVariableForTesting(coord_var.get()); + gen().register_global(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(td().DetermineResultType(&expr)); + + ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error(); + EXPECT_EQ(result(), + "asfloat(matrix(data.Load2(4 + 0), data.Load2(4 + 8), " + "data.Load2(4 + 16)))"); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Nested) { + // struct Data { + // [[offset 0]] z : f32; + // [[offset 4]] a : mat2x3 data : Outer; + // data.b.a; + // + // -> asfloat(matrix(data.Load3(4 + 0), data.Load3(4 + 16))); + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::MatrixType mat(&f32, 3, 2); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("z", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(4)); + members.push_back( + std::make_unique("a", &mat, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::MemberAccessorExpression expr( + std::make_unique("data"), + std::make_unique("a")); + + td().RegisterVariableForTesting(coord_var.get()); + gen().register_global(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(td().DetermineResultType(&expr)); + + ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error(); + EXPECT_EQ( + result(), + "asfloat(matrix(data.Load3(4 + 0), data.Load3(4 + 16)))"); +} + +TEST_F( + HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_By3_Is_16_Bytes) { + // struct Data { + // [[offset 4]] a : mat3x3 data : Data; + // data.a; + // + // -> asfloat(matrix(data.Load3(0), data.Load3(16), + // data.Load3(32))); + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::MatrixType mat(&f32, 3, 3); + + ast::StructMemberList members; + ast::StructMemberDecorationList deco; + deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("a", &mat, std::move(deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::MemberAccessorExpression expr( + std::make_unique("data"), + std::make_unique("a")); + + td().RegisterVariableForTesting(coord_var.get()); + gen().register_global(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(td().DetermineResultType(&expr)); + + ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error(); + EXPECT_EQ(result(), + "asfloat(matrix(data.Load3(0 + 0), data.Load3(0 + 16), " + "data.Load3(0 + 32)))"); +} + +TEST_F(HlslGeneratorImplTest_MemberAccessor, + EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Single_Element) { + // struct Data { + // [[offset 0]] z : f32; + // [[offset 16]] a : mat4x3; + // }; + // var data : Data; + // data.a[2][1]; + // + // -> asfloat(data.Load((2 * 16) + (1 * 4) + 16))) + ast::type::F32Type f32; + ast::type::I32Type i32; + ast::type::MatrixType mat(&f32, 3, 4); + + ast::StructMemberList members; + ast::StructMemberDecorationList a_deco; + a_deco.push_back(std::make_unique(0)); + members.push_back( + std::make_unique("z", &i32, std::move(a_deco))); + + ast::StructMemberDecorationList b_deco; + b_deco.push_back(std::make_unique(16)); + members.push_back( + std::make_unique("a", &mat, std::move(b_deco))); + + auto str = std::make_unique(); + str->set_members(std::move(members)); + + ast::type::StructType s(std::move(str)); + s.set_name("Data"); + + auto coord_var = + std::make_unique(std::make_unique( + "data", ast::StorageClass::kStorageBuffer, &s)); + + ast::ArrayAccessorExpression expr( + std::make_unique( + std::make_unique( + std::make_unique("data"), + std::make_unique("a")), + std::make_unique( + std::make_unique(&i32, 2))), + std::make_unique( + std::make_unique(&i32, 1))); + + td().RegisterVariableForTesting(coord_var.get()); + gen().register_global(coord_var.get()); + mod()->AddGlobalVariable(std::move(coord_var)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(td().DetermineResultType(&expr)); + + ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error(); + EXPECT_EQ(result(), "asfloat(data.Load((4 * 1) + (16 * 2) + 16))"); +} TEST_F(HlslGeneratorImplTest_MemberAccessor, EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray) {