tint/writer/spirv: Replace Operand with std::variant

Operand is just a tagged union of uint32_t, float and std::string.

Use std::variant for this.
Reduces memory size, and removes the need to always construct an empty string when the operand is float or int.

Bug: tint:1383
Change-Id: I02fc10137d6fab410ea25a8d6c6e279b882b2287
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/88302
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-04-28 18:40:03 +00:00 committed by Dawn LUCI CQ
parent b0b53ba403
commit 7e8df044c6
11 changed files with 382 additions and 511 deletions

View File

@ -15,6 +15,7 @@
#include "src/tint/writer/spirv/binary_writer.h" #include "src/tint/writer/spirv/binary_writer.h"
#include <cstring> #include <cstring>
#include <string>
namespace tint::writer::spirv { namespace tint::writer::spirv {
namespace { namespace {
@ -54,19 +55,22 @@ void BinaryWriter::process_instruction(const Instruction& inst) {
} }
void BinaryWriter::process_op(const Operand& op) { void BinaryWriter::process_op(const Operand& op) {
if (op.IsFloat()) { if (auto* i = std::get_if<uint32_t>(&op)) {
out_.push_back(*i);
return;
}
if (auto* f = std::get_if<float>(&op)) {
// Allocate space for the float // Allocate space for the float
out_.push_back(0); out_.push_back(0);
auto f = op.to_f();
uint8_t* ptr = reinterpret_cast<uint8_t*>(out_.data() + (out_.size() - 1)); uint8_t* ptr = reinterpret_cast<uint8_t*>(out_.data() + (out_.size() - 1));
memcpy(ptr, &f, 4); memcpy(ptr, f, 4);
} else if (op.IsInt()) { return;
out_.push_back(op.to_i()); }
} else { if (auto* str = std::get_if<std::string>(&op)) {
auto idx = out_.size(); auto idx = out_.size();
const auto& str = op.to_s(); out_.resize(out_.size() + OperandLength(op), 0);
out_.resize(out_.size() + op.length(), 0); memcpy(out_.data() + idx, str->c_str(), str->size() + 1);
memcpy(out_.data() + idx, str.c_str(), str.size() + 1); return;
} }
} }

View File

@ -35,7 +35,7 @@ TEST_F(BinaryWriterTest, Preamble) {
TEST_F(BinaryWriterTest, Float) { TEST_F(BinaryWriterTest, Float) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();
b.push_annot(spv::Op::OpKill, {Operand::Float(2.4f)}); b.push_annot(spv::Op::OpKill, {Operand(2.4f)});
BinaryWriter bw; BinaryWriter bw;
bw.WriteBuilder(&b); bw.WriteBuilder(&b);
@ -49,7 +49,7 @@ TEST_F(BinaryWriterTest, Float) {
TEST_F(BinaryWriterTest, Int) { TEST_F(BinaryWriterTest, Int) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();
b.push_annot(spv::Op::OpKill, {Operand::Int(2)}); b.push_annot(spv::Op::OpKill, {Operand(2u)});
BinaryWriter bw; BinaryWriter bw;
bw.WriteBuilder(&b); bw.WriteBuilder(&b);
@ -61,7 +61,7 @@ TEST_F(BinaryWriterTest, Int) {
TEST_F(BinaryWriterTest, String) { TEST_F(BinaryWriterTest, String) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();
b.push_annot(spv::Op::OpKill, {Operand::String("my_string")}); b.push_annot(spv::Op::OpKill, {Operand("my_string")});
BinaryWriter bw; BinaryWriter bw;
bw.WriteBuilder(&b); bw.WriteBuilder(&b);
@ -86,7 +86,7 @@ TEST_F(BinaryWriterTest, String) {
TEST_F(BinaryWriterTest, String_Multiple4Length) { TEST_F(BinaryWriterTest, String_Multiple4Length) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();
b.push_annot(spv::Op::OpKill, {Operand::String("mystring")}); b.push_annot(spv::Op::OpKill, {Operand("mystring")});
BinaryWriter bw; BinaryWriter bw;
bw.WriteBuilder(&b); bw.WriteBuilder(&b);
@ -109,8 +109,8 @@ TEST_F(BinaryWriterTest, String_Multiple4Length) {
} }
TEST_F(BinaryWriterTest, TestInstructionWriter) { TEST_F(BinaryWriterTest, TestInstructionWriter) {
Instruction i1{spv::Op::OpKill, {Operand::Int(2)}}; Instruction i1{spv::Op::OpKill, {Operand(2u)}};
Instruction i2{spv::Op::OpKill, {Operand::Int(4)}}; Instruction i2{spv::Op::OpKill, {Operand(4u)}};
BinaryWriter bw; BinaryWriter bw;
bw.WriteInstruction(i1); bw.WriteInstruction(i1);

File diff suppressed because it is too large Load Diff

View File

@ -79,7 +79,7 @@ TEST_P(Attribute_StageTest, Emit) {
EXPECT_EQ(preamble[0].opcode(), spv::Op::OpEntryPoint); EXPECT_EQ(preamble[0].opcode(), spv::Op::OpEntryPoint);
ASSERT_GE(preamble[0].operands().size(), 3u); ASSERT_GE(preamble[0].operands().size(), 3u);
EXPECT_EQ(preamble[0].operands()[0].to_i(), EXPECT_EQ(std::get<uint32_t>(preamble[0].operands()[0]),
static_cast<uint32_t>(params.model)); static_cast<uint32_t>(params.model));
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(

View File

@ -17,8 +17,7 @@
namespace tint::writer::spirv { namespace tint::writer::spirv {
Function::Function() Function::Function()
: declaration_(Instruction{spv::Op::OpNop, {}}), : declaration_(Instruction{spv::Op::OpNop, {}}), label_op_(Operand(0u)) {}
label_op_(Operand::Int(0)) {}
Function::Function(const Instruction& declaration, Function::Function(const Instruction& declaration,
const Operand& label_op, const Operand& label_op,

View File

@ -48,7 +48,7 @@ class Function {
const Instruction& declaration() const { return declaration_; } const Instruction& declaration() const { return declaration_; }
/// @returns the label ID for the function entry block /// @returns the label ID for the function entry block
uint32_t label_id() const { return label_op_.to_i(); } uint32_t label_id() const { return std::get<uint32_t>(label_op_); }
/// Adds an instruction to the instruction list /// Adds an instruction to the instruction list
/// @param op the op to set /// @param op the op to set

View File

@ -28,7 +28,7 @@ Instruction::~Instruction() = default;
uint32_t Instruction::word_length() const { uint32_t Instruction::word_length() const {
uint32_t size = 1; // Initial 1 for the op and size uint32_t size = 1; // Initial 1 for the op and size
for (const auto& op : operands_) { for (const auto& op : operands_) {
size += op.length(); size += OperandLength(op);
} }
return size; return size;
} }

View File

@ -14,6 +14,8 @@
#include "src/tint/writer/spirv/instruction.h" #include "src/tint/writer/spirv/instruction.h"
#include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace tint::writer::spirv { namespace tint::writer::spirv {
@ -22,25 +24,25 @@ namespace {
using InstructionTest = testing::Test; using InstructionTest = testing::Test;
TEST_F(InstructionTest, Create) { TEST_F(InstructionTest, Create) {
Instruction i(spv::Op::OpEntryPoint, {Operand::Float(1.2f), Operand::Int(1), Instruction i(spv::Op::OpEntryPoint,
Operand::String("my_str")}); {Operand(1.2f), Operand(1u), Operand("my_str")});
EXPECT_EQ(i.opcode(), spv::Op::OpEntryPoint); EXPECT_EQ(i.opcode(), spv::Op::OpEntryPoint);
ASSERT_EQ(i.operands().size(), 3u); ASSERT_EQ(i.operands().size(), 3u);
const auto& ops = i.operands(); const auto& ops = i.operands();
EXPECT_TRUE(ops[0].IsFloat()); ASSERT_TRUE(std::holds_alternative<float>(ops[0]));
EXPECT_FLOAT_EQ(ops[0].to_f(), 1.2f); EXPECT_FLOAT_EQ(std::get<float>(ops[0]), 1.2f);
EXPECT_TRUE(ops[1].IsInt()); ASSERT_TRUE(std::holds_alternative<uint32_t>(ops[1]));
EXPECT_EQ(ops[1].to_i(), 1u); EXPECT_EQ(std::get<uint32_t>(ops[1]), 1u);
EXPECT_TRUE(ops[2].IsString()); ASSERT_TRUE(std::holds_alternative<std::string>(ops[2]));
EXPECT_EQ(ops[2].to_s(), "my_str"); EXPECT_EQ(std::get<std::string>(ops[2]), "my_str");
} }
TEST_F(InstructionTest, Length) { TEST_F(InstructionTest, Length) {
Instruction i(spv::Op::OpEntryPoint, {Operand::Float(1.2f), Operand::Int(1), Instruction i(spv::Op::OpEntryPoint,
Operand::String("my_str")}); {Operand(1.2f), Operand(1u), Operand("my_str")});
EXPECT_EQ(i.word_length(), 5u); EXPECT_EQ(i.word_length(), 5u);
} }

View File

@ -16,46 +16,14 @@
namespace tint::writer::spirv { namespace tint::writer::spirv {
// static uint32_t OperandLength(const Operand& o) {
Operand Operand::Float(float val) { if (auto* str = std::get_if<std::string>(&o)) {
Operand o(Kind::kFloat); // SPIR-V always nul-terminates strings. The length is rounded up to a
o.float_val_ = val; // multiple of 4 bytes with 0 bytes padding the end. Accounting for the
return o; // nul terminator is why '+ 4u' is used here instead of '+ 3u'.
} return static_cast<uint32_t>((str->length() + 4u) >> 2);
// static
Operand Operand::Int(uint32_t val) {
Operand o(Kind::kInt);
o.int_val_ = val;
return o;
}
// static
Operand Operand::String(const std::string& val) {
Operand o(Kind::kString);
o.str_val_ = val;
return o;
}
Operand::Operand(Kind kind) : kind_(kind) {}
Operand::~Operand() = default;
uint32_t Operand::length() const {
uint32_t val = 0;
switch (kind_) {
case Kind::kFloat:
case Kind::kInt:
val = 1;
break;
case Kind::kString:
// SPIR-V always nul-terminates strings. The length is rounded up to a
// multiple of 4 bytes with 0 bytes padding the end. Accounting for the
// nul terminator is why '+ 4u' is used here instead of '+ 3u'.
val = static_cast<uint32_t>((str_val_.length() + 4u) >> 2);
break;
} }
return val; return 1;
} }
} // namespace tint::writer::spirv } // namespace tint::writer::spirv

View File

@ -17,6 +17,8 @@
#include <cstring> #include <cstring>
#include <string> #include <string>
// TODO(https://crbug.com/dawn/1379) Update cpplint and remove NOLINT
#include <variant> // NOLINT(build/include_order)
#include <vector> #include <vector>
#include "src/tint/utils/hash.h" #include "src/tint/utils/hash.h"
@ -24,84 +26,16 @@
namespace tint::writer::spirv { namespace tint::writer::spirv {
/// A single SPIR-V instruction operand /// A single SPIR-V instruction operand
class Operand { using Operand = std::variant<uint32_t, float, std::string>;
public:
/// The kind of the operand
// Note, the `kInt` will cover most cases as things like IDs in SPIR-V are
// just ints for the purpose of converting to binary.
enum class Kind { kInt = 0, kFloat, kString };
/// Creates a float operand // Helper for returning an uint32_t Operand with the provided integer value.
/// @param val the float value template <typename T>
/// @returns the operand inline Operand U32Operand(T val) {
static Operand Float(float val); return Operand{static_cast<uint32_t>(val)};
/// Creates an int operand }
/// @param val the int value
/// @returns the operand
static Operand Int(uint32_t val);
/// Creates a string operand
/// @param val the string value
/// @returns the operand
static Operand String(const std::string& val);
/// Constructor /// @returns the number of uint32_t's needed for this operand
/// @param kind the type of operand uint32_t OperandLength(const Operand& o);
explicit Operand(Kind kind);
/// Copy Constructor
Operand(const Operand&) = default;
~Operand();
/// Copy assignment
/// @param b the operand to copy
/// @returns a copy of this operand
Operand& operator=(const Operand& b) = default;
/// Equality operator
/// @param other the RHS of the operator
/// @returns true if this operand is equal to other
bool operator==(const Operand& other) const {
if (kind_ == other.kind_) {
switch (kind_) {
case tint::writer::spirv::Operand::Kind::kFloat:
// Use memcmp to work around:
// error: comparing floating point with == or != is unsafe
// [-Werror,-Wfloat-equal]
return memcmp(&float_val_, &other.float_val_, sizeof(float_val_)) ==
0;
case tint::writer::spirv::Operand::Kind::kInt:
return int_val_ == other.int_val_;
case tint::writer::spirv::Operand::Kind::kString:
return str_val_ == other.str_val_;
}
}
return false;
}
/// @returns the kind of the operand
Kind GetKind() const { return kind_; }
/// @returns true if this is a float operand
bool IsFloat() const { return kind_ == Kind::kFloat; }
/// @returns true if this is an integer operand
bool IsInt() const { return kind_ == Kind::kInt; }
/// @returns true if this is a string operand
bool IsString() const { return kind_ == Kind::kString; }
/// @returns the number of uint32_t's needed for this operand
uint32_t length() const;
/// @returns the float value
float to_f() const { return float_val_; }
/// @returns the int value
uint32_t to_i() const { return int_val_; }
/// @returns the string value
const std::string& to_s() const { return str_val_; }
private:
Kind kind_ = Kind::kInt;
float float_val_ = 0.0;
uint32_t int_val_ = 0;
std::string str_val_;
};
/// A list of operands /// A list of operands
using OperandList = std::vector<Operand>; using OperandList = std::vector<Operand>;
@ -119,15 +53,7 @@ class hash<tint::writer::spirv::Operand> {
/// @param o the Operand /// @param o the Operand
/// @return the hash value /// @return the hash value
inline std::size_t operator()(const tint::writer::spirv::Operand& o) const { inline std::size_t operator()(const tint::writer::spirv::Operand& o) const {
switch (o.GetKind()) { return std::visit([](auto v) { return tint::utils::Hash(v); }, o);
case tint::writer::spirv::Operand::Kind::kFloat:
return tint::utils::Hash(o.to_f());
case tint::writer::spirv::Operand::Kind::kInt:
return tint::utils::Hash(o.to_i());
case tint::writer::spirv::Operand::Kind::kString:
return tint::utils::Hash(o.to_s());
}
return 0;
} }
}; };

View File

@ -22,41 +22,41 @@ namespace {
using OperandTest = testing::Test; using OperandTest = testing::Test;
TEST_F(OperandTest, CreateFloat) { TEST_F(OperandTest, CreateFloat) {
auto o = Operand::Float(1.2f); auto o = Operand(1.2f);
EXPECT_TRUE(o.IsFloat()); ASSERT_TRUE(std::holds_alternative<float>(o));
EXPECT_FLOAT_EQ(o.to_f(), 1.2f); EXPECT_FLOAT_EQ(std::get<float>(o), 1.2f);
} }
TEST_F(OperandTest, CreateInt) { TEST_F(OperandTest, CreateInt) {
auto o = Operand::Int(1); auto o = Operand(1u);
EXPECT_TRUE(o.IsInt()); ASSERT_TRUE(std::holds_alternative<uint32_t>(o));
EXPECT_EQ(o.to_i(), 1u); EXPECT_EQ(std::get<uint32_t>(o), 1u);
} }
TEST_F(OperandTest, CreateString) { TEST_F(OperandTest, CreateString) {
auto o = Operand::String("my string"); auto o = Operand("my string");
EXPECT_TRUE(o.IsString()); ASSERT_TRUE(std::holds_alternative<std::string>(o));
EXPECT_EQ(o.to_s(), "my string"); EXPECT_EQ(std::get<std::string>(o), "my string");
} }
TEST_F(OperandTest, Length_Float) { TEST_F(OperandTest, Length_Float) {
auto o = Operand::Float(1.2f); auto o = Operand(1.2f);
EXPECT_EQ(o.length(), 1u); EXPECT_EQ(OperandLength(o), 1u);
} }
TEST_F(OperandTest, Length_Int) { TEST_F(OperandTest, Length_Int) {
auto o = Operand::Int(1); auto o = U32Operand(1);
EXPECT_EQ(o.length(), 1u); EXPECT_EQ(OperandLength(o), 1u);
} }
TEST_F(OperandTest, Length_String) { TEST_F(OperandTest, Length_String) {
auto o = Operand::String("my string"); auto o = Operand("my string");
EXPECT_EQ(o.length(), 3u); EXPECT_EQ(OperandLength(o), 3u);
} }
TEST_F(OperandTest, Length_String_Empty) { TEST_F(OperandTest, Length_String_Empty) {
auto o = Operand::String(""); auto o = Operand("");
EXPECT_EQ(o.length(), 1u); EXPECT_EQ(OperandLength(o), 1u);
} }
} // namespace } // namespace