writer/hlsl: Fix storage buffers

Use the new CalculateArrayLength and DecomposeStorageAccess transforms to simplify storage buffer patterns before running the HLSL writer.

The HLSL writer now handles the InternalDecorations for the internal load, store, and buffer-length intrinsics.

GeneratorImpl::EmitStorageBufferAccessor() has now been entirely removed, as all this primitive load / store decomposition performed by DecomposeStorageAccess.

TODOs around runtime arrays have been removed, as this is now handled by CalculateArrayLength.

Bug: tint:185
Bug: tint:683
Change-Id: Ie25a527e7a22da52778c4477cfc22501de558a41
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46878
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-04-09 13:20:28 +00:00 committed by Commit Bot service account
parent 015b9aa93a
commit e2c775f4bb
8 changed files with 973 additions and 993 deletions

View File

@ -22,6 +22,9 @@
#include "src/semantic/expression.h"
#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
#include "src/transform/calculate_array_length.h"
#include "src/transform/decompose_storage_access.h"
#include "src/transform/manager.h"
namespace tint {
namespace transform {
@ -29,13 +32,21 @@ namespace transform {
Hlsl::Hlsl() = default;
Hlsl::~Hlsl() = default;
Transform::Output Hlsl::Run(const Program* in, const DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
Transform::Output Hlsl::Run(const Program* in, const DataMap& data) {
Manager manager;
manager.Add<DecomposeStorageAccess>();
manager.Add<CalculateArrayLength>();
auto out = manager.Run(in, data);
if (!out.program.IsValid()) {
return out;
}
ProgramBuilder builder;
CloneContext ctx(&builder, &out.program);
PromoteInitializersToConstVar(ctx);
AddEmptyEntryPoint(ctx);
ctx.Clone();
return Output{Program(std::move(out))};
return Output{Program(std::move(builder))};
}
void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const {

View File

@ -40,6 +40,14 @@ class Manager : public Transform {
transforms_.push_back(std::move(transform));
}
/// Add pass to the manager of type `T`, constructed with the provided
/// arguments.
/// @param args the arguments to forward to the `T` constructor
template <typename T, typename... ARGS>
void Add(ARGS&&... args) {
transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
}
/// Runs the transforms on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data

View File

@ -28,6 +28,8 @@
#include "src/semantic/member_accessor_expression.h"
#include "src/semantic/struct.h"
#include "src/semantic/variable.h"
#include "src/transform/calculate_array_length.h"
#include "src/transform/decompose_storage_access.h"
#include "src/type/access_control_type.h"
#include "src/type/depth_texture_type.h"
#include "src/type/multisampled_texture_type.h"
@ -220,7 +222,7 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out,
return true;
}
out << "typedef ";
if (!EmitType(out, alias->type(), "")) {
if (!EmitType(out, alias->type(), ast::StorageClass::kNone, "")) {
return false;
}
out << " " << namer_.NameFor(builder_.Symbols().NameFor(alias->symbol()))
@ -240,11 +242,6 @@ bool GeneratorImpl::EmitConstructedType(std::ostream& out,
bool GeneratorImpl::EmitArrayAccessor(std::ostream& pre,
std::ostream& out,
ast::ArrayAccessorExpression* expr) {
// Handle writing into a storage buffer array
if (is_storage_buffer_access(expr)) {
return EmitStorageBufferAccessor(pre, out, expr, nullptr);
}
if (!EmitExpression(pre, out, expr->array())) {
return false;
}
@ -268,7 +265,7 @@ bool GeneratorImpl::EmitBitcast(std::ostream& pre,
}
out << "as";
if (!EmitType(out, expr->type(), "")) {
if (!EmitType(out, expr->type(), ast::StorageClass::kNone, "")) {
return false;
}
out << "(";
@ -285,30 +282,6 @@ bool GeneratorImpl::EmitAssign(std::ostream& out,
std::ostringstream pre;
// If the LHS is an accessor into a storage buffer then we have to
// emit a Store operation instead of an ='s.
if (auto* mem = stmt->lhs()->As<ast::MemberAccessorExpression>()) {
if (is_storage_buffer_access(mem)) {
std::ostringstream accessor_out;
if (!EmitStorageBufferAccessor(pre, accessor_out, mem, stmt->rhs())) {
return false;
}
out << pre.str();
out << accessor_out.str() << ";" << std::endl;
return true;
}
} else if (auto* ary = stmt->lhs()->As<ast::ArrayAccessorExpression>()) {
if (is_storage_buffer_access(ary)) {
std::ostringstream accessor_out;
if (!EmitStorageBufferAccessor(pre, accessor_out, ary, stmt->rhs())) {
return false;
}
out << pre.str();
out << accessor_out.str() << ";" << std::endl;
return true;
}
}
std::ostringstream lhs_out;
if (!EmitExpression(pre, lhs_out, stmt->lhs())) {
return false;
@ -516,12 +489,130 @@ bool GeneratorImpl::EmitCall(std::ostream& pre,
return 0;
}
const auto& params = expr->params();
auto* call = builder_.Sem().Get(expr);
auto* target = call->Target();
if (auto* func = target->As<semantic::Function>()) {
if (ast::HasDecoration<
transform::CalculateArrayLength::BufferSizeIntrinsic>(
func->Declaration()->decorations())) {
// Special function generated by the CalculateArrayLength transform for
// calling X.GetDimensions(Y)
if (!EmitExpression(pre, out, params[0])) {
return false;
}
out << ".GetDimensions(";
if (!EmitExpression(pre, out, params[1])) {
return false;
}
out << ")";
return true;
}
if (auto* intrinsic =
ast::GetDecoration<transform::DecomposeStorageAccess::Intrinsic>(
func->Declaration()->decorations())) {
auto load = [&](const char* cast, int n) {
if (cast) {
out << cast << "(";
}
if (!EmitExpression(pre, out, params[0])) { // buffer
return false;
}
out << ".Load";
if (n > 1) {
out << n;
}
ScopedParen sp(out);
if (!EmitExpression(pre, out, params[1])) { // offset
return false;
}
if (cast) {
out << ")";
}
return true;
};
auto store = [&](int n) {
if (!EmitExpression(pre, out, params[0])) { // buffer
return false;
}
out << ".Store";
if (n > 1) {
out << n;
}
ScopedParen sp1(out);
if (!EmitExpression(pre, out, params[1])) { // offset
return false;
}
out << ", asuint";
ScopedParen sp2(out);
if (!EmitExpression(pre, out, params[2])) { // value
return false;
}
return true;
};
switch (intrinsic->type) {
case transform::DecomposeStorageAccess::Intrinsic::kLoadU32:
return load(nullptr, 1);
case transform::DecomposeStorageAccess::Intrinsic::kLoadF32:
return load("asfloat", 1);
case transform::DecomposeStorageAccess::Intrinsic::kLoadI32:
return load("asint", 1);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec2U32:
return load(nullptr, 2);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec2F32:
return load("asfloat", 2);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec2I32:
return load("asint", 2);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec3U32:
return load(nullptr, 3);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec3F32:
return load("asfloat", 3);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec3I32:
return load("asint", 3);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec4U32:
return load(nullptr, 4);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec4F32:
return load("asfloat", 4);
case transform::DecomposeStorageAccess::Intrinsic::kLoadVec4I32:
return load("asint", 4);
case transform::DecomposeStorageAccess::Intrinsic::kStoreU32:
return store(1);
case transform::DecomposeStorageAccess::Intrinsic::kStoreF32:
return store(1);
case transform::DecomposeStorageAccess::Intrinsic::kStoreI32:
return store(1);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec2U32:
return store(2);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec2F32:
return store(2);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec2I32:
return store(2);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec3U32:
return store(3);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec3F32:
return store(3);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec3I32:
return store(3);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec4U32:
return store(4);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec4F32:
return store(4);
case transform::DecomposeStorageAccess::Intrinsic::kStoreVec4I32:
return store(4);
}
TINT_UNIMPLEMENTED(diagnostics_) << static_cast<int>(intrinsic->type);
return false;
}
}
if (auto* intrinsic = call->Target()->As<semantic::Intrinsic>()) {
if (intrinsic->IsTexture()) {
return EmitTextureCall(pre, out, expr, intrinsic);
}
const auto& params = expr->params();
if (intrinsic->Type() == semantic::IntrinsicType::kSelect) {
diagnostics_.add_error("select not supported in HLSL backend yet");
return false;
@ -597,7 +688,6 @@ bool GeneratorImpl::EmitCall(std::ostream& pre,
}
}
const auto& params = expr->params();
for (auto* param : params) {
if (!first) {
out << ", ";
@ -1241,7 +1331,7 @@ bool GeneratorImpl::EmitTypeConstructor(std::ostream& pre,
if (brackets) {
out << "{";
} else {
if (!EmitType(out, expr->type(), "")) {
if (!EmitType(out, expr->type(), ast::StorageClass::kNone, "")) {
return false;
}
out << "(";
@ -1499,7 +1589,7 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
Symbol ep_sym) {
auto name = func->symbol().to_str();
if (!EmitType(out, func->return_type(), "")) {
if (!EmitType(out, func->return_type(), ast::StorageClass::kNone, "")) {
return false;
}
@ -1551,9 +1641,16 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
}
first = false;
auto* type = builder_.Sem().Get(v)->Type();
auto* sem = builder_.Sem().Get(v);
auto* type = sem->Type();
if (!EmitType(out, type, builder_.Symbols().NameFor(v->symbol()))) {
// Note: WGSL only allows for StorageClass::kNone on parameters, however the
// sanitizer transforms generates load / store functions for storage
// buffers. These functions have a storage buffer parameter with
// StorageClass::kStorage. This is required to correctly translate the
// parameter to [RW]ByteAddressBuffer.
if (!EmitType(out, type, sem->StorageClass(),
builder_.Symbols().NameFor(v->symbol()))) {
return false;
}
// Array name is output as part of the type
@ -1638,7 +1735,7 @@ bool GeneratorImpl::EmitEntryPointData(
increment_indent();
make_indent(out);
if (!EmitType(out, type, "")) {
if (!EmitType(out, type, var->StorageClass(), "")) {
return false;
}
out << " " << builder_.Symbols().NameFor(decl->symbol()) << ";"
@ -1663,18 +1760,19 @@ bool GeneratorImpl::EmitEntryPointData(
continue; // Global already emitted
}
auto* ac = var->Type()->As<type::AccessControl>();
if (ac == nullptr) {
auto* access = var->Type()->As<type::AccessControl>();
if (access == nullptr) {
diagnostics_.add_error("access control type required for storage buffer");
return false;
}
if (!ac->IsReadOnly()) {
out << "RW";
if (!EmitType(out, var->Type(), ast::StorageClass::kStorage, "")) {
return false;
}
out << "ByteAddressBuffer " << builder_.Symbols().NameFor(decl->symbol())
<< RegisterAndSpace(ac->IsReadOnly() ? 't' : 'u', binding_point) << ";"
<< std::endl;
out << " " << builder_.Symbols().NameFor(decl->symbol())
<< RegisterAndSpace(access->IsReadOnly() ? 't' : 'u', binding_point)
<< ";" << std::endl;
emitted_storagebuffer = true;
}
if (emitted_storagebuffer) {
@ -1696,10 +1794,12 @@ bool GeneratorImpl::EmitEntryPointData(
for (auto& data : in_variables) {
auto* var = data.first;
auto* deco = data.second;
auto* type = builder_.Sem().Get(var)->Type();
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type();
make_indent(out);
if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
if (!EmitType(out, type, sem->StorageClass(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
@ -1745,10 +1845,12 @@ bool GeneratorImpl::EmitEntryPointData(
for (auto& data : outvariables) {
auto* var = data.first;
auto* deco = data.second;
auto* type = builder_.Sem().Get(var)->Type();
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type();
make_indent(out);
if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
if (!EmitType(out, type, sem->StorageClass(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
@ -1800,7 +1902,7 @@ bool GeneratorImpl::EmitEntryPointData(
continue; // Not interested in this type
}
if (!EmitType(out, var->Type(), "")) {
if (!EmitType(out, var->Type(), var->StorageClass(), "")) {
return false;
}
out << " " << namer_.NameFor(builder_.Symbols().NameFor(decl->symbol()));
@ -1914,7 +2016,8 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
// Emit entry point parameters.
for (auto* var : func->params()) {
auto* type = builder_.Sem().Get(var)->Type();
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type();
if (!type->Is<type::Struct>()) {
TINT_ICE(diagnostics_) << "Unsupported non-struct entry point parameter";
}
@ -1924,7 +2027,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
}
first = false;
if (!EmitType(out, type, "")) {
if (!EmitType(out, type, sem->StorageClass(), "")) {
return false;
}
@ -1992,7 +2095,7 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) {
} else if (type->Is<type::U32>()) {
out << "0u";
} else if (auto* vec = type->As<type::Vector>()) {
if (!EmitType(out, type, "")) {
if (!EmitType(out, type, ast::StorageClass::kNone, "")) {
return false;
}
ScopedParen sp(out);
@ -2005,7 +2108,7 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, type::Type* type) {
}
}
} else if (auto* mat = type->As<type::Matrix>()) {
if (!EmitType(out, type, "")) {
if (!EmitType(out, type, ast::StorageClass::kNone, "")) {
return false;
}
ScopedParen sp(out);
@ -2134,263 +2237,9 @@ bool GeneratorImpl::EmitLoop(std::ostream& out, ast::LoopStatement* stmt) {
return true;
}
std::string GeneratorImpl::generate_storage_buffer_index_expression(
std::ostream& pre,
ast::Expression* expr) {
std::ostringstream out;
bool first = true;
for (;;) {
if (expr->Is<ast::IdentifierExpression>()) {
break;
}
if (!first) {
out << " + ";
}
first = false;
if (auto* mem = expr->As<ast::MemberAccessorExpression>()) {
auto* res_type = TypeOf(mem->structure())->UnwrapAll();
if (auto* str = res_type->As<type::Struct>()) {
auto* str_type = str->impl();
auto* str_member = str_type->get_member(mem->member()->symbol());
auto* sem_mem = builder_.Sem().Get(str_member);
if (!sem_mem) {
TINT_ICE(diagnostics_) << "struct member missing semantic info";
return "";
}
out << sem_mem->Offset();
} else if (res_type->Is<type::Vector>()) {
auto swizzle = builder_.Sem().Get(mem)->Swizzle();
// TODO(dsinclair): Swizzle stuff
//
// This must be a single element swizzle if we've got a vector at this
// point.
if (swizzle.size() != 1) {
TINT_ICE(diagnostics_)
<< "Encountered multi-element swizzle when should have only one "
"level";
return "";
}
// TODO(dsinclair): All our types are currently 4 bytes (f32, i32, u32)
// so this is assuming 4. This will need to be fixed when we get f16 or
// f64 types.
out << "(4 * " << swizzle[0] << ")";
} else {
TINT_ICE(diagnostics_) << "Invalid result type for member accessor: "
<< res_type->type_name();
return "";
}
expr = mem->structure();
} else if (auto* ary = expr->As<ast::ArrayAccessorExpression>()) {
auto* ary_type = TypeOf(ary->array())->UnwrapAll();
out << "(";
if (auto* arr = ary_type->As<type::Array>()) {
auto* sem_arr = builder_.Sem().Get(arr);
if (!sem_arr) {
TINT_ICE(diagnostics_) << "array type missing semantic info";
return "";
}
out << sem_arr->Stride();
} else if (ary_type->Is<type::Vector>()) {
// TODO(dsinclair): This is a hack. Our vectors can only be f32, i32
// or u32 which are all 4 bytes. When we get f16 or other types we'll
// have to ask the type for the byte size.
out << "4";
} else if (auto* mat = ary_type->As<type::Matrix>()) {
if (mat->columns() == 2) {
out << "8";
} else {
out << "16";
}
} else {
diagnostics_.add_error("Invalid array type in storage buffer access");
return "";
}
out << " * ";
if (!EmitExpression(pre, out, ary->idx_expr())) {
return "";
}
out << ")";
expr = ary->array();
} else {
diagnostics_.add_error("error emitting storage buffer access");
return "";
}
}
return out.str();
}
// TODO(dsinclair): This currently only handles loading of 4, 8, 12 or 16 byte
// members. If we need to support larger we'll need to do the loading into
// chunks.
//
// TODO(dsinclair): Need to support loading through a pointer. The pointer is
// just a memory address in the storage buffer, so need to do the correct
// calculation.
bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre,
std::ostream& out,
ast::Expression* expr,
ast::Expression* rhs) {
auto* result_type = TypeOf(expr)->UnwrapAll();
bool is_store = rhs != nullptr;
std::string access_method = is_store ? "Store" : "Load";
if (auto* vec = result_type->As<type::Vector>()) {
access_method += std::to_string(vec->size());
} else if (auto* mat = result_type->As<type::Matrix>()) {
access_method += std::to_string(mat->rows());
}
// If we aren't storing then we need to put in the outer cast.
if (!is_store) {
if (result_type->is_float_scalar_or_vector() ||
result_type->Is<type::Matrix>()) {
out << "asfloat(";
} else if (result_type->is_signed_scalar_or_vector()) {
out << "asint(";
} else if (result_type->is_unsigned_scalar_or_vector()) {
out << "asuint(";
} else {
TINT_UNIMPLEMENTED(diagnostics_)
<< result_type->FriendlyName(builder_.Symbols());
return false;
}
}
auto buffer_name = get_buffer_name(expr);
if (buffer_name.empty()) {
diagnostics_.add_error("error emitting storage buffer access");
return false;
}
auto idx = generate_storage_buffer_index_expression(pre, expr);
if (idx.empty()) {
return false;
}
if (auto* mat = result_type->As<type::Matrix>()) {
// TODO(dsinclair): This is assuming 4 byte elements. Will need to be fixed
// if we get matrixes of f16 or f64.
uint32_t stride = mat->rows() == 2 ? 8 : 16;
if (is_store) {
if (!EmitType(out, mat, "")) {
return false;
}
auto name = generate_name(kTempNamePrefix);
out << " " << name << " = ";
if (!EmitExpression(pre, out, rhs)) {
return false;
}
out << ";" << std::endl;
for (uint32_t i = 0; i < mat->columns(); i++) {
if (i > 0) {
out << ";" << std::endl;
}
make_indent(out);
out << buffer_name << "." << access_method << "(" << idx << " + "
<< (i * stride) << ", asuint(" << name << "[" << i << "]))";
}
return true;
}
out << "uint" << mat->rows() << "x" << mat->columns();
ScopedParen p(out);
for (uint32_t i = 0; i < mat->columns(); i++) {
if (i != 0) {
out << ", ";
}
out << buffer_name << "." << access_method << "(" << idx << " + "
<< (i * stride) << ")";
}
} else {
out << buffer_name << "." << access_method;
ScopedParen p(out);
out << idx;
if (is_store) {
out << ", asuint";
ScopedParen p2(out);
if (!EmitExpression(pre, out, rhs)) {
return false;
}
}
}
if (!is_store) {
out << ")";
}
return true;
}
bool GeneratorImpl::is_storage_buffer_access(
ast::ArrayAccessorExpression* expr) {
// We only care about array so we can get to the next part of the expression.
// If it isn't an array or a member accessor we can stop looking as it won't
// be a storage buffer.
auto* ary = expr->array();
if (auto* member = ary->As<ast::MemberAccessorExpression>()) {
return is_storage_buffer_access(member);
} else if (auto* array = ary->As<ast::ArrayAccessorExpression>()) {
return is_storage_buffer_access(array);
}
return false;
}
bool GeneratorImpl::is_storage_buffer_access(
ast::MemberAccessorExpression* expr) {
auto* structure = expr->structure();
auto* data_type = TypeOf(structure)->UnwrapAll();
// TODO(dsinclair): Swizzle
//
// If the data is a multi-element swizzle then we will not load the swizzle
// portion through the Load command.
if (data_type->Is<type::Vector>() &&
builder_.Symbols().NameFor(expr->member()->symbol()).size() > 1) {
return false;
}
// Check if this is a storage buffer variable
if (auto* ident = expr->structure()->As<ast::IdentifierExpression>()) {
const semantic::Variable* var = nullptr;
if (!global_variables_.get(ident->symbol(), &var)) {
return false;
}
return var->StorageClass() == ast::StorageClass::kStorage;
} else if (auto* member = structure->As<ast::MemberAccessorExpression>()) {
return is_storage_buffer_access(member);
} else if (auto* array = structure->As<ast::ArrayAccessorExpression>()) {
return is_storage_buffer_access(array);
}
// Technically I don't think this is possible, but if we don't have a struct
// or array accessor then we can't have a storage buffer I believe.
return false;
}
bool GeneratorImpl::EmitMemberAccessor(std::ostream& pre,
std::ostream& out,
ast::MemberAccessorExpression* expr) {
// Look for storage buffer accesses as we have to convert them into Load
// expressions. Stores will be identified in the assignment emission and a
// member accessor store of a storage buffer will not get here.
if (is_storage_buffer_access(expr)) {
return EmitStorageBufferAccessor(pre, out, expr, nullptr);
}
if (!EmitExpression(pre, out, expr->structure())) {
return false;
}
@ -2515,12 +2364,26 @@ bool GeneratorImpl::EmitSwitch(std::ostream& out, ast::SwitchStatement* stmt) {
bool GeneratorImpl::EmitType(std::ostream& out,
type::Type* type,
ast::StorageClass storage_class,
const std::string& name) {
auto* access = type->As<type::AccessControl>();
if (access) {
type = access->type();
}
if (storage_class == ast::StorageClass::kStorage) {
if (access == nullptr) {
diagnostics_.add_error("access control type required for storage buffer");
return false;
}
if (!access->IsReadOnly()) {
out << "RW";
}
out << "ByteAddressBuffer";
return true;
}
if (auto* alias = type->As<type::Alias>()) {
out << namer_.NameFor(builder_.Symbols().NameFor(alias->symbol()));
} else if (auto* ary = type->As<type::Array>()) {
@ -2528,16 +2391,15 @@ bool GeneratorImpl::EmitType(std::ostream& out,
std::vector<uint32_t> sizes;
while (auto* arr = base_type->As<type::Array>()) {
if (arr->IsRuntimeArray()) {
// TODO(dsinclair): Support runtime arrays
// https://bugs.chromium.org/p/tint/issues/detail?id=185
diagnostics_.add_error("runtime array not supported yet.");
TINT_ICE(diagnostics_)
<< "Runtime arrays may only exist in storage buffers, which should "
"have been transformed into a ByteAddressBuffer";
return false;
} else {
sizes.push_back(arr->size());
}
sizes.push_back(arr->size());
base_type = arr->type();
}
if (!EmitType(out, base_type, "")) {
if (!EmitType(out, base_type, storage_class, "")) {
return false;
}
if (!name.empty()) {
@ -2553,7 +2415,7 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} else if (type->Is<type::I32>()) {
out << "int";
} else if (auto* mat = type->As<type::Matrix>()) {
if (!EmitType(out, mat->type(), "")) {
if (!EmitType(out, mat->type(), storage_class, "")) {
return false;
}
// Note: HLSL's matrices are declared as <type>NxM, where N is the number of
@ -2652,7 +2514,7 @@ bool GeneratorImpl::EmitType(std::ostream& out,
out << "uint" << size;
} else {
out << "vector<";
if (!EmitType(out, vec->type(), "")) {
if (!EmitType(out, vec->type(), storage_class, "")) {
return false;
}
out << ", " << size << ">";
@ -2689,7 +2551,7 @@ bool GeneratorImpl::EmitStructType(std::ostream& out,
// TODO(dsinclair): Handle [[offset]] annotation on structs
// https://bugs.chromium.org/p/tint/issues/detail?id=184
if (!EmitType(out, mem->type(),
if (!EmitType(out, mem->type(), ast::StorageClass::kNone,
builder_.Symbols().NameFor(mem->symbol()))) {
return false;
}
@ -2788,8 +2650,10 @@ bool GeneratorImpl::EmitVariable(std::ostream& out,
if (var->is_const()) {
out << "const ";
}
auto* type = builder_.Sem().Get(var)->Type();
if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type();
if (!EmitType(out, type, sem->StorageClass(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
if (!type->Is<type::Array>()) {
@ -2824,7 +2688,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
out << pre.str();
}
auto* type = builder_.Sem().Get(var)->Type();
auto* sem = builder_.Sem().Get(var);
auto* type = sem->Type();
if (ast::HasDecoration<ast::ConstantIdDecoration>(var->decorations())) {
auto const_id = var->constant_id();
@ -2840,7 +2705,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
}
out << "#endif" << std::endl;
out << "static const ";
if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
if (!EmitType(out, type, sem->StorageClass(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
out << " " << builder_.Symbols().NameFor(var->symbol())
@ -2848,7 +2714,8 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
out << "#undef WGSL_SPEC_CONSTANT_" << const_id << std::endl;
} else {
out << "static const ";
if (!EmitType(out, type, builder_.Symbols().NameFor(var->symbol()))) {
if (!EmitType(out, type, sem->StorageClass(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
if (!type->Is<type::Array>()) {

View File

@ -37,6 +37,9 @@
namespace tint {
// Forward declarations
namespace type {
class AccessControl;
} // namespace type
namespace semantic {
class Call;
class Intrinsic;
@ -266,16 +269,6 @@ class GeneratorImpl : public TextGenerator {
bool EmitMemberAccessor(std::ostream& pre,
std::ostream& out,
ast::MemberAccessorExpression* expr);
/// Handles a storage buffer accessor expression
/// @param pre the preamble for the expression stream
/// @param out the output of the expression stream
/// @param expr the storage buffer accessor expression
/// @param rhs the right side of a store expression. Set to nullptr for a load
/// @returns true if the storage buffer accessor was emitted
bool EmitStorageBufferAccessor(std::ostream& pre,
std::ostream& out,
ast::Expression* expr,
ast::Expression* rhs);
/// Handles return statements
/// @param out the output stream
/// @param stmt the statement to emit
@ -294,9 +287,13 @@ class GeneratorImpl : public TextGenerator {
/// Handles generating type
/// @param out the output stream
/// @param type the type to generate
/// @param storage_class the storage class of the variable
/// @param name the name of the variable, only used for array emission
/// @returns true if the type is emitted
bool EmitType(std::ostream& out, type::Type* type, const std::string& name);
bool EmitType(std::ostream& out,
type::Type* type,
ast::StorageClass storage_class,
const std::string& name);
/// Handles generating a structure declaration
/// @param out the output stream
/// @param ty the struct to generate
@ -332,15 +329,6 @@ class GeneratorImpl : public TextGenerator {
/// @returns true if the variable was emitted
bool EmitProgramConstVariable(std::ostream& out, const ast::Variable* var);
/// Returns true if the accessor is accessing a storage buffer.
/// @param expr the expression to check
/// @returns true if the accessor is accessing a storage buffer for which
/// we need to execute a Load instruction.
bool is_storage_buffer_access(ast::MemberAccessorExpression* expr);
/// Returns true if the accessor is accessing a storage buffer.
/// @param expr the expression to check
/// @returns true if the accessor is accessing a storage buffer
bool is_storage_buffer_access(ast::ArrayAccessorExpression* expr);
/// Registers the given global with the generator
/// @param global the global to register
void register_global(ast::Variable* global);
@ -348,12 +336,6 @@ class GeneratorImpl : public TextGenerator {
/// @param var the variable to check
/// @returns true if the global is in an input or output struct
bool global_is_in_struct(const semantic::Variable* var) const;
/// Creates a text string representing the index into a storage buffer
/// @param pre the pre stream
/// @param expr the expression to use as the index
/// @returns the index string, or blank if unable to generate
std::string generate_storage_buffer_index_expression(std::ostream& pre,
ast::Expression* expr);
/// Handles generating a builtin method name
/// @param intrinsic the semantic info for the intrinsic
/// @returns the name or "" if not valid

View File

@ -280,30 +280,30 @@ TEST_F(HlslGeneratorImplTest_Function,
EXPECT_EQ(result(), R"(struct VertexOutput {
float4 pos;
};
struct tint_symbol_2 {
struct tint_symbol_6 {
float4 pos : SV_Position;
};
struct tint_symbol_6 {
struct tint_symbol_9 {
float4 pos : SV_Position;
};
VertexOutput foo(float x) {
const VertexOutput tint_symbol_8 = {float4(x, x, x, 1.0f)};
return tint_symbol_8;
}
tint_symbol_2 vert_main1() {
const VertexOutput tint_symbol_4 = {foo(0.5f)};
const tint_symbol_2 tint_symbol_1 = {tint_symbol_4.pos};
const VertexOutput tint_symbol_1 = {float4(x, x, x, 1.0f)};
return tint_symbol_1;
}
tint_symbol_6 vert_main2() {
const VertexOutput tint_symbol_7 = {foo(0.25f)};
tint_symbol_6 vert_main1() {
const VertexOutput tint_symbol_7 = {foo(0.5f)};
const tint_symbol_6 tint_symbol_5 = {tint_symbol_7.pos};
return tint_symbol_5;
}
tint_symbol_9 vert_main2() {
const VertexOutput tint_symbol_10 = {foo(0.25f)};
const tint_symbol_9 tint_symbol_8 = {tint_symbol_10.pos};
return tint_symbol_8;
}
)");
Validate();
@ -415,16 +415,19 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::StageDecoration>(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(),
HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1);
EXPECT_EQ(result(),
R"(
RWByteAddressBuffer coord : register(u0, space1);
void frag_main() {
float v = asfloat(coord.Load(4));
float v = asfloat(coord.Load(4u));
return;
})"));
}
)");
Validate();
}
@ -456,16 +459,19 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::StageDecoration>(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(),
HasSubstr(R"(ByteAddressBuffer coord : register(t0, space1);
EXPECT_EQ(result(),
R"(
ByteAddressBuffer coord : register(t0, space1);
void frag_main() {
float v = asfloat(coord.Load(4));
float v = asfloat(coord.Load(4u));
return;
})"));
}
)");
Validate();
}
@ -495,16 +501,19 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::StageDecoration>(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(),
HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1);
EXPECT_EQ(result(),
R"(
RWByteAddressBuffer coord : register(u0, space1);
void frag_main() {
coord.Store(4, asuint(2.0f));
coord.Store(4u, asuint(2.0f));
return;
})"));
}
)");
Validate();
}
@ -534,16 +543,19 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::StageDecoration>(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(),
HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1);
EXPECT_EQ(result(),
R"(
RWByteAddressBuffer coord : register(u0, space1);
void frag_main() {
coord.Store(4, asuint(2.0f));
coord.Store(4u, asuint(2.0f));
return;
})"));
}
)");
Validate();
}
@ -792,20 +804,22 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::StageDecoration>(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = Build();
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(),
HasSubstr(R"(RWByteAddressBuffer coord : register(u0, space1);
EXPECT_EQ(result(),
R"(RWByteAddressBuffer coord : register(u0, space1);
float sub_func(float param) {
return asfloat(coord.Load((4 * 0)));
return asfloat(coord.Load(0u));
}
void frag_main() {
float v = sub_func(1.0f);
return;
})"));
}
)");
Validate();
}
@ -946,11 +960,13 @@ TEST_F(HlslGeneratorImplTest_Function,
//
// [[stage(compute)]]
// fn a() {
// var v = data.d;
// return;
// }
//
// [[stage(compute)]]
// fn b() {
// var v = data.d;
// return;
// }
@ -994,7 +1010,7 @@ TEST_F(HlslGeneratorImplTest_Function,
});
}
GeneratorImpl& gen = Build();
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(
@ -1002,13 +1018,13 @@ RWByteAddressBuffer data : register(u0, space0);
[numthreads(1, 1, 1)]
void a() {
float v = asfloat(data.Load(0));
float v = asfloat(data.Load(0u));
return;
}
[numthreads(1, 1, 1)]
void b() {
float v = asfloat(data.Load(0));
float v = asfloat(data.Load(0u));
return;
}

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,9 @@
// limitations under the License.
#include "src/ast/stage_decoration.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/variable_decl_statement.h"
#include "src/type/access_control_type.h"
#include "src/writer/hlsl/test_helper.h"
namespace tint {
@ -23,6 +25,54 @@ namespace {
using HlslSanitizerTest = TestHelper;
TEST_F(HlslSanitizerTest, ArrayLength) {
auto* sb_ty = Structure("SB",
{
Member("x", ty.f32()),
Member("arr", ty.array(ty.vec4<f32>())),
},
{
create<ast::StructBlockDecoration>(),
});
auto* ac_ty =
create<type::AccessControl>(ast::AccessControl::kReadOnly, sb_ty);
Global("sb", ac_ty, ast::StorageClass::kStorage, nullptr,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(1),
});
Func("main", ast::VariableList{}, ty.void_(),
ast::StatementList{
create<ast::VariableDeclStatement>(
Var("len", ty.u32(), ast::StorageClass::kFunction,
Call("arrayLength", MemberAccessor("sb", "arr")))),
},
ast::DecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
auto got = result();
auto* expect = R"(
ByteAddressBuffer sb : register(t0, space1);
void main() {
uint tint_symbol_9 = 0u;
sb.GetDimensions(tint_symbol_9);
const uint tint_symbol_10 = ((tint_symbol_9 - 16u) / 16u);
uint len = tint_symbol_10;
return;
}
)";
EXPECT_EQ(expect, got);
}
TEST_F(HlslSanitizerTest, PromoteArrayInitializerToConstVar) {
auto* array_init = array<i32, 4>(1, 2, 3, 4);
auto* array_index = IndexAccessor(array_init, 3);

View File

@ -37,7 +37,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Alias) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, alias, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, alias, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "alias");
}
@ -46,7 +47,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Alias_NameCollision) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, alias, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, alias, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "bool_tint_0");
}
@ -55,7 +57,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Array) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary"))
<< gen.error();
EXPECT_EQ(result(), "bool ary[4]");
}
@ -64,7 +67,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_ArrayOfArray) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary"))
<< gen.error();
EXPECT_EQ(result(), "bool ary[5][4]");
}
@ -75,7 +79,8 @@ TEST_F(HlslGeneratorImplTest_Type,
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary"))
<< gen.error();
EXPECT_EQ(result(), "bool ary[5][4][1]");
}
@ -84,7 +89,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_ArrayOfArrayOfArray) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary"))
<< gen.error();
EXPECT_EQ(result(), "bool ary[6][5][4]");
}
@ -93,7 +99,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Array_NameCollision) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "bool")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "bool"))
<< gen.error();
EXPECT_EQ(result(), "bool bool_tint_0[4]");
}
@ -102,7 +109,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Array_WithoutName) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "bool[4]");
}
@ -111,7 +119,8 @@ TEST_F(HlslGeneratorImplTest_Type, DISABLED_EmitType_RuntimeArray) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "ary")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "ary"))
<< gen.error();
EXPECT_EQ(result(), "bool ary[]");
}
@ -121,7 +130,8 @@ TEST_F(HlslGeneratorImplTest_Type,
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, arr, "double")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, arr, ast::StorageClass::kNone, "double"))
<< gen.error();
EXPECT_EQ(result(), "bool double_tint_0[]");
}
@ -130,7 +140,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Bool) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, bool_, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, bool_, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "bool");
}
@ -139,7 +150,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_F32) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, f32, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, f32, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "float");
}
@ -148,7 +160,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_I32) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, i32, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, i32, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "int");
}
@ -157,7 +170,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Matrix) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, mat2x3, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "float2x3");
}
@ -167,7 +181,8 @@ TEST_F(HlslGeneratorImplTest_Type, DISABLED_EmitType_Pointer) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, &p, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, &p, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "float*");
}
@ -210,7 +225,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, s, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, s, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "S");
}
@ -227,7 +243,8 @@ TEST_F(HlslGeneratorImplTest_Type, DISABLED_EmitType_Struct_InjectPadding) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, s, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, s, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(gen.result(), R"(struct S {
int a;
int8_t pad_0[28];
@ -280,7 +297,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_U32) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, u32, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, u32, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "uint");
}
@ -289,7 +307,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Vector) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, vec3, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, vec3, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "float3");
}
@ -298,7 +317,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Void) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, void_, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, void_, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "void");
}
@ -307,7 +327,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitSampler) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, &sampler, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, &sampler, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "SamplerState");
}
@ -316,7 +337,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitSamplerComparison) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, &sampler, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, &sampler, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "SamplerComparisonState");
}
@ -419,7 +441,8 @@ TEST_F(HlslGeneratorImplTest_Type, EmitMultisampledTexture) {
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitType(out, &s, "")) << gen.error();
ASSERT_TRUE(gen.EmitType(out, &s, ast::StorageClass::kNone, ""))
<< gen.error();
EXPECT_EQ(result(), "Texture2DMS<float4>");
}