Add support for WGSL shaders via Tint
BUG=dawn:405 Change-Id: I7a79a0d7ce58ff995ec1ff917dd427875fb4deaf Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/21340 Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
parent
214c71769b
commit
bd0ad7921d
|
@ -23,6 +23,18 @@
|
|||
#include <spirv-tools/libspirv.hpp>
|
||||
#include <spirv_cross.hpp>
|
||||
|
||||
#ifdef DAWN_ENABLE_WGSL
|
||||
// Tint includes must be after spirv_cross.hpp, because spirv-cross has its own
|
||||
// version of spirv_headers.
|
||||
// clang-format off
|
||||
#include "tint/src/reader/wgsl/parser.h"
|
||||
#include "tint/src/type_determiner.h"
|
||||
#include "tint/src/validator.h"
|
||||
#include "tint/src/writer/spirv/generator.h"
|
||||
#include "tint/src/writer/writer.h"
|
||||
// clang-format on
|
||||
#endif // DAWN_ENABLE_WGSL
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace dawn_native {
|
||||
|
@ -316,6 +328,77 @@ namespace dawn_native {
|
|||
return {};
|
||||
}
|
||||
|
||||
#ifdef DAWN_ENABLE_WGSL
|
||||
MaybeError ValidateWGSL(const char* source) {
|
||||
std::ostringstream errorStream;
|
||||
errorStream << "Tint WGSL failure:" << std::endl;
|
||||
|
||||
tint::Context context;
|
||||
tint::reader::wgsl::Parser parser(&context, source);
|
||||
|
||||
if (!parser.Parse()) {
|
||||
errorStream << "Parser: " << parser.error() << std::endl;
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
tint::ast::Module module = parser.module();
|
||||
if (!module.IsValid()) {
|
||||
errorStream << "Invalid module generated..." << std::endl;
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
tint::TypeDeterminer type_determiner(&context, &module);
|
||||
if (!type_determiner.Determine()) {
|
||||
errorStream << "Type Determination: " << type_determiner.error();
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
tint::Validator validator;
|
||||
if (!validator.Validate(module)) {
|
||||
errorStream << "Validation: " << validator.error() << std::endl;
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRV(const char* source) {
|
||||
std::ostringstream errorStream;
|
||||
errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
|
||||
|
||||
tint::Context context;
|
||||
tint::reader::wgsl::Parser parser(&context, source);
|
||||
|
||||
// TODO: This is a duplicate parse with ValidateWGSL, need to store
|
||||
// state between calls to avoid this.
|
||||
if (!parser.Parse()) {
|
||||
errorStream << "Parser: " << parser.error() << std::endl;
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
tint::ast::Module module = parser.module();
|
||||
if (!module.IsValid()) {
|
||||
errorStream << "Invalid module generated..." << std::endl;
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
tint::TypeDeterminer type_determiner(&context, &module);
|
||||
if (!type_determiner.Determine()) {
|
||||
errorStream << "Type Determination: " << type_determiner.error();
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
tint::writer::spirv::Generator generator(std::move(module));
|
||||
if (!generator.Generate()) {
|
||||
errorStream << "Generator: " << generator.error() << std::endl;
|
||||
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<uint32_t> spirv = generator.result();
|
||||
return std::move(spirv);
|
||||
}
|
||||
#endif // DAWN_ENABLE_WGSL
|
||||
|
||||
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
|
||||
const ShaderModuleDescriptor* descriptor) {
|
||||
const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
|
||||
|
@ -330,17 +413,22 @@ namespace dawn_native {
|
|||
|
||||
switch (chainedDescriptor->sType) {
|
||||
case wgpu::SType::ShaderModuleSPIRVDescriptor: {
|
||||
const ShaderModuleSPIRVDescriptor* spirvDesc =
|
||||
const auto* spirvDesc =
|
||||
static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
|
||||
DAWN_TRY(ValidateSpirv(device, spirvDesc->code, spirvDesc->codeSize));
|
||||
break;
|
||||
}
|
||||
|
||||
case wgpu::SType::ShaderModuleWGSLDescriptor: {
|
||||
return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
|
||||
#ifdef DAWN_ENABLE_WGSL
|
||||
const auto* wgslDesc =
|
||||
static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor);
|
||||
DAWN_TRY(ValidateWGSL(wgslDesc->source));
|
||||
break;
|
||||
#else
|
||||
return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
|
||||
#endif // DAWN_ENABLE_WGSL
|
||||
}
|
||||
|
||||
default:
|
||||
return DAWN_VALIDATION_ERROR("Unsupported sType");
|
||||
}
|
||||
|
@ -351,13 +439,26 @@ namespace dawn_native {
|
|||
// ShaderModuleBase
|
||||
|
||||
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor)
|
||||
: CachedObject(device) {
|
||||
: CachedObject(device), mType(Type::Undefined) {
|
||||
ASSERT(descriptor->nextInChain != nullptr);
|
||||
ASSERT(descriptor->nextInChain->sType == wgpu::SType::ShaderModuleSPIRVDescriptor);
|
||||
|
||||
const ShaderModuleSPIRVDescriptor* spirvDesc =
|
||||
static_cast<const ShaderModuleSPIRVDescriptor*>(descriptor->nextInChain);
|
||||
mSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
|
||||
switch (descriptor->nextInChain->sType) {
|
||||
case wgpu::SType::ShaderModuleSPIRVDescriptor: {
|
||||
mType = Type::Spirv;
|
||||
const auto* spirvDesc =
|
||||
static_cast<const ShaderModuleSPIRVDescriptor*>(descriptor->nextInChain);
|
||||
mSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
|
||||
break;
|
||||
}
|
||||
case wgpu::SType::ShaderModuleWGSLDescriptor: {
|
||||
mType = Type::Wgsl;
|
||||
const auto* wgslDesc =
|
||||
static_cast<const ShaderModuleWGSLDescriptor*>(descriptor->nextInChain);
|
||||
mWgsl = std::string(wgslDesc->source);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
mFragmentOutputFormatBaseTypes.fill(Format::Other);
|
||||
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvcParser)) {
|
||||
|
@ -366,7 +467,7 @@ namespace dawn_native {
|
|||
}
|
||||
|
||||
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
|
||||
: CachedObject(device, tag) {
|
||||
: CachedObject(device, tag), mType(Type::Undefined) {
|
||||
}
|
||||
|
||||
ShaderModuleBase::~ShaderModuleBase() {
|
||||
|
@ -908,4 +1009,15 @@ namespace dawn_native {
|
|||
return options;
|
||||
}
|
||||
|
||||
MaybeError ShaderModuleBase::InitializeBase() {
|
||||
if (mType == Type::Wgsl) {
|
||||
#ifdef DAWN_ENABLE_WGSL
|
||||
DAWN_TRY_ASSIGN(mSpirv, ConvertWGSLToSPIRV(mWgsl.c_str()));
|
||||
#else
|
||||
return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
|
||||
#endif // DAWN_ENABLE_WGSL
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
} // namespace dawn_native
|
||||
|
|
|
@ -43,6 +43,8 @@ namespace dawn_native {
|
|||
|
||||
class ShaderModuleBase : public CachedObject {
|
||||
public:
|
||||
enum class Type { Undefined, Spirv, Wgsl };
|
||||
|
||||
ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
|
||||
~ShaderModuleBase() override;
|
||||
|
||||
|
@ -89,6 +91,7 @@ namespace dawn_native {
|
|||
protected:
|
||||
static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
|
||||
shaderc_spvc::CompileOptions GetCompileOptions() const;
|
||||
MaybeError InitializeBase();
|
||||
|
||||
shaderc_spvc::Context mSpvcContext;
|
||||
|
||||
|
@ -102,7 +105,9 @@ namespace dawn_native {
|
|||
MaybeError ExtractSpirvInfoWithSpvc();
|
||||
MaybeError ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler);
|
||||
|
||||
Type mType;
|
||||
std::vector<uint32_t> mSpirv;
|
||||
std::string mWgsl;
|
||||
|
||||
ModuleBindingInfo mBindingInfo;
|
||||
std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;
|
||||
|
|
|
@ -91,6 +91,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
}
|
||||
|
||||
MaybeError ShaderModule::Initialize() {
|
||||
DAWN_TRY(InitializeBase());
|
||||
const std::vector<uint32_t>& spirv = GetSpirv();
|
||||
|
||||
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
|
||||
|
|
|
@ -67,6 +67,7 @@ namespace dawn_native { namespace metal {
|
|||
}
|
||||
|
||||
MaybeError ShaderModule::Initialize() {
|
||||
DAWN_TRY(InitializeBase());
|
||||
const std::vector<uint32_t>& spirv = GetSpirv();
|
||||
|
||||
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
|
||||
|
|
|
@ -68,6 +68,7 @@ namespace dawn_native { namespace opengl {
|
|||
}
|
||||
|
||||
MaybeError ShaderModule::Initialize() {
|
||||
DAWN_TRY(InitializeBase());
|
||||
const std::vector<uint32_t>& spirv = GetSpirv();
|
||||
|
||||
std::unique_ptr<spirv_cross::CompilerGLSL> compilerImpl;
|
||||
|
|
|
@ -37,6 +37,7 @@ namespace dawn_native { namespace vulkan {
|
|||
}
|
||||
|
||||
MaybeError ShaderModule::Initialize() {
|
||||
DAWN_TRY(InitializeBase());
|
||||
const std::vector<uint32_t>& spirv = GetSpirv();
|
||||
|
||||
// Use SPIRV-Cross to extract info from the SPIRV even if Vulkan consumes SPIRV. We want to
|
||||
|
|
Loading…
Reference in New Issue