IntrinsicTable: De-duplicate returned Intrinsics

Much like sem::Type, it greatly simplifies downstream logic if we can compare sem::Intrinsic pointers to know if they refer to the same intrinsic overload.

Change-Id: If236247cd3979bbde821d9294f304ab85ba4938e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58061
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton
2021-07-15 20:34:21 +00:00
committed by Ben Clayton
parent e54e26d7e8
commit b478f97975
18 changed files with 410 additions and 15 deletions

View File

@@ -51,8 +51,8 @@ struct SamplerTexturePair {
namespace std {
/// Custom std::hash specialization for ttint::inspector::SamplerTexturePair so
/// SamplerTexturePairss be used as keys for std::unordered_map and
/// Custom std::hash specialization for tint::inspector::SamplerTexturePair so
/// SamplerTexturePairs be used as keys for std::unordered_map and
/// std::unordered_set.
template <>
class hash<tint::inspector::SamplerTexturePair> {

View File

@@ -27,6 +27,8 @@
#include "src/sem/pipeline_stage_set.h"
#include "src/sem/sampled_texture_type.h"
#include "src/sem/storage_texture_type.h"
#include "src/utils/get_or_create.h"
#include "src/utils/hash.h"
#include "src/utils/scoped_assignment.h"
namespace tint {
@@ -706,13 +708,13 @@ class Impl : public IntrinsicTable {
const sem::Intrinsic* Lookup(sem::IntrinsicType intrinsic_type,
const std::vector<const sem::Type*>& args,
const Source& source) const override;
const Source& source) override;
private:
const sem::Intrinsic* Match(sem::IntrinsicType intrinsic_type,
const OverloadInfo& overload,
const std::vector<const sem::Type*>& args,
int& match_score) const;
int& match_score);
MatchState Match(ClosedState& closed,
const OverloadInfo& overload,
@@ -724,6 +726,7 @@ class Impl : public IntrinsicTable {
ProgramBuilder& builder;
Matchers matchers;
std::unordered_map<sem::Intrinsic, sem::Intrinsic*> intrinsics;
};
/// @return a string representing a call to an intrinsic with the given argument
@@ -760,7 +763,7 @@ Impl::Impl(ProgramBuilder& b) : builder(b) {}
const sem::Intrinsic* Impl::Lookup(sem::IntrinsicType intrinsic_type,
const std::vector<const sem::Type*>& args,
const Source& source) const {
const Source& source) {
// Candidate holds information about a mismatched overload that could be what
// the user intended to call.
struct Candidate {
@@ -809,7 +812,7 @@ const sem::Intrinsic* Impl::Lookup(sem::IntrinsicType intrinsic_type,
const sem::Intrinsic* Impl::Match(sem::IntrinsicType intrinsic_type,
const OverloadInfo& overload,
const std::vector<const sem::Type*>& args,
int& match_score) const {
int& match_score) {
// Score wait for argument <-> parameter count matches / mismatches
constexpr int kScorePerParamArgMismatch = -1;
constexpr int kScorePerMatchedParam = 2;
@@ -896,9 +899,14 @@ const sem::Intrinsic* Impl::Match(sem::IntrinsicType intrinsic_type,
return_type = builder.create<sem::Void>();
}
return builder.create<sem::Intrinsic>(
intrinsic_type, const_cast<sem::Type*>(return_type),
std::move(parameters), overload.supported_stages, overload.is_deprecated);
sem::Intrinsic intrinsic(intrinsic_type, const_cast<sem::Type*>(return_type),
std::move(parameters), overload.supported_stages,
overload.is_deprecated);
// De-duplicate intrinsics that are identical.
return utils::GetOrCreate(intrinsics, intrinsic, [&] {
return builder.create<sem::Intrinsic>(intrinsic);
});
}
MatchState Impl::Match(ClosedState& closed,

View File

@@ -45,7 +45,7 @@ class IntrinsicTable {
virtual const sem::Intrinsic* Lookup(
sem::IntrinsicType type,
const std::vector<const sem::Type*>& args,
const Source& source) const = 0;
const Source& source) = 0;
};
} // namespace tint

View File

@@ -548,5 +548,26 @@ TEST_F(IntrinsicTableTest, OverloadOrderByMatchingParameter) {
)");
}
TEST_F(IntrinsicTableTest, SameOverloadReturnsSameIntrinsicPointer) {
auto* f32 = create<sem::F32>();
auto* vec2_f32 = create<sem::Vector>(create<sem::F32>(), 2);
auto* bool_ = create<sem::Bool>();
auto* a = table->Lookup(IntrinsicType::kSelect, {f32, f32, bool_}, Source{});
ASSERT_NE(a, nullptr) << Diagnostics().str();
auto* b = table->Lookup(IntrinsicType::kSelect, {f32, f32, bool_}, Source{});
ASSERT_NE(b, nullptr) << Diagnostics().str();
ASSERT_EQ(Diagnostics().str(), "");
auto* c = table->Lookup(IntrinsicType::kSelect, {vec2_f32, vec2_f32, bool_},
Source{});
ASSERT_NE(c, nullptr) << Diagnostics().str();
ASSERT_EQ(Diagnostics().str(), "");
EXPECT_EQ(a, b);
EXPECT_NE(a, c);
EXPECT_NE(b, c);
}
} // namespace
} // namespace tint

View File

@@ -26,6 +26,8 @@ CallTarget::CallTarget(sem::Type* return_type, const ParameterList& parameters)
TINT_ASSERT(Semantic, return_type);
}
CallTarget::CallTarget(const CallTarget&) = default;
CallTarget::~CallTarget() = default;
int IndexOf(const ParameterList& parameters, ParameterUsage usage) {

View File

@@ -20,6 +20,7 @@
#include "src/sem/node.h"
#include "src/sem/parameter_usage.h"
#include "src/sem/sampler_type.h"
#include "src/utils/hash.h"
namespace tint {
@@ -37,11 +38,16 @@ struct Parameter {
std::ostream& operator<<(std::ostream& out, Parameter parameter);
/// Comparison operator for Parameters
/// Equality operator for Parameters
static inline bool operator==(const Parameter& a, const Parameter& b) {
return a.type == b.type && a.usage == b.usage;
}
/// Inequality operator for Parameters
static inline bool operator!=(const Parameter& a, const Parameter& b) {
return !(a == b);
}
/// ParameterList is a list of Parameter
using ParameterList = std::vector<Parameter>;
@@ -59,6 +65,9 @@ class CallTarget : public Castable<CallTarget, Node> {
/// @param parameters the parameters for the call target
CallTarget(sem::Type* return_type, const ParameterList& parameters);
/// Copy constructor
CallTarget(const CallTarget&);
/// @return the return type of the call target
sem::Type* ReturnType() const { return return_type_; }
@@ -76,4 +85,19 @@ class CallTarget : public Castable<CallTarget, Node> {
} // namespace sem
} // namespace tint
namespace std {
/// Custom std::hash specialization for tint::sem::Parameter
template <>
class hash<tint::sem::Parameter> {
public:
/// @param p the tint::sem::Parameter to create a hash for
/// @return the hash value
inline std::size_t operator()(const tint::sem::Parameter& p) const {
return tint::utils::Hash(p.type, p.usage);
}
};
} // namespace std
#endif // SRC_SEM_CALL_TARGET_H_

View File

@@ -111,6 +111,8 @@ Intrinsic::Intrinsic(IntrinsicType type,
supported_stages_(supported_stages),
is_deprecated_(is_deprecated) {}
Intrinsic::Intrinsic(const Intrinsic&) = default;
Intrinsic::~Intrinsic() = default;
bool Intrinsic::IsCoarseDerivative() const {
@@ -153,5 +155,25 @@ bool Intrinsic::IsAtomic() const {
return IsAtomicIntrinsic(type_);
}
bool operator==(const Intrinsic& a, const Intrinsic& b) {
static_assert(sizeof(Intrinsic(IntrinsicType::kNone, nullptr, ParameterList{},
PipelineStageSet{}, false)) > 0,
"don't forget to update the comparison below if you change the "
"constructor of Intrinsic!");
if (a.Type() != b.Type() || a.SupportedStages() != b.SupportedStages() ||
a.ReturnType() != b.ReturnType() ||
a.IsDeprecated() != b.IsDeprecated() ||
a.Parameters().size() != b.Parameters().size()) {
return false;
}
for (size_t i = 0; i < a.Parameters().size(); i++) {
if (a.Parameters()[i] != b.Parameters()[i]) {
return false;
}
}
return true;
}
} // namespace sem
} // namespace tint

View File

@@ -20,6 +20,7 @@
#include "src/sem/call_target.h"
#include "src/sem/intrinsic_type.h"
#include "src/sem/pipeline_stage_set.h"
#include "src/utils/hash.h"
namespace tint {
namespace sem {
@@ -91,6 +92,9 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
PipelineStageSet supported_stages,
bool is_deprecated);
/// Copy constructor
Intrinsic(const Intrinsic&);
/// Destructor
~Intrinsic() override;
@@ -147,7 +151,31 @@ class Intrinsic : public Castable<Intrinsic, CallTarget> {
/// matches the name in the WGSL spec.
std::ostream& operator<<(std::ostream& out, IntrinsicType i);
/// Equality operator for Intrinsics
bool operator==(const Intrinsic& a, const Intrinsic& b);
/// Inequality operator for Intrinsics
static inline bool operator!=(const Intrinsic& a, const Intrinsic& b) {
return !(a == b);
}
} // namespace sem
} // namespace tint
namespace std {
/// Custom std::hash specialization for tint::sem::Intrinsic
template <>
class hash<tint::sem::Intrinsic> {
public:
/// @param i the Intrinsic to create a hash for
/// @return the hash value
inline std::size_t operator()(const tint::sem::Intrinsic& i) const {
return tint::utils::Hash(i.Type(), i.SupportedStages(), i.ReturnType(),
i.Parameters(), i.IsDeprecated());
}
};
} // namespace std
#endif // SRC_SEM_INTRINSIC_H_

View File

@@ -16,6 +16,7 @@
#define SRC_UTILS_ENUM_SET_H_
#include <cstdint>
#include <functional>
#include <type_traits>
namespace tint {
@@ -58,6 +59,19 @@ struct EnumSet {
/// @return true if the set contains `e`
inline bool Contains(Enum e) { return (set & Bit(e)) != 0; }
/// Equality operator
/// @param rhs the other EnumSet to compare this to
/// @return true if this EnumSet is equal to rhs
inline bool operator==(const EnumSet& rhs) const { return set == rhs.set; }
/// Inequality operator
/// @param rhs the other EnumSet to compare this to
/// @return true if this EnumSet is not equal to rhs
inline bool operator!=(const EnumSet& rhs) const { return set != rhs.set; }
/// @return the underlying value for the EnumSet
inline uint64_t Value() const { return set; }
private:
static constexpr uint64_t Bit(Enum value) {
return static_cast<uint64_t>(1) << static_cast<uint64_t>(value);
@@ -76,4 +90,19 @@ struct EnumSet {
} // namespace utils
} // namespace tint
namespace std {
/// Custom std::hash specialization for tint::utils::EnumSet<T>
template <typename T>
class hash<tint::utils::EnumSet<T>> {
public:
/// @param e the EnumSet to create a hash for
/// @return the hash value
inline std::size_t operator()(const tint::utils::EnumSet<T>& e) const {
return std::hash<uint64_t>()(e.Value());
}
};
} // namespace std
#endif // SRC_UTILS_ENUM_SET_H_

View File

@@ -59,6 +59,30 @@ TEST(EnumSetTest, Remove) {
EXPECT_FALSE(set.Contains(E::C));
}
TEST(EnumSetTest, Equality) {
EXPECT_TRUE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::B));
EXPECT_FALSE(EnumSet<E>(E::A, E::B) == EnumSet<E>(E::A, E::C));
}
TEST(EnumSetTest, Inequality) {
EXPECT_FALSE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::B));
EXPECT_TRUE(EnumSet<E>(E::A, E::B) != EnumSet<E>(E::A, E::C));
}
TEST(EnumSetTest, Hash) {
auto hash = [&](EnumSet<E> s) { return std::hash<EnumSet<E>>()(s); };
EXPECT_EQ(hash(EnumSet<E>(E::A, E::B)), hash(EnumSet<E>(E::A, E::B)));
EXPECT_NE(hash(EnumSet<E>(E::A, E::B)), hash(EnumSet<E>(E::A, E::C)));
}
TEST(EnumSetTest, Value) {
EXPECT_EQ(EnumSet<E>().Value(), 0u);
EXPECT_EQ(EnumSet<E>(E::A).Value(), 1u);
EXPECT_EQ(EnumSet<E>(E::B).Value(), 2u);
EXPECT_EQ(EnumSet<E>(E::C).Value(), 4u);
EXPECT_EQ(EnumSet<E>(E::A, E::C).Value(), 5u);
}
} // namespace
} // namespace utils
} // namespace tint

View File

@@ -27,8 +27,10 @@ namespace utils {
/// @param key the map key of the item to query or add
/// @param create a callable function-like object with the signature `V()`
/// @return the value of the item with the given key, or the newly created item
template <typename K, typename V, typename CREATE, typename H>
V GetOrCreate(std::unordered_map<K, V, H>& map, K key, CREATE&& create) {
template <typename K, typename V, typename H, typename C, typename CREATE>
V GetOrCreate(std::unordered_map<K, V, H, C>& map,
const K& key,
CREATE&& create) {
auto it = map.find(key);
if (it != map.end()) {
return it->second;

View File

@@ -18,6 +18,7 @@
#include <stdint.h>
#include <cstdio>
#include <functional>
#include <vector>
namespace tint {
namespace utils {
@@ -51,6 +52,15 @@ void HashCombine(size_t* hash, const T& value) {
*hash ^= std::hash<T>()(value) + offset + (*hash << 6) + (*hash >> 2);
}
// Helper for hashing vectors
template <typename T>
void HashCombine(size_t* hash, const std::vector<T>& vector) {
HashCombine(hash, vector.size());
for (auto& el : vector) {
HashCombine(hash, el);
}
}
template <typename T, typename... ARGS>
void HashCombine(size_t* hash, const T& value, const ARGS&... args) {
HashCombine(hash, value);

View File

@@ -23,14 +23,25 @@ namespace utils {
namespace {
TEST(HashTests, Basic) {
EXPECT_EQ(Hash(123), Hash(123));
EXPECT_NE(Hash(123), Hash(321));
EXPECT_EQ(Hash(123, 456), Hash(123, 456));
EXPECT_NE(Hash(123, 456), Hash(456, 123));
EXPECT_NE(Hash(123, 456), Hash(123));
EXPECT_EQ(Hash(123, 456, false), Hash(123, 456, false));
EXPECT_NE(Hash(123, 456, false), Hash(123, 456));
EXPECT_EQ(Hash(std::string("hello")), Hash(std::string("hello")));
EXPECT_NE(Hash(std::string("hello")), Hash(std::string("world")));
}
TEST(HashTests, Order) {
EXPECT_NE(Hash(123, 456), Hash(456, 123));
TEST(HashTests, Vector) {
EXPECT_EQ(Hash(std::vector<int>({})), Hash(std::vector<int>({})));
EXPECT_EQ(Hash(std::vector<int>({1, 2, 3})),
Hash(std::vector<int>({1, 2, 3})));
EXPECT_NE(Hash(std::vector<int>({1, 2, 3})),
Hash(std::vector<int>({1, 2, 4})));
EXPECT_NE(Hash(std::vector<int>({1, 2, 3})),
Hash(std::vector<int>({1, 2, 3, 4})));
}
} // namespace