From 512ecc2762629820b5318f170161051a010a3260 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Wed, 28 Oct 2020 20:32:22 +0000 Subject: [PATCH] Rename unwrap helpers. With the addition of the AccessControlType we want to look through the access control as well as the aliases as we work through the type tree. This CL renames UnwrapAliasesIfNeeded to be UnwrapIfNeeded and UnwrapAliasPtrAlias to UnwrapAll. Change-Id: I5b027919c3143a89be24c4d87b8106f70358c03b Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/31104 Commit-Queue: David Neto Reviewed-by: Sarah Mashayekhi Reviewed-by: David Neto --- src/ast/expression.cc | 4 +- src/ast/type/alias_type_test.cc | 60 ++++++++++++++----- src/ast/type/type.cc | 16 +++-- src/ast/type/type.h | 14 ++--- src/reader/spirv/parser_impl.cc | 4 +- .../bound_array_accessors_transform.cc | 2 +- src/type_determiner.cc | 4 +- src/validator_impl.cc | 11 ++-- src/writer/hlsl/generator_impl.cc | 10 ++-- src/writer/spirv/builder.cc | 37 +++++------- 10 files changed, 95 insertions(+), 67 deletions(-) diff --git a/src/ast/expression.cc b/src/ast/expression.cc index 6a14603338..90d597e72c 100644 --- a/src/ast/expression.cc +++ b/src/ast/expression.cc @@ -36,8 +36,8 @@ Expression::Expression(const Source& source) : Node(source) {} Expression::~Expression() = default; void Expression::set_result_type(type::Type* type) { - // The expression result should never be an alias type - result_type_ = type->UnwrapAliasesIfNeeded(); + // The expression result should never be an alias or access-controlled type + result_type_ = type->UnwrapIfNeeded(); } bool Expression::IsArrayAccessor() const { diff --git a/src/ast/type/alias_type_test.cc b/src/ast/type/alias_type_test.cc index 76b547a4c7..22b549ebe9 100644 --- a/src/ast/type/alias_type_test.cc +++ b/src/ast/type/alias_type_test.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "src/ast/storage_class.h" +#include "src/ast/type/access_control_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/pointer_type.h" #include "src/ast/type/u32_type.h" @@ -59,25 +60,40 @@ TEST_F(AliasTypeTest, TypeName) { EXPECT_EQ(at.type_name(), "__alias_Particle__i32"); } -TEST_F(AliasTypeTest, UnwrapAliasesIfNeeded) { +TEST_F(AliasTypeTest, UnwrapIfNeeded_Alias) { U32Type u32; AliasType a{"a_type", &u32}; EXPECT_EQ(a.name(), "a_type"); EXPECT_EQ(a.type(), &u32); - EXPECT_EQ(a.UnwrapAliasesIfNeeded(), &u32); - EXPECT_EQ(u32.UnwrapAliasesIfNeeded(), &u32); + EXPECT_EQ(a.UnwrapIfNeeded(), &u32); + EXPECT_EQ(u32.UnwrapIfNeeded(), &u32); } -TEST_F(AliasTypeTest, UnwrapAliasesIfNeeded_MultiLevel) { +TEST_F(AliasTypeTest, UnwrapIfNeeded_AccessControl) { + U32Type u32; + AccessControlType a{AccessControl::kReadOnly, &u32}; + EXPECT_EQ(a.type(), &u32); + EXPECT_EQ(a.UnwrapIfNeeded(), &u32); +} + +TEST_F(AliasTypeTest, UnwrapIfNeeded_MultiLevel) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; EXPECT_EQ(aa.name(), "aa_type"); EXPECT_EQ(aa.type(), &a); - EXPECT_EQ(aa.UnwrapAliasesIfNeeded(), &u32); + EXPECT_EQ(aa.UnwrapIfNeeded(), &u32); } -TEST_F(AliasTypeTest, UnwrapAliasPtrAlias_TwiceAliasPointerTwiceAlias) { +TEST_F(AliasTypeTest, UnwrapIfNeeded_MultiLevel_AliasAccessControl) { + U32Type u32; + AliasType a{"a_type", &u32}; + AccessControlType aa{AccessControl::kReadWrite, &a}; + EXPECT_EQ(aa.type(), &a); + EXPECT_EQ(aa.UnwrapIfNeeded(), &u32); +} + +TEST_F(AliasTypeTest, UnwrapAll_TwiceAliasPointerTwiceAlias) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; @@ -86,23 +102,21 @@ TEST_F(AliasTypeTest, UnwrapAliasPtrAlias_TwiceAliasPointerTwiceAlias) { AliasType aapaa{"aapaa_type", &apaa}; EXPECT_EQ(aapaa.name(), "aapaa_type"); EXPECT_EQ(aapaa.type(), &apaa); - EXPECT_EQ(aapaa.UnwrapAliasPtrAlias(), &u32); - EXPECT_EQ(u32.UnwrapAliasPtrAlias(), &u32); + EXPECT_EQ(aapaa.UnwrapAll(), &u32); + EXPECT_EQ(u32.UnwrapAll(), &u32); } -TEST_F(AliasTypeTest, - UnwrapAliasPtrAlias_SecondConsecutivePointerBlocksWUnrapping) { +TEST_F(AliasTypeTest, UnwrapAll_SecondConsecutivePointerBlocksUnrapping) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; PointerType paa{&aa, StorageClass::kUniform}; PointerType ppaa{&paa, StorageClass::kUniform}; AliasType appaa{"appaa_type", &ppaa}; - EXPECT_EQ(appaa.UnwrapAliasPtrAlias(), &paa); + EXPECT_EQ(appaa.UnwrapAll(), &paa); } -TEST_F(AliasTypeTest, - UnwrapAliasPtrAlias_SecondNonConsecutivePointerBlocksWUnrapping) { +TEST_F(AliasTypeTest, UnwrapAll_SecondNonConsecutivePointerBlocksUnrapping) { U32Type u32; AliasType a{"a_type", &u32}; AliasType aa{"aa_type", &a}; @@ -111,7 +125,25 @@ TEST_F(AliasTypeTest, AliasType aapaa{"aapaa_type", &apaa}; PointerType paapaa{&aapaa, StorageClass::kUniform}; AliasType apaapaa{"apaapaa_type", &paapaa}; - EXPECT_EQ(apaapaa.UnwrapAliasPtrAlias(), &paa); + EXPECT_EQ(apaapaa.UnwrapAll(), &paa); +} + +TEST_F(AliasTypeTest, UnwrapAll_AccessControlPointer) { + U32Type u32; + AccessControlType a{AccessControl::kReadOnly, &u32}; + PointerType pa{&a, StorageClass::kUniform}; + EXPECT_EQ(pa.type(), &a); + EXPECT_EQ(pa.UnwrapAll(), &u32); + EXPECT_EQ(u32.UnwrapAll(), &u32); +} + +TEST_F(AliasTypeTest, UnwrapAll_PointerAccessControl) { + U32Type u32; + PointerType p{&u32, StorageClass::kUniform}; + AccessControlType a{AccessControl::kReadOnly, &p}; + EXPECT_EQ(a.type(), &p); + EXPECT_EQ(a.UnwrapAll(), &u32); + EXPECT_EQ(u32.UnwrapAll(), &u32); } } // namespace diff --git a/src/ast/type/type.cc b/src/ast/type/type.cc index 7f596afa30..b7d065d827 100644 --- a/src/ast/type/type.cc +++ b/src/ast/type/type.cc @@ -46,16 +46,22 @@ Type* Type::UnwrapPtrIfNeeded() { return this; } -Type* Type::UnwrapAliasesIfNeeded() { +Type* Type::UnwrapIfNeeded() { auto* where = this; - while (where->IsAlias()) { - where = where->AsAlias()->type(); + while (true) { + if (where->IsAlias()) { + where = where->AsAlias()->type(); + } else if (where->IsAccessControl()) { + where = where->AsAccessControl()->type(); + } else { + break; + } } return where; } -Type* Type::UnwrapAliasPtrAlias() { - return UnwrapAliasesIfNeeded()->UnwrapPtrIfNeeded()->UnwrapAliasesIfNeeded(); +Type* Type::UnwrapAll() { + return UnwrapIfNeeded()->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); } bool Type::IsAccessControl() const { diff --git a/src/ast/type/type.h b/src/ast/type/type.h index ff02f281f4..ca4b208642 100644 --- a/src/ast/type/type.h +++ b/src/ast/type/type.h @@ -82,20 +82,20 @@ class Type { /// @returns the pointee type if this is a pointer, |this| otherwise Type* UnwrapPtrIfNeeded(); - /// Removes all levels of aliasing, if this is an alias type. Otherwise - /// returns |this|. This is just enough to assist with WGSL translation + /// Removes all levels of aliasing and access control. + /// This is just enough to assist with WGSL translation /// in that you want see through one level of pointer to get from an /// identifier-like expression as an l-value to its corresponding r-value, - /// plus see through the aliases on either side. + /// plus see through the wrappers on either side. /// @returns the completely unaliased type. - Type* UnwrapAliasesIfNeeded(); + Type* UnwrapIfNeeded(); /// Returns the type found after: - /// - removing all layers of aliasing if they exist, then + /// - removing all layers of aliasing and access control if they exist, then /// - removing the pointer, if it exists, then - /// - removing all further layers of aliasing, if they exist + /// - removing all further layers of aliasing or access control, if they exist /// @returns the unwrapped type - Type* UnwrapAliasPtrAlias(); + Type* UnwrapAll(); /// @returns true if this type is a scalar bool is_scalar(); diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 8bc65a9e52..0f43943e6d 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -1143,7 +1143,7 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) { return {}; } - auto* ast_type = original_ast_type->UnwrapAliasesIfNeeded(); + auto* ast_type = original_ast_type->UnwrapIfNeeded(); // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // So canonicalization should map that way too. @@ -1220,7 +1220,7 @@ std::unique_ptr ParserImpl::MakeNullValue( } auto* original_type = type; - type = type->UnwrapAliasesIfNeeded(); + type = type->UnwrapIfNeeded(); if (type->IsBool()) { return std::make_unique( diff --git a/src/transform/bound_array_accessors_transform.cc b/src/transform/bound_array_accessors_transform.cc index 82ecd66efe..39ace7e82b 100644 --- a/src/transform/bound_array_accessors_transform.cc +++ b/src/transform/bound_array_accessors_transform.cc @@ -181,7 +181,7 @@ bool BoundArrayAccessorsTransform::ProcessArrayAccessor( return false; } - auto* ret_type = expr->array()->result_type()->UnwrapAliasPtrAlias(); + auto* ret_type = expr->array()->result_type()->UnwrapAll(); if (!ret_type->IsArray() && !ret_type->IsMatrix() && !ret_type->IsVector()) { return true; } diff --git a/src/type_determiner.cc b/src/type_determiner.cc index acc5b53694..2686d8813c 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -331,7 +331,7 @@ bool TypeDeterminer::DetermineArrayAccessor( } auto* res = expr->array()->result_type(); - auto* parent_type = res->UnwrapAliasPtrAlias(); + auto* parent_type = res->UnwrapAll(); ast::type::Type* ret = nullptr; if (parent_type->IsArray()) { ret = parent_type->AsArray()->type(); @@ -942,7 +942,7 @@ bool TypeDeterminer::DetermineMemberAccessor( } auto* res = expr->structure()->result_type(); - auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapAliasesIfNeeded(); + auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); ast::type::Type* ret = nullptr; if (data_type->IsStruct()) { diff --git a/src/validator_impl.cc b/src/validator_impl.cc index d3d404d1ce..585bb18d61 100644 --- a/src/validator_impl.cc +++ b/src/validator_impl.cc @@ -168,9 +168,8 @@ bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) { ast::type::Type* func_type = current_function_->return_type(); ast::type::VoidType void_type; - auto* ret_type = ret->has_value() - ? ret->value()->result_type()->UnwrapAliasPtrAlias() - : &void_type; + auto* ret_type = + ret->has_value() ? ret->value()->result_type()->UnwrapAll() : &void_type; if (func_type->type_name() != ret_type->type_name()) { set_error(ret->source(), @@ -249,7 +248,7 @@ bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) { return false; } - auto* cond_type = s->condition()->result_type()->UnwrapAliasPtrAlias(); + auto* cond_type = s->condition()->result_type()->UnwrapAll(); if (!(cond_type->IsI32() || cond_type->IsU32())) { set_error(s->condition()->source(), "v-0025: switch statement selector expression must be of a " @@ -393,8 +392,8 @@ bool ValidatorImpl::ValidateResultTypes(const ast::AssignmentStatement* a) { return false; } - auto* lhs_result_type = a->lhs()->result_type()->UnwrapAliasPtrAlias(); - auto* rhs_result_type = a->rhs()->result_type()->UnwrapAliasPtrAlias(); + auto* lhs_result_type = a->lhs()->result_type()->UnwrapAll(); + auto* rhs_result_type = a->rhs()->result_type()->UnwrapAll(); if (lhs_result_type != rhs_result_type) { // TODO(sarahM0): figur out what should be the error number. set_error(a->source(), "v-000x: invalid assignment of '" + diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 70a1cb4df1..c771175229 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1170,7 +1170,7 @@ bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) { } // auto* set = data.second.set; - auto* type = var->type()->UnwrapAliasesIfNeeded(); + auto* type = var->type()->UnwrapIfNeeded(); if (type->IsStruct()) { auto* strct = type->AsStruct(); @@ -1558,7 +1558,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( first = false; if (expr->IsMemberAccessor()) { auto* mem = expr->AsMemberAccessor(); - auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias(); + auto* res_type = mem->structure()->result_type()->UnwrapAll(); if (res_type->IsStruct()) { auto* str_type = res_type->AsStruct()->impl(); auto* str_member = str_type->get_member(mem->member()->name()); @@ -1593,7 +1593,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression( expr = mem->structure(); } else if (expr->IsArrayAccessor()) { auto* ary = expr->AsArrayAccessor(); - auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias(); + auto* ary_type = ary->array()->result_type()->UnwrapAll(); out << "("; if (ary_type->IsArray()) { @@ -1641,7 +1641,7 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre, std::ostream& out, ast::Expression* expr, ast::Expression* rhs) { - auto* result_type = expr->result_type()->UnwrapAliasPtrAlias(); + auto* result_type = expr->result_type()->UnwrapAll(); bool is_store = rhs != nullptr; std::string access_method = is_store ? "Store" : "Load"; @@ -1758,7 +1758,7 @@ bool GeneratorImpl::is_storage_buffer_access( bool GeneratorImpl::is_storage_buffer_access( ast::MemberAccessorExpression* expr) { auto* structure = expr->structure(); - auto* data_type = structure->result_type()->UnwrapAliasPtrAlias(); + auto* data_type = structure->result_type()->UnwrapAll(); // If the data is a multi-element swizzle then we will not load the swizzle // portion through the Load command. if (data_type->IsVector() && expr->member()->name().size() > 1) { diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index baca05c0cc..12d0a7a95b 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -814,10 +814,8 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr, bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, AccessorInfo* info) { - auto* data_type = expr->structure() - ->result_type() - ->UnwrapPtrIfNeeded() - ->UnwrapAliasesIfNeeded(); + auto* data_type = + expr->structure()->result_type()->UnwrapPtrIfNeeded()->UnwrapIfNeeded(); // If the data_type is a structure we're accessing a member, if it's a // vector we're accessing a swizzle. @@ -1137,7 +1135,7 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) { } auto* tc = expr->AsConstructor()->AsTypeConstructor(); - auto* result_type = tc->type()->UnwrapAliasPtrAlias(); + auto* result_type = tc->type()->UnwrapAll(); for (size_t i = 0; i < tc->values().size(); ++i) { auto* e = tc->values()[i].get(); @@ -1165,21 +1163,17 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) { } auto* sc = e->AsConstructor()->AsScalarConstructor(); - ast::type::Type* subtype = result_type->UnwrapAliasPtrAlias(); + ast::type::Type* subtype = result_type->UnwrapAll(); if (subtype->IsVector()) { - subtype = subtype->AsVector()->type()->UnwrapAliasPtrAlias(); + subtype = subtype->AsVector()->type()->UnwrapAll(); } else if (subtype->IsMatrix()) { - subtype = subtype->AsMatrix()->type()->UnwrapAliasPtrAlias(); + subtype = subtype->AsMatrix()->type()->UnwrapAll(); } else if (subtype->IsArray()) { - subtype = subtype->AsArray()->type()->UnwrapAliasPtrAlias(); + subtype = subtype->AsArray()->type()->UnwrapAll(); } else if (subtype->IsStruct()) { - subtype = subtype->AsStruct() - ->impl() - ->members()[i] - ->type() - ->UnwrapAliasPtrAlias(); + subtype = subtype->AsStruct()->impl()->members()[i]->type()->UnwrapAll(); } - if (subtype != sc->result_type()->UnwrapAliasPtrAlias()) { + if (subtype != sc->result_type()->UnwrapAll()) { return false; } } @@ -1200,7 +1194,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( std::ostringstream out; out << "__const"; - auto* result_type = init->type()->UnwrapAliasPtrAlias(); + auto* result_type = init->type()->UnwrapAll(); bool constructor_is_const = is_constructor_const(init, is_global_init); if (has_error()) { return 0; @@ -1209,7 +1203,7 @@ uint32_t Builder::GenerateTypeConstructorExpression( bool can_cast_or_copy = result_type->is_scalar(); if (result_type->IsVector() && result_type->AsVector()->type()->is_scalar()) { - auto* value_type = values[0]->result_type()->UnwrapAliasPtrAlias(); + auto* value_type = values[0]->result_type()->UnwrapAll(); can_cast_or_copy = (value_type->IsVector() && value_type->AsVector()->type()->is_scalar() && @@ -1773,7 +1767,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident, } params.push_back(Operand::Int(struct_id)); - auto* type = accessor->structure()->result_type()->UnwrapAliasPtrAlias(); + auto* type = accessor->structure()->result_type()->UnwrapAll(); if (!type->IsStruct()) { error_ = "invalid type (" + type->type_name() + ") for runtime array length"; @@ -1866,11 +1860,8 @@ uint32_t Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident, ast::CallExpression* call, uint32_t result_id, OperandList wgsl_params) { - auto* texture_type = call->params()[0] - .get() - ->result_type() - ->UnwrapAliasPtrAlias() - ->AsTexture(); + auto* texture_type = + call->params()[0].get()->result_type()->UnwrapAll()->AsTexture(); // TODO: Remove the LOD param from textureLoad on storage textures when // https://github.com/gpuweb/gpuweb/pull/1032 gets merged.