Resolver: compute canonical types and store them as semantic::Variable::Type
We define the canonical type as a type stripped of all aliases. For example, Canonical(alias<alias<vec3<alias<f32>>>>) is vec3<f32>. This change adds Resolver::Canonical(Type*) which caches and returns the resulting canonical type. We use this throughout the Resolver instead of UnwrapAliasIfNeeded(), and we store the result in semantic::Variable, returned from it's Type() member function. Also: * Wrote unit tests for Resolver::Canonical() * Added semantic::Variable::DeclaredType() as a convenience to retrieve the AST variable's type. * Updated post-resolve code (transforms) to make use of Type and DeclaredType appropriately, removing unnecessary calls to UnwrapAliasIfNeeded. * Added IntrinsicTableTest.MatchWithNestedAliasUnwrapping to ensure we don't need to pass canonical parameter types for instrinsic table lookups. * ProgramBuilder: added vecN and matMxN overloads that take a Type* arg to create them with alias types. Bug: tint:705 Change-Id: I58a3b62538356b8dad2b1161a19b38bcefdd5d62 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47360 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
2f25ecf8ba
commit
2543686412
|
@ -357,6 +357,23 @@ TEST_F(IntrinsicTableTest, MatchWithAliasUnwrapping) {
|
|||
EXPECT_THAT(result.intrinsic->Parameters(), ElementsAre(Parameter{ty.f32()}));
|
||||
}
|
||||
|
||||
TEST_F(IntrinsicTableTest, MatchWithNestedAliasUnwrapping) {
|
||||
auto* alias_a = ty.alias("alias_a", ty.bool_());
|
||||
auto* alias_b = ty.alias("alias_b", alias_a);
|
||||
auto* alias_c = ty.alias("alias_c", alias_b);
|
||||
auto* vec4_of_c = ty.vec4(alias_c);
|
||||
auto* alias_d = ty.alias("alias_d", vec4_of_c);
|
||||
auto* alias_e = ty.alias("alias_e", alias_d);
|
||||
|
||||
auto result = table->Lookup(*this, IntrinsicType::kAll, {alias_e}, Source{});
|
||||
ASSERT_NE(result.intrinsic, nullptr);
|
||||
ASSERT_EQ(result.diagnostics.str(), "");
|
||||
EXPECT_THAT(result.intrinsic->Type(), IntrinsicType::kAll);
|
||||
EXPECT_THAT(result.intrinsic->ReturnType(), ty.bool_());
|
||||
EXPECT_THAT(result.intrinsic->Parameters(),
|
||||
ElementsAre(Parameter{ty.vec4<bool>()}));
|
||||
}
|
||||
|
||||
TEST_F(IntrinsicTableTest, MatchOpenType) {
|
||||
auto result = table->Lookup(*this, IntrinsicType::kClamp,
|
||||
{ty.f32(), ty.f32(), ty.f32()}, Source{});
|
||||
|
|
|
@ -338,76 +338,148 @@ class ProgramBuilder {
|
|||
/// @returns a void type
|
||||
type::Void* void_() const { return builder->create<type::Void>(); }
|
||||
|
||||
/// @param type vector subtype
|
||||
/// @return the tint AST type for a 2-element vector of `type`.
|
||||
type::Vector* vec2(type::Type* type) const {
|
||||
return builder->create<type::Vector>(type, 2u);
|
||||
}
|
||||
|
||||
/// @param type vector subtype
|
||||
/// @return the tint AST type for a 3-element vector of `type`.
|
||||
type::Vector* vec3(type::Type* type) const {
|
||||
return builder->create<type::Vector>(type, 3u);
|
||||
}
|
||||
|
||||
/// @param type vector subtype
|
||||
/// @return the tint AST type for a 4-element vector of `type`.
|
||||
type::Type* vec4(type::Type* type) const {
|
||||
return builder->create<type::Vector>(type, 4u);
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 2-element vector of the C type `T`.
|
||||
template <typename T>
|
||||
type::Vector* vec2() const {
|
||||
return builder->create<type::Vector>(Of<T>(), 2);
|
||||
return vec2(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 3-element vector of the C type `T`.
|
||||
template <typename T>
|
||||
type::Vector* vec3() const {
|
||||
return builder->create<type::Vector>(Of<T>(), 3);
|
||||
return vec3(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 4-element vector of the C type `T`.
|
||||
template <typename T>
|
||||
type::Type* vec4() const {
|
||||
return builder->create<type::Vector>(Of<T>(), 4);
|
||||
return vec4(Of<T>());
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 2x3 matrix of `type`.
|
||||
type::Matrix* mat2x2(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 2u, 2u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 2x3 matrix of `type`.
|
||||
type::Matrix* mat2x3(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 3u, 2u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 2x4 matrix of `type`.
|
||||
type::Matrix* mat2x4(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 4u, 2u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 3x2 matrix of `type`.
|
||||
type::Matrix* mat3x2(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 2u, 3u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 3x3 matrix of `type`.
|
||||
type::Matrix* mat3x3(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 3u, 3u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 3x4 matrix of `type`.
|
||||
type::Matrix* mat3x4(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 4u, 3u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 4x2 matrix of `type`.
|
||||
type::Matrix* mat4x2(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 2u, 4u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 4x3 matrix of `type`.
|
||||
type::Matrix* mat4x3(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 3u, 4u);
|
||||
}
|
||||
|
||||
/// @param type matrix subtype
|
||||
/// @return the tint AST type for a 4x4 matrix of `type`.
|
||||
type::Matrix* mat4x4(type::Type* type) const {
|
||||
return builder->create<type::Matrix>(type, 4u, 4u);
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 2x3 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat2x2() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 2, 2);
|
||||
return mat2x2(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 2x3 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat2x3() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 3, 2);
|
||||
return mat2x3(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 2x4 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat2x4() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 4, 2);
|
||||
return mat2x4(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 3x2 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat3x2() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 2, 3);
|
||||
return mat3x2(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 3x3 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat3x3() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 3, 3);
|
||||
return mat3x3(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 3x4 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat3x4() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 4, 3);
|
||||
return mat3x4(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 4x2 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat4x2() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 2, 4);
|
||||
return mat4x2(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 4x3 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat4x3() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 3, 4);
|
||||
return mat4x3(Of<T>());
|
||||
}
|
||||
|
||||
/// @return the tint AST type for a 4x4 matrix of the C type `T`.
|
||||
template <typename T>
|
||||
type::Matrix* mat4x4() const {
|
||||
return builder->create<type::Matrix>(Of<T>(), 4, 4);
|
||||
return mat4x4(Of<T>());
|
||||
}
|
||||
|
||||
/// @param subtype the array element type
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "src/semantic/struct.h"
|
||||
#include "src/semantic/variable.h"
|
||||
#include "src/type/access_control_type.h"
|
||||
#include "src/utils/get_or_create.h"
|
||||
#include "src/utils/math.h"
|
||||
|
||||
namespace tint {
|
||||
|
@ -430,7 +431,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func) {
|
|||
}
|
||||
|
||||
// Check that we saw a pipeline IO attribute iff we need one.
|
||||
if (ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
|
||||
if (Canonical(ty)->Is<type::Struct>()) {
|
||||
if (pipeline_io_attribute) {
|
||||
diagnostics_.add_error(
|
||||
"entry point IO attributes must not be used on structure " +
|
||||
|
@ -466,11 +467,11 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
|
||||
if (auto* struct_ty = Canonical(ty)->As<type::Struct>()) {
|
||||
// Validate the decorations for each struct members, and also check for
|
||||
// invalid member types.
|
||||
for (auto* member : struct_ty->impl()->members()) {
|
||||
auto* member_ty = member->type()->UnwrapAliasIfNeeded();
|
||||
auto* member_ty = Canonical(member->type());
|
||||
if (member_ty->Is<type::Struct>()) {
|
||||
diagnostics_.add_error(
|
||||
"entry point IO types cannot contain nested structures",
|
||||
|
@ -547,8 +548,7 @@ bool Resolver::Function(ast::Function* func) {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (auto* str =
|
||||
param->declared_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
|
||||
if (auto* str = param_info->type->As<type::Struct>()) {
|
||||
auto* info = Structure(str);
|
||||
if (!info) {
|
||||
return false;
|
||||
|
@ -572,8 +572,7 @@ bool Resolver::Function(ast::Function* func) {
|
|||
}
|
||||
}
|
||||
|
||||
if (auto* str =
|
||||
func->return_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
|
||||
if (auto* str = Canonical(func->return_type())->As<type::Struct>()) {
|
||||
if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
|
||||
func->source())) {
|
||||
diagnostics_.add_note("while instantiating return type for " +
|
||||
|
@ -1288,15 +1287,16 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
|||
using Matrix = type::Matrix;
|
||||
using Vector = type::Vector;
|
||||
|
||||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
|
||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
|
||||
auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
|
||||
auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
|
||||
|
||||
auto* lhs_type = Canonical(lhs_declared_type);
|
||||
auto* rhs_type = Canonical(rhs_declared_type);
|
||||
|
||||
auto* lhs_vec = lhs_type->As<Vector>();
|
||||
auto* lhs_vec_elem_type =
|
||||
lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
|
||||
auto* rhs_vec = rhs_type->As<Vector>();
|
||||
auto* rhs_vec_elem_type =
|
||||
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
|
||||
|
||||
const bool matching_vec_elem_types =
|
||||
lhs_vec_elem_type && rhs_vec_elem_type &&
|
||||
|
@ -1348,11 +1348,9 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
|||
}
|
||||
|
||||
auto* lhs_mat = lhs_type->As<Matrix>();
|
||||
auto* lhs_mat_elem_type =
|
||||
lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
||||
auto* rhs_mat = rhs_type->As<Matrix>();
|
||||
auto* rhs_mat_elem_type =
|
||||
rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
||||
|
||||
// Multiplication of a matrix and a scalar
|
||||
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
|
||||
|
@ -1438,9 +1436,9 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
|||
|
||||
diagnostics_.add_error(
|
||||
"Binary expression operand types are invalid for this operation: " +
|
||||
lhs_type->FriendlyName(builder_->Symbols()) + " " +
|
||||
lhs_declared_type->FriendlyName(builder_->Symbols()) + " " +
|
||||
FriendlyName(expr->op()) + " " +
|
||||
rhs_type->FriendlyName(builder_->Symbols()),
|
||||
rhs_declared_type->FriendlyName(builder_->Symbols()),
|
||||
expr->source());
|
||||
return false;
|
||||
}
|
||||
|
@ -1600,7 +1598,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
|
|||
}
|
||||
|
||||
Resolver::VariableInfo* Resolver::CreateVariableInfo(ast::Variable* var) {
|
||||
auto* info = variable_infos_.Create(var);
|
||||
auto* info = variable_infos_.Create(var, Canonical(var->declared_type()));
|
||||
variable_to_info_.emplace(var, info);
|
||||
return info;
|
||||
}
|
||||
|
@ -1748,13 +1746,13 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
|
|||
/*vec4*/ 16,
|
||||
};
|
||||
|
||||
ty = ty->UnwrapAliasIfNeeded();
|
||||
if (ty->is_scalar()) {
|
||||
auto* cty = Canonical(ty);
|
||||
if (cty->is_scalar()) {
|
||||
// Note: Also captures booleans, but these are not host-shareable.
|
||||
align = 4;
|
||||
size = 4;
|
||||
return true;
|
||||
} else if (auto* vec = ty->As<type::Vector>()) {
|
||||
} else if (auto* vec = cty->As<type::Vector>()) {
|
||||
if (vec->size() < 2 || vec->size() > 4) {
|
||||
TINT_UNREACHABLE(diagnostics_)
|
||||
<< "Invalid vector size: vec" << vec->size();
|
||||
|
@ -1763,7 +1761,7 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
|
|||
align = vector_align[vec->size()];
|
||||
size = vector_size[vec->size()];
|
||||
return true;
|
||||
} else if (auto* mat = ty->As<type::Matrix>()) {
|
||||
} else if (auto* mat = cty->As<type::Matrix>()) {
|
||||
if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
|
||||
mat->rows() > 4) {
|
||||
TINT_UNREACHABLE(diagnostics_)
|
||||
|
@ -1773,15 +1771,15 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
|
|||
align = vector_align[mat->rows()];
|
||||
size = vector_align[mat->rows()] * mat->columns();
|
||||
return true;
|
||||
} else if (auto* s = ty->As<type::Struct>()) {
|
||||
} else if (auto* s = cty->As<type::Struct>()) {
|
||||
if (auto* si = Structure(s)) {
|
||||
align = si->align;
|
||||
size = si->size;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} else if (auto* arr = ty->As<type::Array>()) {
|
||||
if (auto* sem = Array(arr)) {
|
||||
} else if (cty->Is<type::Array>()) {
|
||||
if (auto* sem = Array(ty->UnwrapAliasIfNeeded()->As<type::Array>())) {
|
||||
align = sem->Align();
|
||||
size = sem->Size();
|
||||
return true;
|
||||
|
@ -2249,9 +2247,37 @@ std::string Resolver::VectorPretty(uint32_t size, type::Type* element_type) {
|
|||
return vec_type.FriendlyName(builder_->Symbols());
|
||||
}
|
||||
|
||||
Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
|
||||
type::Type* Resolver::Canonical(type::Type* type) {
|
||||
using Type = type::Type;
|
||||
using Alias = type::Alias;
|
||||
using Matrix = type::Matrix;
|
||||
using Vector = type::Vector;
|
||||
|
||||
std::function<Type*(Type*)> make_canonical;
|
||||
make_canonical = [&](Type* t) -> type::Type* {
|
||||
// Unwrap alias sequence
|
||||
Type* ct = t;
|
||||
while (auto* p = ct->As<Alias>()) {
|
||||
ct = p->type();
|
||||
}
|
||||
|
||||
if (auto* v = ct->As<Vector>()) {
|
||||
return builder_->create<Vector>(make_canonical(v->type()), v->size());
|
||||
}
|
||||
if (auto* m = ct->As<Matrix>()) {
|
||||
return builder_->create<Matrix>(make_canonical(m->type()), m->rows(),
|
||||
m->columns());
|
||||
}
|
||||
return ct;
|
||||
};
|
||||
|
||||
return utils::GetOrCreate(type_to_canonical_, type,
|
||||
[&] { return make_canonical(type); });
|
||||
}
|
||||
|
||||
Resolver::VariableInfo::VariableInfo(ast::Variable* decl, type::Type* ctype)
|
||||
: declaration(decl),
|
||||
type(decl->declared_type()),
|
||||
type(ctype),
|
||||
storage_class(decl->declared_storage_class()) {}
|
||||
|
||||
Resolver::VariableInfo::~VariableInfo() = default;
|
||||
|
|
|
@ -86,11 +86,17 @@ class Resolver {
|
|||
/// structure member or array element of type `lhs`
|
||||
static bool IsValidAssignment(type::Type* lhs, type::Type* rhs);
|
||||
|
||||
/// @param type the input type
|
||||
/// @returns the canonical type for `type`; that is, a type with all aliases
|
||||
/// removed. For example, `Canonical(alias<alias<vec3<alias<f32>>>>)` is
|
||||
/// `vec3<f32>`.
|
||||
type::Type* Canonical(type::Type* type);
|
||||
|
||||
private:
|
||||
/// Structure holding semantic information about a variable.
|
||||
/// Used to build the semantic::Variable nodes at the end of resolving.
|
||||
struct VariableInfo {
|
||||
explicit VariableInfo(ast::Variable* decl);
|
||||
VariableInfo(ast::Variable* decl, type::Type* type);
|
||||
~VariableInfo();
|
||||
|
||||
ast::Variable* const declaration;
|
||||
|
@ -306,6 +312,7 @@ class Resolver {
|
|||
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
|
||||
std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
|
||||
std::unordered_map<type::Struct*, StructInfo*> struct_info_;
|
||||
std::unordered_map<type::Type*, type::Type*> type_to_canonical_;
|
||||
FunctionInfo* current_function_ = nullptr;
|
||||
semantic::Statement* current_statement_ = nullptr;
|
||||
BlockAllocator<VariableInfo> variable_infos_;
|
||||
|
|
|
@ -119,18 +119,30 @@ inline type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
|
|||
return ty.f32();
|
||||
}
|
||||
|
||||
using create_type_func_ptr =
|
||||
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
|
||||
|
||||
template <typename T>
|
||||
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.vec3<T>();
|
||||
}
|
||||
|
||||
template <create_type_func_ptr create_type>
|
||||
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
auto* type = create_type(ty);
|
||||
return ty.vec3(type);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat3x3<T>();
|
||||
}
|
||||
|
||||
using create_type_func_ptr =
|
||||
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
|
||||
template <create_type_func_ptr create_type>
|
||||
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
auto* type = create_type(ty);
|
||||
return ty.mat3x3(type);
|
||||
}
|
||||
|
||||
template <create_type_func_ptr create_type>
|
||||
type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) {
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "gmock/gmock.h"
|
||||
|
||||
namespace tint {
|
||||
namespace resolver {
|
||||
namespace {
|
||||
|
||||
class ResolverTypeValidationTest : public resolver::TestHelper,
|
||||
|
@ -463,5 +464,51 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
|
|||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
namespace GetCanonicalTests {
|
||||
struct Params {
|
||||
create_type_func_ptr create_type;
|
||||
create_type_func_ptr create_canonical_type;
|
||||
};
|
||||
|
||||
static constexpr Params cases[] = {
|
||||
Params{ty_bool_, ty_bool_},
|
||||
Params{ty_alias<ty_bool_>, ty_bool_},
|
||||
Params{ty_alias<ty_alias<ty_bool_>>, ty_bool_},
|
||||
|
||||
Params{ty_vec3<ty_f32>, ty_vec3<ty_f32>},
|
||||
Params{ty_alias<ty_vec3<ty_f32>>, ty_vec3<ty_f32>},
|
||||
Params{ty_alias<ty_alias<ty_vec3<ty_f32>>>, ty_vec3<ty_f32>},
|
||||
|
||||
Params{ty_vec3<ty_alias<ty_f32>>, ty_vec3<ty_f32>},
|
||||
Params{ty_alias<ty_vec3<ty_alias<ty_f32>>>, ty_vec3<ty_f32>},
|
||||
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_f32>>>>, ty_vec3<ty_f32>},
|
||||
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_alias<ty_f32>>>>>,
|
||||
ty_vec3<ty_f32>},
|
||||
|
||||
Params{ty_mat3x3<ty_alias<ty_f32>>, ty_mat3x3<ty_f32>},
|
||||
Params{ty_alias<ty_mat3x3<ty_alias<ty_f32>>>, ty_mat3x3<ty_f32>},
|
||||
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_f32>>>>, ty_mat3x3<ty_f32>},
|
||||
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_alias<ty_f32>>>>>,
|
||||
ty_mat3x3<ty_f32>},
|
||||
};
|
||||
|
||||
using CanonicalTest = ResolverTestWithParam<Params>;
|
||||
TEST_P(CanonicalTest, All) {
|
||||
auto& params = GetParam();
|
||||
|
||||
auto* type = params.create_type(ty);
|
||||
auto* expected_canonical_type = params.create_canonical_type(ty);
|
||||
|
||||
auto* canonical_type = r()->Canonical(type);
|
||||
|
||||
EXPECT_EQ(canonical_type, expected_canonical_type);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
|
||||
CanonicalTest,
|
||||
testing::ValuesIn(cases));
|
||||
|
||||
} // namespace GetCanonicalTests
|
||||
|
||||
} // namespace
|
||||
} // namespace resolver
|
||||
} // namespace tint
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "src/semantic/variable.h"
|
||||
|
||||
#include "src/ast/identifier_expression.h"
|
||||
#include "src/ast/variable.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::semantic::Variable);
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::semantic::VariableUser);
|
||||
|
@ -29,6 +30,10 @@ Variable::Variable(const ast::Variable* declaration,
|
|||
|
||||
Variable::~Variable() = default;
|
||||
|
||||
type::Type* Variable::DeclaredType() const {
|
||||
return declaration_->declared_type();
|
||||
}
|
||||
|
||||
VariableUser::VariableUser(ast::IdentifierExpression* declaration,
|
||||
type::Type* type,
|
||||
Statement* statement,
|
||||
|
|
|
@ -52,9 +52,12 @@ class Variable : public Castable<Variable, Node> {
|
|||
/// @returns the AST declaration node
|
||||
const ast::Variable* Declaration() const { return declaration_; }
|
||||
|
||||
/// @returns the type for the variable
|
||||
/// @returns the canonical type for the variable
|
||||
type::Type* Type() const { return type_; }
|
||||
|
||||
/// @returns the AST node's type. May be nullptr.
|
||||
type::Type* DeclaredType() const;
|
||||
|
||||
/// @returns the storage class for the variable
|
||||
ast::StorageClass StorageClass() const { return storage_class_; }
|
||||
|
||||
|
|
|
@ -65,8 +65,7 @@ Transform::Output BindingRemapper::Run(const Program* in,
|
|||
auto ac_it = remappings->access_controls.find(from);
|
||||
if (ac_it != remappings->access_controls.end()) {
|
||||
ast::AccessControl ac = ac_it->second;
|
||||
auto* var_ty = in->Sem().Get(var)->Type();
|
||||
auto* ty = var_ty->UnwrapAliasIfNeeded();
|
||||
auto* ty = in->Sem().Get(var)->Type();
|
||||
type::Type* inner_ty = nullptr;
|
||||
if (auto* old_ac = ty->As<type::AccessControl>()) {
|
||||
inner_ty = ctx.Clone(old_ac->type());
|
||||
|
|
|
@ -72,11 +72,11 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
|||
for (auto* param : func->params()) {
|
||||
auto param_name = ctx.Clone(param->symbol());
|
||||
auto* param_ty = ctx.src->Sem().Get(param)->Type();
|
||||
auto* param_declared_ty = ctx.src->Sem().Get(param)->DeclaredType();
|
||||
|
||||
ast::Expression* func_const_initializer = nullptr;
|
||||
|
||||
if (auto* struct_ty =
|
||||
param_ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
|
||||
if (auto* struct_ty = param_ty->As<type::Struct>()) {
|
||||
// Pull out all struct members and build initializer list.
|
||||
ast::ExpressionList init_values;
|
||||
for (auto* member : struct_ty->impl()->members()) {
|
||||
|
@ -97,7 +97,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
|||
}
|
||||
|
||||
func_const_initializer =
|
||||
ctx.dst->Construct(ctx.Clone(param_ty), init_values);
|
||||
ctx.dst->Construct(ctx.Clone(param_declared_ty), init_values);
|
||||
} else {
|
||||
ast::DecorationList new_decorations = RemoveDecorations(
|
||||
&ctx, param->decorations(), [](const ast::Decoration* deco) {
|
||||
|
@ -105,7 +105,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
|||
ast::LocationDecoration>();
|
||||
});
|
||||
new_struct_members.push_back(ctx.dst->Member(
|
||||
param_name, ctx.Clone(param_ty), new_decorations));
|
||||
param_name, ctx.Clone(param_declared_ty), new_decorations));
|
||||
func_const_initializer =
|
||||
ctx.dst->MemberAccessor(new_struct_param_symbol, param_name);
|
||||
}
|
||||
|
@ -117,8 +117,8 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
|
|||
|
||||
// Create a function-scope const to replace the parameter.
|
||||
// Initialize it with the value extracted from the new struct parameter.
|
||||
auto* func_const = ctx.dst->Const(param_name, ctx.Clone(param_ty),
|
||||
func_const_initializer);
|
||||
auto* func_const = ctx.dst->Const(
|
||||
param_name, ctx.Clone(param_declared_ty), func_const_initializer);
|
||||
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
|
||||
ctx.dst->WrapInStatement(func_const));
|
||||
|
||||
|
|
|
@ -141,7 +141,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
|||
|
||||
for (auto* param : func->params()) {
|
||||
Symbol new_var = HoistToInputVariables(
|
||||
ctx, func, ctx.src->Sem().Get(param)->Type(), param->decorations());
|
||||
ctx, func, ctx.src->Sem().Get(param)->Type(),
|
||||
ctx.src->Sem().Get(param)->DeclaredType(), param->decorations());
|
||||
|
||||
// Replace all uses of the function parameter with the new variable.
|
||||
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
|
||||
|
@ -153,9 +154,9 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
|
|||
if (!func->return_type()->Is<type::Void>()) {
|
||||
ast::StatementList stores;
|
||||
auto store_value_symbol = ctx.dst->Symbols().New();
|
||||
HoistToOutputVariables(ctx, func, func->return_type(),
|
||||
func->return_type_decorations(), {},
|
||||
store_value_symbol, stores);
|
||||
HoistToOutputVariables(
|
||||
ctx, func, func->return_type(), func->return_type(),
|
||||
func->return_type_decorations(), {}, store_value_symbol, stores);
|
||||
|
||||
// Create a function that writes a return value to all output variables.
|
||||
auto* store_value =
|
||||
|
@ -251,8 +252,9 @@ Symbol Spirv::HoistToInputVariables(
|
|||
CloneContext& ctx,
|
||||
const ast::Function* func,
|
||||
type::Type* ty,
|
||||
type::Type* declared_ty,
|
||||
const ast::DecorationList& decorations) const {
|
||||
if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
|
||||
if (!ty->Is<type::Struct>()) {
|
||||
// Base case: create a global variable and return.
|
||||
ast::DecorationList new_decorations =
|
||||
RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
|
||||
|
@ -261,7 +263,7 @@ Symbol Spirv::HoistToInputVariables(
|
|||
});
|
||||
auto global_var_symbol = ctx.dst->Symbols().New();
|
||||
auto* global_var =
|
||||
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
|
||||
ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
|
||||
ast::StorageClass::kInput, nullptr, new_decorations);
|
||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
|
||||
return global_var_symbol;
|
||||
|
@ -269,10 +271,10 @@ Symbol Spirv::HoistToInputVariables(
|
|||
|
||||
// Recurse into struct members and build the initializer list.
|
||||
ast::ExpressionList init_values;
|
||||
auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
|
||||
auto* struct_ty = ty->As<type::Struct>();
|
||||
for (auto* member : struct_ty->impl()->members()) {
|
||||
auto member_var =
|
||||
HoistToInputVariables(ctx, func, member->type(), member->decorations());
|
||||
auto member_var = HoistToInputVariables(
|
||||
ctx, func, member->type(), member->type(), member->decorations());
|
||||
init_values.push_back(ctx.dst->Expr(member_var));
|
||||
}
|
||||
|
||||
|
@ -283,8 +285,9 @@ Symbol Spirv::HoistToInputVariables(
|
|||
}
|
||||
|
||||
// Create a function-scope variable for the struct.
|
||||
auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
|
||||
auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
|
||||
auto* initializer = ctx.dst->Construct(ctx.Clone(declared_ty), init_values);
|
||||
auto* func_var =
|
||||
ctx.dst->Const(func_var_symbol, ctx.Clone(declared_ty), initializer);
|
||||
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
|
||||
ctx.dst->WrapInStatement(func_var));
|
||||
return func_var_symbol;
|
||||
|
@ -293,12 +296,13 @@ Symbol Spirv::HoistToInputVariables(
|
|||
void Spirv::HoistToOutputVariables(CloneContext& ctx,
|
||||
const ast::Function* func,
|
||||
type::Type* ty,
|
||||
type::Type* declared_ty,
|
||||
const ast::DecorationList& decorations,
|
||||
std::vector<Symbol> member_accesses,
|
||||
Symbol store_value,
|
||||
ast::StatementList& stores) const {
|
||||
// Base case.
|
||||
if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
|
||||
if (!ty->Is<type::Struct>()) {
|
||||
// Create a global variable.
|
||||
ast::DecorationList new_decorations =
|
||||
RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
|
||||
|
@ -307,7 +311,7 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
|
|||
});
|
||||
auto global_var_symbol = ctx.dst->Symbols().New();
|
||||
auto* global_var =
|
||||
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
|
||||
ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
|
||||
ast::StorageClass::kOutput, nullptr, new_decorations);
|
||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
|
||||
|
||||
|
@ -322,11 +326,12 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
|
|||
}
|
||||
|
||||
// Recurse into struct members.
|
||||
auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
|
||||
auto* struct_ty = ty->As<type::Struct>();
|
||||
for (auto* member : struct_ty->impl()->members()) {
|
||||
member_accesses.push_back(ctx.Clone(member->symbol()));
|
||||
HoistToOutputVariables(ctx, func, member->type(), member->decorations(),
|
||||
member_accesses, store_value, stores);
|
||||
HoistToOutputVariables(ctx, func, member->type(), member->type(),
|
||||
member->decorations(), member_accesses, store_value,
|
||||
stores);
|
||||
member_accesses.pop_back();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,6 +60,7 @@ class Spirv : public Transform {
|
|||
Symbol HoistToInputVariables(CloneContext& ctx,
|
||||
const ast::Function* func,
|
||||
type::Type* ty,
|
||||
type::Type* declared_ty,
|
||||
const ast::DecorationList& decorations) const;
|
||||
|
||||
/// Recursively create module-scope output variables for `ty` and build a list
|
||||
|
@ -73,6 +74,7 @@ class Spirv : public Transform {
|
|||
void HoistToOutputVariables(CloneContext& ctx,
|
||||
const ast::Function* func,
|
||||
type::Type* ty,
|
||||
type::Type* declared_ty,
|
||||
const ast::DecorationList& decorations,
|
||||
std::vector<Symbol> member_accesses,
|
||||
Symbol store_value,
|
||||
|
|
|
@ -1894,7 +1894,7 @@ bool GeneratorImpl::EmitEntryPointData(
|
|||
for (auto* var : func_sem->ReferencedModuleVariables()) {
|
||||
auto* decl = var->Declaration();
|
||||
|
||||
auto* unwrapped_type = var->Type()->UnwrapAll();
|
||||
auto* unwrapped_type = var->DeclaredType()->UnwrapAll();
|
||||
if (!emitted_globals.emplace(decl->symbol()).second) {
|
||||
continue; // Global already emitted
|
||||
}
|
||||
|
@ -1905,7 +1905,7 @@ bool GeneratorImpl::EmitEntryPointData(
|
|||
continue; // Not interested in this type
|
||||
}
|
||||
|
||||
if (!EmitType(out, var->Type(), var->StorageClass(), "")) {
|
||||
if (!EmitType(out, var->DeclaredType(), var->StorageClass(), "")) {
|
||||
return false;
|
||||
}
|
||||
out << " " << builder_.Symbols().NameFor(decl->symbol());
|
||||
|
@ -1915,9 +1915,7 @@ bool GeneratorImpl::EmitEntryPointData(
|
|||
if (unwrapped_type->Is<type::Texture>()) {
|
||||
register_space = "t";
|
||||
if (unwrapped_type->Is<type::StorageTexture>()) {
|
||||
if (auto* ac = var->Type()
|
||||
->UnwrapAliasIfNeeded()
|
||||
->As<type::AccessControl>()) {
|
||||
if (auto* ac = var->Type()->As<type::AccessControl>()) {
|
||||
if (!ac->IsReadOnly()) {
|
||||
register_space = "u";
|
||||
}
|
||||
|
|
|
@ -325,7 +325,7 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
|
|||
|
||||
out_ << program_->Symbols().NameFor(v->symbol()) << " : ";
|
||||
|
||||
if (!EmitType(program_->Sem().Get(v)->Type())) {
|
||||
if (!EmitType(program_->Sem().Get(v)->DeclaredType())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -599,7 +599,7 @@ bool GeneratorImpl::EmitVariable(ast::Variable* var) {
|
|||
}
|
||||
|
||||
out_ << " " << program_->Symbols().NameFor(var->symbol()) << " : ";
|
||||
if (!EmitType(sem->Type())) {
|
||||
if (!EmitType(sem->DeclaredType())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue