diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2c615df856..ccd47541fc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -399,6 +399,7 @@ if(${TINT_BUILD_SPV_WRITER}) list(APPEND TINT_TEST_SRCS writer/spirv/binary_writer_test.cc writer/spirv/builder_test.cc + writer/spirv/builder_test_entry_point_test.cc writer/spirv/instruction_test.cc writer/spirv/operand_test.cc ) diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 57df564eb7..f5422804b6 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -29,6 +29,26 @@ uint32_t size_of(const std::vector& instructions) { return size; } +uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) { + SpvExecutionModel model = SpvExecutionModelVertex; + + switch (stage) { + case ast::PipelineStage::kFragment: + model = SpvExecutionModelFragment; + break; + case ast::PipelineStage::kVertex: + model = SpvExecutionModelVertex; + break; + case ast::PipelineStage::kCompute: + model = SpvExecutionModelGLCompute; + break; + case ast::PipelineStage::kNone: + model = SpvExecutionModelMax; + break; + } + return model; +} + } // namespace Builder::Builder() = default; @@ -47,6 +67,13 @@ bool Builder::Build(const ast::Module& m) { push_preamble(spv::Op::OpMemoryModel, {Operand::Int(SpvAddressingModelLogical), Operand::Int(SpvMemoryModelVulkanKHR)}); + + for (const auto& ep : m.entry_points()) { + if (!GenerateEntryPoint(ep.get())) { + return false; + } + } + return true; } @@ -85,6 +112,28 @@ void Builder::iterate(std::function cb) const { } } +bool Builder::GenerateEntryPoint(ast::EntryPoint* ep) { + auto name = ep->name(); + if (name.empty()) { + name = ep->function_name(); + } + + auto id = id_for_func_name(ep->function_name()); + if (id == 0) { + return false; + } + + auto stage = pipeline_stage_to_execution_model(ep->stage()); + if (stage == SpvExecutionModelMax) { + return false; + } + + push_preamble(spv::Op::OpEntryPoint, + {Operand::Int(stage), Operand::Int(id), Operand::String(name)}); + + return true; +} + void Builder::GenerateImport(ast::Import* imp) { auto op = result_op(); auto id = op.to_i(); diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 07886211c6..a0ab57176d 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -52,6 +52,22 @@ class Builder { return id; } + /// Sets the id for a given function name + /// @param name the name to set + /// @param id the id to set + void set_func_name_to_id(const std::string& name, uint32_t id) { + func_name_to_id_[name] = id; + } + + /// Retrives the id for the given function name + /// @returns the id for the given name or 0 on failure + uint32_t id_for_func_name(const std::string& name) { + if (func_name_to_id_.count(name) == 0) { + return 0; + } + return func_name_to_id_[name]; + } + /// Iterates over all the instructions in the correct order and calls the /// given callback /// @param cb the callback to execute @@ -98,6 +114,10 @@ class Builder { /// @returns the annotations const std::vector& annot() const { return annotations_; } + /// Generates an entry point instruction + /// @param ep the entry point + /// @returns true if the instruction was generated, false otherwise + bool GenerateEntryPoint(ast::EntryPoint* ep); /// Generates an import instruction /// @param imp the import void GenerateImport(ast::Import* imp); @@ -113,6 +133,7 @@ class Builder { std::vector annotations_; std::unordered_map import_name_to_id_; + std::unordered_map func_name_to_id_; }; } // namespace spirv diff --git a/src/writer/spirv/builder_test_entry_point_test.cc b/src/writer/spirv/builder_test_entry_point_test.cc new file mode 100644 index 0000000000..4a83258ea9 --- /dev/null +++ b/src/writer/spirv/builder_test_entry_point_test.cc @@ -0,0 +1,111 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "spirv/unified1/spirv.h" +#include "spirv/unified1/spirv.hpp11" +#include "src/ast/entry_point.h" +#include "src/ast/pipeline_stage.h" +#include "src/writer/spirv/builder.h" + +namespace tint { +namespace writer { +namespace spirv { + +using BuilderTest = testing::Test; + +TEST_F(BuilderTest, EntryPoint) { + ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main"); + + Builder b; + b.set_func_name_to_id("frag_main", 2); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + + auto preamble = b.preamble(); + ASSERT_EQ(preamble.size(), 1); + EXPECT_EQ(preamble[0].opcode(), spv::Op::OpEntryPoint); + + ASSERT_TRUE(preamble[0].operands().size() >= 3); + EXPECT_EQ(preamble[0].operands()[0].to_i(), SpvExecutionModelFragment); + EXPECT_EQ(preamble[0].operands()[1].to_i(), 2); + EXPECT_EQ(preamble[0].operands()[2].to_s(), "main"); +} + +TEST_F(BuilderTest, EntryPoint_WithoutName) { + ast::EntryPoint ep(ast::PipelineStage::kCompute, "", "compute_main"); + + Builder b; + b.set_func_name_to_id("compute_main", 3); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + + auto preamble = b.preamble(); + ASSERT_EQ(preamble.size(), 1); + EXPECT_EQ(preamble[0].opcode(), spv::Op::OpEntryPoint); + + ASSERT_TRUE(preamble[0].operands().size() >= 3); + EXPECT_EQ(preamble[0].operands()[0].to_i(), SpvExecutionModelGLCompute); + EXPECT_EQ(preamble[0].operands()[1].to_i(), 3); + EXPECT_EQ(preamble[0].operands()[2].to_s(), "compute_main"); +} + +TEST_F(BuilderTest, EntryPoint_BadFunction) { + ast::EntryPoint ep(ast::PipelineStage::kFragment, "main", "frag_main"); + + Builder b; + EXPECT_FALSE(b.GenerateEntryPoint(&ep)); +} + +struct EntryPointStageData { + ast::PipelineStage stage; + SpvExecutionModel model; +}; +inline std::ostream& operator<<(std::ostream& out, EntryPointStageData data) { + out << data.stage; + return out; +} +using EntryPointStageTest = testing::TestWithParam; +TEST_P(EntryPointStageTest, Emit) { + auto params = GetParam(); + + ast::EntryPoint ep(params.stage, "", "main"); + + Builder b; + b.set_func_name_to_id("main", 3); + ASSERT_TRUE(b.GenerateEntryPoint(&ep)); + + auto preamble = b.preamble(); + ASSERT_EQ(preamble.size(), 1); + EXPECT_EQ(preamble[0].opcode(), spv::Op::OpEntryPoint); + + ASSERT_TRUE(preamble[0].operands().size() >= 3); + EXPECT_EQ(preamble[0].operands()[0].to_i(), params.model); +} +INSTANTIATE_TEST_SUITE_P( + BuilderTest, + EntryPointStageTest, + testing::Values(EntryPointStageData{ast::PipelineStage::kVertex, + SpvExecutionModelVertex}, + EntryPointStageData{ast::PipelineStage::kFragment, + SpvExecutionModelFragment}, + EntryPointStageData{ast::PipelineStage::kCompute, + SpvExecutionModelGLCompute})); + +// TODO(http://crbug.com/tint/28) +TEST_F(BuilderTest, DISABLED_EntryPoint_WithInterfaceIds) {} + +} // namespace spirv +} // namespace writer +} // namespace tint