tint/transform: Remove use of StorageClass on parameter

Parameters don't have storage classes or access qualifiers. This was
just (ab)using the fact that a parameter uses the same AST type as a
'var'.

Also simplify the parameter disable validation logic.

Bug: tint:1582
Change-Id: Ic218078a410f991e7956e6cb23621a94a69b75a3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93603
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2022-06-15 19:32:37 +00:00 committed by Dawn LUCI CQ
parent 357b5eba4c
commit 2032d03400
9 changed files with 734 additions and 725 deletions

View File

@ -35,8 +35,8 @@ std::string DisableValidationAttribute::InternalName() const {
return "disable_validation__ignore_storage_class"; return "disable_validation__ignore_storage_class";
case DisabledValidation::kEntryPointParameter: case DisabledValidation::kEntryPointParameter:
return "disable_validation__entry_point_parameter"; return "disable_validation__entry_point_parameter";
case DisabledValidation::kIgnoreConstructibleFunctionParameter: case DisabledValidation::kFunctionParameter:
return "disable_validation__ignore_constructible_function_parameter"; return "disable_validation__function_parameter";
case DisabledValidation::kIgnoreStrideAttribute: case DisabledValidation::kIgnoreStrideAttribute:
return "disable_validation__ignore_stride"; return "disable_validation__ignore_stride";
case DisabledValidation::kIgnoreInvalidPointerArgument: case DisabledValidation::kIgnoreInvalidPointerArgument:

View File

@ -24,27 +24,23 @@ namespace tint::ast {
/// Enumerator of validation features that can be disabled with a /// Enumerator of validation features that can be disabled with a
/// DisableValidationAttribute attribute. /// DisableValidationAttribute attribute.
enum class DisabledValidation { enum class DisabledValidation {
/// When applied to a function, the validator will not complain there is no /// When applied to a function, the validator will not complain there is no body to a function.
/// body to a function.
kFunctionHasNoBody, kFunctionHasNoBody,
/// When applied to a module-scoped variable, the validator will not complain /// When applied to a module-scoped variable, the validator will not complain if two resource
/// if two resource variables have the same binding points. /// variables have the same binding points.
kBindingPointCollision, kBindingPointCollision,
/// When applied to a variable, the validator will not complain about the /// When applied to a variable, the validator will not complain about the declared storage
/// declared storage class. /// class.
kIgnoreStorageClass, kIgnoreStorageClass,
/// When applied to an entry-point function parameter, the validator will not /// When applied to an entry-point function parameter, the validator will not check for entry IO
/// check for entry IO attributes. /// attributes.
kEntryPointParameter, kEntryPointParameter,
/// When applied to a function parameter, the validator will not /// When applied to a function parameter, the parameter will not be validated.
/// check if parameter type is constructible kFunctionParameter,
kIgnoreConstructibleFunctionParameter, /// When applied to a member attribute, a stride attribute may be applied to non-array types.
/// When applied to a member attribute, a stride attribute may be applied to
/// non-array types.
kIgnoreStrideAttribute, kIgnoreStrideAttribute,
/// When applied to a pointer function parameter, the validator will not /// When applied to a pointer function parameter, the validator will not require a function call
/// require a function call argument passed for that parameter to have a /// argument passed for that parameter to have a certain form.
/// certain form.
kIgnoreInvalidPointerArgument, kIgnoreInvalidPointerArgument,
}; };

View File

@ -722,19 +722,20 @@ bool Validator::FunctionParameter(const ast::Function* func, const sem::Variable
auto* decl = var->Declaration(); auto* decl = var->Declaration();
if (IsValidationDisabled(decl->attributes, ast::DisabledValidation::kFunctionParameter)) {
return true;
}
for (auto* attr : decl->attributes) { for (auto* attr : decl->attributes) {
if (!func->IsEntryPoint() && !attr->Is<ast::InternalAttribute>()) { if (!func->IsEntryPoint() && !attr->Is<ast::InternalAttribute>()) {
AddError("attribute is not valid for non-entry point function parameters", AddError("attribute is not valid for non-entry point function parameters",
attr->source); attr->source);
return false; return false;
} else if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InvariantAttribute, }
ast::LocationAttribute, ast::InterpolateAttribute, if (!attr->IsAnyOf<ast::BuiltinAttribute, ast::InvariantAttribute, ast::LocationAttribute,
ast::InternalAttribute>() && ast::InterpolateAttribute, ast::InternalAttribute>() &&
(IsValidationEnabled(decl->attributes, (IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kEntryPointParameter) && ast::DisabledValidation::kEntryPointParameter))) {
IsValidationEnabled(
decl->attributes,
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter))) {
AddError("attribute is not valid for function parameters", attr->source); AddError("attribute is not valid for function parameters", attr->source);
return false; return false;
} }
@ -753,9 +754,7 @@ bool Validator::FunctionParameter(const ast::Function* func, const sem::Variable
} }
if (IsPlain(var->Type())) { if (IsPlain(var->Type())) {
if (!var->Type()->IsConstructible() && if (!var->Type()->IsConstructible()) {
IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter)) {
AddError("store type of function parameter must be a constructible type", decl->source); AddError("store type of function parameter must be a constructible type", decl->source);
return false; return false;
} }
@ -964,9 +963,8 @@ bool Validator::Function(const sem::Function* func, ast::PipelineStage stage) co
ast::InvariantAttribute>() && ast::InvariantAttribute>() &&
(IsValidationEnabled(decl->attributes, (IsValidationEnabled(decl->attributes,
ast::DisabledValidation::kEntryPointParameter) && ast::DisabledValidation::kEntryPointParameter) &&
IsValidationEnabled( IsValidationEnabled(decl->attributes,
decl->attributes, ast::DisabledValidation::kFunctionParameter))) {
ast::DisabledValidation::kIgnoreConstructibleFunctionParameter))) {
AddError("attribute is not valid for entry point return types", attr->source); AddError("attribute is not valid for entry point return types", attr->source);
return false; return false;
} }

View File

@ -23,6 +23,7 @@
#include "src/tint/sem/block_statement.h" #include "src/tint/sem/block_statement.h"
#include "src/tint/sem/call.h" #include "src/tint/sem/call.h"
#include "src/tint/sem/function.h" #include "src/tint/sem/function.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/statement.h" #include "src/tint/sem/statement.h"
#include "src/tint/sem/struct.h" #include "src/tint/sem/struct.h"
#include "src/tint/sem/variable.h" #include "src/tint/sem/variable.h"
@ -89,22 +90,20 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
// get_buffer_size_intrinsic() emits the function decorated with // get_buffer_size_intrinsic() emits the function decorated with
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to // BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
// [RW]ByteAddressBuffer.GetDimensions(). // [RW]ByteAddressBuffer.GetDimensions().
std::unordered_map<const sem::Type*, Symbol> buffer_size_intrinsics; std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics;
auto get_buffer_size_intrinsic = [&](const sem::Type* buffer_type) { auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = ctx.dst->Sym(); auto name = ctx.dst->Sym();
auto* type = CreateASTTypeFor(ctx, buffer_type); auto* type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation = auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter); ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter);
ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>( ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
name, name,
ast::ParameterList{ ast::ParameterList{
// Note: The buffer parameter requires the kStorage StorageClass ctx.dst->Param("buffer",
// in order for HLSL to emit this as a ByteAddressBuffer. ctx.dst->ty.pointer(type, buffer_type->StorageClass(),
ctx.dst->create<ast::Variable>(ctx.dst->Sym("buffer"), buffer_type->Access()),
ast::StorageClass::kStorage, {disable_validation}),
ast::Access::kUndefined, type, true, false,
nullptr, ast::AttributeList{disable_validation}),
ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(), ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(),
ast::StorageClass::kFunction)), ast::StorageClass::kFunction)),
}, },
@ -128,10 +127,10 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
if (builtin->Type() == sem::BuiltinType::kArrayLength) { if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// We're dealing with an arrayLength() call // We're dealing with an arrayLength() call
// A runtime-sized array can only appear as the store type of a // A runtime-sized array can only appear as the store type of a variable, or the
// variable, or the last element of a structure (which cannot itself // last element of a structure (which cannot itself be nested). Given that we
// be nested). Given that we require SimplifyPointers, we can assume // require SimplifyPointers, we can assume that the arrayLength() call has one
// that the arrayLength() call has one of two forms: // of two forms:
// arrayLength(&struct_var.array_member) // arrayLength(&struct_var.array_member)
// arrayLength(&array_var) // arrayLength(&array_var)
auto* arg = call_expr->args[0]; auto* arg = call_expr->args[0];
@ -152,10 +151,9 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
break; break;
} }
auto* storage_buffer_var = storage_buffer_sem->Variable(); auto* storage_buffer_var = storage_buffer_sem->Variable();
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); auto* storage_buffer_type = storage_buffer_sem->Type()->As<sem::Reference>();
// Generate BufferSizeIntrinsic for this storage type if we haven't // Generate BufferSizeIntrinsic for this storage type if we haven't already
// already
auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type); auto buffer_size = get_buffer_size_intrinsic(storage_buffer_type);
// Find the current statement block // Find the current statement block
@ -177,7 +175,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
// BufferSizeIntrinsic(X, ARGS...) is // BufferSizeIntrinsic(X, ARGS...) is
// translated to: // translated to:
// X.GetDimensions(ARGS..) by the writer // X.GetDimensions(ARGS..) by the writer
buffer_size, ctx.Clone(storage_buffer_expr), buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)),
ctx.dst->AddressOf( ctx.dst->AddressOf(
ctx.dst->Expr(buffer_size_result->variable->symbol)))); ctx.dst->Expr(buffer_size_result->variable->symbol))));
@ -188,22 +186,26 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
auto name = ctx.dst->Sym(); auto name = ctx.dst->Sym();
const ast::Expression* total_size = const ast::Expression* total_size =
ctx.dst->Expr(buffer_size_result->variable); ctx.dst->Expr(buffer_size_result->variable);
const sem::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<sem::Struct>()) { const sem::Array* array_type = Switch(
// The variable is a struct, so subtract the byte offset of storage_buffer_type->StoreType(),
// the array member. [&](const sem::Struct* str) {
auto* array_member_sem = str->Members().back(); // The variable is a struct, so subtract the byte offset of
array_type = array_member_sem->Type()->As<sem::Array>(); // the array member.
total_size = auto* array_member_sem = str->Members().back();
ctx.dst->Sub(total_size, u32(array_member_sem->Offset())); total_size =
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) { ctx.dst->Sub(total_size, u32(array_member_sem->Offset()));
array_type = arr; return array_member_sem->Type()->As<sem::Array>();
} else { },
[&](const sem::Array* arr) { return arr; });
if (!array_type) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be " << "expected form of arrayLength argument to be "
"&array_var or &struct_var.array_member"; "&array_var or &struct_var.array_member";
return name; return name;
} }
uint32_t array_stride = array_type->Size(); uint32_t array_stride = array_type->Size();
auto* array_length_var = ctx.dst->Decl( auto* array_length_var = ctx.dst->Decl(
ctx.dst->Let(name, ctx.dst->ty.u32(), ctx.dst->Let(name, ctx.dst->ty.u32(),

View File

@ -76,14 +76,14 @@ fn main() {
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
@group(0) @binding(0) var<storage, read> sb : array<i32>; @group(0) @binding(0) var<storage, read> sb : array<i32>;
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1)); tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u); let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
var len : u32 = tint_symbol_2; var len : u32 = tint_symbol_2;
} }
@ -111,7 +111,7 @@ fn main() {
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB { struct SB {
x : i32, x : i32,
@ -123,7 +123,7 @@ struct SB {
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1)); tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var len : u32 = tint_symbol_2; var len : u32 = tint_symbol_2;
} }
@ -149,7 +149,7 @@ fn main() {
)"; )";
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<S>, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<S>, read>, result : ptr<function, u32>)
struct S { struct S {
f : f32, f : f32,
@ -160,7 +160,7 @@ struct S {
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(arr, &(tint_symbol_1)); tint_symbol(&(arr), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u); let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
let len = tint_symbol_2; let len = tint_symbol_2;
} }
@ -186,7 +186,7 @@ fn main() {
)"; )";
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<array<S, 4u>>, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<array<S, 4u>>, read>, result : ptr<function, u32>)
struct S { struct S {
f : f32, f : f32,
@ -197,7 +197,7 @@ struct S {
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(arr, &(tint_symbol_1)); tint_symbol(&(arr), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 16u); let tint_symbol_2 : u32 = (tint_symbol_1 / 16u);
let len = tint_symbol_2; let len = tint_symbol_2;
} }
@ -222,14 +222,14 @@ fn main() {
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
@group(0) @binding(0) var<storage, read> sb : array<i32>; @group(0) @binding(0) var<storage, read> sb : array<i32>;
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1)); tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = (tint_symbol_1 / 4u); let tint_symbol_2 : u32 = (tint_symbol_1 / 4u);
var a : u32 = tint_symbol_2; var a : u32 = tint_symbol_2;
var b : u32 = tint_symbol_2; var b : u32 = tint_symbol_2;
@ -261,7 +261,7 @@ fn main() {
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB { struct SB {
x : i32, x : i32,
@ -273,7 +273,7 @@ struct SB {
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1)); tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var a : u32 = tint_symbol_2; var a : u32 = tint_symbol_2;
var b : u32 = tint_symbol_2; var b : u32 = tint_symbol_2;
@ -309,7 +309,7 @@ fn main() {
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB { struct SB {
x : i32, x : i32,
@ -322,13 +322,13 @@ struct SB {
fn main() { fn main() {
if (true) { if (true) {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_1)); tint_symbol(&(sb), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var len : u32 = tint_symbol_2; var len : u32 = tint_symbol_2;
} else { } else {
if (true) { if (true) {
var tint_symbol_3 : u32 = 0u; var tint_symbol_3 : u32 = 0u;
tint_symbol(sb, &(tint_symbol_3)); tint_symbol(&(sb), &(tint_symbol_3));
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u); let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
var len : u32 = tint_symbol_4; var len : u32 = tint_symbol_4;
} }
@ -370,13 +370,13 @@ fn main() {
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB1, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>) fn tint_symbol_3(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB2, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>) fn tint_symbol_6(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
struct SB1 { struct SB1 {
x : i32, x : i32,
@ -397,13 +397,13 @@ struct SB2 {
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(sb1, &(tint_symbol_1)); tint_symbol(&(sb1), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var tint_symbol_4 : u32 = 0u; var tint_symbol_4 : u32 = 0u;
tint_symbol_3(sb2, &(tint_symbol_4)); tint_symbol_3(&(sb2), &(tint_symbol_4));
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u); let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
var tint_symbol_7 : u32 = 0u; var tint_symbol_7 : u32 = 0u;
tint_symbol_6(sb3, &(tint_symbol_7)); tint_symbol_6(&(sb3), &(tint_symbol_7));
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u); let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
var len1 : u32 = tint_symbol_2; var len1 : u32 = tint_symbol_2;
var len2 : u32 = tint_symbol_5; var len2 : u32 = tint_symbol_5;
@ -440,7 +440,7 @@ fn main() {
auto* expect = auto* expect =
R"( R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB, read>, result : ptr<function, u32>)
struct SB { struct SB {
x : i32, x : i32,
@ -454,12 +454,12 @@ struct SB {
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(a, &(tint_symbol_1)); tint_symbol(&(a), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var a_1 : u32 = tint_symbol_2; var a_1 : u32 = tint_symbol_2;
{ {
var tint_symbol_3 : u32 = 0u; var tint_symbol_3 : u32 = 0u;
tint_symbol(a, &(tint_symbol_3)); tint_symbol(&(a), &(tint_symbol_3));
let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u); let tint_symbol_4 : u32 = ((tint_symbol_3 - 4u) / 4u);
var b_1 : u32 = tint_symbol_4; var b_1 : u32 = tint_symbol_4;
} }
@ -500,24 +500,24 @@ struct SB2 {
auto* expect = R"( auto* expect = R"(
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB1, result : ptr<function, u32>) fn tint_symbol(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB1, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol_3(@internal(disable_validation__ignore_constructible_function_parameter) buffer : SB2, result : ptr<function, u32>) fn tint_symbol_3(@internal(disable_validation__function_parameter) buffer : ptr<storage, SB2, read>, result : ptr<function, u32>)
@internal(intrinsic_buffer_size) @internal(intrinsic_buffer_size)
fn tint_symbol_6(@internal(disable_validation__ignore_constructible_function_parameter) buffer : array<i32>, result : ptr<function, u32>) fn tint_symbol_6(@internal(disable_validation__function_parameter) buffer : ptr<storage, array<i32>, read>, result : ptr<function, u32>)
@compute @workgroup_size(1) @compute @workgroup_size(1)
fn main() { fn main() {
var tint_symbol_1 : u32 = 0u; var tint_symbol_1 : u32 = 0u;
tint_symbol(sb1, &(tint_symbol_1)); tint_symbol(&(sb1), &(tint_symbol_1));
let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u); let tint_symbol_2 : u32 = ((tint_symbol_1 - 4u) / 4u);
var tint_symbol_4 : u32 = 0u; var tint_symbol_4 : u32 = 0u;
tint_symbol_3(sb2, &(tint_symbol_4)); tint_symbol_3(&(sb2), &(tint_symbol_4));
let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u); let tint_symbol_5 : u32 = ((tint_symbol_4 - 16u) / 16u);
var tint_symbol_7 : u32 = 0u; var tint_symbol_7 : u32 = 0u;
tint_symbol_6(sb3, &(tint_symbol_7)); tint_symbol_6(&(sb3), &(tint_symbol_7));
let tint_symbol_8 : u32 = (tint_symbol_7 / 4u); let tint_symbol_8 : u32 = (tint_symbol_7 / 4u);
var len1 : u32 = tint_symbol_2; var len1 : u32 = tint_symbol_2;
var len2 : u32 = tint_symbol_5; var len2 : u32 = tint_symbol_5;

View File

@ -98,29 +98,32 @@ struct OffsetBinOp : Offset {
/// LoadStoreKey is the unordered map key to a load or store intrinsic. /// LoadStoreKey is the unordered map key to a load or store intrinsic.
struct LoadStoreKey { struct LoadStoreKey {
ast::StorageClass const storage_class; // buffer storage class ast::StorageClass const storage_class; // buffer storage class
ast::Access const access; // buffer access
sem::Type const* buf_ty = nullptr; // buffer type sem::Type const* buf_ty = nullptr; // buffer type
sem::Type const* el_ty = nullptr; // element type sem::Type const* el_ty = nullptr; // element type
bool operator==(const LoadStoreKey& rhs) const { bool operator==(const LoadStoreKey& rhs) const {
return storage_class == rhs.storage_class && buf_ty == rhs.buf_ty && el_ty == rhs.el_ty; return storage_class == rhs.storage_class && access == rhs.access && buf_ty == rhs.buf_ty &&
el_ty == rhs.el_ty;
} }
struct Hasher { struct Hasher {
inline std::size_t operator()(const LoadStoreKey& u) const { inline std::size_t operator()(const LoadStoreKey& u) const {
return utils::Hash(u.storage_class, u.buf_ty, u.el_ty); return utils::Hash(u.storage_class, u.access, u.buf_ty, u.el_ty);
} }
}; };
}; };
/// AtomicKey is the unordered map key to an atomic intrinsic. /// AtomicKey is the unordered map key to an atomic intrinsic.
struct AtomicKey { struct AtomicKey {
ast::Access const access; // buffer access
sem::Type const* buf_ty = nullptr; // buffer type sem::Type const* buf_ty = nullptr; // buffer type
sem::Type const* el_ty = nullptr; // element type sem::Type const* el_ty = nullptr; // element type
sem::BuiltinType const op; // atomic op sem::BuiltinType const op; // atomic op
bool operator==(const AtomicKey& rhs) const { bool operator==(const AtomicKey& rhs) const {
return buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op; return access == rhs.access && buf_ty == rhs.buf_ty && el_ty == rhs.el_ty && op == rhs.op;
} }
struct Hasher { struct Hasher {
inline std::size_t operator()(const AtomicKey& u) const { inline std::size_t operator()(const AtomicKey& u) const {
return utils::Hash(u.buf_ty, u.el_ty, u.op); return utils::Hash(u.access, u.buf_ty, u.el_ty, u.op);
} }
}; };
}; };
@ -420,10 +423,10 @@ struct DecomposeMemoryAccess::State {
return access; return access;
} }
/// LoadFunc() returns a symbol to an intrinsic function that loads an element /// LoadFunc() returns a symbol to an intrinsic function that loads an element of type `el_ty`
/// of type `el_ty` from a storage or uniform buffer of type `buf_ty`. /// from a storage or uniform buffer of type `buf_ty`.
/// The emitted function has the signature: /// The emitted function has the signature:
/// `fn load(buf : buf_ty, offset : u32) -> el_ty` /// `fn load(buf : ptr<SC, buf_ty, A>, offset : u32) -> el_ty`
/// @param buf_ty the storage or uniform buffer type /// @param buf_ty the storage or uniform buffer type
/// @param el_ty the storage or uniform buffer element type /// @param el_ty the storage or uniform buffer element type
/// @param var_user the variable user /// @param var_user the variable user
@ -432,89 +435,84 @@ struct DecomposeMemoryAccess::State {
const sem::Type* el_ty, const sem::Type* el_ty,
const sem::VariableUser* var_user) { const sem::VariableUser* var_user) {
auto storage_class = var_user->Variable()->StorageClass(); auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(load_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] { auto access = var_user->Variable()->Access();
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty); return utils::GetOrCreate(
auto* disable_validation = load_funcs, LoadStoreKey{storage_class, access, buf_ty, el_ty}, [&] {
b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter); ast::ParameterList params = {
b.Param("buffer",
b.ty.pointer(CreateASTTypeFor(ctx, buf_ty), storage_class, access),
{b.Disable(ast::DisabledValidation::kFunctionParameter)}),
b.Param("offset", b.ty.u32()),
};
ast::ParameterList params = { auto name = b.Sym();
// Note: The buffer parameter requires the StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer or cbuffer
// array.
b.create<ast::Variable>(b.Sym("buffer"), storage_class,
var_user->Variable()->Access(), buf_ast_ty, true, false,
nullptr, ast::AttributeList{disable_validation}),
b.Param("offset", b.ty.u32()),
};
auto name = b.Sym(); if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty);
auto* func = b.create<ast::Function>(
name, params, el_ast_ty, nullptr,
ast::AttributeList{
intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
},
ast::AttributeList{});
b.AST().AddFunction(func);
} else if (auto* arr_ty = el_ty->As<sem::Array>()) {
// fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
// var arr : array<T, N>;
// for (var i = 0u; i < array_count; i = i + 1) {
// arr[i] = el_load_func(buffer, offset + i * array_stride)
// }
// return arr;
// }
auto load = LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty));
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0_u));
auto* for_init = b.Decl(i);
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
auto* el_val = b.Call(load, "buffer", el_offset);
auto* for_loop =
b.For(for_init, for_cond, for_cont, b.Block(b.Assign(arr_el, el_val)));
if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, storage_class, el_ty)) { b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty); {
auto* func = b.create<ast::Function>( b.Decl(arr),
name, params, el_ast_ty, nullptr, for_loop,
ast::AttributeList{ b.Return(arr),
intrinsic, });
b.Disable(ast::DisabledValidation::kFunctionHasNoBody), } else {
}, ast::ExpressionList values;
ast::AttributeList{}); if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
b.AST().AddFunction(func); auto* vec_ty = mat_ty->ColumnType();
} else if (auto* arr_ty = el_ty->As<sem::Array>()) { Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
// fn load_func(buf : buf_ty, offset : u32) -> array<T, N> { for (uint32_t i = 0; i < mat_ty->columns(); i++) {
// var arr : array<T, N>; auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
// for (var i = 0u; i < array_count; i = i + 1) { values.emplace_back(b.Call(load, "buffer", offset));
// arr[i] = el_load_func(buf, offset + i * array_stride) }
// } } else if (auto* str = el_ty->As<sem::Struct>()) {
// return arr; for (auto* member : str->Members()) {
// } auto* offset = b.Add("offset", u32(member->Offset()));
auto load = LoadFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user); Symbol load = LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
auto* arr = b.Var(b.Symbols().New("arr"), CreateASTTypeFor(ctx, arr_ty)); values.emplace_back(b.Call(load, "buffer", offset));
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0_u)); }
auto* for_init = b.Decl(i);
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
auto* el_val = b.Call(load, "buffer", el_offset);
auto* for_loop =
b.For(for_init, for_cond, for_cont, b.Block(b.Assign(arr_el, el_val)));
b.Func(name, params, CreateASTTypeFor(ctx, arr_ty),
{
b.Decl(arr),
for_loop,
b.Return(arr),
});
} else {
ast::ExpressionList values;
if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load = LoadFunc(buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
values.emplace_back(b.Call(load, "buffer", offset));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = b.Add("offset", u32(member->Offset()));
Symbol load = LoadFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
values.emplace_back(b.Call(load, "buffer", offset));
} }
b.Func(name, params, CreateASTTypeFor(ctx, el_ty),
{
b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
});
} }
b.Func(name, params, CreateASTTypeFor(ctx, el_ty), return name;
{ });
b.Return(b.Construct(CreateASTTypeFor(ctx, el_ty), values)),
});
}
return name;
});
} }
/// StoreFunc() returns a symbol to an intrinsic function that stores an /// StoreFunc() returns a symbol to an intrinsic function that stores an
/// element of type `el_ty` to a storage buffer of type `buf_ty`. /// element of type `el_ty` to a storage buffer of type `buf_ty`.
/// The function has the signature: /// The function has the signature:
/// `fn store(buf : buf_ty, offset : u32, value : el_ty)` /// `fn store(buf : ptr<SC, buf_ty, A>, offset : u32, value : el_ty)`
/// @param buf_ty the storage buffer type /// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type /// @param el_ty the storage buffer element type
/// @param var_user the variable user /// @param var_user the variable user
@ -523,87 +521,95 @@ struct DecomposeMemoryAccess::State {
const sem::Type* el_ty, const sem::Type* el_ty,
const sem::VariableUser* var_user) { const sem::VariableUser* var_user) {
auto storage_class = var_user->Variable()->StorageClass(); auto storage_class = var_user->Variable()->StorageClass();
return utils::GetOrCreate(store_funcs, LoadStoreKey{storage_class, buf_ty, el_ty}, [&] { auto access = var_user->Variable()->Access();
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty); return utils::GetOrCreate(
auto* el_ast_ty = CreateASTTypeFor(ctx, el_ty); store_funcs, LoadStoreKey{storage_class, access, buf_ty, el_ty}, [&] {
auto* disable_validation = ast::ParameterList params{
b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter); b.Param("buffer",
ast::ParameterList params{ b.ty.pointer(CreateASTTypeFor(ctx, buf_ty), storage_class, access),
// Note: The buffer parameter requires the StorageClass in {b.Disable(ast::DisabledValidation::kFunctionParameter)}),
// order for HLSL to emit this as a ByteAddressBuffer. b.Param("offset", b.ty.u32()),
b.Param("value", CreateASTTypeFor(ctx, el_ty)),
};
b.create<ast::Variable>(b.Sym("buffer"), storage_class, auto name = b.Sym();
var_user->Variable()->Access(), buf_ast_ty, true, false,
nullptr, ast::AttributeList{disable_validation}),
b.Param("offset", b.ty.u32()),
b.Param("value", el_ast_ty),
};
auto name = b.Sym(); if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) {
auto* func = b.create<ast::Function>(
name, params, b.ty.void_(), nullptr,
ast::AttributeList{
intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
},
ast::AttributeList{});
b.AST().AddFunction(func);
} else {
auto body = Switch(
el_ty, //
[&](const sem::Array* arr_ty) {
// fn store_func(buffer : buf_ty, offset : u32, value : el_ty) {
// var array = value; // No dynamic indexing on constant arrays
// for (var i = 0u; i < array_count; i = i + 1) {
// arr[i] = el_store_func(buffer, offset + i * array_stride,
// value[i])
// }
// return arr;
// }
auto* array = b.Var(b.Symbols().New("array"), nullptr, b.Expr("value"));
auto store =
StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user);
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0_u));
auto* for_init = b.Decl(i);
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset =
b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
auto* store_stmt =
b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
auto* for_loop =
b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, storage_class, el_ty)) { return ast::StatementList{b.Decl(array), for_loop};
auto* func = b.create<ast::Function>( },
name, params, b.ty.void_(), nullptr, [&](const sem::Matrix* mat_ty) {
ast::AttributeList{ auto* vec_ty = mat_ty->ColumnType();
intrinsic, Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
b.Disable(ast::DisabledValidation::kFunctionHasNoBody), ast::StatementList stmts;
}, for (uint32_t i = 0; i < mat_ty->columns(); i++) {
ast::AttributeList{}); auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
b.AST().AddFunction(func); auto* element = b.IndexAccessor("value", u32(i));
} else { auto* call = b.Call(store, "buffer", offset, element);
ast::StatementList body; stmts.emplace_back(b.CallStmt(call));
if (auto* arr_ty = el_ty->As<sem::Array>()) { }
// fn store_func(buf : buf_ty, offset : u32, value : el_ty) { return stmts;
// var array = value; // No dynamic indexing on constant arrays },
// for (var i = 0u; i < array_count; i = i + 1) { [&](const sem::Struct* str) {
// arr[i] = el_store_func(buf, offset + i * array_stride, ast::StatementList stmts;
// value[i]) for (auto* member : str->Members()) {
// } auto* offset = b.Add("offset", u32(member->Offset()));
// return arr; auto* element = b.MemberAccessor(
// } "value", ctx.Clone(member->Declaration()->symbol));
auto* array = b.Var(b.Symbols().New("array"), nullptr, b.Expr("value")); Symbol store =
auto store = StoreFunc(buf_ty, arr_ty->ElemType()->UnwrapRef(), var_user); StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
auto* i = b.Var(b.Symbols().New("i"), nullptr, b.Expr(0_u)); auto* call = b.Call(store, "buffer", offset, element);
auto* for_init = b.Decl(i); stmts.emplace_back(b.CallStmt(call));
auto* for_cond = b.create<ast::BinaryExpression>( }
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_ty->Count()))); return stmts;
auto* for_cont = b.Assign(i, b.Add(i, 1_u)); });
auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
auto* store_stmt = b.CallStmt(b.Call(store, "buffer", el_offset, arr_el));
auto* for_loop = b.For(for_init, for_cond, for_cont, b.Block(store_stmt));
body = {b.Decl(array), for_loop}; b.Func(name, params, b.ty.void_(), body);
} else if (auto* mat_ty = el_ty->As<sem::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(buf_ty, vec_ty, var_user);
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
auto* access = b.IndexAccessor("value", u32(i));
auto* call = b.Call(store, "buffer", offset, access);
body.emplace_back(b.CallStmt(call));
}
} else if (auto* str = el_ty->As<sem::Struct>()) {
for (auto* member : str->Members()) {
auto* offset = b.Add("offset", u32(member->Offset()));
auto* access =
b.MemberAccessor("value", ctx.Clone(member->Declaration()->symbol));
Symbol store = StoreFunc(buf_ty, member->Type()->UnwrapRef(), var_user);
auto* call = b.Call(store, "buffer", offset, access);
body.emplace_back(b.CallStmt(call));
}
} }
b.Func(name, params, b.ty.void_(), body);
}
return name; return name;
}); });
} }
/// AtomicFunc() returns a symbol to an intrinsic function that performs an /// AtomicFunc() returns a symbol to an intrinsic function that performs an
/// atomic operation from a storage buffer of type `buf_ty`. The function has /// atomic operation from a storage buffer of type `buf_ty`. The function has
/// the signature: /// the signature:
// `fn atomic_op(buf : buf_ty, offset : u32, ...) -> T` // `fn atomic_op(buf : ptr<storage, buf_ty, A>, offset : u32, ...) -> T`
/// @param buf_ty the storage buffer type /// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type /// @param el_ty the storage buffer element type
/// @param intrinsic the atomic intrinsic /// @param intrinsic the atomic intrinsic
@ -614,19 +620,15 @@ struct DecomposeMemoryAccess::State {
const sem::Builtin* intrinsic, const sem::Builtin* intrinsic,
const sem::VariableUser* var_user) { const sem::VariableUser* var_user) {
auto op = intrinsic->Type(); auto op = intrinsic->Type();
return utils::GetOrCreate(atomic_funcs, AtomicKey{buf_ty, el_ty, op}, [&] { auto access = var_user->Variable()->Access();
auto* buf_ast_ty = CreateASTTypeFor(ctx, buf_ty); return utils::GetOrCreate(atomic_funcs, AtomicKey{access, buf_ty, el_ty, op}, [&] {
auto* disable_validation =
b.Disable(ast::DisabledValidation::kIgnoreConstructibleFunctionParameter);
// The first parameter to all WGSL atomics is the expression to the // The first parameter to all WGSL atomics is the expression to the
// atomic. This is replaced with two parameters: the buffer and offset. // atomic. This is replaced with two parameters: the buffer and offset.
ast::ParameterList params = { ast::ParameterList params = {
// Note: The buffer parameter requires the kStorage StorageClass in b.Param("buffer",
// order for HLSL to emit this as a ByteAddressBuffer. b.ty.pointer(CreateASTTypeFor(ctx, buf_ty), ast::StorageClass::kStorage,
b.create<ast::Variable>(b.Sym("buffer"), ast::StorageClass::kStorage, access),
var_user->Variable()->Access(), buf_ast_ty, true, false, {b.Disable(ast::DisabledValidation::kFunctionParameter)}),
nullptr, ast::AttributeList{disable_validation}),
b.Param("offset", b.ty.u32()), b.Param("offset", b.ty.u32()),
}; };
@ -910,8 +912,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
if (auto* builtin = call->Target()->As<sem::Builtin>()) { if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) { if (builtin->Type() == sem::BuiltinType::kArrayLength) {
// arrayLength(X) // arrayLength(X)
// Don't convert X into a load, this builtin actually requires the // Don't convert X into a load, this builtin actually requires the real pointer.
// real pointer.
state.TakeAccess(call_expr->args[0]); state.TakeAccess(call_expr->args[0]);
continue; continue;
} }
@ -926,7 +927,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
Symbol func = state.AtomicFunc(buf_ty, el_ty, builtin, Symbol func = state.AtomicFunc(buf_ty, el_ty, builtin,
access.var->As<sem::VariableUser>()); access.var->As<sem::VariableUser>());
ast::ExpressionList args{ctx.Clone(buf), offset}; ast::ExpressionList args{ctx.dst->AddressOf(ctx.Clone(buf)), offset};
for (size_t i = 1; i < call_expr->args.size(); i++) { for (size_t i = 1; i < call_expr->args.size(); i++) {
auto* arg = call_expr->args[i]; auto* arg = call_expr->args[i];
args.emplace_back(ctx.Clone(arg)); args.emplace_back(ctx.Clone(arg));
@ -948,26 +949,26 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
} }
BufferAccess access = access_it->second; BufferAccess access = access_it->second;
ctx.Replace(expr, [=, &ctx, &state] { ctx.Replace(expr, [=, &ctx, &state] {
auto* buf = access.var->Declaration(); auto* buf = ctx.dst->AddressOf(ctx.CloneWithoutTransform(access.var->Declaration()));
auto* offset = access.offset->Build(ctx); auto* offset = access.offset->Build(ctx);
auto* buf_ty = access.var->Type()->UnwrapRef(); auto* buf_ty = access.var->Type()->UnwrapRef();
auto* el_ty = access.type->UnwrapRef(); auto* el_ty = access.type->UnwrapRef();
Symbol func = state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>()); Symbol func = state.LoadFunc(buf_ty, el_ty, access.var->As<sem::VariableUser>());
return ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset); return ctx.dst->Call(func, buf, offset);
}); });
} }
// And replace all storage and uniform buffer assignments with stores // And replace all storage and uniform buffer assignments with stores
for (auto store : state.stores) { for (auto store : state.stores) {
ctx.Replace(store.assignment, [=, &ctx, &state] { ctx.Replace(store.assignment, [=, &ctx, &state] {
auto* buf = store.target.var->Declaration(); auto* buf =
ctx.dst->AddressOf(ctx.CloneWithoutTransform((store.target.var->Declaration())));
auto* offset = store.target.offset->Build(ctx); auto* offset = store.target.offset->Build(ctx);
auto* buf_ty = store.target.var->Type()->UnwrapRef(); auto* buf_ty = store.target.var->Type()->UnwrapRef();
auto* el_ty = store.target.type->UnwrapRef(); auto* el_ty = store.target.type->UnwrapRef();
auto* value = store.assignment->rhs; auto* value = store.assignment->rhs;
Symbol func = state.StoreFunc(buf_ty, el_ty, store.target.var->As<sem::VariableUser>()); Symbol func = state.StoreFunc(buf_ty, el_ty, store.target.var->As<sem::VariableUser>());
auto* call = auto* call = ctx.dst->Call(func, buf, offset, ctx.Clone(value));
ctx.dst->Call(func, ctx.CloneWithoutTransform(buf), offset, ctx.Clone(value));
return ctx.dst->CallStmt(call); return ctx.dst->CallStmt(call);
}); });
} }

File diff suppressed because it is too large Load Diff

View File

@ -49,7 +49,8 @@ Output Manager::Run(const Program* program, const DataMap& data) const {
Output out; Output out;
for (const auto& transform : transforms_) { for (const auto& transform : transforms_) {
if (!transform->ShouldRun(in, data)) { if (!transform->ShouldRun(in, data)) {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name << std::endl); TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name
<< std::endl);
continue; continue;
} }
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get())); TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));

View File

@ -2776,14 +2776,25 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
first = false; first = false;
auto const* type = v->Type(); auto const* type = v->Type();
auto storage_class = ast::StorageClass::kNone;
auto access = ast::Access::kUndefined;
if (auto* ptr = type->As<sem::Pointer>()) { if (auto* ptr = type->As<sem::Pointer>()) {
// Transform pointer parameters in to `inout` parameters.
// The WGSL spec is highly restrictive in what can be passed in pointer
// parameters, which allows for this transformation. See:
// https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
out << "inout ";
type = ptr->StoreType(); type = ptr->StoreType();
switch (ptr->StorageClass()) {
case ast::StorageClass::kStorage:
case ast::StorageClass::kUniform:
// Not allowed by WGSL, but is used by certain transforms (e.g. DMA) to pass
// storage buffers and uniform buffers down into transform-generated
// functions. In this situation we want to generate the parameter without an
// 'inout', using the storage class and access from the pointer.
storage_class = ptr->StorageClass();
access = ptr->Access();
break;
default:
// Transform regular WGSL pointer parameters in to `inout` parameters.
out << "inout ";
}
} }
// Note: WGSL only allows for StorageClass::kNone on parameters, however // Note: WGSL only allows for StorageClass::kNone on parameters, however
@ -2792,7 +2803,7 @@ bool GeneratorImpl::EmitFunction(const ast::Function* func) {
// StorageClass::kStorage or StorageClass::kUniform. This is required to // StorageClass::kStorage or StorageClass::kUniform. This is required to
// correctly translate the parameter to a [RW]ByteAddressBuffer for // correctly translate the parameter to a [RW]ByteAddressBuffer for
// storage buffers and a uint4[N] for uniform buffers. // storage buffers and a uint4[N] for uniform buffers.
if (!EmitTypeAndName(out, type, v->StorageClass(), v->Access(), if (!EmitTypeAndName(out, type, storage_class, access,
builder_.Symbols().NameFor(v->Declaration()->symbol))) { builder_.Symbols().NameFor(v->Declaration()->symbol))) {
return false; return false;
} }