tint/utils: Make Hashmap::Find() safer to use
Don't return a raw pointer to the map entry's value, instead return a new Reference which re-looks up the entry if the map is mutated. Change-Id: I031749785faeac98e2a129a776493cb0371a5cb9 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110540 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
597ad53029
commit
7c6e229a18
|
@ -73,7 +73,7 @@ const tint::Cloneable* CloneContext::CloneCloneable(const Cloneable* object) {
|
|||
}
|
||||
|
||||
// Was Replace() called for this object?
|
||||
if (auto* fn = replacements_.Find(object)) {
|
||||
if (auto fn = replacements_.Find(object)) {
|
||||
return (*fn)();
|
||||
}
|
||||
|
||||
|
|
|
@ -208,7 +208,7 @@ class CloneContext {
|
|||
to.Push(CheckedCast<T>(builder()));
|
||||
}
|
||||
for (auto& el : from) {
|
||||
if (auto* insert_before = transforms->insert_before_.Find(el)) {
|
||||
if (auto insert_before = transforms->insert_before_.Find(el)) {
|
||||
for (auto& builder : *insert_before) {
|
||||
to.Push(CheckedCast<T>(builder()));
|
||||
}
|
||||
|
@ -216,7 +216,7 @@ class CloneContext {
|
|||
if (!transforms->remove_.Contains(el)) {
|
||||
to.Push(Clone(el));
|
||||
}
|
||||
if (auto* insert_after = transforms->insert_after_.Find(el)) {
|
||||
if (auto insert_after = transforms->insert_after_.Find(el)) {
|
||||
for (auto& builder : *insert_after) {
|
||||
to.Push(CheckedCast<T>(builder()));
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ class CloneContext {
|
|||
// Clone(el) may have updated the transformation list, adding an `insert_after`
|
||||
// transform for `from`.
|
||||
if (transforms) {
|
||||
if (auto* insert_after = transforms->insert_after_.Find(el)) {
|
||||
if (auto insert_after = transforms->insert_after_.Find(el)) {
|
||||
for (auto& builder : *insert_after) {
|
||||
to.Push(CheckedCast<T>(builder()));
|
||||
}
|
||||
|
@ -389,7 +389,7 @@ class CloneContext {
|
|||
return *this;
|
||||
}
|
||||
|
||||
list_transforms_.Edit(&vector).remove_.Add(object);
|
||||
list_transforms_.GetOrZero(&vector)->remove_.Add(object);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -411,7 +411,7 @@ class CloneContext {
|
|||
/// @returns this CloneContext so calls can be chained
|
||||
template <typename T, size_t N, typename BUILDER>
|
||||
CloneContext& InsertFront(const utils::Vector<T, N>& vector, BUILDER&& builder) {
|
||||
list_transforms_.Edit(&vector).insert_front_.Push(std::forward<BUILDER>(builder));
|
||||
list_transforms_.GetOrZero(&vector)->insert_front_.Push(std::forward<BUILDER>(builder));
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -434,7 +434,7 @@ class CloneContext {
|
|||
/// @returns this CloneContext so calls can be chained
|
||||
template <typename T, size_t N, typename BUILDER>
|
||||
CloneContext& InsertBack(const utils::Vector<T, N>& vector, BUILDER&& builder) {
|
||||
list_transforms_.Edit(&vector).insert_back_.Push(std::forward<BUILDER>(builder));
|
||||
list_transforms_.GetOrZero(&vector)->insert_back_.Push(std::forward<BUILDER>(builder));
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -456,7 +456,7 @@ class CloneContext {
|
|||
return *this;
|
||||
}
|
||||
|
||||
list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
|
||||
list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
|
||||
[object] { return object; });
|
||||
return *this;
|
||||
}
|
||||
|
@ -475,7 +475,7 @@ class CloneContext {
|
|||
CloneContext& InsertBefore(const utils::Vector<T, N>& vector,
|
||||
const BEFORE* before,
|
||||
BUILDER&& builder) {
|
||||
list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
|
||||
list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
|
||||
std::forward<BUILDER>(builder));
|
||||
return *this;
|
||||
}
|
||||
|
@ -498,7 +498,7 @@ class CloneContext {
|
|||
return *this;
|
||||
}
|
||||
|
||||
list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
|
||||
list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
|
||||
[object] { return object; });
|
||||
return *this;
|
||||
}
|
||||
|
@ -517,7 +517,7 @@ class CloneContext {
|
|||
CloneContext& InsertAfter(const utils::Vector<T, N>& vector,
|
||||
const AFTER* after,
|
||||
BUILDER&& builder) {
|
||||
list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
|
||||
list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
|
||||
std::forward<BUILDER>(builder));
|
||||
return *this;
|
||||
}
|
||||
|
@ -601,61 +601,6 @@ class CloneContext {
|
|||
/// @returns the diagnostic list of #dst
|
||||
diag::List& Diagnostics() const;
|
||||
|
||||
/// VectorListTransforms is a map of utils::Vector pointer to transforms for that list
|
||||
struct VectorListTransforms {
|
||||
using Map = utils::Hashmap<const void*, ListTransforms, 4>;
|
||||
|
||||
/// An accessor to the VectorListTransforms map.
|
||||
/// Index caches the last map lookup, and will only re-search the map if the transform map
|
||||
/// was modified since the last lookup.
|
||||
struct Index {
|
||||
/// @returns true if the map now holds a value for the index
|
||||
operator bool() {
|
||||
Update();
|
||||
return cached_;
|
||||
}
|
||||
|
||||
/// @returns a pointer to the indexed map entry
|
||||
const ListTransforms* operator->() {
|
||||
Update();
|
||||
return cached_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend VectorListTransforms;
|
||||
|
||||
Index(const void* list, Map* map)
|
||||
: list_(list),
|
||||
map_(map),
|
||||
generation_(map->Generation()),
|
||||
cached_(map_->Find(list)) {}
|
||||
|
||||
void Update() {
|
||||
if (map_->Generation() != generation_) {
|
||||
cached_ = map_->Find(list_);
|
||||
generation_ = map_->Generation();
|
||||
}
|
||||
}
|
||||
|
||||
const void* list_;
|
||||
Map* map_;
|
||||
uint64_t generation_;
|
||||
const ListTransforms* cached_;
|
||||
};
|
||||
|
||||
/// Edit returns a reference to the ListTransforms for the given vector pointer and
|
||||
/// increments #list_transform_generation_ signalling that the list transforms have been
|
||||
/// modified.
|
||||
inline ListTransforms& Edit(const void* list) { return map_.GetOrZero(list); }
|
||||
|
||||
/// @returns an Index to the transforms for the given list.
|
||||
inline Index Find(const void* list) { return Index{list, &map_}; }
|
||||
|
||||
private:
|
||||
/// The map of vector pointer to ListTransforms
|
||||
Map map_;
|
||||
};
|
||||
|
||||
/// A map of object in #src to functions that create their replacement in #dst
|
||||
utils::Hashmap<const Cloneable*, std::function<const Cloneable*()>, 8> replacements_;
|
||||
|
||||
|
@ -666,7 +611,7 @@ class CloneContext {
|
|||
utils::Vector<CloneableTransform, 8> transforms_;
|
||||
|
||||
/// Transformations to apply to vectors
|
||||
VectorListTransforms list_transforms_;
|
||||
utils::Hashmap<const void*, ListTransforms, 4> list_transforms_;
|
||||
|
||||
/// Symbol transform registered with ReplaceAll()
|
||||
SymbolTransform symbol_transform_;
|
||||
|
|
|
@ -3522,7 +3522,7 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
|
|||
const auto phi_id = assignment.phi_id;
|
||||
auto* const lhs_expr = builder_.Expr(namer_.Name(phi_id));
|
||||
// If RHS value is actually a phi we just cpatured, then use it.
|
||||
auto* const copy_sym = copied_phis.Find(assignment.value_id);
|
||||
auto copy_sym = copied_phis.Find(assignment.value_id);
|
||||
auto* const rhs_expr =
|
||||
copy_sym ? builder_.Expr(*copy_sym) : MakeExpression(assignment.value_id).expr;
|
||||
AddStatement(builder_.Assign(lhs_expr, rhs_expr));
|
||||
|
|
|
@ -666,7 +666,7 @@ class ParserImpl : Reader {
|
|||
/// @param id a SPIR-V ID
|
||||
/// @returns the AST variable or null.
|
||||
const ast::Var* GetModuleVariable(uint32_t id) {
|
||||
auto* entry = module_variable_.Find(id);
|
||||
auto entry = module_variable_.Find(id);
|
||||
return entry ? *entry : nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -471,7 +471,7 @@ class DependencyScanner {
|
|||
}
|
||||
}
|
||||
|
||||
if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
|
||||
if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
|
||||
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
|
||||
DependencyInfo{from->source, action})) {
|
||||
current_global_->deps.Push(*global);
|
||||
|
|
|
@ -1128,8 +1128,8 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
|
|||
|
||||
if (expect_pass) {
|
||||
// Check that the use resolves to the declaration
|
||||
auto* resolved_symbol = graph.resolved_symbols.Find(use);
|
||||
ASSERT_NE(resolved_symbol, nullptr);
|
||||
auto resolved_symbol = graph.resolved_symbols.Find(use);
|
||||
ASSERT_TRUE(resolved_symbol);
|
||||
EXPECT_EQ(*resolved_symbol, decl)
|
||||
<< "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
|
||||
<< "\n"
|
||||
|
@ -1179,8 +1179,8 @@ TEST_P(ResolverDependencyShadowTest, Test) {
|
|||
helper.Build();
|
||||
|
||||
auto shadows = Build().shadows;
|
||||
auto* shadow = shadows.Find(inner_var);
|
||||
ASSERT_NE(shadow, nullptr);
|
||||
auto shadow = shadows.Find(inner_var);
|
||||
ASSERT_TRUE(shadow);
|
||||
EXPECT_EQ(*shadow, outer);
|
||||
}
|
||||
|
||||
|
@ -1310,8 +1310,8 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
|
|||
|
||||
auto graph = Build();
|
||||
for (auto use : symbol_uses) {
|
||||
auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
|
||||
ASSERT_NE(resolved_symbol, nullptr) << use.where;
|
||||
auto resolved_symbol = graph.resolved_symbols.Find(use.use);
|
||||
ASSERT_TRUE(resolved_symbol) << use.where;
|
||||
EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2481,7 +2481,7 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
|
|||
if (loop_block->FirstContinue()) {
|
||||
// If our identifier is in loop_block->decls, make sure its index is
|
||||
// less than first_continue
|
||||
if (auto* decl = loop_block->Decls().Find(symbol)) {
|
||||
if (auto decl = loop_block->Decls().Find(symbol)) {
|
||||
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
|
||||
AddError("continue statement bypasses declaration of '" +
|
||||
builder_->Symbols().NameFor(symbol) + "'",
|
||||
|
|
|
@ -54,7 +54,7 @@ class SemHelper {
|
|||
/// @param node the node to retrieve
|
||||
template <typename SEM = sem::Node>
|
||||
SEM* ResolvedSymbol(const ast::Node* node) const {
|
||||
auto* resolved = dependencies_.resolved_symbols.Find(node);
|
||||
auto resolved = dependencies_.resolved_symbols.Find(node);
|
||||
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -600,7 +600,7 @@ class UniformityGraph {
|
|||
}
|
||||
|
||||
// Add an edge from the variable's loop input node to its value at this point.
|
||||
auto** in_node = info.var_in_nodes.Find(var);
|
||||
auto in_node = info.var_in_nodes.Find(var);
|
||||
TINT_ASSERT(Resolver, in_node != nullptr);
|
||||
auto* out_node = current_function_->variables.Get(var);
|
||||
if (out_node != *in_node) {
|
||||
|
@ -1334,7 +1334,7 @@ class UniformityGraph {
|
|||
[&](const sem::Function* func) {
|
||||
// We must have already analyzed the user-defined function since we process
|
||||
// functions in dependency order.
|
||||
auto* info = functions_.Find(func->Declaration());
|
||||
auto info = functions_.Find(func->Declaration());
|
||||
TINT_ASSERT(Resolver, info != nullptr);
|
||||
callsite_tag = info->callsite_tag;
|
||||
function_tag = info->function_tag;
|
||||
|
@ -1466,7 +1466,7 @@ class UniformityGraph {
|
|||
} else if (auto* user = target->As<sem::Function>()) {
|
||||
// This is a call to a user-defined function, so inspect the functions called by that
|
||||
// function and look for one whose node has an edge from the RequiredToBeUniform node.
|
||||
auto* target_info = functions_.Find(user->Declaration());
|
||||
auto target_info = functions_.Find(user->Declaration());
|
||||
for (auto* call_node : target_info->required_to_be_uniform->edges) {
|
||||
if (call_node->type == Node::kRegular) {
|
||||
auto* child_call = call_node->ast->As<ast::CallExpression>();
|
||||
|
@ -1636,7 +1636,7 @@ class UniformityGraph {
|
|||
// If this is a call to a user-defined function, add a note to show the reason that the
|
||||
// parameter is required to be uniform.
|
||||
if (auto* user = target->As<sem::Function>()) {
|
||||
auto* next_function = functions_.Find(user->Declaration());
|
||||
auto next_function = functions_.Find(user->Declaration());
|
||||
Node* next_cause = next_function->parameters[cause->arg_index].init_value;
|
||||
MakeError(*next_function, next_cause, true);
|
||||
}
|
||||
|
|
|
@ -719,7 +719,7 @@ bool Validator::AtomicVariable(
|
|||
return false;
|
||||
}
|
||||
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
|
||||
if (auto* found = atomic_composite_info.Find(type)) {
|
||||
if (auto found = atomic_composite_info.Find(type)) {
|
||||
if (address_space != ast::AddressSpace::kStorage &&
|
||||
address_space != ast::AddressSpace::kWorkgroup) {
|
||||
AddError("atomic variables must have <storage> or <workgroup> address space",
|
||||
|
@ -798,7 +798,7 @@ bool Validator::Override(
|
|||
for (auto* attr : decl->attributes) {
|
||||
if (attr->Is<ast::IdAttribute>()) {
|
||||
auto id = v->OverrideId();
|
||||
if (auto* var = override_ids.Find(id); var && *var != v) {
|
||||
if (auto var = override_ids.Find(id); var && *var != v) {
|
||||
AddError("@id values must be unique", attr->source);
|
||||
AddNote(
|
||||
"a override with an ID of " + std::to_string(id.value) +
|
||||
|
|
|
@ -44,7 +44,7 @@ class ScopeStack {
|
|||
/// stack, otherwise the zero initializer for type T.
|
||||
V Set(const K& key, V val) {
|
||||
auto& back = stack_.Back();
|
||||
if (auto* el = back.Find(key)) {
|
||||
if (auto el = back.Find(key)) {
|
||||
std::swap(val, *el);
|
||||
return val;
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ class ScopeStack {
|
|||
/// @returns the value, or the zero initializer if the value was not found
|
||||
V Get(const K& key) const {
|
||||
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
|
||||
if (auto* val = iter->Find(key)) {
|
||||
if (auto val = iter->Find(key)) {
|
||||
return *val;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -129,7 +129,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
|||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
|
||||
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
|
||||
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
|
||||
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
|
||||
if (auto info = decomposed.Find(access->Member()->Declaration())) {
|
||||
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
|
||||
auto name =
|
||||
b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
|
||||
|
@ -168,7 +168,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
|||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
|
||||
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
|
||||
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
|
||||
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
|
||||
if (auto info = decomposed.Find(access->Member()->Declaration())) {
|
||||
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
|
||||
auto name =
|
||||
b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
|
||||
|
|
|
@ -140,7 +140,7 @@ struct SimplifyPointers::State {
|
|||
// variable identifier.
|
||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
||||
// Look to see if we need to swap this Expression with a saved variable.
|
||||
if (auto* saved_var = saved_vars.Find(expr)) {
|
||||
if (auto saved_var = saved_vars.Find(expr)) {
|
||||
return ctx.dst->Expr(*saved_var);
|
||||
}
|
||||
|
||||
|
|
|
@ -401,7 +401,7 @@ struct Std140::State {
|
|||
return Switch(
|
||||
ty, //
|
||||
[&](const sem::Struct* str) -> const ast::Type* {
|
||||
if (auto* std140 = std140_structs.Find(str)) {
|
||||
if (auto std140 = std140_structs.Find(str)) {
|
||||
return b.create<ast::TypeName>(*std140);
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -695,7 +695,7 @@ struct Std140::State {
|
|||
// call, or by reassembling a std140 matrix from column vector members.
|
||||
utils::Vector<const ast::Expression*, 8> args;
|
||||
for (auto* member : str->Members()) {
|
||||
if (auto* col_members = std140_mat_members.Find(member)) {
|
||||
if (auto col_members = std140_mat_members.Find(member)) {
|
||||
// std140 decomposed matrix. Reassemble.
|
||||
auto* mat_ty = CreateASTTypeFor(ctx, member->Type());
|
||||
auto mat_args =
|
||||
|
|
|
@ -161,7 +161,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
|
|||
ctx.ReplaceAll(
|
||||
[&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* {
|
||||
auto* return_sem = sem.Get(return_statement);
|
||||
if (auto* mapping_fn_sym =
|
||||
if (auto mapping_fn_sym =
|
||||
entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
|
||||
return b.Return(return_statement->source,
|
||||
b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value)));
|
||||
|
|
|
@ -97,7 +97,7 @@ struct Unshadow::State {
|
|||
ctx.ReplaceAll(
|
||||
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
|
||||
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
|
||||
if (auto* renamed = renamed_to.Find(user->Variable())) {
|
||||
if (auto renamed = renamed_to.Find(user->Variable())) {
|
||||
return b.Expr(*renamed);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -135,7 +135,7 @@ struct HoistToDeclBefore::State {
|
|||
/// automatically called.
|
||||
/// @warning the returned reference is invalid if this is called a second time, or the
|
||||
/// #for_loops map is mutated.
|
||||
LoopInfo& ForLoop(const sem::ForLoopStatement* for_loop) {
|
||||
auto ForLoop(const sem::ForLoopStatement* for_loop) {
|
||||
if (for_loops.IsEmpty()) {
|
||||
RegisterForLoopTransform();
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ struct HoistToDeclBefore::State {
|
|||
/// automatically called.
|
||||
/// @warning the returned reference is invalid if this is called a second time, or the
|
||||
/// #for_loops map is mutated.
|
||||
LoopInfo& WhileLoop(const sem::WhileStatement* while_loop) {
|
||||
auto WhileLoop(const sem::WhileStatement* while_loop) {
|
||||
if (while_loops.IsEmpty()) {
|
||||
RegisterWhileLoopTransform();
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ struct HoistToDeclBefore::State {
|
|||
/// automatically called.
|
||||
/// @warning the returned reference is invalid if this is called a second time, or the
|
||||
/// #else_ifs map is mutated.
|
||||
ElseIfInfo& ElseIf(const ast::IfStatement* else_if) {
|
||||
auto ElseIf(const ast::IfStatement* else_if) {
|
||||
if (else_ifs.IsEmpty()) {
|
||||
RegisterElseIfTransform();
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ struct HoistToDeclBefore::State {
|
|||
auto& sem = ctx.src->Sem();
|
||||
|
||||
if (auto* fl = sem.Get(stmt)) {
|
||||
if (auto* info = for_loops.Find(fl)) {
|
||||
if (auto info = for_loops.Find(fl)) {
|
||||
auto* for_loop = fl->Declaration();
|
||||
// For-loop needs to be decomposed to a loop.
|
||||
// Build the loop body's statements.
|
||||
|
@ -222,7 +222,7 @@ struct HoistToDeclBefore::State {
|
|||
auto& sem = ctx.src->Sem();
|
||||
|
||||
if (auto* w = sem.Get(stmt)) {
|
||||
if (auto* info = while_loops.Find(w)) {
|
||||
if (auto info = while_loops.Find(w)) {
|
||||
auto* while_loop = w->Declaration();
|
||||
// While needs to be decomposed to a loop.
|
||||
// Build the loop body's statements.
|
||||
|
@ -259,7 +259,7 @@ struct HoistToDeclBefore::State {
|
|||
void RegisterElseIfTransform() const {
|
||||
// Decompose 'else-if' statements into 'else { if }' blocks.
|
||||
ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* {
|
||||
if (auto* info = else_ifs.Find(stmt)) {
|
||||
if (auto info = else_ifs.Find(stmt)) {
|
||||
// Build the else block's body statements, starting with let decls for the
|
||||
// conditional expression.
|
||||
auto body_stmts = Build(info->cond_decls);
|
||||
|
@ -291,10 +291,10 @@ struct HoistToDeclBefore::State {
|
|||
if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
|
||||
// Insertion point is an 'else if' condition.
|
||||
// Need to convert 'else if' to 'else { if }'.
|
||||
auto& else_if_info = ElseIf(else_if->Declaration());
|
||||
auto else_if_info = ElseIf(else_if->Declaration());
|
||||
|
||||
// Index the map to convert this else if, even if `stmt` is nullptr.
|
||||
auto& decls = else_if_info.cond_decls;
|
||||
auto& decls = else_if_info->cond_decls;
|
||||
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
|
||||
decls.Push(std::forward<BUILDER>(builder));
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ struct HoistToDeclBefore::State {
|
|||
// For-loop needs to be decomposed to a loop.
|
||||
|
||||
// Index the map to convert this for-loop, even if `stmt` is nullptr.
|
||||
auto& decls = ForLoop(fl).cond_decls;
|
||||
auto& decls = ForLoop(fl)->cond_decls;
|
||||
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
|
||||
decls.Push(std::forward<BUILDER>(builder));
|
||||
}
|
||||
|
@ -318,7 +318,7 @@ struct HoistToDeclBefore::State {
|
|||
// While needs to be decomposed to a loop.
|
||||
|
||||
// Index the map to convert this while, even if `stmt` is nullptr.
|
||||
auto& decls = WhileLoop(w).cond_decls;
|
||||
auto& decls = WhileLoop(w)->cond_decls;
|
||||
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
|
||||
decls.Push(std::forward<BUILDER>(builder));
|
||||
}
|
||||
|
@ -354,7 +354,7 @@ struct HoistToDeclBefore::State {
|
|||
// For-loop needs to be decomposed to a loop.
|
||||
|
||||
// Index the map to convert this for-loop, even if `stmt` is nullptr.
|
||||
auto& decls = ForLoop(fl).cont_decls;
|
||||
auto& decls = ForLoop(fl)->cont_decls;
|
||||
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
|
||||
decls.Push(std::forward<BUILDER>(builder));
|
||||
}
|
||||
|
|
|
@ -47,6 +47,67 @@ class Hashmap : public HashmapBase<KEY, VALUE, N, HASH, EQUAL> {
|
|||
/// Result of Add()
|
||||
using AddResult = typename Base::PutResult;
|
||||
|
||||
/// Reference is returned by Hashmap::Find(), and performs dynamic Hashmap lookups.
|
||||
/// The value returned by the Reference reflects the current state of the Hashmap, and so the
|
||||
/// referenced value may change, or transition between valid or invalid based on the current
|
||||
/// state of the Hashmap.
|
||||
template <bool IS_CONST>
|
||||
class ReferenceT {
|
||||
/// `const Value` if IS_CONST, or `Value` if !IS_CONST
|
||||
using T = std::conditional_t<IS_CONST, const Value, Value>;
|
||||
|
||||
/// `const Hashmap` if IS_CONST, or `Hashmap` if !IS_CONST
|
||||
using Map = std::conditional_t<IS_CONST, const Hashmap, Hashmap>;
|
||||
|
||||
public:
|
||||
/// @returns true if the reference is valid.
|
||||
operator bool() const { return Get() != nullptr; }
|
||||
|
||||
/// @returns the pointer to the Value, or nullptr if the reference is invalid.
|
||||
operator T*() const { return Get(); }
|
||||
|
||||
/// @returns the pointer to the Value
|
||||
/// @warning if the Hashmap does not contain a value for the reference, then this will
|
||||
/// trigger a TINT_ASSERT, or invalid pointer dereference.
|
||||
T* operator->() const {
|
||||
auto* hashmap_reference_lookup = Get();
|
||||
TINT_ASSERT(Utils, hashmap_reference_lookup != nullptr);
|
||||
return hashmap_reference_lookup;
|
||||
}
|
||||
|
||||
/// @returns the pointer to the Value, or nullptr if the reference is invalid.
|
||||
T* Get() const {
|
||||
auto generation = map_.Generation();
|
||||
if (generation_ != generation) {
|
||||
cached_ = map_.Lookup(key_);
|
||||
generation_ = generation;
|
||||
}
|
||||
return cached_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend Hashmap;
|
||||
|
||||
/// Constructor
|
||||
ReferenceT(Map& map, const Key& key)
|
||||
: map_(map), key_(key), cached_(nullptr), generation_(map.Generation() - 1) {}
|
||||
|
||||
/// Constructor
|
||||
ReferenceT(Map& map, const Key& key, T* value)
|
||||
: map_(map), key_(key), cached_(value), generation_(map.Generation()) {}
|
||||
|
||||
Map& map_;
|
||||
const Key key_;
|
||||
mutable T* cached_ = nullptr;
|
||||
mutable size_t generation_ = 0;
|
||||
};
|
||||
|
||||
/// A mutable reference returned by Find()
|
||||
using Reference = ReferenceT</*IS_CONST*/ false>;
|
||||
|
||||
/// An immutable reference returned by Find()
|
||||
using ConstReference = ReferenceT</*IS_CONST*/ true>;
|
||||
|
||||
/// Adds a value to the map, if the map does not already contain an entry with the key @p key.
|
||||
/// @param key the entry key.
|
||||
/// @param value the value of the entry to add to the map.
|
||||
|
@ -108,25 +169,28 @@ class Hashmap : public HashmapBase<KEY, VALUE, N, HASH, EQUAL> {
|
|||
/// @param key the entry's key value to search for.
|
||||
/// @returns the value of the entry.
|
||||
template <typename K>
|
||||
Value& GetOrZero(K&& key) {
|
||||
Reference GetOrZero(K&& key) {
|
||||
auto res = Add(std::forward<K>(key), Value{});
|
||||
return *res.value;
|
||||
return Reference(*this, key, res.value);
|
||||
}
|
||||
|
||||
/// @param key the key to search for.
|
||||
/// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
|
||||
/// not contain the given value.
|
||||
const Value* Find(const Key& key) const {
|
||||
/// @returns a reference to the entry that is equal to the given value.
|
||||
Reference Find(const Key& key) { return Reference(*this, key); }
|
||||
|
||||
/// @param key the key to search for.
|
||||
/// @returns a reference to the entry that is equal to the given value.
|
||||
ConstReference Find(const Key& key) const { return ConstReference(*this, key); }
|
||||
|
||||
private:
|
||||
Value* Lookup(const Key& key) {
|
||||
if (auto [found, index] = this->IndexOf(key); found) {
|
||||
return &this->slots_[index].entry->value;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// @param key the key to search for.
|
||||
/// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
|
||||
/// not contain the given value.
|
||||
Value* Find(const Key& key) {
|
||||
const Value* Lookup(const Key& key) const {
|
||||
if (auto [found, index] = this->IndexOf(key); found) {
|
||||
return &this->slots_[index].entry->value;
|
||||
}
|
||||
|
|
|
@ -90,6 +90,71 @@ TEST(Hashmap, Generation) {
|
|||
EXPECT_EQ(map.Generation(), 5u);
|
||||
}
|
||||
|
||||
TEST(Hashmap, Index) {
|
||||
Hashmap<int, std::string, 4> map;
|
||||
auto zero = map.Find(0);
|
||||
EXPECT_FALSE(zero);
|
||||
|
||||
map.Add(3, "three");
|
||||
auto three = map.Find(3);
|
||||
map.Add(2, "two");
|
||||
auto two = map.Find(2);
|
||||
map.Add(4, "four");
|
||||
auto four = map.Find(4);
|
||||
map.Add(8, "eight");
|
||||
auto eight = map.Find(8);
|
||||
|
||||
EXPECT_FALSE(zero);
|
||||
ASSERT_TRUE(three);
|
||||
ASSERT_TRUE(two);
|
||||
ASSERT_TRUE(four);
|
||||
ASSERT_TRUE(eight);
|
||||
|
||||
EXPECT_EQ(*three, "three");
|
||||
EXPECT_EQ(*two, "two");
|
||||
EXPECT_EQ(*four, "four");
|
||||
EXPECT_EQ(*eight, "eight");
|
||||
|
||||
map.Add(0, "zero"); // Note: Find called before Add() is okay!
|
||||
|
||||
map.Add(5, "five");
|
||||
auto five = map.Find(5);
|
||||
map.Add(6, "six");
|
||||
auto six = map.Find(6);
|
||||
map.Add(1, "one");
|
||||
auto one = map.Find(1);
|
||||
map.Add(7, "seven");
|
||||
auto seven = map.Find(7);
|
||||
|
||||
ASSERT_TRUE(zero);
|
||||
ASSERT_TRUE(three);
|
||||
ASSERT_TRUE(two);
|
||||
ASSERT_TRUE(four);
|
||||
ASSERT_TRUE(eight);
|
||||
ASSERT_TRUE(five);
|
||||
ASSERT_TRUE(six);
|
||||
ASSERT_TRUE(one);
|
||||
ASSERT_TRUE(seven);
|
||||
|
||||
EXPECT_EQ(*zero, "zero");
|
||||
EXPECT_EQ(*three, "three");
|
||||
EXPECT_EQ(*two, "two");
|
||||
EXPECT_EQ(*four, "four");
|
||||
EXPECT_EQ(*eight, "eight");
|
||||
EXPECT_EQ(*five, "five");
|
||||
EXPECT_EQ(*six, "six");
|
||||
EXPECT_EQ(*one, "one");
|
||||
EXPECT_EQ(*seven, "seven");
|
||||
|
||||
map.Remove(2);
|
||||
map.Remove(8);
|
||||
map.Remove(1);
|
||||
|
||||
EXPECT_FALSE(two);
|
||||
EXPECT_FALSE(eight);
|
||||
EXPECT_FALSE(one);
|
||||
}
|
||||
|
||||
TEST(Hashmap, Iterator) {
|
||||
using Map = Hashmap<int, std::string, 8>;
|
||||
using Entry = typename Map::Entry;
|
||||
|
|
Loading…
Reference in New Issue