Rename unwrap helpers.

With the addition of the AccessControlType we want to look through the
access control as well as the aliases as we work through the type tree.
This CL renames UnwrapAliasesIfNeeded to be UnwrapIfNeeded and
UnwrapAliasPtrAlias to UnwrapAll.

Change-Id: I5b027919c3143a89be24c4d87b8106f70358c03b
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/31104
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: Sarah Mashayekhi <sarahmashay@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-10-28 20:32:22 +00:00 committed by Commit Bot service account
parent c23a5652bd
commit 512ecc2762
10 changed files with 95 additions and 67 deletions

View File

@ -36,8 +36,8 @@ Expression::Expression(const Source& source) : Node(source) {}
Expression::~Expression() = default; Expression::~Expression() = default;
void Expression::set_result_type(type::Type* type) { void Expression::set_result_type(type::Type* type) {
// The expression result should never be an alias type // The expression result should never be an alias or access-controlled type
result_type_ = type->UnwrapAliasesIfNeeded(); result_type_ = type->UnwrapIfNeeded();
} }
bool Expression::IsArrayAccessor() const { bool Expression::IsArrayAccessor() const {

View File

@ -16,6 +16,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "src/ast/storage_class.h" #include "src/ast/storage_class.h"
#include "src/ast/type/access_control_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/pointer_type.h" #include "src/ast/type/pointer_type.h"
#include "src/ast/type/u32_type.h" #include "src/ast/type/u32_type.h"
@ -59,25 +60,40 @@ TEST_F(AliasTypeTest, TypeName) {
EXPECT_EQ(at.type_name(), "__alias_Particle__i32"); EXPECT_EQ(at.type_name(), "__alias_Particle__i32");
} }
TEST_F(AliasTypeTest, UnwrapAliasesIfNeeded) { TEST_F(AliasTypeTest, UnwrapIfNeeded_Alias) {
U32Type u32; U32Type u32;
AliasType a{"a_type", &u32}; AliasType a{"a_type", &u32};
EXPECT_EQ(a.name(), "a_type"); EXPECT_EQ(a.name(), "a_type");
EXPECT_EQ(a.type(), &u32); EXPECT_EQ(a.type(), &u32);
EXPECT_EQ(a.UnwrapAliasesIfNeeded(), &u32); EXPECT_EQ(a.UnwrapIfNeeded(), &u32);
EXPECT_EQ(u32.UnwrapAliasesIfNeeded(), &u32); EXPECT_EQ(u32.UnwrapIfNeeded(), &u32);
} }
TEST_F(AliasTypeTest, UnwrapAliasesIfNeeded_MultiLevel) { TEST_F(AliasTypeTest, UnwrapIfNeeded_AccessControl) {
U32Type u32;
AccessControlType a{AccessControl::kReadOnly, &u32};
EXPECT_EQ(a.type(), &u32);
EXPECT_EQ(a.UnwrapIfNeeded(), &u32);
}
TEST_F(AliasTypeTest, UnwrapIfNeeded_MultiLevel) {
U32Type u32; U32Type u32;
AliasType a{"a_type", &u32}; AliasType a{"a_type", &u32};
AliasType aa{"aa_type", &a}; AliasType aa{"aa_type", &a};
EXPECT_EQ(aa.name(), "aa_type"); EXPECT_EQ(aa.name(), "aa_type");
EXPECT_EQ(aa.type(), &a); EXPECT_EQ(aa.type(), &a);
EXPECT_EQ(aa.UnwrapAliasesIfNeeded(), &u32); EXPECT_EQ(aa.UnwrapIfNeeded(), &u32);
} }
TEST_F(AliasTypeTest, UnwrapAliasPtrAlias_TwiceAliasPointerTwiceAlias) { TEST_F(AliasTypeTest, UnwrapIfNeeded_MultiLevel_AliasAccessControl) {
U32Type u32;
AliasType a{"a_type", &u32};
AccessControlType aa{AccessControl::kReadWrite, &a};
EXPECT_EQ(aa.type(), &a);
EXPECT_EQ(aa.UnwrapIfNeeded(), &u32);
}
TEST_F(AliasTypeTest, UnwrapAll_TwiceAliasPointerTwiceAlias) {
U32Type u32; U32Type u32;
AliasType a{"a_type", &u32}; AliasType a{"a_type", &u32};
AliasType aa{"aa_type", &a}; AliasType aa{"aa_type", &a};
@ -86,23 +102,21 @@ TEST_F(AliasTypeTest, UnwrapAliasPtrAlias_TwiceAliasPointerTwiceAlias) {
AliasType aapaa{"aapaa_type", &apaa}; AliasType aapaa{"aapaa_type", &apaa};
EXPECT_EQ(aapaa.name(), "aapaa_type"); EXPECT_EQ(aapaa.name(), "aapaa_type");
EXPECT_EQ(aapaa.type(), &apaa); EXPECT_EQ(aapaa.type(), &apaa);
EXPECT_EQ(aapaa.UnwrapAliasPtrAlias(), &u32); EXPECT_EQ(aapaa.UnwrapAll(), &u32);
EXPECT_EQ(u32.UnwrapAliasPtrAlias(), &u32); EXPECT_EQ(u32.UnwrapAll(), &u32);
} }
TEST_F(AliasTypeTest, TEST_F(AliasTypeTest, UnwrapAll_SecondConsecutivePointerBlocksUnrapping) {
UnwrapAliasPtrAlias_SecondConsecutivePointerBlocksWUnrapping) {
U32Type u32; U32Type u32;
AliasType a{"a_type", &u32}; AliasType a{"a_type", &u32};
AliasType aa{"aa_type", &a}; AliasType aa{"aa_type", &a};
PointerType paa{&aa, StorageClass::kUniform}; PointerType paa{&aa, StorageClass::kUniform};
PointerType ppaa{&paa, StorageClass::kUniform}; PointerType ppaa{&paa, StorageClass::kUniform};
AliasType appaa{"appaa_type", &ppaa}; AliasType appaa{"appaa_type", &ppaa};
EXPECT_EQ(appaa.UnwrapAliasPtrAlias(), &paa); EXPECT_EQ(appaa.UnwrapAll(), &paa);
} }
TEST_F(AliasTypeTest, TEST_F(AliasTypeTest, UnwrapAll_SecondNonConsecutivePointerBlocksUnrapping) {
UnwrapAliasPtrAlias_SecondNonConsecutivePointerBlocksWUnrapping) {
U32Type u32; U32Type u32;
AliasType a{"a_type", &u32}; AliasType a{"a_type", &u32};
AliasType aa{"aa_type", &a}; AliasType aa{"aa_type", &a};
@ -111,7 +125,25 @@ TEST_F(AliasTypeTest,
AliasType aapaa{"aapaa_type", &apaa}; AliasType aapaa{"aapaa_type", &apaa};
PointerType paapaa{&aapaa, StorageClass::kUniform}; PointerType paapaa{&aapaa, StorageClass::kUniform};
AliasType apaapaa{"apaapaa_type", &paapaa}; AliasType apaapaa{"apaapaa_type", &paapaa};
EXPECT_EQ(apaapaa.UnwrapAliasPtrAlias(), &paa); EXPECT_EQ(apaapaa.UnwrapAll(), &paa);
}
TEST_F(AliasTypeTest, UnwrapAll_AccessControlPointer) {
U32Type u32;
AccessControlType a{AccessControl::kReadOnly, &u32};
PointerType pa{&a, StorageClass::kUniform};
EXPECT_EQ(pa.type(), &a);
EXPECT_EQ(pa.UnwrapAll(), &u32);
EXPECT_EQ(u32.UnwrapAll(), &u32);
}
TEST_F(AliasTypeTest, UnwrapAll_PointerAccessControl) {
U32Type u32;
PointerType p{&u32, StorageClass::kUniform};
AccessControlType a{AccessControl::kReadOnly, &p};
EXPECT_EQ(a.type(), &p);
EXPECT_EQ(a.UnwrapAll(), &u32);
EXPECT_EQ(u32.UnwrapAll(), &u32);
} }
} // namespace } // namespace

View File

@ -46,16 +46,22 @@ Type* Type::UnwrapPtrIfNeeded() {
return this; return this;
} }
Type* Type::UnwrapAliasesIfNeeded() { Type* Type::UnwrapIfNeeded() {
auto* where = this; auto* where = this;
while (where->IsAlias()) { while (true) {
if (where->IsAlias()) {
where = where->AsAlias()->type(); where = where->AsAlias()->type();
} else if (where->IsAccessControl()) {
where = where->AsAccessControl()->type();
} else {
break;
}
} }
return where; return where;
} }
Type* Type::UnwrapAliasPtrAlias() { Type* Type::UnwrapAll() {
return UnwrapAliasesIfNeeded()->UnwrapPtrIfNeeded()->UnwrapAliasesIfNeeded(); return UnwrapIfNeeded()->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
} }
bool Type::IsAccessControl() const { bool Type::IsAccessControl() const {

View File

@ -82,20 +82,20 @@ class Type {
/// @returns the pointee type if this is a pointer, |this| otherwise /// @returns the pointee type if this is a pointer, |this| otherwise
Type* UnwrapPtrIfNeeded(); Type* UnwrapPtrIfNeeded();
/// Removes all levels of aliasing, if this is an alias type. Otherwise /// Removes all levels of aliasing and access control.
/// returns |this|. This is just enough to assist with WGSL translation /// This is just enough to assist with WGSL translation
/// in that you want see through one level of pointer to get from an /// in that you want see through one level of pointer to get from an
/// identifier-like expression as an l-value to its corresponding r-value, /// identifier-like expression as an l-value to its corresponding r-value,
/// plus see through the aliases on either side. /// plus see through the wrappers on either side.
/// @returns the completely unaliased type. /// @returns the completely unaliased type.
Type* UnwrapAliasesIfNeeded(); Type* UnwrapIfNeeded();
/// Returns the type found after: /// Returns the type found after:
/// - removing all layers of aliasing if they exist, then /// - removing all layers of aliasing and access control if they exist, then
/// - removing the pointer, if it exists, then /// - removing the pointer, if it exists, then
/// - removing all further layers of aliasing, if they exist /// - removing all further layers of aliasing or access control, if they exist
/// @returns the unwrapped type /// @returns the unwrapped type
Type* UnwrapAliasPtrAlias(); Type* UnwrapAll();
/// @returns true if this type is a scalar /// @returns true if this type is a scalar
bool is_scalar(); bool is_scalar();

View File

@ -1143,7 +1143,7 @@ TypedExpression ParserImpl::MakeConstantExpression(uint32_t id) {
return {}; return {};
} }
auto* ast_type = original_ast_type->UnwrapAliasesIfNeeded(); auto* ast_type = original_ast_type->UnwrapIfNeeded();
// TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0. // TODO(dneto): Note: NullConstant for int, uint, float map to a regular 0.
// So canonicalization should map that way too. // So canonicalization should map that way too.
@ -1220,7 +1220,7 @@ std::unique_ptr<ast::Expression> ParserImpl::MakeNullValue(
} }
auto* original_type = type; auto* original_type = type;
type = type->UnwrapAliasesIfNeeded(); type = type->UnwrapIfNeeded();
if (type->IsBool()) { if (type->IsBool()) {
return std::make_unique<ast::ScalarConstructorExpression>( return std::make_unique<ast::ScalarConstructorExpression>(

View File

@ -181,7 +181,7 @@ bool BoundArrayAccessorsTransform::ProcessArrayAccessor(
return false; return false;
} }
auto* ret_type = expr->array()->result_type()->UnwrapAliasPtrAlias(); auto* ret_type = expr->array()->result_type()->UnwrapAll();
if (!ret_type->IsArray() && !ret_type->IsMatrix() && !ret_type->IsVector()) { if (!ret_type->IsArray() && !ret_type->IsMatrix() && !ret_type->IsVector()) {
return true; return true;
} }

View File

@ -331,7 +331,7 @@ bool TypeDeterminer::DetermineArrayAccessor(
} }
auto* res = expr->array()->result_type(); auto* res = expr->array()->result_type();
auto* parent_type = res->UnwrapAliasPtrAlias(); auto* parent_type = res->UnwrapAll();
ast::type::Type* ret = nullptr; ast::type::Type* ret = nullptr;
if (parent_type->IsArray()) { if (parent_type->IsArray()) {
ret = parent_type->AsArray()->type(); ret = parent_type->AsArray()->type();
@ -942,7 +942,7 @@ bool TypeDeterminer::DetermineMemberAccessor(
} }
auto* res = expr->structure()->result_type(); auto* res = expr->structure()->result_type();
auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapAliasesIfNeeded(); auto* data_type = res->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
ast::type::Type* ret = nullptr; ast::type::Type* ret = nullptr;
if (data_type->IsStruct()) { if (data_type->IsStruct()) {

View File

@ -168,9 +168,8 @@ bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) {
ast::type::Type* func_type = current_function_->return_type(); ast::type::Type* func_type = current_function_->return_type();
ast::type::VoidType void_type; ast::type::VoidType void_type;
auto* ret_type = ret->has_value() auto* ret_type =
? ret->value()->result_type()->UnwrapAliasPtrAlias() ret->has_value() ? ret->value()->result_type()->UnwrapAll() : &void_type;
: &void_type;
if (func_type->type_name() != ret_type->type_name()) { if (func_type->type_name() != ret_type->type_name()) {
set_error(ret->source(), set_error(ret->source(),
@ -249,7 +248,7 @@ bool ValidatorImpl::ValidateSwitch(const ast::SwitchStatement* s) {
return false; return false;
} }
auto* cond_type = s->condition()->result_type()->UnwrapAliasPtrAlias(); auto* cond_type = s->condition()->result_type()->UnwrapAll();
if (!(cond_type->IsI32() || cond_type->IsU32())) { if (!(cond_type->IsI32() || cond_type->IsU32())) {
set_error(s->condition()->source(), set_error(s->condition()->source(),
"v-0025: switch statement selector expression must be of a " "v-0025: switch statement selector expression must be of a "
@ -393,8 +392,8 @@ bool ValidatorImpl::ValidateResultTypes(const ast::AssignmentStatement* a) {
return false; return false;
} }
auto* lhs_result_type = a->lhs()->result_type()->UnwrapAliasPtrAlias(); auto* lhs_result_type = a->lhs()->result_type()->UnwrapAll();
auto* rhs_result_type = a->rhs()->result_type()->UnwrapAliasPtrAlias(); auto* rhs_result_type = a->rhs()->result_type()->UnwrapAll();
if (lhs_result_type != rhs_result_type) { if (lhs_result_type != rhs_result_type) {
// TODO(sarahM0): figur out what should be the error number. // TODO(sarahM0): figur out what should be the error number.
set_error(a->source(), "v-000x: invalid assignment of '" + set_error(a->source(), "v-000x: invalid assignment of '" +

View File

@ -1170,7 +1170,7 @@ bool GeneratorImpl::EmitEntryPointData(std::ostream& out, ast::Function* func) {
} }
// auto* set = data.second.set; // auto* set = data.second.set;
auto* type = var->type()->UnwrapAliasesIfNeeded(); auto* type = var->type()->UnwrapIfNeeded();
if (type->IsStruct()) { if (type->IsStruct()) {
auto* strct = type->AsStruct(); auto* strct = type->AsStruct();
@ -1558,7 +1558,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
first = false; first = false;
if (expr->IsMemberAccessor()) { if (expr->IsMemberAccessor()) {
auto* mem = expr->AsMemberAccessor(); auto* mem = expr->AsMemberAccessor();
auto* res_type = mem->structure()->result_type()->UnwrapAliasPtrAlias(); auto* res_type = mem->structure()->result_type()->UnwrapAll();
if (res_type->IsStruct()) { if (res_type->IsStruct()) {
auto* str_type = res_type->AsStruct()->impl(); auto* str_type = res_type->AsStruct()->impl();
auto* str_member = str_type->get_member(mem->member()->name()); auto* str_member = str_type->get_member(mem->member()->name());
@ -1593,7 +1593,7 @@ std::string GeneratorImpl::generate_storage_buffer_index_expression(
expr = mem->structure(); expr = mem->structure();
} else if (expr->IsArrayAccessor()) { } else if (expr->IsArrayAccessor()) {
auto* ary = expr->AsArrayAccessor(); auto* ary = expr->AsArrayAccessor();
auto* ary_type = ary->array()->result_type()->UnwrapAliasPtrAlias(); auto* ary_type = ary->array()->result_type()->UnwrapAll();
out << "("; out << "(";
if (ary_type->IsArray()) { if (ary_type->IsArray()) {
@ -1641,7 +1641,7 @@ bool GeneratorImpl::EmitStorageBufferAccessor(std::ostream& pre,
std::ostream& out, std::ostream& out,
ast::Expression* expr, ast::Expression* expr,
ast::Expression* rhs) { ast::Expression* rhs) {
auto* result_type = expr->result_type()->UnwrapAliasPtrAlias(); auto* result_type = expr->result_type()->UnwrapAll();
bool is_store = rhs != nullptr; bool is_store = rhs != nullptr;
std::string access_method = is_store ? "Store" : "Load"; std::string access_method = is_store ? "Store" : "Load";
@ -1758,7 +1758,7 @@ bool GeneratorImpl::is_storage_buffer_access(
bool GeneratorImpl::is_storage_buffer_access( bool GeneratorImpl::is_storage_buffer_access(
ast::MemberAccessorExpression* expr) { ast::MemberAccessorExpression* expr) {
auto* structure = expr->structure(); auto* structure = expr->structure();
auto* data_type = structure->result_type()->UnwrapAliasPtrAlias(); auto* data_type = structure->result_type()->UnwrapAll();
// If the data is a multi-element swizzle then we will not load the swizzle // If the data is a multi-element swizzle then we will not load the swizzle
// portion through the Load command. // portion through the Load command.
if (data_type->IsVector() && expr->member()->name().size() > 1) { if (data_type->IsVector() && expr->member()->name().size() > 1) {

View File

@ -814,10 +814,8 @@ bool Builder::GenerateArrayAccessor(ast::ArrayAccessorExpression* expr,
bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr, bool Builder::GenerateMemberAccessor(ast::MemberAccessorExpression* expr,
AccessorInfo* info) { AccessorInfo* info) {
auto* data_type = expr->structure() auto* data_type =
->result_type() expr->structure()->result_type()->UnwrapPtrIfNeeded()->UnwrapIfNeeded();
->UnwrapPtrIfNeeded()
->UnwrapAliasesIfNeeded();
// If the data_type is a structure we're accessing a member, if it's a // If the data_type is a structure we're accessing a member, if it's a
// vector we're accessing a swizzle. // vector we're accessing a swizzle.
@ -1137,7 +1135,7 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) {
} }
auto* tc = expr->AsConstructor()->AsTypeConstructor(); auto* tc = expr->AsConstructor()->AsTypeConstructor();
auto* result_type = tc->type()->UnwrapAliasPtrAlias(); auto* result_type = tc->type()->UnwrapAll();
for (size_t i = 0; i < tc->values().size(); ++i) { for (size_t i = 0; i < tc->values().size(); ++i) {
auto* e = tc->values()[i].get(); auto* e = tc->values()[i].get();
@ -1165,21 +1163,17 @@ bool Builder::is_constructor_const(ast::Expression* expr, bool is_global_init) {
} }
auto* sc = e->AsConstructor()->AsScalarConstructor(); auto* sc = e->AsConstructor()->AsScalarConstructor();
ast::type::Type* subtype = result_type->UnwrapAliasPtrAlias(); ast::type::Type* subtype = result_type->UnwrapAll();
if (subtype->IsVector()) { if (subtype->IsVector()) {
subtype = subtype->AsVector()->type()->UnwrapAliasPtrAlias(); subtype = subtype->AsVector()->type()->UnwrapAll();
} else if (subtype->IsMatrix()) { } else if (subtype->IsMatrix()) {
subtype = subtype->AsMatrix()->type()->UnwrapAliasPtrAlias(); subtype = subtype->AsMatrix()->type()->UnwrapAll();
} else if (subtype->IsArray()) { } else if (subtype->IsArray()) {
subtype = subtype->AsArray()->type()->UnwrapAliasPtrAlias(); subtype = subtype->AsArray()->type()->UnwrapAll();
} else if (subtype->IsStruct()) { } else if (subtype->IsStruct()) {
subtype = subtype->AsStruct() subtype = subtype->AsStruct()->impl()->members()[i]->type()->UnwrapAll();
->impl()
->members()[i]
->type()
->UnwrapAliasPtrAlias();
} }
if (subtype != sc->result_type()->UnwrapAliasPtrAlias()) { if (subtype != sc->result_type()->UnwrapAll()) {
return false; return false;
} }
} }
@ -1200,7 +1194,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
std::ostringstream out; std::ostringstream out;
out << "__const"; out << "__const";
auto* result_type = init->type()->UnwrapAliasPtrAlias(); auto* result_type = init->type()->UnwrapAll();
bool constructor_is_const = is_constructor_const(init, is_global_init); bool constructor_is_const = is_constructor_const(init, is_global_init);
if (has_error()) { if (has_error()) {
return 0; return 0;
@ -1209,7 +1203,7 @@ uint32_t Builder::GenerateTypeConstructorExpression(
bool can_cast_or_copy = result_type->is_scalar(); bool can_cast_or_copy = result_type->is_scalar();
if (result_type->IsVector() && result_type->AsVector()->type()->is_scalar()) { if (result_type->IsVector() && result_type->AsVector()->type()->is_scalar()) {
auto* value_type = values[0]->result_type()->UnwrapAliasPtrAlias(); auto* value_type = values[0]->result_type()->UnwrapAll();
can_cast_or_copy = can_cast_or_copy =
(value_type->IsVector() && (value_type->IsVector() &&
value_type->AsVector()->type()->is_scalar() && value_type->AsVector()->type()->is_scalar() &&
@ -1773,7 +1767,7 @@ uint32_t Builder::GenerateIntrinsic(ast::IdentifierExpression* ident,
} }
params.push_back(Operand::Int(struct_id)); params.push_back(Operand::Int(struct_id));
auto* type = accessor->structure()->result_type()->UnwrapAliasPtrAlias(); auto* type = accessor->structure()->result_type()->UnwrapAll();
if (!type->IsStruct()) { if (!type->IsStruct()) {
error_ = error_ =
"invalid type (" + type->type_name() + ") for runtime array length"; "invalid type (" + type->type_name() + ") for runtime array length";
@ -1866,11 +1860,8 @@ uint32_t Builder::GenerateTextureIntrinsic(ast::IdentifierExpression* ident,
ast::CallExpression* call, ast::CallExpression* call,
uint32_t result_id, uint32_t result_id,
OperandList wgsl_params) { OperandList wgsl_params) {
auto* texture_type = call->params()[0] auto* texture_type =
.get() call->params()[0].get()->result_type()->UnwrapAll()->AsTexture();
->result_type()
->UnwrapAliasPtrAlias()
->AsTexture();
// TODO: Remove the LOD param from textureLoad on storage textures when // TODO: Remove the LOD param from textureLoad on storage textures when
// https://github.com/gpuweb/gpuweb/pull/1032 gets merged. // https://github.com/gpuweb/gpuweb/pull/1032 gets merged.