diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index 33c8449522..1b4cede1ba 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -317,16 +317,10 @@ uint32_t Inspector::GetStorageSize(const std::string& entry_point) { size_t size = 0; auto* func_sem = program_->Sem().Get(func); for (auto& ruv : func_sem->TransitivelyReferencedUniformVariables()) { - const sem::Struct* s = ruv.first->Type()->UnwrapRef()->As(); - if (s) { - size += s->Size(); - } + size += ruv.first->Type()->UnwrapRef()->Size(); } for (auto& rsv : func_sem->TransitivelyReferencedStorageBufferVariables()) { - const sem::Struct* s = rsv.first->Type()->UnwrapRef()->As(); - if (s) { - size += s->Size(); - } + size += rsv.first->Type()->UnwrapRef()->Size(); } if (static_cast(size) > @@ -377,17 +371,18 @@ std::vector Inspector::GetUniformBufferResourceBindings( auto binding_info = ruv.second; auto* unwrapped_type = var->Type()->UnwrapRef(); - auto* str = unwrapped_type->As(); - if (str == nullptr) { - continue; - } ResourceBinding entry; entry.resource_type = ResourceBinding::ResourceType::kUniformBuffer; entry.bind_group = binding_info.group->value; entry.binding = binding_info.binding->value; - entry.size = str->Size(); - entry.size_no_padding = str->SizeNoPadding(); + entry.size = unwrapped_type->Size(); + entry.size_no_padding = entry.size; + if (auto* str = unwrapped_type->As()) { + entry.size_no_padding = str->SizeNoPadding(); + } else { + entry.size_no_padding = entry.size; + } result.push_back(entry); } @@ -667,10 +662,7 @@ std::vector Inspector::GetStorageBufferResourceBindingsImpl( continue; } - auto* str = var->Type()->UnwrapRef()->As(); - if (!str) { - continue; - } + auto* unwrapped_type = var->Type()->UnwrapRef(); ResourceBinding entry; entry.resource_type = @@ -678,8 +670,12 @@ std::vector Inspector::GetStorageBufferResourceBindingsImpl( : ResourceBinding::ResourceType::kStorageBuffer; entry.bind_group = binding_info.group->value; entry.binding = binding_info.binding->value; - entry.size = str->Size(); - entry.size_no_padding = str->SizeNoPadding(); + entry.size = unwrapped_type->Size(); + if (auto* str = unwrapped_type->As()) { + entry.size_no_padding = str->SizeNoPadding(); + } else { + entry.size_no_padding = entry.size; + } result.push_back(entry); } diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc index 1c2d6da978..1d237f80a0 100644 --- a/src/inspector/inspector_test.cc +++ b/src/inspector/inspector_test.cc @@ -1218,7 +1218,24 @@ TEST_F(InspectorGetStorageSizeTest, Empty) { EXPECT_EQ(0u, inspector.GetStorageSize("ep_func")); } -TEST_F(InspectorGetStorageSizeTest, Simple) { +TEST_F(InspectorGetStorageSizeTest, Simple_NonStruct) { + AddUniformBuffer("ub_var", ty.i32(), 0, 0); + AddStorageBuffer("sb_var", ty.i32(), ast::Access::kReadWrite, 1, 0); + AddStorageBuffer("rosb_var", ty.i32(), ast::Access::kRead, 1, 1); + Func("ep_func", {}, ty.void_(), + { + Decl(Const("ub", nullptr, Expr("ub_var"))), + Decl(Const("sb", nullptr, Expr("sb_var"))), + Decl(Const("rosb", nullptr, Expr("rosb_var"))), + }, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); + + Inspector& inspector = Build(); + + EXPECT_EQ(12u, inspector.GetStorageSize("ep_func")); +} + +TEST_F(InspectorGetStorageSizeTest, Simple_Struct) { auto* ub_struct_type = MakeUniformBufferType("ub_type", {ty.i32(), ty.i32()}); AddUniformBuffer("ub_var", ty.Of(ub_struct_type), 0, 0); MakeStructVariableReferenceBodyFunction("ub_func", "ub_var", {{0, ty.i32()}}); @@ -1243,6 +1260,33 @@ TEST_F(InspectorGetStorageSizeTest, Simple) { EXPECT_EQ(16u, inspector.GetStorageSize("ep_func")); } +TEST_F(InspectorGetStorageSizeTest, NonStructVec3) { + AddUniformBuffer("ub_var", ty.vec3(), 0, 0); + Func("ep_func", {}, ty.void_(), + { + Decl(Const("ub", nullptr, Expr("ub_var"))), + }, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); + + Inspector& inspector = Build(); + + EXPECT_EQ(12u, inspector.GetStorageSize("ep_func")); +} + +TEST_F(InspectorGetStorageSizeTest, StructVec3) { + auto* ub_struct_type = MakeUniformBufferType("ub_type", {ty.vec3()}); + AddUniformBuffer("ub_var", ty.Of(ub_struct_type), 0, 0); + Func("ep_func", {}, ty.void_(), + { + Decl(Const("ub", nullptr, Expr("ub_var"))), + }, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); + + Inspector& inspector = Build(); + + EXPECT_EQ(16u, inspector.GetStorageSize("ep_func")); +} + TEST_F(InspectorGetResourceBindingsTest, Empty) { MakeCallerBodyFunction("ep_func", {}, ast::DecorationList{ @@ -1381,7 +1425,30 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, NonEntryPointFunc) { EXPECT_TRUE(error.find("not an entry point") != std::string::npos); } -TEST_F(InspectorGetUniformBufferResourceBindingsTest, Simple) { +TEST_F(InspectorGetUniformBufferResourceBindingsTest, Simple_NonStruct) { + AddUniformBuffer("foo_ub", ty.i32(), 0, 0); + MakePlainGlobalReferenceBodyFunction("ub_func", "foo_ub", ty.i32(), {}); + + MakeCallerBodyFunction("ep_func", {"ub_func"}, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetUniformBufferResourceBindings("ep_func"); + ASSERT_FALSE(inspector.has_error()) << inspector.error(); + ASSERT_EQ(1u, result.size()); + + EXPECT_EQ(ResourceBinding::ResourceType::kUniformBuffer, + result[0].resource_type); + EXPECT_EQ(0u, result[0].bind_group); + EXPECT_EQ(0u, result[0].binding); + EXPECT_EQ(4u, result[0].size); + EXPECT_EQ(4u, result[0].size_no_padding); +} + +TEST_F(InspectorGetUniformBufferResourceBindingsTest, Simple_Struct) { auto* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()}); AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0); @@ -1459,6 +1526,29 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, ContainingPadding) { EXPECT_EQ(12u, result[0].size_no_padding); } +TEST_F(InspectorGetUniformBufferResourceBindingsTest, NonStructVec3) { + AddUniformBuffer("foo_ub", ty.vec3(), 0, 0); + MakePlainGlobalReferenceBodyFunction("ub_func", "foo_ub", ty.vec3(), {}); + + MakeCallerBodyFunction("ep_func", {"ub_func"}, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetUniformBufferResourceBindings("ep_func"); + ASSERT_FALSE(inspector.has_error()) << inspector.error(); + ASSERT_EQ(1u, result.size()); + + EXPECT_EQ(ResourceBinding::ResourceType::kUniformBuffer, + result[0].resource_type); + EXPECT_EQ(0u, result[0].bind_group); + EXPECT_EQ(0u, result[0].binding); + EXPECT_EQ(12u, result[0].size); + EXPECT_EQ(12u, result[0].size_no_padding); +} + TEST_F(InspectorGetUniformBufferResourceBindingsTest, MultipleUniformBuffers) { auto* ub_struct_type = MakeUniformBufferType("ub_type", {ty.i32(), ty.u32(), ty.f32()}); @@ -1546,7 +1636,30 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, ContainingArray) { EXPECT_EQ(80u, result[0].size_no_padding); } -TEST_F(InspectorGetStorageBufferResourceBindingsTest, Simple) { +TEST_F(InspectorGetStorageBufferResourceBindingsTest, Simple_NonStruct) { + AddStorageBuffer("foo_sb", ty.i32(), ast::Access::kReadWrite, 0, 0); + MakePlainGlobalReferenceBodyFunction("sb_func", "foo_sb", ty.i32(), {}); + + MakeCallerBodyFunction("ep_func", {"sb_func"}, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetStorageBufferResourceBindings("ep_func"); + ASSERT_FALSE(inspector.has_error()) << inspector.error(); + ASSERT_EQ(1u, result.size()); + + EXPECT_EQ(ResourceBinding::ResourceType::kStorageBuffer, + result[0].resource_type); + EXPECT_EQ(0u, result[0].bind_group); + EXPECT_EQ(0u, result[0].binding); + EXPECT_EQ(4u, result[0].size); + EXPECT_EQ(4u, result[0].size_no_padding); +} + +TEST_F(InspectorGetStorageBufferResourceBindingsTest, Simple_Struct) { auto foo_struct_type = MakeStorageBufferTypes("foo_type", {ty.i32()}); AddStorageBuffer("foo_sb", foo_struct_type(), ast::Access::kReadWrite, 0, 0); @@ -1743,6 +1856,29 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, ContainingPadding) { EXPECT_EQ(12u, result[0].size_no_padding); } +TEST_F(InspectorGetStorageBufferResourceBindingsTest, NonStructVec3) { + AddStorageBuffer("foo_ub", ty.vec3(), ast::Access::kReadWrite, 0, 0); + MakePlainGlobalReferenceBodyFunction("ub_func", "foo_ub", ty.vec3(), {}); + + MakeCallerBodyFunction("ep_func", {"ub_func"}, + ast::DecorationList{ + Stage(ast::PipelineStage::kFragment), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetStorageBufferResourceBindings("ep_func"); + ASSERT_FALSE(inspector.has_error()) << inspector.error(); + ASSERT_EQ(1u, result.size()); + + EXPECT_EQ(ResourceBinding::ResourceType::kStorageBuffer, + result[0].resource_type); + EXPECT_EQ(0u, result[0].bind_group); + EXPECT_EQ(0u, result[0].binding); + EXPECT_EQ(12u, result[0].size); + EXPECT_EQ(12u, result[0].size_no_padding); +} + TEST_F(InspectorGetStorageBufferResourceBindingsTest, SkipReadOnly) { auto foo_struct_type = MakeStorageBufferTypes("foo_type", {ty.i32()}); AddStorageBuffer("foo_sb", foo_struct_type(), ast::Access::kRead, 0, 0); diff --git a/test/buffer/storage/types/array.wgsl.expected.msl b/test/buffer/storage/types/array.wgsl.expected.msl index a95963fecf..685b94a589 100644 --- a/test/buffer/storage/types/array.wgsl.expected.msl +++ b/test/buffer/storage/types/array.wgsl.expected.msl @@ -5,7 +5,7 @@ struct tint_array_wrapper { /* 0x0000 */ float arr[4]; }; -kernel void tint_symbol(device tint_array_wrapper* tint_symbol_1 [[buffer(1)]], const device tint_array_wrapper* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device tint_array_wrapper* tint_symbol_1 [[buffer(0)]], const device tint_array_wrapper* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/f32.wgsl.expected.msl b/test/buffer/storage/types/f32.wgsl.expected.msl index 033b5e2feb..c4b17538be 100644 --- a/test/buffer/storage/types/f32.wgsl.expected.msl +++ b/test/buffer/storage/types/f32.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device float* tint_symbol_1 [[buffer(1)]], const device float* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device float* tint_symbol_1 [[buffer(0)]], const device float* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/i32.wgsl.expected.msl b/test/buffer/storage/types/i32.wgsl.expected.msl index 1de9e037ca..d801b7828a 100644 --- a/test/buffer/storage/types/i32.wgsl.expected.msl +++ b/test/buffer/storage/types/i32.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device int* tint_symbol_1 [[buffer(1)]], const device int* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device int* tint_symbol_1 [[buffer(0)]], const device int* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/mat2x2.wgsl.expected.msl b/test/buffer/storage/types/mat2x2.wgsl.expected.msl index c9ffff792a..fbc66bf298 100644 --- a/test/buffer/storage/types/mat2x2.wgsl.expected.msl +++ b/test/buffer/storage/types/mat2x2.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device float2x2* tint_symbol_1 [[buffer(1)]], const device float2x2* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device float2x2* tint_symbol_1 [[buffer(0)]], const device float2x2* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/mat2x3.wgsl.expected.msl b/test/buffer/storage/types/mat2x3.wgsl.expected.msl index 1b704d4a15..16b12980d2 100644 --- a/test/buffer/storage/types/mat2x3.wgsl.expected.msl +++ b/test/buffer/storage/types/mat2x3.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device float2x3* tint_symbol_1 [[buffer(1)]], const device float2x3* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device float2x3* tint_symbol_1 [[buffer(0)]], const device float2x3* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/mat3x2.wgsl.expected.msl b/test/buffer/storage/types/mat3x2.wgsl.expected.msl index b8765f08ab..b74b79a4d9 100644 --- a/test/buffer/storage/types/mat3x2.wgsl.expected.msl +++ b/test/buffer/storage/types/mat3x2.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device float3x2* tint_symbol_1 [[buffer(1)]], const device float3x2* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device float3x2* tint_symbol_1 [[buffer(0)]], const device float3x2* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/mat4x4.wgsl.expected.msl b/test/buffer/storage/types/mat4x4.wgsl.expected.msl index 6b33874fa0..d52f1c4956 100644 --- a/test/buffer/storage/types/mat4x4.wgsl.expected.msl +++ b/test/buffer/storage/types/mat4x4.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device float4x4* tint_symbol_1 [[buffer(1)]], const device float4x4* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device float4x4* tint_symbol_1 [[buffer(0)]], const device float4x4* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/runtime_array.wgsl.expected.msl b/test/buffer/storage/types/runtime_array.wgsl.expected.msl index ff6a283788..573d6cee1a 100644 --- a/test/buffer/storage/types/runtime_array.wgsl.expected.msl +++ b/test/buffer/storage/types/runtime_array.wgsl.expected.msl @@ -11,7 +11,7 @@ struct tint_symbol_4 { /* 0x0000 */ S arr[1]; }; -kernel void tint_symbol(device tint_symbol_2* tint_symbol_1 [[buffer(1)]], const device tint_symbol_4* tint_symbol_3 [[buffer(0)]]) { +kernel void tint_symbol(device tint_symbol_2* tint_symbol_1 [[buffer(0)]], const device tint_symbol_4* tint_symbol_3 [[buffer(1)]]) { (*(tint_symbol_1)).arr[0] = (*(tint_symbol_3)).arr[0]; return; } diff --git a/test/buffer/storage/types/u32.wgsl.expected.msl b/test/buffer/storage/types/u32.wgsl.expected.msl index bb65310eec..c2e500aba1 100644 --- a/test/buffer/storage/types/u32.wgsl.expected.msl +++ b/test/buffer/storage/types/u32.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device uint* tint_symbol_1 [[buffer(1)]], const device uint* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device uint* tint_symbol_1 [[buffer(0)]], const device uint* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/vec2.wgsl.expected.msl b/test/buffer/storage/types/vec2.wgsl.expected.msl index c730e05c48..6627fa21b5 100644 --- a/test/buffer/storage/types/vec2.wgsl.expected.msl +++ b/test/buffer/storage/types/vec2.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device int2* tint_symbol_1 [[buffer(1)]], const device int2* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device int2* tint_symbol_1 [[buffer(0)]], const device int2* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/vec3.wgsl.expected.msl b/test/buffer/storage/types/vec3.wgsl.expected.msl index 6d55796494..54cc7c86a3 100644 --- a/test/buffer/storage/types/vec3.wgsl.expected.msl +++ b/test/buffer/storage/types/vec3.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device uint3* tint_symbol_1 [[buffer(1)]], const device uint3* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device uint3* tint_symbol_1 [[buffer(0)]], const device uint3* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; } diff --git a/test/buffer/storage/types/vec4.wgsl.expected.msl b/test/buffer/storage/types/vec4.wgsl.expected.msl index 4af349b7cf..65cbadb30b 100644 --- a/test/buffer/storage/types/vec4.wgsl.expected.msl +++ b/test/buffer/storage/types/vec4.wgsl.expected.msl @@ -1,7 +1,7 @@ #include using namespace metal; -kernel void tint_symbol(device float4* tint_symbol_1 [[buffer(1)]], const device float4* tint_symbol_2 [[buffer(0)]]) { +kernel void tint_symbol(device float4* tint_symbol_1 [[buffer(0)]], const device float4* tint_symbol_2 [[buffer(1)]]) { *(tint_symbol_1) = *(tint_symbol_2); return; }