dawn: Add shader module validation for WGSL extension
Tint has already implemented the enable directive for using WGSL extension in the future, and using a WGSL extension that is not allowed for the device should result in a shader creation error. In this patch a WGSL extension allow list is added in DeviceBase, and a validation is added in shader module base initialization to make sure all extensions used in the WGSL program are in the allow list. This patch also rename the `ValidateShaderModuleDescriptor` to `ValidateAndParseShaderModule`, which is more descriptive for what it actually does. Bug: tint:1472 Change-Id: I4b039a3e37c25159b4fc6cfa37488aa817004ab2 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/88241 Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com> Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
parent
4cb75245e3
commit
d6b2501be2
|
@ -199,6 +199,8 @@ DeviceBase::DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor)
|
|||
mFormatTable = BuildFormatTable(this);
|
||||
SetDefaultToggles();
|
||||
|
||||
SetWGSLExtensionAllowList();
|
||||
|
||||
if (descriptor->label != nullptr && strlen(descriptor->label) != 0) {
|
||||
mLabel = descriptor->label;
|
||||
}
|
||||
|
@ -914,13 +916,13 @@ ResultOrError<Ref<ShaderModuleBase>> DeviceBase::GetOrCreateShaderModule(
|
|||
if (!parseResult->HasParsedShader()) {
|
||||
// We skip the parse on creation if validation isn't enabled which let's us quickly
|
||||
// lookup in the cache without validating and parsing. We need the parsed module
|
||||
// now, so call validate. Most of |ValidateShaderModuleDescriptor| is parsing, but
|
||||
// we can consider splitting it if additional validation is added.
|
||||
// now.
|
||||
ASSERT(!IsValidationEnabled());
|
||||
DAWN_TRY(
|
||||
ValidateShaderModuleDescriptor(this, descriptor, parseResult, compilationMessages));
|
||||
ValidateAndParseShaderModule(this, descriptor, parseResult, compilationMessages));
|
||||
}
|
||||
DAWN_TRY_ASSIGN(result, CreateShaderModuleImpl(descriptor, parseResult));
|
||||
DAWN_TRY_ASSIGN(result,
|
||||
CreateShaderModuleImpl(descriptor, parseResult, compilationMessages));
|
||||
result->SetIsCachedReference();
|
||||
result->SetContentHash(blueprintHash);
|
||||
mCaches->shaderModules.insert(result.Get());
|
||||
|
@ -1119,7 +1121,8 @@ ShaderModuleBase* DeviceBase::APICreateShaderModule(const ShaderModuleDescriptor
|
|||
result = ShaderModuleBase::MakeError(this);
|
||||
}
|
||||
// Move compilation messages into ShaderModuleBase and emit tint errors and warnings
|
||||
// after all other operations are finished successfully.
|
||||
// after all other operations are finished, even if any of them is failed and result
|
||||
// is an error shader module.
|
||||
result->InjectCompilationMessages(std::move(compilationMessages));
|
||||
|
||||
return result.Detach();
|
||||
|
@ -1229,6 +1232,16 @@ bool DeviceBase::IsFeatureEnabled(Feature feature) const {
|
|||
return mEnabledFeatures.IsEnabled(feature);
|
||||
}
|
||||
|
||||
void DeviceBase::SetWGSLExtensionAllowList() {
|
||||
// Set the WGSL extensions allow list based on device's enabled features and other
|
||||
// propority. For example:
|
||||
// mWGSLExtensionAllowList.insert("InternalExtensionForTesting");
|
||||
}
|
||||
|
||||
WGSLExtensionsSet DeviceBase::GetWGSLExtensionAllowList() const {
|
||||
return mWGSLExtensionAllowList;
|
||||
}
|
||||
|
||||
bool DeviceBase::IsValidationEnabled() const {
|
||||
return !IsToggleEnabled(Toggle::SkipValidation);
|
||||
}
|
||||
|
@ -1589,7 +1602,7 @@ ResultOrError<Ref<ShaderModuleBase>> DeviceBase::CreateShaderModule(
|
|||
|
||||
if (IsValidationEnabled()) {
|
||||
DAWN_TRY_CONTEXT(
|
||||
ValidateShaderModuleDescriptor(this, descriptor, &parseResult, compilationMessages),
|
||||
ValidateAndParseShaderModule(this, descriptor, &parseResult, compilationMessages),
|
||||
"validating %s", descriptor);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
@ -54,6 +55,8 @@ struct CallbackTask;
|
|||
struct InternalPipelineStore;
|
||||
struct ShaderModuleParseResult;
|
||||
|
||||
using WGSLExtensionsSet = std::unordered_set<std::string>;
|
||||
|
||||
class DeviceBase : public RefCounted {
|
||||
public:
|
||||
DeviceBase(AdapterBase* adapter, const DeviceDescriptor* descriptor);
|
||||
|
@ -316,6 +319,7 @@ class DeviceBase : public RefCounted {
|
|||
std::mutex* GetObjectListMutex(ObjectType type);
|
||||
|
||||
std::vector<const char*> GetTogglesUsed() const;
|
||||
WGSLExtensionsSet GetWGSLExtensionAllowList() const;
|
||||
bool IsFeatureEnabled(Feature feature) const;
|
||||
bool IsToggleEnabled(Toggle toggle) const;
|
||||
bool IsValidationEnabled() const;
|
||||
|
@ -412,7 +416,8 @@ class DeviceBase : public RefCounted {
|
|||
const SamplerDescriptor* descriptor) = 0;
|
||||
virtual ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) = 0;
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) = 0;
|
||||
virtual ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) = 0;
|
||||
// Note that previousSwapChain may be nullptr, or come from a different backend.
|
||||
|
@ -456,6 +461,8 @@ class DeviceBase : public RefCounted {
|
|||
|
||||
void SetDefaultToggles();
|
||||
|
||||
void SetWGSLExtensionAllowList();
|
||||
|
||||
void ConsumeError(std::unique_ptr<ErrorData> error);
|
||||
|
||||
// Each backend should implement to check their passed fences if there are any and return a
|
||||
|
@ -542,6 +549,7 @@ class DeviceBase : public RefCounted {
|
|||
|
||||
CombinedLimits mLimits;
|
||||
FeaturesSet mEnabledFeatures;
|
||||
WGSLExtensionsSet mWGSLExtensionAllowList;
|
||||
|
||||
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;
|
||||
|
||||
|
|
|
@ -900,27 +900,70 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
|||
return std::move(metadata);
|
||||
}
|
||||
|
||||
ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(const DeviceBase* device,
|
||||
const tint::Program* program) {
|
||||
MaybeError ValidateWGSLProgramExtension(const DeviceBase* device,
|
||||
const tint::Program* program,
|
||||
OwnedCompilationMessages* outMessages) {
|
||||
DAWN_ASSERT(program->IsValid());
|
||||
tint::inspector::Inspector inspector(program);
|
||||
auto enableDirectives = inspector.GetEnableDirectives();
|
||||
|
||||
auto extensionAllowList = device->GetWGSLExtensionAllowList();
|
||||
|
||||
bool hasDisallowedExtension = false;
|
||||
tint::diag::List messages;
|
||||
|
||||
for (auto enable : enableDirectives) {
|
||||
if (extensionAllowList.count(enable.first)) {
|
||||
continue;
|
||||
}
|
||||
hasDisallowedExtension = true;
|
||||
messages.add_error(tint::diag::System::Program,
|
||||
"Extension " + enable.first + " is not allowed on the Device.",
|
||||
enable.second);
|
||||
}
|
||||
|
||||
if (hasDisallowedExtension) {
|
||||
if (outMessages != nullptr) {
|
||||
outMessages->AddMessages(messages);
|
||||
}
|
||||
return DAWN_MAKE_ERROR(InternalErrorType::Validation,
|
||||
"Shader module uses extension(s) not enabled for its device.");
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
MaybeError ReflectShaderUsingTint(const DeviceBase* device,
|
||||
const tint::Program* program,
|
||||
OwnedCompilationMessages* compilationMessages,
|
||||
EntryPointMetadataTable& entryPointMetadataTable,
|
||||
WGSLExtensionsSet* enabledWGSLExtensions) {
|
||||
ASSERT(program->IsValid());
|
||||
|
||||
tint::inspector::Inspector inspector(program);
|
||||
|
||||
ASSERT(enabledWGSLExtensions->empty());
|
||||
auto usedExtensionNames = inspector.GetUsedExtensionNames();
|
||||
for (std::string name : usedExtensionNames) {
|
||||
enabledWGSLExtensions->insert(name);
|
||||
}
|
||||
|
||||
DAWN_TRY(ValidateWGSLProgramExtension(device, program, compilationMessages));
|
||||
|
||||
std::vector<tint::inspector::EntryPoint> entryPoints = inspector.GetEntryPoints();
|
||||
DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n",
|
||||
inspector.error());
|
||||
|
||||
EntryPointMetadataTable result;
|
||||
|
||||
for (const tint::inspector::EntryPoint& entryPoint : entryPoints) {
|
||||
std::unique_ptr<EntryPointMetadata> metadata;
|
||||
DAWN_TRY_ASSIGN_CONTEXT(metadata,
|
||||
ReflectEntryPointUsingTint(device, &inspector, entryPoint),
|
||||
"processing entry point \"%s\".", entryPoint.name);
|
||||
|
||||
ASSERT(result.count(entryPoint.name) == 0);
|
||||
result[entryPoint.name] = std::move(metadata);
|
||||
ASSERT(entryPointMetadataTable.count(entryPoint.name) == 0);
|
||||
entryPointMetadataTable[entryPoint.name] = std::move(metadata);
|
||||
}
|
||||
return std::move(result);
|
||||
return {};
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
|
@ -946,10 +989,10 @@ class TintSource {
|
|||
tint::Source::File file;
|
||||
};
|
||||
|
||||
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* outMessages) {
|
||||
MaybeError ValidateAndParseShaderModule(DeviceBase* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* outMessages) {
|
||||
ASSERT(parseResult != nullptr);
|
||||
|
||||
const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
|
||||
|
@ -1287,11 +1330,13 @@ void ShaderModuleBase::AddExternalTextureTransform(const PipelineLayoutBase* lay
|
|||
}
|
||||
}
|
||||
|
||||
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
|
||||
MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
mTintProgram = std::move(parseResult->tintProgram);
|
||||
mTintSource = std::move(parseResult->tintSource);
|
||||
|
||||
DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingTint(GetDevice(), mTintProgram.get()));
|
||||
DAWN_TRY(ReflectShaderUsingTint(GetDevice(), mTintProgram.get(), compilationMessages,
|
||||
mEntryPoints, &mEnabledWGSLExtensions));
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ class VertexPulling;
|
|||
|
||||
namespace dawn::native {
|
||||
|
||||
using WGSLExtensionsSet = std::unordered_set<std::string>;
|
||||
struct EntryPointMetadata;
|
||||
|
||||
// Base component type of an inter-stage variable
|
||||
|
@ -99,10 +100,10 @@ struct ShaderModuleParseResult {
|
|||
std::unique_ptr<TintSource> tintSource;
|
||||
};
|
||||
|
||||
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* outMessages);
|
||||
MaybeError ValidateAndParseShaderModule(DeviceBase* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* outMessages);
|
||||
MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
|
||||
const EntryPointMetadata& entryPoint,
|
||||
const PipelineLayoutBase* layout);
|
||||
|
@ -289,7 +290,8 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
|
|||
explicit ShaderModuleBase(DeviceBase* device);
|
||||
void DestroyImpl() override;
|
||||
|
||||
MaybeError InitializeBase(ShaderModuleParseResult* parseResult);
|
||||
MaybeError InitializeBase(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
|
||||
static void AddExternalTextureTransform(const PipelineLayoutBase* layout,
|
||||
tint::transform::Manager* transformManager,
|
||||
|
@ -305,6 +307,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject {
|
|||
std::string mWgsl;
|
||||
|
||||
EntryPointMetadataTable mEntryPoints;
|
||||
WGSLExtensionsSet mEnabledWGSLExtensions;
|
||||
std::unique_ptr<tint::Program> mTintProgram;
|
||||
std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive
|
||||
|
||||
|
|
|
@ -422,8 +422,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
|
|||
}
|
||||
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
|
||||
}
|
||||
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) {
|
||||
|
|
|
@ -170,7 +170,8 @@ class Device final : public DeviceBase {
|
|||
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) override;
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) override;
|
||||
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(
|
||||
|
|
|
@ -729,20 +729,23 @@ MaybeError CompileShader(dawn::platform::Platform* platform,
|
|||
} // anonymous namespace
|
||||
|
||||
// static
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
|
||||
Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
|
||||
DAWN_TRY(module->Initialize(parseResult));
|
||||
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
|
||||
return module;
|
||||
}
|
||||
|
||||
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
|
||||
: ShaderModuleBase(device, descriptor) {}
|
||||
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
||||
return InitializeBase(parseResult);
|
||||
return InitializeBase(parseResult, compilationMessages);
|
||||
}
|
||||
|
||||
ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,
|
||||
|
|
|
@ -42,7 +42,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
public:
|
||||
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
|
||||
ResultOrError<CompiledShader> Compile(const ProgrammableStage& programmableStage,
|
||||
SingleShaderStage stage,
|
||||
|
@ -52,7 +53,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
private:
|
||||
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
||||
~ShaderModule() override = default;
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult);
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
};
|
||||
|
||||
} // namespace dawn::native::d3d12
|
||||
|
|
|
@ -92,7 +92,8 @@ class Device final : public DeviceBase {
|
|||
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) override;
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) override;
|
||||
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(
|
||||
|
|
|
@ -260,8 +260,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
|
|||
}
|
||||
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
|
||||
}
|
||||
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) {
|
||||
|
|
|
@ -35,7 +35,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
public:
|
||||
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
|
||||
struct MetalFunctionData {
|
||||
NSPRef<id<MTLFunction>> function;
|
||||
|
@ -65,7 +66,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
std::vector<uint32_t>* workgroupAllocations);
|
||||
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
||||
~ShaderModule() override = default;
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult);
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
};
|
||||
|
||||
} // namespace dawn::native::metal
|
||||
|
|
|
@ -29,20 +29,23 @@
|
|||
namespace dawn::native::metal {
|
||||
|
||||
// static
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
|
||||
Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
|
||||
DAWN_TRY(module->Initialize(parseResult));
|
||||
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
|
||||
return module;
|
||||
}
|
||||
|
||||
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
|
||||
: ShaderModuleBase(device, descriptor) {}
|
||||
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
||||
return InitializeBase(parseResult);
|
||||
return InitializeBase(parseResult, compilationMessages);
|
||||
}
|
||||
|
||||
ResultOrError<std::string> ShaderModule::TranslateToMSL(
|
||||
|
|
|
@ -155,9 +155,10 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
|
|||
}
|
||||
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
|
||||
DAWN_TRY(module->Initialize(parseResult));
|
||||
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
|
||||
return module;
|
||||
}
|
||||
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
|
||||
|
@ -429,8 +430,9 @@ void SwapChain::DetachFromSurfaceImpl() {
|
|||
|
||||
// ShaderModule
|
||||
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
|
||||
return InitializeBase(parseResult);
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
return InitializeBase(parseResult, compilationMessages);
|
||||
}
|
||||
|
||||
// OldSwapChain
|
||||
|
|
|
@ -142,7 +142,8 @@ class Device final : public DeviceBase {
|
|||
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) override;
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) override;
|
||||
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(
|
||||
|
@ -277,7 +278,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
public:
|
||||
using ShaderModuleBase::ShaderModuleBase;
|
||||
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult);
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
};
|
||||
|
||||
class SwapChain final : public NewSwapChainBase {
|
||||
|
|
|
@ -171,8 +171,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
|
|||
}
|
||||
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
|
||||
}
|
||||
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) {
|
||||
|
|
|
@ -98,7 +98,8 @@ class Device final : public DeviceBase {
|
|||
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) override;
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) override;
|
||||
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(
|
||||
|
|
|
@ -59,21 +59,24 @@ std::string CombinedSampler::GetName() const {
|
|||
}
|
||||
|
||||
// static
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
|
||||
Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
|
||||
DAWN_TRY(module->Initialize(parseResult));
|
||||
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
|
||||
return module;
|
||||
}
|
||||
|
||||
ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
|
||||
: ShaderModuleBase(device, descriptor) {}
|
||||
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
||||
|
||||
DAWN_TRY(InitializeBase(parseResult));
|
||||
DAWN_TRY(InitializeBase(parseResult, compilationMessages));
|
||||
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -56,7 +56,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
public:
|
||||
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
|
||||
ResultOrError<std::string> TranslateToGLSL(const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
|
@ -67,7 +68,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
private:
|
||||
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
||||
~ShaderModule() override = default;
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult);
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
};
|
||||
|
||||
} // namespace dawn::native::opengl
|
||||
|
|
|
@ -151,8 +151,9 @@ ResultOrError<Ref<SamplerBase>> Device::CreateSamplerImpl(const SamplerDescripto
|
|||
}
|
||||
ResultOrError<Ref<ShaderModuleBase>> Device::CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
return ShaderModule::Create(this, descriptor, parseResult, compilationMessages);
|
||||
}
|
||||
ResultOrError<Ref<SwapChainBase>> Device::CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) {
|
||||
|
|
|
@ -127,7 +127,8 @@ class Device final : public DeviceBase {
|
|||
ResultOrError<Ref<SamplerBase>> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModuleImpl(
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) override;
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) override;
|
||||
ResultOrError<Ref<SwapChainBase>> CreateSwapChainImpl(
|
||||
const SwapChainDescriptor* descriptor) override;
|
||||
ResultOrError<Ref<NewSwapChainBase>> CreateSwapChainImpl(
|
||||
|
|
|
@ -73,11 +73,13 @@ ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCach
|
|||
}
|
||||
|
||||
// static
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult) {
|
||||
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
|
||||
Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
|
||||
DAWN_TRY(module->Initialize(parseResult));
|
||||
DAWN_TRY(module->Initialize(parseResult, compilationMessages));
|
||||
return module;
|
||||
}
|
||||
|
||||
|
@ -86,7 +88,8 @@ ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descrip
|
|||
mTransformedShaderModuleCache(
|
||||
std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {}
|
||||
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
|
||||
MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages) {
|
||||
if (GetDevice()->IsRobustnessEnabled()) {
|
||||
ScopedTintICEHandler scopedICEHandler(GetDevice());
|
||||
|
||||
|
@ -100,7 +103,7 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
|
|||
parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
|
||||
}
|
||||
|
||||
return InitializeBase(parseResult);
|
||||
return InitializeBase(parseResult, compilationMessages);
|
||||
}
|
||||
|
||||
void ShaderModule::DestroyImpl() {
|
||||
|
|
|
@ -40,7 +40,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
|
||||
static ResultOrError<Ref<ShaderModule>> Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor,
|
||||
ShaderModuleParseResult* parseResult);
|
||||
ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
|
||||
ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
|
||||
const PipelineLayout* layout);
|
||||
|
@ -48,7 +49,8 @@ class ShaderModule final : public ShaderModuleBase {
|
|||
private:
|
||||
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
||||
~ShaderModule() override;
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult);
|
||||
MaybeError Initialize(ShaderModuleParseResult* parseResult,
|
||||
OwnedCompilationMessages* compilationMessages);
|
||||
void DestroyImpl() override;
|
||||
|
||||
// New handles created by GetHandleAndSpirv at pipeline creation time.
|
||||
|
|
|
@ -90,7 +90,9 @@ class DeviceMock : public DeviceBase {
|
|||
(override));
|
||||
MOCK_METHOD(ResultOrError<Ref<ShaderModuleBase>>,
|
||||
CreateShaderModuleImpl,
|
||||
(const ShaderModuleDescriptor*, ShaderModuleParseResult*),
|
||||
(const ShaderModuleDescriptor*,
|
||||
ShaderModuleParseResult*,
|
||||
OwnedCompilationMessages*),
|
||||
(override));
|
||||
MOCK_METHOD(ResultOrError<Ref<SwapChainBase>>,
|
||||
CreateSwapChainImpl,
|
||||
|
|
|
@ -30,8 +30,8 @@ ResultOrError<Ref<ShaderModuleMock>> ShaderModuleMock::Create(DeviceBase* device
|
|||
desc.nextInChain = &wgslDesc;
|
||||
|
||||
ShaderModuleParseResult parseResult;
|
||||
DAWN_TRY(ValidateShaderModuleDescriptor(device, &desc, &parseResult, nullptr));
|
||||
DAWN_TRY(mock->InitializeBase(&parseResult));
|
||||
DAWN_TRY(ValidateAndParseShaderModule(device, &desc, &parseResult, nullptr));
|
||||
DAWN_TRY(mock->InitializeBase(&parseResult, nullptr));
|
||||
return AcquireRef(mock);
|
||||
}
|
||||
|
||||
|
|
|
@ -659,3 +659,11 @@ TEST_F(ShaderModuleValidationTest, MissingDecorations) {
|
|||
}
|
||||
)"));
|
||||
}
|
||||
|
||||
// Test that WGSL extension used by enable directives must be allowed by WebGPU.
|
||||
TEST_F(ShaderModuleValidationTest, ExtensionMustBeAllowed) {
|
||||
ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"(
|
||||
enable InternalExtensionForTesting;
|
||||
|
||||
@stage(compute) @workgroup_size(1) fn main() {})"));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue