diff --git a/src/BUILD.gn b/src/BUILD.gn index 787afe2c87..fc82118034 100644 --- a/src/BUILD.gn +++ b/src/BUILD.gn @@ -558,6 +558,7 @@ source_set("libtint_spv_writer_src") { "writer/spirv/instruction.h", "writer/spirv/operand.cc", "writer/spirv/operand.h", + "writer/spirv/scalar_constant.h", ] configs += [ ":tint_common_config" ] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9614ed4ab8..9e5f29e7f3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -343,6 +343,7 @@ if(${TINT_BUILD_SPV_WRITER}) writer/spirv/instruction.h writer/spirv/operand.cc writer/spirv/operand.h + writer/spirv/scalar_constant.h ) endif() @@ -678,6 +679,7 @@ if(${TINT_BUILD_TESTS}) writer/spirv/builder_unary_op_expression_test.cc writer/spirv/instruction_test.cc writer/spirv/operand_test.cc + writer/spirv/scalar_constant_test.cc writer/spirv/spv_dump.cc writer/spirv/spv_dump.h writer/spirv/test_helper.h diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 71054c8bd8..7209aba4db 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -21,7 +21,6 @@ #include "src/ast/call_statement.h" #include "src/ast/constant_id_decoration.h" #include "src/ast/fallthrough_statement.h" -#include "src/ast/null_literal.h" #include "src/semantic/array.h" #include "src/semantic/call.h" #include "src/semantic/function.h" @@ -365,12 +364,6 @@ bool Builder::GenerateLabel(uint32_t id) { return true; } -uint32_t Builder::GenerateU32Literal(uint32_t val) { - type::U32 u32; - ast::SintLiteral lit(Source{}, &u32, val); - return GenerateLiteralIfNeeded(nullptr, &lit); -} - bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) { auto lhs_id = GenerateExpression(assign->lhs()); if (lhs_id == 0) { @@ -648,8 +641,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) { // TODO(dsinclair) We could detect if the constructor is fully const and emit // an initializer value for the variable instead of doing the OpLoad. - ast::NullLiteral nl(Source{}, var->type()->UnwrapPtrIfNeeded()); - auto null_id = GenerateLiteralIfNeeded(var, &nl); + auto null_id = GenerateConstantNullIfNeeded(var->type()->UnwrapPtrIfNeeded()); if (null_id == 0) { return 0; } @@ -779,8 +771,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) { } else if (sem->StorageClass() == ast::StorageClass::kPrivate || sem->StorageClass() == ast::StorageClass::kNone || sem->StorageClass() == ast::StorageClass::kOutput) { - ast::NullLiteral nl(Source{}, type); - init_id = GenerateLiteralIfNeeded(var, &nl); + init_id = GenerateConstantNullIfNeeded(type); if (init_id == 0) { return 0; } @@ -888,7 +879,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, } } - auto idx_id = GenerateU32Literal(i); + auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i)); if (idx_id == 0) { return 0; } @@ -913,7 +904,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, } if (info->source_type->Is()) { - auto idx_id = GenerateU32Literal(val); + auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(val)); if (idx_id == 0) { return 0; } @@ -1044,8 +1035,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) { auto ary_result = result_op(); - ast::NullLiteral nl(Source{}, ary_res_type); - auto init = GenerateLiteralIfNeeded(nullptr, &nl); + auto init = GenerateConstantNullIfNeeded(ary_res_type); // If we're access chaining into an array then we must be in a function push_function_var( @@ -1259,8 +1249,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( // Generate the zero initializer if there are no values provided. if (values.empty()) { - ast::NullLiteral nl(Source{}, init->type()->UnwrapPtrIfNeeded()); - return GenerateLiteralIfNeeded(nullptr, &nl); + return GenerateConstantNullIfNeeded(init->type()->UnwrapPtrIfNeeded()); } std::ostringstream out; @@ -1370,7 +1359,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( result_is_constant_composite = false; } else { // A global initializer, must use OpSpecConstantOp. Case 1. - auto idx_id = GenerateU32Literal(i); + auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i)); if (idx_id == 0) { return 0; } @@ -1392,8 +1381,8 @@ uint32_t Builder::GenerateTypeConstructorExpression( } auto str = out.str(); - auto val = const_to_id_.find(str); - if (val != const_to_id_.end()) { + auto val = type_constructor_to_id_.find(str); + if (val != type_constructor_to_id_.end()) { return val->second; } @@ -1401,7 +1390,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( ops.insert(ops.begin(), result); ops.insert(ops.begin(), Operand::Int(type_id)); - const_to_id_[str] = result.to_i(); + type_constructor_to_id_[str] = result.to_i(); if (result_is_spec_composite) { push_type(spv::Op::OpSpecConstantComposite, ops); @@ -1480,59 +1469,133 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(type::Type* to_type, uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var, ast::Literal* lit) { - auto type_id = GenerateTypeIfNeeded(lit->type()); - if (type_id == 0) { - return 0; - } + ScalarConstant constant; - auto name = lit->name(); - bool is_spec_constant = false; if (var && var->HasConstantIdDecoration()) { - name = "__spec" + name; - is_spec_constant = true; - } - - auto val = const_to_id_.find(name); - if (val != const_to_id_.end()) { - return val->second; - } - - auto result = result_op(); - auto result_id = result.to_i(); - - if (is_spec_constant) { - push_annot(spv::Op::OpDecorate, - {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId), - Operand::Int(var->constant_id())}); + constant.is_spec_op = true; + constant.constant_id = var->constant_id(); } if (auto* l = lit->As()) { - if (l->IsTrue()) { - push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue - : spv::Op::OpConstantTrue, - {Operand::Int(type_id), result}); - } else { - push_type(is_spec_constant ? spv::Op::OpSpecConstantFalse - : spv::Op::OpConstantFalse, - {Operand::Int(type_id), result}); - } + constant.kind = ScalarConstant::Kind::kBool; + constant.value.b = l->IsTrue(); } else if (auto* sl = lit->As()) { - push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, - {Operand::Int(type_id), result, Operand::Int(sl->value())}); + constant.kind = ScalarConstant::Kind::kI32; + constant.value.i32 = sl->value(); } else if (auto* ul = lit->As()) { - push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, - {Operand::Int(type_id), result, Operand::Int(ul->value())}); + constant.kind = ScalarConstant::Kind::kU32; + constant.value.u32 = ul->value(); } else if (auto* fl = lit->As()) { - push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant, - {Operand::Int(type_id), result, Operand::Float(fl->value())}); - } else if (lit->Is()) { - push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result}); + constant.kind = ScalarConstant::Kind::kF32; + constant.value.f32 = fl->value(); } else { error_ = "unknown literal type"; return 0; } - const_to_id_[name] = result_id; + return GenerateConstantIfNeeded(constant); +} + +uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) { + auto it = const_to_id_.find(constant); + if (it != const_to_id_.end()) { + return it->second; + } + + uint32_t type_id = 0; + + switch (constant.kind) { + case ScalarConstant::Kind::kU32: { + type::U32 u32; + type_id = GenerateTypeIfNeeded(&u32); + break; + } + case ScalarConstant::Kind::kI32: { + type::I32 i32; + type_id = GenerateTypeIfNeeded(&i32); + break; + } + case ScalarConstant::Kind::kF32: { + type::F32 f32; + type_id = GenerateTypeIfNeeded(&f32); + break; + } + case ScalarConstant::Kind::kBool: { + type::Bool bool_; + type_id = GenerateTypeIfNeeded(&bool_); + break; + } + } + + if (type_id == 0) { + return 0; + } + + auto result = result_op(); + auto result_id = result.to_i(); + + if (constant.is_spec_op) { + push_annot(spv::Op::OpDecorate, + {Operand::Int(result_id), Operand::Int(SpvDecorationSpecId), + Operand::Int(constant.constant_id)}); + } + + switch (constant.kind) { + case ScalarConstant::Kind::kU32: { + push_type( + constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant, + {Operand::Int(type_id), result, Operand::Int(constant.value.u32)}); + break; + } + case ScalarConstant::Kind::kI32: { + push_type( + constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant, + {Operand::Int(type_id), result, Operand::Int(constant.value.i32)}); + break; + } + case ScalarConstant::Kind::kF32: { + push_type( + constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant, + {Operand::Int(type_id), result, Operand::Float(constant.value.f32)}); + break; + } + case ScalarConstant::Kind::kBool: { + if (constant.value.b) { + push_type(constant.is_spec_op ? spv::Op::OpSpecConstantTrue + : spv::Op::OpConstantTrue, + {Operand::Int(type_id), result}); + } else { + push_type(constant.is_spec_op ? spv::Op::OpSpecConstantFalse + : spv::Op::OpConstantFalse, + {Operand::Int(type_id), result}); + } + break; + } + } + + const_to_id_[constant] = result_id; + return result_id; +} + +uint32_t Builder::GenerateConstantNullIfNeeded(type::Type* type) { + auto type_id = GenerateTypeIfNeeded(type); + if (type_id == 0) { + return 0; + } + + auto name = type->type_name(); + + auto it = const_null_to_id_.find(name); + if (it != const_null_to_id_.end()) { + return it->second; + } + + auto result = result_op(); + auto result_id = result.to_i(); + + push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result}); + + const_null_to_id_[name] = result_id; return result_id; } @@ -2955,7 +3018,7 @@ bool Builder::GenerateArrayType(type::Array* ary, const Operand& result) { if (ary->IsRuntimeArray()) { push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)}); } else { - auto len_id = GenerateU32Literal(ary->size()); + auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->size())); if (len_id == 0) { return false; } diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 4f3251d6ca..6147c9ff3f 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -37,6 +37,7 @@ #include "src/type/access_control_type.h" #include "src/type/storage_texture_type.h" #include "src/writer/spirv/function.h" +#include "src/writer/spirv/scalar_constant.h" namespace tint { @@ -208,10 +209,6 @@ class Builder { /// @param id the id to use for the label /// @returns true on success. bool GenerateLabel(uint32_t id); - /// Generates a uint32_t literal. - /// @param val the value to generate - /// @returns the ID of the generated literal - uint32_t GenerateU32Literal(uint32_t val); /// Generates an assignment statement /// @param assign the statement to generate /// @returns true if the statement was successfully generated @@ -486,6 +483,16 @@ class Builder { return builder_.TypeOf(expr); } + /// Generates a constant if needed + /// @param constant the constant to generate. + /// @returns the ID on success or 0 on failure + uint32_t GenerateConstantIfNeeded(const ScalarConstant& constant); + + /// Generates a constant-null of the given type, if needed + /// @param type the type of the constant null to generate. + /// @returns the ID on success or 0 on failure + uint32_t GenerateConstantNullIfNeeded(type::Type* type); + ProgramBuilder builder_; std::string error_; uint32_t next_id_ = 1; @@ -504,7 +511,9 @@ class Builder { std::unordered_map import_name_to_id_; std::unordered_map func_symbol_to_id_; std::unordered_map type_name_to_id_; - std::unordered_map const_to_id_; + std::unordered_map const_to_id_; + std::unordered_map type_constructor_to_id_; + std::unordered_map const_null_to_id_; std::unordered_map texture_type_name_to_sampled_image_type_id_; ScopeStack scope_stack_; diff --git a/src/writer/spirv/scalar_constant.h b/src/writer/spirv/scalar_constant.h new file mode 100644 index 0000000000..18dba915fd --- /dev/null +++ b/src/writer/spirv/scalar_constant.h @@ -0,0 +1,115 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_ +#define SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_ + +#include + +#include +#include + +namespace tint { + +// Forward declarations +namespace semantic { +class Call; +} // namespace semantic + +namespace writer { +namespace spirv { + +/// ScalarConstant represents a scalar constant value +struct ScalarConstant { + /// The constant value + union Value { + /// The value as a bool + bool b; + /// The value as a uint32_t + uint32_t u32; + /// The value as a int32_t + int32_t i32; + /// The value as a float + float f32; + + /// The value that is wide enough to encompass all other types (including + /// future 64-bit data types). + uint64_t u64; + }; + + /// The kind of constant + enum class Kind { kBool, kU32, kI32, kF32 }; + + /// Constructor + inline ScalarConstant() { value.u64 = 0; } + + /// @param value the value of the constant + /// @returns a new ScalarConstant with the provided value and kind Kind::kU32 + static inline ScalarConstant U32(uint32_t value) { + ScalarConstant c; + c.value.u32 = value; + c.kind = Kind::kU32; + return c; + } + + /// Equality operator + /// @param rhs the ScalarConstant to compare against + /// @returns true if this ScalarConstant is equal to `rhs` + inline bool operator==(const ScalarConstant& rhs) const { + return value.u64 == rhs.value.u64 && kind == rhs.kind && + is_spec_op == rhs.is_spec_op && constant_id == rhs.constant_id; + } + + /// Inequality operator + /// @param rhs the ScalarConstant to compare against + /// @returns true if this ScalarConstant is not equal to `rhs` + inline bool operator!=(const ScalarConstant& rhs) const { + return !(*this == rhs); + } + + /// The constant value + Value value; + /// The constant value kind + Kind kind = Kind::kBool; + /// True if the constant is a specialization op + bool is_spec_op = false; + /// The identifier if a specialization op + uint32_t constant_id = 0; +}; + +} // namespace spirv +} // namespace writer +} // namespace tint + +namespace std { + +/// Custom std::hash specialization for tint::Symbol so symbols can be used as +/// keys for std::unordered_map and std::unordered_set. +template <> +class hash { + public: + /// @param c the ScalarConstant + /// @return the Symbol internal value + inline std::size_t operator()( + const tint::writer::spirv::ScalarConstant& c) const { + uint32_t value = 0; + std::memcpy(&value, &c.value, sizeof(value)); + return (static_cast(value) << 2) | + (static_cast(c.kind) & 3); + } +}; + +} // namespace std + +#endif // SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_ diff --git a/src/writer/spirv/scalar_constant_test.cc b/src/writer/spirv/scalar_constant_test.cc new file mode 100644 index 0000000000..b514146e39 --- /dev/null +++ b/src/writer/spirv/scalar_constant_test.cc @@ -0,0 +1,60 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/writer/spirv/scalar_constant.h" +#include "src/writer/spirv/test_helper.h" + +namespace tint { +namespace writer { +namespace spirv { +namespace { + +using SpirvScalarConstantTest = TestHelper; + +TEST_F(SpirvScalarConstantTest, Equality) { + ScalarConstant a{}; + ScalarConstant b{}; + EXPECT_EQ(a, b); + + a.kind = ScalarConstant::Kind::kU32; + EXPECT_NE(a, b); + b.kind = ScalarConstant::Kind::kU32; + EXPECT_EQ(a, b); + + a.value.b = true; + EXPECT_NE(a, b); + b.value.b = true; + EXPECT_EQ(a, b); + + a.is_spec_op = true; + EXPECT_NE(a, b); + b.is_spec_op = true; + EXPECT_EQ(a, b); + + a.constant_id = 3; + EXPECT_NE(a, b); + b.constant_id = 3; + EXPECT_EQ(a, b); +} + +TEST_F(SpirvScalarConstantTest, U32) { + auto c = ScalarConstant::U32(123); + EXPECT_EQ(c.value.u32, 123u); + EXPECT_EQ(c.kind, ScalarConstant::Kind::kU32); +} + +} // namespace +} // namespace spirv +} // namespace writer +} // namespace tint diff --git a/test/BUILD.gn b/test/BUILD.gn index b49cf592e3..b69dd73fb4 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -336,6 +336,7 @@ source_set("tint_unittests_spv_writer_src") { "../src/writer/spirv/builder_unary_op_expression_test.cc", "../src/writer/spirv/instruction_test.cc", "../src/writer/spirv/operand_test.cc", + "../src/writer/spirv/scalar_constant_test.cc", "../src/writer/spirv/spv_dump.cc", "../src/writer/spirv/spv_dump.h", "../src/writer/spirv/test_helper.h",