diff --git a/src/inspector/entry_point.h b/src/inspector/entry_point.h index b3570bbb53..6fc4de673e 100644 --- a/src/inspector/entry_point.h +++ b/src/inspector/entry_point.h @@ -69,17 +69,19 @@ struct EntryPoint { /// The entry point stage ast::PipelineStage stage = ast::PipelineStage::kNone; /// The workgroup x size - uint32_t workgroup_size_x; + uint32_t workgroup_size_x = 0; /// The workgroup y size - uint32_t workgroup_size_y; + uint32_t workgroup_size_y = 0; /// The workgroup z size - uint32_t workgroup_size_z; + uint32_t workgroup_size_z = 0; /// List of the input variable accessed via this entry point. std::vector input_variables; /// List of the output variable accessed via this entry point. std::vector output_variables; /// List of the pipeline overridable constants accessed via this entry point. std::vector overridable_constants; + /// Does the entry point use the sample_mask builtin + bool sample_mask_used = false; /// @returns the size of the workgroup in {x,y,z} format std::tuple workgroup_size() { diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index c99e6b8a84..6afa77d70e 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -100,6 +100,9 @@ std::vector Inspector::GetEntryPoints() { entry_point.output_variables); } + entry_point.sample_mask_used = ContainsSampleMaskBuiltin( + sem->ReturnType(), func->return_type_decorations()); + for (auto* var : sem->ReferencedModuleVariables()) { auto* decl = var->Declaration(); @@ -535,6 +538,31 @@ void Inspector::AddEntryPointInOutVariables( variables.push_back(stage_variable); } +bool Inspector::ContainsSampleMaskBuiltin( + sem::Type* type, + const ast::DecorationList& decorations) const { + auto* unwrapped_type = type->UnwrapRef(); + + if (auto* struct_ty = unwrapped_type->As()) { + // Recurse into members. + for (auto* member : struct_ty->Members()) { + if (ContainsSampleMaskBuiltin(member->Type(), + member->Declaration()->decorations())) { + return true; + } + } + return false; + } + + // Base case: check for [[builtin(sample_mask)]] + auto* builtin = ast::GetDecoration(decorations); + if (!builtin || builtin->value() != ast::Builtin::kSampleMask) { + return false; + } + + return true; +} + std::vector Inspector::GetStorageBufferResourceBindingsImpl( const std::string& entry_point, bool read_only) { diff --git a/src/inspector/inspector.h b/src/inspector/inspector.h index 954907f8ec..aba355c03a 100644 --- a/src/inspector/inspector.h +++ b/src/inspector/inspector.h @@ -151,6 +151,12 @@ class Inspector { const ast::DecorationList& decorations, std::vector& variables) const; + /// Recursively determine if the type contains [[builtin(sample_mask)]] + /// If `type` is a struct, recurse into members to check for the decoration. + /// Otherwise, check `decorations` for the decoration. + bool ContainsSampleMaskBuiltin(sem::Type* type, + const ast::DecorationList& decorations) const; + /// @param entry_point name of the entry point to get information about. /// @param read_only if true get only read-only bindings, if false get /// write-only bindings. diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc index 8d0d4a8124..4b8af26940 100644 --- a/src/inspector/inspector_test.cc +++ b/src/inspector/inspector_test.cc @@ -1559,6 +1559,50 @@ TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) { EXPECT_EQ(0u, result[0].overridable_constants.size()); } +TEST_F(InspectorGetEntryPointTest, SampleMaskNotReferenced) { + MakeEmptyBodyFunction("ep_func", {Stage(ast::PipelineStage::kFragment)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + EXPECT_FALSE(result[0].sample_mask_used); +} + +TEST_F(InspectorGetEntryPointTest, SampleMaskSimpleReferenced) { + auto* in_var = + Param("in_var", ty.u32(), {Builtin(ast::Builtin::kSampleMask)}); + Func("ep_func", {in_var}, ty.u32(), {Return("in_var")}, + {Stage(ast::PipelineStage::kFragment)}, + {Builtin(ast::Builtin::kSampleMask)}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + EXPECT_TRUE(result[0].sample_mask_used); +} + +TEST_F(InspectorGetEntryPointTest, SampleMaskStructReferenced) { + ast::StructMemberList members; + members.push_back(Member("inner_sample_mask", ty.u32(), + {Builtin(ast::Builtin::kSampleMask)})); + Structure("out_struct", members, {}); + + Func("ep_func", {}, ty.type_name("out_struct"), + {Decl(Var("out_var", ty.type_name("out_struct"))), Return("out_var")}, + {Stage(ast::PipelineStage::kFragment)}, {}); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + EXPECT_TRUE(result[0].sample_mask_used); +} + // TODO(rharrison): Reenable once GetRemappedNameForEntryPoint isn't a pass // through TEST_F(InspectorGetRemappedNameForEntryPointTest, DISABLED_NoFunctions) {