diff --git a/BUILD.gn b/BUILD.gn index fad9992103..dc6641efab 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -314,6 +314,7 @@ source_set("libdawn_native_sources") { deps = [ ":dawn_common", ":libdawn_native_utils_gen", + "${dawn_spirv_tools_dir}:spvtools_val", "third_party:spirv_cross", ] @@ -772,6 +773,7 @@ test("dawn_unittests") { "src/tests/unittests/validation/PushConstantsValidationTests.cpp", "src/tests/unittests/validation/RenderPassDescriptorValidationTests.cpp", "src/tests/unittests/validation/RenderPipelineValidationTests.cpp", + "src/tests/unittests/validation/ShaderModuleValidationTests.cpp", "src/tests/unittests/validation/ValidationTest.cpp", "src/tests/unittests/validation/ValidationTest.h", "src/tests/unittests/validation/VertexBufferValidationTests.cpp", diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt index a24999d14c..031728a94d 100644 --- a/src/dawn_native/CMakeLists.txt +++ b/src/dawn_native/CMakeLists.txt @@ -31,8 +31,8 @@ Generate( ) set(DAWN_NATIVE_SOURCES) -set(DAWN_NATIVE_DEPS dawn_common spirv_cross dawn_native_utils_autogen) -set(DAWN_NATIVE_INCLUDE_DIRS ${SPIRV_CROSS_INCLUDE_DIR}) +set(DAWN_NATIVE_DEPS dawn_common spirv_cross dawn_native_utils_autogen SPIRV-Tools) +set(DAWN_NATIVE_INCLUDE_DIRS ${SPIRV_CROSS_INCLUDE_DIR} ${SPIRV_TOOLS_INCLUDE_DIR}) ################################################################################ # OpenGL Backend diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index f7f7a24a6e..faf9551ba9 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -20,13 +20,44 @@ #include "dawn_native/PipelineLayout.h" #include +#include namespace dawn_native { MaybeError ValidateShaderModuleDescriptor(DeviceBase*, const ShaderModuleDescriptor* descriptor) { DAWN_TRY_ASSERT(descriptor->nextInChain == nullptr, "nextInChain must be nullptr"); - // TODO(cwallez@chromium.org): Use spirv-val to check the module is well-formed + + spvtools::SpirvTools spirvTools(SPV_ENV_WEBGPU_0); + + std::ostringstream errorStream; + errorStream << "SPIRV Validation failure:" << std::endl; + + spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + errorStream << "error: line " << position.index << ": " << message << std::endl; + break; + case SPV_MSG_WARNING: + errorStream << "warning: line " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + errorStream << "info: line " << position.index << ": " << message << std::endl; + break; + default: + break; + } + }); + + if (!spirvTools.Validate(descriptor->code, descriptor->codeSize)) { + DAWN_RETURN_ERROR(errorStream.str().c_str()); + } + return {}; } diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 27828e20c0..68c573b0b5 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -53,6 +53,7 @@ list(APPEND UNITTEST_SOURCES ${VALIDATION_TESTS_DIR}/PushConstantsValidationTests.cpp ${VALIDATION_TESTS_DIR}/RenderPassDescriptorValidationTests.cpp ${VALIDATION_TESTS_DIR}/RenderPipelineValidationTests.cpp + ${VALIDATION_TESTS_DIR}/ShaderModuleValidationTests.cpp ${VALIDATION_TESTS_DIR}/VertexBufferValidationTests.cpp ${VALIDATION_TESTS_DIR}/ValidationTest.cpp ${VALIDATION_TESTS_DIR}/ValidationTest.h diff --git a/src/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp new file mode 100644 index 0000000000..5bf4a2163b --- /dev/null +++ b/src/tests/unittests/validation/ShaderModuleValidationTests.cpp @@ -0,0 +1,89 @@ +// Copyright 2018 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tests/unittests/validation/ValidationTest.h" + +#include "utils/DawnHelpers.h" + +class ShaderModuleValidationTest : public ValidationTest { +}; + +// Test case with a simpler shader that should successfully be created +TEST_F(ShaderModuleValidationTest, CreationSuccess) { + const char* shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %fragColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %fragColor "fragColor" + OpDecorate %fragColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %_ptr_Output_v4float = OpTypePointer Output %v4float + %fragColor = OpVariable %_ptr_Output_v4float Output + %float_1 = OpConstant %float 1 + %float_0 = OpConstant %float 0 + %12 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1 + %main = OpFunction %void None %3 + %5 = OpLabel + OpStore %fragColor %12 + OpReturn + OpFunctionEnd)"; + + utils::CreateShaderModuleFromASM(device, shader); +} + +// Test case with a shader with OpUndef to test WebGPU-specific validation +TEST_F(ShaderModuleValidationTest, OpUndef) { + const char* shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" %fragColor + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpSourceExtension "GL_GOOGLE_cpp_style_line_directive" + OpSourceExtension "GL_GOOGLE_include_directive" + OpName %main "main" + OpName %fragColor "fragColor" + OpDecorate %fragColor Location 0 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %float = OpTypeFloat 32 + %v4float = OpTypeVector %float 4 + %_ptr_Output_v4float = OpTypePointer Output %v4float + %fragColor = OpVariable %_ptr_Output_v4float Output + %float_1 = OpConstant %float 1 + %float_0 = OpConstant %float 0 + %12 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1 + %main = OpFunction %void None %3 + %5 = OpLabel + %6 = OpUndef %v4float + OpStore %fragColor %12 + OpReturn + OpFunctionEnd)"; + + // Notice "%6 = OpUndef %v4float" above + ASSERT_DEVICE_ERROR(utils::CreateShaderModuleFromASM(device, shader)); + + std::string error = GetLastDeviceErrorMessage(); + ASSERT_NE(error.find("OpUndef"), std::string::npos); +} diff --git a/src/tests/unittests/validation/ValidationTest.cpp b/src/tests/unittests/validation/ValidationTest.cpp index f511b5f947..85ba23707b 100644 --- a/src/tests/unittests/validation/ValidationTest.cpp +++ b/src/tests/unittests/validation/ValidationTest.cpp @@ -69,6 +69,9 @@ bool ValidationTest::EndExpectDeviceError() { mExpectError = false; return mError; } +std::string ValidationTest::GetLastDeviceErrorMessage() const { + return mDeviceErrorMessage; +} dawn::RenderPassDescriptor ValidationTest::CreateSimpleRenderPass() { dawn::TextureDescriptor descriptor; @@ -91,14 +94,16 @@ dawn::RenderPassDescriptor ValidationTest::CreateSimpleRenderPass() { } void ValidationTest::OnDeviceError(const char* message, dawnCallbackUserdata userdata) { + auto self = reinterpret_cast(static_cast(userdata)); + self->mDeviceErrorMessage = message; + // Skip this one specific error that is raised when a builder is used after it got an error // this is important because we don't want to wrap all creation tests in ASSERT_DEVICE_ERROR. // Yes the error message is misleading. - if (std::string(message) == "Builder cannot be used after GetResult") { + if (self->mDeviceErrorMessage == "Builder cannot be used after GetResult") { return; } - auto self = reinterpret_cast(static_cast(userdata)); ASSERT_TRUE(self->mExpectError) << "Got unexpected device error: " << message; ASSERT_FALSE(self->mError) << "Got two errors in expect block"; self->mError = true; diff --git a/src/tests/unittests/validation/ValidationTest.h b/src/tests/unittests/validation/ValidationTest.h index 2189b1077c..c6bd9083cb 100644 --- a/src/tests/unittests/validation/ValidationTest.h +++ b/src/tests/unittests/validation/ValidationTest.h @@ -47,6 +47,7 @@ class ValidationTest : public testing::Test { void StartExpectDeviceError(); bool EndExpectDeviceError(); + std::string GetLastDeviceErrorMessage() const; dawn::RenderPassDescriptor CreateSimpleRenderPass(); @@ -66,6 +67,7 @@ class ValidationTest : public testing::Test { private: static void OnDeviceError(const char* message, dawnCallbackUserdata userdata); + std::string mDeviceErrorMessage; bool mExpectError = false; bool mError = false; diff --git a/src/utils/DawnHelpers.cpp b/src/utils/DawnHelpers.cpp index 01daa9bd7b..0871f26ee0 100644 --- a/src/utils/DawnHelpers.cpp +++ b/src/utils/DawnHelpers.cpp @@ -25,46 +25,53 @@ namespace utils { + namespace { + + shaderc_shader_kind ShadercShaderKind(dawn::ShaderStage stage) { + switch (stage) { + case dawn::ShaderStage::Vertex: + return shaderc_glsl_vertex_shader; + case dawn::ShaderStage::Fragment: + return shaderc_glsl_fragment_shader; + case dawn::ShaderStage::Compute: + return shaderc_glsl_compute_shader; + default: + UNREACHABLE(); + } + } + + dawn::ShaderModule CreateShaderModuleFromResult( + const dawn::Device& device, + const shaderc::SpvCompilationResult& result) { + // result.cend and result.cbegin return pointers to uint32_t. + const uint32_t* resultBegin = result.cbegin(); + const uint32_t* resultEnd = result.cend(); + // So this size is in units of sizeof(uint32_t). + ptrdiff_t resultSize = resultEnd - resultBegin; + // SetSource takes data as uint32_t*. + + dawn::ShaderModuleDescriptor descriptor; + descriptor.codeSize = static_cast(resultSize); + descriptor.code = result.cbegin(); + return device.CreateShaderModule(&descriptor); + } + + } // anonymous namespace + dawn::ShaderModule CreateShaderModule(const dawn::Device& device, dawn::ShaderStage stage, const char* source) { + shaderc_shader_kind kind = ShadercShaderKind(stage); + shaderc::Compiler compiler; - shaderc::CompileOptions options; - - shaderc_shader_kind kind; - switch (stage) { - case dawn::ShaderStage::Vertex: - kind = shaderc_glsl_vertex_shader; - break; - case dawn::ShaderStage::Fragment: - kind = shaderc_glsl_fragment_shader; - break; - case dawn::ShaderStage::Compute: - kind = shaderc_glsl_compute_shader; - break; - default: - UNREACHABLE(); - } - - auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?", options); + auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?"); if (result.GetCompilationStatus() != shaderc_compilation_status_success) { std::cerr << result.GetErrorMessage(); return {}; } - - // result.cend and result.cbegin return pointers to uint32_t. - const uint32_t* resultBegin = result.cbegin(); - const uint32_t* resultEnd = result.cend(); - // So this size is in units of sizeof(uint32_t). - ptrdiff_t resultSize = resultEnd - resultBegin; - // SetSource takes data as uint32_t*. - - dawn::ShaderModuleDescriptor descriptor; - descriptor.codeSize = static_cast(resultSize); - descriptor.code = result.cbegin(); - #ifdef DUMP_SPIRV_ASSEMBLY { + shaderc::CompileOptions options; auto resultAsm = compiler.CompileGlslToSpvAssembly(source, strlen(source), kind, "myshader?", options); size_t sizeAsm = (resultAsm.cend() - resultAsm.cbegin()); @@ -91,7 +98,18 @@ namespace utils { printf("SPIRV JS ARRAY DUMP END\n"); #endif - return device.CreateShaderModule(&descriptor); + return CreateShaderModuleFromResult(device, result); + } + + dawn::ShaderModule CreateShaderModuleFromASM(const dawn::Device& device, const char* source) { + shaderc::Compiler compiler; + shaderc::SpvCompilationResult result = compiler.AssembleToSpv(source, strlen(source)); + if (result.GetCompilationStatus() != shaderc_compilation_status_success) { + std::cerr << result.GetErrorMessage(); + return {}; + } + + return CreateShaderModuleFromResult(device, result); } dawn::Buffer CreateBufferFromData(const dawn::Device& device, diff --git a/src/utils/DawnHelpers.h b/src/utils/DawnHelpers.h index 88698ceff3..03a3e3018e 100644 --- a/src/utils/DawnHelpers.h +++ b/src/utils/DawnHelpers.h @@ -21,6 +21,8 @@ namespace utils { dawn::ShaderModule CreateShaderModule(const dawn::Device& device, dawn::ShaderStage stage, const char* source); + dawn::ShaderModule CreateShaderModuleFromASM(const dawn::Device& device, const char* source); + dawn::Buffer CreateBufferFromData(const dawn::Device& device, const void* data, uint32_t size, diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 4d70444c25..61e614c0c6 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -41,6 +41,9 @@ set(GLAD_INCLUDE_DIR ${GLAD_INCLUDE_DIR} PARENT_SCOPE) target_include_directories(glad SYSTEM PUBLIC ${GLAD_INCLUDE_DIR}) DawnExternalTarget("third_party" glad) +# SPIRV-Tools +set(SPIRV_TOOLS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/spirv-tools/include PARENT_SCOPE) + # ShaderC # Prevent SPIRV-Tools from using Werror as it has a warning on MSVC set(SPIRV_WERROR OFF CACHE BOOL "" FORCE)