Add support for WGSL shaders via Tint

BUG=dawn:405

Change-Id: I7a79a0d7ce58ff995ec1ff917dd427875fb4deaf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/21340
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
Ryan Harrison 2020-05-21 13:42:26 +00:00 committed by Commit Bot service account
parent 214c71769b
commit bd0ad7921d
6 changed files with 131 additions and 10 deletions

View File

@ -23,6 +23,18 @@
#include <spirv-tools/libspirv.hpp>
#include <spirv_cross.hpp>
#ifdef DAWN_ENABLE_WGSL
// Tint includes must be after spirv_cross.hpp, because spirv-cross has its own
// version of spirv_headers.
// clang-format off
#include "tint/src/reader/wgsl/parser.h"
#include "tint/src/type_determiner.h"
#include "tint/src/validator.h"
#include "tint/src/writer/spirv/generator.h"
#include "tint/src/writer/writer.h"
// clang-format on
#endif // DAWN_ENABLE_WGSL
#include <sstream>
namespace dawn_native {
@ -316,6 +328,77 @@ namespace dawn_native {
return {};
}
#ifdef DAWN_ENABLE_WGSL
MaybeError ValidateWGSL(const char* source) {
std::ostringstream errorStream;
errorStream << "Tint WGSL failure:" << std::endl;
tint::Context context;
tint::reader::wgsl::Parser parser(&context, source);
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::TypeDeterminer type_determiner(&context, &module);
if (!type_determiner.Determine()) {
errorStream << "Type Determination: " << type_determiner.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::Validator validator;
if (!validator.Validate(module)) {
errorStream << "Validation: " << validator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
return {};
}
ResultOrError<std::vector<uint32_t>> ConvertWGSLToSPIRV(const char* source) {
std::ostringstream errorStream;
errorStream << "Tint WGSL->SPIR-V failure:" << std::endl;
tint::Context context;
tint::reader::wgsl::Parser parser(&context, source);
// TODO: This is a duplicate parse with ValidateWGSL, need to store
// state between calls to avoid this.
if (!parser.Parse()) {
errorStream << "Parser: " << parser.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::ast::Module module = parser.module();
if (!module.IsValid()) {
errorStream << "Invalid module generated..." << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::TypeDeterminer type_determiner(&context, &module);
if (!type_determiner.Determine()) {
errorStream << "Type Determination: " << type_determiner.error();
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::writer::spirv::Generator generator(std::move(module));
if (!generator.Generate()) {
errorStream << "Generator: " << generator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
std::vector<uint32_t> spirv = generator.result();
return std::move(spirv);
}
#endif // DAWN_ENABLE_WGSL
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
const ShaderModuleDescriptor* descriptor) {
const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
@ -330,17 +413,22 @@ namespace dawn_native {
switch (chainedDescriptor->sType) {
case wgpu::SType::ShaderModuleSPIRVDescriptor: {
const ShaderModuleSPIRVDescriptor* spirvDesc =
const auto* spirvDesc =
static_cast<const ShaderModuleSPIRVDescriptor*>(chainedDescriptor);
DAWN_TRY(ValidateSpirv(device, spirvDesc->code, spirvDesc->codeSize));
break;
}
case wgpu::SType::ShaderModuleWGSLDescriptor: {
return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
#ifdef DAWN_ENABLE_WGSL
const auto* wgslDesc =
static_cast<const ShaderModuleWGSLDescriptor*>(chainedDescriptor);
DAWN_TRY(ValidateWGSL(wgslDesc->source));
break;
#else
return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
#endif // DAWN_ENABLE_WGSL
}
default:
return DAWN_VALIDATION_ERROR("Unsupported sType");
}
@ -351,13 +439,26 @@ namespace dawn_native {
// ShaderModuleBase
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor)
: CachedObject(device) {
: CachedObject(device), mType(Type::Undefined) {
ASSERT(descriptor->nextInChain != nullptr);
ASSERT(descriptor->nextInChain->sType == wgpu::SType::ShaderModuleSPIRVDescriptor);
const ShaderModuleSPIRVDescriptor* spirvDesc =
static_cast<const ShaderModuleSPIRVDescriptor*>(descriptor->nextInChain);
mSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
switch (descriptor->nextInChain->sType) {
case wgpu::SType::ShaderModuleSPIRVDescriptor: {
mType = Type::Spirv;
const auto* spirvDesc =
static_cast<const ShaderModuleSPIRVDescriptor*>(descriptor->nextInChain);
mSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
break;
}
case wgpu::SType::ShaderModuleWGSLDescriptor: {
mType = Type::Wgsl;
const auto* wgslDesc =
static_cast<const ShaderModuleWGSLDescriptor*>(descriptor->nextInChain);
mWgsl = std::string(wgslDesc->source);
break;
}
default:
UNREACHABLE();
}
mFragmentOutputFormatBaseTypes.fill(Format::Other);
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvcParser)) {
@ -366,7 +467,7 @@ namespace dawn_native {
}
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
: CachedObject(device, tag) {
: CachedObject(device, tag), mType(Type::Undefined) {
}
ShaderModuleBase::~ShaderModuleBase() {
@ -908,4 +1009,15 @@ namespace dawn_native {
return options;
}
MaybeError ShaderModuleBase::InitializeBase() {
if (mType == Type::Wgsl) {
#ifdef DAWN_ENABLE_WGSL
DAWN_TRY_ASSIGN(mSpirv, ConvertWGSLToSPIRV(mWgsl.c_str()));
#else
return DAWN_VALIDATION_ERROR("WGSL not supported (yet)");
#endif // DAWN_ENABLE_WGSL
}
return {};
}
} // namespace dawn_native

View File

@ -43,6 +43,8 @@ namespace dawn_native {
class ShaderModuleBase : public CachedObject {
public:
enum class Type { Undefined, Spirv, Wgsl };
ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
~ShaderModuleBase() override;
@ -89,6 +91,7 @@ namespace dawn_native {
protected:
static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
shaderc_spvc::CompileOptions GetCompileOptions() const;
MaybeError InitializeBase();
shaderc_spvc::Context mSpvcContext;
@ -102,7 +105,9 @@ namespace dawn_native {
MaybeError ExtractSpirvInfoWithSpvc();
MaybeError ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler);
Type mType;
std::vector<uint32_t> mSpirv;
std::string mWgsl;
ModuleBindingInfo mBindingInfo;
std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;

View File

@ -91,6 +91,7 @@ namespace dawn_native { namespace d3d12 {
}
MaybeError ShaderModule::Initialize() {
DAWN_TRY(InitializeBase());
const std::vector<uint32_t>& spirv = GetSpirv();
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {

View File

@ -67,6 +67,7 @@ namespace dawn_native { namespace metal {
}
MaybeError ShaderModule::Initialize() {
DAWN_TRY(InitializeBase());
const std::vector<uint32_t>& spirv = GetSpirv();
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {

View File

@ -68,6 +68,7 @@ namespace dawn_native { namespace opengl {
}
MaybeError ShaderModule::Initialize() {
DAWN_TRY(InitializeBase());
const std::vector<uint32_t>& spirv = GetSpirv();
std::unique_ptr<spirv_cross::CompilerGLSL> compilerImpl;

View File

@ -37,6 +37,7 @@ namespace dawn_native { namespace vulkan {
}
MaybeError ShaderModule::Initialize() {
DAWN_TRY(InitializeBase());
const std::vector<uint32_t>& spirv = GetSpirv();
// Use SPIRV-Cross to extract info from the SPIRV even if Vulkan consumes SPIRV. We want to