dawn: Add shader module validation for WGSL extension

Tint has already implemented the enable directive for using WGSL
extension in the future, and using a WGSL extension that is not allowed
for the device should result in a shader creation error.
In this patch a WGSL extension allow list is added in DeviceBase, and
a validation is added in shader module base initialization to make sure
all extensions used in the WGSL program are in the allow list. This
patch also rename the `ValidateShaderModuleDescriptor` to
`ValidateAndParseShaderModule`, which is more descriptive for what it
actually does.

Bug: tint:1472
Change-Id: I4b039a3e37c25159b4fc6cfa37488aa817004ab2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/88241
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
Zhaoming Jiang 2022-05-06 08:51:12 +00:00 committed by Dawn LUCI CQ
parent 4cb75245e3
commit d6b2501be2
25 changed files with 189 additions and 78 deletions

View File

@ -199,6 +199,8 @@ DeviceBase::DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor)
mFormatTable = BuildFormatTable(this); mFormatTable = BuildFormatTable(this);
SetDefaultToggles(); SetDefaultToggles();
SetWGSLExtensionAllowList();
if (descriptor->label != nullptr && strlen(descriptor->label) != 0) { if (descriptor->label != nullptr && strlen(descriptor->label) != 0) {
mLabel = descriptor->label; mLabel = descriptor->label;
} }
@ -914,13 +916,13 @@ ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
if (!parseResult->HasParsedShader()) { if (!parseResult->HasParsedShader()) {
// We skip the parse on creation if validation isn't enabled which let's us quickly // We skip the parse on creation if validation isn't enabled which let's us quickly
// lookup in the cache without validating and parsing. We need the parsed module // lookup in the cache without validating and parsing. We need the parsed module
// now, so call validate. Most of |ValidateShaderModuleDescriptor| is parsing, but // now.
// we can consider splitting it if additional validation is added.
ASSERT(!IsValidationEnabled()); ASSERT(!IsValidationEnabled());
DAWN_TRY( DAWN_TRY(
ValidateShaderModuleDescriptor(this, descriptor, parseResult, compilationMessages)); ValidateAndParseShaderModule(this, descriptor, parseResult, compilationMessages));
} }
DAWN_TRY_ASSIGN(result, CreateShaderModuleImpl(descriptor, parseResult)); DAWN_TRY_ASSIGN(result,
CreateShaderModuleImpl(descriptor, parseResult, compilationMessages));
result->SetIsCachedReference(); result->SetIsCachedReference();
result->SetContentHash(blueprintHash); result->SetContentHash(blueprintHash);
mCaches->shaderModules.insert(result.Get()); mCaches->shaderModules.insert(result.Get());
@ -1119,7 +1121,8 @@ ShaderModuleBase* DeviceBase::APICreateShaderModule(const ShaderModuleDescriptor
result = ShaderModuleBase::MakeError(this); result = ShaderModuleBase::MakeError(this);
} }
// Move compilation messages into ShaderModuleBase and emit tint errors and warnings // Move compilation messages into ShaderModuleBase and emit tint errors and warnings
// after all other operations are finished successfully. // after all other operations are finished, even if any of them is failed and result
// is an error shader module.
result->InjectCompilationMessages(std::move(compilationMessages)); result->InjectCompilationMessages(std::move(compilationMessages));
return result.Detach(); return result.Detach();
@ -1229,6 +1232,16 @@ bool DeviceBase::IsFeatureEnabled(Feature feature) const {
return mEnabledFeatures.IsEnabled(feature); return mEnabledFeatures.IsEnabled(feature);
} }
void DeviceBase::SetWGSLExtensionAllowList() {
// Set the WGSL extensions allow list based on device's enabled features and other
// propority. For example:
// mWGSLExtensionAllowList.insert("InternalExtensionForTesting");
}
WGSLExtensionsSet DeviceBase::GetWGSLExtensionAllowList() const {
return mWGSLExtensionAllowList;
}
bool DeviceBase::IsValidationEnabled() const { bool DeviceBase::IsValidationEnabled() const {
return !IsToggleEnabled(Toggle::SkipValidation); return !IsToggleEnabled(Toggle::SkipValidation);
} }
@ -1589,7 +1602,7 @@ ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
if (IsValidationEnabled()) { if (IsValidationEnabled()) {
DAWN_TRY_CONTEXT( DAWN_TRY_CONTEXT(
ValidateShaderModuleDescriptor(this, descriptor, &parseResult, compilationMessages), ValidateAndParseShaderModule(this, descriptor, &parseResult, compilationMessages),
"validating %s", descriptor); "validating %s", descriptor);
} }

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -54,6 +55,8 @@ struct CallbackTask;
struct InternalPipelineStore; struct InternalPipelineStore;
struct ShaderModuleParseResult; struct ShaderModuleParseResult;
using WGSLExtensionsSet = std::unordered_set<std::string>;
class DeviceBase : public RefCounted { class DeviceBase : public RefCounted {
public: public:
DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor); DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor);
@ -316,6 +319,7 @@ class DeviceBase : public RefCounted {
std::mutex* GetObjectListMutex(ObjectType type); std::mutex* GetObjectListMutex(ObjectType type);
std::vector<const char*> GetTogglesUsed() const; std::vector<const char*> GetTogglesUsed() const;
WGSLExtensionsSet GetWGSLExtensionAllowList() const;
bool IsFeatureEnabled(Feature feature) const; bool IsFeatureEnabled(Feature feature) const;
bool IsToggleEnabled(Toggle toggle) const; bool IsToggleEnabled(Toggle toggle) const;
bool IsValidationEnabled() const; bool IsValidationEnabled() const;
@ -412,7 +416,8 @@ class DeviceBase : public RefCounted {
const SamplerDescriptor* descriptor) = 0; const SamplerDescriptor* descriptor) = 0;
virtual ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl( virtual ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) = 0; ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) = 0;
virtual ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl( virtual ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) = 0; const SwapChainDescriptor* descriptor) = 0;
// Note that previousSwapChain may be nullptr, or come from a different backend. // Note that previousSwapChain may be nullptr, or come from a different backend.
@ -456,6 +461,8 @@ class DeviceBase : public RefCounted {
void SetDefaultToggles(); void SetDefaultToggles();
void SetWGSLExtensionAllowList();
void ConsumeError(std::unique_ptr<ErrorData> error); void ConsumeError(std::unique_ptr<ErrorData> error);
// Each backend should implement to check their passed fences if there are any and return a // Each backend should implement to check their passed fences if there are any and return a
@ -542,6 +549,7 @@ class DeviceBase : public RefCounted {
CombinedLimits mLimits; CombinedLimits mLimits;
FeaturesSet mEnabledFeatures; FeaturesSet mEnabledFeatures;
WGSLExtensionsSet mWGSLExtensionAllowList;
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore; std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;

View File

@ -900,27 +900,70 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
return std::move(metadata); return std::move(metadata);
} }
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(const DeviceBase* device, MaybeError ValidateWGSLProgramExtension(const DeviceBase* device,
const tint::Program* program) { const tint::Program* program,
OwnedCompilationMessages* outMessages) {
DAWN_ASSERT(program->IsValid());
tint::inspector::Inspector inspector(program);
auto enableDirectives = inspector.GetEnableDirectives();
auto extensionAllowList = device->GetWGSLExtensionAllowList();
bool hasDisallowedExtension = false;
tint::diag::List messages;
for (auto enable : enableDirectives) {
if (extensionAllowList.count(enable.first)) {
continue;
}
hasDisallowedExtension = true;
messages.add_error(tint::diag::System::Program,
"Extension " + enable.first + " is not allowed on the Device.",
enable.second);
}
if (hasDisallowedExtension) {
if (outMessages != nullptr) {
outMessages->AddMessages(messages);
}
return DAWN_MAKE_ERROR(InternalErrorType::Validation,
"Shader module uses extension(s) not enabled for its device.");
}
return {};
}
MaybeError ReflectShaderUsingTint(const DeviceBase* device,
const tint::Program* program,
OwnedCompilationMessages* compilationMessages,
EntryPointMetadataTable& entryPointMetadataTable,
WGSLExtensionsSet* enabledWGSLExtensions) {
ASSERT(program->IsValid()); ASSERT(program->IsValid());
tint::inspector::Inspector inspector(program); tint::inspector::Inspector inspector(program);
ASSERT(enabledWGSLExtensions->empty());
auto usedExtensionNames = inspector.GetUsedExtensionNames();
for (std::string name : usedExtensionNames) {
enabledWGSLExtensions->insert(name);
}
DAWN_TRY(ValidateWGSLProgramExtension(device, program, compilationMessages));
std::vector<tint::inspector::EntryPoint> entryPoints = inspector.GetEntryPoints(); std::vector<tint::inspector::EntryPoint> entryPoints = inspector.GetEntryPoints();
DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n", DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n",
inspector.error()); inspector.error());
EntryPointMetadataTable result;
for (const tint::inspector::EntryPoint& entryPoint : entryPoints) { for (const tint::inspector::EntryPoint& entryPoint : entryPoints) {
std::unique_ptr<EntryPointMetadata> metadata; std::unique_ptr<EntryPointMetadata> metadata;
DAWN_TRY_ASSIGN_CONTEXT(metadata, DAWN_TRY_ASSIGN_CONTEXT(metadata,
ReflectEntryPointUsingTint(device, &inspector, entryPoint), ReflectEntryPointUsingTint(device, &inspector, entryPoint),
"processing entry point \"%s\".", entryPoint.name); "processing entry point \"%s\".", entryPoint.name);
ASSERT(result.count(entryPoint.name) == 0); ASSERT(entryPointMetadataTable.count(entryPoint.name) == 0);
result[entryPoint.name] = std::move(metadata); entryPointMetadataTable[entryPoint.name] = std::move(metadata);
} }
return std::move(result); return {};
} }
} // anonymous namespace } // anonymous namespace
@ -946,7 +989,7 @@ class TintSource {
tint::Source::File file; tint::Source::File file;
}; };
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, MaybeError ValidateAndParseShaderModule(DeviceBase* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult, ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* outMessages) { OwnedCompilationMessages* outMessages) {
@ -1287,11 +1330,13 @@ void ShaderModuleBase::AddExternalTextureTransform(const PipelineLayoutBase* lay
} }
} }
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) { MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
mTintProgram = std::move(parseResult->tintProgram); mTintProgram = std::move(parseResult->tintProgram);
mTintSource = std::move(parseResult->tintSource); mTintSource = std::move(parseResult->tintSource);
DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingTint(GetDevice(), mTintProgram.get())); DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages,
mEntryPoints, &mEnabledWGSLExtensions));
return {}; return {};
} }

View File

@ -53,6 +53,7 @@ class VertexPulling;
namespace dawn::native { namespace dawn::native {
using WGSLExtensionsSet = std::unordered_set<std::string>;
struct EntryPointMetadata; struct EntryPointMetadata;
// Base component type of an inter-stage variable // Base component type of an inter-stage variable
@ -99,7 +100,7 @@ struct ShaderModuleParseResult {
std::unique_ptr<TintSource> tintSource; std::unique_ptr<TintSource> tintSource;
}; };
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, MaybeError ValidateAndParseShaderModule(DeviceBase* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult, ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* outMessages); OwnedCompilationMessages* outMessages);
@ -289,7 +290,8 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
explicit ShaderModuleBase(DeviceBase* device); explicit ShaderModuleBase(DeviceBase* device);
void DestroyImpl() override; void DestroyImpl() override;
MaybeError InitializeBase(ShaderModuleParseResult* parseResult); MaybeError InitializeBase(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
static void AddExternalTextureTransform(const PipelineLayoutBase* layout, static void AddExternalTextureTransform(const PipelineLayoutBase* layout,
tint::transform::Manager* transformManager, tint::transform::Manager* transformManager,
@ -305,6 +307,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
std::string mWgsl; std::string mWgsl;
EntryPointMetadataTable mEntryPoints; EntryPointMetadataTable mEntryPoints;
WGSLExtensionsSet mEnabledWGSLExtensions;
std::unique_ptr<tint::Program> mTintProgram; std::unique_ptr<tint::Program> mTintProgram;
std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive

View File

@ -422,8 +422,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
} }
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
return ShaderModule::Create(this, descriptor, parseResult); OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
} }
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -170,7 +170,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override; ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -729,20 +729,23 @@ MaybeError CompileShader(dawn::platform::Platform* platform,
} // anonymous namespace } // anonymous namespace
// static // static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device, ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult)); DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module; return module;
} }
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor) ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
: ShaderModuleBase(device, descriptor) {} : ShaderModuleBase(device, descriptor) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
return InitializeBase(parseResult); return InitializeBase(parseResult, compilationMessages);
} }
ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage, ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,

View File

@ -42,7 +42,8 @@ class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device, static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult); ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
ResultOrError<CompiledShader> Compile(const ProgrammableStage& programmableStage, ResultOrError<CompiledShader> Compile(const ProgrammableStage& programmableStage,
SingleShaderStage stage, SingleShaderStage stage,
@ -52,7 +53,8 @@ class ShaderModule final : public ShaderModuleBase {
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
}; };
} // namespace dawn::native::d3d12 } // namespace dawn::native::d3d12

View File

@ -92,7 +92,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override; ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -260,8 +260,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
} }
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
return ShaderModule::Create(this, descriptor, parseResult); OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
} }
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -35,7 +35,8 @@ class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device, static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult); ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
struct MetalFunctionData { struct MetalFunctionData {
NSPRef<id<MTLFunction>> function; NSPRef<id<MTLFunction>> function;
@ -65,7 +66,8 @@ class ShaderModule final : public ShaderModuleBase {
std::vector<uint32_t>* workgroupAllocations); std::vector<uint32_t>* workgroupAllocations);
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
}; };
} // namespace dawn::native::metal } // namespace dawn::native::metal

View File

@ -29,20 +29,23 @@
namespace dawn::native::metal { namespace dawn::native::metal {
// static // static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device, ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult)); DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module; return module;
} }
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor) ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
: ShaderModuleBase(device, descriptor) {} : ShaderModuleBase(device, descriptor) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
return InitializeBase(parseResult); return InitializeBase(parseResult, compilationMessages);
} }
ResultOrError<std::string> ShaderModule::TranslateToMSL( ResultOrError<std::string> ShaderModule::TranslateToMSL(

View File

@ -155,9 +155,10 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
} }
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
DAWN_TRY(module->Initialize(parseResult)); DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module; return module;
} }
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
@ -429,8 +430,9 @@ void SwapChain::DetachFromSurfaceImpl() {
// ShaderModule // ShaderModule
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
return InitializeBase(parseResult); OwnedCompilationMessages* compilationMessages) {
return InitializeBase(parseResult, compilationMessages);
} }
// OldSwapChain // OldSwapChain

View File

@ -142,7 +142,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override; ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(
@ -277,7 +278,8 @@ class ShaderModule final : public ShaderModuleBase {
public: public:
using ShaderModuleBase::ShaderModuleBase; using ShaderModuleBase::ShaderModuleBase;
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
}; };
class SwapChain final : public NewSwapChainBase { class SwapChain final : public NewSwapChainBase {

View File

@ -171,8 +171,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
} }
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
return ShaderModule::Create(this, descriptor, parseResult); OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
} }
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -98,7 +98,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override; ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -59,21 +59,24 @@ std::string CombinedSampler::GetName() const {
} }
// static // static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device, ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult)); DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module; return module;
} }
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor) ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
: ShaderModuleBase(device, descriptor) {} : ShaderModuleBase(device, descriptor) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
DAWN_TRY(InitializeBase(parseResult)); DAWN_TRY(InitializeBase(parseResult, compilationMessages));
return {}; return {};
} }

View File

@ -56,7 +56,8 @@ class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device, static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult); ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
ResultOrError<std::string> TranslateToGLSL(const char* entryPointName, ResultOrError<std::string> TranslateToGLSL(const char* entryPointName,
SingleShaderStage stage, SingleShaderStage stage,
@ -67,7 +68,8 @@ class ShaderModule final : public ShaderModuleBase {
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
}; };
} // namespace dawn::native::opengl } // namespace dawn::native::opengl

View File

@ -151,8 +151,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
} }
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
return ShaderModule::Create(this, descriptor, parseResult); OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
} }
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) { const SwapChainDescriptor* descriptor) {

View File

@ -127,7 +127,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl( ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override; ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override; const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl( ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -73,11 +73,13 @@ ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCach
} }
// static // static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device, ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) { ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult)); DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module; return module;
} }
@ -86,7 +88,8 @@ ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descrip
mTransformedShaderModuleCache( mTransformedShaderModuleCache(
std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {} std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
if (GetDevice()->IsRobustnessEnabled()) { if (GetDevice()->IsRobustnessEnabled()) {
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
@ -100,7 +103,7 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program)); parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
} }
return InitializeBase(parseResult); return InitializeBase(parseResult, compilationMessages);
} }
void ShaderModule::DestroyImpl() { void ShaderModule::DestroyImpl() {

View File

@ -40,7 +40,8 @@ class ShaderModule final : public ShaderModuleBase {
static ResultOrError<Ref<ShaderModule>> Create(Device* device, static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult); ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName, ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
const PipelineLayout* layout); const PipelineLayout* layout);
@ -48,7 +49,8 @@ class ShaderModule final : public ShaderModuleBase {
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override; ~ShaderModule() override;
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
void DestroyImpl() override; void DestroyImpl() override;
// New handles created by GetHandleAndSpirv at pipeline creation time. // New handles created by GetHandleAndSpirv at pipeline creation time.

View File

@ -90,7 +90,9 @@ class DeviceMock : public DeviceBase {
(override)); (override));
MOCK_METHOD(ResultOrError<Ref<ShaderModuleBase>>, MOCK_METHOD(ResultOrError<Ref<ShaderModuleBase>>,
CreateShaderModuleImpl, CreateShaderModuleImpl,
(const ShaderModuleDescriptor*, ShaderModuleParseResult*), (const ShaderModuleDescriptor*,
ShaderModuleParseResult*,
OwnedCompilationMessages*),
(override)); (override));
MOCK_METHOD(ResultOrError<Ref<SwapChainBase>>, MOCK_METHOD(ResultOrError<Ref<SwapChainBase>>,
CreateSwapChainImpl, CreateSwapChainImpl,

View File

@ -30,8 +30,8 @@ ResultOrError<Ref<ShaderModuleMock>> ShaderModuleMock::Create(DeviceBase* device
desc.nextInChain = &wgslDesc; desc.nextInChain = &wgslDesc;
ShaderModuleParseResult parseResult; ShaderModuleParseResult parseResult;
DAWN_TRY(ValidateShaderModuleDescriptor(device, &desc, &parseResult, nullptr)); DAWN_TRY(ValidateAndParseShaderModule(device, &desc, &parseResult, nullptr));
DAWN_TRY(mock->InitializeBase(&parseResult)); DAWN_TRY(mock->InitializeBase(&parseResult, nullptr));
return AcquireRef(mock); return AcquireRef(mock);
} }

View File

@ -659,3 +659,11 @@ TEST_F(ShaderModuleValidationTest, MissingDecorations) {
} }
)")); )"));
} }
// Test that WGSL extension used by enable directives must be allowed by WebGPU.
TEST_F(ShaderModuleValidationTest, ExtensionMustBeAllowed) {
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
enable InternalExtensionForTesting;
@stage(compute) @workgroup_size(1) fn main() {})"));
}