Metal: Allocate threadgroup memory based on Tint reflection
Tint passes threadgroup memory in MSL as entrypoint arguments since threadgroup memory at the module scope cannot be default initialized. MSL lacks default constructors for matrices in threadgroup memory. Bug: dawn:1110 Change-Id: I7462fa448c6ebdb3cc4dc24bd5ff0a99287cdba0 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/64240 Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
parent
49d794fbcc
commit
5528d0edd2
|
@ -47,6 +47,7 @@ namespace dawn_native { namespace metal {
|
||||||
NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState;
|
NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState;
|
||||||
MTLSize mLocalWorkgroupSize;
|
MTLSize mLocalWorkgroupSize;
|
||||||
bool mRequiresStorageBufferLength;
|
bool mRequiresStorageBufferLength;
|
||||||
|
std::vector<uint32_t> mWorkgroupAllocations;
|
||||||
};
|
};
|
||||||
|
|
||||||
}} // namespace dawn_native::metal
|
}} // namespace dawn_native::metal
|
||||||
|
|
|
@ -54,11 +54,18 @@ namespace dawn_native { namespace metal {
|
||||||
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
|
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
|
||||||
|
|
||||||
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
|
mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
|
||||||
|
mWorkgroupAllocations = std::move(computeData.workgroupAllocations);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComputePipeline::Encode(id<MTLComputeCommandEncoder> encoder) {
|
void ComputePipeline::Encode(id<MTLComputeCommandEncoder> encoder) {
|
||||||
[encoder setComputePipelineState:mMtlComputePipelineState.Get()];
|
[encoder setComputePipelineState:mMtlComputePipelineState.Get()];
|
||||||
|
for (size_t i = 0; i < mWorkgroupAllocations.size(); ++i) {
|
||||||
|
if (mWorkgroupAllocations[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
[encoder setThreadgroupMemoryLength:mWorkgroupAllocations[i] atIndex:i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MTLSize ComputePipeline::GetLocalWorkGroupSize() const {
|
MTLSize ComputePipeline::GetLocalWorkGroupSize() const {
|
||||||
|
|
|
@ -37,6 +37,7 @@ namespace dawn_native { namespace metal {
|
||||||
struct MetalFunctionData {
|
struct MetalFunctionData {
|
||||||
NSPRef<id<MTLFunction>> function;
|
NSPRef<id<MTLFunction>> function;
|
||||||
bool needsStorageBufferLength;
|
bool needsStorageBufferLength;
|
||||||
|
std::vector<uint32_t> workgroupAllocations;
|
||||||
};
|
};
|
||||||
MaybeError CreateFunction(const char* entryPointName,
|
MaybeError CreateFunction(const char* entryPointName,
|
||||||
SingleShaderStage stage,
|
SingleShaderStage stage,
|
||||||
|
@ -53,7 +54,8 @@ namespace dawn_native { namespace metal {
|
||||||
const RenderPipeline* renderPipeline,
|
const RenderPipeline* renderPipeline,
|
||||||
std::string* remappedEntryPointName,
|
std::string* remappedEntryPointName,
|
||||||
bool* needsStorageBufferLength,
|
bool* needsStorageBufferLength,
|
||||||
bool* hasInvariantAttribute);
|
bool* hasInvariantAttribute,
|
||||||
|
std::vector<uint32_t>* workgroupAllocations);
|
||||||
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
||||||
~ShaderModule() override = default;
|
~ShaderModule() override = default;
|
||||||
MaybeError Initialize(ShaderModuleParseResult* parseResult);
|
MaybeError Initialize(ShaderModuleParseResult* parseResult);
|
||||||
|
|
|
@ -44,14 +44,16 @@ namespace dawn_native { namespace metal {
|
||||||
return InitializeBase(parseResult);
|
return InitializeBase(parseResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultOrError<std::string> ShaderModule::TranslateToMSL(const char* entryPointName,
|
ResultOrError<std::string> ShaderModule::TranslateToMSL(
|
||||||
|
const char* entryPointName,
|
||||||
SingleShaderStage stage,
|
SingleShaderStage stage,
|
||||||
const PipelineLayout* layout,
|
const PipelineLayout* layout,
|
||||||
uint32_t sampleMask,
|
uint32_t sampleMask,
|
||||||
const RenderPipeline* renderPipeline,
|
const RenderPipeline* renderPipeline,
|
||||||
std::string* remappedEntryPointName,
|
std::string* remappedEntryPointName,
|
||||||
bool* needsStorageBufferLength,
|
bool* needsStorageBufferLength,
|
||||||
bool* hasInvariantAttribute) {
|
bool* hasInvariantAttribute,
|
||||||
|
std::vector<uint32_t>* workgroupAllocations) {
|
||||||
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
||||||
|
|
||||||
std::ostringstream errorStream;
|
std::ostringstream errorStream;
|
||||||
|
@ -166,6 +168,7 @@ namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
*needsStorageBufferLength = result.needs_storage_buffer_sizes;
|
*needsStorageBufferLength = result.needs_storage_buffer_sizes;
|
||||||
*hasInvariantAttribute = result.has_invariant_attribute;
|
*hasInvariantAttribute = result.has_invariant_attribute;
|
||||||
|
*workgroupAllocations = std::move(result.workgroup_allocations[*remappedEntryPointName]);
|
||||||
|
|
||||||
return std::move(result.msl);
|
return std::move(result.msl);
|
||||||
}
|
}
|
||||||
|
@ -190,7 +193,7 @@ namespace dawn_native { namespace metal {
|
||||||
DAWN_TRY_ASSIGN(msl,
|
DAWN_TRY_ASSIGN(msl,
|
||||||
TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline,
|
TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline,
|
||||||
&remappedEntryPointName, &out->needsStorageBufferLength,
|
&remappedEntryPointName, &out->needsStorageBufferLength,
|
||||||
&hasInvariantAttribute));
|
&hasInvariantAttribute, &out->workgroupAllocations));
|
||||||
|
|
||||||
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
|
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
|
||||||
// category. -Wunused-variable in particular comes up a lot in generated code, and some
|
// category. -Wunused-variable in particular comes up a lot in generated code, and some
|
||||||
|
|
|
@ -100,6 +100,102 @@ TEST_P(ComputeSharedMemoryTests, Basic) {
|
||||||
})");
|
})");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test using assorted types in workgroup memory. MSL lacks constructors
|
||||||
|
// for matrices in threadgroup memory. Basic test that reading and
|
||||||
|
// writing a matrix in workgroup memory works.
|
||||||
|
TEST_P(ComputeSharedMemoryTests, AssortedTypes) {
|
||||||
|
wgpu::ComputePipelineDescriptor csDesc;
|
||||||
|
csDesc.compute.module = utils::CreateShaderModule(device, R"(
|
||||||
|
struct StructValues {
|
||||||
|
m: mat2x2<f32>;
|
||||||
|
};
|
||||||
|
|
||||||
|
[[block]] struct Dst {
|
||||||
|
d_struct : StructValues;
|
||||||
|
d_matrix : mat2x2<f32>;
|
||||||
|
d_array : array<u32, 4>;
|
||||||
|
d_vector : vec4<f32>;
|
||||||
|
};
|
||||||
|
|
||||||
|
[[group(0), binding(0)]] var<storage, write> dst : Dst;
|
||||||
|
|
||||||
|
var<workgroup> wg_struct : StructValues;
|
||||||
|
var<workgroup> wg_matrix : mat2x2<f32>;
|
||||||
|
var<workgroup> wg_array : array<u32, 4>;
|
||||||
|
var<workgroup> wg_vector : vec4<f32>;
|
||||||
|
|
||||||
|
[[stage(compute), workgroup_size(4,1,1)]]
|
||||||
|
fn main([[builtin(local_invocation_id)]] LocalInvocationID : vec3<u32>) {
|
||||||
|
|
||||||
|
let i = 4u * LocalInvocationID.x;
|
||||||
|
if (LocalInvocationID.x == 0u) {
|
||||||
|
wg_struct.m = mat2x2<f32>(
|
||||||
|
vec2<f32>(f32(i), f32(i + 1u)),
|
||||||
|
vec2<f32>(f32(i + 2u), f32(i + 3u)));
|
||||||
|
} elseif (LocalInvocationID.x == 1u) {
|
||||||
|
wg_matrix = mat2x2<f32>(
|
||||||
|
vec2<f32>(f32(i), f32(i + 1u)),
|
||||||
|
vec2<f32>(f32(i + 2u), f32(i + 3u)));
|
||||||
|
} elseif (LocalInvocationID.x == 2u) {
|
||||||
|
wg_array[0u] = i;
|
||||||
|
wg_array[1u] = i + 1u;
|
||||||
|
wg_array[2u] = i + 2u;
|
||||||
|
wg_array[3u] = i + 3u;
|
||||||
|
} elseif (LocalInvocationID.x == 3u) {
|
||||||
|
wg_vector = vec4<f32>(
|
||||||
|
f32(i), f32(i + 1u), f32(i + 2u), f32(i + 3u));
|
||||||
|
}
|
||||||
|
|
||||||
|
workgroupBarrier();
|
||||||
|
|
||||||
|
if (LocalInvocationID.x == 0u) {
|
||||||
|
dst.d_struct = wg_struct;
|
||||||
|
dst.d_matrix = wg_matrix;
|
||||||
|
dst.d_array = wg_array;
|
||||||
|
dst.d_vector = wg_vector;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)");
|
||||||
|
csDesc.compute.entryPoint = "main";
|
||||||
|
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
|
||||||
|
|
||||||
|
// Set up dst storage buffer
|
||||||
|
wgpu::BufferDescriptor dstDesc;
|
||||||
|
dstDesc.size = 64;
|
||||||
|
dstDesc.usage =
|
||||||
|
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
|
||||||
|
wgpu::Buffer dst = device.CreateBuffer(&dstDesc);
|
||||||
|
|
||||||
|
// Set up bind group and issue dispatch
|
||||||
|
wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
||||||
|
{
|
||||||
|
{0, dst},
|
||||||
|
});
|
||||||
|
|
||||||
|
wgpu::CommandBuffer commands;
|
||||||
|
{
|
||||||
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
|
pass.SetPipeline(pipeline);
|
||||||
|
pass.SetBindGroup(0, bindGroup);
|
||||||
|
pass.Dispatch(1);
|
||||||
|
pass.EndPass();
|
||||||
|
|
||||||
|
commands = encoder.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
queue.Submit(1, &commands);
|
||||||
|
|
||||||
|
std::array<float, 4> expectedStruct = {0., 1., 2., 3.};
|
||||||
|
std::array<float, 4> expectedMatrix = {4., 5., 6., 7.};
|
||||||
|
std::array<uint32_t, 4> expectedArray = {8, 9, 10, 11};
|
||||||
|
std::array<float, 4> expectedVector = {12., 13., 14., 15.};
|
||||||
|
EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedStruct.data(), dst, 0, 4);
|
||||||
|
EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedMatrix.data(), dst, 16, 4);
|
||||||
|
EXPECT_BUFFER_U32_RANGE_EQ(expectedArray.data(), dst, 32, 4);
|
||||||
|
EXPECT_BUFFER_FLOAT_RANGE_EQ(expectedVector.data(), dst, 48, 4);
|
||||||
|
}
|
||||||
|
|
||||||
DAWN_INSTANTIATE_TEST(ComputeSharedMemoryTests,
|
DAWN_INSTANTIATE_TEST(ComputeSharedMemoryTests,
|
||||||
D3D12Backend(),
|
D3D12Backend(),
|
||||||
MetalBackend(),
|
MetalBackend(),
|
||||||
|
|
Loading…
Reference in New Issue