writer/hlsl: Emit UBO as an array of vector

Instead of a ConstantBuffer.

HLSL requires that each structure field in a UBO is 16 byte aligned.
WGSL has much looser constraints with its UBO field alignment rules.

Instead generate an array of uint4 vectors, and index into this, much
like we index into [RW]ByteAddressBuffers for SSBOs.

Extend the DecomposeStorageAccess transform to support uniforms too.
This has been renamed to DecomposeMemoryAccess.

Change-Id: I3868ff80af1ab3b3dddfbf5b969724cb87ef0744
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55246
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton
2021-06-18 21:15:25 +00:00
parent 9efc4fcc89
commit 165512c57e
23 changed files with 1145 additions and 1393 deletions

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_storage_access.h"
#include "src/transform/decompose_memory_access.h"
#include <memory>
#include <string>
@@ -38,7 +38,7 @@
#include "src/utils/get_or_create.h"
#include "src/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeStorageAccess::Intrinsic);
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic);
namespace tint {
namespace transform {
@@ -46,7 +46,7 @@ namespace transform {
namespace {
/// Offset is a simple ast::Expression builder interface, used to build byte
/// offsets for storage buffer accesses.
/// offsets for storage and uniform buffer accesses.
struct Offset : Castable<Offset> {
/// @returns builds and returns the ast::Expression in `ctx.dst`
virtual ast::Expression* Build(CloneContext& ctx) = 0;
@@ -179,14 +179,16 @@ std::unique_ptr<Offset> Mul(LHS&& lhs_, RHS&& rhs_) {
/// LoadStoreKey is the unordered map key to a load or store intrinsic.
struct LoadStoreKey {
sem::Type const* buf_ty; // buffer type
sem::Type const* el_ty; // element type
ast::StorageClass const storage_class; // buffer storage class
sem::Type const* buf_ty; // buffer type
sem::Type const* el_ty; // element type
bool operator==(const LoadStoreKey& rhs) const {
return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty;
return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
el_ty == rhs.el_ty;
}
struct Hasher {
inline std::size_t operator()(const LoadStoreKey& u) const {
return utils::Hash(u.buf_ty, u.el_ty);
return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
}
};
};
@@ -218,60 +220,60 @@ uint32_t MatrixColumnStride(const sem::Matrix* mat) {
}
bool IntrinsicDataTypeFor(const sem::Type* ty,
DecomposeStorageAccess::Intrinsic::DataType& out) {
DecomposeMemoryAccess::Intrinsic::DataType& out) {
if (ty->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kI32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kI32;
return true;
}
if (ty->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kU32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kU32;
return true;
}
if (ty->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kF32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kF32;
return true;
}
if (auto* vec = ty->As<sem::Vector>()) {
switch (vec->size()) {
case 2:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec2I32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec2U32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec2F32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec2F32;
return true;
}
break;
case 3:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec3I32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec3U32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec3F32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec3F32;
return true;
}
break;
case 4:
if (vec->type()->Is<sem::I32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec4I32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4I32;
return true;
}
if (vec->type()->Is<sem::U32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec4U32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4U32;
return true;
}
if (vec->type()->Is<sem::F32>()) {
out = DecomposeStorageAccess::Intrinsic::DataType::kVec4F32;
out = DecomposeMemoryAccess::Intrinsic::DataType::kVec4F32;
return true;
}
break;
@@ -282,80 +284,86 @@ bool IntrinsicDataTypeFor(const sem::Type* ty,
return false;
}
/// @returns a DecomposeStorageAccess::Intrinsic decoration that can be applied
/// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
/// to a stub function to load the type `ty`.
DecomposeStorageAccess::Intrinsic* IntrinsicLoadFor(ProgramBuilder* builder,
const sem::Type* ty) {
DecomposeStorageAccess::Intrinsic::DataType type;
DecomposeMemoryAccess::Intrinsic* IntrinsicLoadFor(
ProgramBuilder* builder,
ast::StorageClass storage_class,
const sem::Type* ty) {
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
builder->ID(), DecomposeStorageAccess::Intrinsic::Op::kLoad, type);
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kLoad, storage_class,
type);
}
/// @returns a DecomposeStorageAccess::Intrinsic decoration that can be applied
/// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
/// to a stub function to store the type `ty`.
DecomposeStorageAccess::Intrinsic* IntrinsicStoreFor(ProgramBuilder* builder,
const sem::Type* ty) {
DecomposeStorageAccess::Intrinsic::DataType type;
DecomposeMemoryAccess::Intrinsic* IntrinsicStoreFor(
ProgramBuilder* builder,
ast::StorageClass storage_class,
const sem::Type* ty) {
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
builder->ID(), DecomposeStorageAccess::Intrinsic::Op::kStore, type);
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), DecomposeMemoryAccess::Intrinsic::Op::kStore,
storage_class, type);
}
/// @returns a DecomposeStorageAccess::Intrinsic decoration that can be applied
/// @returns a DecomposeMemoryAccess::Intrinsic decoration that can be applied
/// to a stub function for the atomic op and the type `ty`.
DecomposeStorageAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
sem::IntrinsicType ity,
const sem::Type* ty) {
auto op = DecomposeStorageAccess::Intrinsic::Op::kAtomicLoad;
DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
sem::IntrinsicType ity,
const sem::Type* ty) {
auto op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
switch (ity) {
case sem::IntrinsicType::kAtomicLoad:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicLoad;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicLoad;
break;
case sem::IntrinsicType::kAtomicStore:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicStore;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicStore;
break;
case sem::IntrinsicType::kAtomicAdd:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicAdd;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAdd;
break;
case sem::IntrinsicType::kAtomicMax:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicMax;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMax;
break;
case sem::IntrinsicType::kAtomicMin:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicMin;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicMin;
break;
case sem::IntrinsicType::kAtomicAnd:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicAnd;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicAnd;
break;
case sem::IntrinsicType::kAtomicOr:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicOr;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicOr;
break;
case sem::IntrinsicType::kAtomicXor:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicXor;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicXor;
break;
case sem::IntrinsicType::kAtomicExchange:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicExchange;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicExchange;
break;
case sem::IntrinsicType::kAtomicCompareExchangeWeak:
op = DecomposeStorageAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
op = DecomposeMemoryAccess::Intrinsic::Op::kAtomicCompareExchangeWeak;
break;
default:
TINT_ICE(builder->Diagnostics())
<< "invalid IntrinsicType for DecomposeStorageAccess::Intrinsic: "
<< "invalid IntrinsicType for DecomposeMemoryAccess::Intrinsic: "
<< ty->type_name();
break;
}
DecomposeStorageAccess::Intrinsic::DataType type;
DecomposeMemoryAccess::Intrinsic::DataType type;
if (!IntrinsicDataTypeFor(ty, type)) {
return nullptr;
}
return builder->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
builder->ID(), op, type);
return builder->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
builder->ID(), op, ast::StorageClass::kStorage, type);
}
/// Inserts `node` before `insert_after` in the global declarations of
@@ -387,30 +395,30 @@ const ast::TypeDecl* TypeDeclOf(const sem::Type* ty) {
}
}
/// StorageBufferAccess describes a single storage buffer access
struct StorageBufferAccess {
/// BufferAccess describes a single storage or uniform buffer access
struct BufferAccess {
sem::Expression const* var = nullptr; // Storage buffer variable
std::unique_ptr<Offset> offset; // The byte offset on var
sem::Type const* type = nullptr; // The type of the access
operator bool() const { return var; } // Returns true if valid
};
/// Store describes a single storage buffer write
/// Store describes a single storage or uniform buffer write
struct Store {
ast::AssignmentStatement* assignment; // The AST assignment statement
StorageBufferAccess target; // The target for the write
BufferAccess target; // The target for the write
};
} // namespace
/// State holds the current transform state
struct DecomposeStorageAccess::State {
/// Map of AST expression to storage buffer access
struct DecomposeMemoryAccess::State {
/// Map of AST expression to storage or uniform buffer access
/// This map has entries added when encountered, and removed when outer
/// expressions chain the access.
/// Subset of #expression_order, as expressions are not removed from
/// #expression_order.
std::unordered_map<ast::Expression*, StorageBufferAccess> accesses;
std::unordered_map<ast::Expression*, BufferAccess> accesses;
/// The visited order of AST expressions (superset of #accesses)
std::vector<ast::Expression*> expression_order;
/// [buffer-type, element-type] -> load function name
@@ -419,25 +427,25 @@ struct DecomposeStorageAccess::State {
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> store_funcs;
/// [buffer-type, element-type, atomic-op] -> load function name
std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
/// List of storage buffer writes
/// List of storage or uniform buffer writes
std::vector<Store> stores;
/// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
/// to #expression_order.
/// @param expr the expression that performs the access
/// @param access the access
void AddAccess(ast::Expression* expr, StorageBufferAccess&& access) {
void AddAccess(ast::Expression* expr, BufferAccess&& access) {
TINT_ASSERT(access.type);
accesses.emplace(expr, std::move(access));
expression_order.emplace_back(expr);
}
/// TakeAccess() removes the `node` item from #accesses (if it exists),
/// returning the StorageBufferAccess. If #accesses does not hold an item for
/// `node`, an invalid StorageBufferAccess is returned.
/// returning the BufferAccess. If #accesses does not hold an item for
/// `node`, an invalid BufferAccess is returned.
/// @param node the expression that performed an access
/// @return the StorageBufferAccess for the given expression
StorageBufferAccess TakeAccess(ast::Expression* node) {
/// @return the BufferAccess for the given expression
BufferAccess TakeAccess(ast::Expression* node) {
auto lhs_it = accesses.find(node);
if (lhs_it == accesses.end()) {
return {};
@@ -448,12 +456,13 @@ struct DecomposeStorageAccess::State {
}
/// LoadFunc() returns a symbol to an intrinsic function that loads an element
/// of type `el_ty` from a storage buffer of type `buf_ty`. The function has
/// the signature: `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
/// The emitted function has the signature:
/// `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @param buf_ty the storage or uniform buffer type
/// @param el_ty the storage or uniform buffer element type
/// @param var_user the variable user
/// @return the name of the function that performs the load
Symbol LoadFunc(CloneContext& ctx,
@@ -461,71 +470,79 @@ struct DecomposeStorageAccess::State {
const sem::Type* buf_ty,
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
return utils::GetOrCreate(load_funcs, LoadStoreKey{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(
load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
ast::VariableList params = {
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
};
ast::VariableList params = {
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer or cbuffer
// array.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), storage_class,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty, nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
} else {
ast::ExpressionList values;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load = LoadFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
ast::Function* func = nullptr;
if (auto* intrinsic =
IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty, nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes()
.Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
} else {
ast::ExpressionList values;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load =
LoadFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
}
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty,
ctx.dst->Block(ctx.dst->Return(
ctx.dst->create<ast::TypeConstructorExpression>(
CreateASTTypeFor(&ctx, el_ty), values))),
ast::DecorationList{}, ast::DecorationList{});
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
}
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, el_ast_ty,
ctx.dst->Block(
ctx.dst->Return(ctx.dst->create<ast::TypeConstructorExpression>(
CreateASTTypeFor(&ctx, el_ty), values))),
ast::DecorationList{}, ast::DecorationList{});
}
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
}
/// StoreFunc() returns a symbol to an intrinsic function that stores an
/// element of type `el_ty` to a storage buffer of type `buf_ty`. The function
/// has the signature: `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// element of type `el_ty` to a storage buffer of type `buf_ty`.
/// The function has the signature:
/// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
@@ -537,70 +554,79 @@ struct DecomposeStorageAccess::State {
const sem::Type* buf_ty,
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
return utils::GetOrCreate(store_funcs, LoadStoreKey{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
ast::VariableList params{
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
ctx.dst->Param("value", el_ast_ty),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, el_ty)) {
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(), nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
ctx.dst->ID(), ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(
store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
ast::VariableList params{
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), storage_class,
var_user->Variable()->Access(), buf_ast_ty, true, nullptr,
ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
ctx.dst->Param("value", el_ast_ty),
};
ast::Function* func = nullptr;
if (auto* intrinsic =
IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(), nullptr,
ast::DecorationList{
intrinsic,
ctx.dst->ASTNodes()
.Create<ast::DisableValidationDecoration>(
ctx.dst->ID(),
ast::DisabledValidation::kFunctionHasNoBody),
},
ast::DecorationList{});
} else {
ast::StatementList body;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
auto* access = ctx.dst->IndexAccessor("value", i);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
} else {
ast::StatementList body;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store =
StoreFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset =
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
auto* access = ctx.dst->IndexAccessor("value", i);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
auto* access = ctx.dst->MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol()));
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
auto* access =
ctx.dst->IndexAccessor("value", ctx.dst->Expr(i));
Symbol store =
StoreFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
}
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(),
ctx.dst->Block(body), ast::DecorationList{},
ast::DecorationList{});
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
auto* access = ctx.dst->MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol()));
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
} else if (auto* arr = el_ty->As<sem::Array>()) {
for (uint32_t i = 0; i < arr->Count(); i++) {
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
auto* access = ctx.dst->IndexAccessor("value", ctx.dst->Expr(i));
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
arr->ElemType()->UnwrapRef(), var_user);
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
}
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.dst->ty.void_(), ctx.dst->Block(body),
ast::DecorationList{}, ast::DecorationList{});
}
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
InsertGlobal(ctx, insert_after, func);
return func->symbol();
});
}
/// AtomicFunc() returns a symbol to an intrinsic function that performs an
@@ -667,12 +693,13 @@ struct DecomposeStorageAccess::State {
}
};
DecomposeStorageAccess::Intrinsic::Intrinsic(ProgramID program_id,
Op o,
DataType ty)
: Base(program_id), op(o), type(ty) {}
DecomposeStorageAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeStorageAccess::Intrinsic::InternalName() const {
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID program_id,
Op o,
ast::StorageClass sc,
DataType ty)
: Base(program_id), op(o), storage_class(sc), type(ty) {}
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
std::stringstream ss;
switch (op) {
case Op::kLoad:
@@ -712,6 +739,7 @@ std::string DecomposeStorageAccess::Intrinsic::InternalName() const {
ss << "intrinsic_atomic_compare_exchange_weak_";
break;
}
ss << storage_class << "_";
switch (type) {
case DataType::kU32:
ss << "u32";
@@ -753,16 +781,16 @@ std::string DecomposeStorageAccess::Intrinsic::InternalName() const {
return ss.str();
}
DecomposeStorageAccess::Intrinsic* DecomposeStorageAccess::Intrinsic::Clone(
DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
CloneContext* ctx) const {
return ctx->dst->ASTNodes().Create<DecomposeStorageAccess::Intrinsic>(
ctx->dst->ID(), op, type);
return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
ctx->dst->ID(), op, storage_class, type);
}
DecomposeStorageAccess::DecomposeStorageAccess() = default;
DecomposeStorageAccess::~DecomposeStorageAccess() = default;
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
Output DecomposeMemoryAccess::Run(const Program* in, const DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
@@ -770,9 +798,10 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
State state;
// Scan the AST nodes for storage buffer accesses. Complex expression chains
// (e.g. `storage_buffer.foo.bar[20].x`) are handled by maintaining an offset
// chain via the `state.TakeAccess()`, `state.AddAccess()` methods.
// Scan the AST nodes for storage and uniform buffer accesses. Complex
// expression chains (e.g. `storage_buffer.foo.bar[20].x`) are handled by
// maintaining an offset chain via the `state.TakeAccess()`,
// `state.AddAccess()` methods.
//
// Inner-most expression nodes are guaranteed to be visited first because AST
// nodes are fully immutable and require their children to be constructed
@@ -781,8 +810,9 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
// X
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage) {
// Variable to a storage buffer
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage ||
var->Variable()->StorageClass() == ast::StorageClass::kUniform) {
// Variable to a storage or uniform buffer
state.AddAccess(ident, {
var,
ToOffset(0u),
@@ -940,7 +970,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
ctx.Replace(expr, load);
}
// And replace all storage buffer assignments with stores
// And replace all storage and uniform buffer assignments with stores
for (auto& store : state.stores) {
auto* buf = store.target.var->Declaration();
auto* offset = store.target.offset->Build(ctx);

View File

@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TRANSFORM_DECOMPOSE_STORAGE_ACCESS_H_
#define SRC_TRANSFORM_DECOMPOSE_STORAGE_ACCESS_H_
#ifndef SRC_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
#define SRC_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_
#include <string>
@@ -27,10 +27,10 @@ class CloneContext;
namespace transform {
/// DecomposeStorageAccess is a transform used to replace storage buffer
/// accesses with a combination of load, store or atomic functions on primitive
/// types.
class DecomposeStorageAccess : public Transform {
/// DecomposeMemoryAccess is a transform used to replace storage and uniform
/// buffer accesses with a combination of load, store or atomic functions on
/// primitive types.
class DecomposeMemoryAccess : public Transform {
public:
/// Intrinsic is an InternalDecoration that's used to decorate a stub function
/// so that the HLSL transforms this into calls to
@@ -73,8 +73,9 @@ class DecomposeStorageAccess : public Transform {
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param o the op of the intrinsic
/// @param sc the storage class of the buffer
/// @param ty the data type of the intrinsic
Intrinsic(ProgramID program_id, Op o, DataType ty);
Intrinsic(ProgramID program_id, Op o, ast::StorageClass sc, DataType ty);
/// Destructor
~Intrinsic() override;
@@ -90,14 +91,17 @@ class DecomposeStorageAccess : public Transform {
/// The op of the intrinsic
Op const op;
/// The storage class of the buffer this intrinsic operates on
ast::StorageClass const storage_class;
/// The type of the intrinsic
DataType const type;
};
/// Constructor
DecomposeStorageAccess();
DecomposeMemoryAccess();
/// Destructor
~DecomposeStorageAccess() override;
~DecomposeMemoryAccess() override;
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
@@ -111,4 +115,4 @@ class DecomposeStorageAccess : public Transform {
} // namespace transform
} // namespace tint
#endif // SRC_TRANSFORM_DECOMPOSE_STORAGE_ACCESS_H_
#endif // SRC_TRANSFORM_DECOMPOSE_MEMORY_ACCESS_H_

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/transform/decompose_storage_access.h"
#include "src/transform/decompose_memory_access.h"
#include "src/transform/test_helper.h"
@@ -20,9 +20,9 @@ namespace tint {
namespace transform {
namespace {
using DecomposeStorageAccessTest = TransformTest;
using DecomposeMemoryAccessTest = TransformTest;
TEST_F(DecomposeStorageAccessTest, BasicLoad) {
TEST_F(DecomposeMemoryAccessTest, SB_BasicLoad) {
auto* src = R"(
[[block]]
struct SB {
@@ -106,40 +106,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_load_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> i32
[[internal(intrinsic_load_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32) -> u32
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32) -> f32
[[internal(intrinsic_load_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32) -> vec2<i32>
[[internal(intrinsic_load_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32) -> vec2<u32>
[[internal(intrinsic_load_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32) -> vec2<f32>
[[internal(intrinsic_load_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32) -> vec3<i32>
[[internal(intrinsic_load_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32) -> vec3<u32>
[[internal(intrinsic_load_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32) -> vec3<f32>
[[internal(intrinsic_load_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32) -> vec4<i32>
[[internal(intrinsic_load_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32) -> vec4<u32>
[[internal(intrinsic_load_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32) -> vec4<f32>
fn tint_symbol_12(buffer : SB, offset : u32) -> mat2x2<f32> {
@@ -211,12 +211,206 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, BasicStore) {
TEST_F(DecomposeMemoryAccessTest, UB_BasicLoad) {
auto* src = R"(
[[block]]
struct UB {
a : i32;
b : u32;
c : f32;
d : vec2<i32>;
e : vec2<u32>;
f : vec2<f32>;
g : vec3<i32>;
h : vec3<u32>;
i : vec3<f32>;
j : vec4<i32>;
k : vec4<u32>;
l : vec4<f32>;
m : mat2x2<f32>;
n : mat2x3<f32>;
o : mat2x4<f32>;
p : mat3x2<f32>;
q : mat3x3<f32>;
r : mat3x4<f32>;
s : mat4x2<f32>;
t : mat4x3<f32>;
u : mat4x4<f32>;
v : array<vec3<f32>, 2>;
};
[[group(0), binding(0)]] var<uniform> ub : UB;
[[stage(compute)]]
fn main() {
var a : i32 = ub.a;
var b : u32 = ub.b;
var c : f32 = ub.c;
var d : vec2<i32> = ub.d;
var e : vec2<u32> = ub.e;
var f : vec2<f32> = ub.f;
var g : vec3<i32> = ub.g;
var h : vec3<u32> = ub.h;
var i : vec3<f32> = ub.i;
var j : vec4<i32> = ub.j;
var k : vec4<u32> = ub.k;
var l : vec4<f32> = ub.l;
var m : mat2x2<f32> = ub.m;
var n : mat2x3<f32> = ub.n;
var o : mat2x4<f32> = ub.o;
var p : mat3x2<f32> = ub.p;
var q : mat3x3<f32> = ub.q;
var r : mat3x4<f32> = ub.r;
var s : mat4x2<f32> = ub.s;
var t : mat4x3<f32> = ub.t;
var u : mat4x4<f32> = ub.u;
var v : array<vec3<f32>, 2> = ub.v;
}
)";
auto* expect = R"(
[[block]]
struct UB {
a : i32;
b : u32;
c : f32;
d : vec2<i32>;
e : vec2<u32>;
f : vec2<f32>;
g : vec3<i32>;
h : vec3<u32>;
i : vec3<f32>;
j : vec4<i32>;
k : vec4<u32>;
l : vec4<f32>;
m : mat2x2<f32>;
n : mat2x3<f32>;
o : mat2x4<f32>;
p : mat3x2<f32>;
q : mat3x3<f32>;
r : mat3x4<f32>;
s : mat4x2<f32>;
t : mat4x3<f32>;
u : mat4x4<f32>;
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_load_uniform_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : UB, offset : u32) -> i32
[[internal(intrinsic_load_uniform_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : UB, offset : u32) -> u32
[[internal(intrinsic_load_uniform_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : UB, offset : u32) -> f32
[[internal(intrinsic_load_uniform_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : UB, offset : u32) -> vec2<i32>
[[internal(intrinsic_load_uniform_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : UB, offset : u32) -> vec2<u32>
[[internal(intrinsic_load_uniform_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : UB, offset : u32) -> vec2<f32>
[[internal(intrinsic_load_uniform_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : UB, offset : u32) -> vec3<i32>
[[internal(intrinsic_load_uniform_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : UB, offset : u32) -> vec3<u32>
[[internal(intrinsic_load_uniform_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : UB, offset : u32) -> vec3<f32>
[[internal(intrinsic_load_uniform_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : UB, offset : u32) -> vec4<i32>
[[internal(intrinsic_load_uniform_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : UB, offset : u32) -> vec4<u32>
[[internal(intrinsic_load_uniform_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : UB, offset : u32) -> vec4<f32>
fn tint_symbol_12(buffer : UB, offset : u32) -> mat2x2<f32> {
return mat2x2<f32>(tint_symbol_5(buffer, (offset + 0u)), tint_symbol_5(buffer, (offset + 8u)));
}
fn tint_symbol_13(buffer : UB, offset : u32) -> mat2x3<f32> {
return mat2x3<f32>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)));
}
fn tint_symbol_14(buffer : UB, offset : u32) -> mat2x4<f32> {
return mat2x4<f32>(tint_symbol_11(buffer, (offset + 0u)), tint_symbol_11(buffer, (offset + 16u)));
}
fn tint_symbol_15(buffer : UB, offset : u32) -> mat3x2<f32> {
return mat3x2<f32>(tint_symbol_5(buffer, (offset + 0u)), tint_symbol_5(buffer, (offset + 8u)), tint_symbol_5(buffer, (offset + 16u)));
}
fn tint_symbol_16(buffer : UB, offset : u32) -> mat3x3<f32> {
return mat3x3<f32>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)), tint_symbol_8(buffer, (offset + 32u)));
}
fn tint_symbol_17(buffer : UB, offset : u32) -> mat3x4<f32> {
return mat3x4<f32>(tint_symbol_11(buffer, (offset + 0u)), tint_symbol_11(buffer, (offset + 16u)), tint_symbol_11(buffer, (offset + 32u)));
}
fn tint_symbol_18(buffer : UB, offset : u32) -> mat4x2<f32> {
return mat4x2<f32>(tint_symbol_5(buffer, (offset + 0u)), tint_symbol_5(buffer, (offset + 8u)), tint_symbol_5(buffer, (offset + 16u)), tint_symbol_5(buffer, (offset + 24u)));
}
fn tint_symbol_19(buffer : UB, offset : u32) -> mat4x3<f32> {
return mat4x3<f32>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)), tint_symbol_8(buffer, (offset + 32u)), tint_symbol_8(buffer, (offset + 48u)));
}
fn tint_symbol_20(buffer : UB, offset : u32) -> mat4x4<f32> {
return mat4x4<f32>(tint_symbol_11(buffer, (offset + 0u)), tint_symbol_11(buffer, (offset + 16u)), tint_symbol_11(buffer, (offset + 32u)), tint_symbol_11(buffer, (offset + 48u)));
}
fn tint_symbol_21(buffer : UB, offset : u32) -> array<vec3<f32>, 2> {
return array<vec3<f32>, 2>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)));
}
[[group(0), binding(0)]] var<uniform> ub : UB;
[[stage(compute)]]
fn main() {
var a : i32 = tint_symbol(ub, 0u);
var b : u32 = tint_symbol_1(ub, 4u);
var c : f32 = tint_symbol_2(ub, 8u);
var d : vec2<i32> = tint_symbol_3(ub, 16u);
var e : vec2<u32> = tint_symbol_4(ub, 24u);
var f : vec2<f32> = tint_symbol_5(ub, 32u);
var g : vec3<i32> = tint_symbol_6(ub, 48u);
var h : vec3<u32> = tint_symbol_7(ub, 64u);
var i : vec3<f32> = tint_symbol_8(ub, 80u);
var j : vec4<i32> = tint_symbol_9(ub, 96u);
var k : vec4<u32> = tint_symbol_10(ub, 112u);
var l : vec4<f32> = tint_symbol_11(ub, 128u);
var m : mat2x2<f32> = tint_symbol_12(ub, 144u);
var n : mat2x3<f32> = tint_symbol_13(ub, 160u);
var o : mat2x4<f32> = tint_symbol_14(ub, 192u);
var p : mat3x2<f32> = tint_symbol_15(ub, 224u);
var q : mat3x3<f32> = tint_symbol_16(ub, 256u);
var r : mat3x4<f32> = tint_symbol_17(ub, 304u);
var s : mat4x2<f32> = tint_symbol_18(ub, 352u);
var t : mat4x3<f32> = tint_symbol_19(ub, 384u);
var u : mat4x4<f32> = tint_symbol_20(ub, 448u);
var v : array<vec3<f32>, 2> = tint_symbol_21(ub, 512u);
}
)";
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeMemoryAccessTest, SB_BasicStore) {
auto* src = R"(
[[block]]
struct SB {
@@ -300,40 +494,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_store_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32, value : i32)
[[internal(intrinsic_store_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32, value : u32)
[[internal(intrinsic_store_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32, value : f32)
[[internal(intrinsic_store_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32, value : vec2<i32>)
[[internal(intrinsic_store_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32, value : vec2<u32>)
[[internal(intrinsic_store_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32, value : vec2<f32>)
[[internal(intrinsic_store_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32, value : vec3<i32>)
[[internal(intrinsic_store_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32, value : vec3<u32>)
[[internal(intrinsic_store_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32, value : vec3<f32>)
[[internal(intrinsic_store_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32, value : vec4<i32>)
[[internal(intrinsic_store_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32, value : vec4<u32>)
[[internal(intrinsic_store_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32, value : vec4<f32>)
fn tint_symbol_12(buffer : SB, offset : u32, value : mat2x2<f32>) {
@@ -424,12 +618,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, LoadStructure) {
TEST_F(DecomposeMemoryAccessTest, LoadStructure) {
auto* src = R"(
[[block]]
struct SB {
@@ -492,40 +686,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_load_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> i32
[[internal(intrinsic_load_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32) -> u32
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32) -> f32
[[internal(intrinsic_load_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32) -> vec2<i32>
[[internal(intrinsic_load_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32) -> vec2<u32>
[[internal(intrinsic_load_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32) -> vec2<f32>
[[internal(intrinsic_load_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32) -> vec3<i32>
[[internal(intrinsic_load_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32) -> vec3<u32>
[[internal(intrinsic_load_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32) -> vec3<f32>
[[internal(intrinsic_load_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32) -> vec4<i32>
[[internal(intrinsic_load_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32) -> vec4<u32>
[[internal(intrinsic_load_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32) -> vec4<f32>
fn tint_symbol_12(buffer : SB, offset : u32) -> mat2x2<f32> {
@@ -580,12 +774,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, StoreStructure) {
TEST_F(DecomposeMemoryAccessTest, StoreStructure) {
auto* src = R"(
[[block]]
struct SB {
@@ -648,40 +842,40 @@ struct SB {
v : array<vec3<f32>, 2>;
};
[[internal(intrinsic_store_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32, value : i32)
[[internal(intrinsic_store_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32, value : u32)
[[internal(intrinsic_store_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32, value : f32)
[[internal(intrinsic_store_vec2_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32, value : vec2<i32>)
[[internal(intrinsic_store_vec2_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32, value : vec2<u32>)
[[internal(intrinsic_store_vec2_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec2_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32, value : vec2<f32>)
[[internal(intrinsic_store_vec3_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32, value : vec3<i32>)
[[internal(intrinsic_store_vec3_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32, value : vec3<u32>)
[[internal(intrinsic_store_vec3_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec3_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32, value : vec3<f32>)
[[internal(intrinsic_store_vec4_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32, value : vec4<i32>)
[[internal(intrinsic_store_vec4_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32, value : vec4<u32>)
[[internal(intrinsic_store_vec4_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_store_storage_vec4_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32, value : vec4<f32>)
fn tint_symbol_12(buffer : SB, offset : u32, value : mat2x2<f32>) {
@@ -776,12 +970,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, ComplexStaticAccessChain) {
TEST_F(DecomposeMemoryAccessTest, ComplexStaticAccessChain) {
auto* src = R"(
struct S1 {
a : i32;
@@ -837,7 +1031,7 @@ struct SB {
b : [[stride(256)]] array<S2>;
};
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> f32
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -848,12 +1042,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, ComplexDynamicAccessChain) {
TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChain) {
auto* src = R"(
struct S1 {
a : i32;
@@ -905,7 +1099,7 @@ struct SB {
b : [[stride(256)]] array<S2>;
};
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> f32
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -919,12 +1113,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, ComplexDynamicAccessChainWithAliases) {
TEST_F(DecomposeMemoryAccessTest, ComplexDynamicAccessChainWithAliases) {
auto* src = R"(
struct S1 {
a : i32;
@@ -992,7 +1186,7 @@ struct SB {
b : A2_Array;
};
[[internal(intrinsic_load_f32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32) -> f32
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -1006,12 +1200,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, StorageBufferAtomics) {
TEST_F(DecomposeMemoryAccessTest, StorageBufferAtomics) {
auto* src = R"(
[[block]]
struct SB {
@@ -1056,64 +1250,64 @@ struct SB {
b : atomic<u32>;
};
[[internal(intrinsic_atomic_store_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol(buffer : SB, offset : u32, param_1 : i32)
[[internal(intrinsic_atomic_load_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_1(buffer : SB, offset : u32) -> i32
[[internal(intrinsic_atomic_add_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_add_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_2(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_max_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_max_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_3(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_min_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_min_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_4(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_and_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_and_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_5(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_or_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_or_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_6(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_xor_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_xor_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_7(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_exchange_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_exchange_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_8(buffer : SB, offset : u32, param_1 : i32) -> i32
[[internal(intrinsic_atomic_compare_exchange_weak_i32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_compare_exchange_weak_storage_i32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_9(buffer : SB, offset : u32, param_1 : i32, param_2 : i32) -> vec2<i32>
[[internal(intrinsic_atomic_store_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_store_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_10(buffer : SB, offset : u32, param_1 : u32)
[[internal(intrinsic_atomic_load_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_load_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_11(buffer : SB, offset : u32) -> u32
[[internal(intrinsic_atomic_add_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_add_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_12(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_max_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_max_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_13(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_min_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_min_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_14(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_and_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_and_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_15(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_or_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_or_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_16(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_xor_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_xor_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_17(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_exchange_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_exchange_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_18(buffer : SB, offset : u32, param_1 : u32) -> u32
[[internal(intrinsic_atomic_compare_exchange_weak_u32), internal(disable_validation__function_has_no_body)]]
[[internal(intrinsic_atomic_compare_exchange_weak_storage_u32), internal(disable_validation__function_has_no_body)]]
fn tint_symbol_19(buffer : SB, offset : u32, param_1 : u32, param_2 : u32) -> vec2<u32>
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
@@ -1143,12 +1337,12 @@ fn main() {
}
)";
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(DecomposeStorageAccessTest, WorkgroupBufferAtomics) {
TEST_F(DecomposeMemoryAccessTest, WorkgroupBufferAtomics) {
auto* src = R"(
struct S {
padding : vec4<f32>;
@@ -1185,7 +1379,7 @@ fn main() {
auto* expect = src;
auto got = Run<DecomposeStorageAccess>(src);
auto got = Run<DecomposeMemoryAccess>(src);
EXPECT_EQ(expect, str(got));
}

View File

@@ -19,7 +19,7 @@
#include "src/program_builder.h"
#include "src/transform/calculate_array_length.h"
#include "src/transform/canonicalize_entry_point_io.h"
#include "src/transform/decompose_storage_access.h"
#include "src/transform/decompose_memory_access.h"
#include "src/transform/external_texture_transform.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/manager.h"
@@ -41,13 +41,13 @@ Output Hlsl::Run(const Program* in, const DataMap&) {
manager.Add<InlinePointerLets>();
// Simplify cleans up messy `*(&(expr))` expressions from InlinePointerLets.
manager.Add<Simplify>();
// DecomposeStorageAccess must come after InlinePointerLets as we cannot take
// the address of calls to DecomposeStorageAccess::Intrinsic. Must also come
// DecomposeMemoryAccess must come after InlinePointerLets as we cannot take
// the address of calls to DecomposeMemoryAccess::Intrinsic. Must also come
// after Simplify, as we need to fold away the address-of and defers of
// `*(&(intrinsic_load()))` expressions.
manager.Add<DecomposeStorageAccess>();
// CalculateArrayLength must come after DecomposeStorageAccess, as
// DecomposeStorageAccess special-cases the arrayLength() intrinsic, which
manager.Add<DecomposeMemoryAccess>();
// CalculateArrayLength must come after DecomposeMemoryAccess, as
// DecomposeMemoryAccess special-cases the arrayLength() intrinsic, which
// will be transformed by CalculateArrayLength
manager.Add<CalculateArrayLength>();
manager.Add<ExternalTextureTransform>();