diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index 50d1548326..91568cb63e 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -67,6 +67,12 @@ bool IsShaderIODecoration(const ast::Decoration* deco) { ast::InvariantDecoration, ast::LocationDecoration>(); } +// Returns true if `decos` contains a `sample_mask` builtin. +bool HasSampleMask(const ast::DecorationList& decos) { + auto* builtin = ast::GetDecoration(decos); + return builtin && builtin->value() == ast::Builtin::kSampleMask; +} + } // namespace /// State holds the current transform state for a single entry point. @@ -166,9 +172,16 @@ struct CanonicalizeEntryPointIO::State { // Create the global variable and use its value for the shader input. auto var = ctx.dst->Symbols().New(name); + ast::Expression* value = ctx.dst->Expr(var); + if (HasSampleMask(attributes)) { + // Vulkan requires the type of a SampleMask builtin to be an array. + // Declare it as array and then load the first element. + type = ctx.dst->ty.array(type, 1); + value = ctx.dst->IndexAccessor(value, 0); + } ctx.dst->Global(var, type, ast::StorageClass::kInput, std::move(attributes)); - return ctx.dst->Expr(var); + return value; } else if (cfg.shader_style == ShaderStyle::kMsl && ast::HasDecoration(attributes)) { // If this input is a builtin and we are targeting MSL, then add it to the @@ -303,9 +316,7 @@ struct CanonicalizeEntryPointIO::State { void AddFixedSampleMask() { // Check the existing output values for a sample mask builtin. for (auto& outval : wrapper_output_values) { - auto* builtin = - ast::GetDecoration(outval.attributes); - if (builtin && builtin->value() == ast::Builtin::kSampleMask) { + if (HasSampleMask(outval.attributes)) { // Combine the authored sample mask with the fixed mask. outval.value = ctx.dst->And(outval.value, cfg.fixed_sample_mask); return; @@ -390,7 +401,7 @@ struct CanonicalizeEntryPointIO::State { } /// Create and assign the wrapper function's output variables. - void CreateOutputVariables() { + void CreateSpirvOutputVariables() { for (auto& outval : wrapper_output_values) { // Disable validation for use of the `output` storage class. ast::DecorationList attributes = std::move(outval.attributes); @@ -400,9 +411,17 @@ struct CanonicalizeEntryPointIO::State { // Create the global variable and assign it the output value. auto name = ctx.dst->Symbols().New(outval.name); - ctx.dst->Global(name, outval.type, ast::StorageClass::kOutput, + auto* type = outval.type; + ast::Expression* lhs = ctx.dst->Expr(name); + if (HasSampleMask(attributes)) { + // Vulkan requires the type of a SampleMask builtin to be an array. + // Declare it as array and then store to the first element. + type = ctx.dst->ty.array(type, 1); + lhs = ctx.dst->IndexAccessor(lhs, 0); + } + ctx.dst->Global(name, type, ast::StorageClass::kOutput, std::move(attributes)); - wrapper_body.push_back(ctx.dst->Assign(name, outval.value)); + wrapper_body.push_back(ctx.dst->Assign(lhs, outval.value)); } } @@ -498,7 +517,7 @@ struct CanonicalizeEntryPointIO::State { // Produce the entry point outputs, if necessary. if (!wrapper_output_values.empty()) { if (cfg.shader_style == ShaderStyle::kSpirv) { - CreateOutputVariables(); + CreateSpirvOutputVariables(); } else { auto* output_struct = CreateOutputStruct(); wrapper_ret_type = [&, output_struct] { diff --git a/src/transform/canonicalize_entry_point_io_test.cc b/src/transform/canonicalize_entry_point_io_test.cc index e4d8413486..1dc93b17eb 100644 --- a/src/transform/canonicalize_entry_point_io_test.cc +++ b/src/transform/canonicalize_entry_point_io_test.cc @@ -491,7 +491,7 @@ fn frag_main() -> FragOutput { [[builtin(frag_depth), internal(disable_validation__ignore_storage_class)]] var depth_1 : f32; -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_1 : u32; +[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_1 : array; struct FragOutput { color : vec4; @@ -512,7 +512,7 @@ fn frag_main() { let inner_result = frag_main_inner(); color_1 = inner_result.color; depth_1 = inner_result.depth; - mask_1 = inner_result.mask; + mask_1[0] = inner_result.mask; } )"; @@ -2163,6 +2163,42 @@ fn vert_main() -> tint_symbol { EXPECT_EQ(expect, str(got)); } +TEST_F(CanonicalizeEntryPointIOTest, SpirvSampleMaskBuiltins) { + auto* src = R"( +[[stage(fragment)]] +fn main([[builtin(sample_index)]] sample_index : u32, + [[builtin(sample_mask)]] mask_in : u32 + ) -> [[builtin(sample_mask)]] u32 { + return mask_in; +} +)"; + + auto* expect = R"( +[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var sample_index_1 : u32; + +[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; + +[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value : array; + +fn main_inner(sample_index : u32, mask_in : u32) -> u32 { + return mask_in; +} + +[[stage(fragment)]] +fn main() { + let inner_result = main_inner(sample_index_1, mask_in_1[0]); + value[0] = inner_result; +} +)"; + + DataMap data; + data.Add( + CanonicalizeEntryPointIO::ShaderStyle::kSpirv); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace transform } // namespace tint diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 6bd220aa3b..76fca09d9c 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -68,61 +68,13 @@ Output Spirv::Run(const Program* in, const DataMap& data) { ProgramBuilder builder; CloneContext ctx(&builder, &transformedInput.program); - HandleSampleMaskBuiltins(ctx); + // TODO(jrprice): Move the sanitizer into the backend. ctx.Clone(); builder.SetTransformApplied(this); return Output{Program(std::move(builder))}; } -void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const { - // Find global variables decorated with [[builtin(sample_mask)]] and - // change their type from `u32` to `array`, as required by Vulkan. - // - // Before: - // ``` - // [[builtin(sample_mask)]] var mask_out : u32; - // fn main() { - // mask_out = 1u; - // } - // ``` - // After: - // ``` - // [[builtin(sample_mask)]] var mask_out : array; - // fn main() { - // mask_out[0] = 1u; - // } - // ``` - - for (auto* var : ctx.src->AST().GlobalVariables()) { - for (auto* deco : var->decorations()) { - if (auto* builtin = deco->As()) { - if (builtin->value() != ast::Builtin::kSampleMask) { - continue; - } - - // Use the same name as the old variable. - auto var_name = ctx.Clone(var->symbol()); - // Use `array` for the new variable. - auto* type = ctx.dst->ty.array(ctx.dst->ty.u32(), 1u); - // Create the new variable. - auto* var_arr = ctx.dst->Var(var->source(), var_name, type, - var->declared_storage_class(), nullptr, - ctx.Clone(var->decorations())); - // Replace the variable with the arrayed version. - ctx.Replace(var, var_arr); - - // Replace all uses of the old variable with `var_arr[0]`. - for (auto* user : ctx.src->Sem().Get(var)->Users()) { - auto* new_ident = ctx.dst->IndexAccessor( - ctx.dst->Expr(var_arr->symbol()), ctx.dst->Expr(0)); - ctx.Replace(user->Declaration(), new_ident); - } - } - } - } -} - Spirv::Config::Config(bool emit_vps, bool disable_wi) : emit_vertex_point_size(emit_vps), disable_workgroup_init(disable_wi) {} diff --git a/src/transform/spirv.h b/src/transform/spirv.h index faf5b966e2..7e633ab5f3 100644 --- a/src/transform/spirv.h +++ b/src/transform/spirv.h @@ -67,10 +67,6 @@ class Spirv : public Castable { /// @param data optional extra transform-specific input data /// @returns the transformation result Output Run(const Program* program, const DataMap& data = {}) override; - - private: - /// Change type of sample mask builtin variables to single element arrays. - void HandleSampleMaskBuiltins(CloneContext& ctx) const; }; } // namespace transform diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc index 26ed757a88..e2d4c3c60b 100644 --- a/src/transform/spirv_test.cc +++ b/src/transform/spirv_test.cc @@ -22,140 +22,7 @@ namespace { using SpirvTest = TransformTest; -TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) { - auto* src = R"( -[[stage(fragment)]] -fn main([[builtin(sample_index)]] sample_index : u32, - [[builtin(sample_mask)]] mask_in : u32 - ) -> [[builtin(sample_mask)]] u32 { - return mask_in; -} -)"; - - auto* expect = R"( -[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var sample_index_1 : u32; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value : array; - -fn main_inner(sample_index : u32, mask_in : u32) -> u32 { - return mask_in; -} - -[[stage(fragment)]] -fn main() { - let inner_result = main_inner(sample_index_1, mask_in_1[0]); - value[0] = inner_result; -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - -TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) { - auto* src = R"( -fn filter(mask: u32) -> u32 { - return (mask & 3u); -} - -fn set_mask(input : u32) -> u32 { - return input; -} - -[[stage(fragment)]] -fn main([[builtin(sample_mask)]] mask_in : u32 - ) -> [[builtin(sample_mask)]] u32 { - return set_mask(filter(mask_in)); -} -)"; - - auto* expect = R"( -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value : array; - -fn filter(mask : u32) -> u32 { - return (mask & 3u); -} - -fn set_mask(input : u32) -> u32 { - return input; -} - -fn main_inner(mask_in : u32) -> u32 { - return set_mask(filter(mask_in)); -} - -[[stage(fragment)]] -fn main() { - let inner_result = main_inner(mask_in_1[0]); - value[0] = inner_result; -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - -// Test that different transforms within the sanitizer interact correctly. -TEST_F(SpirvTest, MultipleTransforms) { - auto* src = R"( -[[stage(vertex)]] -fn vert_main() -> [[builtin(position)]] vec4 { - return vec4(); -} - -[[stage(fragment)]] -fn frag_main([[builtin(sample_index)]] sample_index : u32, - [[builtin(sample_mask)]] mask_in : u32) - -> [[builtin(sample_mask)]] u32 { - return mask_in; -} -)"; - - auto* expect = R"( -[[builtin(position), internal(disable_validation__ignore_storage_class)]] var value : vec4; - -[[builtin(pointsize), internal(disable_validation__ignore_storage_class)]] var vertex_point_size : f32; - -[[builtin(sample_index), internal(disable_validation__ignore_storage_class)]] var sample_index_1 : u32; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var mask_in_1 : array; - -[[builtin(sample_mask), internal(disable_validation__ignore_storage_class)]] var value_1 : array; - -fn vert_main_inner() -> vec4 { - return vec4(); -} - -[[stage(vertex)]] -fn vert_main() { - let inner_result = vert_main_inner(); - value = inner_result; - vertex_point_size = 1.0; -} - -fn frag_main_inner(sample_index : u32, mask_in : u32) -> u32 { - return mask_in; -} - -[[stage(fragment)]] -fn frag_main() { - let inner_result_1 = frag_main_inner(sample_index_1, mask_in_1[0]); - value_1[0] = inner_result_1; -} -)"; - - DataMap data; - data.Add(true); - auto got = Run(src, data); - - EXPECT_EQ(expect, str(got)); -} +// TODO(jrprice): Remove this file when we remove the sanitizers. } // namespace } // namespace transform