dawn: Add shader module validation for WGSL extension

Tint has already implemented the enable directive for using WGSL
extension in the future, and using a WGSL extension that is not allowed
for the device should result in a shader creation error.
In this patch a WGSL extension allow list is added in DeviceBase, and
a validation is added in shader module base initialization to make sure
all extensions used in the WGSL program are in the allow list. This
patch also rename the `ValidateShaderModuleDescriptor` to
`ValidateAndParseShaderModule`, which is more descriptive for what it
actually does.

Bug: tint:1472
Change-Id: I4b039a3e37c25159b4fc6cfa37488aa817004ab2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/88241
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
Zhaoming Jiang 2022-05-06 08:51:12 +00:00 committed by Dawn LUCI CQ
parent 4cb75245e3
commit d6b2501be2
25 changed files with 189 additions and 78 deletions

View File

@ -199,6 +199,8 @@ DeviceBase::DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor)
mFormatTable = BuildFormatTable(this);
SetDefaultToggles();
SetWGSLExtensionAllowList();
if (descriptor->label != nullptr && strlen(descriptor->label) != 0) {
mLabel = descriptor->label;
}
@ -914,13 +916,13 @@ ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
if (!parseResult->HasParsedShader()) {
// We skip the parse on creation if validation isn't enabled which let's us quickly
// lookup in the cache without validating and parsing. We need the parsed module
// now, so call validate. Most of |ValidateShaderModuleDescriptor| is parsing, but
// we can consider splitting it if additional validation is added.
// now.
ASSERT(!IsValidationEnabled());
DAWN_TRY(
ValidateShaderModuleDescriptor(this, descriptor, parseResult, compilationMessages));
ValidateAndParseShaderModule(this, descriptor, parseResult, compilationMessages));
}
DAWN_TRY_ASSIGN(result, CreateShaderModuleImpl(descriptor, parseResult));
DAWN_TRY_ASSIGN(result,
CreateShaderModuleImpl(descriptor, parseResult, compilationMessages));
result->SetIsCachedReference();
result->SetContentHash(blueprintHash);
mCaches->shaderModules.insert(result.Get());
@ -1119,7 +1121,8 @@ ShaderModuleBase* DeviceBase::APICreateShaderModule(const ShaderModuleDescriptor
result = ShaderModuleBase::MakeError(this);
}
// Move compilation messages into ShaderModuleBase and emit tint errors and warnings
// after all other operations are finished successfully.
// after all other operations are finished, even if any of them is failed and result
// is an error shader module.
result->InjectCompilationMessages(std::move(compilationMessages));
return result.Detach();
@ -1229,6 +1232,16 @@ bool DeviceBase::IsFeatureEnabled(Feature feature) const {
return mEnabledFeatures.IsEnabled(feature);
}
void DeviceBase::SetWGSLExtensionAllowList() {
// Set the WGSL extensions allow list based on device's enabled features and other
// propority. For example:
// mWGSLExtensionAllowList.insert("InternalExtensionForTesting");
}
WGSLExtensionsSet DeviceBase::GetWGSLExtensionAllowList() const {
return mWGSLExtensionAllowList;
}
bool DeviceBase::IsValidationEnabled() const {
return !IsToggleEnabled(Toggle::SkipValidation);
}
@ -1589,7 +1602,7 @@ ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
if (IsValidationEnabled()) {
DAWN_TRY_CONTEXT(
ValidateShaderModuleDescriptor(this, descriptor, &parseResult, compilationMessages),
ValidateAndParseShaderModule(this, descriptor, &parseResult, compilationMessages),
"validating %s", descriptor);
}

View File

@ -18,6 +18,7 @@
#include <memory>
#include <mutex>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
@ -54,6 +55,8 @@ struct CallbackTask;
struct InternalPipelineStore;
struct ShaderModuleParseResult;
using WGSLExtensionsSet = std::unordered_set<std::string>;
class DeviceBase : public RefCounted {
public:
DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor);
@ -316,6 +319,7 @@ class DeviceBase : public RefCounted {
std::mutex* GetObjectListMutex(ObjectType type);
std::vector<const char*> GetTogglesUsed() const;
WGSLExtensionsSet GetWGSLExtensionAllowList() const;
bool IsFeatureEnabled(Feature feature) const;
bool IsToggleEnabled(Toggle toggle) const;
bool IsValidationEnabled() const;
@ -412,7 +416,8 @@ class DeviceBase : public RefCounted {
const SamplerDescriptor* descriptor) = 0;
virtual ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) = 0;
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) = 0;
virtual ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) = 0;
// Note that previousSwapChain may be nullptr, or come from a different backend.
@ -456,6 +461,8 @@ class DeviceBase : public RefCounted {
void SetDefaultToggles();
void SetWGSLExtensionAllowList();
void ConsumeError(std::unique_ptr<ErrorData> error);
// Each backend should implement to check their passed fences if there are any and return a
@ -542,6 +549,7 @@ class DeviceBase : public RefCounted {
CombinedLimits mLimits;
FeaturesSet mEnabledFeatures;
WGSLExtensionsSet mWGSLExtensionAllowList;
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;

View File

@ -900,27 +900,70 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
return std::move(metadata);
}
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(const DeviceBase* device,
const tint::Program* program) {
MaybeError ValidateWGSLProgramExtension(const DeviceBase* device,
const tint::Program* program,
OwnedCompilationMessages* outMessages) {
DAWN_ASSERT(program->IsValid());
tint::inspector::Inspector inspector(program);
auto enableDirectives = inspector.GetEnableDirectives();
auto extensionAllowList = device->GetWGSLExtensionAllowList();
bool hasDisallowedExtension = false;
tint::diag::List messages;
for (auto enable : enableDirectives) {
if (extensionAllowList.count(enable.first)) {
continue;
}
hasDisallowedExtension = true;
messages.add_error(tint::diag::System::Program,
"Extension " + enable.first + " is not allowed on the Device.",
enable.second);
}
if (hasDisallowedExtension) {
if (outMessages != nullptr) {
outMessages->AddMessages(messages);
}
return DAWN_MAKE_ERROR(InternalErrorType::Validation,
"Shader module uses extension(s) not enabled for its device.");
}
return {};
}
MaybeError ReflectShaderUsingTint(const DeviceBase* device,
const tint::Program* program,
OwnedCompilationMessages* compilationMessages,
EntryPointMetadataTable& entryPointMetadataTable,
WGSLExtensionsSet* enabledWGSLExtensions) {
ASSERT(program->IsValid());
tint::inspector::Inspector inspector(program);
ASSERT(enabledWGSLExtensions->empty());
auto usedExtensionNames = inspector.GetUsedExtensionNames();
for (std::string name : usedExtensionNames) {
enabledWGSLExtensions->insert(name);
}
DAWN_TRY(ValidateWGSLProgramExtension(device, program, compilationMessages));
std::vector<tint::inspector::EntryPoint> entryPoints = inspector.GetEntryPoints();
DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n",
inspector.error());
EntryPointMetadataTable result;
for (const tint::inspector::EntryPoint& entryPoint : entryPoints) {
std::unique_ptr<EntryPointMetadata> metadata;
DAWN_TRY_ASSIGN_CONTEXT(metadata,
ReflectEntryPointUsingTint(device, &inspector, entryPoint),
"processing entry point \"%s\".", entryPoint.name);
ASSERT(result.count(entryPoint.name) == 0);
result[entryPoint.name] = std::move(metadata);
ASSERT(entryPointMetadataTable.count(entryPoint.name) == 0);
entryPointMetadataTable[entryPoint.name] = std::move(metadata);
}
return std::move(result);
return {};
}
} // anonymous namespace
@ -946,10 +989,10 @@ class TintSource {
tint::Source::File file;
};
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* outMessages) {
MaybeError ValidateAndParseShaderModule(DeviceBase* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* outMessages) {
ASSERT(parseResult != nullptr);
const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
@ -1287,11 +1330,13 @@ void ShaderModuleBase::AddExternalTextureTransform(const PipelineLayoutBase* lay
}
}
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
mTintProgram = std::move(parseResult->tintProgram);
mTintSource = std::move(parseResult->tintSource);
DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingTint(GetDevice(), mTintProgram.get()));
DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages,
mEntryPoints, &mEnabledWGSLExtensions));
return {};
}

View File

@ -53,6 +53,7 @@ class VertexPulling;
namespace dawn::native {
using WGSLExtensionsSet = std::unordered_set<std::string>;
struct EntryPointMetadata;
// Base component type of an inter-stage variable
@ -99,10 +100,10 @@ struct ShaderModuleParseResult {
std::unique_ptr<TintSource> tintSource;
};
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* outMessages);
MaybeError ValidateAndParseShaderModule(DeviceBase* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* outMessages);
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
@ -289,7 +290,8 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
explicit ShaderModuleBase(DeviceBase* device);
void DestroyImpl() override;
MaybeError InitializeBase(ShaderModuleParseResult* parseResult);
MaybeError InitializeBase(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
static void AddExternalTextureTransform(const PipelineLayoutBase* layout,
tint::transform::Manager* transformManager,
@ -305,6 +307,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
std::string mWgsl;
EntryPointMetadataTable mEntryPoints;
WGSLExtensionsSet mEnabledWGSLExtensions;
std::unique_ptr<tint::Program> mTintProgram;
std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive

View File

@ -422,8 +422,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {

View File

@ -170,7 +170,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -729,20 +729,23 @@ MaybeError CompileShader(dawn::platform::Platform* platform,
} // anonymous namespace
// static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
: ShaderModuleBase(device, descriptor) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
ScopedTintICEHandler scopedICEHandler(GetDevice());
return InitializeBase(parseResult);
return InitializeBase(parseResult, compilationMessages);
}
ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,

View File

@ -42,7 +42,8 @@ class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
ResultOrError<CompiledShader> Compile(const ProgrammableStage& programmableStage,
SingleShaderStage stage,
@ -52,7 +53,8 @@ class ShaderModule final : public ShaderModuleBase {
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
};
} // namespace dawn::native::d3d12

View File

@ -92,7 +92,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -260,8 +260,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {

View File

@ -35,7 +35,8 @@ class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
struct MetalFunctionData {
NSPRef<id<MTLFunction>> function;
@ -65,7 +66,8 @@ class ShaderModule final : public ShaderModuleBase {
std::vector<uint32_t>* workgroupAllocations);
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
};
} // namespace dawn::native::metal

View File

@ -29,20 +29,23 @@
namespace dawn::native::metal {
// static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
: ShaderModuleBase(device, descriptor) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
ScopedTintICEHandler scopedICEHandler(GetDevice());
return InitializeBase(parseResult);
return InitializeBase(parseResult, compilationMessages);
}
ResultOrError<std::string> ShaderModule::TranslateToMSL(

View File

@ -155,9 +155,10 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
DAWN_TRY(module->Initialize(parseResult));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
@ -429,8 +430,9 @@ void SwapChain::DetachFromSurfaceImpl() {
// ShaderModule
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
return InitializeBase(parseResult);
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
return InitializeBase(parseResult, compilationMessages);
}
// OldSwapChain

View File

@ -142,7 +142,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(
@ -277,7 +278,8 @@ class ShaderModule final : public ShaderModuleBase {
public:
using ShaderModuleBase::ShaderModuleBase;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
};
class SwapChain final : public NewSwapChainBase {

View File

@ -171,8 +171,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {

View File

@ -98,7 +98,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -59,21 +59,24 @@ std::string CombinedSampler::GetName() const {
}
// static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
: ShaderModuleBase(device, descriptor) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
ScopedTintICEHandler scopedICEHandler(GetDevice());
DAWN_TRY(InitializeBase(parseResult));
DAWN_TRY(InitializeBase(parseResult, compilationMessages));
return {};
}

View File

@ -56,7 +56,8 @@ class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
ResultOrError<std::string> TranslateToGLSL(const char* entryPointName,
SingleShaderStage stage,
@ -67,7 +68,8 @@ class ShaderModule final : public ShaderModuleBase {
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
};
} // namespace dawn::native::opengl

View File

@ -151,8 +151,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
}
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
return ShaderModule::Create(this, descriptor, parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
}
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) {

View File

@ -127,7 +127,8 @@ class Device final : public DeviceBase {
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) override;
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) override;
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
const SwapChainDescriptor* descriptor) override;
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(

View File

@ -73,11 +73,13 @@ ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCach
}
// static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult) {
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize(parseResult));
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
return module;
}
@ -86,7 +88,8 @@ ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descrip
mTransformedShaderModuleCache(
std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {}
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages) {
if (GetDevice()->IsRobustnessEnabled()) {
ScopedTintICEHandler scopedICEHandler(GetDevice());
@ -100,7 +103,7 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
}
return InitializeBase(parseResult);
return InitializeBase(parseResult, compilationMessages);
}
void ShaderModule::DestroyImpl() {

View File

@ -40,7 +40,8 @@ class ShaderModule final : public ShaderModuleBase {
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult);
ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
const PipelineLayout* layout);
@ -48,7 +49,8 @@ class ShaderModule final : public ShaderModuleBase {
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
MaybeError Initialize(ShaderModuleParseResult* parseResult,
OwnedCompilationMessages* compilationMessages);
void DestroyImpl() override;
// New handles created by GetHandleAndSpirv at pipeline creation time.

View File

@ -90,7 +90,9 @@ class DeviceMock : public DeviceBase {
(override));
MOCK_METHOD(ResultOrError<Ref<ShaderModuleBase>>,
CreateShaderModuleImpl,
(const ShaderModuleDescriptor*, ShaderModuleParseResult*),
(const ShaderModuleDescriptor*,
ShaderModuleParseResult*,
OwnedCompilationMessages*),
(override));
MOCK_METHOD(ResultOrError<Ref<SwapChainBase>>,
CreateSwapChainImpl,

View File

@ -30,8 +30,8 @@ ResultOrError<Ref<ShaderModuleMock>> ShaderModuleMock::Create(DeviceBase* device
desc.nextInChain = &wgslDesc;
ShaderModuleParseResult parseResult;
DAWN_TRY(ValidateShaderModuleDescriptor(device, &desc, &parseResult, nullptr));
DAWN_TRY(mock->InitializeBase(&parseResult));
DAWN_TRY(ValidateAndParseShaderModule(device, &desc, &parseResult, nullptr));
DAWN_TRY(mock->InitializeBase(&parseResult, nullptr));
return AcquireRef(mock);
}

View File

@ -659,3 +659,11 @@ TEST_F(ShaderModuleValidationTest, MissingDecorations) {
}
)"));
}
// Test that WGSL extension used by enable directives must be allowed by WebGPU.
TEST_F(ShaderModuleValidationTest, ExtensionMustBeAllowed) {
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
enable InternalExtensionForTesting;
@stage(compute) @workgroup_size(1) fn main() {})"));
}