[hlsl-writer] Support matrices in storage buffers.

This CL adds the needed code to load matrix data from a storage buffer.

Bug: tint:7
Change-Id: I850b03adc7fa957b7babbad40d07ec3544b0617f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27442
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
This commit is contained in:
dan sinclair 2020-08-26 19:55:46 +00:00 committed by Commit Bot service account
parent fea2636945
commit 663be30b55
2 changed files with 407 additions and 9 deletions

View File

@ -58,6 +58,7 @@ const char kInStructNameSuffix[] = "in";
const char kOutStructNameSuffix[] = "out"; const char kOutStructNameSuffix[] = "out";
const char kTintStructInVarPrefix[] = "tint_in"; const char kTintStructInVarPrefix[] = "tint_in";
const char kTintStructOutVarPrefix[] = "tint_out"; const char kTintStructOutVarPrefix[] = "tint_out";
const char kTempNamePrefix[] = "_tint_tmp";
bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) { bool last_is_break_or_fallthrough(const ast::BlockStatement* stmts) {
if (stmts->empty()) { if (stmts->empty()) {
@ -1524,7 +1525,6 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
if (expr->IsMemberAccessor()) { if (expr->IsMemberAccessor()) {
auto* mem = expr->AsMemberAccessor(); auto* mem = expr->AsMemberAccessor();
auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias(); auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias();
if (res_type->IsStruct()) { if (res_type->IsStruct()) {
auto* str_type = res_type->AsStruct()->impl(); auto* str_type = res_type->AsStruct()->impl();
auto* str_member = str_type->get_member(mem->member()->name()); auto* str_member = str_type->get_member(mem->member()->name());
@ -1534,6 +1534,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
return ""; return "";
} }
out << str_member->offset(); out << str_member->offset();
} else if (res_type->IsVector()) { } else if (res_type->IsVector()) {
// This must be a single element swizzle if we've got a vector at this // This must be a single element swizzle if we've got a vector at this
// point. // point.
@ -1561,7 +1562,6 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias(); auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias();
out << "("; out << "(";
// TODO(dsinclair): Handle matrix case
if (ary_type->IsArray()) { if (ary_type->IsArray()) {
out << ary_type->AsArray()->array_stride(); out << ary_type->AsArray()->array_stride();
} else if (ary_type->IsVector()) { } else if (ary_type->IsVector()) {
@ -1569,6 +1569,13 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
// or u32 which are all 4 bytes. When we get f16 or other types we'll // or u32 which are all 4 bytes. When we get f16 or other types we'll
// have to ask the type for the byte size. // have to ask the type for the byte size.
out << "4"; out << "4";
} else if (ary_type->IsMatrix()) {
auto* mat = ary_type->AsMatrix();
if (mat->columns() == 2) {
out << "8";
} else {
out << "16";
}
} else { } else {
error_ = "Invalid array type in storage buffer access"; error_ = "Invalid array type in storage buffer access";
return ""; return "";
@ -1600,14 +1607,18 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& out,
ast::Expression* expr, ast::Expression* expr,
ast::Expression* rhs) { ast::Expression* rhs) {
auto* result_type = expr->result_type()->UnwrapAliasPtrAlias(); auto* result_type = expr->result_type()->UnwrapAliasPtrAlias();
std::string access_method = rhs != nullptr ? "Store" : "Load"; bool is_store = rhs != nullptr;
std::string access_method = is_store ? "Store" : "Load";
if (result_type->IsVector()) { if (result_type->IsVector()) {
access_method += std::to_string(result_type->AsVector()->size()); access_method += std::to_string(result_type->AsVector()->size());
} else if (result_type->IsMatrix()) {
access_method += std::to_string(result_type->AsMatrix()->rows());
} }
// If we aren't storing then we need to put in the outer cast. // If we aren't storing then we need to put in the outer cast.
if (rhs == nullptr) { if (!is_store) {
if (result_type->is_float_scalar_or_vector()) { if (result_type->is_float_scalar_or_vector() || result_type->IsMatrix()) {
out << "asfloat("; out << "asfloat(";
} else if (result_type->is_signed_scalar_or_vector()) { } else if (result_type->is_signed_scalar_or_vector()) {
out << "asint("; out << "asint(";
@ -1621,15 +1632,63 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& out,
error_ = "error emitting storage buffer access"; error_ = "error emitting storage buffer access";
return false; return false;
} }
out << buffer_name << "." << access_method << "(";
auto idx = generate_storage_buffer_index_expression(expr); auto idx = generate_storage_buffer_index_expression(expr);
if (idx.empty()) { if (idx.empty()) {
return false; return false;
} }
out << idx;
if (rhs != nullptr) { if (result_type->IsMatrix()) {
auto* mat = result_type->AsMatrix();
// TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed
// if we get matrixes of f16 or f64.
uint32_t stride = mat->rows() == 2 ? 8 : 16;
if (is_store) {
if (!EmitType(out, mat, "")) {
return false;
}
auto name = generate_name(kTempNamePrefix);
out << " " << name << " = ";
if (!EmitExpression(out, rhs)) {
return false;
}
out << ";" << std::endl;
for (uint32_t i = 0; i < mat->columns(); i++) {
if (i > 0) {
out << ";" << std::endl;
}
make_indent(out);
out << buffer_name << "." << access_method << "(" << idx << " + "
<< (i * stride) << ", asuint(" << name << "[" << i << "]))";
}
return true;
}
out << "matrix<uint, " << mat->rows() << ", " << mat->columns() << ">(";
for (uint32_t i = 0; i < mat->columns(); i++) {
if (i != 0) {
out << ", ";
}
out << buffer_name << "." << access_method << "(" << idx << " + "
<< (i * stride) << ")";
}
// Close the matrix type and outer cast
out << "))";
return true;
}
out << buffer_name << "." << access_method << "(" << idx;
if (is_store) {
out << ", asuint("; out << ", asuint(";
if (!EmitExpression(out, rhs)) { if (!EmitExpression(out, rhs)) {
return false; return false;
@ -1640,7 +1699,7 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& out,
out << ")"; out << ")";
// Close the outer cast. // Close the outer cast.
if (rhs == nullptr) { if (!is_store) {
out << ")"; out << ")";
} }

View File

@ -30,6 +30,7 @@
#include "src/ast/type/array_type.h" #include "src/ast/type/array_type.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h" #include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
@ -174,6 +175,344 @@ TEST_F(HlslGeneratorImplTest_MemberAccessor,
ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error(); ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
EXPECT_EQ(result(), "asint(data.Load(0))"); EXPECT_EQ(result(), "asint(data.Load(0))");
} }
TEST_F(HlslGeneratorImplTest_MemberAccessor,
EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix) {
// struct Data {
// [[offset 0]] z : f32;
// [[offset 4]] a : mat2x3<f32>;
// };
// var<storage_buffer> data : Data;
// mat2x3<f32> b;
// data.a = b;
//
// -> matrix<float, 3, 2> _tint_tmp = b;
// data.Store3(4 + 0, asuint(_tint_tmp[0]));
// data.Store3(4 + 16, asuint(_tint_tmp[1]));
ast::type::F32Type f32;
ast::type::I32Type i32;
ast::type::MatrixType mat(&f32, 3, 2);
ast::StructMemberList members;
ast::StructMemberDecorationList a_deco;
a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
members.push_back(
std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
ast::StructMemberDecorationList b_deco;
b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
members.push_back(
std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
auto str = std::make_unique<ast::Struct>();
str->set_members(std::move(members));
ast::type::StructType s(std::move(str));
s.set_name("Data");
auto b_var =
std::make_unique<ast::Variable>("b", ast::StorageClass::kPrivate, &mat);
auto coord_var = std::make_unique<ast::Variable>(
"data", ast::StorageClass::kStorageBuffer, &s);
auto lhs = std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("data"),
std::make_unique<ast::IdentifierExpression>("a"));
auto rhs = std::make_unique<ast::IdentifierExpression>("b");
ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
td().RegisterVariableForTesting(coord_var.get());
td().RegisterVariableForTesting(b_var.get());
gen().register_global(coord_var.get());
gen().register_global(b_var.get());
mod()->AddGlobalVariable(std::move(coord_var));
mod()->AddGlobalVariable(std::move(b_var));
ASSERT_TRUE(td().Determine()) << td().error();
ASSERT_TRUE(td().DetermineResultType(&assign));
ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error();
EXPECT_EQ(result(), R"(matrix<float, 3, 2> _tint_tmp = b;
data.Store3(4 + 0, asuint(_tint_tmp[0]));
data.Store3(4 + 16, asuint(_tint_tmp[1]));
)");
}
TEST_F(HlslGeneratorImplTest_MemberAccessor,
EmitExpression_MemberAccessor_StorageBuffer_Store_Matrix_Empty) {
// struct Data {
// [[offset 0]] z : f32;
// [[offset 4]] a : mat2x3<f32>;
// };
// var<storage_buffer> data : Data;
// data.a = mat2x3<f32>();
//
// -> matrix<float, 3, 2> _tint_tmp = matrix<float, 3, 2>(0.0f, 0.0f, 0.0f,
// 0.0f, 0.0f, 0.0f);
// data.Store3(4 + 0, asuint(_tint_tmp[0]);
// data.Store3(4 + 16, asuint(_tint_tmp[1]));
ast::type::F32Type f32;
ast::type::I32Type i32;
ast::type::MatrixType mat(&f32, 3, 2);
ast::StructMemberList members;
ast::StructMemberDecorationList a_deco;
a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
members.push_back(
std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
ast::StructMemberDecorationList b_deco;
b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
members.push_back(
std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
auto str = std::make_unique<ast::Struct>();
str->set_members(std::move(members));
ast::type::StructType s(std::move(str));
s.set_name("Data");
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"data", ast::StorageClass::kStorageBuffer, &s));
auto lhs = std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("data"),
std::make_unique<ast::IdentifierExpression>("a"));
auto rhs = std::make_unique<ast::TypeConstructorExpression>(
&mat, ast::ExpressionList{});
ast::AssignmentStatement assign(std::move(lhs), std::move(rhs));
td().RegisterVariableForTesting(coord_var.get());
gen().register_global(coord_var.get());
mod()->AddGlobalVariable(std::move(coord_var));
ASSERT_TRUE(td().Determine()) << td().error();
ASSERT_TRUE(td().DetermineResultType(&assign));
ASSERT_TRUE(gen().EmitStatement(out(), &assign)) << gen().error();
EXPECT_EQ(
result(),
R"(matrix<float, 3, 2> _tint_tmp = matrix<float, 3, 2>(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
data.Store3(4 + 0, asuint(_tint_tmp[0]));
data.Store3(4 + 16, asuint(_tint_tmp[1]));
)");
}
TEST_F(HlslGeneratorImplTest_MemberAccessor,
EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix) {
// struct Data {
// [[offset 0]] z : f32;
// [[offset 4]] a : mat3x2<f32>;
// };
// var<storage_buffer> data : Data;
// data.a;
//
// -> asfloat(matrix<uint, 2, 3>(data.Load2(4 + 0), data.Load2(4 + 8),
// data.Load2(4 + 16)));
ast::type::F32Type f32;
ast::type::I32Type i32;
ast::type::MatrixType mat(&f32, 2, 3);
ast::StructMemberList members;
ast::StructMemberDecorationList a_deco;
a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
members.push_back(
std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
ast::StructMemberDecorationList b_deco;
b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
members.push_back(
std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
auto str = std::make_unique<ast::Struct>();
str->set_members(std::move(members));
ast::type::StructType s(std::move(str));
s.set_name("Data");
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"data", ast::StorageClass::kStorageBuffer, &s));
ast::MemberAccessorExpression expr(
std::make_unique<ast::IdentifierExpression>("data"),
std::make_unique<ast::IdentifierExpression>("a"));
td().RegisterVariableForTesting(coord_var.get());
gen().register_global(coord_var.get());
mod()->AddGlobalVariable(std::move(coord_var));
ASSERT_TRUE(td().Determine()) << td().error();
ASSERT_TRUE(td().DetermineResultType(&expr));
ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
EXPECT_EQ(result(),
"asfloat(matrix<uint, 2, 3>(data.Load2(4 + 0), data.Load2(4 + 8), "
"data.Load2(4 + 16)))");
}
TEST_F(HlslGeneratorImplTest_MemberAccessor,
EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Nested) {
// struct Data {
// [[offset 0]] z : f32;
// [[offset 4]] a : mat2x3<f32;
// };
// struct Outer {
// [[offset 0]] c : f32;
// [[offset 4]] b : Data;
// };
// var<storage_buffer> data : Outer;
// data.b.a;
//
// -> asfloat(matrix<uint, 3, 2>(data.Load3(4 + 0), data.Load3(4 + 16)));
ast::type::F32Type f32;
ast::type::I32Type i32;
ast::type::MatrixType mat(&f32, 3, 2);
ast::StructMemberList members;
ast::StructMemberDecorationList a_deco;
a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
members.push_back(
std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
ast::StructMemberDecorationList b_deco;
b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(4));
members.push_back(
std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
auto str = std::make_unique<ast::Struct>();
str->set_members(std::move(members));
ast::type::StructType s(std::move(str));
s.set_name("Data");
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"data", ast::StorageClass::kStorageBuffer, &s));
ast::MemberAccessorExpression expr(
std::make_unique<ast::IdentifierExpression>("data"),
std::make_unique<ast::IdentifierExpression>("a"));
td().RegisterVariableForTesting(coord_var.get());
gen().register_global(coord_var.get());
mod()->AddGlobalVariable(std::move(coord_var));
ASSERT_TRUE(td().Determine()) << td().error();
ASSERT_TRUE(td().DetermineResultType(&expr));
ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
EXPECT_EQ(
result(),
"asfloat(matrix<uint, 3, 2>(data.Load3(4 + 0), data.Load3(4 + 16)))");
}
TEST_F(
HlslGeneratorImplTest_MemberAccessor,
EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_By3_Is_16_Bytes) {
// struct Data {
// [[offset 4]] a : mat3x3<f32;
// };
// var<storage_buffer> data : Data;
// data.a;
//
// -> asfloat(matrix<uint, 3, 3>(data.Load3(0), data.Load3(16),
// data.Load3(32)));
ast::type::F32Type f32;
ast::type::I32Type i32;
ast::type::MatrixType mat(&f32, 3, 3);
ast::StructMemberList members;
ast::StructMemberDecorationList deco;
deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
members.push_back(
std::make_unique<ast::StructMember>("a", &mat, std::move(deco)));
auto str = std::make_unique<ast::Struct>();
str->set_members(std::move(members));
ast::type::StructType s(std::move(str));
s.set_name("Data");
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"data", ast::StorageClass::kStorageBuffer, &s));
ast::MemberAccessorExpression expr(
std::make_unique<ast::IdentifierExpression>("data"),
std::make_unique<ast::IdentifierExpression>("a"));
td().RegisterVariableForTesting(coord_var.get());
gen().register_global(coord_var.get());
mod()->AddGlobalVariable(std::move(coord_var));
ASSERT_TRUE(td().Determine()) << td().error();
ASSERT_TRUE(td().DetermineResultType(&expr));
ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
EXPECT_EQ(result(),
"asfloat(matrix<uint, 3, 3>(data.Load3(0 + 0), data.Load3(0 + 16), "
"data.Load3(0 + 32)))");
}
TEST_F(HlslGeneratorImplTest_MemberAccessor,
EmitExpression_MemberAccessor_StorageBuffer_Load_Matrix_Single_Element) {
// struct Data {
// [[offset 0]] z : f32;
// [[offset 16]] a : mat4x3<f32>;
// };
// var<storage_buffer> data : Data;
// data.a[2][1];
//
// -> asfloat(data.Load((2 * 16) + (1 * 4) + 16)))
ast::type::F32Type f32;
ast::type::I32Type i32;
ast::type::MatrixType mat(&f32, 3, 4);
ast::StructMemberList members;
ast::StructMemberDecorationList a_deco;
a_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(0));
members.push_back(
std::make_unique<ast::StructMember>("z", &i32, std::move(a_deco)));
ast::StructMemberDecorationList b_deco;
b_deco.push_back(std::make_unique<ast::StructMemberOffsetDecoration>(16));
members.push_back(
std::make_unique<ast::StructMember>("a", &mat, std::move(b_deco)));
auto str = std::make_unique<ast::Struct>();
str->set_members(std::move(members));
ast::type::StructType s(std::move(str));
s.set_name("Data");
auto coord_var =
std::make_unique<ast::DecoratedVariable>(std::make_unique<ast::Variable>(
"data", ast::StorageClass::kStorageBuffer, &s));
ast::ArrayAccessorExpression expr(
std::make_unique<ast::ArrayAccessorExpression>(
std::make_unique<ast::MemberAccessorExpression>(
std::make_unique<ast::IdentifierExpression>("data"),
std::make_unique<ast::IdentifierExpression>("a")),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 2))),
std::make_unique<ast::ScalarConstructorExpression>(
std::make_unique<ast::SintLiteral>(&i32, 1)));
td().RegisterVariableForTesting(coord_var.get());
gen().register_global(coord_var.get());
mod()->AddGlobalVariable(std::move(coord_var));
ASSERT_TRUE(td().Determine()) << td().error();
ASSERT_TRUE(td().DetermineResultType(&expr));
ASSERT_TRUE(gen().EmitExpression(out(), &expr)) << gen().error();
EXPECT_EQ(result(), "asfloat(data.Load((4 * 1) + (16 * 2) + 16))");
}
TEST_F(HlslGeneratorImplTest_MemberAccessor, TEST_F(HlslGeneratorImplTest_MemberAccessor,
EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray) { EmitExpression_ArrayAccessor_StorageBuffer_Load_Int_FromArray) {