mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 00:17:03 +00:00
Revert "msl: Use a struct for threadgroup memory arguments"
This reverts commit af8cd3b7f5.
Reason for revert: breaking roll into Dawn.
Original change's description:
> msl: Use a struct for threadgroup memory arguments
>
> MSL has a limit on the number of threadgroup memory arguments, so use
> a struct to support an arbitrary number of workgroup variables.
>
> Bug: tint:938
> Change-Id: I40e4a8d99bc4ae074010479a56e13e2e0acdded3
> Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/64380
> Kokoro: Kokoro <noreply+kokoro@google.com>
> Auto-Submit: James Price <jrprice@google.com>
> Reviewed-by: Ben Clayton <bclayton@google.com>
> Commit-Queue: James Price <jrprice@google.com>
TBR=bclayton@google.com,jrprice@google.com,noreply+kokoro@google.com,tint-scoped@luci-project-accounts.iam.gserviceaccount.com
Change-Id: I58a07c4ab7e92bda205e2bbbab41e0b347aeb1e8
No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Bug: tint:938
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/65162
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Kokoro: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
committed by
Tint LUCI CQ
parent
af8cd3b7f5
commit
40ef4a8269
@@ -42,8 +42,6 @@ std::string DisableValidationDecoration::InternalName() const {
|
||||
return "disable_validation__ignore_constructible_function_parameter";
|
||||
case DisabledValidation::kIgnoreStrideDecoration:
|
||||
return "disable_validation__ignore_stride";
|
||||
case DisabledValidation::kIgnoreInvalidPointerArgument:
|
||||
return "disable_validation__ignore_invalid_pointer_argument";
|
||||
}
|
||||
return "<invalid>";
|
||||
}
|
||||
|
||||
@@ -43,10 +43,6 @@ enum class DisabledValidation {
|
||||
/// When applied to a member decoration, a stride decoration may be applied to
|
||||
/// non-array types.
|
||||
kIgnoreStrideDecoration,
|
||||
/// When applied to a pointer function parameter, the validator will not
|
||||
/// require a function call argument passed for that parameter to have a
|
||||
/// certain form.
|
||||
kIgnoreInvalidPointerArgument,
|
||||
};
|
||||
|
||||
/// An internal decoration used to tell the validator to ignore specific
|
||||
|
||||
@@ -2659,10 +2659,7 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_valid &&
|
||||
IsValidationEnabled(
|
||||
param->declaration->decorations(),
|
||||
ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
|
||||
if (!is_valid) {
|
||||
AddError(
|
||||
"expected an address-of expression of a variable identifier "
|
||||
"expression or a function parameter",
|
||||
|
||||
@@ -47,27 +47,6 @@ bool ContainsMatrix(const sem::Type* type) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Clone any struct types that are contained in `ty` (including `ty` itself),
|
||||
// and add it to the global declarations now, so that they precede new global
|
||||
// declarations that need to reference them.
|
||||
void CloneStructTypes(const sem::Type* ty, CloneContext& ctx) {
|
||||
if (auto* str = ty->As<sem::Struct>()) {
|
||||
// Recurse into members.
|
||||
for (auto* member : str->Members()) {
|
||||
CloneStructTypes(member->Type(), ctx);
|
||||
}
|
||||
|
||||
// Clone the struct and add it to the global declaration list.
|
||||
// Remove the old declaration.
|
||||
auto* ast_str = str->Declaration();
|
||||
ctx.dst->AST().AddTypeDecl(ctx.Clone(const_cast<ast::Struct*>(ast_str)));
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
|
||||
} else if (auto* arr = ty->As<sem::Array>()) {
|
||||
CloneStructTypes(arr->ElemType(), ctx);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
|
||||
@@ -133,17 +112,6 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
// Map module-scope variables onto their function-scope replacement.
|
||||
std::unordered_map<const sem::Variable*, Symbol> var_to_symbol;
|
||||
|
||||
// We aggregate all workgroup variables into a struct to avoid hitting MSL's
|
||||
// limit for threadgroup memory arguments.
|
||||
Symbol workgroup_parameter_symbol;
|
||||
ast::StructMemberList workgroup_parameter_members;
|
||||
auto workgroup_param = [&]() {
|
||||
if (!workgroup_parameter_symbol.IsValid()) {
|
||||
workgroup_parameter_symbol = ctx.dst->Sym();
|
||||
}
|
||||
return workgroup_parameter_symbol;
|
||||
};
|
||||
|
||||
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
||||
if (var->StorageClass() != ast::StorageClass::kPrivate &&
|
||||
var->StorageClass() != ast::StorageClass::kWorkgroup &&
|
||||
@@ -154,16 +122,13 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
// This is the symbol for the variable that replaces the module-scope var.
|
||||
auto new_var_symbol = ctx.dst->Sym();
|
||||
|
||||
// Helper to create an AST node for the store type of the variable.
|
||||
auto store_type = [&]() {
|
||||
return CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
|
||||
};
|
||||
auto* store_type = CreateASTTypeFor(ctx, var->Type()->UnwrapRef());
|
||||
|
||||
// Track whether the new variable is a pointer or not.
|
||||
bool is_pointer = false;
|
||||
|
||||
if (is_entry_point) {
|
||||
if (var->Type()->UnwrapRef()->is_handle()) {
|
||||
if (store_type->is_handle()) {
|
||||
// For a texture or sampler variable, redeclare it as an entry point
|
||||
// parameter. Disable entry point parameter validation.
|
||||
auto* disable_validation =
|
||||
@@ -171,7 +136,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
|
||||
auto decos = ctx.Clone(var->Declaration()->decorations());
|
||||
decos.push_back(disable_validation);
|
||||
auto* param = ctx.dst->Param(new_var_symbol, store_type(), decos);
|
||||
auto* param = ctx.dst->Param(new_var_symbol, store_type, decos);
|
||||
ctx.InsertFront(func_ast->params(), param);
|
||||
} else {
|
||||
if (var->StorageClass() == ast::StorageClass::kWorkgroup &&
|
||||
@@ -179,24 +144,15 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
// Due to a bug in the MSL compiler, we use a threadgroup memory
|
||||
// argument for any workgroup allocation that contains a matrix.
|
||||
// See crbug.com/tint/938.
|
||||
// TODO(jrprice): Do this for all other workgroup variables too.
|
||||
|
||||
// Create a member in the workgroup parameter struct.
|
||||
auto member = ctx.Clone(var->Declaration()->symbol());
|
||||
workgroup_parameter_members.push_back(
|
||||
ctx.dst->Member(member, store_type()));
|
||||
CloneStructTypes(var->Type()->UnwrapRef(), ctx);
|
||||
|
||||
// Create a function-scope variable that is a pointer to the member.
|
||||
auto* member_ptr = ctx.dst->AddressOf(ctx.dst->MemberAccessor(
|
||||
ctx.dst->Deref(workgroup_param()), member));
|
||||
auto* local_var =
|
||||
ctx.dst->Const(new_var_symbol,
|
||||
ctx.dst->ty.pointer(
|
||||
store_type(), ast::StorageClass::kWorkgroup),
|
||||
member_ptr);
|
||||
ctx.InsertFront(func_ast->body()->statements(),
|
||||
ctx.dst->Decl(local_var));
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(),
|
||||
ast::DisabledValidation::kEntryPointParameter);
|
||||
auto* param_type =
|
||||
ctx.dst->ty.pointer(store_type, var->StorageClass());
|
||||
auto* param = ctx.dst->Param(new_var_symbol, param_type,
|
||||
{disable_validation});
|
||||
ctx.InsertFront(func_ast->params(), param);
|
||||
is_pointer = true;
|
||||
} else {
|
||||
// For any other private or workgroup variable, redeclare it at
|
||||
@@ -208,7 +164,7 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
ast::DisabledValidation::kIgnoreStorageClass);
|
||||
auto* constructor = ctx.Clone(var->Declaration()->constructor());
|
||||
auto* local_var = ctx.dst->Var(
|
||||
new_var_symbol, store_type(), var->StorageClass(), constructor,
|
||||
new_var_symbol, store_type, var->StorageClass(), constructor,
|
||||
ast::DecorationList{disable_validation});
|
||||
ctx.InsertFront(func_ast->body()->statements(),
|
||||
ctx.dst->Decl(local_var));
|
||||
@@ -217,21 +173,13 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
} else {
|
||||
// For a regular function, redeclare the variable as a parameter.
|
||||
// Use a pointer for non-handle types.
|
||||
auto* param_type = store_type();
|
||||
ast::DecorationList attributes;
|
||||
if (!param_type->is_handle()) {
|
||||
auto* param_type = store_type;
|
||||
if (!store_type->is_handle()) {
|
||||
param_type = ctx.dst->ty.pointer(param_type, var->StorageClass());
|
||||
is_pointer = true;
|
||||
|
||||
// Disable validation of arguments passed to this pointer parameter,
|
||||
// as we will sometimes pass pointers to struct members.
|
||||
attributes.push_back(
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(),
|
||||
ast::DisabledValidation::kIgnoreInvalidPointerArgument));
|
||||
}
|
||||
ctx.InsertBack(func_ast->params(),
|
||||
ctx.dst->Param(new_var_symbol, param_type, attributes));
|
||||
ctx.dst->Param(new_var_symbol, param_type));
|
||||
}
|
||||
|
||||
// Replace all uses of the module-scope variable.
|
||||
@@ -258,22 +206,6 @@ void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx,
|
||||
var_to_symbol[var] = new_var_symbol;
|
||||
}
|
||||
|
||||
if (!workgroup_parameter_members.empty()) {
|
||||
// Create the workgroup memory parameter.
|
||||
// The parameter is a struct that contains members for each workgroup
|
||||
// variable.
|
||||
auto* str = ctx.dst->Structure(ctx.dst->Sym(),
|
||||
std::move(workgroup_parameter_members));
|
||||
auto* param_type = ctx.dst->ty.pointer(ctx.dst->ty.Of(str),
|
||||
ast::StorageClass::kWorkgroup);
|
||||
auto* disable_validation =
|
||||
ctx.dst->ASTNodes().Create<ast::DisableValidationDecoration>(
|
||||
ctx.dst->ID(), ast::DisabledValidation::kEntryPointParameter);
|
||||
auto* param =
|
||||
ctx.dst->Param(workgroup_param(), param_type, {disable_validation});
|
||||
ctx.InsertFront(func_ast->params(), param);
|
||||
}
|
||||
|
||||
// Pass the variables as pointers to any functions that need them.
|
||||
for (auto* call : calls_to_replace[func_ast]) {
|
||||
auto* target = ctx.src->AST().Functions().Find(call->func()->symbol());
|
||||
|
||||
@@ -78,12 +78,12 @@ fn main() {
|
||||
fn no_uses() {
|
||||
}
|
||||
|
||||
fn bar(a : f32, b : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_1 : ptr<workgroup, f32>) {
|
||||
fn bar(a : f32, b : f32, tint_symbol : ptr<private, f32>, tint_symbol_1 : ptr<workgroup, f32>) {
|
||||
*(tint_symbol) = a;
|
||||
*(tint_symbol_1) = b;
|
||||
}
|
||||
|
||||
fn foo(a : f32, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_2 : ptr<private, f32>, [[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol_3 : ptr<workgroup, f32>) {
|
||||
fn foo(a : f32, tint_symbol_2 : ptr<private, f32>, tint_symbol_3 : ptr<workgroup, f32>) {
|
||||
let b : f32 = 2.0;
|
||||
bar(a, b, tint_symbol_2, tint_symbol_3);
|
||||
no_uses();
|
||||
@@ -181,7 +181,7 @@ fn bar(p : ptr<private, f32>) {
|
||||
*(p) = 0.0;
|
||||
}
|
||||
|
||||
fn foo([[internal(disable_validation__ignore_invalid_pointer_argument)]] tint_symbol : ptr<private, f32>) {
|
||||
fn foo(tint_symbol : ptr<private, f32>) {
|
||||
bar(tint_symbol);
|
||||
}
|
||||
|
||||
@@ -340,13 +340,8 @@ fn main() {
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
struct tint_symbol_2 {
|
||||
m : mat2x2<f32>;
|
||||
};
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
|
||||
let tint_symbol : ptr<workgroup, mat2x2<f32>> = &((*(tint_symbol_1)).m);
|
||||
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, mat2x2<f32>>) {
|
||||
let x = *(tint_symbol);
|
||||
}
|
||||
)";
|
||||
@@ -381,13 +376,8 @@ struct S2 {
|
||||
s : S1;
|
||||
};
|
||||
|
||||
struct tint_symbol_2 {
|
||||
m : array<S2, 4u>;
|
||||
};
|
||||
|
||||
[[stage(compute), workgroup_size(1)]]
|
||||
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol_1 : ptr<workgroup, tint_symbol_2>) {
|
||||
let tint_symbol : ptr<workgroup, array<S2, 4u>> = &((*(tint_symbol_1)).m);
|
||||
fn main([[internal(disable_validation__entry_point_parameter)]] tint_symbol : ptr<workgroup, array<S2, 4u>>) {
|
||||
let x = *(tint_symbol);
|
||||
}
|
||||
)";
|
||||
|
||||
@@ -142,10 +142,6 @@ TEST_F(MslGeneratorImplTest, WorkgroupMatrix) {
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
struct tint_symbol_3 {
|
||||
float2x2 m;
|
||||
};
|
||||
|
||||
void comp_main_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol) {
|
||||
{
|
||||
*(tint_symbol) = float2x2();
|
||||
@@ -154,8 +150,8 @@ void comp_main_inner(uint local_invocation_index, threadgroup float2x2* const ti
|
||||
float2x2 const x = *(tint_symbol);
|
||||
}
|
||||
|
||||
kernel void comp_main(threadgroup tint_symbol_3* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
comp_main_inner(local_invocation_index, &((*(tint_symbol_2)).m));
|
||||
kernel void comp_main(threadgroup float2x2* tint_symbol_1 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
comp_main_inner(local_invocation_index, tint_symbol_1);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -182,9 +178,6 @@ using namespace metal;
|
||||
struct tint_array_wrapper {
|
||||
float2x2 arr[4];
|
||||
};
|
||||
struct tint_symbol_3 {
|
||||
tint_array_wrapper m;
|
||||
};
|
||||
|
||||
void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper* const tint_symbol) {
|
||||
for(uint idx = local_invocation_index; (idx < 4u); idx = (idx + 1u)) {
|
||||
@@ -195,8 +188,8 @@ void comp_main_inner(uint local_invocation_index, threadgroup tint_array_wrapper
|
||||
tint_array_wrapper const x = *(tint_symbol);
|
||||
}
|
||||
|
||||
kernel void comp_main(threadgroup tint_symbol_3* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
comp_main_inner(local_invocation_index, &((*(tint_symbol_2)).m));
|
||||
kernel void comp_main(threadgroup tint_array_wrapper* tint_symbol_1 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
comp_main_inner(local_invocation_index, tint_symbol_1);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -234,9 +227,6 @@ struct S1 {
|
||||
struct S2 {
|
||||
S1 s;
|
||||
};
|
||||
struct tint_symbol_4 {
|
||||
S2 s;
|
||||
};
|
||||
|
||||
void comp_main_inner(uint local_invocation_index, threadgroup S2* const tint_symbol_1) {
|
||||
{
|
||||
@@ -247,8 +237,8 @@ void comp_main_inner(uint local_invocation_index, threadgroup S2* const tint_sym
|
||||
S2 const x = *(tint_symbol_1);
|
||||
}
|
||||
|
||||
kernel void comp_main(threadgroup tint_symbol_4* tint_symbol_3 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
comp_main_inner(local_invocation_index, &((*(tint_symbol_3)).s));
|
||||
kernel void comp_main(threadgroup S2* tint_symbol_2 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
comp_main_inner(local_invocation_index, tint_symbol_2);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -301,22 +291,6 @@ TEST_F(MslGeneratorImplTest, WorkgroupMatrix_Multiples) {
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
struct tint_symbol_7 {
|
||||
float2x2 m1;
|
||||
float2x3 m2;
|
||||
float2x4 m3;
|
||||
};
|
||||
struct tint_symbol_15 {
|
||||
float3x2 m4;
|
||||
float3x3 m5;
|
||||
float3x4 m6;
|
||||
};
|
||||
struct tint_symbol_23 {
|
||||
float4x2 m7;
|
||||
float4x3 m8;
|
||||
float4x4 m9;
|
||||
};
|
||||
|
||||
void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_symbol, threadgroup float2x3* const tint_symbol_1, threadgroup float2x4* const tint_symbol_2) {
|
||||
{
|
||||
*(tint_symbol) = float2x2();
|
||||
@@ -329,42 +303,42 @@ void main1_inner(uint local_invocation_index, threadgroup float2x2* const tint_s
|
||||
float2x4 const a3 = *(tint_symbol_2);
|
||||
}
|
||||
|
||||
kernel void main1(threadgroup tint_symbol_7* tint_symbol_4 [[threadgroup(0)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
main1_inner(local_invocation_index, &((*(tint_symbol_4)).m1), &((*(tint_symbol_4)).m2), &((*(tint_symbol_4)).m3));
|
||||
kernel void main1(threadgroup float2x2* tint_symbol_3 [[threadgroup(0)]], threadgroup float2x3* tint_symbol_4 [[threadgroup(1)]], threadgroup float2x4* tint_symbol_5 [[threadgroup(2)]], uint local_invocation_index [[thread_index_in_threadgroup]]) {
|
||||
main1_inner(local_invocation_index, tint_symbol_3, tint_symbol_4, tint_symbol_5);
|
||||
return;
|
||||
}
|
||||
|
||||
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_8, threadgroup float3x3* const tint_symbol_9, threadgroup float3x4* const tint_symbol_10) {
|
||||
void main2_inner(uint local_invocation_index_1, threadgroup float3x2* const tint_symbol_6, threadgroup float3x3* const tint_symbol_7, threadgroup float3x4* const tint_symbol_8) {
|
||||
{
|
||||
*(tint_symbol_8) = float3x2();
|
||||
*(tint_symbol_9) = float3x3();
|
||||
*(tint_symbol_10) = float3x4();
|
||||
*(tint_symbol_6) = float3x2();
|
||||
*(tint_symbol_7) = float3x3();
|
||||
*(tint_symbol_8) = float3x4();
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float3x2 const a1 = *(tint_symbol_8);
|
||||
float3x3 const a2 = *(tint_symbol_9);
|
||||
float3x4 const a3 = *(tint_symbol_10);
|
||||
float3x2 const a1 = *(tint_symbol_6);
|
||||
float3x3 const a2 = *(tint_symbol_7);
|
||||
float3x4 const a3 = *(tint_symbol_8);
|
||||
}
|
||||
|
||||
kernel void main2(threadgroup tint_symbol_15* tint_symbol_12 [[threadgroup(0)]], uint local_invocation_index_1 [[thread_index_in_threadgroup]]) {
|
||||
main2_inner(local_invocation_index_1, &((*(tint_symbol_12)).m4), &((*(tint_symbol_12)).m5), &((*(tint_symbol_12)).m6));
|
||||
kernel void main2(threadgroup float3x2* tint_symbol_9 [[threadgroup(0)]], threadgroup float3x3* tint_symbol_10 [[threadgroup(1)]], threadgroup float3x4* tint_symbol_11 [[threadgroup(2)]], uint local_invocation_index_1 [[thread_index_in_threadgroup]]) {
|
||||
main2_inner(local_invocation_index_1, tint_symbol_9, tint_symbol_10, tint_symbol_11);
|
||||
return;
|
||||
}
|
||||
|
||||
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_16, threadgroup float4x3* const tint_symbol_17, threadgroup float4x4* const tint_symbol_18) {
|
||||
void main3_inner(uint local_invocation_index_2, threadgroup float4x2* const tint_symbol_12, threadgroup float4x3* const tint_symbol_13, threadgroup float4x4* const tint_symbol_14) {
|
||||
{
|
||||
*(tint_symbol_16) = float4x2();
|
||||
*(tint_symbol_17) = float4x3();
|
||||
*(tint_symbol_18) = float4x4();
|
||||
*(tint_symbol_12) = float4x2();
|
||||
*(tint_symbol_13) = float4x3();
|
||||
*(tint_symbol_14) = float4x4();
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
float4x2 const a1 = *(tint_symbol_16);
|
||||
float4x3 const a2 = *(tint_symbol_17);
|
||||
float4x4 const a3 = *(tint_symbol_18);
|
||||
float4x2 const a1 = *(tint_symbol_12);
|
||||
float4x3 const a2 = *(tint_symbol_13);
|
||||
float4x4 const a3 = *(tint_symbol_14);
|
||||
}
|
||||
|
||||
kernel void main3(threadgroup tint_symbol_23* tint_symbol_20 [[threadgroup(0)]], uint local_invocation_index_2 [[thread_index_in_threadgroup]]) {
|
||||
main3_inner(local_invocation_index_2, &((*(tint_symbol_20)).m7), &((*(tint_symbol_20)).m8), &((*(tint_symbol_20)).m9));
|
||||
kernel void main3(threadgroup float4x2* tint_symbol_15 [[threadgroup(0)]], threadgroup float4x3* tint_symbol_16 [[threadgroup(1)]], threadgroup float4x4* tint_symbol_17 [[threadgroup(2)]], uint local_invocation_index_2 [[thread_index_in_threadgroup]]) {
|
||||
main3_inner(local_invocation_index_2, tint_symbol_15, tint_symbol_16, tint_symbol_17);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -379,12 +353,18 @@ kernel void main4_no_usages() {
|
||||
ASSERT_TRUE(allocations.count("main2"));
|
||||
ASSERT_TRUE(allocations.count("main3"));
|
||||
EXPECT_EQ(allocations.count("main4_no_usages"), 0u);
|
||||
ASSERT_EQ(allocations["main1"].size(), 1u);
|
||||
EXPECT_EQ(allocations["main1"][0], 20u * sizeof(float));
|
||||
ASSERT_EQ(allocations["main2"].size(), 1u);
|
||||
EXPECT_EQ(allocations["main2"][0], 32u * sizeof(float));
|
||||
ASSERT_EQ(allocations["main3"].size(), 1u);
|
||||
EXPECT_EQ(allocations["main3"][0], 40u * sizeof(float));
|
||||
ASSERT_EQ(allocations["main1"].size(), 3u);
|
||||
EXPECT_EQ(allocations["main1"][0], 2u * 2u * sizeof(float));
|
||||
EXPECT_EQ(allocations["main1"][1], 2u * 4u * sizeof(float));
|
||||
EXPECT_EQ(allocations["main1"][2], 2u * 4u * sizeof(float));
|
||||
ASSERT_EQ(allocations["main2"].size(), 3u);
|
||||
EXPECT_EQ(allocations["main2"][0], 3u * 2u * sizeof(float));
|
||||
EXPECT_EQ(allocations["main2"][1], 3u * 4u * sizeof(float));
|
||||
EXPECT_EQ(allocations["main2"][2], 3u * 4u * sizeof(float));
|
||||
ASSERT_EQ(allocations["main3"].size(), 3u);
|
||||
EXPECT_EQ(allocations["main3"][0], 4u * 2u * sizeof(float));
|
||||
EXPECT_EQ(allocations["main3"][1], 4u * 4u * sizeof(float));
|
||||
EXPECT_EQ(allocations["main3"][2], 4u * 4u * sizeof(float));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user