diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp index e5cd483815..26d5bb60a4 100644 --- a/src/dawn/native/Device.cpp +++ b/src/dawn/native/Device.cpp @@ -1238,7 +1238,7 @@ void DeviceBase::SetWGSLExtensionAllowList() { // mWGSLExtensionAllowList.insert("InternalExtensionForTesting"); } -WGSLExtensionsSet DeviceBase::GetWGSLExtensionAllowList() const { +WGSLExtensionSet DeviceBase::GetWGSLExtensionAllowList() const { return mWGSLExtensionAllowList; } diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h index 37e65d304b..d903151118 100644 --- a/src/dawn/native/Device.h +++ b/src/dawn/native/Device.h @@ -55,7 +55,7 @@ struct CallbackTask; struct InternalPipelineStore; struct ShaderModuleParseResult; -using WGSLExtensionsSet = std::unordered_set; +using WGSLExtensionSet = std::unordered_set; class DeviceBase : public RefCounted { public: @@ -319,7 +319,7 @@ class DeviceBase : public RefCounted { std::mutex* GetObjectListMutex(ObjectType type); std::vector GetTogglesUsed() const; - WGSLExtensionsSet GetWGSLExtensionAllowList() const; + WGSLExtensionSet GetWGSLExtensionAllowList() const; bool IsFeatureEnabled(Feature feature) const; bool IsToggleEnabled(Toggle toggle) const; bool IsValidationEnabled() const; @@ -549,7 +549,7 @@ class DeviceBase : public RefCounted { CombinedLimits mLimits; FeaturesSet mEnabledFeatures; - WGSLExtensionsSet mWGSLExtensionAllowList; + WGSLExtensionSet mWGSLExtensionAllowList; std::unique_ptr mInternalPipelineStore; diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index ab9a1305a3..70419f1f15 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -901,25 +901,20 @@ ResultOrError> ReflectEntryPointUsingTint( } MaybeError ValidateWGSLProgramExtension(const DeviceBase* device, - const tint::Program* program, + const WGSLExtensionSet* enabledExtensions, OwnedCompilationMessages* outMessages) { - DAWN_ASSERT(program->IsValid()); - tint::inspector::Inspector inspector(program); - auto enableDirectives = inspector.GetEnableDirectives(); - - auto extensionAllowList = device->GetWGSLExtensionAllowList(); + const WGSLExtensionSet& extensionAllowList = device->GetWGSLExtensionAllowList(); bool hasDisallowedExtension = false; tint::diag::List messages; - for (auto enable : enableDirectives) { - if (extensionAllowList.count(enable.first)) { + for (const std::string& extension : *enabledExtensions) { + if (extensionAllowList.count(extension)) { continue; } hasDisallowedExtension = true; messages.add_error(tint::diag::System::Program, - "Extension " + enable.first + " is not allowed on the Device.", - enable.second); + "Extension " + extension + " is not allowed on the Device."); } if (hasDisallowedExtension) { @@ -936,8 +931,8 @@ MaybeError ValidateWGSLProgramExtension(const DeviceBase* device, MaybeError ReflectShaderUsingTint(const DeviceBase* device, const tint::Program* program, OwnedCompilationMessages* compilationMessages, - EntryPointMetadataTable& entryPointMetadataTable, - WGSLExtensionsSet* enabledWGSLExtensions) { + EntryPointMetadataTable* entryPointMetadataTable, + WGSLExtensionSet* enabledWGSLExtensions) { ASSERT(program->IsValid()); tint::inspector::Inspector inspector(program); @@ -947,8 +942,7 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device, for (std::string name : usedExtensionNames) { enabledWGSLExtensions->insert(name); } - - DAWN_TRY(ValidateWGSLProgramExtension(device, program, compilationMessages)); + DAWN_TRY(ValidateWGSLProgramExtension(device, enabledWGSLExtensions, compilationMessages)); std::vector entryPoints = inspector.GetEntryPoints(); DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n", @@ -960,8 +954,8 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device, ReflectEntryPointUsingTint(device, &inspector, entryPoint), "processing entry point \"%s\".", entryPoint.name); - ASSERT(entryPointMetadataTable.count(entryPoint.name) == 0); - entryPointMetadataTable[entryPoint.name] = std::move(metadata); + ASSERT(entryPointMetadataTable->count(entryPoint.name) == 0); + (*entryPointMetadataTable)[entryPoint.name] = std::move(metadata); } return {}; } @@ -1336,7 +1330,7 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult mTintSource = std::move(parseResult->tintSource); DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages, - mEntryPoints, &mEnabledWGSLExtensions)); + &mEntryPoints, &mEnabledWGSLExtensions)); return {}; } diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index e7068c9228..1df775cfba 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h @@ -53,7 +53,7 @@ class VertexPulling; namespace dawn::native { -using WGSLExtensionsSet = std::unordered_set; +using WGSLExtensionSet = std::unordered_set; struct EntryPointMetadata; // Base component type of an inter-stage variable @@ -307,7 +307,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject { std::string mWgsl; EntryPointMetadataTable mEntryPoints; - WGSLExtensionsSet mEnabledWGSLExtensions; + WGSLExtensionSet mEnabledWGSLExtensions; std::unique_ptr mTintProgram; std::unique_ptr mTintSource; // Keep the tint::Source::File alive