ShaderModule: Store the tint::Program in the base class.

This is in preparation of removing all the DAWN_ENABLE_WGSL logic: the
ShaderModuleBase will have either mSpirv or mTintProgram set based on
UseTintGenerator.

Also improves the constness of some functions.

Also simplifies a bit ShaderModuleBase::Initialize.

Bug: dawn:706
Change-Id: Ib879e2aec8a004aeb8ac5dc6e1176b1667fc227d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/45422
Commit-Queue: Austin Eng <enga@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
Corentin Wallez 2021-03-22 18:24:16 +00:00 committed by Commit Bot service account
parent 5ff4978a5d
commit 3e4b57b77e
6 changed files with 42 additions and 66 deletions

View File

@ -266,7 +266,7 @@ namespace dawn_native {
return std::move(program); return std::move(program);
} }
MaybeError ValidateModule(tint::Program* program) { MaybeError ValidateModule(const tint::Program* program) {
std::ostringstream errorStream; std::ostringstream errorStream;
errorStream << "Tint program validation" << std::endl; errorStream << "Tint program validation" << std::endl;
@ -729,14 +729,14 @@ namespace dawn_native {
// PopulateMetadataUsingSPIRVCross will be removed. // PopulateMetadataUsingSPIRVCross will be removed.
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint( ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(
DeviceBase*, DeviceBase*,
const tint::Program& program) { const tint::Program* program) {
ASSERT(program.IsValid()); ASSERT(program->IsValid());
EntryPointMetadataTable result; EntryPointMetadataTable result;
std::ostringstream errorStream; std::ostringstream errorStream;
errorStream << "Tint Reflection failure:" << std::endl; errorStream << "Tint Reflection failure:" << std::endl;
tint::inspector::Inspector inspector(&program); tint::inspector::Inspector inspector(program);
auto entryPoints = inspector.GetEntryPoints(); auto entryPoints = inspector.GetEntryPoints();
if (inspector.has_error()) { if (inspector.has_error()) {
errorStream << "Inspector: " << inspector.error() << std::endl; errorStream << "Inspector: " << inspector.error() << std::endl;
@ -842,7 +842,7 @@ namespace dawn_native {
// fallback/source of truth. // fallback/source of truth.
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingSPIRVCross( ResultOrError<EntryPointMetadataTable> ReflectShaderUsingSPIRVCross(
DeviceBase* device, DeviceBase* device,
std::vector<uint32_t> spirv) { const std::vector<uint32_t>& spirv) {
EntryPointMetadataTable result; EntryPointMetadataTable result;
spirv_cross::Compiler compiler(spirv); spirv_cross::Compiler compiler(spirv);
for (const spirv_cross::EntryPoint& entryPoint : for (const spirv_cross::EntryPoint& entryPoint :
@ -866,7 +866,7 @@ namespace dawn_native {
// using the SPIRV-Cross implementation. Once the Tint implementation is // using the SPIRV-Cross implementation. Once the Tint implementation is
// completed, this function will be removed. // completed, this function will be removed.
MaybeError PopulateMetadataUsingSPIRVCross(DeviceBase* device, MaybeError PopulateMetadataUsingSPIRVCross(DeviceBase* device,
std::vector<uint32_t> spirv, const std::vector<uint32_t>& spirv,
EntryPointMetadataTable* tintTable) { EntryPointMetadataTable* tintTable) {
EntryPointMetadataTable crossTable; EntryPointMetadataTable crossTable;
DAWN_TRY_ASSIGN(crossTable, ReflectShaderUsingSPIRVCross(device, spirv)); DAWN_TRY_ASSIGN(crossTable, ReflectShaderUsingSPIRVCross(device, spirv));
@ -998,6 +998,7 @@ namespace dawn_native {
if (device->IsValidationEnabled()) { if (device->IsValidationEnabled()) {
DAWN_TRY(ValidateModule(&program)); DAWN_TRY(ValidateModule(&program));
} }
parseResult.tintProgram = std::make_unique<tint::Program>(std::move(program));
} else { } else {
tint::transform::Manager transformManager; tint::transform::Manager transformManager;
transformManager.append( transformManager.append(
@ -1015,8 +1016,6 @@ namespace dawn_native {
parseResult.spirv = std::move(spirv); parseResult.spirv = std::move(spirv);
} }
parseResult.tintProgram = std::make_unique<tint::Program>(std::move(program));
break; break;
#else #else
return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build."); return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
@ -1042,7 +1041,7 @@ namespace dawn_native {
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform, ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
tint::Program* program) { const tint::Program* program) {
tint::transform::Transform::Output output = transform->Run(program); tint::transform::Transform::Output output = transform->Run(program);
if (!output.program.IsValid()) { if (!output.program.IsValid()) {
std::string err = "Tint program failure: " + output.program.Diagnostics().str(); std::string err = "Tint program failure: " + output.program.Diagnostics().str();
@ -1169,6 +1168,11 @@ namespace dawn_native {
} }
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
const tint::Program* ShaderModuleBase::GetTintProgram() const {
ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
return mTintProgram.get();
}
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv( ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
const std::vector<uint32_t>& spirv, const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState, const VertexStateDescriptor& vertexState,
@ -1181,7 +1185,7 @@ namespace dawn_native {
} }
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv( ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
tint::Program* programIn, const tint::Program* programIn,
const VertexStateDescriptor& vertexState, const VertexStateDescriptor& vertexState,
const std::string& entryPoint, const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) const { BindGroupIndex pullingBufferBindingSet) const {
@ -1214,7 +1218,7 @@ namespace dawn_native {
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) { MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
tint::Program* program = parseResult->tintProgram.get(); mTintProgram = std::move(parseResult->tintProgram);
#endif #endif
mSpirv = std::move(parseResult->spirv); mSpirv = std::move(parseResult->spirv);
@ -1226,43 +1230,23 @@ namespace dawn_native {
DAWN_TRY_ASSIGN(mSpirv, RunRobustBufferAccessPass(mSpirv)); DAWN_TRY_ASSIGN(mSpirv, RunRobustBufferAccessPass(mSpirv));
} }
// We still need the spirv for reflection. Remove this when we use the Tint inspector
// completely.
std::vector<uint32_t>* spirvPtr = &mSpirv;
std::vector<uint32_t> localSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
ASSERT(program != nullptr); // We still need the spirv for reflection. Remove this when we use the Tint inspector
// completely.
std::vector<uint32_t> reflectionSpirv;
DAWN_TRY_ASSIGN(reflectionSpirv, ModuleToSPIRV(mTintProgram.get()));
DAWN_TRY(ValidateSpirv(reflectionSpirv.data(), reflectionSpirv.size()));
DAWN_TRY_ASSIGN(localSpirv, ModuleToSPIRV(program)); EntryPointMetadataTable table;
DAWN_TRY(ValidateSpirv(localSpirv.data(), localSpirv.size())); DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), mTintProgram.get()));
spirvPtr = &localSpirv; DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), reflectionSpirv, &table));
mEntryPoints = std::move(table);
#else #else
UNREACHABLE(); UNREACHABLE();
#endif
}
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
#ifdef DAWN_ENABLE_WGSL
tint::Program localProgram;
tint::Program* programPtr = program;
if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
// We have mSpirv, but no Tint program
DAWN_TRY_ASSIGN(localProgram, ParseSPIRV(mSpirv));
DAWN_TRY(ValidateModule(&localProgram));
programPtr = &localProgram;
}
EntryPointMetadataTable table;
DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), *programPtr));
DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), *spirvPtr, &table));
mEntryPoints = std::move(table);
#else
return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
#endif #endif
} else { } else {
DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), *spirvPtr)); DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), mSpirv));
} }
return {}; return {};

View File

@ -77,7 +77,7 @@ namespace dawn_native {
const PipelineLayoutBase* layout); const PipelineLayoutBase* layout);
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform, ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
tint::Program* program); const tint::Program* program);
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform( std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
const VertexStateDescriptor& vertexState, const VertexStateDescriptor& vertexState,
@ -147,6 +147,8 @@ namespace dawn_native {
const std::vector<uint32_t>& GetSpirv() const; const std::vector<uint32_t>& GetSpirv() const;
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
const tint::Program* GetTintProgram() const;
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv( ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
const std::vector<uint32_t>& spirv, const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState, const VertexStateDescriptor& vertexState,
@ -154,7 +156,7 @@ namespace dawn_native {
BindGroupIndex pullingBufferBindingSet) const; BindGroupIndex pullingBufferBindingSet) const;
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv( ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
tint::Program* program, const tint::Program* program,
const VertexStateDescriptor& vertexState, const VertexStateDescriptor& vertexState,
const std::string& entryPoint, const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) const; BindGroupIndex pullingBufferBindingSet) const;
@ -166,13 +168,19 @@ namespace dawn_native {
private: private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
// The original data in the descriptor for caching.
enum class Type { Undefined, Spirv, Wgsl }; enum class Type { Undefined, Spirv, Wgsl };
Type mType; Type mType;
std::vector<uint32_t> mOriginalSpirv; std::vector<uint32_t> mOriginalSpirv;
std::vector<uint32_t> mSpirv;
std::string mWgsl; std::string mWgsl;
// Data computed from what is in the descriptor. mSpirv is set iff !UseTintGenerator while
// mTintProgram is set iff UseTintGenerator.
EntryPointMetadataTable mEntryPoints; EntryPointMetadataTable mEntryPoints;
std::vector<uint32_t> mSpirv;
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::Program> mTintProgram;
#endif
}; };
} // namespace dawn_native } // namespace dawn_native

View File

@ -186,11 +186,7 @@ namespace dawn_native { namespace d3d12 {
} }
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
DAWN_TRY(InitializeBase(parseResult)); return InitializeBase(parseResult);
#ifdef DAWN_ENABLE_WGSL
mTintProgram = std::move(parseResult->tintProgram);
#endif
return {};
} }
ResultOrError<std::string> ShaderModule::TranslateToHLSLWithTint( ResultOrError<std::string> ShaderModule::TranslateToHLSLWithTint(
@ -215,7 +211,7 @@ namespace dawn_native { namespace d3d12 {
transformManager.append(std::make_unique<tint::transform::Renamer>()); transformManager.append(std::make_unique<tint::transform::Renamer>());
transformManager.append(std::make_unique<tint::transform::Hlsl>()); transformManager.append(std::make_unique<tint::transform::Hlsl>());
tint::transform::Transform::Output output = transformManager.Run(mTintProgram.get()); tint::transform::Transform::Output output = transformManager.Run(GetTintProgram());
tint::Program& program = output.program; tint::Program& program = output.program;
if (!program.IsValid()) { if (!program.IsValid()) {

View File

@ -76,10 +76,6 @@ namespace dawn_native { namespace d3d12 {
ResultOrError<uint64_t> GetDXCompilerVersion() const; ResultOrError<uint64_t> GetDXCompilerVersion() const;
uint64_t GetD3DCompilerVersion() const; uint64_t GetD3DCompilerVersion() const;
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::Program> mTintProgram;
#endif
}; };
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -69,10 +69,6 @@ namespace dawn_native { namespace metal {
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult);
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::Program> mTintProgram;
#endif
}; };
}} // namespace dawn_native::metal }} // namespace dawn_native::metal

View File

@ -48,11 +48,7 @@ namespace dawn_native { namespace metal {
} }
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
DAWN_TRY(InitializeBase(parseResult)); return InitializeBase(parseResult);
#ifdef DAWN_ENABLE_WGSL
mTintProgram = std::move(parseResult->tintProgram);
#endif
return {};
} }
ResultOrError<std::string> ShaderModule::TranslateToMSLWithTint( ResultOrError<std::string> ShaderModule::TranslateToMSLWithTint(
@ -90,7 +86,7 @@ namespace dawn_native { namespace metal {
transformManager.append(std::make_unique<tint::transform::Renamer>()); transformManager.append(std::make_unique<tint::transform::Renamer>());
transformManager.append(std::make_unique<tint::transform::Msl>()); transformManager.append(std::make_unique<tint::transform::Msl>());
tint::transform::Transform::Output output = transformManager.Run(mTintProgram.get()); tint::transform::Transform::Output output = transformManager.Run(GetTintProgram());
tint::Program& program = output.program; tint::Program& program = output.program;
if (!program.IsValid()) { if (!program.IsValid()) {
@ -137,9 +133,9 @@ namespace dawn_native { namespace metal {
std::vector<uint32_t> pullingSpirv; std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
stage == SingleShaderStage::Vertex) { stage == SingleShaderStage::Vertex) {
if (mTintProgram) { if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
DAWN_TRY_ASSIGN(pullingSpirv, DAWN_TRY_ASSIGN(pullingSpirv,
GeneratePullingSpirv(mTintProgram.get(), GeneratePullingSpirv(GetTintProgram(),
*renderPipeline->GetVertexStateDescriptor(), *renderPipeline->GetVertexStateDescriptor(),
entryPointName, kPullingBufferBindingSet)); entryPointName, kPullingBufferBindingSet));
} else { } else {