Update ComputePipelineDescriptor to use PipelineStageDescriptor

The contents of PipelineStageDescriptor were inlined inside of
ComputePipelineDescriptor. This changes updates
ComputePipelineDescriptor to contain PipelineStageDescriptor to match
WebGPU.

Bug: chromium:877147
Change-Id: Ic030b7bd7a237945cbbaf4c567cc361940e1ad00
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6400
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Austin Eng 2019-04-09 15:17:30 +00:00 committed by Commit Bot service account
parent 6f0b021dbf
commit fbe6cfdb16
13 changed files with 63 additions and 56 deletions

View File

@ -375,8 +375,7 @@
"extensible": true, "extensible": true,
"members": [ "members": [
{"name": "layout", "type": "pipeline layout"}, {"name": "layout", "type": "pipeline layout"},
{"name": "module", "type": "shader module"}, {"name": "compute stage", "type": "pipeline stage descriptor", "annotation": "const*"}
{"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"}
] ]
}, },
"device": { "device": {

View File

@ -241,9 +241,13 @@ void initSim() {
dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl); dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
dawn::ComputePipelineDescriptor csDesc; dawn::ComputePipelineDescriptor csDesc;
csDesc.module = module;
csDesc.entryPoint = "main";
csDesc.layout = pl; csDesc.layout = pl;
dawn::PipelineStageDescriptor computeStage;
computeStage.module = module;
computeStage.entryPoint = "main";
csDesc.computeStage = &computeStage;
updatePipeline = device.CreateComputePipeline(&csDesc); updatePipeline = device.CreateComputePipeline(&csDesc);
for (uint32_t i = 0; i < 2; ++i) { for (uint32_t i = 0; i < 2; ++i) {

View File

@ -24,21 +24,9 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR("nextInChain must be nullptr"); return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");
} }
DAWN_TRY(device->ValidateObject(descriptor->module));
DAWN_TRY(device->ValidateObject(descriptor->layout)); DAWN_TRY(device->ValidateObject(descriptor->layout));
DAWN_TRY(ValidatePipelineStageDescriptor(device, descriptor->computeStage,
if (descriptor->entryPoint != std::string("main")) { descriptor->layout, dawn::ShaderStage::Compute));
return DAWN_VALIDATION_ERROR("Currently the entry point has to be main()");
}
if (descriptor->module->GetExecutionModel() != dawn::ShaderStage::Compute) {
return DAWN_VALIDATION_ERROR("Setting module with wrong execution model");
}
if (!descriptor->module->IsCompatibleWithPipelineLayout(descriptor->layout)) {
return DAWN_VALIDATION_ERROR("Stage not compatible with layout");
}
return {}; return {};
} }
@ -47,7 +35,7 @@ namespace dawn_native {
ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
const ComputePipelineDescriptor* descriptor) const ComputePipelineDescriptor* descriptor)
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) { : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
ExtractModuleData(dawn::ShaderStage::Compute, descriptor->module); ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module);
} }
ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag) ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)

View File

@ -20,6 +20,24 @@
namespace dawn_native { namespace dawn_native {
MaybeError ValidatePipelineStageDescriptor(DeviceBase* device,
const PipelineStageDescriptor* descriptor,
const PipelineLayoutBase* layout,
dawn::ShaderStage stage) {
DAWN_TRY(device->ValidateObject(descriptor->module));
if (descriptor->entryPoint != std::string("main")) {
return DAWN_VALIDATION_ERROR("Entry point must be \"main\"");
}
if (descriptor->module->GetExecutionModel() != stage) {
return DAWN_VALIDATION_ERROR("Setting module with wrong stages");
}
if (!descriptor->module->IsCompatibleWithPipelineLayout(layout)) {
return DAWN_VALIDATION_ERROR("Stage not compatible with layout");
}
return {};
}
// PipelineBase // PipelineBase
PipelineBase::PipelineBase(DeviceBase* device, PipelineBase::PipelineBase(DeviceBase* device,

View File

@ -34,6 +34,11 @@ namespace dawn_native {
Float, Float,
}; };
MaybeError ValidatePipelineStageDescriptor(DeviceBase* device,
const PipelineStageDescriptor* descriptor,
const PipelineLayoutBase* layout,
dawn::ShaderStage stage);
class PipelineBase : public ObjectBase { class PipelineBase : public ObjectBase {
public: public:
struct PushConstantInfo { struct PushConstantInfo {

View File

@ -97,24 +97,6 @@ namespace dawn_native {
return {}; return {};
} }
MaybeError ValidatePipelineStageDescriptor(DeviceBase* device,
const PipelineStageDescriptor* descriptor,
const PipelineLayoutBase* layout,
dawn::ShaderStage stage) {
DAWN_TRY(device->ValidateObject(descriptor->module));
if (descriptor->entryPoint != std::string("main")) {
return DAWN_VALIDATION_ERROR("Entry point must be \"main\"");
}
if (descriptor->module->GetExecutionModel() != stage) {
return DAWN_VALIDATION_ERROR("Setting module with wrong stages");
}
if (!descriptor->module->IsCompatibleWithPipelineLayout(layout)) {
return DAWN_VALIDATION_ERROR("Stage not compatible with layout");
}
return {};
}
MaybeError ValidateColorStateDescriptor(const ColorStateDescriptor* descriptor) { MaybeError ValidateColorStateDescriptor(const ColorStateDescriptor* descriptor) {
if (descriptor->nextInChain != nullptr) { if (descriptor->nextInChain != nullptr) {
return DAWN_VALIDATION_ERROR("nextInChain must be nullptr"); return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");

View File

@ -32,7 +32,7 @@ namespace dawn_native { namespace d3d12 {
// SPRIV-cross does matrix multiplication expecting row major matrices // SPRIV-cross does matrix multiplication expecting row major matrices
compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
const ShaderModule* module = ToBackend(descriptor->module); const ShaderModule* module = ToBackend(descriptor->computeStage->module);
const std::string& hlslSource = module->GetHLSLSource(ToBackend(GetLayout())); const std::string& hlslSource = module->GetHLSLSource(ToBackend(GetLayout()));
ComPtr<ID3DBlob> compiledShader; ComPtr<ID3DBlob> compiledShader;
@ -40,8 +40,8 @@ namespace dawn_native { namespace d3d12 {
const PlatformFunctions* functions = device->GetFunctions(); const PlatformFunctions* functions = device->GetFunctions();
if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
nullptr, descriptor->entryPoint, "cs_5_1", compileFlags, 0, nullptr, descriptor->computeStage->entryPoint, "cs_5_1",
&compiledShader, &errors))) { compileFlags, 0, &compiledShader, &errors))) {
printf("%s\n", reinterpret_cast<char*>(errors->GetBufferPointer())); printf("%s\n", reinterpret_cast<char*>(errors->GetBufferPointer()));
ASSERT(false); ASSERT(false);
} }

View File

@ -23,15 +23,14 @@ namespace dawn_native { namespace metal {
: ComputePipelineBase(device, descriptor) { : ComputePipelineBase(device, descriptor) {
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
const auto& module = ToBackend(descriptor->module); const ShaderModule* computeModule = ToBackend(descriptor->computeStage->module);
const char* entryPoint = descriptor->entryPoint; const char* computeEntryPoint = descriptor->computeStage->entryPoint;
ShaderModule::MetalFunctionData computeData = computeModule->GetFunction(
auto compilationData = computeEntryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout()));
module->GetFunction(entryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout()));
NSError* error = nil; NSError* error = nil;
mMtlComputePipelineState = mMtlComputePipelineState =
[mtlDevice newComputePipelineStateWithFunction:compilationData.function error:&error]; [mtlDevice newComputePipelineStateWithFunction:computeData.function error:&error];
if (error != nil) { if (error != nil) {
NSLog(@" error => %@", error); NSLog(@" error => %@", error);
GetDevice()->HandleError("Error creating pipeline state"); GetDevice()->HandleError("Error creating pipeline state");
@ -39,7 +38,7 @@ namespace dawn_native { namespace metal {
} }
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
mLocalWorkgroupSize = compilationData.localWorkgroupSize; mLocalWorkgroupSize = computeData.localWorkgroupSize;
} }
ComputePipeline::~ComputePipeline() { ComputePipeline::~ComputePipeline() {

View File

@ -21,7 +21,7 @@ namespace dawn_native { namespace opengl {
ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor) ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
: ComputePipelineBase(device, descriptor) { : ComputePipelineBase(device, descriptor) {
PerStage<const ShaderModule*> modules(nullptr); PerStage<const ShaderModule*> modules(nullptr);
modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->module); modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->computeStage->module);
PipelineGL::Initialize(ToBackend(descriptor->layout), modules); PipelineGL::Initialize(ToBackend(descriptor->layout), modules);
} }

View File

@ -35,8 +35,8 @@ namespace dawn_native { namespace vulkan {
createInfo.stage.pNext = nullptr; createInfo.stage.pNext = nullptr;
createInfo.stage.flags = 0; createInfo.stage.flags = 0;
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
createInfo.stage.module = ToBackend(descriptor->module)->GetHandle(); createInfo.stage.module = ToBackend(descriptor->computeStage->module)->GetHandle();
createInfo.stage.pName = descriptor->entryPoint; createInfo.stage.pName = descriptor->computeStage->entryPoint;
createInfo.stage.pSpecializationInfo = nullptr; createInfo.stage.pSpecializationInfo = nullptr;
if (device->fn.CreateComputePipelines(device->GetVkDevice(), VK_NULL_HANDLE, 1, &createInfo, if (device->fn.CreateComputePipelines(device->GetVkDevice(), VK_NULL_HANDLE, 1, &createInfo,

View File

@ -68,9 +68,13 @@ TEST_P(BindGroupTests, ReusedBindGroupSingleSubmit) {
dawn::ShaderModule module = dawn::ShaderModule module =
utils::CreateShaderModule(device, dawn::ShaderStage::Compute, shader); utils::CreateShaderModule(device, dawn::ShaderStage::Compute, shader);
dawn::ComputePipelineDescriptor cpDesc; dawn::ComputePipelineDescriptor cpDesc;
cpDesc.module = module;
cpDesc.entryPoint = "main";
cpDesc.layout = pl; cpDesc.layout = pl;
dawn::PipelineStageDescriptor computeStage;
computeStage.module = module;
computeStage.entryPoint = "main";
cpDesc.computeStage = &computeStage;
dawn::ComputePipeline cp = device.CreateComputePipeline(&cpDesc); dawn::ComputePipeline cp = device.CreateComputePipeline(&cpDesc);
dawn::BufferDescriptor bufferDesc; dawn::BufferDescriptor bufferDesc;

View File

@ -39,9 +39,13 @@ void ComputeCopyStorageBufferTests::BasicTest(const char* shader) {
auto pl = utils::MakeBasicPipelineLayout(device, &bgl); auto pl = utils::MakeBasicPipelineLayout(device, &bgl);
dawn::ComputePipelineDescriptor csDesc; dawn::ComputePipelineDescriptor csDesc;
csDesc.module = module;
csDesc.entryPoint = "main";
csDesc.layout = pl; csDesc.layout = pl;
dawn::PipelineStageDescriptor computeStage;
computeStage.module = module;
computeStage.entryPoint = "main";
csDesc.computeStage = &computeStage;
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); dawn::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
// Set up src storage buffer // Set up src storage buffer

View File

@ -149,9 +149,13 @@ class PushConstantTest: public DawnTest {
); );
dawn::ComputePipelineDescriptor descriptor; dawn::ComputePipelineDescriptor descriptor;
descriptor.module = module;
descriptor.entryPoint = "main";
descriptor.layout = pl; descriptor.layout = pl;
dawn::PipelineStageDescriptor computeStage;
computeStage.module = module;
computeStage.entryPoint = "main";
descriptor.computeStage = &computeStage;
return device.CreateComputePipeline(&descriptor); return device.CreateComputePipeline(&descriptor);
} }