writer/spirv: Clean up constant generation

Requiring a temporary stack-allocated ast::Literal is an unpleasant requirement to generate a SPIR-V constant value.
GenerateU32Literal() was also creating an invalid AST - the type was U32, yet an an ast::SintLiteral was used.

Instead add Constant for holding a constant value, and use this as the map key.

This also removes the last remaining use of ast::NullLiteral, which will be removed in the next change.

Change-Id: Ia85732784075f153503dbef101ba95018eaa4bf5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45342
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-03-22 15:32:35 +00:00 committed by Commit Bot service account
parent d6fe74b01f
commit eae161cd9f
7 changed files with 318 additions and 67 deletions

View File

@ -558,6 +558,7 @@ source_set("libtint_spv_writer_src") {
"writer/spirv/instruction.h",
"writer/spirv/operand.cc",
"writer/spirv/operand.h",
"writer/spirv/scalar_constant.h",
]
configs += [ ":tint_common_config" ]

View File

@ -343,6 +343,7 @@ if(${TINT_BUILD_SPV_WRITER})
writer/spirv/instruction.h
writer/spirv/operand.cc
writer/spirv/operand.h
writer/spirv/scalar_constant.h
)
endif()
@ -678,6 +679,7 @@ if(${TINT_BUILD_TESTS})
writer/spirv/builder_unary_op_expression_test.cc
writer/spirv/instruction_test.cc
writer/spirv/operand_test.cc
writer/spirv/scalar_constant_test.cc
writer/spirv/spv_dump.cc
writer/spirv/spv_dump.h
writer/spirv/test_helper.h

View File

@ -21,7 +21,6 @@
#include "src/ast/call_statement.h"
#include "src/ast/constant_id_decoration.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/null_literal.h"
#include "src/semantic/array.h"
#include "src/semantic/call.h"
#include "src/semantic/function.h"
@ -365,12 +364,6 @@ bool Builder::GenerateLabel(uint32_t id) {
return true;
}
uint32_t Builder::GenerateU32Literal(uint32_t val) {
type::U32 u32;
ast::SintLiteral lit(Source{}, &u32, val);
return GenerateLiteralIfNeeded(nullptr, &lit);
}
bool Builder::GenerateAssignStatement(ast::AssignmentStatement* assign) {
auto lhs_id = GenerateExpression(assign->lhs());
if (lhs_id == 0) {
@ -648,8 +641,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) {
// TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad.
ast::NullLiteral nl(Source{}, var->type()->UnwrapPtrIfNeeded());
auto null_id = GenerateLiteralIfNeeded(var, &nl);
auto null_id = GenerateConstantNullIfNeeded(var->type()->UnwrapPtrIfNeeded());
if (null_id == 0) {
return 0;
}
@ -779,8 +771,7 @@ bool Builder::GenerateGlobalVariable(ast::Variable* var) {
} else if (sem->StorageClass() == ast::StorageClass::kPrivate ||
sem->StorageClass() == ast::StorageClass::kNone ||
sem->StorageClass() == ast::StorageClass::kOutput) {
ast::NullLiteral nl(Source{}, type);
init_id = GenerateLiteralIfNeeded(var, &nl);
init_id = GenerateConstantNullIfNeeded(type);
if (init_id == 0) {
return 0;
}
@ -888,7 +879,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
}
}
auto idx_id = GenerateU32Literal(i);
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
if (idx_id == 0) {
return 0;
}
@ -913,7 +904,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
}
if (info->source_type->Is<type::Pointer>()) {
auto idx_id = GenerateU32Literal(val);
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(val));
if (idx_id == 0) {
return 0;
}
@ -1044,8 +1035,7 @@ uint32_t Builder::GenerateAccessorExpression(ast::Expression* expr) {
auto ary_result = result_op();
ast::NullLiteral nl(Source{}, ary_res_type);
auto init = GenerateLiteralIfNeeded(nullptr, &nl);
auto init = GenerateConstantNullIfNeeded(ary_res_type);
// If we're access chaining into an array then we must be in a function
push_function_var(
@ -1259,8 +1249,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
// Generate the zero initializer if there are no values provided.
if (values.empty()) {
ast::NullLiteral nl(Source{}, init->type()->UnwrapPtrIfNeeded());
return GenerateLiteralIfNeeded(nullptr, &nl);
return GenerateConstantNullIfNeeded(init->type()->UnwrapPtrIfNeeded());
}
std::ostringstream out;
@ -1370,7 +1359,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
result_is_constant_composite = false;
} else {
// A global initializer, must use OpSpecConstantOp. Case 1.
auto idx_id = GenerateU32Literal(i);
auto idx_id = GenerateConstantIfNeeded(ScalarConstant::U32(i));
if (idx_id == 0) {
return 0;
}
@ -1392,8 +1381,8 @@ uint32_t Builder::GenerateTypeConstructorExpression(
}
auto str = out.str();
auto val = const_to_id_.find(str);
if (val != const_to_id_.end()) {
auto val = type_constructor_to_id_.find(str);
if (val != type_constructor_to_id_.end()) {
return val->second;
}
@ -1401,7 +1390,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
ops.insert(ops.begin(), result);
ops.insert(ops.begin(), Operand::Int(type_id));
const_to_id_[str] = result.to_i();
type_constructor_to_id_[str] = result.to_i();
if (result_is_spec_composite) {
push_type(spv::Op::OpSpecConstantComposite, ops);
@ -1480,59 +1469,133 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(type::Type* to_type,
uint32_t Builder::GenerateLiteralIfNeeded(ast::Variable* var,
ast::Literal* lit) {
auto type_id = GenerateTypeIfNeeded(lit->type());
if (type_id == 0) {
return 0;
}
ScalarConstant constant;
auto name = lit->name();
bool is_spec_constant = false;
if (var && var->HasConstantIdDecoration()) {
name = "__spec" + name;
is_spec_constant = true;
}
auto val = const_to_id_.find(name);
if (val != const_to_id_.end()) {
return val->second;
}
auto result = result_op();
auto result_id = result.to_i();
if (is_spec_constant) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
Operand::Int(var->constant_id())});
constant.is_spec_op = true;
constant.constant_id = var->constant_id();
}
if (auto* l = lit->As<ast::BoolLiteral>()) {
if (l->IsTrue()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstantTrue
: spv::Op::OpConstantTrue,
{Operand::Int(type_id), result});
} else {
push_type(is_spec_constant ? spv::Op::OpSpecConstantFalse
: spv::Op::OpConstantFalse,
{Operand::Int(type_id), result});
}
constant.kind = ScalarConstant::Kind::kBool;
constant.value.b = l->IsTrue();
} else if (auto* sl = lit->As<ast::SintLiteral>()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Int(sl->value())});
constant.kind = ScalarConstant::Kind::kI32;
constant.value.i32 = sl->value();
} else if (auto* ul = lit->As<ast::UintLiteral>()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Int(ul->value())});
constant.kind = ScalarConstant::Kind::kU32;
constant.value.u32 = ul->value();
} else if (auto* fl = lit->As<ast::FloatLiteral>()) {
push_type(is_spec_constant ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Float(fl->value())});
} else if (lit->Is<ast::NullLiteral>()) {
push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
constant.kind = ScalarConstant::Kind::kF32;
constant.value.f32 = fl->value();
} else {
error_ = "unknown literal type";
return 0;
}
const_to_id_[name] = result_id;
return GenerateConstantIfNeeded(constant);
}
uint32_t Builder::GenerateConstantIfNeeded(const ScalarConstant& constant) {
auto it = const_to_id_.find(constant);
if (it != const_to_id_.end()) {
return it->second;
}
uint32_t type_id = 0;
switch (constant.kind) {
case ScalarConstant::Kind::kU32: {
type::U32 u32;
type_id = GenerateTypeIfNeeded(&u32);
break;
}
case ScalarConstant::Kind::kI32: {
type::I32 i32;
type_id = GenerateTypeIfNeeded(&i32);
break;
}
case ScalarConstant::Kind::kF32: {
type::F32 f32;
type_id = GenerateTypeIfNeeded(&f32);
break;
}
case ScalarConstant::Kind::kBool: {
type::Bool bool_;
type_id = GenerateTypeIfNeeded(&bool_);
break;
}
}
if (type_id == 0) {
return 0;
}
auto result = result_op();
auto result_id = result.to_i();
if (constant.is_spec_op) {
push_annot(spv::Op::OpDecorate,
{Operand::Int(result_id), Operand::Int(SpvDecorationSpecId),
Operand::Int(constant.constant_id)});
}
switch (constant.kind) {
case ScalarConstant::Kind::kU32: {
push_type(
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Int(constant.value.u32)});
break;
}
case ScalarConstant::Kind::kI32: {
push_type(
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Int(constant.value.i32)});
break;
}
case ScalarConstant::Kind::kF32: {
push_type(
constant.is_spec_op ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
{Operand::Int(type_id), result, Operand::Float(constant.value.f32)});
break;
}
case ScalarConstant::Kind::kBool: {
if (constant.value.b) {
push_type(constant.is_spec_op ? spv::Op::OpSpecConstantTrue
: spv::Op::OpConstantTrue,
{Operand::Int(type_id), result});
} else {
push_type(constant.is_spec_op ? spv::Op::OpSpecConstantFalse
: spv::Op::OpConstantFalse,
{Operand::Int(type_id), result});
}
break;
}
}
const_to_id_[constant] = result_id;
return result_id;
}
uint32_t Builder::GenerateConstantNullIfNeeded(type::Type* type) {
auto type_id = GenerateTypeIfNeeded(type);
if (type_id == 0) {
return 0;
}
auto name = type->type_name();
auto it = const_null_to_id_.find(name);
if (it != const_null_to_id_.end()) {
return it->second;
}
auto result = result_op();
auto result_id = result.to_i();
push_type(spv::Op::OpConstantNull, {Operand::Int(type_id), result});
const_null_to_id_[name] = result_id;
return result_id;
}
@ -2955,7 +3018,7 @@ bool Builder::GenerateArrayType(type::Array* ary, const Operand& result) {
if (ary->IsRuntimeArray()) {
push_type(spv::Op::OpTypeRuntimeArray, {result, Operand::Int(elem_type)});
} else {
auto len_id = GenerateU32Literal(ary->size());
auto len_id = GenerateConstantIfNeeded(ScalarConstant::U32(ary->size()));
if (len_id == 0) {
return false;
}

View File

@ -37,6 +37,7 @@
#include "src/type/access_control_type.h"
#include "src/type/storage_texture_type.h"
#include "src/writer/spirv/function.h"
#include "src/writer/spirv/scalar_constant.h"
namespace tint {
@ -208,10 +209,6 @@ class Builder {
/// @param id the id to use for the label
/// @returns true on success.
bool GenerateLabel(uint32_t id);
/// Generates a uint32_t literal.
/// @param val the value to generate
/// @returns the ID of the generated literal
uint32_t GenerateU32Literal(uint32_t val);
/// Generates an assignment statement
/// @param assign the statement to generate
/// @returns true if the statement was successfully generated
@ -486,6 +483,16 @@ class Builder {
return builder_.TypeOf(expr);
}
/// Generates a constant if needed
/// @param constant the constant to generate.
/// @returns the ID on success or 0 on failure
uint32_t GenerateConstantIfNeeded(const ScalarConstant& constant);
/// Generates a constant-null of the given type, if needed
/// @param type the type of the constant null to generate.
/// @returns the ID on success or 0 on failure
uint32_t GenerateConstantNullIfNeeded(type::Type* type);
ProgramBuilder builder_;
std::string error_;
uint32_t next_id_ = 1;
@ -504,7 +511,9 @@ class Builder {
std::unordered_map<std::string, uint32_t> import_name_to_id_;
std::unordered_map<Symbol, uint32_t> func_symbol_to_id_;
std::unordered_map<std::string, uint32_t> type_name_to_id_;
std::unordered_map<std::string, uint32_t> const_to_id_;
std::unordered_map<ScalarConstant, uint32_t> const_to_id_;
std::unordered_map<std::string, uint32_t> type_constructor_to_id_;
std::unordered_map<std::string, uint32_t> const_null_to_id_;
std::unordered_map<std::string, uint32_t>
texture_type_name_to_sampled_image_type_id_;
ScopeStack<uint32_t> scope_stack_;

View File

@ -0,0 +1,115 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
#define SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_
#include <stdint.h>
#include <cstring>
#include <functional>
namespace tint {
// Forward declarations
namespace semantic {
class Call;
} // namespace semantic
namespace writer {
namespace spirv {
/// ScalarConstant represents a scalar constant value
struct ScalarConstant {
/// The constant value
union Value {
/// The value as a bool
bool b;
/// The value as a uint32_t
uint32_t u32;
/// The value as a int32_t
int32_t i32;
/// The value as a float
float f32;
/// The value that is wide enough to encompass all other types (including
/// future 64-bit data types).
uint64_t u64;
};
/// The kind of constant
enum class Kind { kBool, kU32, kI32, kF32 };
/// Constructor
inline ScalarConstant() { value.u64 = 0; }
/// @param value the value of the constant
/// @returns a new ScalarConstant with the provided value and kind Kind::kU32
static inline ScalarConstant U32(uint32_t value) {
ScalarConstant c;
c.value.u32 = value;
c.kind = Kind::kU32;
return c;
}
/// Equality operator
/// @param rhs the ScalarConstant to compare against
/// @returns true if this ScalarConstant is equal to `rhs`
inline bool operator==(const ScalarConstant& rhs) const {
return value.u64 == rhs.value.u64 && kind == rhs.kind &&
is_spec_op == rhs.is_spec_op && constant_id == rhs.constant_id;
}
/// Inequality operator
/// @param rhs the ScalarConstant to compare against
/// @returns true if this ScalarConstant is not equal to `rhs`
inline bool operator!=(const ScalarConstant& rhs) const {
return !(*this == rhs);
}
/// The constant value
Value value;
/// The constant value kind
Kind kind = Kind::kBool;
/// True if the constant is a specialization op
bool is_spec_op = false;
/// The identifier if a specialization op
uint32_t constant_id = 0;
};
} // namespace spirv
} // namespace writer
} // namespace tint
namespace std {
/// Custom std::hash specialization for tint::Symbol so symbols can be used as
/// keys for std::unordered_map and std::unordered_set.
template <>
class hash<tint::writer::spirv::ScalarConstant> {
public:
/// @param c the ScalarConstant
/// @return the Symbol internal value
inline std::size_t operator()(
const tint::writer::spirv::ScalarConstant& c) const {
uint32_t value = 0;
std::memcpy(&value, &c.value, sizeof(value));
return (static_cast<std::size_t>(value) << 2) |
(static_cast<std::size_t>(c.kind) & 3);
}
};
} // namespace std
#endif // SRC_WRITER_SPIRV_SCALAR_CONSTANT_H_

View File

@ -0,0 +1,60 @@
// Copyright 2020 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/writer/spirv/scalar_constant.h"
#include "src/writer/spirv/test_helper.h"
namespace tint {
namespace writer {
namespace spirv {
namespace {
using SpirvScalarConstantTest = TestHelper;
TEST_F(SpirvScalarConstantTest, Equality) {
ScalarConstant a{};
ScalarConstant b{};
EXPECT_EQ(a, b);
a.kind = ScalarConstant::Kind::kU32;
EXPECT_NE(a, b);
b.kind = ScalarConstant::Kind::kU32;
EXPECT_EQ(a, b);
a.value.b = true;
EXPECT_NE(a, b);
b.value.b = true;
EXPECT_EQ(a, b);
a.is_spec_op = true;
EXPECT_NE(a, b);
b.is_spec_op = true;
EXPECT_EQ(a, b);
a.constant_id = 3;
EXPECT_NE(a, b);
b.constant_id = 3;
EXPECT_EQ(a, b);
}
TEST_F(SpirvScalarConstantTest, U32) {
auto c = ScalarConstant::U32(123);
EXPECT_EQ(c.value.u32, 123u);
EXPECT_EQ(c.kind, ScalarConstant::Kind::kU32);
}
} // namespace
} // namespace spirv
} // namespace writer
} // namespace tint

View File

@ -336,6 +336,7 @@ source_set("tint_unittests_spv_writer_src") {
"../src/writer/spirv/builder_unary_op_expression_test.cc",
"../src/writer/spirv/instruction_test.cc",
"../src/writer/spirv/operand_test.cc",
"../src/writer/spirv/scalar_constant_test.cc",
"../src/writer/spirv/spv_dump.cc",
"../src/writer/spirv/spv_dump.h",
"../src/writer/spirv/test_helper.h",