ShaderModule: Store the tint::Program in the base class.

This is in preparation of removing all the DAWN_ENABLE_WGSL logic: the
ShaderModuleBase will have either mSpirv or mTintProgram set based on
UseTintGenerator.

Also improves the constness of some functions.

Also simplifies a bit ShaderModuleBase::Initialize.

Bug: dawn:706
Change-Id: Ib879e2aec8a004aeb8ac5dc6e1176b1667fc227d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/45422
Commit-Queue: Austin Eng <enga@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
Corentin Wallez 2021-03-22 18:24:16 +00:00 committed by Commit Bot service account
parent 5ff4978a5d
commit 3e4b57b77e
6 changed files with 42 additions and 66 deletions

View File

@ -266,7 +266,7 @@ namespace dawn_native {
return std::move(program);
}
MaybeError ValidateModule(tint::Program* program) {
MaybeError ValidateModule(const tint::Program* program) {
std::ostringstream errorStream;
errorStream << "Tint program validation" << std::endl;
@ -729,14 +729,14 @@ namespace dawn_native {
// PopulateMetadataUsingSPIRVCross will be removed.
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(
DeviceBase*,
const tint::Program& program) {
ASSERT(program.IsValid());
const tint::Program* program) {
ASSERT(program->IsValid());
EntryPointMetadataTable result;
std::ostringstream errorStream;
errorStream << "Tint Reflection failure:" << std::endl;
tint::inspector::Inspector inspector(&program);
tint::inspector::Inspector inspector(program);
auto entryPoints = inspector.GetEntryPoints();
if (inspector.has_error()) {
errorStream << "Inspector: " << inspector.error() << std::endl;
@ -842,7 +842,7 @@ namespace dawn_native {
// fallback/source of truth.
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingSPIRVCross(
DeviceBase* device,
std::vector<uint32_t> spirv) {
const std::vector<uint32_t>& spirv) {
EntryPointMetadataTable result;
spirv_cross::Compiler compiler(spirv);
for (const spirv_cross::EntryPoint& entryPoint :
@ -866,7 +866,7 @@ namespace dawn_native {
// using the SPIRV-Cross implementation. Once the Tint implementation is
// completed, this function will be removed.
MaybeError PopulateMetadataUsingSPIRVCross(DeviceBase* device,
std::vector<uint32_t> spirv,
const std::vector<uint32_t>& spirv,
EntryPointMetadataTable* tintTable) {
EntryPointMetadataTable crossTable;
DAWN_TRY_ASSIGN(crossTable, ReflectShaderUsingSPIRVCross(device, spirv));
@ -998,6 +998,7 @@ namespace dawn_native {
if (device->IsValidationEnabled()) {
DAWN_TRY(ValidateModule(&program));
}
parseResult.tintProgram = std::make_unique<tint::Program>(std::move(program));
} else {
tint::transform::Manager transformManager;
transformManager.append(
@ -1015,8 +1016,6 @@ namespace dawn_native {
parseResult.spirv = std::move(spirv);
}
parseResult.tintProgram = std::make_unique<tint::Program>(std::move(program));
break;
#else
return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
@ -1042,7 +1041,7 @@ namespace dawn_native {
#ifdef DAWN_ENABLE_WGSL
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
tint::Program* program) {
const tint::Program* program) {
tint::transform::Transform::Output output = transform->Run(program);
if (!output.program.IsValid()) {
std::string err = "Tint program failure: " + output.program.Diagnostics().str();
@ -1169,6 +1168,11 @@ namespace dawn_native {
}
#ifdef DAWN_ENABLE_WGSL
const tint::Program* ShaderModuleBase::GetTintProgram() const {
ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
return mTintProgram.get();
}
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState,
@ -1181,7 +1185,7 @@ namespace dawn_native {
}
ResultOrError<std::vector<uint32_t>> ShaderModuleBase::GeneratePullingSpirv(
tint::Program* programIn,
const tint::Program* programIn,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) const {
@ -1214,7 +1218,7 @@ namespace dawn_native {
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
#ifdef DAWN_ENABLE_WGSL
tint::Program* program = parseResult->tintProgram.get();
mTintProgram = std::move(parseResult->tintProgram);
#endif
mSpirv = std::move(parseResult->spirv);
@ -1226,43 +1230,23 @@ namespace dawn_native {
DAWN_TRY_ASSIGN(mSpirv, RunRobustBufferAccessPass(mSpirv));
}
// We still need the spirv for reflection. Remove this when we use the Tint inspector
// completely.
std::vector<uint32_t>* spirvPtr = &mSpirv;
std::vector<uint32_t> localSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
#ifdef DAWN_ENABLE_WGSL
ASSERT(program != nullptr);
// We still need the spirv for reflection. Remove this when we use the Tint inspector
// completely.
std::vector<uint32_t> reflectionSpirv;
DAWN_TRY_ASSIGN(reflectionSpirv, ModuleToSPIRV(mTintProgram.get()));
DAWN_TRY(ValidateSpirv(reflectionSpirv.data(), reflectionSpirv.size()));
DAWN_TRY_ASSIGN(localSpirv, ModuleToSPIRV(program));
DAWN_TRY(ValidateSpirv(localSpirv.data(), localSpirv.size()));
spirvPtr = &localSpirv;
EntryPointMetadataTable table;
DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), mTintProgram.get()));
DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), reflectionSpirv, &table));
mEntryPoints = std::move(table);
#else
UNREACHABLE();
#endif
}
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
#ifdef DAWN_ENABLE_WGSL
tint::Program localProgram;
tint::Program* programPtr = program;
if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
// We have mSpirv, but no Tint program
DAWN_TRY_ASSIGN(localProgram, ParseSPIRV(mSpirv));
DAWN_TRY(ValidateModule(&localProgram));
programPtr = &localProgram;
}
EntryPointMetadataTable table;
DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), *programPtr));
DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), *spirvPtr, &table));
mEntryPoints = std::move(table);
#else
return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build.");
#endif
} else {
DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), *spirvPtr));
DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), mSpirv));
}
return {};

View File

@ -77,7 +77,7 @@ namespace dawn_native {
const PipelineLayoutBase* layout);
#ifdef DAWN_ENABLE_WGSL
ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
tint::Program* program);
const tint::Program* program);
std::unique_ptr<tint::transform::VertexPulling> MakeVertexPullingTransform(
const VertexStateDescriptor& vertexState,
@ -147,6 +147,8 @@ namespace dawn_native {
const std::vector<uint32_t>& GetSpirv() const;
#ifdef DAWN_ENABLE_WGSL
const tint::Program* GetTintProgram() const;
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
const std::vector<uint32_t>& spirv,
const VertexStateDescriptor& vertexState,
@ -154,7 +156,7 @@ namespace dawn_native {
BindGroupIndex pullingBufferBindingSet) const;
ResultOrError<std::vector<uint32_t>> GeneratePullingSpirv(
tint::Program* program,
const tint::Program* program,
const VertexStateDescriptor& vertexState,
const std::string& entryPoint,
BindGroupIndex pullingBufferBindingSet) const;
@ -166,13 +168,19 @@ namespace dawn_native {
private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
// The original data in the descriptor for caching.
enum class Type { Undefined, Spirv, Wgsl };
Type mType;
std::vector<uint32_t> mOriginalSpirv;
std::vector<uint32_t> mSpirv;
std::string mWgsl;
// Data computed from what is in the descriptor. mSpirv is set iff !UseTintGenerator while
// mTintProgram is set iff UseTintGenerator.
EntryPointMetadataTable mEntryPoints;
std::vector<uint32_t> mSpirv;
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::Program> mTintProgram;
#endif
};
} // namespace dawn_native

View File

@ -186,11 +186,7 @@ namespace dawn_native { namespace d3d12 {
}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
DAWN_TRY(InitializeBase(parseResult));
#ifdef DAWN_ENABLE_WGSL
mTintProgram = std::move(parseResult->tintProgram);
#endif
return {};
return InitializeBase(parseResult);
}
ResultOrError<std::string> ShaderModule::TranslateToHLSLWithTint(
@ -215,7 +211,7 @@ namespace dawn_native { namespace d3d12 {
transformManager.append(std::make_unique<tint::transform::Renamer>());
transformManager.append(std::make_unique<tint::transform::Hlsl>());
tint::transform::Transform::Output output = transformManager.Run(mTintProgram.get());
tint::transform::Transform::Output output = transformManager.Run(GetTintProgram());
tint::Program& program = output.program;
if (!program.IsValid()) {

View File

@ -76,10 +76,6 @@ namespace dawn_native { namespace d3d12 {
ResultOrError<uint64_t> GetDXCompilerVersion() const;
uint64_t GetD3DCompilerVersion() const;
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::Program> mTintProgram;
#endif
};
}} // namespace dawn_native::d3d12

View File

@ -69,10 +69,6 @@ namespace dawn_native { namespace metal {
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
#ifdef DAWN_ENABLE_WGSL
std::unique_ptr<tint::Program> mTintProgram;
#endif
};
}} // namespace dawn_native::metal

View File

@ -48,11 +48,7 @@ namespace dawn_native { namespace metal {
}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
DAWN_TRY(InitializeBase(parseResult));
#ifdef DAWN_ENABLE_WGSL
mTintProgram = std::move(parseResult->tintProgram);
#endif
return {};
return InitializeBase(parseResult);
}
ResultOrError<std::string> ShaderModule::TranslateToMSLWithTint(
@ -90,7 +86,7 @@ namespace dawn_native { namespace metal {
transformManager.append(std::make_unique<tint::transform::Renamer>());
transformManager.append(std::make_unique<tint::transform::Msl>());
tint::transform::Transform::Output output = transformManager.Run(mTintProgram.get());
tint::transform::Transform::Output output = transformManager.Run(GetTintProgram());
tint::Program& program = output.program;
if (!program.IsValid()) {
@ -137,9 +133,9 @@ namespace dawn_native { namespace metal {
std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
stage == SingleShaderStage::Vertex) {
if (mTintProgram) {
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
DAWN_TRY_ASSIGN(pullingSpirv,
GeneratePullingSpirv(mTintProgram.get(),
GeneratePullingSpirv(GetTintProgram(),
*renderPipeline->GetVertexStateDescriptor(),
entryPointName, kPullingBufferBindingSet));
} else {