tint/utils: Make Hashmap::Find() safer to use

Don't return a raw pointer to the map entry's value, instead return a new Reference which re-looks up the entry if the map is mutated.

Change-Id: I031749785faeac98e2a129a776493cb0371a5cb9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110540
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-11-23 21:04:25 +00:00 committed by Dawn LUCI CQ
parent 597ad53029
commit 7c6e229a18
19 changed files with 187 additions and 113 deletions

View File

@ -73,7 +73,7 @@ const tint::Cloneable* CloneContext::CloneCloneable(const Cloneable* object) {
}
// Was Replace() called for this object?
if (auto* fn = replacements_.Find(object)) {
if (auto fn = replacements_.Find(object)) {
return (*fn)();
}

View File

@ -208,7 +208,7 @@ class CloneContext {
to.Push(CheckedCast<T>(builder()));
}
for (auto& el : from) {
if (auto* insert_before = transforms->insert_before_.Find(el)) {
if (auto insert_before = transforms->insert_before_.Find(el)) {
for (auto& builder : *insert_before) {
to.Push(CheckedCast<T>(builder()));
}
@ -216,7 +216,7 @@ class CloneContext {
if (!transforms->remove_.Contains(el)) {
to.Push(Clone(el));
}
if (auto* insert_after = transforms->insert_after_.Find(el)) {
if (auto insert_after = transforms->insert_after_.Find(el)) {
for (auto& builder : *insert_after) {
to.Push(CheckedCast<T>(builder()));
}
@ -232,7 +232,7 @@ class CloneContext {
// Clone(el) may have updated the transformation list, adding an `insert_after`
// transform for `from`.
if (transforms) {
if (auto* insert_after = transforms->insert_after_.Find(el)) {
if (auto insert_after = transforms->insert_after_.Find(el)) {
for (auto& builder : *insert_after) {
to.Push(CheckedCast<T>(builder()));
}
@ -389,7 +389,7 @@ class CloneContext {
return *this;
}
list_transforms_.Edit(&vector).remove_.Add(object);
list_transforms_.GetOrZero(&vector)->remove_.Add(object);
return *this;
}
@ -411,7 +411,7 @@ class CloneContext {
/// @returns this CloneContext so calls can be chained
template <typename T, size_t N, typename BUILDER>
CloneContext& InsertFront(const utils::Vector<T, N>& vector, BUILDER&& builder) {
list_transforms_.Edit(&vector).insert_front_.Push(std::forward<BUILDER>(builder));
list_transforms_.GetOrZero(&vector)->insert_front_.Push(std::forward<BUILDER>(builder));
return *this;
}
@ -434,7 +434,7 @@ class CloneContext {
/// @returns this CloneContext so calls can be chained
template <typename T, size_t N, typename BUILDER>
CloneContext& InsertBack(const utils::Vector<T, N>& vector, BUILDER&& builder) {
list_transforms_.Edit(&vector).insert_back_.Push(std::forward<BUILDER>(builder));
list_transforms_.GetOrZero(&vector)->insert_back_.Push(std::forward<BUILDER>(builder));
return *this;
}
@ -456,7 +456,7 @@ class CloneContext {
return *this;
}
list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
[object] { return object; });
return *this;
}
@ -475,7 +475,7 @@ class CloneContext {
CloneContext& InsertBefore(const utils::Vector<T, N>& vector,
const BEFORE* before,
BUILDER&& builder) {
list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(
list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push(
std::forward<BUILDER>(builder));
return *this;
}
@ -498,7 +498,7 @@ class CloneContext {
return *this;
}
list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
[object] { return object; });
return *this;
}
@ -517,7 +517,7 @@ class CloneContext {
CloneContext& InsertAfter(const utils::Vector<T, N>& vector,
const AFTER* after,
BUILDER&& builder) {
list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(
list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push(
std::forward<BUILDER>(builder));
return *this;
}
@ -601,61 +601,6 @@ class CloneContext {
/// @returns the diagnostic list of #dst
diag::List& Diagnostics() const;
/// VectorListTransforms is a map of utils::Vector pointer to transforms for that list
struct VectorListTransforms {
using Map = utils::Hashmap<const void*, ListTransforms, 4>;
/// An accessor to the VectorListTransforms map.
/// Index caches the last map lookup, and will only re-search the map if the transform map
/// was modified since the last lookup.
struct Index {
/// @returns true if the map now holds a value for the index
operator bool() {
Update();
return cached_;
}
/// @returns a pointer to the indexed map entry
const ListTransforms* operator->() {
Update();
return cached_;
}
private:
friend VectorListTransforms;
Index(const void* list, Map* map)
: list_(list),
map_(map),
generation_(map->Generation()),
cached_(map_->Find(list)) {}
void Update() {
if (map_->Generation() != generation_) {
cached_ = map_->Find(list_);
generation_ = map_->Generation();
}
}
const void* list_;
Map* map_;
uint64_t generation_;
const ListTransforms* cached_;
};
/// Edit returns a reference to the ListTransforms for the given vector pointer and
/// increments #list_transform_generation_ signalling that the list transforms have been
/// modified.
inline ListTransforms& Edit(const void* list) { return map_.GetOrZero(list); }
/// @returns an Index to the transforms for the given list.
inline Index Find(const void* list) { return Index{list, &map_}; }
private:
/// The map of vector pointer to ListTransforms
Map map_;
};
/// A map of object in #src to functions that create their replacement in #dst
utils::Hashmap<const Cloneable*, std::function<const Cloneable*()>, 8> replacements_;
@ -666,7 +611,7 @@ class CloneContext {
utils::Vector<CloneableTransform, 8> transforms_;
/// Transformations to apply to vectors
VectorListTransforms list_transforms_;
utils::Hashmap<const void*, ListTransforms, 4> list_transforms_;
/// Symbol transform registered with ReplaceAll()
SymbolTransform symbol_transform_;

View File

@ -3522,7 +3522,7 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info,
const auto phi_id = assignment.phi_id;
auto* const lhs_expr = builder_.Expr(namer_.Name(phi_id));
// If RHS value is actually a phi we just cpatured, then use it.
auto* const copy_sym = copied_phis.Find(assignment.value_id);
auto copy_sym = copied_phis.Find(assignment.value_id);
auto* const rhs_expr =
copy_sym ? builder_.Expr(*copy_sym) : MakeExpression(assignment.value_id).expr;
AddStatement(builder_.Assign(lhs_expr, rhs_expr));

View File

@ -666,7 +666,7 @@ class ParserImpl : Reader {
/// @param id a SPIR-V ID
/// @returns the AST variable or null.
const ast::Var* GetModuleVariable(uint32_t id) {
auto* entry = module_variable_.Find(id);
auto entry = module_variable_.Find(id);
return entry ? *entry : nullptr;
}

View File

@ -471,7 +471,7 @@ class DependencyScanner {
}
}
if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
DependencyInfo{from->source, action})) {
current_global_->deps.Push(*global);

View File

@ -1128,8 +1128,8 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
if (expect_pass) {
// Check that the use resolves to the declaration
auto* resolved_symbol = graph.resolved_symbols.Find(use);
ASSERT_NE(resolved_symbol, nullptr);
auto resolved_symbol = graph.resolved_symbols.Find(use);
ASSERT_TRUE(resolved_symbol);
EXPECT_EQ(*resolved_symbol, decl)
<< "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
<< "\n"
@ -1179,8 +1179,8 @@ TEST_P(ResolverDependencyShadowTest, Test) {
helper.Build();
auto shadows = Build().shadows;
auto* shadow = shadows.Find(inner_var);
ASSERT_NE(shadow, nullptr);
auto shadow = shadows.Find(inner_var);
ASSERT_TRUE(shadow);
EXPECT_EQ(*shadow, outer);
}
@ -1310,8 +1310,8 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
auto graph = Build();
for (auto use : symbol_uses) {
auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
ASSERT_NE(resolved_symbol, nullptr) << use.where;
auto resolved_symbol = graph.resolved_symbols.Find(use.use);
ASSERT_TRUE(resolved_symbol) << use.where;
EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
}
}

View File

@ -2481,7 +2481,7 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
if (loop_block->FirstContinue()) {
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
if (auto* decl = loop_block->Decls().Find(symbol)) {
if (auto decl = loop_block->Decls().Find(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
builder_->Symbols().NameFor(symbol) + "'",

View File

@ -54,7 +54,7 @@ class SemHelper {
/// @param node the node to retrieve
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = dependencies_.resolved_symbols.Find(node);
auto resolved = dependencies_.resolved_symbols.Find(node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
}

View File

@ -600,7 +600,7 @@ class UniformityGraph {
}
// Add an edge from the variable's loop input node to its value at this point.
auto** in_node = info.var_in_nodes.Find(var);
auto in_node = info.var_in_nodes.Find(var);
TINT_ASSERT(Resolver, in_node != nullptr);
auto* out_node = current_function_->variables.Get(var);
if (out_node != *in_node) {
@ -1334,7 +1334,7 @@ class UniformityGraph {
[&](const sem::Function* func) {
// We must have already analyzed the user-defined function since we process
// functions in dependency order.
auto* info = functions_.Find(func->Declaration());
auto info = functions_.Find(func->Declaration());
TINT_ASSERT(Resolver, info != nullptr);
callsite_tag = info->callsite_tag;
function_tag = info->function_tag;
@ -1466,7 +1466,7 @@ class UniformityGraph {
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
auto* target_info = functions_.Find(user->Declaration());
auto target_info = functions_.Find(user->Declaration());
for (auto* call_node : target_info->required_to_be_uniform->edges) {
if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
@ -1636,7 +1636,7 @@ class UniformityGraph {
// If this is a call to a user-defined function, add a note to show the reason that the
// parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) {
auto* next_function = functions_.Find(user->Declaration());
auto next_function = functions_.Find(user->Declaration());
Node* next_cause = next_function->parameters[cause->arg_index].init_value;
MakeError(*next_function, next_cause, true);
}

View File

@ -719,7 +719,7 @@ bool Validator::AtomicVariable(
return false;
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
if (auto* found = atomic_composite_info.Find(type)) {
if (auto found = atomic_composite_info.Find(type)) {
if (address_space != ast::AddressSpace::kStorage &&
address_space != ast::AddressSpace::kWorkgroup) {
AddError("atomic variables must have <storage> or <workgroup> address space",
@ -798,7 +798,7 @@ bool Validator::Override(
for (auto* attr : decl->attributes) {
if (attr->Is<ast::IdAttribute>()) {
auto id = v->OverrideId();
if (auto* var = override_ids.Find(id); var && *var != v) {
if (auto var = override_ids.Find(id); var && *var != v) {
AddError("@id values must be unique", attr->source);
AddNote(
"a override with an ID of " + std::to_string(id.value) +

View File

@ -44,7 +44,7 @@ class ScopeStack {
/// stack, otherwise the zero initializer for type T.
V Set(const K& key, V val) {
auto& back = stack_.Back();
if (auto* el = back.Find(key)) {
if (auto el = back.Find(key)) {
std::swap(val, *el);
return val;
}
@ -57,7 +57,7 @@ class ScopeStack {
/// @returns the value, or the zero initializer if the value was not found
V Get(const K& key) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
if (auto* val = iter->Find(key)) {
if (auto val = iter->Find(key)) {
return *val;
}
}

View File

@ -129,7 +129,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
if (auto info = decomposed.Find(access->Member()->Declaration())) {
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
auto name =
b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
@ -168,7 +168,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
if (auto info = decomposed.Find(access->Member()->Declaration())) {
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
auto name =
b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +

View File

@ -140,7 +140,7 @@ struct SimplifyPointers::State {
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
// Look to see if we need to swap this Expression with a saved variable.
if (auto* saved_var = saved_vars.Find(expr)) {
if (auto saved_var = saved_vars.Find(expr)) {
return ctx.dst->Expr(*saved_var);
}

View File

@ -401,7 +401,7 @@ struct Std140::State {
return Switch(
ty, //
[&](const sem::Struct* str) -> const ast::Type* {
if (auto* std140 = std140_structs.Find(str)) {
if (auto std140 = std140_structs.Find(str)) {
return b.create<ast::TypeName>(*std140);
}
return nullptr;
@ -695,7 +695,7 @@ struct Std140::State {
// call, or by reassembling a std140 matrix from column vector members.
utils::Vector<const ast::Expression*, 8> args;
for (auto* member : str->Members()) {
if (auto* col_members = std140_mat_members.Find(member)) {
if (auto col_members = std140_mat_members.Find(member)) {
// std140 decomposed matrix. Reassemble.
auto* mat_ty = CreateASTTypeFor(ctx, member->Type());
auto mat_args =

View File

@ -161,7 +161,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
ctx.ReplaceAll(
[&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* {
auto* return_sem = sem.Get(return_statement);
if (auto* mapping_fn_sym =
if (auto mapping_fn_sym =
entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
return b.Return(return_statement->source,
b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value)));

View File

@ -97,7 +97,7 @@ struct Unshadow::State {
ctx.ReplaceAll(
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
if (auto* renamed = renamed_to.Find(user->Variable())) {
if (auto renamed = renamed_to.Find(user->Variable())) {
return b.Expr(*renamed);
}
}

View File

@ -135,7 +135,7 @@ struct HoistToDeclBefore::State {
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #for_loops map is mutated.
LoopInfo& ForLoop(const sem::ForLoopStatement* for_loop) {
auto ForLoop(const sem::ForLoopStatement* for_loop) {
if (for_loops.IsEmpty()) {
RegisterForLoopTransform();
}
@ -147,7 +147,7 @@ struct HoistToDeclBefore::State {
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #for_loops map is mutated.
LoopInfo& WhileLoop(const sem::WhileStatement* while_loop) {
auto WhileLoop(const sem::WhileStatement* while_loop) {
if (while_loops.IsEmpty()) {
RegisterWhileLoopTransform();
}
@ -159,7 +159,7 @@ struct HoistToDeclBefore::State {
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #else_ifs map is mutated.
ElseIfInfo& ElseIf(const ast::IfStatement* else_if) {
auto ElseIf(const ast::IfStatement* else_if) {
if (else_ifs.IsEmpty()) {
RegisterElseIfTransform();
}
@ -172,7 +172,7 @@ struct HoistToDeclBefore::State {
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
if (auto* info = for_loops.Find(fl)) {
if (auto info = for_loops.Find(fl)) {
auto* for_loop = fl->Declaration();
// For-loop needs to be decomposed to a loop.
// Build the loop body's statements.
@ -222,7 +222,7 @@ struct HoistToDeclBefore::State {
auto& sem = ctx.src->Sem();
if (auto* w = sem.Get(stmt)) {
if (auto* info = while_loops.Find(w)) {
if (auto info = while_loops.Find(w)) {
auto* while_loop = w->Declaration();
// While needs to be decomposed to a loop.
// Build the loop body's statements.
@ -259,7 +259,7 @@ struct HoistToDeclBefore::State {
void RegisterElseIfTransform() const {
// Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* {
if (auto* info = else_ifs.Find(stmt)) {
if (auto info = else_ifs.Find(stmt)) {
// Build the else block's body statements, starting with let decls for the
// conditional expression.
auto body_stmts = Build(info->cond_decls);
@ -291,10 +291,10 @@ struct HoistToDeclBefore::State {
if (else_if && else_if->Parent()->Is<sem::IfStatement>()) {
// Insertion point is an 'else if' condition.
// Need to convert 'else if' to 'else { if }'.
auto& else_if_info = ElseIf(else_if->Declaration());
auto else_if_info = ElseIf(else_if->Declaration());
// Index the map to convert this else if, even if `stmt` is nullptr.
auto& decls = else_if_info.cond_decls;
auto& decls = else_if_info->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@ -306,7 +306,7 @@ struct HoistToDeclBefore::State {
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
auto& decls = ForLoop(fl).cond_decls;
auto& decls = ForLoop(fl)->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@ -318,7 +318,7 @@ struct HoistToDeclBefore::State {
// While needs to be decomposed to a loop.
// Index the map to convert this while, even if `stmt` is nullptr.
auto& decls = WhileLoop(w).cond_decls;
auto& decls = WhileLoop(w)->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
@ -354,7 +354,7 @@ struct HoistToDeclBefore::State {
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
auto& decls = ForLoop(fl).cont_decls;
auto& decls = ForLoop(fl)->cont_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}

View File

@ -47,6 +47,67 @@ class Hashmap : public HashmapBase<KEY, VALUE, N, HASH, EQUAL> {
/// Result of Add()
using AddResult = typename Base::PutResult;
/// Reference is returned by Hashmap::Find(), and performs dynamic Hashmap lookups.
/// The value returned by the Reference reflects the current state of the Hashmap, and so the
/// referenced value may change, or transition between valid or invalid based on the current
/// state of the Hashmap.
template <bool IS_CONST>
class ReferenceT {
/// `const Value` if IS_CONST, or `Value` if !IS_CONST
using T = std::conditional_t<IS_CONST, const Value, Value>;
/// `const Hashmap` if IS_CONST, or `Hashmap` if !IS_CONST
using Map = std::conditional_t<IS_CONST, const Hashmap, Hashmap>;
public:
/// @returns true if the reference is valid.
operator bool() const { return Get() != nullptr; }
/// @returns the pointer to the Value, or nullptr if the reference is invalid.
operator T*() const { return Get(); }
/// @returns the pointer to the Value
/// @warning if the Hashmap does not contain a value for the reference, then this will
/// trigger a TINT_ASSERT, or invalid pointer dereference.
T* operator->() const {
auto* hashmap_reference_lookup = Get();
TINT_ASSERT(Utils, hashmap_reference_lookup != nullptr);
return hashmap_reference_lookup;
}
/// @returns the pointer to the Value, or nullptr if the reference is invalid.
T* Get() const {
auto generation = map_.Generation();
if (generation_ != generation) {
cached_ = map_.Lookup(key_);
generation_ = generation;
}
return cached_;
}
private:
friend Hashmap;
/// Constructor
ReferenceT(Map& map, const Key& key)
: map_(map), key_(key), cached_(nullptr), generation_(map.Generation() - 1) {}
/// Constructor
ReferenceT(Map& map, const Key& key, T* value)
: map_(map), key_(key), cached_(value), generation_(map.Generation()) {}
Map& map_;
const Key key_;
mutable T* cached_ = nullptr;
mutable size_t generation_ = 0;
};
/// A mutable reference returned by Find()
using Reference = ReferenceT</*IS_CONST*/ false>;
/// An immutable reference returned by Find()
using ConstReference = ReferenceT</*IS_CONST*/ true>;
/// Adds a value to the map, if the map does not already contain an entry with the key @p key.
/// @param key the entry key.
/// @param value the value of the entry to add to the map.
@ -108,25 +169,28 @@ class Hashmap : public HashmapBase<KEY, VALUE, N, HASH, EQUAL> {
/// @param key the entry's key value to search for.
/// @returns the value of the entry.
template <typename K>
Value& GetOrZero(K&& key) {
Reference GetOrZero(K&& key) {
auto res = Add(std::forward<K>(key), Value{});
return *res.value;
return Reference(*this, key, res.value);
}
/// @param key the key to search for.
/// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
/// not contain the given value.
const Value* Find(const Key& key) const {
/// @returns a reference to the entry that is equal to the given value.
Reference Find(const Key& key) { return Reference(*this, key); }
/// @param key the key to search for.
/// @returns a reference to the entry that is equal to the given value.
ConstReference Find(const Key& key) const { return ConstReference(*this, key); }
private:
Value* Lookup(const Key& key) {
if (auto [found, index] = this->IndexOf(key); found) {
return &this->slots_[index].entry->value;
}
return nullptr;
}
/// @param key the key to search for.
/// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does
/// not contain the given value.
Value* Find(const Key& key) {
const Value* Lookup(const Key& key) const {
if (auto [found, index] = this->IndexOf(key); found) {
return &this->slots_[index].entry->value;
}

View File

@ -90,6 +90,71 @@ TEST(Hashmap, Generation) {
EXPECT_EQ(map.Generation(), 5u);
}
TEST(Hashmap, Index) {
Hashmap<int, std::string, 4> map;
auto zero = map.Find(0);
EXPECT_FALSE(zero);
map.Add(3, "three");
auto three = map.Find(3);
map.Add(2, "two");
auto two = map.Find(2);
map.Add(4, "four");
auto four = map.Find(4);
map.Add(8, "eight");
auto eight = map.Find(8);
EXPECT_FALSE(zero);
ASSERT_TRUE(three);
ASSERT_TRUE(two);
ASSERT_TRUE(four);
ASSERT_TRUE(eight);
EXPECT_EQ(*three, "three");
EXPECT_EQ(*two, "two");
EXPECT_EQ(*four, "four");
EXPECT_EQ(*eight, "eight");
map.Add(0, "zero"); // Note: Find called before Add() is okay!
map.Add(5, "five");
auto five = map.Find(5);
map.Add(6, "six");
auto six = map.Find(6);
map.Add(1, "one");
auto one = map.Find(1);
map.Add(7, "seven");
auto seven = map.Find(7);
ASSERT_TRUE(zero);
ASSERT_TRUE(three);
ASSERT_TRUE(two);
ASSERT_TRUE(four);
ASSERT_TRUE(eight);
ASSERT_TRUE(five);
ASSERT_TRUE(six);
ASSERT_TRUE(one);
ASSERT_TRUE(seven);
EXPECT_EQ(*zero, "zero");
EXPECT_EQ(*three, "three");
EXPECT_EQ(*two, "two");
EXPECT_EQ(*four, "four");
EXPECT_EQ(*eight, "eight");
EXPECT_EQ(*five, "five");
EXPECT_EQ(*six, "six");
EXPECT_EQ(*one, "one");
EXPECT_EQ(*seven, "seven");
map.Remove(2);
map.Remove(8);
map.Remove(1);
EXPECT_FALSE(two);
EXPECT_FALSE(eight);
EXPECT_FALSE(one);
}
TEST(Hashmap, Iterator) {
using Map = Hashmap<int, std::string, 8>;
using Entry = typename Map::Entry;