writer/hlsl: Emit UBO as an array of vector

Instead of a ConstantBuffer.

HLSL requires that each structure field in a UBO is 16 byte aligned.
WGSL has much looser constraints with its UBO field alignment rules.

Instead generate an array of uint4 vectors, and index into this, much
like we index into [RW]ByteAddressBuffers for SSBOs.

Extend the DecomposeStorageAccess transform to support uniforms too.
This has been renamed to DecomposeMemoryAccess.

Change-Id: I3868ff80af1ab3b3dddfbf5b969724cb87ef0744
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55246
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton
2021-06-18 21:15:25 +00:00
parent 9efc4fcc89
commit 165512c57e
23 changed files with 1145 additions and 1393 deletions

View File

@@ -556,8 +556,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/calculate_array_length.h",
"transform/canonicalize_entry_point_io.cc",
"transform/canonicalize_entry_point_io.h",
"transform/decompose_storage_access.cc",
"transform/decompose_storage_access.h",
"transform/decompose_memory_access.cc",
"transform/decompose_memory_access.h",
"transform/external_texture_transform.cc",
"transform/external_texture_transform.h",
"transform/first_index_offset.cc",

View File

@@ -281,8 +281,8 @@ set(TINT_LIB_SRCS
transform/calculate_array_length.h
transform/canonicalize_entry_point_io.cc
transform/canonicalize_entry_point_io.h
transform/decompose_storage_access.cc
transform/decompose_storage_access.h
transform/decompose_memory_access.cc
transform/decompose_memory_access.h
transform/external_texture_transform.cc
transform/external_texture_transform.h
transform/first_index_offset.cc
@@ -860,7 +860,7 @@ if(${TINT_BUILD_TESTS})
transform/bound_array_accessors_test.cc
transform/calculate_array_length_test.cc
transform/canonicalize_entry_point_io_test.cc
transform/decompose_storage_access_test.cc
transform/decompose_memory_access_test.cc
transform/external_texture_transform_test.cc
transform/first_index_offset_test.cc
transform/fold_constants_test.cc

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_storage_access.h"
#include "src/transform/decompose_memory_access.h"
#include <memory>
#include <string>
@@ -38,7 +38,7 @@
#include "src/utils/get_or_create.h"
#include "src/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStorageAccess::Intrinsic);
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic);
namespace tint {
namespace transform {
@@ -46,7 +46,7 @@ namespace transform {
namespace {
/// Offset is a simple ast::Expression builder interface, used to build byte
/// offsets for storage buffer accesses.
/// offsets for storage and uniform buffer accesses.
struct Offset : Castable<Offset> {
/// @returns builds and returns the ast::Expression in `ctx.dst`
virtual ast::Expression* Build(CloneContext& ctx) = 0;
@@ -179,14 +179,16 @@ std::unique_ptr<Offset> Mul(LHS&& lhs_, RHS&& rhs_) {
/// LoadStoreKey is the unordered map key to a load or store intrinsic.
struct LoadStoreKey {
sem::Type const* buf_ty; // buffer type
sem::Type const* el_ty; // element type
ast::StorageClass const storage_class; // buffer storage class
sem::Type const* buf_ty; // buffer type
sem::Type const* el_ty; // element type
bool operator==(const LoadStoreKey& rhs) const {
return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty;
return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
el_ty == rhs.el_ty;
}
struct Hasher {
inline std::size_t operator()(const LoadStoreKey& u) const {
return utils::Hash(u.buf_ty, u.el_ty);
return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
}
};
};
@@ -218,60 +220,60 @@ uint32_t MatrixColumnStride(const sem::Matrix* mat) {
}
bool IntrinsicDataTypeFor(const sem::Type* ty,
DecomposeStorageAccess::Intrinsic::DataType& out) {
DecomposeMemoryAccess::Intrinsic::DataType& out) {
if (ty->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kI32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
return true;
}
if (ty->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kU32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
return true;
}
if (ty->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kF32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
return true;
}
if (auto* vec = ty->As<sem::Vector>()) {
switch (vec->size()) {
case 2:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec2I32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec2U32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec2F32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
return true;
}
break;
case 3:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec3I32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec3U32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec3F32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
return true;
}
break;
case 4:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec4I32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec4U32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec4F32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
return true;
}
break;
@@ -282,80 +284,86 @@ bool IntrinsicDataTypeFor(const sem::Type* ty,
return false;
}
/// @returns a DecomposeStorageAccess::Intrinsic decoration that can be applied
/// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
/// to a stub function to load the type `ty`.
DecomposeStorageAccess::Intrinsic* IntrinsicLoadFor(ProgramBuilder* builder,
const sem::Type* ty) {
DecomposeStorageAccess::Intrinsic::DataType type;
DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(
ProgramBuilder* builder,
ast::StorageClass storage_class,
const sem::Type* ty) {
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
builder->ID(), DecomposeStorageAccess::Intrinsic::Op::kLoad, type);
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class,
type);
}
/// @returns a DecomposeStorageAccess::Intrinsic decoration that can be applied
/// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
/// to a stub function to store the type `ty`.
DecomposeStorageAccess::Intrinsic* IntrinsicStoreFor(ProgramBuilder* builder,
const sem::Type* ty) {
DecomposeStorageAccess::Intrinsic::DataType type;
DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(
ProgramBuilder* builder,
ast::StorageClass storage_class,
const sem::Type* ty) {
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
builder->ID(), DecomposeStorageAccess::Intrinsic::Op::kStore, type);
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
storage_class, type);
}
/// @returns a DecomposeStorageAccess::Intrinsic decoration that can be applied
/// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
/// to a stub function for the atomic op and the type `ty`.
DecomposeStorageAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
sem::IntrinsicType ity,
const sem::Type* ty) {
auto op = DecomposeStorageAccess::Intrinsic::Op::kAtomicLoad;
DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
sem::IntrinsicType ity,
const sem::Type* ty) {
auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
switch (ity) {
case sem::IntrinsicType::kAtomicLoad:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicLoad;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
break;
case sem::IntrinsicType::kAtomicStore:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicStore;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
break;
case sem::IntrinsicType::kAtomicAdd:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicAdd;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
break;
case sem::IntrinsicType::kAtomicMax:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicMax;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
break;
case sem::IntrinsicType::kAtomicMin:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicMin;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
break;
case sem::IntrinsicType::kAtomicAnd:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicAnd;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
break;
case sem::IntrinsicType::kAtomicOr:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicOr;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
break;
case sem::IntrinsicType::kAtomicXor:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicXor;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
break;
case sem::IntrinsicType::kAtomicExchange:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicExchange;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
break;
case sem::IntrinsicType::kAtomicCompareExchangeWeak:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
break;
default:
TINT_ICE(builder->Diagnostics())
<< "invalid IntrinsicType for DecomposeStorageAccess::Intrinsic: "
<< "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
<< ty->type_name();
break;
}
DecomposeStorageAccess::Intrinsic::DataType type;
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
builder->ID(), op, type);
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), op, ast::StorageClass::kStorage, type);
}
/// Inserts `node` before `insert_after` in the global declarations of
@@ -387,30 +395,30 @@ const ast::TypeDecl* TypeDeclOf(const sem::Type* ty) {
}
}
/// StorageBufferAccess describes a single storage buffer access
struct StorageBufferAccess {
/// BufferAccess describes a single storage or uniform buffer access
struct BufferAccess {
sem::Expression const* var = nullptr; // Storage buffer variable
std::unique_ptr<Offset> offset; // The byte offset on var
sem::Type const* type = nullptr; // The type of the access
operator bool() const { return var; } // Returns true if valid
};
/// Store describes a single storage buffer write
/// Store describes a single storage or uniform buffer write
struct Store {
ast::AssignmentStatement* assignment; // The AST assignment statement
StorageBufferAccess target; // The target for the write
BufferAccess target; // The target for the write
};
} // namespace
/// State holds the current transform state
struct DecomposeStorageAccess::State {
/// Map of AST expression to storage buffer access
struct DecomposeMemoryAccess::State {
/// Map of AST expression to storage or uniform buffer access
/// This map has entries added when encountered, and removed when outer
/// expressions chain the access.
/// Subset of #expression_order, as expressions are not removed from
/// #expression_order.
std::unordered_map<ast::Expression*, StorageBufferAccess> accesses;
std::unordered_map<ast::Expression*, BufferAccess> accesses;
/// The visited order of AST expressions (superset of #accesses)
std::vector<ast::Expression*> expression_order;
/// [buffer-type, element-type] -> load function name
@@ -419,25 +427,25 @@ struct DecomposeStorageAccess::State {
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
/// [buffer-type, element-type, atomic-op] -> load function name
std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
/// List of storage buffer writes
/// List of storage or uniform buffer writes
std::vector<Store> stores;
/// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
/// to #expression_order.
/// @param expr the expression that performs the access
/// @param access the access
void AddAccess(ast::Expression* expr, StorageBufferAccess&& access) {
void AddAccess(ast::Expression* expr, BufferAccess&& access) {
TINT_ASSERT(access.type);
accesses.emplace(expr, std::move(access));
expression_order.emplace_back(expr);
}
/// TakeAccess() removes the `node` item from #accesses (if it exists),
/// returning the StorageBufferAccess. If #accesses does not hold an item for
/// `node`, an invalid StorageBufferAccess is returned.
/// returning the BufferAccess. If #accesses does not hold an item for
/// `node`, an invalid BufferAccess is returned.
/// @param node the expression that performed an access
/// @return the StorageBufferAccess for the given expression
StorageBufferAccess TakeAccess(ast::Expression* node) {
/// @return the BufferAccess for the given expression
BufferAccess TakeAccess(ast::Expression* node) {
auto lhs_it = accesses.find(node);
if (lhs_it == accesses.end()) {
return {};
@@ -448,12 +456,13 @@ struct DecomposeStorageAccess::State {
}
/// LoadFunc() returns a symbol to an intrinsic function that loads an element
/// of type `el_ty` from a storage buffer of type `buf_ty`. The function has
/// the signature: `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
/// The emitted function has the signature:
/// `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @param buf_ty the storage or uniform buffer type
/// @param el_ty the storage or uniform buffer element type
/// @param var_user the variable user
/// @return the name of the function that performs the load
Symbol LoadFunc(CloneContext& ctx,
@@ -461,71 +470,79 @@ struct DecomposeStorageAccess::State {
const sem::Type* buf_ty,
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
return utils::GetOrCreate(load_funcs, LoadStoreKey{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(
load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
ast::VariableList params = {
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
};
ast::VariableList params = {
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer or cbuffer
// array.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), storage_class,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty, nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
} else {
ast::ExpressionList values;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load = LoadFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
ast::Function* func = nullptr;
if (auto* intrinsic =
IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty, nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes()
.Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
} else {
ast::ExpressionList values;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load =
LoadFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
}
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty,
ctx.dst->Block(ctx.dst->Return(
ctx.dst->create<ast::TypeConstructorExpression>(
CreateASTTypeFor(&ctx, el_ty), values))),
ast::DecorationList{}, ast::DecorationList{});
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
}
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty,
ctx.dst->Block(
ctx.dst->Return(ctx.dst->create<ast::TypeConstructorExpression>(
CreateASTTypeFor(&ctx, el_ty), values))),
ast::DecorationList{}, ast::DecorationList{});
}
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
}
/// StoreFunc() returns a symbol to an intrinsic function that stores an
/// element of type `el_ty` to a storage buffer of type `buf_ty`. The function
/// has the signature: `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// element of type `el_ty` to a storage buffer of type `buf_ty`.
/// The function has the signature:
/// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
@@ -537,70 +554,79 @@ struct DecomposeStorageAccess::State {
const sem::Type* buf_ty,
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
return utils::GetOrCreate(store_funcs, LoadStoreKey{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
ast::VariableList params{
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
ctx.dst->Param("value", el_ast_ty),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, el_ty)) {
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(), nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(
store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
ast::VariableList params{
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), storage_class,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
ctx.dst->Param("value", el_ast_ty),
};
ast::Function* func = nullptr;
if (auto* intrinsic =
IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(), nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes()
.Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
} else {
ast::StatementList body;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
auto* access = ctx.dst->IndexAccessor("value", i);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
} else {
ast::StatementList body;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store =
StoreFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
auto* access = ctx.dst->IndexAccessor("value", i);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
auto* access = ctx.dst->MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol()));
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
auto* access =
ctx.dst->IndexAccessor("value", ctx.dst->Expr(i));
Symbol store =
StoreFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
}
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(),
ctx.dst->Block(body), ast::DecorationList{},
ast::DecorationList{});
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
auto* access = ctx.dst->MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol()));
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
auto* access = ctx.dst->IndexAccessor("value", ctx.dst->Expr(i));
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
}
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(), ctx.dst->Block(body),
ast::DecorationList{}, ast::DecorationList{});
}
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
}
/// AtomicFunc() returns a symbol to an intrinsic function that performs an
@@ -667,12 +693,13 @@ struct DecomposeStorageAccess::State {
}
};
DecomposeStorageAccess::Intrinsic::Intrinsic(ProgramID program_id,
Op o,
DataType ty)
: Base(program_id), op(o), type(ty) {}
DecomposeStorageAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeStorageAccess::Intrinsic::InternalName() const {
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID program_id,
Op o,
ast::StorageClass sc,
DataType ty)
: Base(program_id), op(o), storage_class(sc), type(ty) {}
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
std::stringstream ss;
switch (op) {
case Op::kLoad:
@@ -712,6 +739,7 @@ std::string DecomposeStorageAccess::Intrinsic::InternalName() const {
ss << "intrinsic_atomic_compare_exchange_weak_";
break;
}
ss << storage_class << "_";
switch (type) {
case DataType::kU32:
ss << "u32";
@@ -753,16 +781,16 @@ std::string DecomposeStorageAccess::Intrinsic::InternalName() const {
return ss.str();
}
DecomposeStorageAccess::Intrinsic* DecomposeStorageAccess::Intrinsic::Clone(
DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
CloneContext* ctx) const {
return ctx->dst->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
ctx->dst->ID(), op, type);
return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
ctx->dst->ID(), op, storage_class, type);
}
DecomposeStorageAccess::DecomposeStorageAccess() = default;
DecomposeStorageAccess::~DecomposeStorageAccess() = default;
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
Output DecomposeMemoryAccess::Run(const Program* in, const DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
@@ -770,9 +798,10 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
State state;
// Scan the AST nodes for storage buffer accesses. Complex expression chains
// (e.g. `storage_buffer.foo.bar[20].x`) are handled by maintaining an offset
// chain via the `state.TakeAccess()`, `state.AddAccess()` methods.
// Scan the AST nodes for storage and uniform buffer accesses. Complex
// expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
// maintaining an offset chain via the `state.TakeAccess()`,
// `state.AddAccess()` methods.
//
// Inner-most expression nodes are guaranteed to be visited first because AST
// nodes are fully immutable and require their children to be constructed
@@ -781,8 +810,9 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
// X
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage) {
// Variable to a storage buffer
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
// Variable to a storage or uniform buffer
state.AddAccess(ident, {
var,
ToOffset(0u),
@@ -940,7 +970,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
ctx.Replace(expr, load);
}
// And replace all storage buffer assignments with stores
// And replace all storage and uniform buffer assignments with stores
for (auto& store : state.stores) {
auto* buf = store.target.var->Declaration();
auto* offset = store.target.offset->Build(ctx);

View File

@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_DECOMPOSE_STORAGE_ACCESS_H_
#define SRC_TRANSFORM_DECOMPOSE_STORAGE_ACCESS_H_
#ifndef SRC_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
#define SRC_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
#include <string>
@@ -27,10 +27,10 @@ class CloneContext;
namespace transform {
/// DecomposeStorageAccess is a transform used to replace storage buffer
/// accesses with a combination of load, store or atomic functions on primitive
/// types.
class DecomposeStorageAccess : public Transform {
/// DecomposeMemoryAccess is a transform used to replace storage and uniform
/// buffer accesses with a combination of load, store or atomic functions on
/// primitive types.
class DecomposeMemoryAccess : public Transform {
public:
/// Intrinsic is an InternalDecoration that's used to decorate a stub function
/// so that the HLSL transforms this into calls to
@@ -73,8 +73,9 @@ class DecomposeStorageAccess : public Transform {
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param o the op of the intrinsic
/// @param sc the storage class of the buffer
/// @param ty the data type of the intrinsic
Intrinsic(ProgramID program_id, Op o, DataType ty);
Intrinsic(ProgramID program_id, Op o, ast::StorageClass sc, DataType ty);
/// Destructor
~Intrinsic() override;
@@ -90,14 +91,17 @@ class DecomposeStorageAccess : public Transform {
/// The op of the intrinsic
Op const op;
/// The storage class of the buffer this intrinsic operates on
ast::StorageClass const storage_class;
/// The type of the intrinsic
DataType const type;
};
/// Constructor
DecomposeStorageAccess();
DecomposeMemoryAccess();
/// Destructor
~DecomposeStorageAccess() override;
~DecomposeMemoryAccess() override;
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
@@ -111,4 +115,4 @@ class DecomposeStorageAccess : public Transform {
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_DECOMPOSE_STORAGE_ACCESS_H_
#endif // SRC_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_storage_access.h"
#include "src/transform/decompose_memory_access.h"
#include "src/transform/test_helper.h"
@@ -20,9 +20,9 @@ namespace tint {
namespace transform {
namespace {
using DecomposeStorageAccessTest = TransformTest;
using DecomposeMemoryAccessTest = TransformTest;
TEST_F(DecomposeStorageAccessTest, BasicLoad) {
TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) {
auto* src = R"(
[[block]]
struct SB {
@@ -106,40 +106,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_load_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> i32
[[internal(intrinsic_load_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32) -> u32
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32) -> f32
[[internal(intrinsic_load_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32) -> vec2<i32>
[[internal(intrinsic_load_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32) -> vec2<u32>
[[internal(intrinsic_load_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32) -> vec2<f32>
[[internal(intrinsic_load_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32) -> vec3<i32>
[[internal(intrinsic_load_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32) -> vec3<u32>
[[internal(intrinsic_load_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32) -> vec3<f32>
[[internal(intrinsic_load_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32) -> vec4<i32>
[[internal(intrinsic_load_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32) -> vec4<u32>
[[internal(intrinsic_load_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32) -> vec4<f32>
fn tint_symbol_12(buffer : SB, offset : u32) -> mat2x2<f32> {
@@ -211,12 +211,206 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, BasicStore) {
TEST_F(DecomposeMemoryAccessTest, UB_BasicLoad) {
auto* src = R"(
[[block]]
struct UB {
a : i32;
b : u32;
c : f32;
d : vec2<i32>;
e : vec2<u32>;
f : vec2<f32>;
g : vec3<i32>;
h : vec3<u32>;
i : vec3<f32>;
j : vec4<i32>;
k : vec4<u32>;
l : vec4<f32>;
m : mat2x2<f32>;
n : mat2x3<f32>;
o : mat2x4<f32>;
p : mat3x2<f32>;
q : mat3x3<f32>;
r : mat3x4<f32>;
s : mat4x2<f32>;
t : mat4x3<f32>;
u : mat4x4<f32>;
v : array<vec3<f32>, 2>;
};
[[group(0), binding(0)]] var<uniform> ub : UB;
[[stage(compute)]]
fn main() {
var a : i32 = ub.a;
var b : u32 = ub.b;
var c : f32 = ub.c;
var d : vec2<i32> = ub.d;
var e : vec2<u32> = ub.e;
var f : vec2<f32> = ub.f;
var g : vec3<i32> = ub.g;
var h : vec3<u32> = ub.h;
var i : vec3<f32> = ub.i;
var j : vec4<i32> = ub.j;
var k : vec4<u32> = ub.k;
var l : vec4<f32> = ub.l;
var m : mat2x2<f32> = ub.m;
var n : mat2x3<f32> = ub.n;
var o : mat2x4<f32> = ub.o;
var p : mat3x2<f32> = ub.p;
var q : mat3x3<f32> = ub.q;
var r : mat3x4<f32> = ub.r;
var s : mat4x2<f32> = ub.s;
var t : mat4x3<f32> = ub.t;
var u : mat4x4<f32> = ub.u;
var v : array<vec3<f32>, 2> = ub.v;
}
)";
auto* expect = R"(
[[block]]
struct UB {
a : i32;
b : u32;
c : f32;
d : vec2<i32>;
e : vec2<u32>;
f : vec2<f32>;
g : vec3<i32>;
h : vec3<u32>;
i : vec3<f32>;
j : vec4<i32>;
k : vec4<u32>;
l : vec4<f32>;
m : mat2x2<f32>;
n : mat2x3<f32>;
o : mat2x4<f32>;
p : mat3x2<f32>;
q : mat3x3<f32>;
r : mat3x4<f32>;
s : mat4x2<f32>;
t : mat4x3<f32>;
u : mat4x4<f32>;
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_load_uniform_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : UB, offset : u32) -> i32
[[internal(intrinsic_load_uniform_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : UB, offset : u32) -> u32
[[internal(intrinsic_load_uniform_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : UB, offset : u32) -> f32
[[internal(intrinsic_load_uniform_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : UB, offset : u32) -> vec2<i32>
[[internal(intrinsic_load_uniform_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : UB, offset : u32) -> vec2<u32>
[[internal(intrinsic_load_uniform_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : UB, offset : u32) -> vec2<f32>
[[internal(intrinsic_load_uniform_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : UB, offset : u32) -> vec3<i32>
[[internal(intrinsic_load_uniform_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : UB, offset : u32) -> vec3<u32>
[[internal(intrinsic_load_uniform_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : UB, offset : u32) -> vec3<f32>
[[internal(intrinsic_load_uniform_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : UB, offset : u32) -> vec4<i32>
[[internal(intrinsic_load_uniform_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : UB, offset : u32) -> vec4<u32>
[[internal(intrinsic_load_uniform_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : UB, offset : u32) -> vec4<f32>
fn tint_symbol_12(buffer : UB, offset : u32) -> mat2x2<f32> {
return mat2x2<f32>(tint_symbol_5(buffer, (offset + 0u)), tint_symbol_5(buffer, (offset + 8u)));
}
fn tint_symbol_13(buffer : UB, offset : u32) -> mat2x3<f32> {
return mat2x3<f32>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)));
}
fn tint_symbol_14(buffer : UB, offset : u32) -> mat2x4<f32> {
return mat2x4<f32>(tint_symbol_11(buffer, (offset + 0u)), tint_symbol_11(buffer, (offset + 16u)));
}
fn tint_symbol_15(buffer : UB, offset : u32) -> mat3x2<f32> {
return mat3x2<f32>(tint_symbol_5(buffer, (offset + 0u)), tint_symbol_5(buffer, (offset + 8u)), tint_symbol_5(buffer, (offset + 16u)));
}
fn tint_symbol_16(buffer : UB, offset : u32) -> mat3x3<f32> {
return mat3x3<f32>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)), tint_symbol_8(buffer, (offset + 32u)));
}
fn tint_symbol_17(buffer : UB, offset : u32) -> mat3x4<f32> {
return mat3x4<f32>(tint_symbol_11(buffer, (offset + 0u)), tint_symbol_11(buffer, (offset + 16u)), tint_symbol_11(buffer, (offset + 32u)));
}
fn tint_symbol_18(buffer : UB, offset : u32) -> mat4x2<f32> {
return mat4x2<f32>(tint_symbol_5(buffer, (offset + 0u)), tint_symbol_5(buffer, (offset + 8u)), tint_symbol_5(buffer, (offset + 16u)), tint_symbol_5(buffer, (offset + 24u)));
}
fn tint_symbol_19(buffer : UB, offset : u32) -> mat4x3<f32> {
return mat4x3<f32>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)), tint_symbol_8(buffer, (offset + 32u)), tint_symbol_8(buffer, (offset + 48u)));
}
fn tint_symbol_20(buffer : UB, offset : u32) -> mat4x4<f32> {
return mat4x4<f32>(tint_symbol_11(buffer, (offset + 0u)), tint_symbol_11(buffer, (offset + 16u)), tint_symbol_11(buffer, (offset + 32u)), tint_symbol_11(buffer, (offset + 48u)));
}
fn tint_symbol_21(buffer : UB, offset : u32) -> array<vec3<f32>, 2> {
return array<vec3<f32>, 2>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)));
}
[[group(0), binding(0)]] var<uniform> ub : UB;
[[stage(compute)]]
fn main() {
var a : i32 = tint_symbol(ub, 0u);
var b : u32 = tint_symbol_1(ub, 4u);
var c : f32 = tint_symbol_2(ub, 8u);
var d : vec2<i32> = tint_symbol_3(ub, 16u);
var e : vec2<u32> = tint_symbol_4(ub, 24u);
var f : vec2<f32> = tint_symbol_5(ub, 32u);
var g : vec3<i32> = tint_symbol_6(ub, 48u);
var h : vec3<u32> = tint_symbol_7(ub, 64u);
var i : vec3<f32> = tint_symbol_8(ub, 80u);
var j : vec4<i32> = tint_symbol_9(ub, 96u);
var k : vec4<u32> = tint_symbol_10(ub, 112u);
var l : vec4<f32> = tint_symbol_11(ub, 128u);
var m : mat2x2<f32> = tint_symbol_12(ub, 144u);
var n : mat2x3<f32> = tint_symbol_13(ub, 160u);
var o : mat2x4<f32> = tint_symbol_14(ub, 192u);
var p : mat3x2<f32> = tint_symbol_15(ub, 224u);
var q : mat3x3<f32> = tint_symbol_16(ub, 256u);
var r : mat3x4<f32> = tint_symbol_17(ub, 304u);
var s : mat4x2<f32> = tint_symbol_18(ub, 352u);
var t : mat4x3<f32> = tint_symbol_19(ub, 384u);
var u : mat4x4<f32> = tint_symbol_20(ub, 448u);
var v : array<vec3<f32>, 2> = tint_symbol_21(ub, 512u);
}
)";
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicStore) {
auto* src = R"(
[[block]]
struct SB {
@@ -300,40 +494,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_store_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32, value : i32)
[[internal(intrinsic_store_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32, value : u32)
[[internal(intrinsic_store_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32, value : f32)
[[internal(intrinsic_store_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32, value : vec2<i32>)
[[internal(intrinsic_store_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32, value : vec2<u32>)
[[internal(intrinsic_store_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32, value : vec2<f32>)
[[internal(intrinsic_store_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32, value : vec3<i32>)
[[internal(intrinsic_store_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32, value : vec3<u32>)
[[internal(intrinsic_store_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32, value : vec3<f32>)
[[internal(intrinsic_store_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32, value : vec4<i32>)
[[internal(intrinsic_store_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32, value : vec4<u32>)
[[internal(intrinsic_store_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32, value : vec4<f32>)
fn tint_symbol_12(buffer : SB, offset : u32, value : mat2x2<f32>) {
@@ -424,12 +618,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, LoadStructure) {
TEST_F(DecomposeMemoryAccessTest, LoadStructure) {
auto* src = R"(
[[block]]
struct SB {
@@ -492,40 +686,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_load_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> i32
[[internal(intrinsic_load_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32) -> u32
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32) -> f32
[[internal(intrinsic_load_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32) -> vec2<i32>
[[internal(intrinsic_load_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32) -> vec2<u32>
[[internal(intrinsic_load_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32) -> vec2<f32>
[[internal(intrinsic_load_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32) -> vec3<i32>
[[internal(intrinsic_load_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32) -> vec3<u32>
[[internal(intrinsic_load_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32) -> vec3<f32>
[[internal(intrinsic_load_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32) -> vec4<i32>
[[internal(intrinsic_load_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32) -> vec4<u32>
[[internal(intrinsic_load_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32) -> vec4<f32>
fn tint_symbol_12(buffer : SB, offset : u32) -> mat2x2<f32> {
@@ -580,12 +774,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, StoreStructure) {
TEST_F(DecomposeMemoryAccessTest, StoreStructure) {
auto* src = R"(
[[block]]
struct SB {
@@ -648,40 +842,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_store_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32, value : i32)
[[internal(intrinsic_store_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32, value : u32)
[[internal(intrinsic_store_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32, value : f32)
[[internal(intrinsic_store_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32, value : vec2<i32>)
[[internal(intrinsic_store_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32, value : vec2<u32>)
[[internal(intrinsic_store_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32, value : vec2<f32>)
[[internal(intrinsic_store_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32, value : vec3<i32>)
[[internal(intrinsic_store_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32, value : vec3<u32>)
[[internal(intrinsic_store_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32, value : vec3<f32>)
[[internal(intrinsic_store_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32, value : vec4<i32>)
[[internal(intrinsic_store_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32, value : vec4<u32>)
[[internal(intrinsic_store_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32, value : vec4<f32>)
fn tint_symbol_12(buffer : SB, offset : u32, value : mat2x2<f32>) {
@@ -776,12 +970,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, ComplexStaticAccessChain) {
TEST_F(DecomposeMemoryAccessTest, ComplexStaticAccessChain) {
auto* src = R"(
struct S1 {
a : i32;
@@ -837,7 +1031,7 @@ struct SB {
b : [[stride(256)]] array<S2>;
};
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> f32
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -848,12 +1042,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, ComplexDynamicAccessChain) {
TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChain) {
auto* src = R"(
struct S1 {
a : i32;
@@ -905,7 +1099,7 @@ struct SB {
b : [[stride(256)]] array<S2>;
};
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> f32
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -919,12 +1113,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, ComplexDynamicAccessChainWithAliases) {
TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChainWithAliases) {
auto* src = R"(
struct S1 {
a : i32;
@@ -992,7 +1186,7 @@ struct SB {
b : A2_Array;
};
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> f32
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -1006,12 +1200,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, StorageBufferAtomics) {
TEST_F(DecomposeMemoryAccessTest, StorageBufferAtomics) {
auto* src = R"(
[[block]]
struct SB {
@@ -1056,64 +1250,64 @@ struct SB {
b : atomic<u32>;
};
[[internal(intrinsic_atomic_store_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32, param_1 : i32)
[[internal(intrinsic_atomic_load_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32) -> i32
[[internal(intrinsic_atomic_add_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_add_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_max_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_max_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_min_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_min_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_and_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_and_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_or_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_or_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_xor_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_xor_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_exchange_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_exchange_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_compare_exchange_weak_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_compare_exchange_weak_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32, param_1 : i32, param_2 : i32) -> vec2<i32>
[[internal(intrinsic_atomic_store_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_store_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32, param_1 : u32)
[[internal(intrinsic_atomic_load_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_load_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32) -> u32
[[internal(intrinsic_atomic_add_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_add_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_12(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_max_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_max_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_13(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_min_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_min_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_14(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_and_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_and_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_15(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_or_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_or_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_16(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_xor_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_xor_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_17(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_exchange_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_exchange_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_18(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_compare_exchange_weak_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_compare_exchange_weak_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_19(buffer : SB, offset : u32, param_1 : u32, param_2 : u32) -> vec2<u32>
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -1143,12 +1337,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, WorkgroupBufferAtomics) {
TEST_F(DecomposeMemoryAccessTest, WorkgroupBufferAtomics) {
auto* src = R"(
struct S {
padding : vec4<f32>;
@@ -1185,7 +1379,7 @@ fn main() {
auto* expect = src;
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}

View File

@@ -19,7 +19,7 @@
#include "src/program_builder.h"
#include "src/transform/calculate_array_length.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/decompose_storage_access.h"
#include "src/transform/decompose_memory_access.h"
#include "src/transform/external_texture_transform.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
@@ -41,13 +41,13 @@ Output Hlsl::Run(const Program* in, const DataMap&) {
manager.Add<InlinePointerLets>();
// Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets.
manager.Add<Simplify>();
// DecomposeStorageAccess must come after InlinePointerLets as we cannot take
// the address of calls to DecomposeStorageAccess::Intrinsic. Must also come
// DecomposeMemoryAccess must come after InlinePointerLets as we cannot take
// the address of calls to DecomposeMemoryAccess::Intrinsic. Must also come
// after Simplify, as we need to fold away the address-of and defers of
// `*(&(intrinsic_load()))` expressions.
manager.Add<DecomposeStorageAccess>();
// CalculateArrayLength must come after DecomposeStorageAccess, as
// DecomposeStorageAccess special-cases the arrayLength() intrinsic, which
manager.Add<DecomposeMemoryAccess>();
// CalculateArrayLength must come after DecomposeMemoryAccess, as
// DecomposeMemoryAccess special-cases the arrayLength() intrinsic, which
// will be transformed by CalculateArrayLength
manager.Add<CalculateArrayLength>();
manager.Add<ExternalTextureTransform>();

View File

@@ -429,9 +429,19 @@ bool GeneratorImpl::EmitCall(std::ostream& pre,
}
if (auto* intrinsic =
ast::GetDecoration<transform::DecomposeStorageAccess::Intrinsic>(
ast::GetDecoration<transform::DecomposeMemoryAccess::Intrinsic>(
func->Declaration()->decorations())) {
return EmitDecomposeStorageAccessIntrinsic(pre, out, expr, intrinsic);
switch (intrinsic->storage_class) {
case ast::StorageClass::kUniform:
return EmitUniformBufferAccess(pre, out, expr, intrinsic);
case ast::StorageClass::kStorage:
return EmitStorageBufferAccess(pre, out, expr, intrinsic);
default:
TINT_UNREACHABLE(diagnostics_)
<< "unsupported DecomposeMemoryAccess::Intrinsic storage class:"
<< intrinsic->storage_class;
return false;
}
}
}
@@ -507,15 +517,122 @@ bool GeneratorImpl::EmitCall(std::ostream& pre,
return true;
}
bool GeneratorImpl::EmitDecomposeStorageAccessIntrinsic(
bool GeneratorImpl::EmitUniformBufferAccess(
std::ostream& pre,
std::ostream& out,
ast::CallExpression* expr,
const transform::DecomposeStorageAccess::Intrinsic* intrinsic) {
const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
const auto& params = expr->params();
using Op = transform::DecomposeStorageAccess::Intrinsic::Op;
using DataType = transform::DecomposeStorageAccess::Intrinsic::DataType;
std::string scalar_offset = generate_name("scalar_offset");
{
std::stringstream ss;
ss << "const int " << scalar_offset << " = (";
if (!EmitExpression(pre, ss, params[1])) { // offset
return false;
}
make_indent(ss << ") / 4;" << std::endl);
pre << ss.str();
}
using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
using DataType = transform::DecomposeMemoryAccess::Intrinsic::DataType;
switch (intrinsic->op) {
case Op::kLoad: {
auto cast = [&](const char* to, auto&& load) {
out << to << "(";
auto result = load();
out << ")";
return result;
};
auto load_scalar = [&]() {
if (!EmitExpression(pre, out, params[0])) { // buffer
return false;
}
out << "[" << scalar_offset << " / 4][" << scalar_offset << " % 4]";
return true;
};
// Has a minimum alignment of 8 bytes, so is either .xy or .zw
auto load_vec2 = [&] {
std::string ubo_load = generate_name("ubo_load");
std::stringstream ss;
ss << "uint4 " << ubo_load << " = ";
if (!EmitExpression(pre, ss, params[0])) { // buffer
return false;
}
ss << "[" << scalar_offset << " / 4]";
make_indent(ss << ";" << std::endl);
pre << ss.str();
out << "(" << scalar_offset << " & 2) ? " << ubo_load
<< ".zw : " << ubo_load << ".xy";
return true;
};
// vec3 has a minimum alignment of 16 bytes, so is just a .xyz swizzle
auto load_vec3 = [&] {
if (!EmitExpression(pre, out, params[0])) { // buffer
return false;
}
out << "[" << scalar_offset << " / 4].xyz";
return true;
};
// vec4 has a minimum alignment of 16 bytes, easiest case
auto load_vec4 = [&] {
if (!EmitExpression(pre, out, params[0])) { // buffer
return false;
}
out << "[" << scalar_offset << " / 4]";
return true;
};
switch (intrinsic->type) {
case DataType::kU32:
return load_scalar();
case DataType::kF32:
return cast("asfloat", load_scalar);
case DataType::kI32:
return cast("asint", load_scalar);
case DataType::kVec2U32:
return load_vec2();
case DataType::kVec2F32:
return cast("asfloat", load_vec2);
case DataType::kVec2I32:
return cast("asint", load_vec2);
case DataType::kVec3U32:
return load_vec3();
case DataType::kVec3F32:
return cast("asfloat", load_vec3);
case DataType::kVec3I32:
return cast("asint", load_vec3);
case DataType::kVec4U32:
return load_vec4();
case DataType::kVec4F32:
return cast("asfloat", load_vec4);
case DataType::kVec4I32:
return cast("asint", load_vec4);
}
TINT_UNREACHABLE(diagnostics_)
<< "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
<< static_cast<int>(intrinsic->type);
return false;
}
default:
break;
}
TINT_UNREACHABLE(diagnostics_)
<< "unsupported DecomposeMemoryAccess::Intrinsic::Op: "
<< static_cast<int>(intrinsic->op);
return false;
}
bool GeneratorImpl::EmitStorageBufferAccess(
std::ostream& pre,
std::ostream& out,
ast::CallExpression* expr,
const transform::DecomposeMemoryAccess::Intrinsic* intrinsic) {
const auto& params = expr->params();
using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
using DataType = transform::DecomposeMemoryAccess::Intrinsic::DataType;
switch (intrinsic->op) {
case Op::kLoad: {
auto load = [&](const char* cast, int n) {
@@ -565,7 +682,7 @@ bool GeneratorImpl::EmitDecomposeStorageAccessIntrinsic(
return load("asint", 4);
}
TINT_UNREACHABLE(diagnostics_)
<< "unsupported DecomposeStorageAccess::Intrinsic::DataType: "
<< "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
<< static_cast<int>(intrinsic->type);
return false;
}
@@ -617,7 +734,7 @@ bool GeneratorImpl::EmitDecomposeStorageAccessIntrinsic(
return store(4);
}
TINT_UNREACHABLE(diagnostics_)
<< "unsupported DecomposeStorageAccess::Intrinsic::DataType: "
<< "unsupported DecomposeMemoryAccess::Intrinsic::DataType: "
<< static_cast<int>(intrinsic->type);
return false;
}
@@ -636,7 +753,7 @@ bool GeneratorImpl::EmitDecomposeStorageAccessIntrinsic(
}
TINT_UNREACHABLE(diagnostics_)
<< "unsupported DecomposeStorageAccess::Intrinsic::Op: "
<< "unsupported DecomposeMemoryAccess::Intrinsic::Op: "
<< static_cast<int>(intrinsic->op);
return false;
}
@@ -645,19 +762,19 @@ bool GeneratorImpl::EmitStorageAtomicCall(
std::ostream& pre,
std::ostream& out,
ast::CallExpression* expr,
transform::DecomposeStorageAccess::Intrinsic::Op op) {
using Op = transform::DecomposeStorageAccess::Intrinsic::Op;
transform::DecomposeMemoryAccess::Intrinsic::Op op) {
using Op = transform::DecomposeMemoryAccess::Intrinsic::Op;
std::stringstream ss;
std::string result = generate_name("atomic_result");
auto* result_ty = TypeOf(expr);
if (!result_ty->Is<sem::Void>()) {
if (!EmitType(ss, TypeOf(expr), ast::StorageClass::kNone,
ast::Access::kUndefined, "")) {
if (!EmitTypeAndName(ss, TypeOf(expr), ast::StorageClass::kNone,
ast::Access::kUndefined, result)) {
return false;
}
ss << " " << result << " = ";
ss << " = ";
if (!EmitZeroValue(ss, result_ty)) {
return false;
}
@@ -693,11 +810,11 @@ bool GeneratorImpl::EmitStorageAtomicCall(
// InterlockedExchange and discard the returned value
auto* value = expr->params()[2];
auto* value_ty = TypeOf(value);
if (!EmitType(pre, value_ty, ast::StorageClass::kNone,
ast::Access::kUndefined, "")) {
if (!EmitTypeAndName(pre, value_ty, ast::StorageClass::kNone,
ast::Access::kUndefined, result)) {
return false;
}
pre << " " << result << " = ";
pre << " = ";
if (!EmitZeroValue(pre, value_ty)) {
return false;
}
@@ -725,11 +842,11 @@ bool GeneratorImpl::EmitStorageAtomicCall(
auto* value = expr->params()[3];
std::string compare = generate_name("atomic_compare_value");
if (!EmitType(ss, TypeOf(compare_value), ast::StorageClass::kNone,
ast::Access::kUndefined, "")) {
if (!EmitTypeAndName(ss, TypeOf(compare_value), ast::StorageClass::kNone,
ast::Access::kUndefined, compare)) {
return false;
}
ss << " " << compare << " = ";
ss << " = ";
if (!EmitExpression(pre, ss, compare_value)) {
return false;
}
@@ -805,7 +922,7 @@ bool GeneratorImpl::EmitStorageAtomicCall(
default:
TINT_UNREACHABLE(diagnostics_)
<< "unsupported atomic DecomposeStorageAccess::Intrinsic::Op: "
<< "unsupported atomic DecomposeMemoryAccess::Intrinsic::Op: "
<< static_cast<int>(op);
return false;
}
@@ -842,11 +959,11 @@ bool GeneratorImpl::EmitWorkgroupAtomicCall(std::ostream& pre,
std::string result = generate_name("atomic_result");
if (!intrinsic->ReturnType()->Is<sem::Void>()) {
if (!EmitType(ss, intrinsic->ReturnType(), ast::StorageClass::kNone,
ast::Access::kUndefined, "")) {
if (!EmitTypeAndName(ss, intrinsic->ReturnType(), ast::StorageClass::kNone,
ast::Access::kUndefined, result)) {
return false;
}
ss << " " << result << " = ";
ss << " = ";
if (!EmitZeroValue(ss, intrinsic->ReturnType())) {
return false;
}
@@ -875,11 +992,11 @@ bool GeneratorImpl::EmitWorkgroupAtomicCall(std::ostream& pre,
// HLSL does not have an InterlockedStore, so we emulate it with
// InterlockedExchange and discard the returned value
auto* value_ty = intrinsic->Parameters()[1].type;
if (!EmitType(pre, value_ty, ast::StorageClass::kNone,
ast::Access::kUndefined, "")) {
if (!EmitTypeAndName(pre, value_ty, ast::StorageClass::kNone,
ast::Access::kUndefined, result)) {
return false;
}
pre << " " << result << " = ";
pre << " = ";
if (!EmitZeroValue(pre, value_ty)) {
return false;
}
@@ -905,11 +1022,11 @@ bool GeneratorImpl::EmitWorkgroupAtomicCall(std::ostream& pre,
auto* value = expr->params()[2];
std::string compare = generate_name("atomic_compare_value");
if (!EmitType(ss, TypeOf(compare_value), ast::StorageClass::kNone,
ast::Access::kUndefined, "")) {
if (!EmitTypeAndName(ss, TypeOf(compare_value), ast::StorageClass::kNone,
ast::Access::kUndefined, compare)) {
return false;
}
ss << " " << compare << " = ";
ss << " = ";
if (!EmitExpression(pre, ss, compare_value)) {
return false;
}
@@ -1902,18 +2019,16 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
}
// Note: WGSL only allows for StorageClass::kNone on parameters, however the
// sanitizer transforms generates load / store functions for storage
// buffers. These functions have a storage buffer parameter with
// StorageClass::kStorage. This is required to correctly translate the
// parameter to [RW]ByteAddressBuffer.
if (!EmitType(out, type, v->StorageClass(), v->Access(),
builder_.Symbols().NameFor(v->Declaration()->symbol()))) {
// sanitizer transforms generates load / store functions for storage or
// uniform buffers. These functions have a buffer parameter with
// StorageClass::kStorage or StorageClass::kUniform. This is required to
// correctly translate the parameter to a [RW]ByteAddressBuffer for storage
// buffers and a uint4[N] for uniform buffers.
if (!EmitTypeAndName(
out, type, v->StorageClass(), v->Access(),
builder_.Symbols().NameFor(v->Declaration()->symbol()))) {
return false;
}
// Array name is output as part of the type
if (!type->Is<sem::Array>()) {
out << " " << builder_.Symbols().NameFor(v->Declaration()->symbol());
}
}
out << ") ";
@@ -1959,26 +2074,28 @@ bool GeneratorImpl::EmitUniformVariable(std::ostream& out,
auto binding_point = decl->binding_point();
auto* type = var->Type()->UnwrapRef();
if (auto* strct = type->As<sem::Struct>()) {
out << "ConstantBuffer<"
<< builder_.Symbols().NameFor(strct->Declaration()->name()) << "> "
<< builder_.Symbols().NameFor(decl->symbol())
<< RegisterAndSpace('b', binding_point) << ";" << std::endl;
} else {
auto name = "cbuffer_" + builder_.Symbols().NameFor(decl->symbol());
out << "cbuffer " << name << RegisterAndSpace('b', binding_point) << " {"
<< std::endl;
increment_indent();
make_indent(out);
if (!EmitType(out, type, var->StorageClass(), var->Access(), "")) {
return false;
}
out << " " << builder_.Symbols().NameFor(decl->symbol()) << ";"
<< std::endl;
decrement_indent();
out << "};" << std::endl;
auto* str = type->As<sem::Struct>();
if (!str) {
// https://www.w3.org/TR/WGSL/#module-scope-variables
TINT_ICE(diagnostics_)
<< "variables with uniform storage must be structure";
}
auto name = builder_.Symbols().NameFor(decl->symbol());
out << "cbuffer cbuffer_" << name << RegisterAndSpace('b', binding_point)
<< " {" << std::endl;
increment_indent();
make_indent(out);
if (!EmitTypeAndName(out, type, ast::StorageClass::kUniform, var->Access(),
name)) {
return false;
}
out << ";" << std::endl;
decrement_indent();
out << "};" << std::endl;
return true;
}
@@ -1988,12 +2105,12 @@ bool GeneratorImpl::EmitStorageVariable(std::ostream& out,
auto* decl = var->Declaration();
auto* type = var->Type()->UnwrapRef();
if (!EmitType(out, type, ast::StorageClass::kStorage, var->Access(), "")) {
if (!EmitTypeAndName(out, type, ast::StorageClass::kStorage, var->Access(),
builder_.Symbols().NameFor(decl->symbol()))) {
return false;
}
out << " " << builder_.Symbols().NameFor(decl->symbol())
<< RegisterAndSpace(var->Access() == ast::Access::kRead ? 't' : 'u',
out << RegisterAndSpace(var->Access() == ast::Access::kRead ? 't' : 'u',
decl->binding_point())
<< ";" << std::endl;
@@ -2016,12 +2133,9 @@ bool GeneratorImpl::EmitHandleVariable(std::ostream& out,
auto name = builder_.Symbols().NameFor(decl->symbol());
auto* type = var->Type()->UnwrapRef();
if (!EmitType(out, type, var->StorageClass(), var->Access(), name)) {
if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << name;
}
const char* register_space = nullptr;
@@ -2067,12 +2181,9 @@ bool GeneratorImpl::EmitPrivateVariable(std::ostream& out,
auto name = builder_.Symbols().NameFor(decl->symbol());
auto* type = var->Type()->UnwrapRef();
if (!EmitType(out, type, var->StorageClass(), var->Access(), name)) {
if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << name;
}
if (constructor_out.str().length()) {
out << " = " << constructor_out.str();
@@ -2099,12 +2210,9 @@ bool GeneratorImpl::EmitWorkgroupVariable(std::ostream& out,
auto name = builder_.Symbols().NameFor(decl->symbol());
auto* type = var->Type()->UnwrapRef();
if (!EmitType(out, type, var->StorageClass(), var->Access(), name)) {
if (!EmitTypeAndName(out, type, var->StorageClass(), var->Access(), name)) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << name;
}
if (constructor_out.str().length()) {
out << " = " << constructor_out.str();
@@ -2193,11 +2301,10 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
}
first = false;
if (!EmitType(out, type, sem->StorageClass(), sem->Access(), "")) {
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
out << " " << builder_.Symbols().NameFor(var->symbol());
}
out << ") {" << std::endl;
@@ -2467,13 +2574,31 @@ bool GeneratorImpl::EmitType(std::ostream& out,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
const std::string& name) {
if (storage_class == ast::StorageClass::kStorage) {
if (access != ast::Access::kRead) {
out << "RW";
const std::string& name,
bool* name_printed /* = nullptr */) {
switch (storage_class) {
case ast::StorageClass::kStorage:
if (access != ast::Access::kRead) {
out << "RW";
}
out << "ByteAddressBuffer";
return true;
case ast::StorageClass::kUniform: {
auto* str = type->As<sem::Struct>();
if (!str) {
// https://www.w3.org/TR/WGSL/#module-scope-variables
TINT_ICE(diagnostics_)
<< "variables with uniform storage must be structure";
}
auto array_length = (str->Size() + 15) / 16;
out << "uint4 " << name << "[" << array_length << "]";
if (name_printed) {
*name_printed = true;
}
return true;
}
out << "ByteAddressBuffer";
return true;
default:
break;
}
if (auto* ary = type->As<sem::Array>()) {
@@ -2494,6 +2619,9 @@ bool GeneratorImpl::EmitType(std::ostream& out,
}
if (!name.empty()) {
out << " " << name;
if (name_printed) {
*name_printed = true;
}
}
for (uint32_t size : sizes) {
out << "[" << size << "]";
@@ -2620,44 +2748,35 @@ bool GeneratorImpl::EmitType(std::ostream& out,
return true;
}
bool GeneratorImpl::EmitTypeAndName(std::ostream& out,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
const std::string& name) {
bool printed_name = false;
if (!EmitType(out, type, storage_class, access, name, &printed_name)) {
return false;
}
if (!name.empty() && !printed_name) {
out << " " << name;
}
return true;
}
bool GeneratorImpl::EmitStructType(std::ostream& out, const sem::Struct* str) {
auto storage_class_uses = str->StorageClassUsage();
if (storage_class_uses.size() ==
storage_class_uses.count(ast::StorageClass::kStorage)) {
// The only use of the structure is as a storage buffer.
(storage_class_uses.count(ast::StorageClass::kStorage) +
storage_class_uses.count(ast::StorageClass::kUniform))) {
// The only use of the structure is as a storage buffer and / or uniform
// buffer.
// Structures used as storage buffer are read and written to via a
// ByteAddressBuffer instead of true structure.
// Structures used as uniform buffer are read from an array of vectors
// instead of true structure.
return true;
}
bool is_host_shareable = str->IsHostShareable();
uint32_t hlsl_offset = 0;
// Emits a `/* 0xnnnn */` byte offset comment for a struct member.
auto add_byte_offset_comment = [&](uint32_t offset) {
std::ios_base::fmtflags saved_flag_state(out.flags());
out << "/* 0x" << std::hex << std::setfill('0') << std::setw(4) << offset
<< " */ ";
out.flags(saved_flag_state);
};
uint32_t pad_count = 0;
auto add_padding = [&](uint32_t size) {
if (size & 3) {
TINT_ICE(builder_.Diagnostics())
<< "attempting to pad field with " << size
<< " bytes, but we require a multiple of 4 bytes";
return false;
}
std::string name;
do {
name = "tint_pad_" + std::to_string(pad_count++);
} while (str->FindMember(builder_.Symbols().Get(name)));
out << "int " << name << "[" << (size / 4) << "];" << std::endl;
return true;
};
auto struct_name = builder_.Symbols().NameFor(str->Declaration()->name());
out << "struct " << struct_name << " {" << std::endl;
@@ -2666,40 +2785,13 @@ bool GeneratorImpl::EmitStructType(std::ostream& out, const sem::Struct* str) {
make_indent(out);
auto name = builder_.Symbols().NameFor(mem->Declaration()->symbol());
auto wgsl_offset = mem->Offset();
if (is_host_shareable) {
if (wgsl_offset < hlsl_offset) {
// Unimplementable layout
TINT_ICE(diagnostics_)
<< "Structure member WGSL offset (" << wgsl_offset
<< ") is behind HLSL offset (" << hlsl_offset << ")";
return false;
}
// Generate padding if required
if (auto padding = wgsl_offset - hlsl_offset) {
add_byte_offset_comment(hlsl_offset);
if (!add_padding(padding)) {
return false;
}
hlsl_offset += padding;
make_indent(out);
}
add_byte_offset_comment(hlsl_offset);
}
auto* ty = mem->Type();
if (!EmitType(out, ty, ast::StorageClass::kNone, ast::Access::kReadWrite,
name)) {
if (!EmitTypeAndName(out, ty, ast::StorageClass::kNone,
ast::Access::kReadWrite, name)) {
return false;
}
// Array member name will be output with the type
if (!ty->Is<sem::Array>()) {
out << " " << name;
}
for (auto* deco : mem->Declaration()->decorations()) {
if (auto* location = deco->As<ast::LocationDecoration>()) {
@@ -2733,26 +2825,6 @@ bool GeneratorImpl::EmitStructType(std::ostream& out, const sem::Struct* str) {
}
out << ";" << std::endl;
if (is_host_shareable) {
// Calculate new HLSL offset
auto size_align = HlslPackedTypeSizeAndAlign(ty);
if (hlsl_offset % size_align.align) {
TINT_ICE(diagnostics_)
<< "Misaligned HLSL structure member "
<< ty->FriendlyName(builder_.Symbols()) << " " << name;
return false;
}
hlsl_offset += size_align.size;
}
}
if (is_host_shareable && str->Size() != hlsl_offset) {
make_indent(out);
add_byte_offset_comment(hlsl_offset);
if (!add_padding(str->Size() - hlsl_offset)) {
return false;
}
}
decrement_indent();
@@ -2760,57 +2832,6 @@ bool GeneratorImpl::EmitStructType(std::ostream& out, const sem::Struct* str) {
out << "};" << std::endl;
// If the structure has padding members, create a helper function for building
// the structure.
if (pad_count) {
auto builder_name = generate_name("make_" + struct_name);
out << std::endl;
out << struct_name << " " << builder_name << "(";
uint32_t idx = 0;
for (auto* mem : str->Members()) {
if (idx > 0) {
out << ",";
make_indent(out << std::endl);
out << std::string(struct_name.length() + builder_name.length() + 2,
' ');
}
auto name = "param_" + std::to_string(idx++);
auto* ty = mem->Type();
if (!EmitType(out, ty, ast::StorageClass::kNone, ast::Access::kReadWrite,
name)) {
return false;
}
// Array member name will be output with the type
if (!ty->Is<sem::Array>()) {
out << " " << name;
}
}
out << ") {";
increment_indent();
make_indent(out << std::endl);
out << struct_name << " output;";
make_indent(out << std::endl);
idx = 0;
for (auto* mem : str->Members()) {
out << "output."
<< builder_.Symbols().NameFor(mem->Declaration()->symbol()) << " = "
<< "param_" + std::to_string(idx++) << ";";
make_indent(out << std::endl);
}
out << "return output;";
decrement_indent();
make_indent(out << std::endl);
out << "}";
make_indent(out << std::endl);
structure_builders_[str] = builder_name;
}
return true;
}
@@ -2872,13 +2893,10 @@ bool GeneratorImpl::EmitVariable(std::ostream& out, ast::Variable* var) {
if (var->is_const()) {
out << "const ";
}
if (!EmitType(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol()))) {
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << builder_.Symbols().NameFor(var->symbol());
}
out << constructor_out.str() << ";" << std::endl;
return true;
@@ -2925,21 +2943,17 @@ bool GeneratorImpl::EmitProgramConstVariable(std::ostream& out,
}
out << "#endif" << std::endl;
out << "static const ";
if (!EmitType(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol()))) {
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
out << " " << builder_.Symbols().NameFor(var->symbol()) << " = "
<< kSpecConstantPrefix << const_id << ";" << std::endl;
out << " = " << kSpecConstantPrefix << const_id << ";" << std::endl;
} else {
out << "static const ";
if (!EmitType(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol()))) {
if (!EmitTypeAndName(out, type, sem->StorageClass(), sem->Access(),
builder_.Symbols().NameFor(var->symbol()))) {
return false;
}
if (!type->Is<sem::Array>()) {
out << " " << builder_.Symbols().NameFor(var->symbol());
}
if (var->constructor() != nullptr) {
out << " = " << constructor_out.str();
@@ -2982,63 +2996,6 @@ bool GeneratorImpl::EmitBlockBraces(std::ostream& out,
return true;
}
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-packing-rules
// TODO(crbug.com/tint/898): We need CTS and / or Dawn e2e tests for this logic.
GeneratorImpl::SizeAndAlign GeneratorImpl::HlslPackedTypeSizeAndAlign(
const sem::Type* ty) {
if (ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
return {4, 4};
}
if (auto* vec = ty->As<sem::Vector>()) {
auto num_els = vec->size();
auto* el_ty = vec->type();
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
return SizeAndAlign{num_els * 4, 4};
}
}
if (auto* mat = ty->As<sem::Matrix>()) {
auto cols = mat->columns();
auto rows = mat->rows();
auto* el_ty = mat->type();
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
static constexpr SizeAndAlign table[] = {
/* float2x2 */ {16, 8},
/* float2x3 */ {32, 16},
/* float2x4 */ {32, 16},
/* float3x2 */ {24, 8},
/* float3x3 */ {48, 16},
/* float3x4 */ {48, 16},
/* float4x2 */ {32, 8},
/* float4x3 */ {64, 16},
/* float4x4 */ {64, 16},
};
if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
return table[(3 * (cols - 2)) + (rows - 2)];
}
}
}
if (auto* arr = ty->As<sem::Array>()) {
auto el_size_align = HlslPackedTypeSizeAndAlign(arr->ElemType());
if (!arr->IsStrideImplicit()) {
TINT_ICE(diagnostics_) << "arrays with explicit strides should have "
"removed with the PadArrayElements transform";
return {};
}
auto num_els = std::max<uint32_t>(arr->Count(), 1);
return SizeAndAlign{el_size_align.size * num_els, el_size_align.align};
}
if (auto* str = ty->As<sem::Struct>()) {
return SizeAndAlign{str->Size(), str->Align()};
}
TINT_UNREACHABLE(diagnostics_) << "Unhandled type " << ty->TypeInfo().name;
return {};
}
} // namespace hlsl
} // namespace writer
} // namespace tint

View File

@@ -32,7 +32,7 @@
#include "src/ast/unary_op_expression.h"
#include "src/program_builder.h"
#include "src/scope_stack.h"
#include "src/transform/decompose_storage_access.h"
#include "src/transform/decompose_memory_access.h"
#include "src/writer/text_generator.h"
namespace tint {
@@ -117,17 +117,29 @@ class GeneratorImpl : public TextGenerator {
std::ostream& out,
ast::CallExpression* expr);
/// Handles generating a call expression to a
/// transform::DecomposeStorageAccess::Intrinsic
/// transform::DecomposeMemoryAccess::Intrinsic for a uniform buffer
/// @param pre the preamble for the expression stream
/// @param out the output of the expression stream
/// @param expr the call expression
/// @param intrinsic the transform::DecomposeStorageAccess::Intrinsic
/// @param intrinsic the transform::DecomposeMemoryAccess::Intrinsic
/// @returns true if the call expression is emitted
bool EmitDecomposeStorageAccessIntrinsic(
bool EmitUniformBufferAccess(
std::ostream& pre,
std::ostream& out,
ast::CallExpression* expr,
const transform::DecomposeStorageAccess::Intrinsic* intrinsic);
const transform::DecomposeMemoryAccess::Intrinsic* intrinsic);
/// Handles generating a call expression to a
/// transform::DecomposeMemoryAccess::Intrinsic for a storage buffer
/// @param pre the preamble for the expression stream
/// @param out the output of the expression stream
/// @param expr the call expression
/// @param intrinsic the transform::DecomposeMemoryAccess::Intrinsic
/// @returns true if the call expression is emitted
bool EmitStorageBufferAccess(
std::ostream& pre,
std::ostream& out,
ast::CallExpression* expr,
const transform::DecomposeMemoryAccess::Intrinsic* intrinsic);
/// Handles generating a barrier intrinsic call
/// @param pre the preamble for the expression stream
/// @param out the output of the expression stream
@@ -146,7 +158,7 @@ class GeneratorImpl : public TextGenerator {
std::ostream& pre,
std::ostream& out,
ast::CallExpression* expr,
transform::DecomposeStorageAccess::Intrinsic::Op op);
transform::DecomposeMemoryAccess::Intrinsic::Op op);
/// Handles generating an atomic intrinsic call for a workgroup variable
/// @param pre the preamble for the expression stream
/// @param out the output of the expression stream
@@ -361,13 +373,28 @@ class GeneratorImpl : public TextGenerator {
/// @param type the type to generate
/// @param storage_class the storage class of the variable
/// @param access the access control type of the variable
/// @param name the name of the variable, only used for array emission
/// @param name the name of the variable, used for array emission.
/// @param name_printed (optional) if not nullptr and an array was printed
/// then the boolean is set to true.
/// @returns true if the type is emitted
bool EmitType(std::ostream& out,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
const std::string& name);
const std::string& name,
bool* name_printed = nullptr);
/// Handles generating type and name
/// @param out the output stream
/// @param type the type to generate
/// @param storage_class the storage class of the variable
/// @param access the access control type of the variable
/// @param name the name of the variable, used for array emission.
/// @returns true if the type is emitted
bool EmitTypeAndName(std::ostream& out,
const sem::Type* type,
ast::StorageClass storage_class,
ast::Access access,
const std::string& name);
/// Handles generating a structure declaration
/// @param out the output stream
/// @param ty the struct to generate
@@ -460,16 +487,6 @@ class GeneratorImpl : public TextGenerator {
return EmitBlockBraces(out, "", std::forward<F>(cb));
}
// A pair of byte size and alignment `uint32_t`s.
struct SizeAndAlign {
uint32_t size;
uint32_t align;
};
/// @returns the HLSL packed type size and alignment in bytes for the given
/// type.
SizeAndAlign HlslPackedTypeSizeAndAlign(const sem::Type* ty);
ProgramBuilder builder_;
std::function<bool(std::ostream& out)> emit_continuing_;
std::unordered_map<const sem::Struct*, std::string> structure_builders_;

View File

@@ -341,12 +341,10 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct UBO {
/* 0x0000 */ float4 coord;
EXPECT_EQ(result(), R"(cbuffer cbuffer_ubo : register(b0, space1) {
uint4 ubo[1];
};
ConstantBuffer<UBO> ubo : register(b0, space1);
float sub_func(float param) {
return ubo.coord.x;
}
@@ -384,12 +382,10 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct Uniforms {
/* 0x0000 */ float4 coord;
EXPECT_EQ(result(), R"(cbuffer cbuffer_uniforms : register(b0, space1) {
uint4 uniforms[1];
};
ConstantBuffer<Uniforms> uniforms : register(b0, space1);
void frag_main() {
float v = uniforms.coord.x;
return;
@@ -583,12 +579,10 @@ TEST_F(HlslGeneratorImplTest_Function,
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_EQ(result(), R"(struct S {
/* 0x0000 */ float x;
EXPECT_EQ(result(), R"(cbuffer cbuffer_coord : register(b0, space1) {
uint4 coord[1];
};
ConstantBuffer<S> coord : register(b0, space1);
float sub_func(float param) {
return coord.x;
}

View File

@@ -93,31 +93,6 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Array_WithoutName) {
EXPECT_EQ(result(), "bool[4]");
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_ArrayWithStride) {
auto* s = Structure("s", {Member("arr", ty.array<f32, 4>(64))},
{create<ast::StructBlockDecoration>()});
auto* ubo = Global("ubo", ty.Of(s), ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::GroupDecoration>(1),
create<ast::BindingDecoration>(1),
});
WrapInFunction(MemberAccessor(ubo, "arr"));
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate(out)) << gen.error();
EXPECT_THAT(result(), HasSubstr(R"(struct tint_padded_array_element {
/* 0x0000 */ float el;
/* 0x0004 */ int tint_pad_0[15];
};)"));
EXPECT_THAT(result(), HasSubstr(R"(struct tint_array_wrapper {
/* 0x0000 */ tint_padded_array_element arr[4];
};)"));
EXPECT_THAT(result(), HasSubstr(R"(struct s {
/* 0x0000 */ tint_array_wrapper arr;
};)"));
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_Bool) {
auto* bool_ = create<sem::Bool>();
@@ -232,430 +207,6 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct) {
EXPECT_EQ(result(), "S");
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct_Layout_NonComposites) {
auto* s =
Structure("S",
{
Member("a", ty.i32(), {MemberSize(32)}),
Member("b", ty.f32(), {MemberAlign(128), MemberSize(128)}),
Member("c", ty.vec2<f32>()),
Member("d", ty.u32()),
Member("e", ty.vec3<f32>()),
Member("f", ty.u32()),
Member("g", ty.vec4<f32>()),
Member("h", ty.u32()),
Member("i", ty.mat2x2<f32>()),
Member("j", ty.u32()),
Member("k", ty.mat2x3<f32>()),
Member("l", ty.u32()),
Member("m", ty.mat2x4<f32>()),
Member("n", ty.u32()),
Member("o", ty.mat3x2<f32>()),
Member("p", ty.u32()),
Member("q", ty.mat3x3<f32>()),
Member("r", ty.u32()),
Member("s", ty.mat3x4<f32>()),
Member("t", ty.u32()),
Member("u", ty.mat4x2<f32>()),
Member("v", ty.u32()),
Member("w", ty.mat4x3<f32>()),
Member("x", ty.u32()),
Member("y", ty.mat4x4<f32>()),
Member("z", ty.f32()),
},
{create<ast::StructBlockDecoration>()});
Global("G", ty.Of(s), ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
GeneratorImpl& gen = Build();
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(out, sem_s)) << gen.error();
auto* expect = R"(struct S {
/* 0x0000 */ int a;
/* 0x0004 */ int tint_pad_0[31];
/* 0x0080 */ float b;
/* 0x0084 */ int tint_pad_1[31];
/* 0x0100 */ float2 c;
/* 0x0108 */ uint d;
/* 0x010c */ int tint_pad_2[1];
/* 0x0110 */ float3 e;
/* 0x011c */ uint f;
/* 0x0120 */ float4 g;
/* 0x0130 */ uint h;
/* 0x0134 */ int tint_pad_3[1];
/* 0x0138 */ float2x2 i;
/* 0x0148 */ uint j;
/* 0x014c */ int tint_pad_4[1];
/* 0x0150 */ float2x3 k;
/* 0x0170 */ uint l;
/* 0x0174 */ int tint_pad_5[3];
/* 0x0180 */ float2x4 m;
/* 0x01a0 */ uint n;
/* 0x01a4 */ int tint_pad_6[1];
/* 0x01a8 */ float3x2 o;
/* 0x01c0 */ uint p;
/* 0x01c4 */ int tint_pad_7[3];
/* 0x01d0 */ float3x3 q;
/* 0x0200 */ uint r;
/* 0x0204 */ int tint_pad_8[3];
/* 0x0210 */ float3x4 s;
/* 0x0240 */ uint t;
/* 0x0244 */ int tint_pad_9[1];
/* 0x0248 */ float4x2 u;
/* 0x0268 */ uint v;
/* 0x026c */ int tint_pad_10[1];
/* 0x0270 */ float4x3 w;
/* 0x02b0 */ uint x;
/* 0x02b4 */ int tint_pad_11[3];
/* 0x02c0 */ float4x4 y;
/* 0x0300 */ float z;
/* 0x0304 */ int tint_pad_12[31];
};
S make_S(int param_0,
float param_1,
float2 param_2,
uint param_3,
float3 param_4,
uint param_5,
float4 param_6,
uint param_7,
float2x2 param_8,
uint param_9,
float2x3 param_10,
uint param_11,
float2x4 param_12,
uint param_13,
float3x2 param_14,
uint param_15,
float3x3 param_16,
uint param_17,
float3x4 param_18,
uint param_19,
float4x2 param_20,
uint param_21,
float4x3 param_22,
uint param_23,
float4x4 param_24,
float param_25) {
S output;
output.a = param_0;
output.b = param_1;
output.c = param_2;
output.d = param_3;
output.e = param_4;
output.f = param_5;
output.g = param_6;
output.h = param_7;
output.i = param_8;
output.j = param_9;
output.k = param_10;
output.l = param_11;
output.m = param_12;
output.n = param_13;
output.o = param_14;
output.p = param_15;
output.q = param_16;
output.r = param_17;
output.s = param_18;
output.t = param_19;
output.u = param_20;
output.v = param_21;
output.w = param_22;
output.x = param_23;
output.y = param_24;
output.z = param_25;
return output;
}
)";
EXPECT_EQ(result(), expect);
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct_Layout_Structures) {
// inner_x: size(1024), align(512)
auto* inner_x =
Structure("inner_x", {
Member("a", ty.i32()),
Member("b", ty.f32(), {MemberAlign(512)}),
});
// inner_y: size(516), align(4)
auto* inner_y =
Structure("inner_y", {
Member("a", ty.i32(), {MemberSize(512)}),
Member("b", ty.f32()),
});
auto* s = Structure("S",
{
Member("a", ty.i32()),
Member("b", ty.Of(inner_x)),
Member("c", ty.f32()),
Member("d", ty.Of(inner_y)),
Member("e", ty.f32()),
},
{create<ast::StructBlockDecoration>()});
Global("G", ty.Of(s), ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
GeneratorImpl& gen = Build();
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(out, sem_s)) << gen.error();
auto* expect = R"(struct S {
/* 0x0000 */ int a;
/* 0x0004 */ int tint_pad_0[127];
/* 0x0200 */ inner_x b;
/* 0x0600 */ float c;
/* 0x0604 */ inner_y d;
/* 0x0808 */ float e;
/* 0x080c */ int tint_pad_1[125];
};
S make_S(int param_0,
inner_x param_1,
float param_2,
inner_y param_3,
float param_4) {
S output;
output.a = param_0;
output.b = param_1;
output.c = param_2;
output.d = param_3;
output.e = param_4;
return output;
}
)";
EXPECT_EQ(result(), expect);
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct_Layout_ArrayDefaultStride) {
// inner: size(1024), align(512)
auto* inner =
Structure("inner", {
Member("a", ty.i32()),
Member("b", ty.f32(), {MemberAlign(512)}),
});
// array_x: size(28), align(4)
auto* array_x = ty.array<f32, 7>();
// array_y: size(4096), align(512)
auto* array_y = ty.array(ty.Of(inner), 4);
// array_z: size(4), align(4)
auto* array_z = ty.array<f32, 1>();
auto* s =
Structure("S",
{
Member("a", ty.i32()),
Member("b", array_x),
Member("c", ty.f32()),
Member("d", array_y),
Member("e", ty.f32()),
Member("f", array_z),
},
ast::DecorationList{create<ast::StructBlockDecoration>()});
Global("G", ty.Of(s), ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
GeneratorImpl& gen = Build();
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(out, sem_s)) << gen.error();
auto* expect = R"(struct S {
/* 0x0000 */ int a;
/* 0x0004 */ float b[7];
/* 0x0020 */ float c;
/* 0x0024 */ int tint_pad_0[119];
/* 0x0200 */ inner d[4];
/* 0x1200 */ float e;
/* 0x1204 */ float f[1];
/* 0x1208 */ int tint_pad_1[126];
};
S make_S(int param_0,
float param_1[7],
float param_2,
inner param_3[4],
float param_4,
float param_5[1]) {
S output;
output.a = param_0;
output.b = param_1;
output.c = param_2;
output.d = param_3;
output.e = param_4;
output.f = param_5;
return output;
}
)";
EXPECT_EQ(result(), expect);
}
TEST_F(HlslGeneratorImplTest_Type, AttemptTintPadSymbolCollision) {
auto* s = Structure(
"S",
{
// uses symbols tint_pad_[0..9] and tint_pad_[20..35]
Member("tint_pad_2", ty.i32(), {MemberSize(32)}),
Member("tint_pad_20", ty.f32(), {MemberAlign(128), MemberSize(128)}),
Member("tint_pad_33", ty.vec2<f32>()),
Member("tint_pad_1", ty.u32()),
Member("tint_pad_3", ty.vec3<f32>()),
Member("tint_pad_7", ty.u32()),
Member("tint_pad_25", ty.vec4<f32>()),
Member("tint_pad_5", ty.u32()),
Member("tint_pad_27", ty.mat2x2<f32>()),
Member("tint_pad_24", ty.u32()),
Member("tint_pad_23", ty.mat2x3<f32>()),
Member("tint_pad_0", ty.u32()),
Member("tint_pad_8", ty.mat2x4<f32>()),
Member("tint_pad_26", ty.u32()),
Member("tint_pad_29", ty.mat3x2<f32>()),
Member("tint_pad_6", ty.u32()),
Member("tint_pad_22", ty.mat3x3<f32>()),
Member("tint_pad_32", ty.u32()),
Member("tint_pad_34", ty.mat3x4<f32>()),
Member("tint_pad_35", ty.u32()),
Member("tint_pad_30", ty.mat4x2<f32>()),
Member("tint_pad_9", ty.u32()),
Member("tint_pad_31", ty.mat4x3<f32>()),
Member("tint_pad_28", ty.u32()),
Member("tint_pad_4", ty.mat4x4<f32>()),
Member("tint_pad_21", ty.f32()),
},
{create<ast::StructBlockDecoration>()});
Global("G", ty.Of(s), ast::StorageClass::kUniform,
ast::DecorationList{
create<ast::BindingDecoration>(0),
create<ast::GroupDecoration>(0),
});
GeneratorImpl& gen = Build();
auto* sem_s = program->TypeOf(s)->As<sem::Struct>();
ASSERT_TRUE(gen.EmitStructType(out, sem_s)) << gen.error();
EXPECT_EQ(result(), R"(struct S {
/* 0x0000 */ int tint_pad_2;
/* 0x0004 */ int tint_pad_10[31];
/* 0x0080 */ float tint_pad_20;
/* 0x0084 */ int tint_pad_11[31];
/* 0x0100 */ float2 tint_pad_33;
/* 0x0108 */ uint tint_pad_1;
/* 0x010c */ int tint_pad_12[1];
/* 0x0110 */ float3 tint_pad_3;
/* 0x011c */ uint tint_pad_7;
/* 0x0120 */ float4 tint_pad_25;
/* 0x0130 */ uint tint_pad_5;
/* 0x0134 */ int tint_pad_13[1];
/* 0x0138 */ float2x2 tint_pad_27;
/* 0x0148 */ uint tint_pad_24;
/* 0x014c */ int tint_pad_14[1];
/* 0x0150 */ float2x3 tint_pad_23;
/* 0x0170 */ uint tint_pad_0;
/* 0x0174 */ int tint_pad_15[3];
/* 0x0180 */ float2x4 tint_pad_8;
/* 0x01a0 */ uint tint_pad_26;
/* 0x01a4 */ int tint_pad_16[1];
/* 0x01a8 */ float3x2 tint_pad_29;
/* 0x01c0 */ uint tint_pad_6;
/* 0x01c4 */ int tint_pad_17[3];
/* 0x01d0 */ float3x3 tint_pad_22;
/* 0x0200 */ uint tint_pad_32;
/* 0x0204 */ int tint_pad_18[3];
/* 0x0210 */ float3x4 tint_pad_34;
/* 0x0240 */ uint tint_pad_35;
/* 0x0244 */ int tint_pad_19[1];
/* 0x0248 */ float4x2 tint_pad_30;
/* 0x0268 */ uint tint_pad_9;
/* 0x026c */ int tint_pad_36[1];
/* 0x0270 */ float4x3 tint_pad_31;
/* 0x02b0 */ uint tint_pad_28;
/* 0x02b4 */ int tint_pad_37[3];
/* 0x02c0 */ float4x4 tint_pad_4;
/* 0x0300 */ float tint_pad_21;
/* 0x0304 */ int tint_pad_38[31];
};
S make_S(int param_0,
float param_1,
float2 param_2,
uint param_3,
float3 param_4,
uint param_5,
float4 param_6,
uint param_7,
float2x2 param_8,
uint param_9,
float2x3 param_10,
uint param_11,
float2x4 param_12,
uint param_13,
float3x2 param_14,
uint param_15,
float3x3 param_16,
uint param_17,
float3x4 param_18,
uint param_19,
float4x2 param_20,
uint param_21,
float4x3 param_22,
uint param_23,
float4x4 param_24,
float param_25) {
S output;
output.tint_pad_2 = param_0;
output.tint_pad_20 = param_1;
output.tint_pad_33 = param_2;
output.tint_pad_1 = param_3;
output.tint_pad_3 = param_4;
output.tint_pad_7 = param_5;
output.tint_pad_25 = param_6;
output.tint_pad_5 = param_7;
output.tint_pad_27 = param_8;
output.tint_pad_24 = param_9;
output.tint_pad_23 = param_10;
output.tint_pad_0 = param_11;
output.tint_pad_8 = param_12;
output.tint_pad_26 = param_13;
output.tint_pad_29 = param_14;
output.tint_pad_6 = param_15;
output.tint_pad_22 = param_16;
output.tint_pad_32 = param_17;
output.tint_pad_34 = param_18;
output.tint_pad_35 = param_19;
output.tint_pad_30 = param_20;
output.tint_pad_9 = param_21;
output.tint_pad_31 = param_22;
output.tint_pad_28 = param_23;
output.tint_pad_4 = param_24;
output.tint_pad_21 = param_25;
return output;
}
)");
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_Struct_NameCollision) {
auto* s = Structure("S", {
Member("double", ty.i32()),