Resolver: compute canonical types and store them as semantic::Variable::Type

We define the canonical type as a type stripped of all aliases. For
example, Canonical(alias<alias<vec3<alias<f32>>>>) is vec3<f32>. This
change adds Resolver::Canonical(Type*) which caches and returns the
resulting canonical type. We use this throughout the Resolver instead of
UnwrapAliasIfNeeded(), and we store the result in semantic::Variable,
returned from it's Type() member function.

Also:

* Wrote unit tests for Resolver::Canonical()

* Added semantic::Variable::DeclaredType() as a convenience to
retrieve the AST variable's type.

* Updated post-resolve code (transforms) to make use of Type and
DeclaredType appropriately, removing unnecessary calls to
UnwrapAliasIfNeeded.

* Added IntrinsicTableTest.MatchWithNestedAliasUnwrapping to ensure we
don't need to pass canonical parameter types for instrinsic table
lookups.

* ProgramBuilder: added vecN and matMxN overloads that take a Type* arg
to create them with alias types.

Bug: tint:705
Change-Id: I58a3b62538356b8dad2b1161a19b38bcefdd5d62
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47360
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2021-04-13 13:32:33 +00:00
committed by Commit Bot service account
parent 2f25ecf8ba
commit 2543686412
14 changed files with 269 additions and 76 deletions

View File

@@ -42,6 +42,7 @@
#include "src/semantic/struct.h"
#include "src/semantic/variable.h"
#include "src/type/access_control_type.h"
#include "src/utils/get_or_create.h"
#include "src/utils/math.h"
namespace tint {
@@ -430,7 +431,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func) {
}
// Check that we saw a pipeline IO attribute iff we need one.
if (ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
if (Canonical(ty)->Is<type::Struct>()) {
if (pipeline_io_attribute) {
diagnostics_.add_error(
"entry point IO attributes must not be used on structure " +
@@ -466,11 +467,11 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func) {
return false;
}
if (auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
if (auto* struct_ty = Canonical(ty)->As<type::Struct>()) {
// Validate the decorations for each struct members, and also check for
// invalid member types.
for (auto* member : struct_ty->impl()->members()) {
auto* member_ty = member->type()->UnwrapAliasIfNeeded();
auto* member_ty = Canonical(member->type());
if (member_ty->Is<type::Struct>()) {
diagnostics_.add_error(
"entry point IO types cannot contain nested structures",
@@ -547,8 +548,7 @@ bool Resolver::Function(ast::Function* func) {
return false;
}
if (auto* str =
param->declared_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
if (auto* str = param_info->type->As<type::Struct>()) {
auto* info = Structure(str);
if (!info) {
return false;
@@ -572,8 +572,7 @@ bool Resolver::Function(ast::Function* func) {
}
}
if (auto* str =
func->return_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
if (auto* str = Canonical(func->return_type())->As<type::Struct>()) {
if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
func->source())) {
diagnostics_.add_note("while instantiating return type for " +
@@ -1288,15 +1287,16 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
using Matrix = type::Matrix;
using Vector = type::Vector;
auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
auto* lhs_type = Canonical(lhs_declared_type);
auto* rhs_type = Canonical(rhs_declared_type);
auto* lhs_vec = lhs_type->As<Vector>();
auto* lhs_vec_elem_type =
lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_type->As<Vector>();
auto* rhs_vec_elem_type =
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
const bool matching_vec_elem_types =
lhs_vec_elem_type && rhs_vec_elem_type &&
@@ -1348,11 +1348,9 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
}
auto* lhs_mat = lhs_type->As<Matrix>();
auto* lhs_mat_elem_type =
lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_type->As<Matrix>();
auto* rhs_mat_elem_type =
rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
@@ -1438,9 +1436,9 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
diagnostics_.add_error(
"Binary expression operand types are invalid for this operation: " +
lhs_type->FriendlyName(builder_->Symbols()) + " " +
lhs_declared_type->FriendlyName(builder_->Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(builder_->Symbols()),
rhs_declared_type->FriendlyName(builder_->Symbols()),
expr->source());
return false;
}
@@ -1600,7 +1598,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
}
Resolver::VariableInfo* Resolver::CreateVariableInfo(ast::Variable* var) {
auto* info = variable_infos_.Create(var);
auto* info = variable_infos_.Create(var, Canonical(var->declared_type()));
variable_to_info_.emplace(var, info);
return info;
}
@@ -1748,13 +1746,13 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
/*vec4*/ 16,
};
ty = ty->UnwrapAliasIfNeeded();
if (ty->is_scalar()) {
auto* cty = Canonical(ty);
if (cty->is_scalar()) {
// Note: Also captures booleans, but these are not host-shareable.
align = 4;
size = 4;
return true;
} else if (auto* vec = ty->As<type::Vector>()) {
} else if (auto* vec = cty->As<type::Vector>()) {
if (vec->size() < 2 || vec->size() > 4) {
TINT_UNREACHABLE(diagnostics_)
<< "Invalid vector size: vec" << vec->size();
@@ -1763,7 +1761,7 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
align = vector_align[vec->size()];
size = vector_size[vec->size()];
return true;
} else if (auto* mat = ty->As<type::Matrix>()) {
} else if (auto* mat = cty->As<type::Matrix>()) {
if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
mat->rows() > 4) {
TINT_UNREACHABLE(diagnostics_)
@@ -1773,15 +1771,15 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
align = vector_align[mat->rows()];
size = vector_align[mat->rows()] * mat->columns();
return true;
} else if (auto* s = ty->As<type::Struct>()) {
} else if (auto* s = cty->As<type::Struct>()) {
if (auto* si = Structure(s)) {
align = si->align;
size = si->size;
return true;
}
return false;
} else if (auto* arr = ty->As<type::Array>()) {
if (auto* sem = Array(arr)) {
} else if (cty->Is<type::Array>()) {
if (auto* sem = Array(ty->UnwrapAliasIfNeeded()->As<type::Array>())) {
align = sem->Align();
size = sem->Size();
return true;
@@ -2249,9 +2247,37 @@ std::string Resolver::VectorPretty(uint32_t size, type::Type* element_type) {
return vec_type.FriendlyName(builder_->Symbols());
}
Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
type::Type* Resolver::Canonical(type::Type* type) {
using Type = type::Type;
using Alias = type::Alias;
using Matrix = type::Matrix;
using Vector = type::Vector;
std::function<Type*(Type*)> make_canonical;
make_canonical = [&](Type* t) -> type::Type* {
// Unwrap alias sequence
Type* ct = t;
while (auto* p = ct->As<Alias>()) {
ct = p->type();
}
if (auto* v = ct->As<Vector>()) {
return builder_->create<Vector>(make_canonical(v->type()), v->size());
}
if (auto* m = ct->As<Matrix>()) {
return builder_->create<Matrix>(make_canonical(m->type()), m->rows(),
m->columns());
}
return ct;
};
return utils::GetOrCreate(type_to_canonical_, type,
[&] { return make_canonical(type); });
}
Resolver::VariableInfo::VariableInfo(ast::Variable* decl, type::Type* ctype)
: declaration(decl),
type(decl->declared_type()),
type(ctype),
storage_class(decl->declared_storage_class()) {}
Resolver::VariableInfo::~VariableInfo() = default;

View File

@@ -86,11 +86,17 @@ class Resolver {
/// structure member or array element of type `lhs`
static bool IsValidAssignment(type::Type* lhs, type::Type* rhs);
/// @param type the input type
/// @returns the canonical type for `type`; that is, a type with all aliases
/// removed. For example, `Canonical(alias<alias<vec3<alias<f32>>>>)` is
/// `vec3<f32>`.
type::Type* Canonical(type::Type* type);
private:
/// Structure holding semantic information about a variable.
/// Used to build the semantic::Variable nodes at the end of resolving.
struct VariableInfo {
explicit VariableInfo(ast::Variable* decl);
VariableInfo(ast::Variable* decl, type::Type* type);
~VariableInfo();
ast::Variable* const declaration;
@@ -306,6 +312,7 @@ class Resolver {
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<type::Struct*, StructInfo*> struct_info_;
std::unordered_map<type::Type*, type::Type*> type_to_canonical_;
FunctionInfo* current_function_ = nullptr;
semantic::Statement* current_statement_ = nullptr;
BlockAllocator<VariableInfo> variable_infos_;

View File

@@ -119,18 +119,30 @@ inline type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.f32();
}
using create_type_func_ptr =
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <typename T>
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3<T>();
}
template <create_type_func_ptr create_type>
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.vec3(type);
}
template <typename T>
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>();
}
using create_type_func_ptr =
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <create_type_func_ptr create_type>
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.mat3x3(type);
}
template <create_type_func_ptr create_type>
type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) {

View File

@@ -21,6 +21,7 @@
#include "gmock/gmock.h"
namespace tint {
namespace resolver {
namespace {
class ResolverTypeValidationTest : public resolver::TestHelper,
@@ -463,5 +464,51 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
namespace GetCanonicalTests {
struct Params {
create_type_func_ptr create_type;
create_type_func_ptr create_canonical_type;
};
static constexpr Params cases[] = {
Params{ty_bool_, ty_bool_},
Params{ty_alias<ty_bool_>, ty_bool_},
Params{ty_alias<ty_alias<ty_bool_>>, ty_bool_},
Params{ty_vec3<ty_f32>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_f32>>>, ty_vec3<ty_f32>},
Params{ty_vec3<ty_alias<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_alias<ty_f32>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_f32>>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_alias<ty_f32>>>>>,
ty_vec3<ty_f32>},
Params{ty_mat3x3<ty_alias<ty_f32>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_mat3x3<ty_alias<ty_f32>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_f32>>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_alias<ty_f32>>>>>,
ty_mat3x3<ty_f32>},
};
using CanonicalTest = ResolverTestWithParam<Params>;
TEST_P(CanonicalTest, All) {
auto& params = GetParam();
auto* type = params.create_type(ty);
auto* expected_canonical_type = params.create_canonical_type(ty);
auto* canonical_type = r()->Canonical(type);
EXPECT_EQ(canonical_type, expected_canonical_type);
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
CanonicalTest,
testing::ValuesIn(cases));
} // namespace GetCanonicalTests
} // namespace
} // namespace resolver
} // namespace tint