Produce tint::ast::Module in the frontend if UseTintGenerator

This factors code to move parsing of tint::ast::Module to the
frontend. All backends will use this code path when
UseTintGenerator is enabled for both SPIR-V and WGSL ingestion.

To avoid too much code explosion, parsing and validating the
shader is moved into ValidateShaderModuleDescriptor which
returns a result struct that gets passed into creation.

Bug: dawn:571
Change-Id: I598693ef36954fd0056a0744a2a0ebd7cc7d40a4
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32301
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
Austin Eng 2020-12-07 18:12:13 +00:00 committed by Commit Bot service account
parent 224a3a4ab5
commit 0d948f7752
23 changed files with 388 additions and 251 deletions

View File

@ -584,7 +584,8 @@ namespace dawn_native {
} }
ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule( ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
ShaderModuleBase blueprint(this, descriptor); ShaderModuleBase blueprint(this, descriptor);
const size_t blueprintHash = blueprint.ComputeContentHash(); const size_t blueprintHash = blueprint.ComputeContentHash();
@ -597,7 +598,18 @@ namespace dawn_native {
} }
ShaderModuleBase* backendObj; ShaderModuleBase* backendObj;
DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor)); if (parseResult == nullptr) {
// We skip the parse on creation if validation isn't enabled which let's us quickly
// lookup in the cache without validating and parsing. We need the parsed module now, so
// call validate. Most of |ValidateShaderModuleDescriptor| is parsing, but we can
// consider splitting it if additional validation is added.
ASSERT(!IsValidationEnabled());
ShaderModuleParseResult localParseResult =
ValidateShaderModuleDescriptor(this, descriptor).AcquireSuccess();
DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor, &localParseResult));
} else {
DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor, parseResult));
}
backendObj->SetIsCachedReference(); backendObj->SetIsCachedReference();
backendObj->SetContentHash(blueprintHash); backendObj->SetContentHash(blueprintHash);
mCaches->shaderModules.insert(backendObj); mCaches->shaderModules.insert(backendObj);
@ -1062,10 +1074,15 @@ namespace dawn_native {
MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result, MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor) {
DAWN_TRY(ValidateIsAlive()); DAWN_TRY(ValidateIsAlive());
ShaderModuleParseResult parseResult = {};
ShaderModuleParseResult* parseResultPtr = nullptr;
if (IsValidationEnabled()) { if (IsValidationEnabled()) {
DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor)); DAWN_TRY_ASSIGN(parseResult, ValidateShaderModuleDescriptor(this, descriptor));
parseResultPtr = &parseResult;
} }
DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor));
DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor, parseResultPtr));
return {}; return {};
} }

View File

@ -40,6 +40,7 @@ namespace dawn_native {
class PersistentCache; class PersistentCache;
class StagingBufferBase; class StagingBufferBase;
struct InternalPipelineStore; struct InternalPipelineStore;
struct ShaderModuleParseResult;
class DeviceBase { class DeviceBase {
public: public:
@ -129,7 +130,8 @@ namespace dawn_native {
void UncacheSampler(SamplerBase* obj); void UncacheSampler(SamplerBase* obj);
ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule( ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
const ShaderModuleDescriptor* descriptor); const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
void UncacheShaderModule(ShaderModuleBase* obj); void UncacheShaderModule(ShaderModuleBase* obj);
Ref<AttachmentState> GetOrCreateAttachmentState(AttachmentStateBlueprint* blueprint); Ref<AttachmentState> GetOrCreateAttachmentState(AttachmentStateBlueprint* blueprint);
@ -275,7 +277,8 @@ namespace dawn_native {
virtual ResultOrError<SamplerBase*> CreateSamplerImpl( virtual ResultOrError<SamplerBase*> CreateSamplerImpl(
const SamplerDescriptor* descriptor) = 0; const SamplerDescriptor* descriptor) = 0;
virtual ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl( virtual ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) = 0; const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) = 0;
virtual ResultOrError<SwapChainBase*> CreateSwapChainImpl( virtual ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) = 0; const SwapChainDescriptor* descriptor) = 0;
// Note that previousSwapChain may be nullptr, or come from a different backend. // Note that previousSwapChain may be nullptr, or come from a different backend.

View File

@ -173,13 +173,12 @@ namespace dawn_native {
} }
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
MaybeError ValidateWGSL(const char* source) { ResultOrError<tint::ast::Module> ParseWGSL(const char* wgsl) {
std::ostringstream errorStream; std::ostringstream errorStream;
errorStream << "Tint WGSL failure:" << std::endl; errorStream << "Tint WGSL reader failure:" << std::endl;
tint::Source::File file("", source); tint::Source::File file("", wgsl);
tint::reader::wgsl::Parser parser(&file); tint::reader::wgsl::Parser parser(&file);
if (!parser.Parse()) { if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl; errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
@ -191,14 +190,46 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
} }
tint::TypeDeterminer type_determiner(&module); tint::TypeDeterminer typeDeterminer(&module);
if (!type_determiner.Determine()) { if (!typeDeterminer.Determine()) {
errorStream << "Type Determination: " << type_determiner.error(); errorStream << "Type Determination: " << typeDeterminer.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
} }
return std::move(module);
}
ResultOrError<tint::ast::Module> ParseSPIRV(const std::vector<uint32_t>& spirv) {
std::ostringstream errorStream;
errorStream << "Tint SPIRV reader failure:" << std::endl;
tint::reader::spirv::Parser parser(spirv);
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::TypeDeterminer typeDeterminer(&module);
if (!typeDeterminer.Determine()) {
errorStream << "Type Determination: " << typeDeterminer.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
return std::move(module);
}
MaybeError ValidateModule(tint::ast::Module* module) {
std::ostringstream errorStream;
errorStream << "Tint module validation" << std::endl;
tint::Validator validator; tint::Validator validator;
if (!validator.Validate(&module)) { if (!validator.Validate(module)) {
errorStream << "Validation: " << validator.error() << std::endl; errorStream << "Validation: " << validator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
} }
@ -206,111 +237,9 @@ namespace dawn_native {
return {}; return {};
} }
ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRV(const char* source) { ResultOrError<std::vector<uint32_t>> ModuleToSPIRV(tint::ast::Module module) {
std::ostringstream errorStream; std::ostringstream errorStream;
errorStream << "Tint WGSL->SPIR-V failure:" << std::endl; errorStream << "Tint SPIR-V writer failure:" << std::endl;
tint::Source::File file("", source);
tint::reader::wgsl::Parser parser(&file);
// TODO: This is a duplicate parse with ValidateWGSL, need to store
// state between calls to avoid this.
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::TypeDeterminer type_determiner(&module);
if (!type_determiner.Determine()) {
errorStream << "Type Determination: " << type_determiner.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::writer::spirv::Generator generator(std::move(module));
if (!generator.Generate()) {
errorStream << "Generator: " << generator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
std::vector<uint32_t> spirv = generator.result();
DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
return std::move(spirv);
}
ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRVWithPulling(
const char* source,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) {
std::ostringstream errorStream;
errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
tint::Source::File file("", source);
tint::reader::wgsl::Parser parser(&file);
// TODO: This is a duplicate parse with ValidateWGSL, need to store
// state between calls to avoid this.
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::transform::Manager transformManager;
{
auto transform = std::make_unique<tint::transform::VertexPulling>();
tint::transform::VertexStateDescriptor state;
for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
auto& vertexBuffer = vertexState.vertexBuffers[i];
tint::transform::VertexBufferLayoutDescriptor layout;
layout.array_stride = vertexBuffer.arrayStride;
layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
auto& attribute = vertexBuffer.attributes[j];
tint::transform::VertexAttributeDescriptor attr;
attr.format = ToTintVertexFormat(attribute.format);
attr.offset = attribute.offset;
attr.shader_location = attribute.shaderLocation;
layout.attributes.push_back(std::move(attr));
}
state.push_back(std::move(layout));
}
transform->SetVertexState(std::move(state));
transform->SetEntryPoint(entryPoint);
transform->SetPullingBufferBindingSet(pullingBufferBindingSet);
transformManager.append(std::move(transform));
}
auto result = transformManager.Run(&module);
if (result.diagnostics.contains_errors()) {
errorStream << "Vertex pulling transform: "
<< tint::diag::Formatter{}.format(result.diagnostics);
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
module = std::move(result.module);
tint::TypeDeterminer type_determiner(&module);
if (!type_determiner.Determine()) {
errorStream << "Type Determination: " << type_determiner.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::writer::spirv::Generator generator(std::move(module)); tint::writer::spirv::Generator generator(std::move(module));
if (!generator.Generate()) { if (!generator.Generate()) {
@ -745,38 +674,16 @@ namespace dawn_native {
// completed using PopulateMetadataUsingSPIRVCross. In the future, once // completed using PopulateMetadataUsingSPIRVCross. In the future, once
// this function is complete, ReflectShaderUsingSPIRVCross and // this function is complete, ReflectShaderUsingSPIRVCross and
// PopulateMetadataUsingSPIRVCross will be removed. // PopulateMetadataUsingSPIRVCross will be removed.
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(DeviceBase* device, ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(
std::vector<uint32_t> spirv) { DeviceBase* device,
const tint::ast::Module& module) {
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
ASSERT(module.IsValid());
EntryPointMetadataTable result; EntryPointMetadataTable result;
std::ostringstream errorStream; std::ostringstream errorStream;
errorStream << "Tint Reflection failure:" << std::endl; errorStream << "Tint Reflection failure:" << std::endl;
tint::reader::spirv::Parser parser(spirv);
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::TypeDeterminer typeDeterminer(&module);
if (!typeDeterminer.Determine()) {
errorStream << "Type Determination: " << typeDeterminer.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::Validator validator;
if (!validator.Validate(&module)) {
errorStream << "Validation: " << validator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::inspector::Inspector inspector(module); tint::inspector::Inspector inspector(module);
auto entryPoints = inspector.GetEntryPoints(); auto entryPoints = inspector.GetEntryPoints();
if (inspector.has_error()) { if (inspector.has_error()) {
@ -862,8 +769,17 @@ namespace dawn_native {
} // anonymous namespace } // anonymous namespace
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, ShaderModuleParseResult::ShaderModuleParseResult() = default;
const ShaderModuleDescriptor* descriptor) { ShaderModuleParseResult::~ShaderModuleParseResult() = default;
ShaderModuleParseResult::ShaderModuleParseResult(ShaderModuleParseResult&& rhs) = default;
ShaderModuleParseResult& ShaderModuleParseResult::operator=(ShaderModuleParseResult&& rhs) =
default;
ResultOrError<ShaderModuleParseResult> ValidateShaderModuleDescriptor(
DeviceBase* device,
const ShaderModuleDescriptor* descriptor) {
const ChainedStruct* chainedDescriptor = descriptor->nextInChain; const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
if (chainedDescriptor == nullptr) { if (chainedDescriptor == nullptr) {
return DAWN_VALIDATION_ERROR("Shader module descriptor missing chained descriptor"); return DAWN_VALIDATION_ERROR("Shader module descriptor missing chained descriptor");
@ -874,11 +790,29 @@ namespace dawn_native {
"Shader module descriptor chained nextInChain must be nullptr"); "Shader module descriptor chained nextInChain must be nullptr");
} }
ShaderModuleParseResult parseResult = {};
switch (chainedDescriptor->sType) { switch (chainedDescriptor->sType) {
case wgpu::SType::ShaderModuleSPIRVDescriptor: { case wgpu::SType::ShaderModuleSPIRVDescriptor: {
const auto* spirvDesc = const auto* spirvDesc =
static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor); static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
DAWN_TRY(ValidateSpirv(spirvDesc->code, spirvDesc->codeSize)); std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
#ifdef DAWN_ENABLE_WGSL
tint::ast::Module module;
DAWN_TRY_ASSIGN(module, ParseSPIRV(spirv));
if (device->IsValidationEnabled()) {
DAWN_TRY(ValidateModule(&module));
}
parseResult.tintModule = std::make_unique<tint::ast::Module>(std::move(module));
#else
return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
#endif // DAWN_ENABLE_WGSL
} else {
if (device->IsValidationEnabled()) {
DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
}
parseResult.spirv = std::move(spirv);
}
break; break;
} }
@ -886,17 +820,35 @@ namespace dawn_native {
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
const auto* wgslDesc = const auto* wgslDesc =
static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor); static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor);
DAWN_TRY(ValidateWGSL(wgslDesc->source));
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
tint::ast::Module module;
DAWN_TRY_ASSIGN(module, ParseWGSL(wgslDesc->source));
if (device->IsValidationEnabled()) {
DAWN_TRY(ValidateModule(&module));
}
parseResult.tintModule = std::make_unique<tint::ast::Module>(std::move(module));
} else {
tint::ast::Module module;
DAWN_TRY_ASSIGN(module, ParseWGSL(wgslDesc->source));
if (device->IsValidationEnabled()) {
DAWN_TRY(ValidateModule(&module));
}
std::vector<uint32_t> spirv;
DAWN_TRY_ASSIGN(spirv, ModuleToSPIRV(std::move(module)));
DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
parseResult.spirv = std::move(spirv);
}
break; break;
#else #else
return DAWN_VALIDATION_ERROR("WGSL not supported (yet)"); return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
#endif // DAWN_ENABLE_WGSL #endif // DAWN_ENABLE_WGSL
} }
default: default:
return DAWN_VALIDATION_ERROR("Unsupported sType"); return DAWN_VALIDATION_ERROR("Unsupported sType");
} }
return {}; return std::move(parseResult);
} }
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
@ -910,6 +862,23 @@ namespace dawn_native {
return bufferSizes; return bufferSizes;
} }
#ifdef DAWN_ENABLE_WGSL
ResultOrError<tint::ast::Module> RunTransforms(tint::transform::Manager* manager,
tint::ast::Module* module) {
tint::transform::Transform::Output output = manager->Run(module);
if (output.diagnostics.contains_errors()) {
std::string err =
"Tint transform failure: " + tint::diag::Formatter{}.format(output.diagnostics);
return DAWN_VALIDATION_ERROR(err.c_str());
}
if (!output.module.IsValid()) {
return DAWN_VALIDATION_ERROR("Tint transform did not produce valid module.");
}
return std::move(output.module);
}
#endif
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device, MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint, const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout) { const PipelineLayoutBase* layout) {
@ -994,49 +963,136 @@ namespace dawn_native {
} }
const std::vector<uint32_t>& ShaderModuleBase::GetSpirv() const { const std::vector<uint32_t>& ShaderModuleBase::GetSpirv() const {
ASSERT(!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
return mSpirv; return mSpirv;
} }
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv( ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState, const VertexStateDescriptor& vertexState,
const std::string& entryPoint, const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const { uint32_t pullingBufferBindingSet) const {
std::vector<uint32_t> spirv; tint::ast::Module module;
DAWN_TRY_ASSIGN(spirv, ConvertWGSLToSPIRVWithPulling(mWgsl.c_str(), vertexState, entryPoint, DAWN_TRY_ASSIGN(module, ParseSPIRV(spirv));
pullingBufferBindingSet));
return GeneratePullingSpirv(&module, vertexState, entryPoint, pullingBufferBindingSet);
}
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
tint::ast::Module* moduleIn,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const {
std::ostringstream errorStream;
errorStream << "Tint vertex pulling failure:" << std::endl;
tint::transform::Manager transformManager;
{
auto transform = std::make_unique<tint::transform::VertexPulling>();
tint::transform::VertexStateDescriptor state;
for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) {
const auto& vertexBuffer = vertexState.vertexBuffers[i];
tint::transform::VertexBufferLayoutDescriptor layout;
layout.array_stride = vertexBuffer.arrayStride;
layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode);
for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) {
const auto& attribute = vertexBuffer.attributes[j];
tint::transform::VertexAttributeDescriptor attr;
attr.format = ToTintVertexFormat(attribute.format);
attr.offset = attribute.offset;
attr.shader_location = attribute.shaderLocation;
layout.attributes.push_back(std::move(attr));
}
state.push_back(std::move(layout));
}
transform->SetVertexState(std::move(state));
transform->SetEntryPoint(entryPoint);
transform->SetPullingBufferBindingSet(pullingBufferBindingSet);
transformManager.append(std::move(transform));
}
if (GetDevice()->IsRobustnessEnabled()) {
// TODO(enga): Run the Tint BoundArrayAccessors transform instead of the SPIRV Tools
// one, but it appears to crash after running VertexPulling.
// transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
}
tint::ast::Module module;
DAWN_TRY_ASSIGN(module, RunTransforms(&transformManager, moduleIn));
tint::writer::spirv::Generator generator(std::move(module));
if (!generator.Generate()) {
errorStream << "Generator: " << generator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
std::vector<uint32_t> spirv = generator.result();
if (GetDevice()->IsRobustnessEnabled()) { if (GetDevice()->IsRobustnessEnabled()) {
DAWN_TRY_ASSIGN(spirv, RunRobustBufferAccessPass(spirv)); DAWN_TRY_ASSIGN(spirv, RunRobustBufferAccessPass(spirv));
} }
DAWN_TRY(ValidateSpirv(spirv.data(), spirv.size()));
return std::move(spirv); return std::move(spirv);
} }
#endif #endif
MaybeError ShaderModuleBase::InitializeBase() { MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
std::vector<uint32_t> spirv;
if (mType == Type::Wgsl) {
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
DAWN_TRY_ASSIGN(spirv, ConvertWGSLToSPIRV(mWgsl.c_str())); tint::ast::Module* module = parseResult->tintModule.get();
#endif
mSpirv = std::move(parseResult->spirv);
// If not using Tint to generate backend code, run the robust buffer access pass now since
// all backends will use this SPIR-V. If Tint is used, the robustness pass should be run
// per-backend.
if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator) &&
GetDevice()->IsRobustnessEnabled()) {
DAWN_TRY_ASSIGN(mSpirv, RunRobustBufferAccessPass(mSpirv));
}
// We still need the spirv for reflection. Remove this when we use the Tint inspector
// completely.
std::vector<uint32_t>* spirvPtr = &mSpirv;
std::vector<uint32_t> localSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
#ifdef DAWN_ENABLE_WGSL
ASSERT(module != nullptr);
tint::ast::Module clonedModule = module->Clone();
tint::TypeDeterminer typeDeterminer(&clonedModule);
if (!typeDeterminer.Determine()) {
return DAWN_VALIDATION_ERROR(typeDeterminer.error().c_str());
}
DAWN_TRY_ASSIGN(localSpirv, ModuleToSPIRV(std::move(clonedModule)));
DAWN_TRY(ValidateSpirv(localSpirv.data(), localSpirv.size()));
spirvPtr = &localSpirv;
#else
UNREACHABLE();
#endif
}
if (GetDevice()->IsToggleEnabled(Toggle::UseTintInspector)) {
#ifdef DAWN_ENABLE_WGSL
tint::ast::Module localModule;
tint::ast::Module* modulePtr = module;
if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
// We have mSpirv, but no Tint module
DAWN_TRY_ASSIGN(localModule, ParseSPIRV(mSpirv));
DAWN_TRY(ValidateModule(&localModule));
modulePtr = &localModule;
}
EntryPointMetadataTable table;
DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), *modulePtr));
DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), *spirvPtr, &table));
mEntryPoints = std::move(table);
#else #else
return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build."); return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
#endif // DAWN_ENABLE_WGSL #endif
} else { } else {
spirv = mOriginalSpirv; DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), *spirvPtr));
}
if (GetDevice()->IsRobustnessEnabled()) {
DAWN_TRY_ASSIGN(spirv, RunRobustBufferAccessPass(spirv));
}
mSpirv = std::move(spirv);
if (GetDevice()->IsToggleEnabled(Toggle::UseTintInspector)) {
EntryPointMetadataTable table;
DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), mSpirv));
DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), mSpirv, &table));
mEntryPoints = std::move(table);
} else {
DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), mSpirv));
} }
return {}; return {};

View File

@ -31,6 +31,18 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
namespace tint {
namespace ast {
class Module;
} // namespace ast
namespace transform {
class Manager;
} // namespace transform
} // namespace tint
namespace spirv_cross { namespace spirv_cross {
class Compiler; class Compiler;
} }
@ -43,14 +55,31 @@ namespace dawn_native {
using EntryPointMetadataTable = using EntryPointMetadataTable =
std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>; std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>;
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, struct ShaderModuleParseResult {
const ShaderModuleDescriptor* descriptor); ShaderModuleParseResult();
~ShaderModuleParseResult();
ShaderModuleParseResult(ShaderModuleParseResult&& rhs);
ShaderModuleParseResult& operator=(ShaderModuleParseResult&& rhs);
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::ast::Module> tintModule;
#endif
std::vector<uint32_t> spirv;
};
ResultOrError<ShaderModuleParseResult> ValidateShaderModuleDescriptor(
DeviceBase* device,
const ShaderModuleDescriptor* descriptor);
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device, MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint, const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout); const PipelineLayoutBase* layout);
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout); const PipelineLayoutBase* layout);
#ifdef DAWN_ENABLE_WGSL
ResultOrError<tint::ast::Module> RunTransforms(tint::transform::Manager* manager,
tint::ast::Module* module);
#endif
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
// stored in the ShaderModuleBase and destroyed only when the shader module is destroyed so // stored in the ShaderModuleBase and destroyed only when the shader module is destroyed so
@ -116,13 +145,20 @@ namespace dawn_native {
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv( ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const;
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
tint::ast::Module* module,
const VertexStateDescriptor& vertexState, const VertexStateDescriptor& vertexState,
const std::string& entryPoint, const std::string& entryPoint,
uint32_t pullingBufferBindingSet) const; uint32_t pullingBufferBindingSet) const;
#endif #endif
protected: protected:
MaybeError InitializeBase(); MaybeError InitializeBase(ShaderModuleParseResult* parseResult);
private: private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);

View File

@ -143,7 +143,7 @@ namespace dawn_native {
"http://crbug.com/1138528"}}, "http://crbug.com/1138528"}},
{Toggle::UseTintGenerator, {Toggle::UseTintGenerator,
{"use_tint_generator", "Use Tint instead of SPRIV-cross to generate shaders.", {"use_tint_generator", "Use Tint instead of SPRIV-cross to generate shaders.",
"https://crbug.com/dawn/548"}}, "https://crbug.com/dawn/571"}},
{Toggle::UseTintInspector, {Toggle::UseTintInspector,
{"use_tint_inspector", "Use Tint instead of SPRIV-cross for shader reflection.", {"use_tint_inspector", "Use Tint instead of SPRIV-cross for shader reflection.",
"https://crbug.com/dawn/578"}}, "https://crbug.com/dawn/578"}},

View File

@ -325,8 +325,9 @@ namespace dawn_native { namespace d3d12 {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
return ShaderModule::Create(this, descriptor); ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
} }
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl( ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -160,7 +160,8 @@ namespace dawn_native { namespace d3d12 {
const RenderPipelineDescriptor* descriptor) override; const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override; const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl( ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl( ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(

View File

@ -174,9 +174,10 @@ namespace dawn_native { namespace d3d12 {
// static // static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device, ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->InitializeBase()); DAWN_TRY(module->Initialize(parseResult));
return module.Detach(); return module.Detach();
} }
@ -184,6 +185,14 @@ namespace dawn_native { namespace d3d12 {
: ShaderModuleBase(device, descriptor) { : ShaderModuleBase(device, descriptor) {
} }
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
DAWN_TRY(InitializeBase(parseResult));
#ifdef DAWN_ENABLE_WGSL
mTintModule = std::move(parseResult->tintModule);
#endif
return {};
}
ResultOrError<std::string> ShaderModule::TranslateToHLSLWithTint( ResultOrError<std::string> ShaderModule::TranslateToHLSLWithTint(
const char* entryPointName, const char* entryPointName,
SingleShaderStage stage, SingleShaderStage stage,
@ -195,41 +204,11 @@ namespace dawn_native { namespace d3d12 {
std::ostringstream errorStream; std::ostringstream errorStream;
errorStream << "Tint HLSL failure:" << std::endl; errorStream << "Tint HLSL failure:" << std::endl;
// TODO: Remove redundant SPIRV step between WGSL and HLSL.
tint::reader::spirv::Parser parser(GetSpirv());
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::TypeDeterminer typeDeterminer(&module);
if (!typeDeterminer.Determine()) {
errorStream << "Type Determination: " << typeDeterminer.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::Validator validator;
if (!validator.Validate(&module)) {
errorStream << "Validation: " << validator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>()); transformManager.append(std::make_unique<tint::transform::BoundArrayAccessors>());
auto result = transformManager.Run(&module);
if (result.diagnostics.contains_errors()) { tint::ast::Module module;
errorStream << "Bound Array Accessors Transform: " DAWN_TRY_ASSIGN(module, RunTransforms(&transformManager, mTintModule.get()));
<< tint::diag::Formatter{}.format(result.diagnostics);
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
module = std::move(result.module);
ASSERT(remappedEntryPointName != nullptr); ASSERT(remappedEntryPointName != nullptr);
tint::inspector::Inspector inspector(module); tint::inspector::Inspector inspector(module);

View File

@ -36,7 +36,8 @@ namespace dawn_native { namespace d3d12 {
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<ShaderModule*> Create(Device* device, static ResultOrError<ShaderModule*> Create(Device* device,
const ShaderModuleDescriptor* descriptor); const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
ResultOrError<CompiledShader> Compile(const char* entryPointName, ResultOrError<CompiledShader> Compile(const char* entryPointName,
SingleShaderStage stage, SingleShaderStage stage,
@ -46,6 +47,7 @@ namespace dawn_native { namespace d3d12 {
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
ResultOrError<std::string> TranslateToHLSLWithTint( ResultOrError<std::string> TranslateToHLSLWithTint(
const char* entryPointName, const char* entryPointName,
@ -61,6 +63,10 @@ namespace dawn_native { namespace d3d12 {
SingleShaderStage stage, SingleShaderStage stage,
const std::string& hlslSource, const std::string& hlslSource,
uint32_t compileFlags) const; uint32_t compileFlags) const;
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::ast::Module> mTintModule;
#endif
}; };
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -92,7 +92,8 @@ namespace dawn_native { namespace metal {
const RenderPipelineDescriptor* descriptor) override; const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override; const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl( ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl( ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(

View File

@ -151,8 +151,9 @@ namespace dawn_native { namespace metal {
return Sampler::Create(this, descriptor); return Sampler::Create(this, descriptor);
} }
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
return ShaderModule::Create(this, descriptor); ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
} }
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl( ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -35,7 +35,8 @@ namespace dawn_native { namespace metal {
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<ShaderModule*> Create(Device* device, static ResultOrError<ShaderModule*> Create(Device* device,
const ShaderModuleDescriptor* descriptor); const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
struct MetalFunctionData { struct MetalFunctionData {
NSPRef<id<MTLFunction>> function; NSPRef<id<MTLFunction>> function;
@ -51,7 +52,11 @@ namespace dawn_native { namespace metal {
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize(); MaybeError Initialize(ShaderModuleParseResult* parseResult);
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::ast::Module> mTintModule;
#endif
}; };
}} // namespace dawn_native::metal }} // namespace dawn_native::metal

View File

@ -22,15 +22,24 @@
#include <spirv_msl.hpp> #include <spirv_msl.hpp>
#ifdef DAWN_ENABLE_WGSL
// Tint include must be after spirv_msl.hpp, because spirv-cross has its own
// version of spirv_headers. We also need to undef SPV_REVISION because SPIRV-Cross
// is at 3 while spirv-headers is at 4.
# undef SPV_REVISION
# include <tint/tint.h>
#endif // DAWN_ENABLE_WGSL
#include <sstream> #include <sstream>
namespace dawn_native { namespace metal { namespace dawn_native { namespace metal {
// static // static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device, ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize()); DAWN_TRY(module->Initialize(parseResult));
return module.Detach(); return module.Detach();
} }
@ -38,8 +47,12 @@ namespace dawn_native { namespace metal {
: ShaderModuleBase(device, descriptor) { : ShaderModuleBase(device, descriptor) {
} }
MaybeError ShaderModule::Initialize() { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
return InitializeBase(); DAWN_TRY(InitializeBase(parseResult));
#ifdef DAWN_ENABLE_WGSL
mTintModule = std::move(parseResult->tintModule);
#endif
return {};
} }
MaybeError ShaderModule::CreateFunction(const char* entryPointName, MaybeError ShaderModule::CreateFunction(const char* entryPointName,
@ -59,9 +72,17 @@ namespace dawn_native { namespace metal {
std::vector<uint32_t> pullingSpirv; std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
stage == SingleShaderStage::Vertex) { stage == SingleShaderStage::Vertex) {
DAWN_TRY_ASSIGN(pullingSpirv, if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(), DAWN_TRY_ASSIGN(pullingSpirv,
entryPointName, kPullingBufferBindingSet)); GeneratePullingSpirv(mTintModule.get(),
*renderPipeline->GetVertexStateDescriptor(),
entryPointName, kPullingBufferBindingSet));
} else {
DAWN_TRY_ASSIGN(
pullingSpirv,
GeneratePullingSpirv(GetSpirv(), *renderPipeline->GetVertexStateDescriptor(),
entryPointName, kPullingBufferBindingSet));
}
spirv = &pullingSpirv; spirv = &pullingSpirv;
} }
#endif #endif

View File

@ -127,9 +127,10 @@ namespace dawn_native { namespace null {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
DAWN_TRY(module->Initialize()); DAWN_TRY(module->Initialize(parseResult));
return module.Detach(); return module.Detach();
} }
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl( ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
@ -395,8 +396,8 @@ namespace dawn_native { namespace null {
// ShaderModule // ShaderModule
MaybeError ShaderModule::Initialize() { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
return InitializeBase(); return InitializeBase(parseResult);
} }
// OldSwapChain // OldSwapChain

View File

@ -135,7 +135,8 @@ namespace dawn_native { namespace null {
const RenderPipelineDescriptor* descriptor) override; const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override; const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl( ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl( ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(
@ -246,7 +247,7 @@ namespace dawn_native { namespace null {
public: public:
using ShaderModuleBase::ShaderModuleBase; using ShaderModuleBase::ShaderModuleBase;
MaybeError Initialize(); MaybeError Initialize(ShaderModuleParseResult* parseResult);
}; };
class SwapChain final : public NewSwapChainBase { class SwapChain final : public NewSwapChainBase {

View File

@ -128,8 +128,9 @@ namespace dawn_native { namespace opengl {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
return ShaderModule::Create(this, descriptor); ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
} }
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl( ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -91,7 +91,8 @@ namespace dawn_native { namespace opengl {
const RenderPipelineDescriptor* descriptor) override; const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override; const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl( ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl( ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(

View File

@ -59,9 +59,10 @@ namespace dawn_native { namespace opengl {
// static // static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device, ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->InitializeBase()); DAWN_TRY(module->InitializeBase(parseResult));
return module.Detach(); return module.Detach();
} }

View File

@ -47,7 +47,8 @@ namespace dawn_native { namespace opengl {
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<ShaderModule*> Create(Device* device, static ResultOrError<ShaderModule*> Create(Device* device,
const ShaderModuleDescriptor* descriptor); const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
std::string TranslateToGLSL(const char* entryPointName, std::string TranslateToGLSL(const char* entryPointName,
SingleShaderStage stage, SingleShaderStage stage,

View File

@ -136,8 +136,9 @@ namespace dawn_native { namespace vulkan {
return Sampler::Create(this, descriptor); return Sampler::Create(this, descriptor);
} }
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
return ShaderModule::Create(this, descriptor); ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
} }
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl( ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -127,7 +127,8 @@ namespace dawn_native { namespace vulkan {
const RenderPipelineDescriptor* descriptor) override; const RenderPipelineDescriptor* descriptor) override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl( ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override; const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ResultOrError<SwapChainBase*> CreateSwapChainImpl( ResultOrError<SwapChainBase*> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<NewSwapChainBase*> CreateSwapChainImpl( ResultOrError<NewSwapChainBase*> CreateSwapChainImpl(

View File

@ -24,12 +24,13 @@ namespace dawn_native { namespace vulkan {
// static // static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device, ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
if (module == nullptr) { if (module == nullptr) {
return DAWN_VALIDATION_ERROR("Unable to create ShaderModule"); return DAWN_VALIDATION_ERROR("Unable to create ShaderModule");
} }
DAWN_TRY(module->Initialize()); DAWN_TRY(module->Initialize(parseResult));
return module.Detach(); return module.Detach();
} }
@ -37,8 +38,8 @@ namespace dawn_native { namespace vulkan {
: ShaderModuleBase(device, descriptor) { : ShaderModuleBase(device, descriptor) {
} }
MaybeError ShaderModule::Initialize() { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
DAWN_TRY(InitializeBase()); DAWN_TRY(InitializeBase(parseResult));
const std::vector<uint32_t>& spirv = GetSpirv(); const std::vector<uint32_t>& spirv = GetSpirv();
VkShaderModuleCreateInfo createInfo; VkShaderModuleCreateInfo createInfo;

View File

@ -27,14 +27,15 @@ namespace dawn_native { namespace vulkan {
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<ShaderModule*> Create(Device* device, static ResultOrError<ShaderModule*> Create(Device* device,
const ShaderModuleDescriptor* descriptor); const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
VkShaderModule GetHandle() const; VkShaderModule GetHandle() const;
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override; ~ShaderModule() override;
MaybeError Initialize(); MaybeError Initialize(ShaderModuleParseResult* parseResult);
VkShaderModule mHandle = VK_NULL_HANDLE; VkShaderModule mHandle = VK_NULL_HANDLE;
}; };