From b04d992f8395c812aacc727f0fa1150bddbe5c49 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 31 Aug 2022 23:15:38 +0000 Subject: [PATCH] tint/utils: Fix Hashmap::GetOrCreate() for map mutation in create Its not unreasonable for the create callback to mutate the map. If this happened, the map would be corrupted. This change fixes this. Change-Id: I2bb3820061c741c6da36ebe3667cb6b878515a27 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/100903 Kokoro: Kokoro Reviewed-by: Dan Sinclair Commit-Queue: Ben Clayton --- src/tint/utils/hashmap.h | 80 ++++++++++------------------------ src/tint/utils/hashmap_test.cc | 70 ++++++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 58 deletions(-) diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h index 1154086d2d..87a0dd9b9b 100644 --- a/src/tint/utils/hashmap.h +++ b/src/tint/utils/hashmap.h @@ -35,28 +35,6 @@ template , typename EQUAL = std::equal_to> class Hashmap { - /// LazyCreator is a transient structure used to late-build the Entry::value, when inserted into - /// the underlying Hashset. - /// - /// LazyCreator holds a #key, and a #create function used to build the final Entry::value. - /// The #create function must be of the signature `V()`. - /// - /// LazyCreator can be compared to Entry and hashed, allowing them to be passed to - /// Hashset::Insert(). If the set does not contain an existing entry with #key, - /// Hashset::Insert() will construct a new Entry passing the rvalue LazyCreator as the - /// constructor argument, which in turn calls the #create function to generate the entry value. - /// - /// @see Entry - /// @see Hasher - /// @see Equality - template - struct LazyCreator { - /// The key of the entry to insert into the map - const K& key; - /// The value creation function - CREATE create; - }; - /// Entry holds a key and value pair, and is used as the element type of the underlying Hashset. /// Entries are compared and hashed using only the #key. /// @see Hasher @@ -71,23 +49,6 @@ class Hashmap { /// Move-constructor. Entry(Entry&&) = default; - /// Constructor from a LazyCreator. - /// The constructor invokes the LazyCreator::create function to build the #value. - /// @see LazyCreator - template - Entry(const LazyCreator& creator) // NOLINT(runtime/explicit) - : key(creator.key), value(creator.create()) {} - - /// Assignment operator from a LazyCreator. - /// The assignment invokes the LazyCreator::create function to build the #value. - /// @see LazyCreator - template - Entry& operator=(LazyCreator&& creator) { - key = std::move(creator.key); - value = creator.create(); - return *this; - } - /// Copy-assignment operator Entry& operator=(const Entry&) = default; @@ -99,33 +60,23 @@ class Hashmap { }; /// Hash provider for the underlying Hashset. - /// Provides hash functions for an Entry, K or LazyCreator. + /// Provides hash functions for an Entry or K. /// The hash functions only consider the key of an entry. struct Hasher { /// Calculates a hash from an Entry size_t operator()(const Entry& entry) const { return HASH()(entry.key); } /// Calculates a hash from a K size_t operator()(const K& key) const { return HASH()(key); } - /// Calculates a hash from a LazyCreator - template - size_t operator()(const LazyCreator& lc) const { - return HASH()(lc.key); - } }; /// Equality provider for the underlying Hashset. - /// Provides equality functions for an Entry, K or LazyCreator to an Entry. + /// Provides equality functions for an Entry or K to an Entry. /// The equality functions only consider the key for equality. struct Equality { /// Compares an Entry to an Entry for equality. bool operator()(const Entry& a, const Entry& b) const { return EQUAL()(a.key, b.key); } /// Compares a K to an Entry for equality. bool operator()(const K& a, const Entry& b) const { return EQUAL()(a, b.key); } - /// Compares a LazyCreator to an Entry for equality. - template - bool operator()(const LazyCreator& lc, const Entry& b) const { - return EQUAL()(lc.key, b.key); - } }; /// The underlying set @@ -151,7 +102,8 @@ class Hashmap { /// Used by gmock for the `ElementsAre` checks. using value_type = KeyValue; - /// Iterator for the map + /// Iterator for the map. + /// Iterators are invalidated if the map is modified. class Iterator { public: /// @returns the key of the entry pointed to by this iterator @@ -226,13 +178,29 @@ class Hashmap { /// Searches for an entry with the given key value, adding and returning the result of /// calling `create` if the entry was not found. + /// @note: Before calling `create`, the map will insert a zero-initialized value for the given + /// key, which will be replaced with the value returned by `create`. If `create` adds an entry + /// with `key` to this map, it will be replaced. /// @param key the entry's key value to search for. /// @param create the create function to call if the map does not contain the key. /// @returns the value of the entry. template V& GetOrCreate(const K& key, CREATE&& create) { - LazyCreator lc{key, std::forward(create)}; - auto res = set_.Add(std::move(lc)); + auto res = set_.Add(Entry{key, V{}}); + if (res.action == AddAction::kAdded) { + // Store the set generation before calling create() + auto generation = set_.Generation(); + // Call create(), which might modify this map. + auto value = create(); + // Was this map mutated? + if (set_.Generation() == generation) { + // Calling create() did not touch the map. No need to lookup again. + res.entry->value = std::move(value); + } else { + // Calling create() modified the map. Need to insert again. + res = set_.Replace(Entry{key, std::move(value)}); + } + } return res.entry->value; } @@ -241,9 +209,7 @@ class Hashmap { /// @param key the entry's key value to search for. /// @returns the value of the entry. V& GetOrZero(const K& key) { - auto zero = [] { return V{}; }; - LazyCreator lc{key, zero}; - auto res = set_.Add(std::move(lc)); + auto res = set_.Add(Entry{key, V{}}); return res.entry->value; } diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc index e52b144fd0..9a5b01e0b4 100644 --- a/src/tint/utils/hashmap_test.cc +++ b/src/tint/utils/hashmap_test.cc @@ -119,9 +119,16 @@ TEST(Hashmap, AddMany) { TEST(Hashmap, GetOrCreate) { Hashmap map; - EXPECT_EQ(map.GetOrCreate(0, [&] { return "zero"; }), "zero"); + std::optional value_of_key_0_at_create; + EXPECT_EQ(map.GetOrCreate(0, + [&] { + value_of_key_0_at_create = map.Get(0); + return "zero"; + }), + "zero"); EXPECT_EQ(map.Count(), 1u); EXPECT_EQ(map.Get(0), "zero"); + EXPECT_EQ(value_of_key_0_at_create, ""); bool create_called = false; EXPECT_EQ(map.GetOrCreate(0, @@ -139,6 +146,67 @@ TEST(Hashmap, GetOrCreate) { EXPECT_EQ(map.Get(1), "one"); } +TEST(Hashmap, GetOrCreate_CreateModifiesMap) { + Hashmap map; + EXPECT_EQ(map.GetOrCreate(0, + [&] { + map.Add(3, "three"); + map.Add(1, "one"); + map.Add(2, "two"); + return "zero"; + }), + "zero"); + EXPECT_EQ(map.Count(), 4u); + EXPECT_EQ(map.Get(0), "zero"); + EXPECT_EQ(map.Get(1), "one"); + EXPECT_EQ(map.Get(2), "two"); + EXPECT_EQ(map.Get(3), "three"); + + bool create_called = false; + EXPECT_EQ(map.GetOrCreate(0, + [&] { + create_called = true; + return "oh noes"; + }), + "zero"); + EXPECT_FALSE(create_called); + EXPECT_EQ(map.Count(), 4u); + EXPECT_EQ(map.Get(0), "zero"); + EXPECT_EQ(map.Get(1), "one"); + EXPECT_EQ(map.Get(2), "two"); + EXPECT_EQ(map.Get(3), "three"); + + EXPECT_EQ(map.GetOrCreate(4, + [&] { + map.Add(6, "six"); + map.Add(5, "five"); + map.Add(7, "seven"); + return "four"; + }), + "four"); + EXPECT_EQ(map.Count(), 8u); + EXPECT_EQ(map.Get(0), "zero"); + EXPECT_EQ(map.Get(1), "one"); + EXPECT_EQ(map.Get(2), "two"); + EXPECT_EQ(map.Get(3), "three"); + EXPECT_EQ(map.Get(4), "four"); + EXPECT_EQ(map.Get(5), "five"); + EXPECT_EQ(map.Get(6), "six"); + EXPECT_EQ(map.Get(7), "seven"); +} + +TEST(Hashmap, GetOrCreate_CreateAddsSameKeyedValue) { + Hashmap map; + EXPECT_EQ(map.GetOrCreate(42, + [&] { + map.Add(42, "should-be-replaced"); + return "expected-value"; + }), + "expected-value"); + EXPECT_EQ(map.Count(), 1u); + EXPECT_EQ(map.Get(42), "expected-value"); +} + TEST(Hashmap, Soak) { std::mt19937 rnd; std::unordered_map reference;