diff --git a/src/tint/cmd/main.cc b/src/tint/cmd/main.cc index a30a133196..fa8dbc03f4 100644 --- a/src/tint/cmd/main.cc +++ b/src/tint/cmd/main.cc @@ -103,6 +103,10 @@ struct Options { bool rename_all = false; +#if TINT_BUILD_SPV_READER + tint::reader::spirv::Options spirv_reader_options; +#endif + std::vector transforms; std::string fxc_path; @@ -135,6 +139,9 @@ const char kUsage[] = R"(Usage: tint [options] --transform -- Runs transforms, name list is comma separated Available transforms: ${transforms} --parse-only -- Stop after parsing the input + --allow-non-uniform-derivatives -- When using SPIR-V input, allow non-uniform derivatives by + inserting a module-scope directive to suppress any uniformity + violations that may be produced. --disable-workgroup-init -- Disable workgroup memory zero initialization. --demangle -- Preserve original source names. Demangle them. Affects AST dumping, and text-based output languages. @@ -443,6 +450,13 @@ bool ParseArgs(const std::vector& args, Options* opts) { opts->transforms = split_on_comma(args[i]); } else if (arg == "--parse-only") { opts->parse_only = true; + } else if (arg == "--allow-non-uniform-derivatives") { +#if TINT_BUILD_SPV_READER + opts->spirv_reader_options.allow_non_uniform_derivatives = true; +#else + std::cerr << "Tint not built with the SPIR-V reader enabled" << std::endl; + return false; +#endif } else if (arg == "--disable-workgroup-init") { opts->disable_workgroup_init = true; } else if (arg == "--demangle") { @@ -1285,7 +1299,8 @@ int main(int argc, const char** argv) { if (!ReadFile(options.input_filename, &data)) { return 1; } - program = std::make_unique(tint::reader::spirv::Parse(data)); + program = std::make_unique( + tint::reader::spirv::Parse(data, options.spirv_reader_options)); break; #else std::cerr << "Tint not built with the SPIR-V reader enabled" << std::endl; @@ -1309,7 +1324,8 @@ int main(int argc, const char** argv) { SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS)) { return 1; } - program = std::make_unique(tint::reader::spirv::Parse(data)); + program = std::make_unique( + tint::reader::spirv::Parse(data, options.spirv_reader_options)); break; #else std::cerr << "Tint not built with the SPIR-V reader enabled" << std::endl; diff --git a/src/tint/reader/spirv/parser.cc b/src/tint/reader/spirv/parser.cc index ac43b9e54c..41e6df3589 100644 --- a/src/tint/reader/spirv/parser.cc +++ b/src/tint/reader/spirv/parser.cc @@ -27,7 +27,7 @@ namespace tint::reader::spirv { -Program Parse(const std::vector& input) { +Program Parse(const std::vector& input, const Options& options) { ParserImpl parser(input); bool parsed = parser.Parse(); @@ -38,6 +38,13 @@ Program Parse(const std::vector& input) { return Program(std::move(builder)); } + if (options.allow_non_uniform_derivatives) { + // Suppress errors regarding non-uniform derivative operations if requested, by adding a + // diagnostic directive to the module. + builder.DiagnosticDirective(ast::DiagnosticSeverity::kOff, + builder.Expr("derivative_uniformity")); + } + // The SPIR-V parser can construct disjoint AST nodes, which is invalid for // the Resolver. Clone the Program to clean these up. builder.SetResolveOnBuild(false); diff --git a/src/tint/reader/spirv/parser.h b/src/tint/reader/spirv/parser.h index 3641e08daa..78c80c0d84 100644 --- a/src/tint/reader/spirv/parser.h +++ b/src/tint/reader/spirv/parser.h @@ -21,13 +21,20 @@ namespace tint::reader::spirv { +/// Options that control how the SPIR-V parser should behave. +struct Options { + /// Set to `true` to allow calls to derivative builtins in non-uniform control flow. + bool allow_non_uniform_derivatives = false; +}; + /// Parses the SPIR-V source data, returning the parsed program. /// If the source data fails to parse then the returned /// `program.Diagnostics.contains_errors()` will be true, and the /// `program.Diagnostics()` will describe the error. /// @param input the source data +/// @param options the parser options /// @returns the parsed program -Program Parse(const std::vector& input); +Program Parse(const std::vector& input, const Options& options = {}); } // namespace tint::reader::spirv diff --git a/src/tint/reader/spirv/parser_test.cc b/src/tint/reader/spirv/parser_test.cc index 35cb5da8e9..3f5e3703ce 100644 --- a/src/tint/reader/spirv/parser_test.cc +++ b/src/tint/reader/spirv/parser_test.cc @@ -14,7 +14,9 @@ #include "src/tint/reader/spirv/parser.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "src/tint/reader/spirv/spirv_tools_helpers_test.h" namespace tint::reader::spirv { namespace { @@ -29,6 +31,54 @@ TEST_F(ParserTest, DataEmpty) { EXPECT_EQ(errs, "error: line:0: Invalid SPIR-V magic number.\n"); } +constexpr auto kShaderWithNonUniformDerivative = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %foo "foo" %x + OpExecutionMode %foo OriginUpperLeft + OpDecorate %x Location 0 + %float = OpTypeFloat 32 +%_ptr_Input_float = OpTypePointer Input %float + %x = OpVariable %_ptr_Input_float Input + %void = OpTypeVoid + %float_0 = OpConstantNull %float + %bool = OpTypeBool + %func_type = OpTypeFunction %void + %foo = OpFunction %void None %func_type + %foo_start = OpLabel + %x_value = OpLoad %float %x + %condition = OpFOrdGreaterThan %bool %x_value %float_0 + OpSelectionMerge %merge None + OpBranchConditional %condition %true_branch %merge +%true_branch = OpLabel + %result = OpDPdx %float %x_value + OpBranch %merge + %merge = OpLabel + OpReturn + OpFunctionEnd +)"; + +TEST_F(ParserTest, AllowNonUniformDerivatives_False) { + auto spv = test::Assemble(kShaderWithNonUniformDerivative); + Options options; + options.allow_non_uniform_derivatives = false; + auto program = Parse(spv, options); + auto errs = diag::Formatter().format(program.Diagnostics()); + // TODO(jrprice): This will become EXPECT_FALSE. + EXPECT_TRUE(program.IsValid()) << errs; + EXPECT_THAT(errs, ::testing::HasSubstr("'dpdx' must only be called from uniform control flow")); +} + +TEST_F(ParserTest, AllowNonUniformDerivatives_True) { + auto spv = test::Assemble(kShaderWithNonUniformDerivative); + Options options; + options.allow_non_uniform_derivatives = true; + auto program = Parse(spv, options); + auto errs = diag::Formatter().format(program.Diagnostics()); + EXPECT_TRUE(program.IsValid()) << errs; + EXPECT_EQ(program.Diagnostics().count(), 0u) << errs; +} + // TODO(dneto): uint32 vec, valid SPIR-V // TODO(dneto): uint32 vec, invalid SPIR-V