tint/resolver: Allow array sizes to be unnamed override-expressions
I got the rules around this wrong. This should be allowed, but the array types cannot compare equal if they are unnamed override-expressions. Fixed tint:1737 Change-Id: I83dc49703eed015e9c183e804474886da5dad7b9 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107685 Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Auto-Submit: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
cc85ed6dd1
commit
22c4850b06
|
@ -2710,14 +2710,15 @@ utils::Result<sem::ArrayCount> Resolver::ArrayCount(const ast::Expression* count
|
||||||
return utils::Failure;
|
return utils::Failure;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: If the array count is an 'override', but not a identifier expression, we do not return
|
if (count_sem->Stage() == sem::EvaluationStage::kOverride) {
|
||||||
// here, but instead continue to the ConstantValue() check below.
|
// array count is an override expression.
|
||||||
if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
|
// Is the count a named 'override'?
|
||||||
if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
|
if (auto* user = count_sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
|
||||||
if (global->Declaration()->Is<ast::Override>()) {
|
if (auto* global = user->Variable()->As<sem::GlobalVariable>()) {
|
||||||
return sem::ArrayCount{sem::OverrideArrayCount{global}};
|
return sem::ArrayCount{sem::NamedOverrideArrayCount{global}};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return sem::ArrayCount{sem::UnnamedOverrideArrayCount{count_sem}};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* count_val = count_sem->ConstantValue();
|
auto* count_val = count_sem->ConstantValue();
|
||||||
|
|
|
@ -486,8 +486,8 @@ TEST_F(ResolverTest, ArraySize_SignedConst) {
|
||||||
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
|
EXPECT_EQ(ary->Count(), sem::ConstantArrayCount{10u});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTest, ArraySize_Override) {
|
TEST_F(ResolverTest, ArraySize_NamedOverride) {
|
||||||
// override size = 0;
|
// override size = 10i;
|
||||||
// var<workgroup> a : array<f32, size>;
|
// var<workgroup> a : array<f32, size>;
|
||||||
auto* override = Override("size", Expr(10_i));
|
auto* override = Override("size", Expr(10_i));
|
||||||
auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::AddressSpace::kWorkgroup);
|
auto* a = GlobalVar("a", ty.array(ty.f32(), Expr("size")), ast::AddressSpace::kWorkgroup);
|
||||||
|
@ -500,11 +500,11 @@ TEST_F(ResolverTest, ArraySize_Override) {
|
||||||
auto* ary = ref->StoreType()->As<sem::Array>();
|
auto* ary = ref->StoreType()->As<sem::Array>();
|
||||||
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
|
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
|
||||||
ASSERT_NE(sem_override, nullptr);
|
ASSERT_NE(sem_override, nullptr);
|
||||||
EXPECT_EQ(ary->Count(), sem::OverrideArrayCount{sem_override});
|
EXPECT_EQ(ary->Count(), sem::NamedOverrideArrayCount{sem_override});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTest, ArraySize_Override_Equivalence) {
|
TEST_F(ResolverTest, ArraySize_NamedOverride_Equivalence) {
|
||||||
// override size = 0;
|
// override size = 10i;
|
||||||
// var<workgroup> a : array<f32, size>;
|
// var<workgroup> a : array<f32, size>;
|
||||||
// var<workgroup> b : array<f32, size>;
|
// var<workgroup> b : array<f32, size>;
|
||||||
auto* override = Override("size", Expr(10_i));
|
auto* override = Override("size", Expr(10_i));
|
||||||
|
@ -525,11 +525,58 @@ TEST_F(ResolverTest, ArraySize_Override_Equivalence) {
|
||||||
|
|
||||||
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
|
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
|
||||||
ASSERT_NE(sem_override, nullptr);
|
ASSERT_NE(sem_override, nullptr);
|
||||||
EXPECT_EQ(ary_a->Count(), sem::OverrideArrayCount{sem_override});
|
EXPECT_EQ(ary_a->Count(), sem::NamedOverrideArrayCount{sem_override});
|
||||||
EXPECT_EQ(ary_b->Count(), sem::OverrideArrayCount{sem_override});
|
EXPECT_EQ(ary_b->Count(), sem::NamedOverrideArrayCount{sem_override});
|
||||||
EXPECT_EQ(ary_a, ary_b);
|
EXPECT_EQ(ary_a, ary_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverTest, ArraySize_UnnamedOverride) {
|
||||||
|
// override size = 10i;
|
||||||
|
// var<workgroup> a : array<f32, size*2>;
|
||||||
|
auto* override = Override("size", Expr(10_i));
|
||||||
|
auto* cnt = Mul("size", 2_a);
|
||||||
|
auto* a = GlobalVar("a", ty.array(ty.f32(), cnt), ast::AddressSpace::kWorkgroup);
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
|
||||||
|
ASSERT_NE(TypeOf(a), nullptr);
|
||||||
|
auto* ref = TypeOf(a)->As<sem::Reference>();
|
||||||
|
ASSERT_NE(ref, nullptr);
|
||||||
|
auto* ary = ref->StoreType()->As<sem::Array>();
|
||||||
|
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
|
||||||
|
ASSERT_NE(sem_override, nullptr);
|
||||||
|
EXPECT_EQ(ary->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(cnt)});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverTest, ArraySize_UnamedOverride_Equivalence) {
|
||||||
|
// override size = 10i;
|
||||||
|
// var<workgroup> a : array<f32, size>;
|
||||||
|
// var<workgroup> b : array<f32, size>;
|
||||||
|
auto* override = Override("size", Expr(10_i));
|
||||||
|
auto* a_cnt = Mul("size", 2_a);
|
||||||
|
auto* b_cnt = Mul("size", 2_a);
|
||||||
|
auto* a = GlobalVar("a", ty.array(ty.f32(), a_cnt), ast::AddressSpace::kWorkgroup);
|
||||||
|
auto* b = GlobalVar("b", ty.array(ty.f32(), b_cnt), ast::AddressSpace::kWorkgroup);
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
|
||||||
|
ASSERT_NE(TypeOf(a), nullptr);
|
||||||
|
auto* ref_a = TypeOf(a)->As<sem::Reference>();
|
||||||
|
ASSERT_NE(ref_a, nullptr);
|
||||||
|
auto* ary_a = ref_a->StoreType()->As<sem::Array>();
|
||||||
|
|
||||||
|
ASSERT_NE(TypeOf(b), nullptr);
|
||||||
|
auto* ref_b = TypeOf(b)->As<sem::Reference>();
|
||||||
|
ASSERT_NE(ref_b, nullptr);
|
||||||
|
auto* ary_b = ref_b->StoreType()->As<sem::Array>();
|
||||||
|
|
||||||
|
auto* sem_override = Sem().Get<sem::GlobalVariable>(override);
|
||||||
|
ASSERT_NE(sem_override, nullptr);
|
||||||
|
EXPECT_EQ(ary_a->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(a_cnt)});
|
||||||
|
EXPECT_EQ(ary_b->Count(), sem::UnnamedOverrideArrayCount{Sem().Get(b_cnt)});
|
||||||
|
EXPECT_NE(ary_a, ary_b);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTest, Expr_Bitcast) {
|
TEST_F(ResolverTest, Expr_Bitcast) {
|
||||||
GlobalVar("name", ty.f32(), ast::AddressSpace::kPrivate);
|
GlobalVar("name", ty.f32(), ast::AddressSpace::kPrivate);
|
||||||
|
|
||||||
|
@ -2331,8 +2378,8 @@ TEST_F(ResolverTest, Literal_F16WithExtension) {
|
||||||
EXPECT_TRUE(r()->Resolve());
|
EXPECT_TRUE(r()->Resolve());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Windows debug builds have significantly smaller stack than other builds, and these tests will stack
|
// Windows debug builds have significantly smaller stack than other builds, and these tests will
|
||||||
// overflow.
|
// stack overflow.
|
||||||
#if !defined(NDEBUG)
|
#if !defined(NDEBUG)
|
||||||
|
|
||||||
TEST_F(ResolverTest, ScopeDepth_NestedBlocks) {
|
TEST_F(ResolverTest, ScopeDepth_NestedBlocks) {
|
||||||
|
|
|
@ -371,7 +371,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_TooBig_ExplicitStride) {
|
||||||
"12:34 error: array byte size (0x7a1185ee00) must not exceed 0xffffffff bytes");
|
"12:34 error: array byte size (0x7a1185ee00) must not exceed 0xffffffff bytes");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_PrivateVar) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_PrivateVar) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// var<private> a : array<f32, size>;
|
// var<private> a : array<f32, size>;
|
||||||
Override("size", Expr(10_i));
|
Override("size", Expr(10_i));
|
||||||
|
@ -382,19 +382,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_PrivateVar) {
|
||||||
"type of a 'var<workgroup>'");
|
"type of a 'var<workgroup>'");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_ComplexExpr) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_InArray) {
|
||||||
// override size = 10i;
|
|
||||||
// var<workgroup> a : array<f32, size + 1>;
|
|
||||||
Override("size", Expr(10_i));
|
|
||||||
GlobalVar("a", ty.array(ty.f32(), Add(Source{{12, 34}}, "size", 1_i)),
|
|
||||||
ast::AddressSpace::kWorkgroup);
|
|
||||||
EXPECT_FALSE(r()->Resolve());
|
|
||||||
EXPECT_EQ(r()->error(),
|
|
||||||
"12:34 error: array count must evaluate to a constant integer expression or override "
|
|
||||||
"variable");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_InArray) {
|
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// var<workgroup> a : array<array<f32, size>, 4>;
|
// var<workgroup> a : array<array<f32, size>, 4>;
|
||||||
Override("size", Expr(10_i));
|
Override("size", Expr(10_i));
|
||||||
|
@ -406,7 +394,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_InArray) {
|
||||||
"type of a 'var<workgroup>'");
|
"type of a 'var<workgroup>'");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_InStruct) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_InStruct) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// struct S {
|
// struct S {
|
||||||
// a : array<f32, size>
|
// a : array<f32, size>
|
||||||
|
@ -419,7 +407,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_InStruct) {
|
||||||
"type of a 'var<workgroup>'");
|
"type of a 'var<workgroup>'");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Explicit) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_FunctionVar_Explicit) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// fn f() {
|
// fn f() {
|
||||||
// var a : array<f32, size>;
|
// var a : array<f32, size>;
|
||||||
|
@ -435,7 +423,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Explicit) {
|
||||||
"type of a 'var<workgroup>'");
|
"type of a 'var<workgroup>'");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Explicit) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_FunctionLet_Explicit) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// fn f() {
|
// fn f() {
|
||||||
// var a : array<f32, size>;
|
// var a : array<f32, size>;
|
||||||
|
@ -451,7 +439,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Explicit) {
|
||||||
"type of a 'var<workgroup>'");
|
"type of a 'var<workgroup>'");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Implicit) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_FunctionVar_Implicit) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// var<workgroup> w : array<f32, size>;
|
// var<workgroup> w : array<f32, size>;
|
||||||
// fn f() {
|
// fn f() {
|
||||||
|
@ -469,7 +457,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionVar_Implicit) {
|
||||||
"type of a 'var<workgroup>'");
|
"type of a 'var<workgroup>'");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Implicit) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_FunctionLet_Implicit) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// var<workgroup> w : array<f32, size>;
|
// var<workgroup> w : array<f32, size>;
|
||||||
// fn f() {
|
// fn f() {
|
||||||
|
@ -487,7 +475,24 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_FunctionLet_Implicit) {
|
||||||
"type of a 'var<workgroup>'");
|
"type of a 'var<workgroup>'");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_Param) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_UnnamedOverride_Equivalence) {
|
||||||
|
// override size = 10i;
|
||||||
|
// var<workgroup> a : array<f32, size + 1>;
|
||||||
|
// var<workgroup> b : array<f32, size + 1>;
|
||||||
|
// fn f() {
|
||||||
|
// a = b;
|
||||||
|
// }
|
||||||
|
Override("size", Expr(10_i));
|
||||||
|
GlobalVar("a", ty.array(ty.f32(), Add("size", 1_i)), ast::AddressSpace::kWorkgroup);
|
||||||
|
GlobalVar("b", ty.array(ty.f32(), Add("size", 1_i)), ast::AddressSpace::kWorkgroup);
|
||||||
|
WrapInFunction(Assign(Source{{12, 34}}, "a", "b"));
|
||||||
|
EXPECT_FALSE(r()->Resolve());
|
||||||
|
EXPECT_EQ(r()->error(),
|
||||||
|
"12:34 error: cannot assign 'array<f32, [unnamed override-expression]>' to "
|
||||||
|
"'array<f32, [unnamed override-expression]>'");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_Param) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// fn f(a : array<f32, size>) {
|
// fn f(a : array<f32, size>) {
|
||||||
// }
|
// }
|
||||||
|
@ -498,7 +503,7 @@ TEST_F(ResolverTypeValidationTest, ArraySize_Override_Param) {
|
||||||
EXPECT_EQ(r()->error(), "12:34 error: type of function parameter must be constructible");
|
EXPECT_EQ(r()->error(), "12:34 error: type of function parameter must be constructible");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverTypeValidationTest, ArraySize_Override_ReturnType) {
|
TEST_F(ResolverTypeValidationTest, ArraySize_NamedOverride_ReturnType) {
|
||||||
// override size = 10i;
|
// override size = 10i;
|
||||||
// fn f() -> array<f32, size> {
|
// fn f() -> array<f32, size> {
|
||||||
// }
|
// }
|
||||||
|
|
|
@ -40,7 +40,8 @@ TypeFlags FlagsFrom(const Type* element, ArrayCount count) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (std::holds_alternative<ConstantArrayCount>(count) ||
|
if (std::holds_alternative<ConstantArrayCount>(count) ||
|
||||||
std::holds_alternative<OverrideArrayCount>(count)) {
|
std::holds_alternative<NamedOverrideArrayCount>(count) ||
|
||||||
|
std::holds_alternative<UnnamedOverrideArrayCount>(count)) {
|
||||||
if (element->HasFixedFootprint()) {
|
if (element->HasFixedFootprint()) {
|
||||||
flags.Add(TypeFlag::kFixedFootprint);
|
flags.Add(TypeFlag::kFixedFootprint);
|
||||||
}
|
}
|
||||||
|
@ -92,8 +93,10 @@ std::string Array::FriendlyName(const SymbolTable& symbols) const {
|
||||||
out << "array<" << element_->FriendlyName(symbols);
|
out << "array<" << element_->FriendlyName(symbols);
|
||||||
if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
|
if (auto* const_count = std::get_if<ConstantArrayCount>(&count_)) {
|
||||||
out << ", " << const_count->value;
|
out << ", " << const_count->value;
|
||||||
} else if (auto* override_count = std::get_if<OverrideArrayCount>(&count_)) {
|
} else if (auto* named_override_count = std::get_if<NamedOverrideArrayCount>(&count_)) {
|
||||||
out << ", " << symbols.NameFor(override_count->variable->Declaration()->symbol);
|
out << ", " << symbols.NameFor(named_override_count->variable->Declaration()->symbol);
|
||||||
|
} else if (std::holds_alternative<UnnamedOverrideArrayCount>(count_)) {
|
||||||
|
out << ", [unnamed override-expression]";
|
||||||
}
|
}
|
||||||
out << ">";
|
out << ">";
|
||||||
return out.str();
|
return out.str();
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
namespace tint::sem {
|
namespace tint::sem {
|
||||||
|
class Expression;
|
||||||
class GlobalVariable;
|
class GlobalVariable;
|
||||||
} // namespace tint::sem
|
} // namespace tint::sem
|
||||||
|
|
||||||
|
@ -48,11 +49,33 @@ struct ConstantArrayCount {
|
||||||
/// override N : i32;
|
/// override N : i32;
|
||||||
/// type arr = array<i32, N>
|
/// type arr = array<i32, N>
|
||||||
/// ```
|
/// ```
|
||||||
struct OverrideArrayCount {
|
struct NamedOverrideArrayCount {
|
||||||
/// The `override` variable.
|
/// The `override` variable.
|
||||||
const GlobalVariable* variable;
|
const GlobalVariable* variable;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// The variant of an ArrayCount when the count is an unnamed override variable.
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// override N : i32;
|
||||||
|
/// type arr = array<i32, N*2>
|
||||||
|
/// ```
|
||||||
|
struct UnnamedOverrideArrayCount {
|
||||||
|
/// The unnamed override expression.
|
||||||
|
/// Note: Each AST expression gets a unique semantic expression node, so two equivalent AST
|
||||||
|
/// expressions will not result in the same `expr` pointer. This property is important to ensure
|
||||||
|
/// that two array declarations with equivalent AST expressions do not compare equal.
|
||||||
|
/// For example, consider:
|
||||||
|
/// ```
|
||||||
|
/// override size : u32;
|
||||||
|
/// var<workgroup> a : array<f32, size * 2>;
|
||||||
|
/// var<workgroup> b : array<f32, size * 2>;
|
||||||
|
/// ```
|
||||||
|
// The array count for `a` and `b` have equivalent AST expressions, but the types for `a` and
|
||||||
|
// `b` must not compare equal.
|
||||||
|
const Expression* expr;
|
||||||
|
};
|
||||||
|
|
||||||
/// The variant of an ArrayCount when the array is is runtime-sized.
|
/// The variant of an ArrayCount when the array is is runtime-sized.
|
||||||
/// Example:
|
/// Example:
|
||||||
/// ```
|
/// ```
|
||||||
|
@ -60,8 +83,12 @@ struct OverrideArrayCount {
|
||||||
/// ```
|
/// ```
|
||||||
struct RuntimeArrayCount {};
|
struct RuntimeArrayCount {};
|
||||||
|
|
||||||
/// An array count is either a constant-expression value, an override identifier, or runtime-sized.
|
/// An array count is either a constant-expression value, a named override identifier, an unnamed
|
||||||
using ArrayCount = std::variant<ConstantArrayCount, OverrideArrayCount, RuntimeArrayCount>;
|
/// override identifier, or runtime-sized.
|
||||||
|
using ArrayCount = std::variant<ConstantArrayCount,
|
||||||
|
NamedOverrideArrayCount,
|
||||||
|
UnnamedOverrideArrayCount,
|
||||||
|
RuntimeArrayCount>;
|
||||||
|
|
||||||
/// Equality operator
|
/// Equality operator
|
||||||
/// @param a the LHS ConstantArrayCount
|
/// @param a the LHS ConstantArrayCount
|
||||||
|
@ -75,10 +102,18 @@ inline bool operator==(const ConstantArrayCount& a, const ConstantArrayCount& b)
|
||||||
/// @param a the LHS OverrideArrayCount
|
/// @param a the LHS OverrideArrayCount
|
||||||
/// @param b the RHS OverrideArrayCount
|
/// @param b the RHS OverrideArrayCount
|
||||||
/// @returns true if @p a is equal to @p b
|
/// @returns true if @p a is equal to @p b
|
||||||
inline bool operator==(const OverrideArrayCount& a, const OverrideArrayCount& b) {
|
inline bool operator==(const NamedOverrideArrayCount& a, const NamedOverrideArrayCount& b) {
|
||||||
return a.variable == b.variable;
|
return a.variable == b.variable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Equality operator
|
||||||
|
/// @param a the LHS OverrideArrayCount
|
||||||
|
/// @param b the RHS OverrideArrayCount
|
||||||
|
/// @returns true if @p a is equal to @p b
|
||||||
|
inline bool operator==(const UnnamedOverrideArrayCount& a, const UnnamedOverrideArrayCount& b) {
|
||||||
|
return a.expr == b.expr;
|
||||||
|
}
|
||||||
|
|
||||||
/// Equality operator
|
/// Equality operator
|
||||||
/// @returns true
|
/// @returns true
|
||||||
inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
|
inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
|
||||||
|
@ -90,9 +125,9 @@ inline bool operator==(const RuntimeArrayCount&, const RuntimeArrayCount&) {
|
||||||
/// @param b the RHS count
|
/// @param b the RHS count
|
||||||
/// @returns true if @p a is equal to @p b
|
/// @returns true if @p a is equal to @p b
|
||||||
template <typename T,
|
template <typename T,
|
||||||
typename = std::enable_if_t<std::is_same_v<T, ConstantArrayCount> ||
|
typename = std::enable_if_t<
|
||||||
std::is_same_v<T, OverrideArrayCount> ||
|
std::is_same_v<T, ConstantArrayCount> || std::is_same_v<T, NamedOverrideArrayCount> ||
|
||||||
std::is_same_v<T, RuntimeArrayCount>>>
|
std::is_same_v<T, UnnamedOverrideArrayCount> || std::is_same_v<T, RuntimeArrayCount>>>
|
||||||
inline bool operator==(const ArrayCount& a, const T& b) {
|
inline bool operator==(const ArrayCount& a, const T& b) {
|
||||||
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
|
TINT_BEGIN_DISABLE_WARNING(UNREACHABLE_CODE);
|
||||||
return std::visit(
|
return std::visit(
|
||||||
|
@ -178,8 +213,18 @@ class Array final : public Castable<Array, Type> {
|
||||||
/// @returns true if this array is sized using an const-expression
|
/// @returns true if this array is sized using an const-expression
|
||||||
bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
|
bool IsConstantSized() const { return std::holds_alternative<ConstantArrayCount>(count_); }
|
||||||
|
|
||||||
/// @returns true if this array is sized using an override variable
|
/// @returns true if this array is sized using a named override variable
|
||||||
bool IsOverrideSized() const { return std::holds_alternative<OverrideArrayCount>(count_); }
|
bool IsNamedOverrideSized() const {
|
||||||
|
return std::holds_alternative<NamedOverrideArrayCount>(count_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @returns true if this array is sized using an unnamed override variable
|
||||||
|
bool IsUnnamedOverrideSized() const {
|
||||||
|
return std::holds_alternative<UnnamedOverrideArrayCount>(count_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @returns true if this array is sized using a named or unnamed override variable
|
||||||
|
bool IsOverrideSized() const { return IsNamedOverrideSized() || IsUnnamedOverrideSized(); }
|
||||||
|
|
||||||
/// @returns true if this array is runtime sized
|
/// @returns true if this array is runtime sized
|
||||||
bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
|
bool IsRuntimeSized() const { return std::holds_alternative<RuntimeArrayCount>(count_); }
|
||||||
|
@ -213,17 +258,28 @@ class hash<tint::sem::ConstantArrayCount> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Custom std::hash specialization for tint::sem::OverrideArrayCount.
|
/// Custom std::hash specialization for tint::sem::NamedOverrideArrayCount.
|
||||||
template <>
|
template <>
|
||||||
class hash<tint::sem::OverrideArrayCount> {
|
class hash<tint::sem::NamedOverrideArrayCount> {
|
||||||
public:
|
public:
|
||||||
/// @param count the count to hash
|
/// @param count the count to hash
|
||||||
/// @return the hash value
|
/// @return the hash value
|
||||||
inline std::size_t operator()(const tint::sem::OverrideArrayCount& count) const {
|
inline std::size_t operator()(const tint::sem::NamedOverrideArrayCount& count) const {
|
||||||
return std::hash<decltype(count.variable)>()(count.variable);
|
return std::hash<decltype(count.variable)>()(count.variable);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Custom std::hash specialization for tint::sem::UnnamedOverrideArrayCount.
|
||||||
|
template <>
|
||||||
|
class hash<tint::sem::UnnamedOverrideArrayCount> {
|
||||||
|
public:
|
||||||
|
/// @param count the count to hash
|
||||||
|
/// @return the hash value
|
||||||
|
inline std::size_t operator()(const tint::sem::UnnamedOverrideArrayCount& count) const {
|
||||||
|
return std::hash<decltype(count.expr)>()(count.expr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
|
/// Custom std::hash specialization for tint::sem::RuntimeArrayCount.
|
||||||
template <>
|
template <>
|
||||||
class hash<tint::sem::RuntimeArrayCount> {
|
class hash<tint::sem::RuntimeArrayCount> {
|
||||||
|
|
|
@ -127,31 +127,43 @@ TEST_F(ArrayTest, FriendlyNameStaticSizedNonImplicitStride) {
|
||||||
|
|
||||||
TEST_F(ArrayTest, IsConstructable) {
|
TEST_F(ArrayTest, IsConstructable) {
|
||||||
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
|
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
|
||||||
auto* override_sized = create<Array>(create<U32>(), OverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
auto* named_override_sized =
|
||||||
|
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
|
auto* unnamed_override_sized =
|
||||||
|
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
|
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
|
|
||||||
EXPECT_TRUE(fixed_sized->IsConstructible());
|
EXPECT_TRUE(fixed_sized->IsConstructible());
|
||||||
EXPECT_FALSE(override_sized->IsConstructible());
|
EXPECT_FALSE(named_override_sized->IsConstructible());
|
||||||
|
EXPECT_FALSE(unnamed_override_sized->IsConstructible());
|
||||||
EXPECT_FALSE(runtime_sized->IsConstructible());
|
EXPECT_FALSE(runtime_sized->IsConstructible());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ArrayTest, HasCreationFixedFootprint) {
|
TEST_F(ArrayTest, HasCreationFixedFootprint) {
|
||||||
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
|
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
|
||||||
auto* override_sized = create<Array>(create<U32>(), OverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
auto* named_override_sized =
|
||||||
|
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
|
auto* unnamed_override_sized =
|
||||||
|
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
|
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
|
|
||||||
EXPECT_TRUE(fixed_sized->HasCreationFixedFootprint());
|
EXPECT_TRUE(fixed_sized->HasCreationFixedFootprint());
|
||||||
EXPECT_FALSE(override_sized->HasCreationFixedFootprint());
|
EXPECT_FALSE(named_override_sized->HasCreationFixedFootprint());
|
||||||
|
EXPECT_FALSE(unnamed_override_sized->HasCreationFixedFootprint());
|
||||||
EXPECT_FALSE(runtime_sized->HasCreationFixedFootprint());
|
EXPECT_FALSE(runtime_sized->HasCreationFixedFootprint());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ArrayTest, HasFixedFootprint) {
|
TEST_F(ArrayTest, HasFixedFootprint) {
|
||||||
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
|
auto* fixed_sized = create<Array>(create<U32>(), ConstantArrayCount{2u}, 4u, 8u, 32u, 16u);
|
||||||
auto* override_sized = create<Array>(create<U32>(), OverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
auto* named_override_sized =
|
||||||
|
create<Array>(create<U32>(), NamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
|
auto* unnamed_override_sized =
|
||||||
|
create<Array>(create<U32>(), UnnamedOverrideArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
|
auto* runtime_sized = create<Array>(create<U32>(), RuntimeArrayCount{}, 4u, 8u, 32u, 16u);
|
||||||
|
|
||||||
EXPECT_TRUE(fixed_sized->HasFixedFootprint());
|
EXPECT_TRUE(fixed_sized->HasFixedFootprint());
|
||||||
EXPECT_TRUE(override_sized->HasFixedFootprint());
|
EXPECT_TRUE(named_override_sized->HasFixedFootprint());
|
||||||
|
EXPECT_TRUE(unnamed_override_sized->HasFixedFootprint());
|
||||||
EXPECT_FALSE(runtime_sized->HasFixedFootprint());
|
EXPECT_FALSE(runtime_sized->HasFixedFootprint());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -114,10 +114,14 @@ const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type*
|
||||||
if (a->IsRuntimeSized()) {
|
if (a->IsRuntimeSized()) {
|
||||||
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
|
return ctx.dst->ty.array(el, nullptr, std::move(attrs));
|
||||||
}
|
}
|
||||||
if (auto* override = std::get_if<sem::OverrideArrayCount>(&a->Count())) {
|
if (auto* override = std::get_if<sem::NamedOverrideArrayCount>(&a->Count())) {
|
||||||
auto* count = ctx.Clone(override->variable->Declaration());
|
auto* count = ctx.Clone(override->variable->Declaration());
|
||||||
return ctx.dst->ty.array(el, count, std::move(attrs));
|
return ctx.dst->ty.array(el, count, std::move(attrs));
|
||||||
}
|
}
|
||||||
|
if (auto* override = std::get_if<sem::UnnamedOverrideArrayCount>(&a->Count())) {
|
||||||
|
auto* count = ctx.Clone(override->expr->Declaration());
|
||||||
|
return ctx.dst->ty.array(el, count, std::move(attrs));
|
||||||
|
}
|
||||||
if (auto count = a->ConstantCount()) {
|
if (auto count = a->ConstantCount()) {
|
||||||
return ctx.dst->ty.array(el, u32(count.value()), std::move(attrs));
|
return ctx.dst->ty.array(el, u32(count.value()), std::move(attrs));
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
// flags: --overrides wgsize=10
|
||||||
|
|
||||||
|
override wgsize : u32;
|
||||||
|
var<workgroup> a : array<f32, wgsize>; // Accepted
|
||||||
|
var<workgroup> b : array<f32, wgsize * 2>; // Rejected
|
||||||
|
|
||||||
|
fn f() {
|
||||||
|
let x = a[0];
|
||||||
|
let y = b[0];
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
[numthreads(1, 1, 1)]
|
||||||
|
void unused_entry_point() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
groupshared float a[10];
|
||||||
|
groupshared float b[20];
|
||||||
|
|
||||||
|
void f() {
|
||||||
|
const float x = a[0];
|
||||||
|
const float y = b[0];
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
[numthreads(1, 1, 1)]
|
||||||
|
void unused_entry_point() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
groupshared float a[10];
|
||||||
|
groupshared float b[20];
|
||||||
|
|
||||||
|
void f() {
|
||||||
|
const float x = a[0];
|
||||||
|
const float y = b[0];
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
#version 310 es
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
void unused_entry_point() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
shared float a[10];
|
||||||
|
shared float b[20];
|
||||||
|
void f() {
|
||||||
|
float x = a[0];
|
||||||
|
float y = b[0];
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template<typename T, size_t N>
|
||||||
|
struct tint_array {
|
||||||
|
const constant T& operator[](size_t i) const constant { return elements[i]; }
|
||||||
|
device T& operator[](size_t i) device { return elements[i]; }
|
||||||
|
const device T& operator[](size_t i) const device { return elements[i]; }
|
||||||
|
thread T& operator[](size_t i) thread { return elements[i]; }
|
||||||
|
const thread T& operator[](size_t i) const thread { return elements[i]; }
|
||||||
|
threadgroup T& operator[](size_t i) threadgroup { return elements[i]; }
|
||||||
|
const threadgroup T& operator[](size_t i) const threadgroup { return elements[i]; }
|
||||||
|
T elements[N];
|
||||||
|
};
|
||||||
|
|
||||||
|
void f(threadgroup tint_array<float, 10>* const tint_symbol, threadgroup tint_array<float, 20>* const tint_symbol_1) {
|
||||||
|
float const x = (*(tint_symbol))[0];
|
||||||
|
float const y = (*(tint_symbol_1))[0];
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
; SPIR-V
|
||||||
|
; Version: 1.3
|
||||||
|
; Generator: Google Tint Compiler; 0
|
||||||
|
; Bound: 24
|
||||||
|
; Schema: 0
|
||||||
|
OpCapability Shader
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
|
||||||
|
OpExecutionMode %unused_entry_point LocalSize 1 1 1
|
||||||
|
OpName %a "a"
|
||||||
|
OpName %b "b"
|
||||||
|
OpName %unused_entry_point "unused_entry_point"
|
||||||
|
OpName %f "f"
|
||||||
|
OpDecorate %_arr_float_uint_10 ArrayStride 4
|
||||||
|
OpDecorate %_arr_float_uint_20 ArrayStride 4
|
||||||
|
%float = OpTypeFloat 32
|
||||||
|
%uint = OpTypeInt 32 0
|
||||||
|
%uint_10 = OpConstant %uint 10
|
||||||
|
%_arr_float_uint_10 = OpTypeArray %float %uint_10
|
||||||
|
%_ptr_Workgroup__arr_float_uint_10 = OpTypePointer Workgroup %_arr_float_uint_10
|
||||||
|
%a = OpVariable %_ptr_Workgroup__arr_float_uint_10 Workgroup
|
||||||
|
%uint_20 = OpConstant %uint 20
|
||||||
|
%_arr_float_uint_20 = OpTypeArray %float %uint_20
|
||||||
|
%_ptr_Workgroup__arr_float_uint_20 = OpTypePointer Workgroup %_arr_float_uint_20
|
||||||
|
%b = OpVariable %_ptr_Workgroup__arr_float_uint_20 Workgroup
|
||||||
|
%void = OpTypeVoid
|
||||||
|
%11 = OpTypeFunction %void
|
||||||
|
%int = OpTypeInt 32 1
|
||||||
|
%18 = OpConstantNull %int
|
||||||
|
%_ptr_Workgroup_float = OpTypePointer Workgroup %float
|
||||||
|
%unused_entry_point = OpFunction %void None %11
|
||||||
|
%14 = OpLabel
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
%f = OpFunction %void None %11
|
||||||
|
%16 = OpLabel
|
||||||
|
%20 = OpAccessChain %_ptr_Workgroup_float %a %18
|
||||||
|
%21 = OpLoad %float %20
|
||||||
|
%22 = OpAccessChain %_ptr_Workgroup_float %b %18
|
||||||
|
%23 = OpLoad %float %22
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
|
@ -0,0 +1,10 @@
|
||||||
|
const wgsize : u32 = 10u;
|
||||||
|
|
||||||
|
var<workgroup> a : array<f32, wgsize>;
|
||||||
|
|
||||||
|
var<workgroup> b : array<f32, (wgsize * 2)>;
|
||||||
|
|
||||||
|
fn f() {
|
||||||
|
let x = a[0];
|
||||||
|
let y = b[0];
|
||||||
|
}
|
Loading…
Reference in New Issue