From 676ec7cf996460021f12070a92c2ec875073fed4 Mon Sep 17 00:00:00 2001 From: James Price Date: Tue, 7 Sep 2021 18:59:21 +0000 Subject: [PATCH] spirv: Handle sample_mask in shader IO transform This is easy to do while we are processing builtins in the main transform now that we use wrapper functions. This is step towards removing the sanitizers completely. Change-Id: If5472ce552e3cce1e5905916eeffa8fef90461c9 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63585 Kokoro: Kokoro Reviewed-by: Ben Clayton --- src/transform/canonicalize_entry_point_io.cc | 35 +++-- .../canonicalize_entry_point_io_test.cc | 40 +++++- src/transform/spirv.cc | 50 +------ src/transform/spirv.h | 4 - src/transform/spirv_test.cc | 135 +----------------- 5 files changed, 67 insertions(+), 197 deletions(-) diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index 50d1548326..91568cb63e 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -67,6 +67,12 @@ bool IsShaderIODecoration(const ast::Decoration* deco) { ast::InvariantDecoration, ast::LocationDecoration>(); } +// Returns true if `decos` contains a `sample_mask` builtin. +bool HasSampleMask(const ast::DecorationList& decos) { + auto* builtin = ast::GetDecoration(decos); + return builtin && builtin->value() == ast::Builtin::kSampleMask; +} + } // namespace /// State holds the current transform state for a single entry point. @@ -166,9 +172,16 @@ struct CanonicalizeEntryPointIO::State { // Create the global variable and use its value for the shader input. auto var = ctx.dst->Symbols().New(name); + ast::Expression* value = ctx.dst->Expr(var); + if (HasSampleMask(attributes)) { + // Vulkan requires the type of a SampleMask builtin to be an array. + // Declare it as array and then load the first element. + type = ctx.dst->ty.array(type, 1); + value = ctx.dst->IndexAccessor(value, 0); + } ctx.dst->Global(var, type, ast::StorageClass::kInput, std::move(attributes)); - return ctx.dst->Expr(var); + return value; } else if (cfg.shader_style == ShaderStyle::kMsl && ast::HasDecoration(attributes)) { // If this input is a builtin and we are targeting MSL, then add it to the @@ -303,9 +316,7 @@ struct CanonicalizeEntryPointIO::State { void AddFixedSampleMask() { // Check the existing output values for a sample mask builtin. for (auto& outval : wrapper_output_values) { - auto* builtin = - ast::GetDecoration(outval.attributes); - if (builtin && builtin->value() == ast::Builtin::kSampleMask) { + if (HasSampleMask(outval.attributes)) { // Combine the authored sample mask with the fixed mask. outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask); return; @@ -390,7 +401,7 @@ struct CanonicalizeEntryPointIO::State { } /// Create and assign the wrapper function's output variables. - void CreateOutputVariables() { + void CreateSpirvOutputVariables() { for (auto& outval : wrapper_output_values) { // Disable validation for use of the `output` storage class. ast::DecorationList attributes = std::move(outval.attributes); @@ -400,9 +411,17 @@ struct CanonicalizeEntryPointIO::State { // Create the global variable and assign it the output value. auto name = ctx.dst->Symbols().New(outval.name); - ctx.dst->Global(name, outval.type, ast::StorageClass::kOutput, + auto* type = outval.type; + ast::Expression* lhs = ctx.dst->Expr(name); + if (HasSampleMask(attributes)) { + // Vulkan requires the type of a SampleMask builtin to be an array. + // Declare it as array and then store to the first element. + type = ctx.dst->ty.array(type, 1); + lhs = ctx.dst->IndexAccessor(lhs, 0); + } + ctx.dst->Global(name, type, ast::StorageClass::kOutput, std::move(attributes)); - wrapper_body.push_back(ctx.dst->Assign(name, outval.value)); + wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value)); } } @@ -498,7 +517,7 @@ struct CanonicalizeEntryPointIO::State { // Produce the entry point outputs, if necessary. if (!wrapper_output_values.empty()) { if (cfg.shader_style == ShaderStyle::kSpirv) { - CreateOutputVariables(); + CreateSpirvOutputVariables(); } else { auto* output_struct = CreateOutputStruct(); wrapper_ret_type = [&, output_struct] { diff --git a/src/transform/canonicalize_entry_point_io_test.cc b/src/transform/canonicalize_entry_point_io_test.cc index e4d8413486..1dc93b17eb 100644 --- a/src/transform/canonicalize_entry_point_io_test.cc +++ b/src/transform/canonicalize_entry_point_io_test.cc @@ -491,7 +491,7 @@ fn frag_main() -> FragOutput { [[builtin(frag_depth), internal(disable_validation__ignore_storage_class)]] var depth_1 : f32; -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_1 : u32; +[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_1 : array; struct FragOutput { color : vec4; @@ -512,7 +512,7 @@ fn frag_main() { let inner_result = frag_main_inner(); color_1 = inner_result.color; depth_1 = inner_result.depth; - mask_1 = inner_result.mask; + mask_1[0] = inner_result.mask; } )"; @@ -2163,6 +2163,42 @@ fn vert_main() -> tint_symbol { EXPECT_EQ(expect, str(got)); } +TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) { + auto* src = R"( +[[stage(fragment)]] +fn main([[builtin(sample_index)]] sample_index : u32, + [[builtin(sample_mask)]] mask_in : u32 + ) -> [[builtin(sample_mask)]] u32 { + return mask_in; +} +)"; + + auto* expect = R"( +[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var sample_index_1 : u32; + +[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; + +[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value : array; + +fn main_inner(sample_index : u32, mask_in : u32) -> u32 { + return mask_in; +} + +[[stage(fragment)]] +fn main() { + let inner_result = main_inner(sample_index_1, mask_in_1[0]); + value[0] = inner_result; +} +)"; + + DataMap data; + data.Add( + CanonicalizeEntryPointIO::ShaderStyle::kSpirv); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace transform } // namespace tint diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 6bd220aa3b..76fca09d9c 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -68,61 +68,13 @@ Output Spirv::Run(const Program* in, const DataMap& data) { ProgramBuilder builder; CloneContext ctx(&builder, &transformedInput.program); - HandleSampleMaskBuiltins(ctx); + // TODO(jrprice): Move the sanitizer into the backend. ctx.Clone(); builder.SetTransformApplied(this); return Output{Program(std::move(builder))}; } -void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const { - // Find global variables decorated with [[builtin(sample_mask)]] and - // change their type from `u32` to `array`, as required by Vulkan. - // - // Before: - // ``` - // [[builtin(sample_mask)]] var mask_out : u32; - // fn main() { - // mask_out = 1u; - // } - // ``` - // After: - // ``` - // [[builtin(sample_mask)]] var mask_out : array; - // fn main() { - // mask_out[0] = 1u; - // } - // ``` - - for (auto* var : ctx.src->AST().GlobalVariables()) { - for (auto* deco : var->decorations()) { - if (auto* builtin = deco->As()) { - if (builtin->value() != ast::Builtin::kSampleMask) { - continue; - } - - // Use the same name as the old variable. - auto var_name = ctx.Clone(var->symbol()); - // Use `array` for the new variable. - auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u); - // Create the new variable. - auto* var_arr = ctx.dst->Var(var->source(), var_name, type, - var->declared_storage_class(), nullptr, - ctx.Clone(var->decorations())); - // Replace the variable with the arrayed version. - ctx.Replace(var, var_arr); - - // Replace all uses of the old variable with `var_arr[0]`. - for (auto* user : ctx.src->Sem().Get(var)->Users()) { - auto* new_ident = ctx.dst->IndexAccessor( - ctx.dst->Expr(var_arr->symbol()), ctx.dst->Expr(0)); - ctx.Replace(user->Declaration(), new_ident); - } - } - } - } -} - Spirv::Config::Config(bool emit_vps, bool disable_wi) : emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {} diff --git a/src/transform/spirv.h b/src/transform/spirv.h index faf5b966e2..7e633ab5f3 100644 --- a/src/transform/spirv.h +++ b/src/transform/spirv.h @@ -67,10 +67,6 @@ class Spirv : public Castable { /// @param data optional extra transform-specific input data /// @returns the transformation result Output Run(const Program* program, const DataMap& data = {}) override; - - private: - /// Change type of sample mask builtin variables to single element arrays. - void HandleSampleMaskBuiltins(CloneContext& ctx) const; }; } // namespace transform diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc index 26ed757a88..e2d4c3c60b 100644 --- a/src/transform/spirv_test.cc +++ b/src/transform/spirv_test.cc @@ -22,140 +22,7 @@ namespace { using SpirvTest = TransformTest; -TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) { - auto* src = R"( -[[stage(fragment)]] -fn main([[builtin(sample_index)]] sample_index : u32, - [[builtin(sample_mask)]] mask_in : u32 - ) -> [[builtin(sample_mask)]] u32 { - return mask_in; -} -)"; - - auto* expect = R"( -[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var sample_index_1 : u32; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value : array; - -fn main_inner(sample_index : u32, mask_in : u32) -> u32 { - return mask_in; -} - -[[stage(fragment)]] -fn main() { - let inner_result = main_inner(sample_index_1, mask_in_1[0]); - value[0] = inner_result; -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - -TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) { - auto* src = R"( -fn filter(mask: u32) -> u32 { - return (mask & 3u); -} - -fn set_mask(input : u32) -> u32 { - return input; -} - -[[stage(fragment)]] -fn main([[builtin(sample_mask)]] mask_in : u32 - ) -> [[builtin(sample_mask)]] u32 { - return set_mask(filter(mask_in)); -} -)"; - - auto* expect = R"( -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value : array; - -fn filter(mask : u32) -> u32 { - return (mask & 3u); -} - -fn set_mask(input : u32) -> u32 { - return input; -} - -fn main_inner(mask_in : u32) -> u32 { - return set_mask(filter(mask_in)); -} - -[[stage(fragment)]] -fn main() { - let inner_result = main_inner(mask_in_1[0]); - value[0] = inner_result; -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - -// Test that different transforms within the sanitizer interact correctly. -TEST_F(SpirvTest, MultipleTransforms) { - auto* src = R"( -[[stage(vertex)]] -fn vert_main() -> [[builtin(position)]] vec4 { - return vec4(); -} - -[[stage(fragment)]] -fn frag_main([[builtin(sample_index)]] sample_index : u32, - [[builtin(sample_mask)]] mask_in : u32) - -> [[builtin(sample_mask)]] u32 { - return mask_in; -} -)"; - - auto* expect = R"( -[[builtin(position), internal(disable_validation__ignore_storage_class)]] var value : vec4; - -[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var vertex_point_size : f32; - -[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var sample_index_1 : u32; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value_1 : array; - -fn vert_main_inner() -> vec4 { - return vec4(); -} - -[[stage(vertex)]] -fn vert_main() { - let inner_result = vert_main_inner(); - value = inner_result; - vertex_point_size = 1.0; -} - -fn frag_main_inner(sample_index : u32, mask_in : u32) -> u32 { - return mask_in; -} - -[[stage(fragment)]] -fn frag_main() { - let inner_result_1 = frag_main_inner(sample_index_1, mask_in_1[0]); - value_1[0] = inner_result_1; -} -)"; - - DataMap data; - data.Add(true); - auto got = Run(src, data); - - EXPECT_EQ(expect, str(got)); -} +// TODO(jrprice): Remove this file when we remove the sanitizers. } // namespace } // namespace transform