Add MSL support for UseTintGenerator toggle

Turns on Tint generation of MSL if UseTintGenerator is on

Bug: dawn:571
Change-Id: Icfa523c36a509baf5da3b2a54152a7fb462c86f4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32303
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Austin Eng 2020-12-07 22:04:23 +00:00 committed by Commit Bot service account
parent 042184128e
commit 8a73e1876d
5 changed files with 173 additions and 73 deletions

View File

@ -877,6 +877,36 @@ namespace dawn_native {
}
return std::move(output.module);
}
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) {
auto transform = std::make_unique<tint::transform::VertexPulling>();
tint::transform::VertexStateDescriptor state;
for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
const auto& vertexBuffer = vertexState.vertexBuffers[i];
tint::transform::VertexBufferLayoutDescriptor layout;
layout.array_stride = vertexBuffer.arrayStride;
layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
const auto& attribute = vertexBuffer.attributes[j];
tint::transform::VertexAttributeDescriptor attr;
attr.format = ToTintVertexFormat(attribute.format);
attr.offset = attribute.offset;
attr.shader_location = attribute.shaderLocation;
layout.attributes.push_back(std::move(attr));
}
state.push_back(std::move(layout));
}
transform->SetVertexState(std::move(state));
transform->SetEntryPoint(entryPoint);
transform->SetPullingBufferBindingSet(static_cast<uint32_t>(pullingBufferBindingSet));
return transform;
}
#endif
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
@ -972,7 +1002,7 @@ namespace dawn_native {
const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const {
BindGroupIndex pullingBufferBindingSet) const {
tint::ast::Module module;
DAWN_TRY_ASSIGN(module, ParseSPIRV(spirv));
@ -983,37 +1013,13 @@ namespace dawn_native {
tint::ast::Module* moduleIn,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const {
BindGroupIndex pullingBufferBindingSet) const {
std::ostringstream errorStream;
errorStream << "Tint vertex pulling failure:" << std::endl;
tint::transform::Manager transformManager;
{
auto transform = std::make_unique<tint::transform::VertexPulling>();
tint::transform::VertexStateDescriptor state;
for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
const auto& vertexBuffer = vertexState.vertexBuffers[i];
tint::transform::VertexBufferLayoutDescriptor layout;
layout.array_stride = vertexBuffer.arrayStride;
layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
const auto& attribute = vertexBuffer.attributes[j];
tint::transform::VertexAttributeDescriptor attr;
attr.format = ToTintVertexFormat(attribute.format);
attr.offset = attribute.offset;
attr.shader_location = attribute.shaderLocation;
layout.attributes.push_back(std::move(attr));
}
state.push_back(std::move(layout));
}
transform->SetVertexState(std::move(state));
transform->SetEntryPoint(entryPoint);
transform->SetPullingBufferBindingSet(pullingBufferBindingSet);
transformManager.append(std::move(transform));
}
transformManager.append(
MakeVertexPullingTransform(vertexState, entryPoint, pullingBufferBindingSet));
if (GetDevice()->IsRobustnessEnabled()) {
// TODO(enga): Run the Tint BoundArrayAccessors transform instead of the SPIRV Tools
// one, but it appears to crash after running VertexPulling.

View File

@ -39,6 +39,7 @@ namespace tint {
namespace transform {
class Manager;
class VertexPulling;
} // namespace transform
} // namespace tint
@ -79,6 +80,11 @@ namespace dawn_native {
#ifdef DAWN_ENABLE_WGSL
ResultOrError<tint::ast::Module> RunTransforms(tint::transform::Manager* manager,
tint::ast::Module* module);
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet);
#endif
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
@ -148,13 +154,13 @@ namespace dawn_native {
const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const;
BindGroupIndex pullingBufferBindingSet) const;
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
tint::ast::Module* module,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const;
BindGroupIndex pullingBufferBindingSet) const;
#endif
protected:

View File

@ -38,6 +38,8 @@ namespace dawn_native { namespace metal {
// The number of Metal buffers Dawn can use in a generic way (i.e. that aren't reserved)
static constexpr size_t kGenericMetalBufferSlots = kMetalBufferTableSize - 1;
static constexpr BindGroupIndex kPullingBufferBindingSet = BindGroupIndex(kMaxBindGroups);
class PipelineLayout final : public PipelineLayoutBase {
public:
PipelineLayout(Device* device, const PipelineLayoutDescriptor* descriptor);

View File

@ -50,6 +50,22 @@ namespace dawn_native { namespace metal {
const RenderPipeline* renderPipeline = nullptr);
private:
ResultOrError<std::string> TranslateToMSLWithTint(const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
uint32_t sampleMask,
const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName,
bool* needsStorageBufferLength);
ResultOrError<std::string> TranslateToMSLWithSPIRVCross(
const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
uint32_t sampleMask,
const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName,
bool* needsStorageBufferLength);
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult);

View File

@ -55,20 +55,71 @@ namespace dawn_native { namespace metal {
return {};
}
MaybeError ShaderModule::CreateFunction(const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ASSERT(!IsError());
ASSERT(out);
ResultOrError<std::string> ShaderModule::TranslateToMSLWithTint(
const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
// TODO(crbug.com/tint/387): AND in a fixed sample mask in the shader.
uint32_t sampleMask,
const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName,
bool* needsStorageBufferLength) {
#if DAWN_ENABLE_WGSL
// TODO(crbug.com/tint/256): Set this accordingly if arrayLength(..) is used.
*needsStorageBufferLength = false;
std::ostringstream errorStream;
errorStream << "Tint MSL failure:" << std::endl;
tint::transform::Manager transformManager;
if (stage == SingleShaderStage::Vertex &&
GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
transformManager.append(
MakeVertexPullingTransform(*renderPipeline->GetVertexStateDescriptor(),
entryPointName, kPullingBufferBindingSet));
for (VertexBufferSlot slot :
IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
DAWN_UNUSED(metalIndex);
// TODO(crbug.com/tint/104): Tell Tint to map (kPullingBufferBindingSet, slot) to
// this MSL buffer index.
}
}
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
tint::ast::Module module;
DAWN_TRY_ASSIGN(module, RunTransforms(&transformManager, mTintModule.get()));
ASSERT(remappedEntryPointName != nullptr);
tint::inspector::Inspector inspector(module);
*remappedEntryPointName = inspector.GetRemappedNameForEntryPoint(entryPointName);
tint::writer::msl::Generator generator(std::move(module));
if (!generator.Generate()) {
errorStream << "Generator: " << generator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
std::string msl = generator.result();
return std::move(msl);
#else
UNREACHABLE();
#endif
}
ResultOrError<std::string> ShaderModule::TranslateToMSLWithSPIRVCross(
const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
uint32_t sampleMask,
const RenderPipeline* renderPipeline,
std::string* remappedEntryPointName,
bool* needsStorageBufferLength) {
const std::vector<uint32_t>* spirv = &GetSpirv();
spv::ExecutionModel executionModel = ShaderStageToExecutionModel(stage);
#ifdef DAWN_ENABLE_WGSL
// Use set 4 since it is bigger than what users can access currently
static const uint32_t kPullingBufferBindingSet = 4;
std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
stage == SingleShaderStage::Vertex) {
@ -152,7 +203,7 @@ namespace dawn_native { namespace metal {
spirv_cross::MSLResourceBinding mslBinding;
mslBinding.stage = spv::ExecutionModelVertex;
mslBinding.desc_set = kPullingBufferBindingSet;
mslBinding.desc_set = static_cast<uint32_t>(kPullingBufferBindingSet);
mslBinding.binding = static_cast<uint8_t>(slot);
mslBinding.msl_buffer = metalIndex;
compiler.add_msl_resource_binding(mslBinding);
@ -160,49 +211,68 @@ namespace dawn_native { namespace metal {
}
#endif
{
// SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing
// by default.
std::string msl = compiler.compile();
// SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing
// by default.
std::string msl = compiler.compile();
// Some entry point names are forbidden in MSL so SPIRV-Cross modifies them. Query the
// modified entryPointName from it.
const std::string& modifiedEntryPointName =
compiler.get_entry_point(entryPointName, executionModel).name;
// Some entry point names are forbidden in MSL so SPIRV-Cross modifies them. Query the
// modified entryPointName from it.
*remappedEntryPointName = compiler.get_entry_point(entryPointName, executionModel).name;
*needsStorageBufferLength = compiler.needs_buffer_size_buffer();
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
// category. -Wunused-variable in particular comes up a lot in generated code, and some
// (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
// of a warning.
msl = R"(\
return std::move(msl);
}
MaybeError ShaderModule::CreateFunction(const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ASSERT(!IsError());
ASSERT(out);
std::string remappedEntryPointName;
std::string msl;
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
DAWN_TRY_ASSIGN(msl, TranslateToMSLWithTint(entryPointName, stage, layout, sampleMask,
renderPipeline, &remappedEntryPointName,
&out->needsStorageBufferLength));
} else {
DAWN_TRY_ASSIGN(msl, TranslateToMSLWithSPIRVCross(
entryPointName, stage, layout, sampleMask, renderPipeline,
&remappedEntryPointName, &out->needsStorageBufferLength));
}
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
// category. -Wunused-variable in particular comes up a lot in generated code, and some
// (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
// of a warning.
msl = R"(\
#ifdef __clang__
#pragma clang diagnostic ignored "-Wall"
#endif
)" + msl;
NSRef<NSString> mslSource =
AcquireNSRef([[NSString alloc] initWithUTF8String:msl.c_str()]);
NSRef<NSString> mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:msl.c_str()]);
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
NSError* error = nullptr;
NSPRef<id<MTLLibrary>> library =
AcquireNSPRef([mtlDevice newLibraryWithSource:mslSource.Get()
options:nullptr
error:&error]);
if (error != nullptr) {
if (error.code != MTLLibraryErrorCompileWarning) {
const char* errorString = [error.localizedDescription UTF8String];
return DAWN_VALIDATION_ERROR(std::string("Unable to create library object: ") +
errorString);
}
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
NSError* error = nullptr;
NSPRef<id<MTLLibrary>> library =
AcquireNSPRef([mtlDevice newLibraryWithSource:mslSource.Get()
options:nullptr
error:&error]);
if (error != nullptr) {
if (error.code != MTLLibraryErrorCompileWarning) {
const char* errorString = [error.localizedDescription UTF8String];
return DAWN_VALIDATION_ERROR(std::string("Unable to create library object: ") +
errorString);
}
NSRef<NSString> name =
AcquireNSRef([[NSString alloc] initWithUTF8String:modifiedEntryPointName.c_str()]);
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
}
out->needsStorageBufferLength = compiler.needs_buffer_size_buffer();
NSRef<NSString> name =
AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]);
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
GetEntryPoint(entryPointName).usedVertexAttributes.any()) {