tint/transform: Remove use of StorageClass on parameter

Parameters don't have storage classes or access qualifiers. This was
just (ab)using the fact that a parameter uses the same AST type as a
'var'.

Also simplify the parameter disable validation logic.

Bug: tint:1582
Change-Id: Ic218078a410f991e7956e6cb23621a94a69b75a3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93603
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2022-06-15 19:32:37 +00:00 committed by Dawn LUCI CQ
parent 357b5eba4c
commit 2032d03400
9 changed files with 734 additions and 725 deletions

View File

@ -35,8 +35,8 @@ std::string DisableValidationAttribute::InternalName() const {
return "disable_validation__ignore_storage_class";
case DisabledValidation::kEntryPointParameter:
return "disable_validation__entry_point_parameter";
case DisabledValidation::kIgnoreConstructibleFunctionParameter:
return "disable_validation__ignore_constructible_function_parameter";
case DisabledValidation::kFunctionParameter:
return "disable_validation__function_parameter";
case DisabledValidation::kIgnoreStrideAttribute:
return "disable_validation__ignore_stride";
case DisabledValidation::kIgnoreInvalidPointerArgument:

View File

@ -24,27 +24,23 @@ namespace tint::ast {
/// Enumerator of validation features that can be disabled with a
/// DisableValidationAttribute attribute.
enum class DisabledValidation {
/// When applied to a function, the validator will not complain there is no
/// body to a function.
/// When applied to a function, the validator will not complain there is no body to a function.
kFunctionHasNoBody,
/// When applied to a module-scoped variable, the validator will not complain
/// if two resource variables have the same binding points.
/// When applied to a module-scoped variable, the validator will not complain if two resource
/// variables have the same binding points.
kBindingPointCollision,
/// When applied to a variable, the validator will not complain about the
/// declared storage class.
/// When applied to a variable, the validator will not complain about the declared storage
/// class.
kIgnoreStorageClass,
/// When applied to an entry-point function parameter, the validator will not
/// check for entry IO attributes.
/// When applied to an entry-point function parameter, the validator will not check for entry IO
/// attributes.
kEntryPointParameter,
/// When applied to a function parameter, the validator will not
/// check if parameter type is constructible
kIgnoreConstructibleFunctionParameter,
/// When applied to a member attribute, a stride attribute may be applied to
/// non-array types.
/// When applied to a function parameter, the parameter will not be validated.
kFunctionParameter,
/// When applied to a member attribute, a stride attribute may be applied to non-array types.
kIgnoreStrideAttribute,
/// When applied to a pointer function parameter, the validator will not
/// require a function call argument passed for that parameter to have a
/// certain form.
/// When applied to a pointer function parameter, the validator will not require a function call
/// argument passed for that parameter to have a certain form.
kIgnoreInvalidPointerArgument,
};

View File

@ -722,19 +722,20 @@ bool Validator::FunctionParameter(const ast::Function* func, const sem::Variable
auto* decl = var->Declaration();
if (IsValidationDisabled(decl->attributes, ast::DisabledValidation::kFunctionParameter)) {
return true;
}
for (auto* attr : decl->attributes) {
if (!func->IsEntryPoint() && !attr->Is<ast::InternalAttribute>()) {
AddError("attribute is not valid for non-entry point function parameters",
attr->source);
return false;
} else if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InvariantAttribute,
ast::LocationAttribute, ast::InterpolateAttribute,
ast::InternalAttribute>() &&
}
if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InvariantAttribute, ast::LocationAttribute,
ast::InterpolateAttribute, ast::InternalAttribute>() &&
(IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kEntryPointParameter) &&
IsValidationEnabled(
decl->attributes,
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter))) {
ast::DisabledValidation::kEntryPointParameter))) {
AddError("attribute is not valid for function parameters", attr->source);
return false;
}
@ -753,9 +754,7 @@ bool Validator::FunctionParameter(const ast::Function* func, const sem::Variable
}
if (IsPlain(var->Type())) {
if (!var->Type()->IsConstructible() &&
IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter)) {
if (!var->Type()->IsConstructible()) {
AddError("store type of function parameter must be a constructible type", decl->source);
return false;
}
@ -964,9 +963,8 @@ bool Validator::Function(const sem::Function* func, ast::PipelineStage stage) co
ast::InvariantAttribute>() &&
(IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kEntryPointParameter) &&
IsValidationEnabled(
decl->attributes,
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter))) {
IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kFunctionParameter))) {
AddError("attribute is not valid for entry point return types", attr->source);
return false;
}

View File

@ -23,6 +23,7 @@
#include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/statement.h"
#include "src/tint/sem/struct.h"
#include "src/tint/sem/variable.h"
@ -89,22 +90,20 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
// get_buffer_size_intrinsic() emits the function decorated with
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
// [RW]ByteAddressBuffer.GetDimensions().
std::unordered_map<const sem::Type*, Symbol> buffer_size_intrinsics;
auto get_buffer_size_intrinsic = [&](const sem::Type* buffer_type) {
std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics;
auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = ctx.dst->Sym();
auto* type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter);
ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
name,
ast::ParameterList{
// Note: The buffer parameter requires the kStorage StorageClass
// in order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(ctx.dst->Sym("buffer"),
ast::StorageClass::kStorage,
ast::Access::kUndefined, type, true, false,
nullptr, ast::AttributeList{disable_validation}),
ctx.dst->Param("buffer",
ctx.dst->ty.pointer(type, buffer_type->StorageClass(),
buffer_type->Access()),
{disable_validation}),
ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(),
ast::StorageClass::kFunction)),
},
@ -128,10 +127,10 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// We're dealing with an arrayLength() call
// A runtime-sized array can only appear as the store type of a
// variable, or the last element of a structure (which cannot itself
// be nested). Given that we require SimplifyPointers, we can assume
// that the arrayLength() call has one of two forms:
// A runtime-sized array can only appear as the store type of a variable, or the
// last element of a structure (which cannot itself be nested). Given that we
// require SimplifyPointers, we can assume that the arrayLength() call has one
// of two forms:
// arrayLength(&struct_var.array_member)
// arrayLength(&array_var)
auto* arg = call_expr->args[0];
@ -152,10 +151,9 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
break;
}
auto* storage_buffer_var = storage_buffer_sem->Variable();
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
auto* storage_buffer_type = storage_buffer_sem->Type()->As<sem::Reference>();
// Generate BufferSizeIntrinsic for this storage type if we haven't
// already
// Generate BufferSizeIntrinsic for this storage type if we haven't already
auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
// Find the current statement block
@ -177,7 +175,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
// BufferSizeIntrinsic(X, ARGS...) is
// translated to:
// X.GetDimensions(ARGS..) by the writer
buffer_size, ctx.Clone(storage_buffer_expr),
buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)),
ctx.dst->AddressOf(
ctx.dst->Expr(buffer_size_result->variable->symbol))));
@ -188,22 +186,26 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
auto name = ctx.dst->Sym();
const ast::Expression* total_size =
ctx.dst->Expr(buffer_size_result->variable);
const sem::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
const sem::Array* array_type = Switch(
storage_buffer_type->StoreType(),
[&](const sem::Struct* str) {
// The variable is a struct, so subtract the byte offset of
// the array member.
auto* array_member_sem = str->Members().back();
array_type = array_member_sem->Type()->As<sem::Array>();
total_size =
ctx.dst->Sub(total_size, u32(array_member_sem->Offset()));
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
array_type = arr;
} else {
return array_member_sem->Type()->As<sem::Array>();
},
[&](const sem::Array* arr) { return arr; });
if (!array_type) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be "
"&array_var or &struct_var.array_member";
return name;
}
uint32_t array_stride = array_type->Size();
auto* array_length_var = ctx.dst->Decl(
ctx.dst->Let(name, ctx.dst->ty.u32(),

View File

@ -76,14 +76,14 @@ fn main() {
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
var len : u32 = tint_symbol_2;
}
@ -111,7 +111,7 @@ fn main() {
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB {
x : i32,
@ -123,7 +123,7 @@ struct SB {
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var len : u32 = tint_symbol_2;
}
@ -149,7 +149,7 @@ fn main() {
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<S>, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<S>, read>, result : ptr<function, u32>)
struct S {
f : f32,
@ -160,7 +160,7 @@ struct S {
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(arr, &(tint_symbol_1));
tint_symbol(&(arr), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
let len = tint_symbol_2;
}
@ -186,7 +186,7 @@ fn main() {
)";
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<array<S, 4u>>, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<array<S, 4u>>, read>, result : ptr<function, u32>)
struct S {
f : f32,
@ -197,7 +197,7 @@ struct S {
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(arr, &(tint_symbol_1));
tint_symbol(&(arr), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 16u);
let len = tint_symbol_2;
}
@ -222,14 +222,14 @@ fn main() {
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
@group(0) @binding(0) var<storage, read> sb : array<i32>;
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
var a : u32 = tint_symbol_2;
var b : u32 = tint_symbol_2;
@ -261,7 +261,7 @@ fn main() {
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB {
x : i32,
@ -273,7 +273,7 @@ struct SB {
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var a : u32 = tint_symbol_2;
var b : u32 = tint_symbol_2;
@ -309,7 +309,7 @@ fn main() {
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB {
x : i32,
@ -322,13 +322,13 @@ struct SB {
fn main() {
if (true) {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1));
tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var len : u32 = tint_symbol_2;
} else {
if (true) {
var tint_symbol_3 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_3));
tint_symbol(&(sb), &(tint_symbol_3));
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
var len : u32 = tint_symbol_4;
}
@ -370,13 +370,13 @@ fn main() {
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB1, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>)
fn tint_symbol_3(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB2, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
fn tint_symbol_6(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
struct SB1 {
x : i32,
@ -397,13 +397,13 @@ struct SB2 {
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb1, &(tint_symbol_1));
tint_symbol(&(sb1), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var tint_symbol_4 : u32 = 0u;
tint_symbol_3(sb2, &(tint_symbol_4));
tint_symbol_3(&(sb2), &(tint_symbol_4));
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
var tint_symbol_7 : u32 = 0u;
tint_symbol_6(sb3, &(tint_symbol_7));
tint_symbol_6(&(sb3), &(tint_symbol_7));
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
var len1 : u32 = tint_symbol_2;
var len2 : u32 = tint_symbol_5;
@ -440,7 +440,7 @@ fn main() {
auto* expect =
R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB {
x : i32,
@ -454,12 +454,12 @@ struct SB {
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(a, &(tint_symbol_1));
tint_symbol(&(a), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var a_1 : u32 = tint_symbol_2;
{
var tint_symbol_3 : u32 = 0u;
tint_symbol(a, &(tint_symbol_3));
tint_symbol(&(a), &(tint_symbol_3));
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
var b_1 : u32 = tint_symbol_4;
}
@ -500,24 +500,24 @@ struct SB2 {
auto* expect = R"(
@internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>)
fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB1, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>)
fn tint_symbol_3(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB2, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size)
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>)
fn tint_symbol_6(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
@compute @workgroup_size(1)
fn main() {
var tint_symbol_1 : u32 = 0u;
tint_symbol(sb1, &(tint_symbol_1));
tint_symbol(&(sb1), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var tint_symbol_4 : u32 = 0u;
tint_symbol_3(sb2, &(tint_symbol_4));
tint_symbol_3(&(sb2), &(tint_symbol_4));
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
var tint_symbol_7 : u32 = 0u;
tint_symbol_6(sb3, &(tint_symbol_7));
tint_symbol_6(&(sb3), &(tint_symbol_7));
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
var len1 : u32 = tint_symbol_2;
var len2 : u32 = tint_symbol_5;

View File

@ -98,29 +98,32 @@ struct OffsetBinOp : Offset {
/// LoadStoreKey is the unordered map key to a load or store intrinsic.
struct LoadStoreKey {
ast::StorageClass const storage_class; // buffer storage class
ast::Access const access; // buffer access
sem::Type const* buf_ty = nullptr; // buffer type
sem::Type const* el_ty = nullptr; // element type
bool operator==(const LoadStoreKey& rhs) const {
return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty && el_ty == rhs.el_ty;
return storage_class == rhs.storage_class && access == rhs.access && buf_ty == rhs.buf_ty &&
el_ty == rhs.el_ty;
}
struct Hasher {
inline std::size_t operator()(const LoadStoreKey& u) const {
return utils::Hash(u.storage_class, u.buf_ty, u.el_ty);
return utils::Hash(u.storage_class, u.access, u.buf_ty, u.el_ty);
}
};
};
/// AtomicKey is the unordered map key to an atomic intrinsic.
struct AtomicKey {
ast::Access const access; // buffer access
sem::Type const* buf_ty = nullptr; // buffer type
sem::Type const* el_ty = nullptr; // element type
sem::BuiltinType const op; // atomic op
bool operator==(const AtomicKey& rhs) const {
return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
return access == rhs.access && buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
}
struct Hasher {
inline std::size_t operator()(const AtomicKey& u) const {
return utils::Hash(u.buf_ty, u.el_ty, u.op);
return utils::Hash(u.access, u.buf_ty, u.el_ty, u.op);
}
};
};
@ -420,10 +423,10 @@ struct DecomposeMemoryAccess::State {
return access;
}
/// LoadFunc() returns a symbol to an intrinsic function that loads an element
/// of type `el_ty` from a storage or uniform buffer of type `buf_ty`.
/// LoadFunc() returns a symbol to an intrinsic function that loads an element of type `el_ty`
/// from a storage or uniform buffer of type `buf_ty`.
/// The emitted function has the signature:
/// `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// `fn load(buf : ptr<SC, buf_ty, A>, offset : u32) -> el_ty`
/// @param buf_ty the storage or uniform buffer type
/// @param el_ty the storage or uniform buffer element type
/// @param var_user the variable user
@ -432,18 +435,13 @@ struct DecomposeMemoryAccess::State {
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
auto* disable_validation =
b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
auto access = var_user->Variable()->Access();
return utils::GetOrCreate(
load_funcs, LoadStoreKey{storage_class, access, buf_ty, el_ty}, [&] {
ast::ParameterList params = {
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer or cbuffer
// array.
b.create<ast::Variable>(b.Sym("buffer"), storage_class,
var_user->Variable()->Access(), buf_ast_ty, true, false,
nullptr, ast::AttributeList{disable_validation}),
b.Param("buffer",
b.ty.pointer(CreateASTTypeFor(ctx, buf_ty), storage_class, access),
{b.Disable(ast::DisabledValidation::kFunctionParameter)}),
b.Param("offset", b.ty.u32()),
};
@ -460,10 +458,10 @@ struct DecomposeMemoryAccess::State {
ast::AttributeList{});
b.AST().AddFunction(func);
} else if (auto* arr_ty = el_ty->As<sem::Array>()) {
// fn load_func(buf : buf_ty, offset : u32) -> array<T, N> {
// fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
// var arr : array<T, N>;
// for (var i = 0u; i < array_count; i = i + 1) {
// arr[i] = el_load_func(buf, offset + i * array_stride)
// arr[i] = el_load_func(buffer, offset + i * array_stride)
// }
// return arr;
// }
@ -514,7 +512,7 @@ struct DecomposeMemoryAccess::State {
/// StoreFunc() returns a symbol to an intrinsic function that stores an
/// element of type `el_ty` to a storage buffer of type `buf_ty`.
/// The function has the signature:
/// `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// `fn store(buf : ptr<SC, buf_ty, A>, offset : u32, value : el_ty)`
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @param var_user the variable user
@ -523,20 +521,15 @@ struct DecomposeMemoryAccess::State {
const sem::Type* el_ty,
const sem::VariableUser* var_user) {
auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
auto* disable_validation =
b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
auto access = var_user->Variable()->Access();
return utils::GetOrCreate(
store_funcs, LoadStoreKey{storage_class, access, buf_ty, el_ty}, [&] {
ast::ParameterList params{
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
b.create<ast::Variable>(b.Sym("buffer"), storage_class,
var_user->Variable()->Access(), buf_ast_ty, true, false,
nullptr, ast::AttributeList{disable_validation}),
b.Param("buffer",
b.ty.pointer(CreateASTTypeFor(ctx, buf_ty), storage_class, access),
{b.Disable(ast::DisabledValidation::kFunctionParameter)}),
b.Param("offset", b.ty.u32()),
b.Param("value", el_ast_ty),
b.Param("value", CreateASTTypeFor(ctx, el_ty)),
};
auto name = b.Sym();
@ -551,48 +544,61 @@ struct DecomposeMemoryAccess::State {
ast::AttributeList{});
b.AST().AddFunction(func);
} else {
ast::StatementList body;
if (auto* arr_ty = el_ty->As<sem::Array>()) {
// fn store_func(buf : buf_ty, offset : u32, value : el_ty) {
auto body = Switch(
el_ty, //
[&](const sem::Array* arr_ty) {
// fn store_func(buffer : buf_ty, offset : u32, value : el_ty) {
// var array = value; // No dynamic indexing on constant arrays
// for (var i = 0u; i < array_count; i = i + 1) {
// arr[i] = el_store_func(buf, offset + i * array_stride,
// arr[i] = el_store_func(buffer, offset + i * array_stride,
// value[i])
// }
// return arr;
// }
auto* array = b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
auto store = StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto store =
StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0_u));
auto* for_init = b.Decl(i);
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
auto* store_stmt = b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
auto* for_loop = b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
auto* el_offset =
b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
auto* store_stmt =
b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
auto* for_loop =
b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
body = {b.Decl(array), for_loop};
} else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
return ast::StatementList{b.Decl(array), for_loop};
},
[&](const sem::Matrix* mat_ty) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
ast::StatementList stmts;
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
auto* access = b.IndexAccessor("value", u32(i));
auto* call = b.Call(store, "buffer", offset, access);
body.emplace_back(b.CallStmt(call));
auto* element = b.IndexAccessor("value", u32(i));
auto* call = b.Call(store, "buffer", offset, element);
stmts.emplace_back(b.CallStmt(call));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
return stmts;
},
[&](const sem::Struct* str) {
ast::StatementList stmts;
for (auto* member : str->Members()) {
auto* offset = b.Add("offset", u32(member->Offset()));
auto* access =
b.MemberAccessor("value", ctx.Clone(member->Declaration()->symbol));
Symbol store = StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
auto* call = b.Call(store, "buffer", offset, access);
body.emplace_back(b.CallStmt(call));
}
auto* element = b.MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol));
Symbol store =
StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
auto* call = b.Call(store, "buffer", offset, element);
stmts.emplace_back(b.CallStmt(call));
}
return stmts;
});
b.Func(name, params, b.ty.void_(), body);
}
@ -603,7 +609,7 @@ struct DecomposeMemoryAccess::State {
/// AtomicFunc() returns a symbol to an intrinsic function that performs an
/// atomic operation from a storage buffer of type `buf_ty`. The function has
/// the signature:
// `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T`
// `fn atomic_op(buf : ptr<storage, buf_ty, A>, offset : u32, ...) -> T`
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @param intrinsic the atomic intrinsic
@ -614,19 +620,15 @@ struct DecomposeMemoryAccess::State {
const sem::Builtin* intrinsic,
const sem::VariableUser* var_user) {
auto op = intrinsic->Type();
return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty);
auto* disable_validation =
b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
auto access = var_user->Variable()->Access();
return utils::GetOrCreate(atomic_funcs, AtomicKey{access, buf_ty, el_ty, op}, [&] {
// The first parameter to all WGSL atomics is the expression to the
// atomic. This is replaced with two parameters: the buffer and offset.
ast::ParameterList params = {
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage,
var_user->Variable()->Access(), buf_ast_ty, true, false,
nullptr, ast::AttributeList{disable_validation}),
b.Param("buffer",
b.ty.pointer(CreateASTTypeFor(ctx, buf_ty), ast::StorageClass::kStorage,
access),
{b.Disable(ast::DisabledValidation::kFunctionParameter)}),
b.Param("offset", b.ty.u32()),
};
@ -910,8 +912,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// arrayLength(X)
// Don't convert X into a load, this builtin actually requires the
// real pointer.
// Don't convert X into a load, this builtin actually requires the real pointer.
state.TakeAccess(call_expr->args[0]);
continue;
}
@ -926,7 +927,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
Symbol func = state.AtomicFunc(buf_ty, el_ty, builtin,
access.var->As<sem::VariableUser>());
ast::ExpressionList args{ctx.Clone(buf), offset};
ast::ExpressionList args{ctx.dst->AddressOf(ctx.Clone(buf)), offset};
for (size_t i = 1; i < call_expr->args.size(); i++) {
auto* arg = call_expr->args[i];
args.emplace_back(ctx.Clone(arg));
@ -948,26 +949,26 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
}
BufferAccess access = access_it->second;
ctx.Replace(expr, [=, &ctx, &state] {
auto* buf = access.var->Declaration();
auto* buf = ctx.dst->AddressOf(ctx.CloneWithoutTransform(access.var->Declaration()));
auto* offset = access.offset->Build(ctx);
auto* buf_ty = access.var->Type()->UnwrapRef();
auto* el_ty = access.type->UnwrapRef();
Symbol func = state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset);
return ctx.dst->Call(func, buf, offset);
});
}
// And replace all storage and uniform buffer assignments with stores
for (auto store : state.stores) {
ctx.Replace(store.assignment, [=, &ctx, &state] {
auto* buf = store.target.var->Declaration();
auto* buf =
ctx.dst->AddressOf(ctx.CloneWithoutTransform((store.target.var->Declaration())));
auto* offset = store.target.offset->Build(ctx);
auto* buf_ty = store.target.var->Type()->UnwrapRef();
auto* el_ty = store.target.type->UnwrapRef();
auto* value = store.assignment->rhs;
Symbol func = state.StoreFunc(buf_ty, el_ty, store.target.var->As<sem::VariableUser>());
auto* call =
ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset, ctx.Clone(value));
auto* call = ctx.dst->Call(func, buf, offset, ctx.Clone(value));
return ctx.dst->CallStmt(call);
});
}

File diff suppressed because it is too large Load Diff

View File

@ -49,7 +49,8 @@ Output Manager::Run(const Program* program, const DataMap& data) const {
Output out;
for (const auto& transform : transforms_) {
if (!transform->ShouldRun(in, data)) {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name << std::endl);
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name
<< std::endl);
continue;
}
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));

View File

@ -2776,14 +2776,25 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
first = false;
auto const* type = v->Type();
auto storage_class = ast::StorageClass::kNone;
auto access = ast::Access::kUndefined;
if (auto* ptr = type->As<sem::Pointer>()) {
// Transform pointer parameters in to `inout` parameters.
// The WGSL spec is highly restrictive in what can be passed in pointer
// parameters, which allows for this transformation. See:
// https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
out << "inout ";
type = ptr->StoreType();
switch (ptr->StorageClass()) {
case ast::StorageClass::kStorage:
case ast::StorageClass::kUniform:
// Not allowed by WGSL, but is used by certain transforms (e.g. DMA) to pass
// storage buffers and uniform buffers down into transform-generated
// functions. In this situation we want to generate the parameter without an
// 'inout', using the storage class and access from the pointer.
storage_class = ptr->StorageClass();
access = ptr->Access();
break;
default:
// Transform regular WGSL pointer parameters in to `inout` parameters.
out << "inout ";
}
}
// Note: WGSL only allows for StorageClass::kNone on parameters, however
@ -2792,7 +2803,7 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
// StorageClass::kStorage or StorageClass::kUniform. This is required to
// correctly translate the parameter to a [RW]ByteAddressBuffer for
// storage buffers and a uint4[N] for uniform buffers.
if (!EmitTypeAndName(out, type, v->StorageClass(), v->Access(),
if (!EmitTypeAndName(out, type, storage_class, access,
builder_.Symbols().NameFor(v->Declaration()->symbol))) {
return false;
}