sem: Fold together sem::Array and sem::ArrayType

There's now no need to have both.
Removes a whole bunch of Sem().Get() smell, and simplifies the resolver.

Also fixes a long-standing issue where an array with an explicit, but equal-to-implicit-stride attribute would result in a different type to an array without the decoration.

Bug: tint:724
Fixed: tint:782
Change-Id: I0202459009cd45be427cdb621993a5a3b07ff51e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50301
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2021-05-07 20:58:34 +00:00
committed by Commit Bot service account
parent 6732e8561c
commit 4cd5eea87e
62 changed files with 486 additions and 591 deletions

View File

@@ -122,7 +122,7 @@ using ArrayDecorationTest = TestWithParams;
TEST_P(ArrayDecorationTest, IsValid) {
auto& params = GetParam();
auto arr =
auto* arr =
ty.array(ty.f32(), 0,
{
createDecoration(Source{{12, 34}}, *this, params.kind),
@@ -360,7 +360,7 @@ TEST_P(ArrayStrideTest, All) {
<< ", should_pass: " << params.should_pass;
SCOPED_TRACE(ss.str());
auto arr = ty.array(Source{{12, 34}}, el_ty, 4, params.stride);
auto* arr = ty.array(Source{{12, 34}}, el_ty, 4, params.stride);
Global("myarray", arr, ast::StorageClass::kInput);
@@ -445,11 +445,11 @@ INSTANTIATE_TEST_SUITE_P(
Params{ast_mat4x4<f32>, (default_mat4x4.align - 1) * 7, false}));
TEST_F(ArrayStrideTest, MultipleDecorations) {
auto arr = ty.array(Source{{12, 34}}, ty.i32(), 4,
{
create<ast::StrideDecoration>(4),
create<ast::StrideDecoration>(4),
});
auto* arr = ty.array(Source{{12, 34}}, ty.i32(), 4,
{
create<ast::StrideDecoration>(4),
create<ast::StrideDecoration>(4),
});
Global("myarray", arr, ast::StorageClass::kInput);
@@ -468,7 +468,7 @@ using StructBlockTest = ResolverTest;
TEST_F(StructBlockTest, StructUsedAsArrayElement) {
auto* s = Structure("S", {Member("x", ty.i32())},
{create<ast::StructBlockDecoration>()});
auto a = ty.array(s, 4);
auto* a = ty.array(s, 4);
Global("G", a, ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());

View File

@@ -756,7 +756,7 @@ INSTANTIATE_TEST_SUITE_P(
using ResolverIntrinsicDataTest = ResolverTest;
TEST_F(ResolverIntrinsicDataTest, ArrayLength_Vector) {
auto ary = ty.array<i32>();
auto* ary = ty.array<i32>();
auto* str = Structure("S", {Member("x", ary)},
{create<ast::StructBlockDecoration>()});
auto ac = ty.access(ast::AccessControl::kReadOnly, str);

View File

@@ -98,11 +98,13 @@ TEST_F(ResolverIsHostShareable, AccessControlI32) {
}
TEST_F(ResolverIsHostShareable, ArraySizedOfHostShareable) {
EXPECT_TRUE(r()->IsHostShareable(ty.array(ty.i32(), 5)));
auto* arr = create<sem::Array>(create<sem::I32>(), 5, 4, 20, 4, true);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
TEST_F(ResolverIsHostShareable, ArrayUnsizedOfHostShareable) {
EXPECT_TRUE(r()->IsHostShareable(ty.array<i32>()));
auto* arr = create<sem::Array>(create<sem::I32>(), 0, 4, 4, 4, true);
EXPECT_TRUE(r()->IsHostShareable(arr));
}
// Note: Structure tests covered in host_shareable_validation_test.cc

View File

@@ -82,11 +82,13 @@ TEST_F(ResolverIsStorableTest, AccessControlI32) {
}
TEST_F(ResolverIsStorableTest, ArraySizedOfStorable) {
EXPECT_TRUE(r()->IsStorable(ty.array(ty.i32(), 5)));
auto* arr = create<sem::Array>(create<sem::I32>(), 5, 4, 20, 4, true);
EXPECT_TRUE(r()->IsStorable(arr));
}
TEST_F(ResolverIsStorableTest, ArrayUnsizedOfStorable) {
EXPECT_TRUE(r()->IsStorable(ty.array<i32>()));
auto* arr = create<sem::Array>(create<sem::I32>(), 0, 4, 4, 4, true);
EXPECT_TRUE(r()->IsStorable(arr));
}
TEST_F(ResolverIsStorableTest, Struct_AllMembersStorable) {

View File

@@ -178,8 +178,8 @@ bool Resolver::IsStorable(const sem::Type* type) {
if (type->is_scalar() || type->Is<sem::Vector>() || type->Is<sem::Matrix>()) {
return true;
}
if (auto* arr = type->As<sem::ArrayType>()) {
return IsStorable(arr->type());
if (auto* arr = type->As<sem::Array>()) {
return IsStorable(arr->ElemType());
}
if (auto* str = type->As<sem::Struct>()) {
for (const auto* member : str->Members()) {
@@ -204,8 +204,8 @@ bool Resolver::IsHostShareable(const sem::Type* type) {
if (auto* mat = type->As<sem::Matrix>()) {
return IsHostShareable(mat->type());
}
if (auto* arr = type->As<sem::ArrayType>()) {
return IsHostShareable(arr->type());
if (auto* arr = type->As<sem::Array>()) {
return IsHostShareable(arr->ElemType());
}
if (auto* str = type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
@@ -287,7 +287,7 @@ bool Resolver::ResolveInternal() {
// TODO(crbug.com/tint/724) - Remove once tint:724 is complete.
// ast::AccessDecorations are generated by the WGSL parser, used to
// build sem::AccessControls and then leaked.
// ast::StrideDecoration are used to build a sem::ArrayTypes, but
// ast::StrideDecoration are used to build a sem::Arrays, but
// multiple arrays of the same stride, size and element type are
// currently de-duplicated by the type manager, and we leak these
// decorations.
@@ -350,14 +350,7 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
return nullptr;
}
if (auto* t = ty->As<ast::Array>()) {
if (auto* el = Type(t->type())) {
auto* sem = builder_->create<sem::ArrayType>(
const_cast<sem::Type*>(el), t->size(), t->decorations());
if (Array(sem, ty->source())) {
return sem;
}
}
return nullptr;
return Array(t);
}
if (auto* t = ty->As<ast::Pointer>()) {
if (auto* el = Type(t->type())) {
@@ -420,9 +413,10 @@ sem::Type* Resolver::Type(const ast::Type* ty) {
return s;
}
Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
sem::Type* type, /* = nullptr */
std::string type_name /* = "" */) {
Resolver::VariableInfo* Resolver::Variable(
ast::Variable* var,
const sem::Type* type, /* = nullptr */
std::string type_name /* = "" */) {
auto it = variable_to_info_.find(var);
if (it != variable_to_info_.end()) {
return it->second;
@@ -436,18 +430,10 @@ Resolver::VariableInfo* Resolver::Variable(ast::Variable* var,
return nullptr;
}
auto* ctype = Canonical(type);
auto* ctype = Canonical(const_cast<sem::Type*>(type));
auto* info = variable_infos_.Create(var, ctype, type_name);
variable_to_info_.emplace(var, info);
// TODO(bclayton): Why is this here? Needed?
// Resolve variable's type
if (auto* arr = info->type->As<sem::ArrayType>()) {
if (!Array(arr, var->source())) {
return nullptr;
}
}
return info;
}
@@ -596,8 +582,8 @@ bool Resolver::ValidateGlobalVariable(const VariableInfo* info) {
bool Resolver::ValidateVariable(const ast::Variable* var) {
auto* type = variable_to_info_[var]->type;
if (auto* r = type->As<sem::ArrayType>()) {
if (r->IsRuntimeArray()) {
if (auto* r = type->As<sem::Array>()) {
if (r->IsRuntimeSized()) {
diagnostics_.add_error(
"v-0015",
"runtime arrays may only appear as the last member of a struct",
@@ -873,8 +859,8 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func,
builder_->Symbols().NameFor(func->symbol()),
func->source());
return false;
} else if (auto* arr = member_ty->As<sem::ArrayType>()) {
if (arr->IsRuntimeArray()) {
} else if (auto* arr = member_ty->As<sem::Array>()) {
if (arr->IsRuntimeSized()) {
diagnostics_.add_error(
"entry point IO types cannot contain runtime sized arrays",
member->Declaration()->source());
@@ -1276,9 +1262,9 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
auto* res = TypeOf(expr->array());
auto* parent_type = res->UnwrapAll();
sem::Type* ret = nullptr;
if (auto* arr = parent_type->As<sem::ArrayType>()) {
ret = arr->type();
const sem::Type* ret = nullptr;
if (auto* arr = parent_type->As<sem::Array>()) {
ret = arr->ElemType();
} else if (auto* vec = parent_type->As<sem::Vector>()) {
ret = vec->type();
} else if (auto* mat = parent_type->As<sem::Matrix>()) {
@@ -1293,8 +1279,8 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
// If we're extracting from a pointer, we return a pointer.
if (auto* ptr = res->As<sem::Pointer>()) {
ret = builder_->create<sem::Pointer>(ret, ptr->storage_class());
} else if (auto* arr = parent_type->As<sem::ArrayType>()) {
if (!arr->type()->is_scalar()) {
} else if (auto* arr = parent_type->As<sem::Array>()) {
if (!arr->ElemType()->is_scalar()) {
// If we extract a non-scalar from an array then we also get a pointer. We
// will generate a Function storage class variable to store this into.
ret = builder_->create<sem::Pointer>(ret, ast::StorageClass::kFunction);
@@ -1459,7 +1445,7 @@ bool Resolver::ValidateVectorConstructor(
value_cardinality_sum++;
} else if (auto* value_vec = value_type->As<sem::Vector>()) {
sem::Type* value_elem_type = value_vec->type()->UnwrapAll();
auto* value_elem_type = value_vec->type()->UnwrapAll();
// A mismatch of vector type parameter T is only an error if multiple
// arguments are present. A single argument constructor constitutes a
// type conversion expression.
@@ -1754,8 +1740,8 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
auto* lhs_type = Canonical(lhs_declared_type);
auto* rhs_type = Canonical(rhs_declared_type);
auto* lhs_type = Canonical(const_cast<sem::Type*>(lhs_declared_type));
auto* rhs_type = Canonical(const_cast<sem::Type*>(rhs_declared_type));
auto* lhs_vec = lhs_type->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
@@ -2006,7 +1992,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
// If the variable has a declared type, resolve it.
std::string type_name;
sem::Type* type = nullptr;
const sem::Type* type = nullptr;
if (auto* ast_ty = var->type()) {
type_name = ast_ty->FriendlyName(builder_->Symbols());
type = Type(ast_ty);
@@ -2065,7 +2051,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
}
// TODO(bclayton): Remove this and fix tests. We're overriding the semantic
// type stored in info->type here with a possibly non-canonicalized type.
info->type = type;
info->type = const_cast<sem::Type*>(type);
variable_stack_.set(var->symbol(), info);
current_block_->decls.push_back(var);
@@ -2251,8 +2237,7 @@ void Resolver::CreateSemanticNodes() const {
bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size,
const Source& source) {
uint32_t& size) {
static constexpr uint32_t vector_size[] = {
/* padding */ 0,
/* padding */ 0,
@@ -2297,76 +2282,71 @@ bool Resolver::DefaultAlignAndSize(const sem::Type* ty,
align = s->Align();
size = s->Size();
return true;
} else if (cty->Is<sem::ArrayType>()) {
if (auto* sem =
Array(ty->UnwrapAliasIfNeeded()->As<sem::ArrayType>(), source)) {
align = sem->Align();
size = sem->Size();
return true;
}
return false;
} else if (auto* a = cty->As<sem::Array>()) {
align = a->Align();
size = a->SizeInBytes();
return true;
}
TINT_UNREACHABLE(diagnostics_) << "Invalid type " << ty->TypeInfo().name;
return false;
}
const sem::Array* Resolver::Array(const sem::ArrayType* arr,
const Source& source) {
if (auto* sem = builder_->Sem().Get(arr)) {
// Semantic info already constructed for this array type
return sem;
}
sem::Array* Resolver::Array(const ast::Array* arr) {
auto source = arr->source();
if (!ValidateArray(arr, source)) {
auto* el_ty = Type(arr->type());
if (!el_ty) {
return nullptr;
}
auto* el_ty = arr->type();
uint32_t el_align = 0;
uint32_t el_size = 0;
if (!DefaultAlignAndSize(el_ty, el_align, el_size, source)) {
if (!DefaultAlignAndSize(el_ty, el_align, el_size)) {
return nullptr;
}
auto create_semantic = [&](uint32_t stride) -> sem::Array* {
auto align = el_align;
// WebGPU requires runtime arrays have at least one element, but the AST
// records an element count of 0 for it.
auto size = std::max<uint32_t>(arr->size(), 1) * stride;
auto* sem = builder_->create<sem::Array>(const_cast<sem::ArrayType*>(arr),
align, size, stride);
builder_->Sem().Add(arr, sem);
return sem;
};
// Look for explicit stride via [[stride(n)]] decoration
uint32_t explicit_stride = 0;
for (auto* deco : arr->decorations()) {
Mark(deco);
if (auto* stride = deco->As<ast::StrideDecoration>()) {
if (auto* sd = deco->As<ast::StrideDecoration>()) {
if (explicit_stride) {
diagnostics_.add_error(
"array must have at most one [[stride]] decoration", source);
return nullptr;
}
explicit_stride = stride->stride();
if (!ValidateArrayStrideDecoration(stride, el_size, el_align, source)) {
explicit_stride = sd->stride();
if (!ValidateArrayStrideDecoration(sd, el_size, el_align, source)) {
return nullptr;
}
continue;
}
}
if (explicit_stride) {
return create_semantic(explicit_stride);
diagnostics_.add_error("decoration is not valid for array types",
deco->source());
return nullptr;
}
// Calculate implicit stride
auto implicit_stride = utils::RoundUp(el_align, el_size);
return create_semantic(implicit_stride);
auto stride = explicit_stride ? explicit_stride : implicit_stride;
// WebGPU requires runtime arrays have at least one element, but the AST
// records an element count of 0 for it.
auto size = std::max<uint32_t>(arr->size(), 1) * stride;
auto* sem = builder_->create<sem::Array>(el_ty, arr->size(), el_align, size,
stride, stride == implicit_stride);
if (!ValidateArray(sem, source)) {
return nullptr;
}
return sem;
}
bool Resolver::ValidateArray(const sem::ArrayType* arr, const Source& source) {
auto* el_ty = arr->type();
bool Resolver::ValidateArray(const sem::Array* arr, const Source& source) {
auto* el_ty = arr->ElemType();
if (!IsStorable(el_ty)) {
builder_->Diagnostics().add_error(
@@ -2416,8 +2396,8 @@ bool Resolver::ValidateArrayStrideDecoration(const ast::StrideDecoration* deco,
bool Resolver::ValidateStructure(const sem::Struct* str) {
for (auto* member : str->Members()) {
if (auto* r = member->Type()->UnwrapAll()->As<sem::ArrayType>()) {
if (r->IsRuntimeArray()) {
if (auto* r = member->Type()->UnwrapAll()->As<sem::Array>()) {
if (r->IsRuntimeSized()) {
if (member != str->Members().back()) {
diagnostics_.add_error(
"v-0015",
@@ -2434,14 +2414,6 @@ bool Resolver::ValidateStructure(const sem::Struct* str) {
member->Declaration()->source());
return false;
}
for (auto* deco : r->decorations()) {
if (!deco->Is<ast::StrideDecoration>()) {
diagnostics_.add_error("decoration is not valid for array types",
deco->source());
return false;
}
}
}
}
@@ -2511,7 +2483,7 @@ sem::Struct* Resolver::Structure(const ast::Struct* str) {
uint32_t offset = struct_size;
uint32_t align = 0;
uint32_t size = 0;
if (!DefaultAlignAndSize(type, align, size, member->source())) {
if (!DefaultAlignAndSize(type, align, size)) {
return nullptr;
}
@@ -2779,7 +2751,7 @@ bool Resolver::Assignment(ast::AssignmentStatement* a) {
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
sem::Type* ty,
const Source& usage) {
ty = ty->UnwrapIfNeeded();
ty = const_cast<sem::Type*>(ty->UnwrapIfNeeded());
if (auto* str = ty->As<sem::Struct>()) {
if (str->StorageClassUsage().count(sc)) {
@@ -2801,8 +2773,9 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
return true;
}
if (auto* arr = ty->As<sem::ArrayType>()) {
return ApplyStorageClassUsageToType(sc, arr->type(), usage);
if (auto* arr = ty->As<sem::Array>()) {
return ApplyStorageClassUsageToType(
sc, const_cast<sem::Type*>(arr->ElemType()), usage);
}
if (ast::IsHostShareable(sc) && !IsHostShareable(ty)) {
@@ -2829,7 +2802,8 @@ bool Resolver::BlockScope(const ast::BlockStatement* block,
return result;
}
std::string Resolver::VectorPretty(uint32_t size, sem::Type* element_type) {
std::string Resolver::VectorPretty(uint32_t size,
const sem::Type* element_type) {
sem::Vector vec_type(element_type, size);
return vec_type.FriendlyName(builder_->Symbols());
}

View File

@@ -224,7 +224,7 @@ class Resolver {
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidateArray(const sem::ArrayType* arr, const Source& source);
bool ValidateArray(const sem::Array* arr, const Source& source);
bool ValidateArrayStrideDecoration(const ast::StrideDecoration* deco,
uint32_t el_size,
uint32_t el_align,
@@ -250,15 +250,18 @@ class Resolver {
/// @param ty the ast::Type
sem::Type* Type(const ast::Type* ty);
/// @returns the semantic information for the array `arr`, building it if it
/// hasn't been constructed already. If an error is raised, nullptr is
/// returned.
/// Builds and returns the semantic information for the array `arr`.
/// This method does not mark the ast::Array node, nor attach the generated
/// semantic information to the AST node.
/// @returns the semantic Array information, or nullptr if an error is raised.
/// @param arr the Array to get semantic information for
/// @param source the Source of the ast node with this array as its type
const sem::Array* Array(const sem::ArrayType* arr, const Source& source);
sem::Array* Array(const ast::Array* arr);
/// @returns the sem::Struct for the AST structure `str`. If an error is
/// raised, nullptr is returned.
/// Builds and returns the semantic information for the structure `str`.
/// This method does not mark the ast::Struct node, nor attach the generated
/// semantic information to the AST node.
/// @returns the semantic Struct information, or nullptr if an error is
/// raised. raised, nullptr is returned.
sem::Struct* Structure(const ast::Struct* str);
/// @returns the VariableInfo for the variable `var`, building it if it hasn't
@@ -268,7 +271,7 @@ class Resolver {
/// @param type_name optional type name of `var` to use instead of
/// `var->type()->FriendlyName()`.
VariableInfo* Variable(ast::Variable* var,
sem::Type* type = nullptr,
const sem::Type* type = nullptr,
std::string type_name = "");
/// Records the storage class usage for the given type, and any transient
@@ -285,12 +288,10 @@ class Resolver {
/// @param align the output default alignment in bytes for the type `ty`
/// @param size the output default size in bytes for the type `ty`
/// @param source the Source of the variable declaration of type `ty`
/// @returns true on success, false on error
bool DefaultAlignAndSize(const sem::Type* ty,
uint32_t& align,
uint32_t& size,
const Source& source);
uint32_t& size);
/// @returns the resolved type of the ast::Expression `expr`
/// @param expr the expression
@@ -333,7 +334,7 @@ class Resolver {
/// @param size the vector dimension
/// @param element_type scalar vector sub-element type
/// @return pretty string representation
std::string VectorPretty(uint32_t size, sem::Type* element_type);
std::string VectorPretty(uint32_t size, const sem::Type* element_type);
/// Mark records that the given AST node has been visited, and asserts that
/// the given node has not already been seen. Diamonds in the AST are illegal.

View File

@@ -61,7 +61,7 @@ TEST_F(ResolverStorageClassValidationTest, StorageBufferPointer) {
TEST_F(ResolverStorageClassValidationTest, StorageBufferArray) {
// var<storage> g : [[access(read)]] array<S, 3>;
auto* s = Structure("S", {Member("a", ty.f32())});
auto a = ty.array(s, 3);
auto* a = ty.array(s, 3);
auto ac = ty.access(ast::AccessControl::kReadOnly, a);
Global(Source{{56, 78}}, "g", ac, ast::StorageClass::kStorage);
@@ -169,7 +169,7 @@ TEST_F(ResolverStorageClassValidationTest, UniformBufferPointer) {
TEST_F(ResolverStorageClassValidationTest, UniformBufferArray) {
// var<uniform> g : [[access(read)]] array<S, 3>;
auto* s = Structure("S", {Member("a", ty.f32())});
auto a = ty.array(s, 3);
auto* a = ty.array(s, 3);
auto ac = ty.access(ast::AccessControl::kReadOnly, a);
Global(Source{{56, 78}}, "g", ac, ast::StorageClass::kUniform);

View File

@@ -173,8 +173,8 @@ TEST_F(ResolverStructLayoutTest, ExplicitStrideArrayRuntimeSized) {
}
TEST_F(ResolverStructLayoutTest, ImplicitStrideArrayOfExplicitStrideArray) {
auto inner = ty.array<i32, 2>(/*stride*/ 16); // size: 32
auto outer = ty.array(inner, 12); // size: 12 * 32
auto* inner = ty.array<i32, 2>(/*stride*/ 16); // size: 32
auto* outer = ty.array(inner, 12); // size: 12 * 32
auto* s = Structure("S", {
Member("c", outer),
});
@@ -198,7 +198,7 @@ TEST_F(ResolverStructLayoutTest, ImplicitStrideArrayOfStructure) {
Member("b", ty.vec3<i32>()),
Member("c", ty.vec4<i32>()),
}); // size: 48
auto outer = ty.array(inner, 12); // size: 12 * 48
auto* outer = ty.array(inner, 12); // size: 12 * 48
auto* s = Structure("S", {
Member("c", outer),
});

View File

@@ -105,7 +105,7 @@ TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalStruct) {
TEST_F(ResolverStorageClassUseTest, StructReachableViaGlobalArray) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto a = ty.array(s, 3);
auto* a = ty.array(s, 3);
Global("g", a, ast::StorageClass::kPrivate);
ASSERT_TRUE(r()->Resolve()) << r()->error();
@@ -158,7 +158,7 @@ TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalStruct) {
TEST_F(ResolverStorageClassUseTest, StructReachableViaLocalArray) {
auto* s = Structure("S", {Member("a", ty.f32())});
auto a = ty.array(s, 3);
auto* a = ty.array(s, 3);
WrapInFunction(Var("g", a, ast::StorageClass::kFunction));
ASSERT_TRUE(r()->Resolve()) << r()->error();