diff --git a/fuzzers/tint_common_fuzzer.cc b/fuzzers/tint_common_fuzzer.cc index 7469150470..c663648e07 100644 --- a/fuzzers/tint_common_fuzzer.cc +++ b/fuzzers/tint_common_fuzzer.cc @@ -16,10 +16,15 @@ #include #include +#include #include #include #include +#if TINT_BUILD_SPV_READER +#include "spirv-tools/libspirv.hpp" +#endif // TINT_BUILD_SPV_READER + #include "src/ast/module.h" #include "src/diagnostic/formatter.h" #include "src/program.h" @@ -29,21 +34,19 @@ namespace fuzzers { namespace { -[[noreturn]] void TintInternalCompilerErrorReporter( - const tint::diag::List& diagnostics) { +[[noreturn]] void FatalError(const tint::diag::List& diags, + std::string msg = "") { auto printer = tint::diag::Printer::create(stderr, true); - tint::diag::Formatter{}.format(diagnostics, printer.get()); + if (msg.size()) { + printer->write((msg + "\n").c_str(), {diag::Color::kRed, true}); + } + tint::diag::Formatter().format(diags, printer.get()); __builtin_trap(); } -[[noreturn]] void ValidityErrorReporter(const tint::diag::List& diags) { - auto printer = tint::diag::Printer::create(stderr, true); - printer->write( - "Fuzzing detected valid input program being transformed into an invalid " - "output progam\n", - {diag::Color::kRed, true}); - tint::diag::Formatter().format(diags, printer.get()); - __builtin_trap(); +[[noreturn]] void TintInternalCompilerErrorReporter( + const tint::diag::List& diagnostics) { + FatalError(diagnostics); } transform::VertexAttributeDescriptor ExtractVertexAttributeDescriptor( @@ -66,6 +69,26 @@ transform::VertexBufferLayoutDescriptor ExtractVertexBufferLayoutDescriptor( return desc; } +bool SPIRVToolsValidationCheck(const tint::Program& program, + std::vector spirv) { + spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_1); + const tint::diag::List& diags = program.Diagnostics(); + tools.SetMessageConsumer([diags](spv_message_level_t, const char*, + const spv_position_t& pos, const char* msg) { + std::stringstream out; + out << "Unexpected spirv-val error:\n" + << (pos.line + 1) << ":" << (pos.column + 1) << ": " << msg + << std::endl; + + auto printer = tint::diag::Printer::create(stderr, true); + printer->write(out.str(), {diag::Color::kYellow, false}); + tint::diag::Formatter().format(diags, printer.get()); + }); + + return tools.Validate(spirv.data(), spirv.size(), + spvtools::ValidatorOptions()); +} + } // namespace Reader::Reader(const uint8_t* data, size_t size) : data_(data), size_(size) {} @@ -162,6 +185,13 @@ int CommonFuzzer::Run(const uint8_t* data, size_t size) { std::unique_ptr file; #endif // TINT_BUILD_WGSL_READER +#if TINT_BUILD_SPV_READER + size_t u32_size = size / sizeof(uint32_t); + const uint32_t* u32_data = reinterpret_cast(data); + std::vector spirv_input(u32_data, u32_data + u32_size); + +#endif // TINT_BUILD_SPV_READER + switch (input_) { #if TINT_BUILD_WGSL_READER case InputFormat::kWGSL: { @@ -173,16 +203,12 @@ int CommonFuzzer::Run(const uint8_t* data, size_t size) { #endif // TINT_BUILD_WGSL_READER #if TINT_BUILD_SPV_READER case InputFormat::kSpv: { - size_t sizeInU32 = size / sizeof(uint32_t); - const uint32_t* u32Data = reinterpret_cast(data); - std::vector input(u32Data, u32Data + sizeInU32); - - if (input.size() != 0) { - program = reader::spirv::Parse(input); + if (spirv_input.size() != 0) { + program = reader::spirv::Parse(spirv_input); } break; } -#endif // TINT_BUILD_WGSL_READER +#endif // TINT_BUILD_SPV_READER default: return 0; } @@ -196,6 +222,14 @@ int CommonFuzzer::Run(const uint8_t* data, size_t size) { return 0; } +#if TINT_BUILD_SPV_READER + if (input_ == InputFormat::kSpv && + !SPIRVToolsValidationCheck(program, spirv_input)) { + FatalError(program.Diagnostics(), + "Fuzzing detected invalid input spirv not being caught by Tint"); + } +#endif // TINT_BUILD_SPV_READER + if (inspector_enabled_) { inspector::Inspector inspector(&program); @@ -276,7 +310,9 @@ int CommonFuzzer::Run(const uint8_t* data, size_t size) { for (auto diag : out.program.Diagnostics()) { if (diag.severity > diag::Severity::Error || diag.system != diag::System::Transform) { - ValidityErrorReporter(out.program.Diagnostics()); + FatalError(out.program.Diagnostics(), + "Fuzzing detected valid input program being transformed " + "into an invalid output program"); } } } @@ -314,6 +350,16 @@ int CommonFuzzer::Run(const uint8_t* data, size_t size) { errors_ = writer_->error(); return 0; } + +#if TINT_BUILD_SPV_WRITER + if (output_ == OutputFormat::kSpv && + !SPIRVToolsValidationCheck( + program, + static_cast(writer_.get())->result())) { + FatalError(program.Diagnostics(), + "Fuzzing detected invalid spirv being emitted by Tint"); + } +#endif // TINT_BUILD_SPV_WRITER } return 0;