From 77f7f5d3692496b2cf34323142a8e9852eb31fb2 Mon Sep 17 00:00:00 2001 From: David Neto Date: Thu, 29 Apr 2021 00:09:04 +0000 Subject: [PATCH] spirv-reader: register statically accessed inputs and outputs Bug: tint:508 Change-Id: I585abb0791f5ea0bcb282f12f6940e718da4956d Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48861 Kokoro: Kokoro Reviewed-by: James Price Reviewed-by: Ben Clayton Commit-Queue: James Price Auto-Submit: David Neto --- src/BUILD.gn | 1 + src/CMakeLists.txt | 1 + src/reader/spirv/entry_point_info.cc | 38 +++++++ src/reader/spirv/entry_point_info.h | 20 ++++ src/reader/spirv/parser_impl.cc | 31 +++++- .../spirv/parser_impl_module_var_test.cc | 99 +++++++++++++++++++ .../spirv/parser_impl_user_name_test.cc | 2 +- 7 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 src/reader/spirv/entry_point_info.cc diff --git a/src/BUILD.gn b/src/BUILD.gn index 941e1e79e8..0cdaac47f7 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -581,6 +581,7 @@ libtint_source_set("libtint_spv_reader_src") { sources = [ "reader/spirv/construct.cc", "reader/spirv/construct.h", + "reader/spirv/entry_point_info.cc", "reader/spirv/entry_point_info.h", "reader/spirv/enum_converter.cc", "reader/spirv/enum_converter.h", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 795091f2eb..658f7dde70 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -352,6 +352,7 @@ if(${TINT_BUILD_SPV_READER}) reader/spirv/construct.h reader/spirv/construct.cc reader/spirv/entry_point_info.h + reader/spirv/entry_point_info.cc reader/spirv/enum_converter.h reader/spirv/enum_converter.cc reader/spirv/fail_stream.h diff --git a/src/reader/spirv/entry_point_info.cc b/src/reader/spirv/entry_point_info.cc new file mode 100644 index 0000000000..61f1586db7 --- /dev/null +++ b/src/reader/spirv/entry_point_info.cc @@ -0,0 +1,38 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/reader/spirv/entry_point_info.h" + +#include + +namespace tint { +namespace reader { +namespace spirv { + +EntryPointInfo::EntryPointInfo(std::string the_name, + ast::PipelineStage the_stage, + std::vector&& the_inputs, + std::vector&& the_outputs) + : name(the_name), + stage(the_stage), + inputs(std::move(the_inputs)), + outputs(std::move(the_outputs)) {} + +EntryPointInfo::EntryPointInfo(const EntryPointInfo&) = default; + +EntryPointInfo::~EntryPointInfo() = default; + +} // namespace spirv +} // namespace reader +} // namespace tint diff --git a/src/reader/spirv/entry_point_info.h b/src/reader/spirv/entry_point_info.h index 8256794e1f..8cb11f307c 100644 --- a/src/reader/spirv/entry_point_info.h +++ b/src/reader/spirv/entry_point_info.h @@ -16,6 +16,7 @@ #define SRC_READER_SPIRV_ENTRY_POINT_INFO_H_ #include +#include #include "src/ast/pipeline_stage.h" @@ -25,10 +26,29 @@ namespace spirv { /// Entry point information for a function struct EntryPointInfo { + // Constructor. + // @param the_name the name of the entry point + // @param the_stage the pipeline stage + // @param the_inputs list of IDs for Input variables used by the shader + // @param the_outputs list of IDs for Output variables used by the shader + EntryPointInfo(std::string the_name, + ast::PipelineStage the_stage, + std::vector&& the_inputs, + std::vector&& the_outputs); + // Copy constructor + // @param other the other entry point info to be built from + EntryPointInfo(const EntryPointInfo& other); + // Destructor + ~EntryPointInfo(); + /// The entry point name std::string name; /// The entry point stage ast::PipelineStage stage = ast::PipelineStage::kNone; + /// IDs of pipeline input variables, sorted and without duplicates. + std::vector inputs; + /// IDs of pipeline output variables, sorted and without duplicates. + std::vector outputs; }; } // namespace spirv diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index f147b70395..32b32b0539 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -14,8 +14,10 @@ #include "src/reader/spirv/parser_impl.h" +#include #include #include +#include #include "source/opt/build_module.h" #include "src/ast/bitcast_expression.h" @@ -26,6 +28,7 @@ #include "src/sem/depth_texture_type.h" #include "src/sem/multisampled_texture_type.h" #include "src/sem/sampled_texture_type.h" +#include "src/utils/unique_vector.h" namespace tint { namespace reader { @@ -711,8 +714,32 @@ bool ParserImpl::RegisterEntryPoints() { const uint32_t function_id = entry_point.GetSingleWordInOperand(1); const std::string ep_name = entry_point.GetOperand(2).AsString(); - EntryPointInfo info{ep_name, enum_converter_.ToPipelineStage(stage)}; - function_to_ep_info_[function_id].push_back(info); + tint::UniqueVector inputs; + tint::UniqueVector outputs; + for (unsigned iarg = 3; iarg < entry_point.NumInOperands(); iarg++) { + const uint32_t var_id = entry_point.GetSingleWordInOperand(iarg); + if (const auto* var_inst = def_use_mgr_->GetDef(var_id)) { + switch (SpvStorageClass(var_inst->GetSingleWordInOperand(0))) { + case SpvStorageClassInput: + inputs.add(var_id); + break; + case SpvStorageClassOutput: + outputs.add(var_id); + break; + default: + break; + } + } + } + // Save the lists, in ID-sorted order. + std::vector sorted_inputs(inputs.begin(), inputs.end()); + std::sort(sorted_inputs.begin(), sorted_inputs.end()); + std::vector sorted_outputs(outputs.begin(), outputs.end()); + std::sort(sorted_inputs.begin(), sorted_inputs.end()); + + function_to_ep_info_[function_id].emplace_back( + ep_name, enum_converter_.ToPipelineStage(stage), + std::move(sorted_inputs), std::move(sorted_outputs)); } // The enum conversion could have failed, so return the existing status value. return success_; diff --git a/src/reader/spirv/parser_impl_module_var_test.cc b/src/reader/spirv/parser_impl_module_var_test.cc index b52daa2537..714d8dc7d5 100644 --- a/src/reader/spirv/parser_impl_module_var_test.cc +++ b/src/reader/spirv/parser_impl_module_var_test.cc @@ -24,6 +24,7 @@ namespace { using SpvModuleScopeVarParserTest = SpvParserTest; +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::Not; @@ -3722,6 +3723,104 @@ TEST_F(SpvModuleScopeVarParserTest, InstanceIndex_U32_FunctParam) { })")) << module_str; } +TEST_F(SpvModuleScopeVarParserTest, RegisterInputOutputVars) { + const std::string assembly = + R"( + OpCapability Shader + OpMemoryModel Logical Simple + OpEntryPoint GLCompute %1000 "w1000" + OpEntryPoint GLCompute %1100 "w1100" %1 + OpEntryPoint GLCompute %1200 "w1200" %2 %15 + ; duplication is tolerated prior to SPIR-V 1.4 + OpEntryPoint GLCompute %1300 "w1300" %1 %15 %2 %1 + +)" + CommonTypes() + + R"( + + %ptr_in_uint = OpTypePointer Input %uint + %ptr_out_uint = OpTypePointer Output %uint + + %1 = OpVariable %ptr_in_uint Input + %2 = OpVariable %ptr_in_uint Input + %5 = OpVariable %ptr_in_uint Input + %11 = OpVariable %ptr_out_uint Output + %12 = OpVariable %ptr_out_uint Output + %15 = OpVariable %ptr_out_uint Output + + %100 = OpFunction %void None %voidfn + %entry_100 = OpLabel + %load_100 = OpLoad %uint %1 + OpReturn + OpFunctionEnd + + %200 = OpFunction %void None %voidfn + %entry_200 = OpLabel + %load_200 = OpLoad %uint %2 + OpStore %15 %load_200 + OpStore %15 %load_200 + OpReturn + OpFunctionEnd + + %300 = OpFunction %void None %voidfn + %entry_300 = OpLabel + %dummy_300_1 = OpFunctionCall %void %100 + %dummy_300_2 = OpFunctionCall %void %200 + OpReturn + OpFunctionEnd + + ; Call nothing + %1000 = OpFunction %void None %voidfn + %entry_1000 = OpLabel + OpReturn + OpFunctionEnd + + ; Call %100 + %1100 = OpFunction %void None %voidfn + %entry_1100 = OpLabel + %dummy_1100_1 = OpFunctionCall %void %100 + OpReturn + OpFunctionEnd + + ; Call %200 + %1200 = OpFunction %void None %voidfn + %entry_1200 = OpLabel + %dummy_1200_1 = OpFunctionCall %void %200 + OpReturn + OpFunctionEnd + + ; Call %300 + %1300 = OpFunction %void None %voidfn + %entry_1300 = OpLabel + %dummy_1300_1 = OpFunctionCall %void %300 + OpReturn + OpFunctionEnd + + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error() << assembly; + EXPECT_TRUE(p->error().empty()); + + const auto& info_1000 = p->GetEntryPointInfo(1000); + EXPECT_EQ(1u, info_1000.size()); + EXPECT_TRUE(info_1000[0].inputs.empty()); + EXPECT_TRUE(info_1000[0].outputs.empty()); + + const auto& info_1100 = p->GetEntryPointInfo(1100); + EXPECT_EQ(1u, info_1100.size()); + EXPECT_THAT(info_1100[0].inputs, ElementsAre(1)); + EXPECT_TRUE(info_1100[0].outputs.empty()); + + const auto& info_1200 = p->GetEntryPointInfo(1200); + EXPECT_EQ(1u, info_1200.size()); + EXPECT_THAT(info_1200[0].inputs, ElementsAre(2)); + EXPECT_THAT(info_1200[0].outputs, ElementsAre(15)); + + const auto& info_1300 = p->GetEntryPointInfo(1300); + EXPECT_EQ(1u, info_1300.size()); + EXPECT_THAT(info_1300[0].inputs, ElementsAre(1, 2)); + EXPECT_THAT(info_1300[0].outputs, ElementsAre(15)); +} + // TODO(dneto): Test passing pointer to SampleMask as function parameter, // both input case and output case. diff --git a/src/reader/spirv/parser_impl_user_name_test.cc b/src/reader/spirv/parser_impl_user_name_test.cc index d854a9b878..99ffbcc55e 100644 --- a/src/reader/spirv/parser_impl_user_name_test.cc +++ b/src/reader/spirv/parser_impl_user_name_test.cc @@ -130,7 +130,7 @@ TEST_F(SpvParserTest, EntryPointNamesAlwaysTakePrecedence) { // has grabbed "main_1" first. EXPECT_THAT(p->namer().Name(1), Eq("main_1_1")); - const auto ep_info = p->GetEntryPointInfo(100); + const auto& ep_info = p->GetEntryPointInfo(100); ASSERT_EQ(2u, ep_info.size()); EXPECT_EQ(ep_info[0].name, "main"); EXPECT_EQ(ep_info[1].name, "main_1");