mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 00:17:03 +00:00
transform: Fixes for DecomposeMemoryAccess
CloneContext::Replace(T* what, T* with) is bug-prone, as complex transforms may want to clone `what` multiple times, or not at all. In both cases, this will likely result in an ICE as either the replacement will be reachable multiple times, or not at all. The CTS test: webgpu:shader,execution,robust_access:linear_memory:storageClass="storage";storageMode="read_write";access="read";atomic=true;baseType="i32" Was triggering this brokenness with DecomposeMemoryAccess's use of CloneContext::Replace(T*, T*). Switch the usage of CloneContext::Replace(T*, T*) to the new function form. As std::function is copyable, it cannot hold a captured std::unique_ptr. This prevented the Replace() lambdas from capturing the necessary `BufferAccess` data, as this held a `std::unique_ptr<Offset>`. To fix this, use a `BlockAllocator` for Offsets, and use raw pointers instead. Because the function passed to Replace() is called just before the node is cloned, insertion of new functions will occur just before the currently evaluated module-scope entity. This allows us to remove the "insert_after" arguments to LoadFunc(), StoreFunc(), and AtomicFunc(). We can also kill the icky InsertGlobal() and TypeDeclOf() helpers. Bug: tint:993 Change-Id: I60972bc13a2fa819a163ee2671f61e82d0e68d2a Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58222 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/type_name.h"
|
||||
#include "src/ast/unary_op.h"
|
||||
#include "src/block_allocator.h"
|
||||
#include "src/program_builder.h"
|
||||
#include "src/sem/array.h"
|
||||
#include "src/sem/atomic_type.h"
|
||||
@@ -50,7 +51,7 @@ namespace {
|
||||
/// offsets for storage and uniform buffer accesses.
|
||||
struct Offset : Castable<Offset> {
|
||||
/// @returns builds and returns the ast::Expression in `ctx.dst`
|
||||
virtual ast::Expression* Build(CloneContext& ctx) = 0;
|
||||
virtual ast::Expression* Build(CloneContext& ctx) const = 0;
|
||||
};
|
||||
|
||||
/// OffsetExpr is an implementation of Offset that clones and casts the given
|
||||
@@ -60,7 +61,7 @@ struct OffsetExpr : Offset {
|
||||
|
||||
explicit OffsetExpr(ast::Expression* e) : expr(e) {}
|
||||
|
||||
ast::Expression* Build(CloneContext& ctx) override {
|
||||
ast::Expression* Build(CloneContext& ctx) const override {
|
||||
auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
|
||||
auto* res = ctx.Clone(expr);
|
||||
if (!type->Is<sem::U32>()) {
|
||||
@@ -77,7 +78,7 @@ struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
|
||||
|
||||
explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
|
||||
|
||||
ast::Expression* Build(CloneContext& ctx) override {
|
||||
ast::Expression* Build(CloneContext& ctx) const override {
|
||||
return ctx.dst->Expr(literal);
|
||||
}
|
||||
};
|
||||
@@ -86,103 +87,20 @@ struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
|
||||
/// two Offsets.
|
||||
struct OffsetBinOp : Offset {
|
||||
ast::BinaryOp op;
|
||||
std::unique_ptr<Offset> lhs;
|
||||
std::unique_ptr<Offset> rhs;
|
||||
Offset const* lhs = nullptr;
|
||||
Offset const* rhs = nullptr;
|
||||
|
||||
ast::Expression* Build(CloneContext& ctx) override {
|
||||
ast::Expression* Build(CloneContext& ctx) const override {
|
||||
return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx),
|
||||
rhs->Build(ctx));
|
||||
}
|
||||
};
|
||||
|
||||
/// @returns an Offset for the given literal value
|
||||
std::unique_ptr<Offset> ToOffset(uint32_t offset) {
|
||||
return std::make_unique<OffsetLiteral>(offset);
|
||||
}
|
||||
|
||||
/// @returns an Offset for the given ast::Expression
|
||||
std::unique_ptr<Offset> ToOffset(ast::Expression* expr) {
|
||||
if (auto* scalar = expr->As<ast::ScalarConstructorExpression>()) {
|
||||
if (auto* u32 = scalar->literal()->As<ast::UintLiteral>()) {
|
||||
return std::make_unique<OffsetLiteral>(u32->value());
|
||||
} else if (auto* i32 = scalar->literal()->As<ast::SintLiteral>()) {
|
||||
if (i32->value() > 0) {
|
||||
return std::make_unique<OffsetLiteral>(i32->value());
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_unique<OffsetExpr>(expr);
|
||||
}
|
||||
|
||||
/// @returns the given offset (pass-through)
|
||||
std::unique_ptr<Offset> ToOffset(std::unique_ptr<Offset> offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
/// @return an Offset that is a sum of lhs and rhs, performing basic constant
|
||||
/// folding if possible
|
||||
template <typename LHS, typename RHS>
|
||||
std::unique_ptr<Offset> Add(LHS&& lhs_, RHS&& rhs_) {
|
||||
std::unique_ptr<Offset> lhs = ToOffset(std::forward<LHS>(lhs_));
|
||||
std::unique_ptr<Offset> rhs = ToOffset(std::forward<RHS>(rhs_));
|
||||
auto* lhs_lit = lhs->As<OffsetLiteral>();
|
||||
auto* rhs_lit = rhs->As<OffsetLiteral>();
|
||||
if (lhs_lit && lhs_lit->literal == 0) {
|
||||
return rhs;
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 0) {
|
||||
return lhs;
|
||||
}
|
||||
if (lhs_lit && rhs_lit) {
|
||||
if (static_cast<uint64_t>(lhs_lit->literal) +
|
||||
static_cast<uint64_t>(rhs_lit->literal) <=
|
||||
0xffffffff) {
|
||||
return std::make_unique<OffsetLiteral>(lhs_lit->literal +
|
||||
rhs_lit->literal);
|
||||
}
|
||||
}
|
||||
auto out = std::make_unique<OffsetBinOp>();
|
||||
out->op = ast::BinaryOp::kAdd;
|
||||
out->lhs = std::move(lhs);
|
||||
out->rhs = std::move(rhs);
|
||||
return out;
|
||||
}
|
||||
|
||||
/// @return an Offset that is the multiplication of lhs and rhs, performing
|
||||
/// basic constant folding if possible
|
||||
template <typename LHS, typename RHS>
|
||||
std::unique_ptr<Offset> Mul(LHS&& lhs_, RHS&& rhs_) {
|
||||
std::unique_ptr<Offset> lhs = ToOffset(std::forward<LHS>(lhs_));
|
||||
std::unique_ptr<Offset> rhs = ToOffset(std::forward<RHS>(rhs_));
|
||||
auto* lhs_lit = lhs->As<OffsetLiteral>();
|
||||
auto* rhs_lit = rhs->As<OffsetLiteral>();
|
||||
if (lhs_lit && lhs_lit->literal == 0) {
|
||||
return std::make_unique<OffsetLiteral>(0);
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 0) {
|
||||
return std::make_unique<OffsetLiteral>(0);
|
||||
}
|
||||
if (lhs_lit && lhs_lit->literal == 1) {
|
||||
return rhs;
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 1) {
|
||||
return lhs;
|
||||
}
|
||||
if (lhs_lit && rhs_lit) {
|
||||
return std::make_unique<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal);
|
||||
}
|
||||
auto out = std::make_unique<OffsetBinOp>();
|
||||
out->op = ast::BinaryOp::kMultiply;
|
||||
out->lhs = std::move(lhs);
|
||||
out->rhs = std::move(rhs);
|
||||
return out;
|
||||
}
|
||||
|
||||
/// LoadStoreKey is the unordered map key to a load or store intrinsic.
|
||||
struct LoadStoreKey {
|
||||
ast::StorageClass const storage_class; // buffer storage class
|
||||
sem::Type const* buf_ty; // buffer type
|
||||
sem::Type const* el_ty; // element type
|
||||
sem::Type const* buf_ty = nullptr; // buffer type
|
||||
sem::Type const* el_ty = nullptr; // element type
|
||||
bool operator==(const LoadStoreKey& rhs) const {
|
||||
return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty &&
|
||||
el_ty == rhs.el_ty;
|
||||
@@ -196,9 +114,9 @@ struct LoadStoreKey {
|
||||
|
||||
/// AtomicKey is the unordered map key to an atomic intrinsic.
|
||||
struct AtomicKey {
|
||||
sem::Type const* buf_ty; // buffer type
|
||||
sem::Type const* el_ty; // element type
|
||||
sem::IntrinsicType const op; // atomic op
|
||||
sem::Type const* buf_ty = nullptr; // buffer type
|
||||
sem::Type const* el_ty = nullptr; // element type
|
||||
sem::IntrinsicType const op; // atomic op
|
||||
bool operator==(const AtomicKey& rhs) const {
|
||||
return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
|
||||
}
|
||||
@@ -367,39 +285,10 @@ DecomposeMemoryAccess::Intrinsic* IntrinsicAtomicFor(ProgramBuilder* builder,
|
||||
builder->ID(), op, ast::StorageClass::kStorage, type);
|
||||
}
|
||||
|
||||
/// Inserts `node` before `insert_after` in the global declarations of
|
||||
/// `ctx.dst`. If `insert_after` is nullptr, then `node` is inserted at the top
|
||||
/// of the module.
|
||||
void InsertGlobal(CloneContext& ctx,
|
||||
const Cloneable* insert_after,
|
||||
Cloneable* node) {
|
||||
auto& globals = ctx.src->AST().GlobalDeclarations();
|
||||
if (insert_after) {
|
||||
ctx.InsertAfter(globals, insert_after, node);
|
||||
} else {
|
||||
ctx.InsertBefore(globals, *globals.begin(), node);
|
||||
}
|
||||
}
|
||||
|
||||
/// @returns the unwrapped, user-declared type of ty.
|
||||
const ast::TypeDecl* TypeDeclOf(const sem::Type* ty) {
|
||||
while (true) {
|
||||
if (auto* ref = ty->As<sem::Reference>()) {
|
||||
ty = ref->StoreType();
|
||||
continue;
|
||||
}
|
||||
if (auto* str = ty->As<sem::Struct>()) {
|
||||
return str->Declaration();
|
||||
}
|
||||
// Not a declared type
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
/// BufferAccess describes a single storage or uniform buffer access
|
||||
struct BufferAccess {
|
||||
sem::Expression const* var = nullptr; // Storage buffer variable
|
||||
std::unique_ptr<Offset> offset; // The byte offset on var
|
||||
Offset const* offset = nullptr; // The byte offset on var
|
||||
sem::Type const* type = nullptr; // The type of the access
|
||||
operator bool() const { return var; } // Returns true if valid
|
||||
};
|
||||
@@ -430,14 +319,105 @@ struct DecomposeMemoryAccess::State {
|
||||
std::unordered_map<AtomicKey, Symbol, AtomicKey::Hasher> atomic_funcs;
|
||||
/// List of storage or uniform buffer writes
|
||||
std::vector<Store> stores;
|
||||
/// Allocations for offsets
|
||||
BlockAllocator<Offset> offsets_;
|
||||
|
||||
/// @param offset the offset value to wrap in an Offset
|
||||
/// @returns an Offset for the given literal value
|
||||
const Offset* ToOffset(uint32_t offset) {
|
||||
return offsets_.Create<OffsetLiteral>(offset);
|
||||
}
|
||||
|
||||
/// @param expr the expression to convert to an Offset
|
||||
/// @returns an Offset for the given ast::Expression
|
||||
const Offset* ToOffset(ast::Expression* expr) {
|
||||
if (auto* scalar = expr->As<ast::ScalarConstructorExpression>()) {
|
||||
if (auto* u32 = scalar->literal()->As<ast::UintLiteral>()) {
|
||||
return offsets_.Create<OffsetLiteral>(u32->value());
|
||||
} else if (auto* i32 = scalar->literal()->As<ast::SintLiteral>()) {
|
||||
if (i32->value() > 0) {
|
||||
return offsets_.Create<OffsetLiteral>(i32->value());
|
||||
}
|
||||
}
|
||||
}
|
||||
return offsets_.Create<OffsetExpr>(expr);
|
||||
}
|
||||
|
||||
/// @param offset the Offset that is returned
|
||||
/// @returns the given offset (pass-through)
|
||||
const Offset* ToOffset(const Offset* offset) { return offset; }
|
||||
|
||||
/// @param lhs_ the left-hand side of the add expression
|
||||
/// @param rhs_ the right-hand side of the add expression
|
||||
/// @return an Offset that is a sum of lhs and rhs, performing basic constant
|
||||
/// folding if possible
|
||||
template <typename LHS, typename RHS>
|
||||
const Offset* Add(LHS&& lhs_, RHS&& rhs_) {
|
||||
auto* lhs = ToOffset(std::forward<LHS>(lhs_));
|
||||
auto* rhs = ToOffset(std::forward<RHS>(rhs_));
|
||||
auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
|
||||
auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
|
||||
if (lhs_lit && lhs_lit->literal == 0) {
|
||||
return rhs;
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 0) {
|
||||
return lhs;
|
||||
}
|
||||
if (lhs_lit && rhs_lit) {
|
||||
if (static_cast<uint64_t>(lhs_lit->literal) +
|
||||
static_cast<uint64_t>(rhs_lit->literal) <=
|
||||
0xffffffff) {
|
||||
return offsets_.Create<OffsetLiteral>(lhs_lit->literal +
|
||||
rhs_lit->literal);
|
||||
}
|
||||
}
|
||||
auto* out = offsets_.Create<OffsetBinOp>();
|
||||
out->op = ast::BinaryOp::kAdd;
|
||||
out->lhs = lhs;
|
||||
out->rhs = rhs;
|
||||
return out;
|
||||
}
|
||||
|
||||
/// @param lhs_ the left-hand side of the multiply expression
|
||||
/// @param rhs_ the right-hand side of the multiply expression
|
||||
/// @return an Offset that is the multiplication of lhs and rhs, performing
|
||||
/// basic constant folding if possible
|
||||
template <typename LHS, typename RHS>
|
||||
const Offset* Mul(LHS&& lhs_, RHS&& rhs_) {
|
||||
auto* lhs = ToOffset(std::forward<LHS>(lhs_));
|
||||
auto* rhs = ToOffset(std::forward<RHS>(rhs_));
|
||||
auto* lhs_lit = tint::As<OffsetLiteral>(lhs);
|
||||
auto* rhs_lit = tint::As<OffsetLiteral>(rhs);
|
||||
if (lhs_lit && lhs_lit->literal == 0) {
|
||||
return offsets_.Create<OffsetLiteral>(0);
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 0) {
|
||||
return offsets_.Create<OffsetLiteral>(0);
|
||||
}
|
||||
if (lhs_lit && lhs_lit->literal == 1) {
|
||||
return rhs;
|
||||
}
|
||||
if (rhs_lit && rhs_lit->literal == 1) {
|
||||
return lhs;
|
||||
}
|
||||
if (lhs_lit && rhs_lit) {
|
||||
return offsets_.Create<OffsetLiteral>(lhs_lit->literal *
|
||||
rhs_lit->literal);
|
||||
}
|
||||
auto* out = offsets_.Create<OffsetBinOp>();
|
||||
out->op = ast::BinaryOp::kMultiply;
|
||||
out->lhs = lhs;
|
||||
out->rhs = rhs;
|
||||
return out;
|
||||
}
|
||||
|
||||
/// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
|
||||
/// to #expression_order.
|
||||
/// @param expr the expression that performs the access
|
||||
/// @param access the access
|
||||
void AddAccess(ast::Expression* expr, BufferAccess&& access) {
|
||||
void AddAccess(ast::Expression* expr, const BufferAccess& access) {
|
||||
TINT_ASSERT(Transform, access.type);
|
||||
accesses.emplace(expr, std::move(access));
|
||||
accesses.emplace(expr, access);
|
||||
expression_order.emplace_back(expr);
|
||||
}
|
||||
|
||||
@@ -451,7 +431,7 @@ struct DecomposeMemoryAccess::State {
|
||||
if (lhs_it == accesses.end()) {
|
||||
return {};
|
||||
}
|
||||
auto access = std::move(lhs_it->second);
|
||||
auto access = lhs_it->second;
|
||||
accesses.erase(node);
|
||||
return access;
|
||||
}
|
||||
@@ -461,13 +441,11 @@ struct DecomposeMemoryAccess::State {
|
||||
/// The emitted function has the signature:
|
||||
/// `fn load(buf : buf_ty, offset : u32) -> el_ty`
|
||||
/// @param ctx the CloneContext
|
||||
/// @param insert_after the user-declared type to insert the function after
|
||||
/// @param buf_ty the storage or uniform buffer type
|
||||
/// @param el_ty the storage or uniform buffer element type
|
||||
/// @param var_user the variable user
|
||||
/// @return the name of the function that performs the load
|
||||
Symbol LoadFunc(CloneContext& ctx,
|
||||
const ast::TypeDecl* insert_after,
|
||||
const sem::Type* buf_ty,
|
||||
const sem::Type* el_ty,
|
||||
const sem::VariableUser* var_user) {
|
||||
@@ -509,8 +487,7 @@ struct DecomposeMemoryAccess::State {
|
||||
ast::ExpressionList values;
|
||||
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
|
||||
auto* vec_ty = mat_ty->ColumnType();
|
||||
Symbol load =
|
||||
LoadFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
|
||||
Symbol load = LoadFunc(ctx, buf_ty, vec_ty, var_user);
|
||||
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
|
||||
auto* offset =
|
||||
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
|
||||
@@ -519,14 +496,14 @@ struct DecomposeMemoryAccess::State {
|
||||
} else if (auto* str = el_ty->As<sem::Struct>()) {
|
||||
for (auto* member : str->Members()) {
|
||||
auto* offset = ctx.dst->Add("offset", member->Offset());
|
||||
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
|
||||
member->Type()->UnwrapRef(), var_user);
|
||||
Symbol load = LoadFunc(ctx, buf_ty, member->Type()->UnwrapRef(),
|
||||
var_user);
|
||||
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
|
||||
}
|
||||
} else if (auto* arr = el_ty->As<sem::Array>()) {
|
||||
for (uint32_t i = 0; i < arr->Count(); i++) {
|
||||
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
|
||||
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
|
||||
Symbol load = LoadFunc(ctx, buf_ty,
|
||||
arr->ElemType()->UnwrapRef(), var_user);
|
||||
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
|
||||
}
|
||||
@@ -539,7 +516,7 @@ struct DecomposeMemoryAccess::State {
|
||||
CreateASTTypeFor(&ctx, el_ty), values))),
|
||||
ast::DecorationList{}, ast::DecorationList{});
|
||||
}
|
||||
InsertGlobal(ctx, insert_after, func);
|
||||
ctx.dst->AST().AddFunction(func);
|
||||
return func->symbol();
|
||||
});
|
||||
}
|
||||
@@ -549,13 +526,11 @@ struct DecomposeMemoryAccess::State {
|
||||
/// The function has the signature:
|
||||
/// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
|
||||
/// @param ctx the CloneContext
|
||||
/// @param insert_after the user-declared type to insert the function after
|
||||
/// @param buf_ty the storage buffer type
|
||||
/// @param el_ty the storage buffer element type
|
||||
/// @param var_user the variable user
|
||||
/// @return the name of the function that performs the store
|
||||
Symbol StoreFunc(CloneContext& ctx,
|
||||
const ast::TypeDecl* insert_after,
|
||||
const sem::Type* buf_ty,
|
||||
const sem::Type* el_ty,
|
||||
const sem::VariableUser* var_user) {
|
||||
@@ -597,8 +572,7 @@ struct DecomposeMemoryAccess::State {
|
||||
ast::StatementList body;
|
||||
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
|
||||
auto* vec_ty = mat_ty->ColumnType();
|
||||
Symbol store =
|
||||
StoreFunc(ctx, insert_after, buf_ty, vec_ty, var_user);
|
||||
Symbol store = StoreFunc(ctx, buf_ty, vec_ty, var_user);
|
||||
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
|
||||
auto* offset =
|
||||
ctx.dst->Add("offset", i * MatrixColumnStride(mat_ty));
|
||||
@@ -611,7 +585,7 @@ struct DecomposeMemoryAccess::State {
|
||||
auto* offset = ctx.dst->Add("offset", member->Offset());
|
||||
auto* access = ctx.dst->MemberAccessor(
|
||||
"value", ctx.Clone(member->Declaration()->symbol()));
|
||||
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
|
||||
Symbol store = StoreFunc(ctx, buf_ty,
|
||||
member->Type()->UnwrapRef(), var_user);
|
||||
auto* call = ctx.dst->Call(store, "buffer", offset, access);
|
||||
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
|
||||
@@ -621,9 +595,8 @@ struct DecomposeMemoryAccess::State {
|
||||
auto* offset = ctx.dst->Add("offset", arr->Stride() * i);
|
||||
auto* access =
|
||||
ctx.dst->IndexAccessor("value", ctx.dst->Expr(i));
|
||||
Symbol store =
|
||||
StoreFunc(ctx, insert_after, buf_ty,
|
||||
arr->ElemType()->UnwrapRef(), var_user);
|
||||
Symbol store = StoreFunc(
|
||||
ctx, buf_ty, arr->ElemType()->UnwrapRef(), var_user);
|
||||
auto* call = ctx.dst->Call(store, "buffer", offset, access);
|
||||
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
|
||||
}
|
||||
@@ -634,7 +607,7 @@ struct DecomposeMemoryAccess::State {
|
||||
ast::DecorationList{});
|
||||
}
|
||||
|
||||
InsertGlobal(ctx, insert_after, func);
|
||||
ctx.dst->AST().AddFunction(func);
|
||||
return func->symbol();
|
||||
});
|
||||
}
|
||||
@@ -644,14 +617,12 @@ struct DecomposeMemoryAccess::State {
|
||||
/// the signature:
|
||||
// `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
|
||||
/// @param ctx the CloneContext
|
||||
/// @param insert_after the user-declared type to insert the function after
|
||||
/// @param buf_ty the storage buffer type
|
||||
/// @param el_ty the storage buffer element type
|
||||
/// @param intrinsic the atomic intrinsic
|
||||
/// @param var_user the variable user
|
||||
/// @return the name of the function that performs the load
|
||||
Symbol AtomicFunc(CloneContext& ctx,
|
||||
const ast::TypeDecl* insert_after,
|
||||
const sem::Type* buf_ty,
|
||||
const sem::Type* el_ty,
|
||||
const sem::Intrinsic* intrinsic,
|
||||
@@ -700,7 +671,7 @@ struct DecomposeMemoryAccess::State {
|
||||
},
|
||||
ast::DecorationList{});
|
||||
|
||||
InsertGlobal(ctx, insert_after, func);
|
||||
ctx.dst->AST().AddFunction(func);
|
||||
return func->symbol();
|
||||
});
|
||||
}
|
||||
@@ -825,7 +796,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// Variable to a storage or uniform buffer
|
||||
state.AddAccess(ident, {
|
||||
var,
|
||||
ToOffset(0u),
|
||||
state.ToOffset(0u),
|
||||
var->Type()->UnwrapRef(),
|
||||
});
|
||||
}
|
||||
@@ -840,14 +811,13 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
if (swizzle->Indices().size() == 1) {
|
||||
if (auto access = state.TakeAccess(accessor->structure())) {
|
||||
auto* vec_ty = access.type->As<sem::Vector>();
|
||||
auto offset =
|
||||
Mul(ScalarSize(vec_ty->type()), swizzle->Indices()[0]);
|
||||
state.AddAccess(
|
||||
accessor, {
|
||||
access.var,
|
||||
Add(std::move(access.offset), std::move(offset)),
|
||||
vec_ty->type()->UnwrapRef(),
|
||||
});
|
||||
auto* offset =
|
||||
state.Mul(ScalarSize(vec_ty->type()), swizzle->Indices()[0]);
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
vec_ty->type()->UnwrapRef(),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -855,12 +825,11 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
auto* str_ty = access.type->As<sem::Struct>();
|
||||
auto* member = str_ty->FindMember(accessor->member()->symbol());
|
||||
auto offset = member->Offset();
|
||||
state.AddAccess(accessor,
|
||||
{
|
||||
access.var,
|
||||
Add(std::move(access.offset), std::move(offset)),
|
||||
member->Type()->UnwrapRef(),
|
||||
});
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
member->Type()->UnwrapRef(),
|
||||
});
|
||||
}
|
||||
}
|
||||
continue;
|
||||
@@ -870,33 +839,32 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
if (auto access = state.TakeAccess(accessor->array())) {
|
||||
// X[Y]
|
||||
if (auto* arr = access.type->As<sem::Array>()) {
|
||||
auto offset = Mul(arr->Stride(), accessor->idx_expr());
|
||||
state.AddAccess(accessor,
|
||||
{
|
||||
access.var,
|
||||
Add(std::move(access.offset), std::move(offset)),
|
||||
arr->ElemType()->UnwrapRef(),
|
||||
});
|
||||
auto* offset = state.Mul(arr->Stride(), accessor->idx_expr());
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
arr->ElemType()->UnwrapRef(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (auto* vec_ty = access.type->As<sem::Vector>()) {
|
||||
auto offset = Mul(ScalarSize(vec_ty->type()), accessor->idx_expr());
|
||||
state.AddAccess(accessor,
|
||||
{
|
||||
access.var,
|
||||
Add(std::move(access.offset), std::move(offset)),
|
||||
vec_ty->type()->UnwrapRef(),
|
||||
});
|
||||
auto* offset =
|
||||
state.Mul(ScalarSize(vec_ty->type()), accessor->idx_expr());
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
vec_ty->type()->UnwrapRef(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (auto* mat_ty = access.type->As<sem::Matrix>()) {
|
||||
auto offset = Mul(MatrixColumnStride(mat_ty), accessor->idx_expr());
|
||||
state.AddAccess(accessor,
|
||||
{
|
||||
access.var,
|
||||
Add(std::move(access.offset), std::move(offset)),
|
||||
mat_ty->ColumnType(),
|
||||
});
|
||||
auto* offset =
|
||||
state.Mul(MatrixColumnStride(mat_ty), accessor->idx_expr());
|
||||
state.AddAccess(accessor, {
|
||||
access.var,
|
||||
state.Add(access.offset, offset),
|
||||
mat_ty->ColumnType(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -908,7 +876,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
if (auto access = state.TakeAccess(op->expr())) {
|
||||
// HLSL does not support pointers, so just take the access from the
|
||||
// reference and place it on the pointer.
|
||||
state.AddAccess(op, std::move(access));
|
||||
state.AddAccess(op, access);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -918,7 +886,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// X = Y
|
||||
// Move the LHS access to a store.
|
||||
if (auto lhs = state.TakeAccess(assign->lhs())) {
|
||||
state.stores.emplace_back(Store{assign, std::move(lhs)});
|
||||
state.stores.emplace_back(Store{assign, lhs});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -934,23 +902,22 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
if (intrinsic->IsAtomic()) {
|
||||
if (auto access = state.TakeAccess(call_expr->params()[0])) {
|
||||
// atomic___(X)
|
||||
ctx.Replace(call_expr, [=, &ctx, &state] {
|
||||
auto* buf = access.var->Declaration();
|
||||
auto* offset = access.offset->Build(ctx);
|
||||
auto* buf_ty = access.var->Type()->UnwrapRef();
|
||||
auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
|
||||
Symbol func =
|
||||
state.AtomicFunc(ctx, buf_ty, el_ty, intrinsic,
|
||||
access.var->As<sem::VariableUser>());
|
||||
|
||||
auto* buf = access.var->Declaration();
|
||||
auto* offset = access.offset->Build(ctx);
|
||||
auto* buf_ty = access.var->Type()->UnwrapRef();
|
||||
auto* el_ty = access.type->UnwrapRef()->As<sem::Atomic>()->Type();
|
||||
auto* insert_after = TypeDeclOf(access.var->Type());
|
||||
Symbol func =
|
||||
state.AtomicFunc(ctx, insert_after, buf_ty, el_ty, intrinsic,
|
||||
access.var->As<sem::VariableUser>());
|
||||
|
||||
ast::ExpressionList args{ctx.Clone(buf), offset};
|
||||
for (size_t i = 1; i < call_expr->params().size(); i++) {
|
||||
auto* arg = call_expr->params()[i];
|
||||
args.emplace_back(ctx.Clone(arg));
|
||||
}
|
||||
|
||||
ctx.Replace(call_expr, ctx.dst->Call(func, args));
|
||||
ast::ExpressionList args{ctx.Clone(buf), offset};
|
||||
for (size_t i = 1; i < call_expr->params().size(); i++) {
|
||||
auto* arg = call_expr->params()[i];
|
||||
args.emplace_back(ctx.Clone(arg));
|
||||
}
|
||||
return ctx.dst->Call(func, args);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -964,36 +931,32 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
if (access_it == state.accesses.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto access = std::move(access_it->second);
|
||||
|
||||
auto* buf = access.var->Declaration();
|
||||
auto* offset = access.offset->Build(ctx);
|
||||
auto* buf_ty = access.var->Type()->UnwrapRef();
|
||||
auto* el_ty = access.type->UnwrapRef();
|
||||
auto* insert_after = TypeDeclOf(access.var->Type());
|
||||
Symbol func = state.LoadFunc(ctx, insert_after, buf_ty, el_ty,
|
||||
access.var->As<sem::VariableUser>());
|
||||
|
||||
auto* load = ctx.dst->Call(func, ctx.Clone(buf), offset);
|
||||
|
||||
ctx.Replace(expr, load);
|
||||
BufferAccess access = access_it->second;
|
||||
ctx.Replace(expr, [=, &ctx, &state] {
|
||||
auto* buf = access.var->Declaration();
|
||||
auto* offset = access.offset->Build(ctx);
|
||||
auto* buf_ty = access.var->Type()->UnwrapRef();
|
||||
auto* el_ty = access.type->UnwrapRef();
|
||||
Symbol func = state.LoadFunc(ctx, buf_ty, el_ty,
|
||||
access.var->As<sem::VariableUser>());
|
||||
return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
|
||||
});
|
||||
}
|
||||
|
||||
// And replace all storage and uniform buffer assignments with stores
|
||||
for (auto& store : state.stores) {
|
||||
auto* buf = store.target.var->Declaration();
|
||||
auto* offset = store.target.offset->Build(ctx);
|
||||
auto* buf_ty = store.target.var->Type()->UnwrapRef();
|
||||
auto* el_ty = store.target.type->UnwrapRef();
|
||||
auto* value = store.assignment->rhs();
|
||||
auto* insert_after = TypeDeclOf(store.target.var->Type());
|
||||
Symbol func = state.StoreFunc(ctx, insert_after, buf_ty, el_ty,
|
||||
store.target.var->As<sem::VariableUser>());
|
||||
|
||||
auto* call = ctx.dst->Call(func, ctx.Clone(buf), offset, ctx.Clone(value));
|
||||
|
||||
ctx.Replace(store.assignment, ctx.dst->create<ast::CallStatement>(call));
|
||||
for (auto store : state.stores) {
|
||||
ctx.Replace(store.assignment, [=, &ctx, &state] {
|
||||
auto* buf = store.target.var->Declaration();
|
||||
auto* offset = store.target.offset->Build(ctx);
|
||||
auto* buf_ty = store.target.var->Type()->UnwrapRef();
|
||||
auto* el_ty = store.target.type->UnwrapRef();
|
||||
auto* value = store.assignment->rhs();
|
||||
Symbol func = state.StoreFunc(ctx, buf_ty, el_ty,
|
||||
store.target.var->As<sem::VariableUser>());
|
||||
auto* call = ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset,
|
||||
ctx.Clone(value));
|
||||
return ctx.dst->create<ast::CallStatement>(call);
|
||||
});
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
@@ -106,6 +106,8 @@ struct SB {
|
||||
v : array<vec3<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32) -> i32
|
||||
|
||||
@@ -182,8 +184,6 @@ fn tint_symbol_21([[internal(disable_validation__ignore_constructible_function_p
|
||||
return array<vec3<f32>, 2>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)));
|
||||
}
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var a : i32 = tint_symbol(sb, 0u);
|
||||
@@ -300,6 +300,8 @@ struct UB {
|
||||
v : array<vec3<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> ub : UB;
|
||||
|
||||
[[internal(intrinsic_load_uniform_i32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : UB, offset : u32) -> i32
|
||||
|
||||
@@ -376,8 +378,6 @@ fn tint_symbol_21([[internal(disable_validation__ignore_constructible_function_p
|
||||
return array<vec3<f32>, 2>(tint_symbol_8(buffer, (offset + 0u)), tint_symbol_8(buffer, (offset + 16u)));
|
||||
}
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> ub : UB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var a : i32 = tint_symbol(ub, 0u);
|
||||
@@ -494,6 +494,8 @@ struct SB {
|
||||
v : array<vec3<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32, value : i32)
|
||||
|
||||
@@ -589,8 +591,6 @@ fn tint_symbol_21([[internal(disable_validation__ignore_constructible_function_p
|
||||
tint_symbol_8(buffer, (offset + 16u), value[1u]);
|
||||
}
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
tint_symbol(sb, 0u, i32());
|
||||
@@ -686,6 +686,8 @@ struct SB {
|
||||
v : array<vec3<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_load_storage_i32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32) -> i32
|
||||
|
||||
@@ -766,8 +768,6 @@ fn tint_symbol_22([[internal(disable_validation__ignore_constructible_function_p
|
||||
return SB(tint_symbol(buffer, (offset + 0u)), tint_symbol_1(buffer, (offset + 4u)), tint_symbol_2(buffer, (offset + 8u)), tint_symbol_3(buffer, (offset + 16u)), tint_symbol_4(buffer, (offset + 24u)), tint_symbol_5(buffer, (offset + 32u)), tint_symbol_6(buffer, (offset + 48u)), tint_symbol_7(buffer, (offset + 64u)), tint_symbol_8(buffer, (offset + 80u)), tint_symbol_9(buffer, (offset + 96u)), tint_symbol_10(buffer, (offset + 112u)), tint_symbol_11(buffer, (offset + 128u)), tint_symbol_12(buffer, (offset + 144u)), tint_symbol_13(buffer, (offset + 160u)), tint_symbol_14(buffer, (offset + 192u)), tint_symbol_15(buffer, (offset + 224u)), tint_symbol_16(buffer, (offset + 256u)), tint_symbol_17(buffer, (offset + 304u)), tint_symbol_18(buffer, (offset + 352u)), tint_symbol_19(buffer, (offset + 384u)), tint_symbol_20(buffer, (offset + 448u)), tint_symbol_21(buffer, (offset + 512u)));
|
||||
}
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var x : SB = tint_symbol_22(sb, 0u);
|
||||
@@ -842,6 +842,8 @@ struct SB {
|
||||
v : array<vec3<f32>, 2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32, value : i32)
|
||||
|
||||
@@ -962,8 +964,6 @@ fn tint_symbol_22([[internal(disable_validation__ignore_constructible_function_p
|
||||
tint_symbol_21(buffer, (offset + 512u), value.v);
|
||||
}
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
tint_symbol_22(sb, 0u, SB());
|
||||
@@ -1031,11 +1031,11 @@ struct SB {
|
||||
b : [[stride(256)]] array<S2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32) -> f32
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var x : f32 = tint_symbol(sb, 1224u);
|
||||
@@ -1099,11 +1099,11 @@ struct SB {
|
||||
b : [[stride(256)]] array<S2>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32) -> f32
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var i : i32 = 4;
|
||||
@@ -1186,11 +1186,11 @@ struct SB {
|
||||
b : A2_Array;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_load_storage_f32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32) -> f32
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
var i : i32 = 4;
|
||||
@@ -1250,6 +1250,8 @@ struct SB {
|
||||
b : atomic<u32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[internal(intrinsic_atomic_store_storage_i32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32, param_1 : i32)
|
||||
|
||||
@@ -1310,8 +1312,6 @@ fn tint_symbol_18([[internal(disable_validation__ignore_constructible_function_p
|
||||
[[internal(intrinsic_atomic_compare_exchange_weak_storage_u32), internal(disable_validation__function_has_no_body)]]
|
||||
fn tint_symbol_19([[internal(disable_validation__ignore_constructible_function_parameter)]] buffer : SB, offset : u32, param_1 : u32, param_2 : u32) -> vec2<u32>
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> sb : SB;
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main() {
|
||||
tint_symbol(sb, 16u, 123);
|
||||
|
||||
Reference in New Issue
Block a user