diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 4c6f0bd679..57df564eb7 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -35,10 +35,15 @@ Builder::Builder() = default; Builder::~Builder() = default; -bool Builder::Build(const ast::Module&) { +bool Builder::Build(const ast::Module& m) { push_preamble(spv::Op::OpCapability, {Operand::Int(SpvCapabilityShader)}); - push_preamble(spv::Op::OpExtInstImport, - {result_op(), Operand::String("GLSL.std.450")}); + push_preamble(spv::Op::OpCapability, + {Operand::Int(SpvCapabilityVulkanMemoryModel)}); + + for (const auto& imp : m.imports()) { + GenerateImport(imp.get()); + } + push_preamble(spv::Op::OpMemoryModel, {Operand::Int(SpvAddressingModelLogical), Operand::Int(SpvMemoryModelVulkanKHR)}); @@ -80,6 +85,15 @@ void Builder::iterate(std::function cb) const { } } +void Builder::GenerateImport(ast::Import* imp) { + auto op = result_op(); + auto id = op.to_i(); + + push_preamble(spv::Op::OpExtInstImport, {op, Operand::String(imp->path())}); + + import_name_to_id_[imp->name()] = id; +} + } // namespace spirv } // namespace writer } // namespace tint diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 7b27b18575..07886211c6 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -16,6 +16,8 @@ #define SRC_WRITER_SPIRV_BUILDER_H_ #include +#include +#include #include #include "src/ast/module.h" @@ -43,6 +45,13 @@ class Builder { /// @returns the id bound for this module uint32_t id_bound() const { return next_id_; } + /// @returns the next id to be used + uint32_t next_id() { + auto id = next_id_; + next_id_ += 1; + return id; + } + /// Iterates over all the instructions in the correct order and calls the /// given callback /// @param cb the callback to execute @@ -89,13 +98,12 @@ class Builder { /// @returns the annotations const std::vector& annot() const { return annotations_; } + /// Generates an import instruction + /// @param imp the import + void GenerateImport(ast::Import* imp); + private: Operand result_op(); - uint32_t next_id() { - auto id = next_id_; - next_id_ += 1; - return id; - } uint32_t next_id_ = 1; std::vector preamble_; @@ -103,6 +111,8 @@ class Builder { std::vector types_; std::vector instructions_; std::vector annotations_; + + std::unordered_map import_name_to_id_; }; } // namespace spirv diff --git a/src/writer/spirv/builder_test.cc b/src/writer/spirv/builder_test.cc index 5e44c73d48..8d06910b6e 100644 --- a/src/writer/spirv/builder_test.cc +++ b/src/writer/spirv/builder_test.cc @@ -14,8 +14,12 @@ #include "src/writer/spirv/builder.h" +#include + #include "gtest/gtest.h" +#include "spirv/unified1/spirv.h" #include "spirv/unified1/spirv.hpp11" +#include "src/ast/import.h" #include "src/ast/module.h" namespace tint { @@ -24,17 +28,48 @@ namespace spirv { using BuilderTest = testing::Test; -TEST_F(BuilderTest, InsertsPreamble) { +TEST_F(BuilderTest, InsertsPreambleWithImport) { + ast::Module m; + m.AddImport(std::make_unique("GLSL.std.450", "glsl")); + + Builder b; + ASSERT_TRUE(b.Build(m)); + ASSERT_EQ(b.preamble().size(), 4); + + auto pre = b.preamble(); + EXPECT_EQ(pre[0].opcode(), spv::Op::OpCapability); + EXPECT_EQ(pre[0].operands()[0].to_i(), SpvCapabilityShader); + EXPECT_EQ(pre[1].opcode(), spv::Op::OpCapability); + EXPECT_EQ(pre[1].operands()[0].to_i(), SpvCapabilityVulkanMemoryModel); + EXPECT_EQ(pre[2].opcode(), spv::Op::OpExtInstImport); + EXPECT_EQ(pre[2].operands()[1].to_s(), "GLSL.std.450"); + EXPECT_EQ(pre[3].opcode(), spv::Op::OpMemoryModel); +} + +TEST_F(BuilderTest, InsertsPreambleWithoutImport) { ast::Module m; Builder b; ASSERT_TRUE(b.Build(m)); ASSERT_EQ(b.preamble().size(), 3); + auto pre = b.preamble(); EXPECT_EQ(pre[0].opcode(), spv::Op::OpCapability); - EXPECT_EQ(pre[1].opcode(), spv::Op::OpExtInstImport); + EXPECT_EQ(pre[0].operands()[0].to_i(), SpvCapabilityShader); + EXPECT_EQ(pre[1].opcode(), spv::Op::OpCapability); + EXPECT_EQ(pre[1].operands()[0].to_i(), SpvCapabilityVulkanMemoryModel); EXPECT_EQ(pre[2].opcode(), spv::Op::OpMemoryModel); } +TEST_F(BuilderTest, TracksIdBounds) { + Builder b; + + for (size_t i = 0; i < 5; i++) { + EXPECT_EQ(b.next_id(), i + 1); + } + + EXPECT_EQ(6, b.id_bound()); +} + } // namespace spirv } // namespace writer } // namespace tint