ShaderModule: Don't create an inspector just to reflect exts

Bug: tint:1472

Change-Id: Ifc170c3da531dd17015f0f36dfccfaa8e250b50c
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/89403
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
Corentin Wallez 2022-05-10 06:41:24 +00:00 committed by Dawn LUCI CQ
parent c167ae12aa
commit 4b6d3f4346
4 changed files with 17 additions and 23 deletions

View File

@ -1238,7 +1238,7 @@ void DeviceBase::SetWGSLExtensionAllowList() {
// mWGSLExtensionAllowList.insert("InternalExtensionForTesting"); // mWGSLExtensionAllowList.insert("InternalExtensionForTesting");
} }
WGSLExtensionsSet DeviceBase::GetWGSLExtensionAllowList() const { WGSLExtensionSet DeviceBase::GetWGSLExtensionAllowList() const {
return mWGSLExtensionAllowList; return mWGSLExtensionAllowList;
} }

View File

@ -55,7 +55,7 @@ struct CallbackTask;
struct InternalPipelineStore; struct InternalPipelineStore;
struct ShaderModuleParseResult; struct ShaderModuleParseResult;
using WGSLExtensionsSet = std::unordered_set<std::string>; using WGSLExtensionSet = std::unordered_set<std::string>;
class DeviceBase : public RefCounted { class DeviceBase : public RefCounted {
public: public:
@ -319,7 +319,7 @@ class DeviceBase : public RefCounted {
std::mutex* GetObjectListMutex(ObjectType type); std::mutex* GetObjectListMutex(ObjectType type);
std::vector<const char*> GetTogglesUsed() const; std::vector<const char*> GetTogglesUsed() const;
WGSLExtensionsSet GetWGSLExtensionAllowList() const; WGSLExtensionSet GetWGSLExtensionAllowList() const;
bool IsFeatureEnabled(Feature feature) const; bool IsFeatureEnabled(Feature feature) const;
bool IsToggleEnabled(Toggle toggle) const; bool IsToggleEnabled(Toggle toggle) const;
bool IsValidationEnabled() const; bool IsValidationEnabled() const;
@ -549,7 +549,7 @@ class DeviceBase : public RefCounted {
CombinedLimits mLimits; CombinedLimits mLimits;
FeaturesSet mEnabledFeatures; FeaturesSet mEnabledFeatures;
WGSLExtensionsSet mWGSLExtensionAllowList; WGSLExtensionSet mWGSLExtensionAllowList;
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore; std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;

View File

@ -901,25 +901,20 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
} }
MaybeError ValidateWGSLProgramExtension(const DeviceBase* device, MaybeError ValidateWGSLProgramExtension(const DeviceBase* device,
const tint::Program* program, const WGSLExtensionSet* enabledExtensions,
OwnedCompilationMessages* outMessages) { OwnedCompilationMessages* outMessages) {
DAWN_ASSERT(program->IsValid()); const WGSLExtensionSet& extensionAllowList = device->GetWGSLExtensionAllowList();
tint::inspector::Inspector inspector(program);
auto enableDirectives = inspector.GetEnableDirectives();
auto extensionAllowList = device->GetWGSLExtensionAllowList();
bool hasDisallowedExtension = false; bool hasDisallowedExtension = false;
tint::diag::List messages; tint::diag::List messages;
for (auto enable : enableDirectives) { for (const std::string& extension : *enabledExtensions) {
if (extensionAllowList.count(enable.first)) { if (extensionAllowList.count(extension)) {
continue; continue;
} }
hasDisallowedExtension = true; hasDisallowedExtension = true;
messages.add_error(tint::diag::System::Program, messages.add_error(tint::diag::System::Program,
"Extension " + enable.first + " is not allowed on the Device.", "Extension " + extension + " is not allowed on the Device.");
enable.second);
} }
if (hasDisallowedExtension) { if (hasDisallowedExtension) {
@ -936,8 +931,8 @@ MaybeError ValidateWGSLProgramExtension(const DeviceBase* device,
MaybeError ReflectShaderUsingTint(const DeviceBase* device, MaybeError ReflectShaderUsingTint(const DeviceBase* device,
const tint::Program* program, const tint::Program* program,
OwnedCompilationMessages* compilationMessages, OwnedCompilationMessages* compilationMessages,
EntryPointMetadataTable& entryPointMetadataTable, EntryPointMetadataTable* entryPointMetadataTable,
WGSLExtensionsSet* enabledWGSLExtensions) { WGSLExtensionSet* enabledWGSLExtensions) {
ASSERT(program->IsValid()); ASSERT(program->IsValid());
tint::inspector::Inspector inspector(program); tint::inspector::Inspector inspector(program);
@ -947,8 +942,7 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device,
for (std::string name : usedExtensionNames) { for (std::string name : usedExtensionNames) {
enabledWGSLExtensions->insert(name); enabledWGSLExtensions->insert(name);
} }
DAWN_TRY(ValidateWGSLProgramExtension(device, enabledWGSLExtensions, compilationMessages));
DAWN_TRY(ValidateWGSLProgramExtension(device, program, compilationMessages));
std::vector<tint::inspector::EntryPoint> entryPoints = inspector.GetEntryPoints(); std::vector<tint::inspector::EntryPoint> entryPoints = inspector.GetEntryPoints();
DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n", DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n",
@ -960,8 +954,8 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device,
ReflectEntryPointUsingTint(device, &inspector, entryPoint), ReflectEntryPointUsingTint(device, &inspector, entryPoint),
"processing entry point \"%s\".", entryPoint.name); "processing entry point \"%s\".", entryPoint.name);
ASSERT(entryPointMetadataTable.count(entryPoint.name) == 0); ASSERT(entryPointMetadataTable->count(entryPoint.name) == 0);
entryPointMetadataTable[entryPoint.name] = std::move(metadata); (*entryPointMetadataTable)[entryPoint.name] = std::move(metadata);
} }
return {}; return {};
} }
@ -1336,7 +1330,7 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult
mTintSource = std::move(parseResult->tintSource); mTintSource = std::move(parseResult->tintSource);
DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages, DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages,
mEntryPoints, &mEnabledWGSLExtensions)); &mEntryPoints, &mEnabledWGSLExtensions));
return {}; return {};
} }

View File

@ -53,7 +53,7 @@ class VertexPulling;
namespace dawn::native { namespace dawn::native {
using WGSLExtensionsSet = std::unordered_set<std::string>; using WGSLExtensionSet = std::unordered_set<std::string>;
struct EntryPointMetadata; struct EntryPointMetadata;
// Base component type of an inter-stage variable // Base component type of an inter-stage variable
@ -307,7 +307,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
std::string mWgsl; std::string mWgsl;
EntryPointMetadataTable mEntryPoints; EntryPointMetadataTable mEntryPoints;
WGSLExtensionsSet mEnabledWGSLExtensions; WGSLExtensionSet mEnabledWGSLExtensions;
std::unique_ptr<tint::Program> mTintProgram; std::unique_ptr<tint::Program> mTintProgram;
std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive