Fix overrides in array size.

This CL fixes the usage of overrides in array sizes. Currently
the usage will generate a validation error as we check that the
array size is const.

Bug: tint:1660
Change-Id: Ibf440905c30a73b581d55b0c071b8621b61605e6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/101900
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: dan sinclair <dsinclair@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
dan sinclair
2022-09-22 22:28:21 +00:00
committed by Dawn LUCI CQ
parent 534a198f88
commit 78f8067fd5
42 changed files with 839 additions and 186 deletions

View File

@@ -16,15 +16,22 @@
#include <string>
#include "src/tint/ast/variable.h"
#include "src/tint/debug.h"
#include "src/tint/sem/variable.h"
#include "src/tint/symbol_table.h"
#include "src/tint/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::Array);
namespace tint::sem {
const char* Array::kErrExpectedConstantCount =
"array size is an override-expression, when expected a constant-expression.\n"
"Was the SubstituteOverride transform run?";
Array::Array(const Type* element,
uint32_t count,
ArrayCount count,
uint32_t align,
uint32_t size,
uint32_t stride,
@@ -35,8 +42,9 @@ Array::Array(const Type* element,
size_(size),
stride_(stride),
implicit_stride_(implicit_stride),
constructible_(count > 0 // Runtime-sized arrays are not constructible
&& element->IsConstructible()) {
// Only constant-expression sized arrays are constructible
constructible_(std::holds_alternative<ConstantArrayCount>(count) &&
element->IsConstructible()) {
TINT_ASSERT(Semantic, element_);
}
@@ -64,8 +72,10 @@ std::string Array::FriendlyName(const SymbolTable& symbols) const {
out << "@stride(" << stride_ << ") ";
}
out << "array<" << element_->FriendlyName(symbols);
if (!IsRuntimeSized()) {
out << ", " << count_;
if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
out << ", " << const_count->value;
} else if (auto* override_count = std::get_if<OverrideArrayCount>(&count_)) {
out << ", " << symbols.NameFor(override_count->variable->Declaration()->symbol);
}
out << ">";
return out.str();

View File

@@ -16,29 +16,116 @@
#define SRC_TINT_SEM_ARRAY_H_
#include <stdint.h>
#include <optional>
#include <string>
#include <variant>
#include "src/tint/sem/node.h"
#include "src/tint/sem/type.h"
#include "src/tint/utils/compiler_macros.h"
// Forward declarations
namespace tint::sem {
class GlobalVariable;
} // namespace tint::sem
namespace tint::sem {
/// The variant of an ArrayCount when the array is a constant expression.
/// Example:
/// ```
/// const N = 123;
/// type arr = array<i32, N>
/// ```
struct ConstantArrayCount {
/// The array count constant-expression value.
uint32_t value;
};
/// The variant of an ArrayCount when the count is a named override variable.
/// Example:
/// ```
/// override N : i32;
/// type arr = array<i32, N>
/// ```
struct OverrideArrayCount {
/// The `override` variable.
const GlobalVariable* variable;
};
/// The variant of an ArrayCount when the array is is runtime-sized.
/// Example:
/// ```
/// type arr = array<i32>
/// ```
struct RuntimeArrayCount {};
/// An array count is either a constant-expression value, an override identifier, or runtime-sized.
using ArrayCount = std::variant<ConstantArrayCount, OverrideArrayCount, RuntimeArrayCount>;
/// Equality operator
/// @param a the LHS ConstantArrayCount
/// @param b the RHS ConstantArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b) {
return a.value == b.value;
}
/// Equality operator
/// @param a the LHS OverrideArrayCount
/// @param b the RHS OverrideArrayCount
/// @returns true if @p a is equal to @p b
inline bool operator==(const OverrideArrayCount& a, const OverrideArrayCount& b) {
return a.variable == b.variable;
}
/// Equality operator
/// @returns true
inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
return true;
}
/// Equality operator
/// @param a the LHS ArrayCount
/// @param b the RHS count
/// @returns true if @p a is equal to @p b
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, ConstantArrayCount> ||
std::is_same_v<T, OverrideArrayCount> ||
std::is_same_v<T, RuntimeArrayCount>>>
inline bool operator==(const ArrayCount& a, const T& b) {
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
return std::visit(
[&](auto count) {
if constexpr (std::is_same_v<std::decay_t<decltype(count)>, T>) {
return count == b;
}
return false;
},
a);
TINT_END_DISABLE_WARNING(UNREACHABLE_CODE);
}
/// Array holds the semantic information for Array nodes.
class Array final : public Castable<Array, Type> {
public:
/// An error message string stating that the array count was expected to be a constant
/// expression. Used by multiple writers and transforms.
static const char* kErrExpectedConstantCount;
/// Constructor
/// @param element the array element type
/// @param count the number of elements in the array. 0 represents a
/// runtime-sized array.
/// @param count the number of elements in the array.
/// @param align the byte alignment of the array
/// @param size the byte size of the array
/// @param size the byte size of the array. The size will be 0 if the array element count is
/// pipeline overridable.
/// @param stride the number of bytes from the start of one element of the
/// array to the start of the next element
/// array to the start of the next element
/// @param implicit_stride the number of bytes from the start of one element
/// of the array to the start of the next element, if there was no `@stride`
/// attribute applied.
Array(Type const* element,
uint32_t count,
ArrayCount count,
uint32_t align,
uint32_t size,
uint32_t stride,
@@ -54,9 +141,16 @@ class Array final : public Castable<Array, Type> {
/// @return the array element type
Type const* ElemType() const { return element_; }
/// @returns the number of elements in the array. 0 represents a runtime-sized
/// array.
uint32_t Count() const { return count_; }
/// @returns the number of elements in the array.
const ArrayCount& Count() const { return count_; }
/// @returns the array count if the count is a constant expression, otherwise returns nullopt.
inline std::optional<uint32_t> ConstantCount() const {
if (auto* count = std::get_if<ConstantArrayCount>(&count_)) {
return count->value;
}
return std::nullopt;
}
/// @returns the byte alignment of the array
/// @note this may differ from the alignment of a structure member of this
@@ -81,8 +175,14 @@ class Array final : public Castable<Array, Type> {
/// natural stride
bool IsStrideImplicit() const { return stride_ == implicit_stride_; }
/// @returns true if this array is sized using an constant expression
bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
/// @returns true if this array is sized using an override variable
bool IsOverrideSized() const { return std::holds_alternative<OverrideArrayCount>(count_); }
/// @returns true if this array is runtime sized
bool IsRuntimeSized() const { return count_ == 0; }
bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
/// @returns true if constructible as per
/// https://gpuweb.github.io/gpuweb/wgsl/#constructible-types
@@ -95,7 +195,7 @@ class Array final : public Castable<Array, Type> {
private:
Type const* const element_;
const uint32_t count_;
const ArrayCount count_;
const uint32_t align_;
const uint32_t size_;
const uint32_t stride_;
@@ -105,4 +205,38 @@ class Array final : public Castable<Array, Type> {
} // namespace tint::sem
namespace std {
/// Custom std::hash specialization for tint::sem::ConstantArrayCount.
template <>
class hash<tint::sem::ConstantArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::ConstantArrayCount& count) const {
return std::hash<decltype(count.value)>()(count.value);
}
};
/// Custom std::hash specialization for tint::sem::OverrideArrayCount.
template <>
class hash<tint::sem::OverrideArrayCount> {
public:
/// @param count the count to hash
/// @return the hash value
inline std::size_t operator()(const tint::sem::OverrideArrayCount& count) const {
return std::hash<decltype(count.variable)>()(count.variable);
}
};
/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
template <>
class hash<tint::sem::RuntimeArrayCount> {
public:
/// @return the hash value
inline std::size_t operator()(const tint::sem::RuntimeArrayCount&) const { return 42; }
};
} // namespace std
#endif // SRC_TINT_SEM_ARRAY_H_

View File

@@ -21,16 +21,16 @@ namespace {
using ArrayTest = TestHelper;
TEST_F(ArrayTest, CreateSizedArray) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), 2u);
EXPECT_EQ(a->Count(), ConstantArrayCount{2u});
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@@ -47,15 +47,15 @@ TEST_F(ArrayTest, CreateSizedArray) {
}
TEST_F(ArrayTest, CreateRuntimeArray) {
auto* a = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u);
auto* b = create<Array>(create<U32>(), 0u, 4u, 8u, 32u, 32u);
auto* c = create<Array>(create<U32>(), 0u, 5u, 8u, 32u, 32u);
auto* d = create<Array>(create<U32>(), 0u, 4u, 9u, 32u, 32u);
auto* e = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 32u);
auto* f = create<Array>(create<U32>(), 0u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* b = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 32u);
auto* c = create<Array>(create<U32>(), RuntimeArrayCount{}, 5u, 8u, 32u, 32u);
auto* d = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 9u, 32u, 32u);
auto* e = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 32u);
auto* f = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->ElemType(), create<U32>());
EXPECT_EQ(a->Count(), 0u);
EXPECT_EQ(a->Count(), sem::RuntimeArrayCount{});
EXPECT_EQ(a->Align(), 4u);
EXPECT_EQ(a->Size(), 8u);
EXPECT_EQ(a->Stride(), 32u);
@@ -71,13 +71,13 @@ TEST_F(ArrayTest, CreateRuntimeArray) {
}
TEST_F(ArrayTest, Hash) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_EQ(a->Hash(), b->Hash());
EXPECT_NE(a->Hash(), c->Hash());
@@ -88,13 +88,13 @@ TEST_F(ArrayTest, Hash) {
}
TEST_F(ArrayTest, Equals) {
auto* a = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), 2u, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), 3u, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), 2u, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), 2u, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), 2u, 4u, 8u, 33u, 17u);
auto* a = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* b = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
auto* c = create<Array>(create<U32>(), ConstantArrayCount{3u}, 4u, 8u, 32u, 16u);
auto* d = create<Array>(create<U32>(), ConstantArrayCount{2u}, 5u, 8u, 32u, 16u);
auto* e = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 9u, 32u, 16u);
auto* f = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 16u);
auto* g = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 33u, 17u);
EXPECT_TRUE(a->Equals(*b));
EXPECT_FALSE(a->Equals(*c));
@@ -106,22 +106,22 @@ TEST_F(ArrayTest, Equals) {
}
TEST_F(ArrayTest, FriendlyNameRuntimeSized) {
auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 4u, 4u);
auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSized) {
auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 4u, 4u);
auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 4u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "array<i32, 5>");
}
TEST_F(ArrayTest, FriendlyNameRuntimeSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), 0u, 0u, 4u, 8u, 4u);
auto* arr = create<Array>(create<I32>(), RuntimeArrayCount{}, 0u, 4u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32>");
}
TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
auto* arr = create<Array>(create<I32>(), 5u, 4u, 20u, 8u, 4u);
auto* arr = create<Array>(create<I32>(), ConstantArrayCount{5u}, 4u, 20u, 8u, 4u);
EXPECT_EQ(arr->FriendlyName(Symbols()), "@stride(8) array<i32, 5>");
}

View File

@@ -246,7 +246,9 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
},
[&](const Array* a) {
if (count) {
*count = a->Count();
if (auto* const_count = std::get_if<ConstantArrayCount>(&a->Count())) {
*count = const_count->value;
}
}
return a->ElemType();
},

View File

@@ -62,56 +62,56 @@ struct TypeTest : public TestHelper {
/* size_no_padding*/ 4u);
const sem::Array* arr_i32 = create<Array>(
/* element */ i32,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_ai = create<Array>(
/* element */ ai,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,
/* implicit_stride */ 5u * 4u);
const sem::Array* arr_vec3_i32 = create<Array>(
/* element */ vec3_i32,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_vec3_ai = create<Array>(
/* element */ vec3_ai,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 16u,
/* size */ 5u * 16u,
/* stride */ 5u * 16u,
/* implicit_stride */ 5u * 16u);
const sem::Array* arr_mat4x3_f16 = create<Array>(
/* element */ mat4x3_f16,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 32u,
/* size */ 5u * 32u,
/* stride */ 5u * 32u,
/* implicit_stride */ 5u * 32u);
const sem::Array* arr_mat4x3_f32 = create<Array>(
/* element */ mat4x3_f32,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_mat4x3_af = create<Array>(
/* element */ mat4x3_af,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 64u,
/* size */ 5u * 64u,
/* stride */ 5u * 64u,
/* implicit_stride */ 5u * 64u);
const sem::Array* arr_str = create<Array>(
/* element */ str,
/* count */ 5u,
/* count */ ConstantArrayCount{5u},
/* align */ 4u,
/* size */ 5u * 4u,
/* stride */ 5u * 4u,