diff --git a/src/ast/expression.cc b/src/ast/expression.cc index 6a14603338..90d597e72c 100644 --- a/src/ast/expression.cc +++ b/src/ast/expression.cc @@ -36,8 +36,8 @@ Expression::Expression(const Source& source) : Node(source) {} Expression::~Expression() = default; void Expression::set_result_type(type::Type* type) { - // The expression result should never be an alias type - result_type_ = type->UnwrapAliasesIfNeeded(); + // The expression result should never be an alias or access-controlled type + result_type_ = type->UnwrapIfNeeded(); } bool Expression::IsArrayAccessor() const { diff --git a/src/ast/type/alias_type_test.cc b/src/ast/type/alias_type_test.cc index 76b547a4c7..22b549ebe9 100644 --- a/src/ast/type/alias_type_test.cc +++ b/src/ast/type/alias_type_test.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "src/ast/storage_class.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/pointer_type.h" #include "src/ast/type/u32_type.h" @@ -59,25 +60,40 @@ TEST_F(AliasTypeTest, TypeName) { EXPECT_EQ(at.type_name(), "__alias_Particle__i32"); } -TEST_F(AliasTypeTest, UnwrapAliasesIfNeeded) { +TEST_F(AliasTypeTest, UnwrapIfNeeded_Alias) { U32Type u32; AliasType a{"a_type", &u32}; EXPECT_EQ(a.name(), "a_type"); EXPECT_EQ(a.type(), &u32); - EXPECT_EQ(a.UnwrapAliasesIfNeeded(), &u32); - EXPECT_EQ(u32.UnwrapAliasesIfNeeded(), &u32); + EXPECT_EQ(a.UnwrapIfNeeded(), &u32); + EXPECT_EQ(u32.UnwrapIfNeeded(), &u32); } -TEST_F(AliasTypeTest, UnwrapAliasesIfNeeded_MultiLevel) { +TEST_F(AliasTypeTest, UnwrapIfNeeded_AccessControl) { + U32Type u32; + AccessControlType a{AccessControl::kReadOnly, &u32}; + EXPECT_EQ(a.type(), &u32); + EXPECT_EQ(a.UnwrapIfNeeded(), &u32); +} + +TEST_F(AliasTypeTest, UnwrapIfNeeded_MultiLevel) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; EXPECT_EQ(aa.name(), "aa_type"); EXPECT_EQ(aa.type(), &a); - EXPECT_EQ(aa.UnwrapAliasesIfNeeded(), &u32); + EXPECT_EQ(aa.UnwrapIfNeeded(), &u32); } -TEST_F(AliasTypeTest, UnwrapAliasPtrAlias_TwiceAliasPointerTwiceAlias) { +TEST_F(AliasTypeTest, UnwrapIfNeeded_MultiLevel_AliasAccessControl) { + U32Type u32; + AliasType a{"a_type", &u32}; + AccessControlType aa{AccessControl::kReadWrite, &a}; + EXPECT_EQ(aa.type(), &a); + EXPECT_EQ(aa.UnwrapIfNeeded(), &u32); +} + +TEST_F(AliasTypeTest, UnwrapAll_TwiceAliasPointerTwiceAlias) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; @@ -86,23 +102,21 @@ TEST_F(AliasTypeTest, UnwrapAliasPtrAlias_TwiceAliasPointerTwiceAlias) { AliasType aapaa{"aapaa_type", &apaa}; EXPECT_EQ(aapaa.name(), "aapaa_type"); EXPECT_EQ(aapaa.type(), &apaa); - EXPECT_EQ(aapaa.UnwrapAliasPtrAlias(), &u32); - EXPECT_EQ(u32.UnwrapAliasPtrAlias(), &u32); + EXPECT_EQ(aapaa.UnwrapAll(), &u32); + EXPECT_EQ(u32.UnwrapAll(), &u32); } -TEST_F(AliasTypeTest, - UnwrapAliasPtrAlias_SecondConsecutivePointerBlocksWUnrapping) { +TEST_F(AliasTypeTest, UnwrapAll_SecondConsecutivePointerBlocksUnrapping) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; PointerType paa{&aa, StorageClass::kUniform}; PointerType ppaa{&paa, StorageClass::kUniform}; AliasType appaa{"appaa_type", &ppaa}; - EXPECT_EQ(appaa.UnwrapAliasPtrAlias(), &paa); + EXPECT_EQ(appaa.UnwrapAll(), &paa); } -TEST_F(AliasTypeTest, - UnwrapAliasPtrAlias_SecondNonConsecutivePointerBlocksWUnrapping) { +TEST_F(AliasTypeTest, UnwrapAll_SecondNonConsecutivePointerBlocksUnrapping) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; @@ -111,7 +125,25 @@ TEST_F(AliasTypeTest, AliasType aapaa{"aapaa_type", &apaa}; PointerType paapaa{&aapaa, StorageClass::kUniform}; AliasType apaapaa{"apaapaa_type", &paapaa}; - EXPECT_EQ(apaapaa.UnwrapAliasPtrAlias(), &paa); + EXPECT_EQ(apaapaa.UnwrapAll(), &paa); +} + +TEST_F(AliasTypeTest, UnwrapAll_AccessControlPointer) { + U32Type u32; + AccessControlType a{AccessControl::kReadOnly, &u32}; + PointerType pa{&a, StorageClass::kUniform}; + EXPECT_EQ(pa.type(), &a); + EXPECT_EQ(pa.UnwrapAll(), &u32); + EXPECT_EQ(u32.UnwrapAll(), &u32); +} + +TEST_F(AliasTypeTest, UnwrapAll_PointerAccessControl) { + U32Type u32; + PointerType p{&u32, StorageClass::kUniform}; + AccessControlType a{AccessControl::kReadOnly, &p}; + EXPECT_EQ(a.type(), &p); + EXPECT_EQ(a.UnwrapAll(), &u32); + EXPECT_EQ(u32.UnwrapAll(), &u32); } } // namespace diff --git a/src/ast/type/type.cc b/src/ast/type/type.cc index 7f596afa30..b7d065d827 100644 --- a/src/ast/type/type.cc +++ b/src/ast/type/type.cc @@ -46,16 +46,22 @@ Type* Type::UnwrapPtrIfNeeded() { return this; } -Type* Type::UnwrapAliasesIfNeeded() { +Type* Type::UnwrapIfNeeded() { auto* where = this; - while (where->IsAlias()) { - where = where->AsAlias()->type(); + while (true) { + if (where->IsAlias()) { + where = where->AsAlias()->type(); + } else if (where->IsAccessControl()) { + where = where->AsAccessControl()->type(); + } else { + break; + } } return where; } -Type* Type::UnwrapAliasPtrAlias() { - return UnwrapAliasesIfNeeded()->UnwrapPtrIfNeeded()->UnwrapAliasesIfNeeded(); +Type* Type::UnwrapAll() { + return UnwrapIfNeeded()->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); } bool Type::IsAccessControl() const { diff --git a/src/ast/type/type.h b/src/ast/type/type.h index ff02f281f4..ca4b208642 100644 --- a/src/ast/type/type.h +++ b/src/ast/type/type.h @@ -82,20 +82,20 @@ class Type { /// @returns the pointee type if this is a pointer, |this| otherwise Type* UnwrapPtrIfNeeded(); - /// Removes all levels of aliasing, if this is an alias type. Otherwise - /// returns |this|. This is just enough to assist with WGSL translation + /// Removes all levels of aliasing and access control. + /// This is just enough to assist with WGSL translation /// in that you want see through one level of pointer to get from an /// identifier-like expression as an l-value to its corresponding r-value, - /// plus see through the aliases on either side. + /// plus see through the wrappers on either side. /// @returns the completely unaliased type. - Type* UnwrapAliasesIfNeeded(); + Type* UnwrapIfNeeded(); /// Returns the type found after: - /// - removing all layers of aliasing if they exist, then + /// - removing all layers of aliasing and access control if they exist, then /// - removing the pointer, if it exists, then - /// - removing all further layers of aliasing, if they exist + /// - removing all further layers of aliasing or access control, if they exist /// @returns the unwrapped type - Type* UnwrapAliasPtrAlias(); + Type* UnwrapAll(); /// @returns true if this type is a scalar bool is_scalar(); diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 8bc65a9e52..0f43943e6d 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -1143,7 +1143,7 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { return {}; } - auto* ast_type = original_ast_type->UnwrapAliasesIfNeeded(); + auto* ast_type = original_ast_type->UnwrapIfNeeded(); // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // So canonicalization should map that way too. @@ -1220,7 +1220,7 @@ std::unique_ptr ParserImpl::MakeNullValue( } auto* original_type = type; - type = type->UnwrapAliasesIfNeeded(); + type = type->UnwrapIfNeeded(); if (type->IsBool()) { return std::make_unique( diff --git a/src/transform/bound_array_accessors_transform.cc b/src/transform/bound_array_accessors_transform.cc index 82ecd66efe..39ace7e82b 100644 --- a/src/transform/bound_array_accessors_transform.cc +++ b/src/transform/bound_array_accessors_transform.cc @@ -181,7 +181,7 @@ bool BoundArrayAccessorsTransform::ProcessArrayAccessor( return false; } - auto* ret_type = expr->array()->result_type()->UnwrapAliasPtrAlias(); + auto* ret_type = expr->array()->result_type()->UnwrapAll(); if (!ret_type->IsArray() && !ret_type->IsMatrix() && !ret_type->IsVector()) { return true; } diff --git a/src/type_determiner.cc b/src/type_determiner.cc index acc5b53694..2686d8813c 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -331,7 +331,7 @@ bool TypeDeterminer::DetermineArrayAccessor( } auto* res = expr->array()->result_type(); - auto* parent_type = res->UnwrapAliasPtrAlias(); + auto* parent_type = res->UnwrapAll(); ast::type::Type* ret = nullptr; if (parent_type->IsArray()) { ret = parent_type->AsArray()->type(); @@ -942,7 +942,7 @@ bool TypeDeterminer::DetermineMemberAccessor( } auto* res = expr->structure()->result_type(); - auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapAliasesIfNeeded(); + auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); ast::type::Type* ret = nullptr; if (data_type->IsStruct()) { diff --git a/src/validator_impl.cc b/src/validator_impl.cc index d3d404d1ce..585bb18d61 100644 --- a/src/validator_impl.cc +++ b/src/validator_impl.cc @@ -168,9 +168,8 @@ bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) { ast::type::Type* func_type = current_function_->return_type(); ast::type::VoidType void_type; - auto* ret_type = ret->has_value() - ? ret->value()->result_type()->UnwrapAliasPtrAlias() - : &void_type; + auto* ret_type = + ret->has_value() ? ret->value()->result_type()->UnwrapAll() : &void_type; if (func_type->type_name() != ret_type->type_name()) { set_error(ret->source(), @@ -249,7 +248,7 @@ bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) { return false; } - auto* cond_type = s->condition()->result_type()->UnwrapAliasPtrAlias(); + auto* cond_type = s->condition()->result_type()->UnwrapAll(); if (!(cond_type->IsI32() || cond_type->IsU32())) { set_error(s->condition()->source(), "v-0025: switch statement selector expression must be of a " @@ -393,8 +392,8 @@ bool ValidatorImpl::ValidateResultTypes(const ast::AssignmentStatement* a) { return false; } - auto* lhs_result_type = a->lhs()->result_type()->UnwrapAliasPtrAlias(); - auto* rhs_result_type = a->rhs()->result_type()->UnwrapAliasPtrAlias(); + auto* lhs_result_type = a->lhs()->result_type()->UnwrapAll(); + auto* rhs_result_type = a->rhs()->result_type()->UnwrapAll(); if (lhs_result_type != rhs_result_type) { // TODO(sarahM0): figur out what should be the error number. set_error(a->source(), "v-000x: invalid assignment of '" + diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 70a1cb4df1..c771175229 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1170,7 +1170,7 @@ bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) { } // auto* set = data.second.set; - auto* type = var->type()->UnwrapAliasesIfNeeded(); + auto* type = var->type()->UnwrapIfNeeded(); if (type->IsStruct()) { auto* strct = type->AsStruct(); @@ -1558,7 +1558,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( first = false; if (expr->IsMemberAccessor()) { auto* mem = expr->AsMemberAccessor(); - auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias(); + auto* res_type = mem->structure()->result_type()->UnwrapAll(); if (res_type->IsStruct()) { auto* str_type = res_type->AsStruct()->impl(); auto* str_member = str_type->get_member(mem->member()->name()); @@ -1593,7 +1593,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( expr = mem->structure(); } else if (expr->IsArrayAccessor()) { auto* ary = expr->AsArrayAccessor(); - auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias(); + auto* ary_type = ary->array()->result_type()->UnwrapAll(); out << "("; if (ary_type->IsArray()) { @@ -1641,7 +1641,7 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre, std::ostream& out, ast::Expression* expr, ast::Expression* rhs) { - auto* result_type = expr->result_type()->UnwrapAliasPtrAlias(); + auto* result_type = expr->result_type()->UnwrapAll(); bool is_store = rhs != nullptr; std::string access_method = is_store ? "Store" : "Load"; @@ -1758,7 +1758,7 @@ bool GeneratorImpl::is_storage_buffer_access( bool GeneratorImpl::is_storage_buffer_access( ast::MemberAccessorExpression* expr) { auto* structure = expr->structure(); - auto* data_type = structure->result_type()->UnwrapAliasPtrAlias(); + auto* data_type = structure->result_type()->UnwrapAll(); // If the data is a multi-element swizzle then we will not load the swizzle // portion through the Load command. if (data_type->IsVector() && expr->member()->name().size() > 1) { diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index baca05c0cc..12d0a7a95b 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -814,10 +814,8 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr, bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, AccessorInfo* info) { - auto* data_type = expr->structure() - ->result_type() - ->UnwrapPtrIfNeeded() - ->UnwrapAliasesIfNeeded(); + auto* data_type = + expr->structure()->result_type()->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); // If the data_type is a structure we're accessing a member, if it's a // vector we're accessing a swizzle. @@ -1137,7 +1135,7 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) { } auto* tc = expr->AsConstructor()->AsTypeConstructor(); - auto* result_type = tc->type()->UnwrapAliasPtrAlias(); + auto* result_type = tc->type()->UnwrapAll(); for (size_t i = 0; i < tc->values().size(); ++i) { auto* e = tc->values()[i].get(); @@ -1165,21 +1163,17 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) { } auto* sc = e->AsConstructor()->AsScalarConstructor(); - ast::type::Type* subtype = result_type->UnwrapAliasPtrAlias(); + ast::type::Type* subtype = result_type->UnwrapAll(); if (subtype->IsVector()) { - subtype = subtype->AsVector()->type()->UnwrapAliasPtrAlias(); + subtype = subtype->AsVector()->type()->UnwrapAll(); } else if (subtype->IsMatrix()) { - subtype = subtype->AsMatrix()->type()->UnwrapAliasPtrAlias(); + subtype = subtype->AsMatrix()->type()->UnwrapAll(); } else if (subtype->IsArray()) { - subtype = subtype->AsArray()->type()->UnwrapAliasPtrAlias(); + subtype = subtype->AsArray()->type()->UnwrapAll(); } else if (subtype->IsStruct()) { - subtype = subtype->AsStruct() - ->impl() - ->members()[i] - ->type() - ->UnwrapAliasPtrAlias(); + subtype = subtype->AsStruct()->impl()->members()[i]->type()->UnwrapAll(); } - if (subtype != sc->result_type()->UnwrapAliasPtrAlias()) { + if (subtype != sc->result_type()->UnwrapAll()) { return false; } } @@ -1200,7 +1194,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( std::ostringstream out; out << "__const"; - auto* result_type = init->type()->UnwrapAliasPtrAlias(); + auto* result_type = init->type()->UnwrapAll(); bool constructor_is_const = is_constructor_const(init, is_global_init); if (has_error()) { return 0; @@ -1209,7 +1203,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( bool can_cast_or_copy = result_type->is_scalar(); if (result_type->IsVector() && result_type->AsVector()->type()->is_scalar()) { - auto* value_type = values[0]->result_type()->UnwrapAliasPtrAlias(); + auto* value_type = values[0]->result_type()->UnwrapAll(); can_cast_or_copy = (value_type->IsVector() && value_type->AsVector()->type()->is_scalar() && @@ -1773,7 +1767,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident, } params.push_back(Operand::Int(struct_id)); - auto* type = accessor->structure()->result_type()->UnwrapAliasPtrAlias(); + auto* type = accessor->structure()->result_type()->UnwrapAll(); if (!type->IsStruct()) { error_ = "invalid type (" + type->type_name() + ") for runtime array length"; @@ -1866,11 +1860,8 @@ uint32_t Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, ast::CallExpression* call, uint32_t result_id, OperandList wgsl_params) { - auto* texture_type = call->params()[0] - .get() - ->result_type() - ->UnwrapAliasPtrAlias() - ->AsTexture(); + auto* texture_type = + call->params()[0].get()->result_type()->UnwrapAll()->AsTexture(); // TODO: Remove the LOD param from textureLoad on storage textures when // https://github.com/gpuweb/gpuweb/pull/1032 gets merged.