tint/resolver: Allow array sizes to be unnamed override-expressions

I got the rules around this wrong. This should be allowed, but the array types cannot compare equal if they are unnamed override-expressions.

Fixed tint:1737

Change-Id: I83dc49703eed015e9c183e804474886da5dad7b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107685
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton
2022-10-31 17:26:10 +00:00
committed by Dawn LUCI CQ
parent cc85ed6dd1
commit 22c4850b06
14 changed files with 306 additions and 58 deletions

View File

@@ -40,7 +40,8 @@ TypeFlags FlagsFrom(const Type* element, ArrayCount count) {
}
}
if (std::holds_alternative<ConstantArrayCount>(count) ||
std::holds_alternative<OverrideArrayCount>(count)) {
std::holds_alternative<NamedOverrideArrayCount>(count) ||
std::holds_alternative<UnnamedOverrideArrayCount>(count)) {
if (element->HasFixedFootprint()) {
flags.Add(TypeFlag::kFixedFootprint);
}
@@ -92,8 +93,10 @@ std::string Array::FriendlyName(const SymbolTable& symbols) const {
out << "array<" << element_->FriendlyName(symbols);
if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
out << ", " << const_count->value;
} else if (auto* override_count = std::get_if<OverrideArrayCount>(&count_)) {
out << ", " << symbols.NameFor(override_count->variable->Declaration()->symbol);
} else if (auto* named_override_count = std::get_if<NamedOverrideArrayCount>(&count_)) {
out << ", " << symbols.NameFor(named_override_count->variable->Declaration()->symbol);
} else if (std::holds_alternative<UnnamedOverrideArrayCount>(count_)) {
out << ", [unnamed override-expression]";
}
out << ">";
return out.str();

View File

@@ -26,6 +26,7 @@
// Forward declarations
namespace tint::sem {
class Expression;
class GlobalVariable;
} // namespace tint::sem
@@ -48,11 +49,33 @@ struct ConstantArrayCount {
/// override N : i32;
/// type arr = array<i32, N>
/// ```
struct OverrideArrayCount {
struct NamedOverrideArrayCount {
/// The `override` variable.
const GlobalVariable* variable;
};
/// The variant of an ArrayCount when the count is an unnamed override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N*2>
/// ```
struct UnnamedOverrideArrayCount {
/// The unnamed override expression.
/// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST
/// expressions will not result in the same `expr` pointer. This property is important to ensure
/// that two array declarations with equivalent AST expressions do not compare equal.
/// For example, consider:
/// ```
/// override size : u32;
/// var<workgroup> a : array<f32, size * 2>;
/// var<workgroup> b : array<f32, size * 2>;
/// ```
// The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and
// `b` must not compare equal.
const Expression* expr;
};
/// The variant of an ArrayCount when the array is is runtime-sized.
/// Example:
/// ```
@@ -60,8 +83,12 @@ struct OverrideArrayCount {
/// ```
struct RuntimeArrayCount {};
/// An array count is either a constant-expression value, an override identifier, or runtime-sized.
using ArrayCount = std::variant<ConstantArrayCount, OverrideArrayCount, RuntimeArrayCount>;
/// An array count is either a constant-expression value, a named override identifier, an unnamed
/// override identifier, or runtime-sized.
using ArrayCount = std::variant<ConstantArrayCount,
NamedOverrideArrayCount,
UnnamedOverrideArrayCount,
RuntimeArrayCount>;
/// Equality operator
/// @param a the LHS ConstantArrayCount
@@ -75,10 +102,18 @@ inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b)
/// @param a the LHS OverrideArrayCount
/// @param b the RHS OverrideArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const OverrideArrayCount& a, const OverrideArrayCount& b) {
inline bool operator==(const NamedOverrideArrayCount& a, const NamedOverrideArrayCount& b) {
return a.variable == b.variable;
}
/// Equality operator
/// @param a the LHS OverrideArrayCount
/// @param b the RHS OverrideArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const UnnamedOverrideArrayCount& a, const UnnamedOverrideArrayCount& b) {
return a.expr == b.expr;
}
/// Equality operator
/// @returns true
inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
@@ -90,9 +125,9 @@ inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
/// @param b the RHS count
/// @returns true if @p a is equal to @p b
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, ConstantArrayCount> ||
std::is_same_v<T, OverrideArrayCount> ||
std::is_same_v<T, RuntimeArrayCount>>>
typename = std::enable_if_t<
std::is_same_v<T, ConstantArrayCount> || std::is_same_v<T, NamedOverrideArrayCount> ||
std::is_same_v<T, UnnamedOverrideArrayCount> || std::is_same_v<T, RuntimeArrayCount>>>
inline bool operator==(const ArrayCount& a, const T& b) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
return std::visit(
@@ -178,8 +213,18 @@ class Array final : public Castable<Array, Type> {
/// @returns true if this array is sized using an const-expression
bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
/// @returns true if this array is sized using an override variable
bool IsOverrideSized() const { return std::holds_alternative<OverrideArrayCount>(count_); }
/// @returns true if this array is sized using a named override variable
bool IsNamedOverrideSized() const {
return std::holds_alternative<NamedOverrideArrayCount>(count_);
}
/// @returns true if this array is sized using an unnamed override variable
bool IsUnnamedOverrideSized() const {
return std::holds_alternative<UnnamedOverrideArrayCount>(count_);
}
/// @returns true if this array is sized using a named or unnamed override variable
bool IsOverrideSized() const { return IsNamedOverrideSized() || IsUnnamedOverrideSized(); }
/// @returns true if this array is runtime sized
bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
@@ -213,17 +258,28 @@ class hash<tint::sem::ConstantArrayCount> {
}
};
/// Custom std::hash specialization for tint::sem::OverrideArrayCount.
/// Custom std::hash specialization for tint::sem::NamedOverrideArrayCount.
template <>
class hash<tint::sem::OverrideArrayCount> {
class hash<tint::sem::NamedOverrideArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::OverrideArrayCount& count) const {
inline std::size_t operator()(const tint::sem::NamedOverrideArrayCount& count) const {
return std::hash<decltype(count.variable)>()(count.variable);
}
};
/// Custom std::hash specialization for tint::sem::UnnamedOverrideArrayCount.
template <>
class hash<tint::sem::UnnamedOverrideArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::UnnamedOverrideArrayCount& count) const {
return std::hash<decltype(count.expr)>()(count.expr);
}
};
/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
template <>
class hash<tint::sem::RuntimeArrayCount> {

View File

@@ -127,31 +127,43 @@ TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
TEST_F(ArrayTest, IsConstructable) {
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* override_sized = create<Array>(create<U32>(), OverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* named_override_sized =
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->IsConstructible());
EXPECT_FALSE(override_sized->IsConstructible());
EXPECT_FALSE(named_override_sized->IsConstructible());
EXPECT_FALSE(unnamed_override_sized->IsConstructible());
EXPECT_FALSE(runtime_sized->IsConstructible());
}
TEST_F(ArrayTest, HasCreationFixedFootprint) {
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* override_sized = create<Array>(create<U32>(), OverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* named_override_sized =
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->HasCreationFixedFootprint());
EXPECT_FALSE(override_sized->HasCreationFixedFootprint());
EXPECT_FALSE(named_override_sized->HasCreationFixedFootprint());
EXPECT_FALSE(unnamed_override_sized->HasCreationFixedFootprint());
EXPECT_FALSE(runtime_sized->HasCreationFixedFootprint());
}
TEST_F(ArrayTest, HasFixedFootprint) {
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* override_sized = create<Array>(create<U32>(), OverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* named_override_sized =
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* unnamed_override_sized =
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
EXPECT_TRUE(fixed_sized->HasFixedFootprint());
EXPECT_TRUE(override_sized->HasFixedFootprint());
EXPECT_TRUE(named_override_sized->HasFixedFootprint());
EXPECT_TRUE(unnamed_override_sized->HasFixedFootprint());
EXPECT_FALSE(runtime_sized->HasFixedFootprint());
}