Metal: Allocate threadgroup memory based on Tint reflection

Tint passes threadgroup memory in MSL as entrypoint arguments since
threadgroup memory at the module scope cannot be default initialized.
MSL lacks default constructors for matrices in threadgroup memory.

Bug: dawn:1110
Change-Id: I7462fa448c6ebdb3cc4dc24bd5ff0a99287cdba0
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/64240
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2021-09-15 18:16:50 +00:00 committed by Dawn LUCI CQ
parent 49d794fbcc
commit 5528d0edd2
5 changed files with 119 additions and 10 deletions

View File

@ -47,6 +47,7 @@ namespace dawn_native { namespace metal {
NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState; NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState;
MTLSize mLocalWorkgroupSize; MTLSize mLocalWorkgroupSize;
bool mRequiresStorageBufferLength; bool mRequiresStorageBufferLength;
std::vector<uint32_t> mWorkgroupAllocations;
}; };
}} // namespace dawn_native::metal }} // namespace dawn_native::metal

View File

@ -54,11 +54,18 @@ namespace dawn_native { namespace metal {
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z); mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
mRequiresStorageBufferLength = computeData.needsStorageBufferLength; mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
mWorkgroupAllocations = std::move(computeData.workgroupAllocations);
return {}; return {};
} }
void ComputePipeline::Encode(id<MTLComputeCommandEncoder> encoder) { void ComputePipeline::Encode(id<MTLComputeCommandEncoder> encoder) {
[encoder setComputePipelineState:mMtlComputePipelineState.Get()]; [encoder setComputePipelineState:mMtlComputePipelineState.Get()];
for (size_t i = 0; i < mWorkgroupAllocations.size(); ++i) {
if (mWorkgroupAllocations[i] == 0) {
continue;
}
[encoder setThreadgroupMemoryLength:mWorkgroupAllocations[i] atIndex:i];
}
} }
MTLSize ComputePipeline::GetLocalWorkGroupSize() const { MTLSize ComputePipeline::GetLocalWorkGroupSize() const {

View File

@ -37,6 +37,7 @@ namespace dawn_native { namespace metal {
struct MetalFunctionData { struct MetalFunctionData {
NSPRef<id<MTLFunction>> function; NSPRef<id<MTLFunction>> function;
bool needsStorageBufferLength; bool needsStorageBufferLength;
std::vector<uint32_t> workgroupAllocations;
}; };
MaybeError CreateFunction(const char* entryPointName, MaybeError CreateFunction(const char* entryPointName,
SingleShaderStage stage, SingleShaderStage stage,
@ -53,7 +54,8 @@ namespace dawn_native { namespace metal {
const RenderPipeline* renderPipeline, const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName, std::string* remappedEntryPointName,
bool* needsStorageBufferLength, bool* needsStorageBufferLength,
bool* hasInvariantAttribute); bool* hasInvariantAttribute,
std::vector<uint32_t>* workgroupAllocations);
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult);

View File

@ -44,14 +44,16 @@ namespace dawn_native { namespace metal {
return InitializeBase(parseResult); return InitializeBase(parseResult);
} }
ResultOrError<std::string> ShaderModule::TranslateToMSL(const char* entryPointName, ResultOrError<std::string> ShaderModule::TranslateToMSL(
const char* entryPointName,
SingleShaderStage stage, SingleShaderStage stage,
const PipelineLayout* layout, const PipelineLayout* layout,
uint32_t sampleMask, uint32_t sampleMask,
const RenderPipeline* renderPipeline, const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName, std::string* remappedEntryPointName,
bool* needsStorageBufferLength, bool* needsStorageBufferLength,
bool* hasInvariantAttribute) { bool* hasInvariantAttribute,
std::vector<uint32_t>* workgroupAllocations) {
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
std::ostringstream errorStream; std::ostringstream errorStream;
@ -166,6 +168,7 @@ namespace dawn_native { namespace metal {
*needsStorageBufferLength = result.needs_storage_buffer_sizes; *needsStorageBufferLength = result.needs_storage_buffer_sizes;
*hasInvariantAttribute = result.has_invariant_attribute; *hasInvariantAttribute = result.has_invariant_attribute;
*workgroupAllocations = std::move(result.workgroup_allocations[*remappedEntryPointName]);
return std::move(result.msl); return std::move(result.msl);
} }
@ -190,7 +193,7 @@ namespace dawn_native { namespace metal {
DAWN_TRY_ASSIGN(msl, DAWN_TRY_ASSIGN(msl,
TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline, TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline,
&remappedEntryPointName, &out->needsStorageBufferLength, &remappedEntryPointName, &out->needsStorageBufferLength,
&hasInvariantAttribute)); &hasInvariantAttribute, &out->workgroupAllocations));
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
// category. -Wunused-variable in particular comes up a lot in generated code, and some // category. -Wunused-variable in particular comes up a lot in generated code, and some

View File

@ -100,6 +100,102 @@ TEST_P(ComputeSharedMemoryTests, Basic) {
})"); })");
} }
// Test using assorted types in workgroup memory. MSL lacks constructors
// for matrices in threadgroup memory. Basic test that reading and
// writing a matrix in workgroup memory works.
TEST_P(ComputeSharedMemoryTests, AssortedTypes) {
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = utils::CreateShaderModule(device, R"(
struct StructValues {
m: mat2x2<f32>;
};
[[block]] struct Dst {
d_struct : StructValues;
d_matrix : mat2x2<f32>;
d_array : array<u32, 4>;
d_vector : vec4<f32>;
};
[[group(0), binding(0)]] var<storage, write> dst : Dst;
var<workgroup> wg_struct : StructValues;
var<workgroup> wg_matrix : mat2x2<f32>;
var<workgroup> wg_array : array<u32, 4>;
var<workgroup> wg_vector : vec4<f32>;
[[stage(compute), workgroup_size(4,1,1)]]
fn main([[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
let i = 4u * LocalInvocationID.x;
if (LocalInvocationID.x == 0u) {
wg_struct.m = mat2x2<f32>(
vec2<f32>(f32(i), f32(i + 1u)),
vec2<f32>(f32(i + 2u), f32(i + 3u)));
} elseif (LocalInvocationID.x == 1u) {
wg_matrix = mat2x2<f32>(
vec2<f32>(f32(i), f32(i + 1u)),
vec2<f32>(f32(i + 2u), f32(i + 3u)));
} elseif (LocalInvocationID.x == 2u) {
wg_array[0u] = i;
wg_array[1u] = i + 1u;
wg_array[2u] = i + 2u;
wg_array[3u] = i + 3u;
} elseif (LocalInvocationID.x == 3u) {
wg_vector = vec4<f32>(
f32(i), f32(i + 1u), f32(i + 2u), f32(i + 3u));
}
workgroupBarrier();
if (LocalInvocationID.x == 0u) {
dst.d_struct = wg_struct;
dst.d_matrix = wg_matrix;
dst.d_array = wg_array;
dst.d_vector = wg_vector;
}
}
)");
csDesc.compute.entryPoint = "main";
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
// Set up dst storage buffer
wgpu::BufferDescriptor dstDesc;
dstDesc.size = 64;
dstDesc.usage =
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
wgpu::Buffer dst = device.CreateBuffer(&dstDesc);
// Set up bind group and issue dispatch
wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
{
{0, dst},
});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline);
pass.SetBindGroup(0, bindGroup);
pass.Dispatch(1);
pass.EndPass();
commands = encoder.Finish();
}
queue.Submit(1, &commands);
std::array<float, 4> expectedStruct = {0., 1., 2., 3.};
std::array<float, 4> expectedMatrix = {4., 5., 6., 7.};
std::array<uint32_t, 4> expectedArray = {8, 9, 10, 11};
std::array<float, 4> expectedVector = {12., 13., 14., 15.};
EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedStruct.data(), dst, 0, 4);
EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedMatrix.data(), dst, 16, 4);
EXPECT_BUFFER_U32_RANGE_EQ(expectedArray.data(), dst, 32, 4);
EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedVector.data(), dst, 48, 4);
}
DAWN_INSTANTIATE_TEST(ComputeSharedMemoryTests, DAWN_INSTANTIATE_TEST(ComputeSharedMemoryTests,
D3D12Backend(), D3D12Backend(),
MetalBackend(), MetalBackend(),