Rename all type UnwrapXXX() methods

Give them sensible names.
Make them act consistently.
Remove those that were not used.

Change-Id: Ib043a4093cfae9f81630643e1a0e4eae7bca2440
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/50305
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2021-05-10 18:06:31 +00:00 committed by Commit Bot service account
parent fcda15ef67
commit f14e0e1c8c
17 changed files with 103 additions and 233 deletions

View File

@ -80,41 +80,6 @@ TEST_F(AstAliasTest, FriendlyName) {
EXPECT_EQ(at->FriendlyName(Symbols()), "Particle"); EXPECT_EQ(at->FriendlyName(Symbols()), "Particle");
} }
TEST_F(AstAliasTest, UnwrapIfNeeded_Alias) {
auto* u32 = create<U32>();
auto* a = create<Alias>(Sym("a_type"), u32);
EXPECT_EQ(a->symbol(), Symbol(1, ID()));
EXPECT_EQ(a->type(), u32);
EXPECT_EQ(a->UnwrapIfNeeded(), u32);
EXPECT_EQ(u32->UnwrapIfNeeded(), u32);
}
TEST_F(AstAliasTest, UnwrapIfNeeded_AccessControl) {
auto* u32 = create<U32>();
auto* ac = create<AccessControl>(AccessControl::kReadOnly, u32);
EXPECT_EQ(ac->type(), u32);
EXPECT_EQ(ac->UnwrapIfNeeded(), u32);
}
TEST_F(AstAliasTest, UnwrapIfNeeded_MultiLevel) {
auto* u32 = create<U32>();
auto* a = create<Alias>(Sym("a_type"), u32);
auto* aa = create<Alias>(Sym("aa_type"), a);
EXPECT_EQ(aa->symbol(), Symbol(2, ID()));
EXPECT_EQ(aa->type(), a);
EXPECT_EQ(aa->UnwrapIfNeeded(), u32);
}
TEST_F(AstAliasTest, UnwrapIfNeeded_MultiLevel_AliasAccessControl) {
auto* u32 = create<U32>();
auto* a = create<Alias>(Sym("a_type"), u32);
auto* ac = create<AccessControl>(AccessControl::kReadWrite, a);
EXPECT_EQ(ac->type(), a);
EXPECT_EQ(ac->UnwrapIfNeeded(), u32);
}
TEST_F(AstAliasTest, UnwrapAll_TwiceAliasPointerTwiceAlias) { TEST_F(AstAliasTest, UnwrapAll_TwiceAliasPointerTwiceAlias) {
auto* u32 = create<U32>(); auto* u32 = create<U32>();
auto* a = create<Alias>(Sym("a_type"), u32); auto* a = create<Alias>(Sym("a_type"), u32);
@ -128,31 +93,6 @@ TEST_F(AstAliasTest, UnwrapAll_TwiceAliasPointerTwiceAlias) {
EXPECT_EQ(aapaa->UnwrapAll(), u32); EXPECT_EQ(aapaa->UnwrapAll(), u32);
} }
TEST_F(AstAliasTest, UnwrapAll_SecondConsecutivePointerBlocksUnrapping) {
auto* u32 = create<U32>();
auto* a = create<Alias>(Sym("a_type"), u32);
auto* aa = create<Alias>(Sym("aa_type"), a);
auto* paa = create<Pointer>(aa, StorageClass::kUniform);
auto* ppaa = create<Pointer>(paa, StorageClass::kUniform);
auto* appaa = create<Alias>(Sym("appaa_type"), ppaa);
EXPECT_EQ(appaa->UnwrapAll(), paa);
}
TEST_F(AstAliasTest, UnwrapAll_SecondNonConsecutivePointerBlocksUnrapping) {
auto* u32 = create<U32>();
auto* a = create<Alias>(Sym("a_type"), u32);
auto* aa = create<Alias>(Sym("aa_type"), a);
auto* paa = create<Pointer>(aa, StorageClass::kUniform);
auto* apaa = create<Alias>(Sym("apaa_type"), paa);
auto* aapaa = create<Alias>(Sym("aapaa_type"), apaa);
auto* paapaa = create<Pointer>(aapaa, StorageClass::kUniform);
auto* apaapaa = create<Alias>(Sym("apaapaa_type"), paapaa);
EXPECT_EQ(apaapaa->UnwrapAll(), paa);
}
TEST_F(AstAliasTest, UnwrapAll_AccessControlPointer) { TEST_F(AstAliasTest, UnwrapAll_AccessControlPointer) {
auto* u32 = create<U32>(); auto* u32 = create<U32>();
auto* a = create<AccessControl>(AccessControl::kReadOnly, u32); auto* a = create<AccessControl>(AccessControl::kReadOnly, u32);
@ -170,14 +110,6 @@ TEST_F(AstAliasTest, UnwrapAll_PointerAccessControl) {
EXPECT_EQ(a->UnwrapAll(), u32); EXPECT_EQ(a->UnwrapAll(), u32);
} }
TEST_F(AstAliasTest, UnwrapAliasIfNeeded) {
auto* f32 = create<F32>();
auto* alias1 = create<Alias>(Sym("alias1"), f32);
auto* alias2 = create<Alias>(Sym("alias2"), alias1);
auto* alias3 = create<Alias>(Sym("alias3"), alias2);
EXPECT_EQ(alias3->UnwrapAliasIfNeeded(), f32);
}
} // namespace } // namespace
} // namespace ast } // namespace ast
} // namespace tint } // namespace tint

View File

@ -38,37 +38,20 @@ Type::Type(Type&&) = default;
Type::~Type() = default; Type::~Type() = default;
Type* Type::UnwrapPtrIfNeeded() { Type* Type::UnwrapAll() {
if (auto* ptr = As<Pointer>()) { auto* type = this;
return ptr->type();
}
return this;
}
Type* Type::UnwrapAliasIfNeeded() {
Type* unwrapped = this;
while (auto* ptr = unwrapped->As<Alias>()) {
unwrapped = ptr->type();
}
return unwrapped;
}
Type* Type::UnwrapIfNeeded() {
auto* where = this;
while (true) { while (true) {
if (auto* alias = where->As<Alias>()) { if (auto* alias = type->As<Alias>()) {
where = alias->type(); type = alias->type();
} else if (auto* access = where->As<AccessControl>()) { } else if (auto* access = type->As<AccessControl>()) {
where = access->type(); type = access->type();
} else if (auto* ptr = type->As<Pointer>()) {
type = ptr->type();
} else { } else {
break; break;
} }
} }
return where; return type;
}
Type* Type::UnwrapAll() {
return UnwrapIfNeeded()->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
} }
bool Type::is_scalar() const { bool Type::is_scalar() const {

View File

@ -43,49 +43,10 @@ class Type : public Castable<Type, Node> {
/// declared in WGSL. /// declared in WGSL.
virtual std::string FriendlyName(const SymbolTable& symbols) const = 0; virtual std::string FriendlyName(const SymbolTable& symbols) const = 0;
/// @returns the pointee type if this is a pointer, `this` otherwise /// @returns the type with all aliasing, access control and pointers removed
Type* UnwrapPtrIfNeeded();
/// @returns the most deeply nested aliased type if this is an alias, `this`
/// otherwise
const Type* UnwrapAliasIfNeeded() const {
return const_cast<Type*>(this)->UnwrapAliasIfNeeded();
}
/// @returns the most deeply nested aliased type if this is an alias, `this`
/// otherwise
Type* UnwrapAliasIfNeeded();
/// Removes all levels of aliasing and access control.
/// This is just enough to assist with WGSL translation
/// in that you want see through one level of pointer to get from an
/// identifier-like expression as an l-value to its corresponding r-value,
/// plus see through the wrappers on either side.
/// @returns the completely unaliased type.
Type* UnwrapIfNeeded();
/// Removes all levels of aliasing and access control.
/// This is just enough to assist with WGSL translation
/// in that you want see through one level of pointer to get from an
/// identifier-like expression as an l-value to its corresponding r-value,
/// plus see through the wrappers on either side.
/// @returns the completely unaliased type.
const Type* UnwrapIfNeeded() const {
return const_cast<Type*>(this)->UnwrapIfNeeded();
}
/// Returns the type found after:
/// - removing all layers of aliasing and access control if they exist, then
/// - removing the pointer, if it exists, then
/// - removing all further layers of aliasing or access control, if they exist
/// @returns the unwrapped type
Type* UnwrapAll(); Type* UnwrapAll();
/// Returns the type found after: /// @returns the type with all aliasing, access control and pointers removed
/// - removing all layers of aliasing and access control if they exist, then
/// - removing the pointer, if it exists, then
/// - removing all further layers of aliasing or access control, if they exist
/// @returns the unwrapped type
const Type* UnwrapAll() const { return const_cast<Type*>(this)->UnwrapAll(); } const Type* UnwrapAll() const { return const_cast<Type*>(this)->UnwrapAll(); }
/// @returns true if this type is a scalar /// @returns true if this type is a scalar

View File

@ -387,7 +387,7 @@ std::vector<ResourceBinding> Inspector::GetUniformBufferResourceBindings(
auto* var = ruv.first; auto* var = ruv.first;
auto binding_info = ruv.second; auto binding_info = ruv.second;
auto* unwrapped_type = var->Type()->UnwrapIfNeeded(); auto* unwrapped_type = var->Type()->UnwrapAccess();
auto* str = unwrapped_type->As<sem::Struct>(); auto* str = unwrapped_type->As<sem::Struct>();
if (str == nullptr) { if (str == nullptr) {
continue; continue;
@ -509,7 +509,7 @@ std::vector<ResourceBinding> Inspector::GetDepthTextureResourceBindings(
entry.bind_group = binding_info.group->value(); entry.bind_group = binding_info.group->value();
entry.binding = binding_info.binding->value(); entry.binding = binding_info.binding->value();
auto* texture_type = var->Type()->UnwrapIfNeeded()->As<sem::Texture>(); auto* texture_type = var->Type()->UnwrapAccess()->As<sem::Texture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension( entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(
texture_type->dim()); texture_type->dim());
@ -602,7 +602,7 @@ std::vector<ResourceBinding> Inspector::GetStorageBufferResourceBindingsImpl(
continue; continue;
} }
auto* str = var->Type()->UnwrapIfNeeded()->As<sem::Struct>(); auto* str = var->Type()->UnwrapAccess()->As<sem::Struct>();
if (!str) { if (!str) {
continue; continue;
} }
@ -646,18 +646,15 @@ std::vector<ResourceBinding> Inspector::GetSampledTextureResourceBindingsImpl(
entry.bind_group = binding_info.group->value(); entry.bind_group = binding_info.group->value();
entry.binding = binding_info.binding->value(); entry.binding = binding_info.binding->value();
auto* texture_type = var->Type()->UnwrapIfNeeded()->As<sem::Texture>(); auto* texture_type = var->Type()->UnwrapAccess()->As<sem::Texture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension( entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(
texture_type->dim()); texture_type->dim());
const sem::Type* base_type = nullptr; const sem::Type* base_type = nullptr;
if (multisampled_only) { if (multisampled_only) {
base_type = texture_type->As<sem::MultisampledTexture>() base_type = texture_type->As<sem::MultisampledTexture>()->type();
->type()
->UnwrapIfNeeded();
} else { } else {
base_type = base_type = texture_type->As<sem::SampledTexture>()->type();
texture_type->As<sem::SampledTexture>()->type()->UnwrapIfNeeded();
} }
entry.sampled_kind = BaseTypeToSampledKind(base_type); entry.sampled_kind = BaseTypeToSampledKind(base_type);
@ -697,12 +694,11 @@ std::vector<ResourceBinding> Inspector::GetStorageTextureResourceBindingsImpl(
entry.bind_group = binding_info.group->value(); entry.bind_group = binding_info.group->value();
entry.binding = binding_info.binding->value(); entry.binding = binding_info.binding->value();
auto* texture_type = auto* texture_type = var->Type()->UnwrapAccess()->As<sem::StorageTexture>();
var->Type()->UnwrapIfNeeded()->As<sem::StorageTexture>();
entry.dim = TypeTextureDimensionToResourceBindingTextureDimension( entry.dim = TypeTextureDimensionToResourceBindingTextureDimension(
texture_type->dim()); texture_type->dim());
auto* base_type = texture_type->type()->UnwrapIfNeeded(); auto* base_type = texture_type->type();
entry.sampled_kind = BaseTypeToSampledKind(base_type); entry.sampled_kind = BaseTypeToSampledKind(base_type);
entry.image_format = TypeImageFormatToResourceBindingImageFormat( entry.image_format = TypeImageFormatToResourceBindingImageFormat(
texture_type->image_format()); texture_type->image_format());

View File

@ -1472,7 +1472,7 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
} }
auto source = GetSourceForInst(inst); auto source = GetSourceForInst(inst);
auto* ast_type = original_ast_type->UnwrapIfNeeded(); auto* ast_type = original_ast_type->UnwrapAliasAndAccess();
// TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
// So canonicalization should map that way too. // So canonicalization should map that way too.
@ -1548,7 +1548,7 @@ ast::Expression* ParserImpl::MakeNullValue(const Type* type) {
} }
auto* original_type = type; auto* original_type = type;
type = type->UnwrapIfNeeded(); type = type->UnwrapAliasAndAccess();
if (type->Is<Bool>()) { if (type->Is<Bool>()) {
return create<ast::ScalarConstructorExpression>( return create<ast::ScalarConstructorExpression>(

View File

@ -305,37 +305,50 @@ struct TypeManager::State {
storage_textures_; storage_textures_;
}; };
const Type* Type::UnwrapPtrIfNeeded() const { const Type* Type::UnwrapPtr() const {
if (auto* ptr = As<Pointer>()) { const Type* type = this;
return ptr->type; while (auto* ptr = type->As<Pointer>()) {
type = ptr->type;
} }
return this; return type;
} }
const Type* Type::UnwrapAliasIfNeeded() const { const Type* Type::UnwrapAlias() const {
const Type* unwrapped = this; const Type* type = this;
while (auto* ptr = unwrapped->As<Alias>()) { while (auto* alias = type->As<Alias>()) {
unwrapped = ptr->type; type = alias->type;
} }
return unwrapped; return type;
} }
const Type* Type::UnwrapIfNeeded() const { const Type* Type::UnwrapAliasAndAccess() const {
auto* where = this; auto* type = this;
while (true) { while (true) {
if (auto* alias = where->As<Alias>()) { if (auto* alias = type->As<Alias>()) {
where = alias->type; type = alias->type;
} else if (auto* access = where->As<AccessControl>()) { } else if (auto* access = type->As<AccessControl>()) {
where = access->type; type = access->type;
} else { } else {
break; break;
} }
} }
return where; return type;
} }
const Type* Type::UnwrapAll() const { const Type* Type::UnwrapAll() const {
return UnwrapIfNeeded()->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); auto* type = this;
while (true) {
if (auto* alias = type->As<Alias>()) {
type = alias->type;
} else if (auto* access = type->As<AccessControl>()) {
type = access->type;
} else if (auto* ptr = type->As<Pointer>()) {
type = ptr->type;
} else {
break;
}
}
return type;
} }
bool Type::IsFloatScalar() const { bool Type::IsFloatScalar() const {

View File

@ -45,26 +45,17 @@ class Type : public Castable<Type> {
/// @returns the constructed ast::Type node for the given type /// @returns the constructed ast::Type node for the given type
virtual ast::Type* Build(ProgramBuilder& b) const = 0; virtual ast::Type* Build(ProgramBuilder& b) const = 0;
/// @returns the pointee type if this is a pointer, `this` otherwise /// @returns the inner most pointee type if this is a pointer, `this`
const Type* UnwrapPtrIfNeeded() const;
/// @returns the most deeply nested aliased type if this is an alias, `this`
/// otherwise /// otherwise
const Type* UnwrapAliasIfNeeded() const; const Type* UnwrapPtr() const;
/// Removes all levels of aliasing and access control. /// @returns the inner most aliased type if this is an alias, `this` otherwise
/// This is just enough to assist with WGSL translation const Type* UnwrapAlias() const;
/// in that you want see through one level of pointer to get from an
/// identifier-like expression as an l-value to its corresponding r-value,
/// plus see through the wrappers on either side.
/// @returns the completely unaliased type.
const Type* UnwrapIfNeeded() const;
/// Returns the type found after: /// @returns the type with all aliasing and access control removed
/// - removing all layers of aliasing and access control if they exist, then const Type* UnwrapAliasAndAccess() const;
/// - removing the pointer, if it exists, then
/// - removing all further layers of aliasing or access control, if they exist /// @returns the type with all aliasing, access control and pointers removed
/// @returns the unwrapped type
const Type* UnwrapAll() const; const Type* UnwrapAll() const;
/// @returns true if this type is a float scalar /// @returns true if this type is a float scalar

View File

@ -174,7 +174,7 @@ bool Resolver::Resolve() {
// https://gpuweb.github.io/gpuweb/wgsl.html#storable-types // https://gpuweb.github.io/gpuweb/wgsl.html#storable-types
bool Resolver::IsStorable(const sem::Type* type) { bool Resolver::IsStorable(const sem::Type* type) {
type = type->UnwrapIfNeeded(); type = type->UnwrapAccess();
if (type->is_scalar() || type->Is<sem::Vector>() || type->Is<sem::Matrix>()) { if (type->is_scalar() || type->Is<sem::Vector>() || type->Is<sem::Matrix>()) {
return true; return true;
} }
@ -194,7 +194,7 @@ bool Resolver::IsStorable(const sem::Type* type) {
// https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types // https://gpuweb.github.io/gpuweb/wgsl.html#host-shareable-types
bool Resolver::IsHostShareable(const sem::Type* type) { bool Resolver::IsHostShareable(const sem::Type* type) {
type = type->UnwrapIfNeeded(); type = type->UnwrapAccess();
if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) { if (type->IsAnyOf<sem::I32, sem::U32, sem::F32>()) {
return true; return true;
} }
@ -224,9 +224,9 @@ bool Resolver::IsValidAssignment(const sem::Type* lhs, const sem::Type* rhs) {
// This will need to be fixed after WGSL agrees the behavior of pointers / // This will need to be fixed after WGSL agrees the behavior of pointers /
// references. // references.
// Check: // Check:
if (lhs->UnwrapIfNeeded() != rhs->UnwrapIfNeeded()) { if (lhs->UnwrapAccess() != rhs->UnwrapAccess()) {
// Try RHS dereference // Try RHS dereference
if (lhs->UnwrapIfNeeded() != rhs->UnwrapAll()) { if (lhs->UnwrapAccess() != rhs->UnwrapAll()) {
return false; return false;
} }
} }
@ -1636,7 +1636,7 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
} }
auto* res = TypeOf(expr->structure()); auto* res = TypeOf(expr->structure());
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); auto* data_type = res->UnwrapAll();
sem::Type* ret = nullptr; sem::Type* ret = nullptr;
std::vector<uint32_t> swizzle; std::vector<uint32_t> swizzle;
@ -1926,7 +1926,7 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() || if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() || expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
expr->IsDivide() || expr->IsModulo()) { expr->IsDivide() || expr->IsModulo()) {
SetType(expr, TypeOf(expr->lhs())->UnwrapPtrIfNeeded()); SetType(expr, TypeOf(expr->lhs())->UnwrapPtr());
return true; return true;
} }
// Result type is a scalar or vector of boolean type // Result type is a scalar or vector of boolean type
@ -1999,7 +1999,7 @@ bool Resolver::UnaryOp(ast::UnaryOpExpression* expr) {
return false; return false;
} }
auto* result_type = TypeOf(expr->expr())->UnwrapPtrIfNeeded(); auto* result_type = TypeOf(expr->expr())->UnwrapPtr();
SetType(expr, result_type); SetType(expr, result_type);
return true; return true;
} }
@ -2039,7 +2039,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
// If the variable has no type, infer it from the rhs // If the variable has no type, infer it from the rhs
if (type == nullptr) { if (type == nullptr) {
type_name = TypeNameOf(ctor); type_name = TypeNameOf(ctor);
type = rhs_type->UnwrapPtrIfNeeded(); type = rhs_type->UnwrapPtr();
} }
if (!IsValidAssignment(type, rhs_type)) { if (!IsValidAssignment(type, rhs_type)) {
@ -2726,7 +2726,7 @@ bool Resolver::ValidateAssignment(const ast::AssignmentStatement* a) {
} }
// lhs must be a pointer or a constant // lhs must be a pointer or a constant
auto* lhs_result_type = TypeOf(lhs)->UnwrapIfNeeded(); auto* lhs_result_type = TypeOf(lhs)->UnwrapAccess();
if (!lhs_result_type->Is<sem::Pointer>()) { if (!lhs_result_type->Is<sem::Pointer>()) {
// In case lhs is a constant identifier, output a nicer message as it's // In case lhs is a constant identifier, output a nicer message as it's
// likely to be a common programmer error. // likely to be a common programmer error.
@ -2768,7 +2768,7 @@ bool Resolver::Assignment(ast::AssignmentStatement* a) {
bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc,
sem::Type* ty, sem::Type* ty,
const Source& usage) { const Source& usage) {
ty = const_cast<sem::Type*>(ty->UnwrapIfNeeded()); ty = const_cast<sem::Type*>(ty->UnwrapAccess());
if (auto* str = ty->As<sem::Struct>()) { if (auto* str = ty->As<sem::Struct>()) {
if (str->StorageClassUsage().count(sc)) { if (str->StorageClassUsage().count(sc)) {

View File

@ -521,7 +521,7 @@ TEST_P(CanonicalTest, All) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* got = TypeOf(expr)->UnwrapPtrIfNeeded(); auto* got = TypeOf(expr)->UnwrapPtr();
auto* expected = params.create_sem_type(ty); auto* expected = params.create_sem_type(ty);
EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n" EXPECT_EQ(got, expected) << "got: " << FriendlyName(got) << "\n"

View File

@ -22,9 +22,7 @@ namespace sem {
Expression::Expression(ast::Expression* declaration, Expression::Expression(ast::Expression* declaration,
const sem::Type* type, const sem::Type* type,
Statement* statement) Statement* statement)
: declaration_(declaration), : declaration_(declaration), type_(type), statement_(statement) {
type_(type->UnwrapIfNeeded()),
statement_(statement) {
TINT_ASSERT(type_); TINT_ASSERT(type_);
} }

View File

@ -138,7 +138,7 @@ Function::VariableBindings Function::ReferencedStorageTextureVariables() const {
VariableBindings ret; VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) { for (auto* var : ReferencedModuleVariables()) {
auto* unwrapped_type = var->Type()->UnwrapIfNeeded(); auto* unwrapped_type = var->Type()->UnwrapAccess();
auto* storage_texture = unwrapped_type->As<sem::StorageTexture>(); auto* storage_texture = unwrapped_type->As<sem::StorageTexture>();
if (storage_texture == nullptr) { if (storage_texture == nullptr) {
continue; continue;
@ -155,7 +155,7 @@ Function::VariableBindings Function::ReferencedDepthTextureVariables() const {
VariableBindings ret; VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) { for (auto* var : ReferencedModuleVariables()) {
auto* unwrapped_type = var->Type()->UnwrapIfNeeded(); auto* unwrapped_type = var->Type()->UnwrapAccess();
auto* storage_texture = unwrapped_type->As<sem::DepthTexture>(); auto* storage_texture = unwrapped_type->As<sem::DepthTexture>();
if (storage_texture == nullptr) { if (storage_texture == nullptr) {
continue; continue;
@ -182,7 +182,7 @@ Function::VariableBindings Function::ReferencedSamplerVariablesImpl(
VariableBindings ret; VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) { for (auto* var : ReferencedModuleVariables()) {
auto* unwrapped_type = var->Type()->UnwrapIfNeeded(); auto* unwrapped_type = var->Type()->UnwrapAccess();
auto* sampler = unwrapped_type->As<sem::Sampler>(); auto* sampler = unwrapped_type->As<sem::Sampler>();
if (sampler == nullptr || sampler->kind() != kind) { if (sampler == nullptr || sampler->kind() != kind) {
continue; continue;
@ -200,7 +200,7 @@ Function::VariableBindings Function::ReferencedSampledTextureVariablesImpl(
VariableBindings ret; VariableBindings ret;
for (auto* var : ReferencedModuleVariables()) { for (auto* var : ReferencedModuleVariables()) {
auto* unwrapped_type = var->Type()->UnwrapIfNeeded(); auto* unwrapped_type = var->Type()->UnwrapAccess();
auto* texture = unwrapped_type->As<sem::Texture>(); auto* texture = unwrapped_type->As<sem::Texture>();
if (texture == nullptr) { if (texture == nullptr) {
continue; continue;

View File

@ -36,7 +36,7 @@ Type::Type(Type&&) = default;
Type::~Type() = default; Type::~Type() = default;
const Type* Type::UnwrapPtrIfNeeded() const { const Type* Type::UnwrapPtr() const {
auto* type = this; auto* type = this;
while (auto* ptr = type->As<sem::Pointer>()) { while (auto* ptr = type->As<sem::Pointer>()) {
type = ptr->type(); type = ptr->type();
@ -44,7 +44,7 @@ const Type* Type::UnwrapPtrIfNeeded() const {
return type; return type;
} }
const Type* Type::UnwrapIfNeeded() const { const Type* Type::UnwrapAccess() const {
auto* type = this; auto* type = this;
while (auto* access = type->As<sem::AccessControl>()) { while (auto* access = type->As<sem::AccessControl>()) {
type = access->type(); type = access->type();
@ -57,14 +57,13 @@ const Type* Type::UnwrapAll() const {
while (true) { while (true) {
if (auto* ptr = type->As<sem::Pointer>()) { if (auto* ptr = type->As<sem::Pointer>()) {
type = ptr->type(); type = ptr->type();
continue; } else if (auto* access = type->As<sem::AccessControl>()) {
}
if (auto* access = type->As<sem::AccessControl>()) {
type = access->type(); type = access->type();
continue; } else {
break;
} }
return type;
} }
return type;
} }
bool Type::is_scalar() const { bool Type::is_scalar() const {

View File

@ -45,19 +45,16 @@ class Type : public Castable<Type, Node> {
/// declared in WGSL. /// declared in WGSL.
virtual std::string FriendlyName(const SymbolTable& symbols) const = 0; virtual std::string FriendlyName(const SymbolTable& symbols) const = 0;
/// @returns the pointee type if this is a pointer, `this` otherwise /// @returns the inner most pointee type if this is a pointer, `this`
const Type* UnwrapPtrIfNeeded() const; /// otherwise
const Type* UnwrapPtr() const;
/// Removes all levels of access control. /// @returns the inner most type if this is an access control, `this`
/// This is just enough to assist with WGSL translation /// otherwise
/// in that you want see through one level of pointer to get from an const Type* UnwrapAccess() const;
/// identifier-like expression as an l-value to its corresponding r-value,
/// plus see through the wrappers on either side.
/// @returns the completely unaliased type.
const Type* UnwrapIfNeeded() const;
/// Returns the type found after removing all layers of access control and /// Returns the type found after removing all layers of access control and
/// pointer. /// pointer
/// @returns the unwrapped type /// @returns the unwrapped type
const Type* UnwrapAll() const; const Type* UnwrapAll() const;

View File

@ -746,7 +746,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
auto* buf = access.var->Declaration(); auto* buf = access.var->Declaration();
auto* offset = access.offset->Build(ctx); auto* offset = access.offset->Build(ctx);
auto* buf_ty = access.var->Type()->UnwrapPtrIfNeeded(); auto* buf_ty = access.var->Type()->UnwrapPtr();
auto* el_ty = access.type->UnwrapAll(); auto* el_ty = access.type->UnwrapAll();
auto* insert_after = ConstructedTypeOf(access.var->Type()); auto* insert_after = ConstructedTypeOf(access.var->Type());
Symbol func = state.LoadFunc(ctx, insert_after, buf_ty, el_ty); Symbol func = state.LoadFunc(ctx, insert_after, buf_ty, el_ty);
@ -760,7 +760,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
for (auto& store : state.stores) { for (auto& store : state.stores) {
auto* buf = store.target.var->Declaration(); auto* buf = store.target.var->Declaration();
auto* offset = store.target.offset->Build(ctx); auto* offset = store.target.offset->Build(ctx);
auto* buf_ty = store.target.var->Type()->UnwrapPtrIfNeeded(); auto* buf_ty = store.target.var->Type()->UnwrapPtr();
auto* el_ty = store.target.type->UnwrapAll(); auto* el_ty = store.target.type->UnwrapAll();
auto* value = store.assignment->rhs(); auto* value = store.assignment->rhs();
auto* insert_after = ConstructedTypeOf(store.target.var->Type()); auto* insert_after = ConstructedTypeOf(store.target.var->Type());

View File

@ -41,7 +41,7 @@ ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b,
uint32_t packed_size; uint32_t packed_size;
const sem::Type* packed_el_sem_ty; const sem::Type* packed_el_sem_ty;
auto* vector_sem = b->Sem().Get(vector); auto* vector_sem = b->Sem().Get(vector);
auto* vector_ty = vector_sem->Type()->UnwrapPtrIfNeeded(); auto* vector_ty = vector_sem->Type()->UnwrapPtr();
if (auto* vec = vector_ty->As<sem::Vector>()) { if (auto* vec = vector_ty->As<sem::Vector>()) {
packed_size = vec->size() + 1; packed_size = vec->size() + 1;
packed_el_sem_ty = vec->type(); packed_el_sem_ty = vec->type();
@ -72,7 +72,7 @@ ast::TypeConstructorExpression* AppendVector(ProgramBuilder* b,
} else { } else {
packed.emplace_back(vector); packed.emplace_back(vector);
} }
if (packed_el_sem_ty != b->TypeOf(scalar)->UnwrapPtrIfNeeded()) { if (packed_el_sem_ty != b->TypeOf(scalar)->UnwrapPtr()) {
// Cast scalar to the vector element type // Cast scalar to the vector element type
auto* scalar_cast = b->Construct(packed_el_ty, scalar); auto* scalar_cast = b->Construct(packed_el_ty, scalar);
b->Sem().Add(scalar_cast, b->create<sem::Expression>( b->Sem().Add(scalar_cast, b->create<sem::Expression>(

View File

@ -1697,7 +1697,7 @@ bool GeneratorImpl::EmitEntryPointData(
continue; // Global already emitted continue; // Global already emitted
} }
auto* type = var->Type()->UnwrapIfNeeded(); auto* type = var->Type()->UnwrapAccess();
if (auto* strct = type->As<sem::Struct>()) { if (auto* strct = type->As<sem::Struct>()) {
out << "ConstantBuffer<" out << "ConstantBuffer<"
<< builder_.Symbols().NameFor(strct->Declaration()->name()) << "> " << builder_.Symbols().NameFor(strct->Declaration()->name()) << "> "

View File

@ -627,7 +627,7 @@ bool Builder::GenerateFunctionVariable(ast::Variable* var) {
// TODO(dsinclair) We could detect if the constructor is fully const and emit // TODO(dsinclair) We could detect if the constructor is fully const and emit
// an initializer value for the variable instead of doing the OpLoad. // an initializer value for the variable instead of doing the OpLoad.
auto null_id = GenerateConstantNullIfNeeded(type->UnwrapPtrIfNeeded()); auto null_id = GenerateConstantNullIfNeeded(type->UnwrapPtr());
if (null_id == 0) { if (null_id == 0) {
return 0; return 0;
} }
@ -953,7 +953,7 @@ bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
} }
info->source_id = GenerateLoadIfNeeded(expr_type, extract_id); info->source_id = GenerateLoadIfNeeded(expr_type, extract_id);
info->source_type = expr_type->UnwrapPtrIfNeeded(); info->source_type = expr_type->UnwrapPtr();
info->access_chain_indices.clear(); info->access_chain_indices.clear();
} }
@ -1130,7 +1130,7 @@ uint32_t Builder::GenerateLoadIfNeeded(const sem::Type* type, uint32_t id) {
return id; return id;
} }
auto type_id = GenerateTypeIfNeeded(type->UnwrapPtrIfNeeded()); auto type_id = GenerateTypeIfNeeded(type->UnwrapPtr());
auto result = result_op(); auto result = result_op();
auto result_id = result.to_i(); auto result_id = result.to_i();
if (!push_function_inst(spv::Op::OpLoad, if (!push_function_inst(spv::Op::OpLoad,
@ -1271,7 +1271,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
// Generate the zero initializer if there are no values provided. // Generate the zero initializer if there are no values provided.
if (values.empty()) { if (values.empty()) {
return GenerateConstantNullIfNeeded(result_type->UnwrapPtrIfNeeded()); return GenerateConstantNullIfNeeded(result_type->UnwrapPtr());
} }
std::ostringstream out; std::ostringstream out;
@ -1326,7 +1326,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
return 0; return 0;
} }
auto* value_type = TypeOf(e)->UnwrapPtrIfNeeded(); auto* value_type = TypeOf(e)->UnwrapPtr();
// If the result and value types are the same we can just use the object. // If the result and value types are the same we can just use the object.
// If the result is not a vector then we should have validated that the // If the result is not a vector then we should have validated that the
// value type is a correctly sized vector so we can just use it directly. // value type is a correctly sized vector so we can just use it directly.
@ -1443,7 +1443,7 @@ uint32_t Builder::GenerateCastOrCopyOrPassthrough(const sem::Type* to_type,
} }
val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id); val_id = GenerateLoadIfNeeded(TypeOf(from_expr), val_id);
auto* from_type = TypeOf(from_expr)->UnwrapPtrIfNeeded(); auto* from_type = TypeOf(from_expr)->UnwrapPtr();
spv::Op op = spv::Op::OpNop; spv::Op op = spv::Op::OpNop;
if ((from_type->Is<sem::I32>() && to_type->Is<sem::F32>()) || if ((from_type->Is<sem::I32>() && to_type->Is<sem::F32>()) ||
@ -2578,8 +2578,8 @@ uint32_t Builder::GenerateBitcastExpression(ast::BitcastExpression* expr) {
val_id = GenerateLoadIfNeeded(TypeOf(expr->expr()), val_id); val_id = GenerateLoadIfNeeded(TypeOf(expr->expr()), val_id);
// Bitcast does not allow same types, just emit a CopyObject // Bitcast does not allow same types, just emit a CopyObject
auto* to_type = TypeOf(expr)->UnwrapPtrIfNeeded(); auto* to_type = TypeOf(expr)->UnwrapPtr();
auto* from_type = TypeOf(expr->expr())->UnwrapPtrIfNeeded(); auto* from_type = TypeOf(expr->expr())->UnwrapPtr();
if (to_type->type_name() == from_type->type_name()) { if (to_type->type_name() == from_type->type_name()) {
if (!push_function_inst( if (!push_function_inst(
spv::Op::OpCopyObject, spv::Op::OpCopyObject,
@ -2931,7 +2931,7 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
} }
if (auto* ac = type->As<sem::AccessControl>()) { if (auto* ac = type->As<sem::AccessControl>()) {
if (!ac->type()->UnwrapIfNeeded()->Is<sem::Struct>()) { if (!ac->type()->UnwrapAccess()->Is<sem::Struct>()) {
return GenerateTypeIfNeeded(ac->type()); return GenerateTypeIfNeeded(ac->type());
} }
} }
@ -2945,7 +2945,7 @@ uint32_t Builder::GenerateTypeIfNeeded(const sem::Type* type) {
auto id = result.to_i(); auto id = result.to_i();
if (auto* ac = type->As<sem::AccessControl>()) { if (auto* ac = type->As<sem::AccessControl>()) {
// The non-struct case was handled above. // The non-struct case was handled above.
auto* subtype = ac->type()->UnwrapIfNeeded(); auto* subtype = ac->UnwrapAccess();
if (!GenerateStructType(subtype->As<sem::Struct>(), ac->access_control(), if (!GenerateStructType(subtype->As<sem::Struct>(), ac->access_control(),
result)) { result)) {
return 0; return 0;