Resolver: compute canonical types and store them as semantic::Variable::Type

We define the canonical type as a type stripped of all aliases. For
example, Canonical(alias<alias<vec3<alias<f32>>>>) is vec3<f32>. This
change adds Resolver::Canonical(Type*) which caches and returns the
resulting canonical type. We use this throughout the Resolver instead of
UnwrapAliasIfNeeded(), and we store the result in semantic::Variable,
returned from it's Type() member function.

Also:

* Wrote unit tests for Resolver::Canonical()

* Added semantic::Variable::DeclaredType() as a convenience to
retrieve the AST variable's type.

* Updated post-resolve code (transforms) to make use of Type and
DeclaredType appropriately, removing unnecessary calls to
UnwrapAliasIfNeeded.

* Added IntrinsicTableTest.MatchWithNestedAliasUnwrapping to ensure we
don't need to pass canonical parameter types for instrinsic table
lookups.

* ProgramBuilder: added vecN and matMxN overloads that take a Type* arg
to create them with alias types.

Bug: tint:705
Change-Id: I58a3b62538356b8dad2b1161a19b38bcefdd5d62
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/47360
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2021-04-13 13:32:33 +00:00 committed by Commit Bot service account
parent 2f25ecf8ba
commit 2543686412
14 changed files with 269 additions and 76 deletions

View File

@ -357,6 +357,23 @@ TEST_F(IntrinsicTableTest, MatchWithAliasUnwrapping) {
EXPECT_THAT(result.intrinsic->Parameters(), ElementsAre(Parameter{ty.f32()}));
}
TEST_F(IntrinsicTableTest, MatchWithNestedAliasUnwrapping) {
auto* alias_a = ty.alias("alias_a", ty.bool_());
auto* alias_b = ty.alias("alias_b", alias_a);
auto* alias_c = ty.alias("alias_c", alias_b);
auto* vec4_of_c = ty.vec4(alias_c);
auto* alias_d = ty.alias("alias_d", vec4_of_c);
auto* alias_e = ty.alias("alias_e", alias_d);
auto result = table->Lookup(*this, IntrinsicType::kAll, {alias_e}, Source{});
ASSERT_NE(result.intrinsic, nullptr);
ASSERT_EQ(result.diagnostics.str(), "");
EXPECT_THAT(result.intrinsic->Type(), IntrinsicType::kAll);
EXPECT_THAT(result.intrinsic->ReturnType(), ty.bool_());
EXPECT_THAT(result.intrinsic->Parameters(),
ElementsAre(Parameter{ty.vec4<bool>()}));
}
TEST_F(IntrinsicTableTest, MatchOpenType) {
auto result = table->Lookup(*this, IntrinsicType::kClamp,
{ty.f32(), ty.f32(), ty.f32()}, Source{});

View File

@ -338,76 +338,148 @@ class ProgramBuilder {
/// @returns a void type
type::Void* void_() const { return builder->create<type::Void>(); }
/// @param type vector subtype
/// @return the tint AST type for a 2-element vector of `type`.
type::Vector* vec2(type::Type* type) const {
return builder->create<type::Vector>(type, 2u);
}
/// @param type vector subtype
/// @return the tint AST type for a 3-element vector of `type`.
type::Vector* vec3(type::Type* type) const {
return builder->create<type::Vector>(type, 3u);
}
/// @param type vector subtype
/// @return the tint AST type for a 4-element vector of `type`.
type::Type* vec4(type::Type* type) const {
return builder->create<type::Vector>(type, 4u);
}
/// @return the tint AST type for a 2-element vector of the C type `T`.
template <typename T>
type::Vector* vec2() const {
return builder->create<type::Vector>(Of<T>(), 2);
return vec2(Of<T>());
}
/// @return the tint AST type for a 3-element vector of the C type `T`.
template <typename T>
type::Vector* vec3() const {
return builder->create<type::Vector>(Of<T>(), 3);
return vec3(Of<T>());
}
/// @return the tint AST type for a 4-element vector of the C type `T`.
template <typename T>
type::Type* vec4() const {
return builder->create<type::Vector>(Of<T>(), 4);
return vec4(Of<T>());
}
/// @param type matrix subtype
/// @return the tint AST type for a 2x3 matrix of `type`.
type::Matrix* mat2x2(type::Type* type) const {
return builder->create<type::Matrix>(type, 2u, 2u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 2x3 matrix of `type`.
type::Matrix* mat2x3(type::Type* type) const {
return builder->create<type::Matrix>(type, 3u, 2u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 2x4 matrix of `type`.
type::Matrix* mat2x4(type::Type* type) const {
return builder->create<type::Matrix>(type, 4u, 2u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 3x2 matrix of `type`.
type::Matrix* mat3x2(type::Type* type) const {
return builder->create<type::Matrix>(type, 2u, 3u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 3x3 matrix of `type`.
type::Matrix* mat3x3(type::Type* type) const {
return builder->create<type::Matrix>(type, 3u, 3u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 3x4 matrix of `type`.
type::Matrix* mat3x4(type::Type* type) const {
return builder->create<type::Matrix>(type, 4u, 3u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 4x2 matrix of `type`.
type::Matrix* mat4x2(type::Type* type) const {
return builder->create<type::Matrix>(type, 2u, 4u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 4x3 matrix of `type`.
type::Matrix* mat4x3(type::Type* type) const {
return builder->create<type::Matrix>(type, 3u, 4u);
}
/// @param type matrix subtype
/// @return the tint AST type for a 4x4 matrix of `type`.
type::Matrix* mat4x4(type::Type* type) const {
return builder->create<type::Matrix>(type, 4u, 4u);
}
/// @return the tint AST type for a 2x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat2x2() const {
return builder->create<type::Matrix>(Of<T>(), 2, 2);
return mat2x2(Of<T>());
}
/// @return the tint AST type for a 2x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat2x3() const {
return builder->create<type::Matrix>(Of<T>(), 3, 2);
return mat2x3(Of<T>());
}
/// @return the tint AST type for a 2x4 matrix of the C type `T`.
template <typename T>
type::Matrix* mat2x4() const {
return builder->create<type::Matrix>(Of<T>(), 4, 2);
return mat2x4(Of<T>());
}
/// @return the tint AST type for a 3x2 matrix of the C type `T`.
template <typename T>
type::Matrix* mat3x2() const {
return builder->create<type::Matrix>(Of<T>(), 2, 3);
return mat3x2(Of<T>());
}
/// @return the tint AST type for a 3x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat3x3() const {
return builder->create<type::Matrix>(Of<T>(), 3, 3);
return mat3x3(Of<T>());
}
/// @return the tint AST type for a 3x4 matrix of the C type `T`.
template <typename T>
type::Matrix* mat3x4() const {
return builder->create<type::Matrix>(Of<T>(), 4, 3);
return mat3x4(Of<T>());
}
/// @return the tint AST type for a 4x2 matrix of the C type `T`.
template <typename T>
type::Matrix* mat4x2() const {
return builder->create<type::Matrix>(Of<T>(), 2, 4);
return mat4x2(Of<T>());
}
/// @return the tint AST type for a 4x3 matrix of the C type `T`.
template <typename T>
type::Matrix* mat4x3() const {
return builder->create<type::Matrix>(Of<T>(), 3, 4);
return mat4x3(Of<T>());
}
/// @return the tint AST type for a 4x4 matrix of the C type `T`.
template <typename T>
type::Matrix* mat4x4() const {
return builder->create<type::Matrix>(Of<T>(), 4, 4);
return mat4x4(Of<T>());
}
/// @param subtype the array element type

View File

@ -42,6 +42,7 @@
#include "src/semantic/struct.h"
#include "src/semantic/variable.h"
#include "src/type/access_control_type.h"
#include "src/utils/get_or_create.h"
#include "src/utils/math.h"
namespace tint {
@ -430,7 +431,7 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func) {
}
// Check that we saw a pipeline IO attribute iff we need one.
if (ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
if (Canonical(ty)->Is<type::Struct>()) {
if (pipeline_io_attribute) {
diagnostics_.add_error(
"entry point IO attributes must not be used on structure " +
@ -466,11 +467,11 @@ bool Resolver::ValidateEntryPoint(const ast::Function* func) {
return false;
}
if (auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
if (auto* struct_ty = Canonical(ty)->As<type::Struct>()) {
// Validate the decorations for each struct members, and also check for
// invalid member types.
for (auto* member : struct_ty->impl()->members()) {
auto* member_ty = member->type()->UnwrapAliasIfNeeded();
auto* member_ty = Canonical(member->type());
if (member_ty->Is<type::Struct>()) {
diagnostics_.add_error(
"entry point IO types cannot contain nested structures",
@ -547,8 +548,7 @@ bool Resolver::Function(ast::Function* func) {
return false;
}
if (auto* str =
param->declared_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
if (auto* str = param_info->type->As<type::Struct>()) {
auto* info = Structure(str);
if (!info) {
return false;
@ -572,8 +572,7 @@ bool Resolver::Function(ast::Function* func) {
}
}
if (auto* str =
func->return_type()->UnwrapAliasIfNeeded()->As<type::Struct>()) {
if (auto* str = Canonical(func->return_type())->As<type::Struct>()) {
if (!ApplyStorageClassUsageToType(ast::StorageClass::kNone, str,
func->source())) {
diagnostics_.add_note("while instantiating return type for " +
@ -1288,15 +1287,16 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
using Matrix = type::Matrix;
using Vector = type::Vector;
auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
auto* lhs_declared_type = TypeOf(expr->lhs())->UnwrapAll();
auto* rhs_declared_type = TypeOf(expr->rhs())->UnwrapAll();
auto* lhs_type = Canonical(lhs_declared_type);
auto* rhs_type = Canonical(rhs_declared_type);
auto* lhs_vec = lhs_type->As<Vector>();
auto* lhs_vec_elem_type =
lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_type->As<Vector>();
auto* rhs_vec_elem_type =
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
const bool matching_vec_elem_types =
lhs_vec_elem_type && rhs_vec_elem_type &&
@ -1348,11 +1348,9 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
}
auto* lhs_mat = lhs_type->As<Matrix>();
auto* lhs_mat_elem_type =
lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_type->As<Matrix>();
auto* rhs_mat_elem_type =
rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
@ -1438,9 +1436,9 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
diagnostics_.add_error(
"Binary expression operand types are invalid for this operation: " +
lhs_type->FriendlyName(builder_->Symbols()) + " " +
lhs_declared_type->FriendlyName(builder_->Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(builder_->Symbols()),
rhs_declared_type->FriendlyName(builder_->Symbols()),
expr->source());
return false;
}
@ -1600,7 +1598,7 @@ bool Resolver::VariableDeclStatement(const ast::VariableDeclStatement* stmt) {
}
Resolver::VariableInfo* Resolver::CreateVariableInfo(ast::Variable* var) {
auto* info = variable_infos_.Create(var);
auto* info = variable_infos_.Create(var, Canonical(var->declared_type()));
variable_to_info_.emplace(var, info);
return info;
}
@ -1748,13 +1746,13 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
/*vec4*/ 16,
};
ty = ty->UnwrapAliasIfNeeded();
if (ty->is_scalar()) {
auto* cty = Canonical(ty);
if (cty->is_scalar()) {
// Note: Also captures booleans, but these are not host-shareable.
align = 4;
size = 4;
return true;
} else if (auto* vec = ty->As<type::Vector>()) {
} else if (auto* vec = cty->As<type::Vector>()) {
if (vec->size() < 2 || vec->size() > 4) {
TINT_UNREACHABLE(diagnostics_)
<< "Invalid vector size: vec" << vec->size();
@ -1763,7 +1761,7 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
align = vector_align[vec->size()];
size = vector_size[vec->size()];
return true;
} else if (auto* mat = ty->As<type::Matrix>()) {
} else if (auto* mat = cty->As<type::Matrix>()) {
if (mat->columns() < 2 || mat->columns() > 4 || mat->rows() < 2 ||
mat->rows() > 4) {
TINT_UNREACHABLE(diagnostics_)
@ -1773,15 +1771,15 @@ bool Resolver::DefaultAlignAndSize(type::Type* ty,
align = vector_align[mat->rows()];
size = vector_align[mat->rows()] * mat->columns();
return true;
} else if (auto* s = ty->As<type::Struct>()) {
} else if (auto* s = cty->As<type::Struct>()) {
if (auto* si = Structure(s)) {
align = si->align;
size = si->size;
return true;
}
return false;
} else if (auto* arr = ty->As<type::Array>()) {
if (auto* sem = Array(arr)) {
} else if (cty->Is<type::Array>()) {
if (auto* sem = Array(ty->UnwrapAliasIfNeeded()->As<type::Array>())) {
align = sem->Align();
size = sem->Size();
return true;
@ -2249,9 +2247,37 @@ std::string Resolver::VectorPretty(uint32_t size, type::Type* element_type) {
return vec_type.FriendlyName(builder_->Symbols());
}
Resolver::VariableInfo::VariableInfo(ast::Variable* decl)
type::Type* Resolver::Canonical(type::Type* type) {
using Type = type::Type;
using Alias = type::Alias;
using Matrix = type::Matrix;
using Vector = type::Vector;
std::function<Type*(Type*)> make_canonical;
make_canonical = [&](Type* t) -> type::Type* {
// Unwrap alias sequence
Type* ct = t;
while (auto* p = ct->As<Alias>()) {
ct = p->type();
}
if (auto* v = ct->As<Vector>()) {
return builder_->create<Vector>(make_canonical(v->type()), v->size());
}
if (auto* m = ct->As<Matrix>()) {
return builder_->create<Matrix>(make_canonical(m->type()), m->rows(),
m->columns());
}
return ct;
};
return utils::GetOrCreate(type_to_canonical_, type,
[&] { return make_canonical(type); });
}
Resolver::VariableInfo::VariableInfo(ast::Variable* decl, type::Type* ctype)
: declaration(decl),
type(decl->declared_type()),
type(ctype),
storage_class(decl->declared_storage_class()) {}
Resolver::VariableInfo::~VariableInfo() = default;

View File

@ -86,11 +86,17 @@ class Resolver {
/// structure member or array element of type `lhs`
static bool IsValidAssignment(type::Type* lhs, type::Type* rhs);
/// @param type the input type
/// @returns the canonical type for `type`; that is, a type with all aliases
/// removed. For example, `Canonical(alias<alias<vec3<alias<f32>>>>)` is
/// `vec3<f32>`.
type::Type* Canonical(type::Type* type);
private:
/// Structure holding semantic information about a variable.
/// Used to build the semantic::Variable nodes at the end of resolving.
struct VariableInfo {
explicit VariableInfo(ast::Variable* decl);
VariableInfo(ast::Variable* decl, type::Type* type);
~VariableInfo();
ast::Variable* const declaration;
@ -306,6 +312,7 @@ class Resolver {
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
std::unordered_map<ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<type::Struct*, StructInfo*> struct_info_;
std::unordered_map<type::Type*, type::Type*> type_to_canonical_;
FunctionInfo* current_function_ = nullptr;
semantic::Statement* current_statement_ = nullptr;
BlockAllocator<VariableInfo> variable_infos_;

View File

@ -119,18 +119,30 @@ inline type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.f32();
}
using create_type_func_ptr =
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <typename T>
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3<T>();
}
template <create_type_func_ptr create_type>
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.vec3(type);
}
template <typename T>
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>();
}
using create_type_func_ptr =
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
template <create_type_func_ptr create_type>
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
auto* type = create_type(ty);
return ty.mat3x3(type);
}
template <create_type_func_ptr create_type>
type::Type* ty_alias(const ProgramBuilder::TypesBuilder& ty) {

View File

@ -21,6 +21,7 @@
#include "gmock/gmock.h"
namespace tint {
namespace resolver {
namespace {
class ResolverTypeValidationTest : public resolver::TestHelper,
@ -463,5 +464,51 @@ TEST_F(ResolverTypeValidationTest, AliasRuntimeArrayIsLast_Pass) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
namespace GetCanonicalTests {
struct Params {
create_type_func_ptr create_type;
create_type_func_ptr create_canonical_type;
};
static constexpr Params cases[] = {
Params{ty_bool_, ty_bool_},
Params{ty_alias<ty_bool_>, ty_bool_},
Params{ty_alias<ty_alias<ty_bool_>>, ty_bool_},
Params{ty_vec3<ty_f32>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_f32>>>, ty_vec3<ty_f32>},
Params{ty_vec3<ty_alias<ty_f32>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_vec3<ty_alias<ty_f32>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_f32>>>>, ty_vec3<ty_f32>},
Params{ty_alias<ty_alias<ty_vec3<ty_alias<ty_alias<ty_f32>>>>>,
ty_vec3<ty_f32>},
Params{ty_mat3x3<ty_alias<ty_f32>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_mat3x3<ty_alias<ty_f32>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_f32>>>>, ty_mat3x3<ty_f32>},
Params{ty_alias<ty_alias<ty_mat3x3<ty_alias<ty_alias<ty_f32>>>>>,
ty_mat3x3<ty_f32>},
};
using CanonicalTest = ResolverTestWithParam<Params>;
TEST_P(CanonicalTest, All) {
auto& params = GetParam();
auto* type = params.create_type(ty);
auto* expected_canonical_type = params.create_canonical_type(ty);
auto* canonical_type = r()->Canonical(type);
EXPECT_EQ(canonical_type, expected_canonical_type);
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,
CanonicalTest,
testing::ValuesIn(cases));
} // namespace GetCanonicalTests
} // namespace
} // namespace resolver
} // namespace tint

View File

@ -15,6 +15,7 @@
#include "src/semantic/variable.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::semantic::Variable);
TINT_INSTANTIATE_TYPEINFO(tint::semantic::VariableUser);
@ -29,6 +30,10 @@ Variable::Variable(const ast::Variable* declaration,
Variable::~Variable() = default;
type::Type* Variable::DeclaredType() const {
return declaration_->declared_type();
}
VariableUser::VariableUser(ast::IdentifierExpression* declaration,
type::Type* type,
Statement* statement,

View File

@ -52,9 +52,12 @@ class Variable : public Castable<Variable, Node> {
/// @returns the AST declaration node
const ast::Variable* Declaration() const { return declaration_; }
/// @returns the type for the variable
/// @returns the canonical type for the variable
type::Type* Type() const { return type_; }
/// @returns the AST node's type. May be nullptr.
type::Type* DeclaredType() const;
/// @returns the storage class for the variable
ast::StorageClass StorageClass() const { return storage_class_; }

View File

@ -65,8 +65,7 @@ Transform::Output BindingRemapper::Run(const Program* in,
auto ac_it = remappings->access_controls.find(from);
if (ac_it != remappings->access_controls.end()) {
ast::AccessControl ac = ac_it->second;
auto* var_ty = in->Sem().Get(var)->Type();
auto* ty = var_ty->UnwrapAliasIfNeeded();
auto* ty = in->Sem().Get(var)->Type();
type::Type* inner_ty = nullptr;
if (auto* old_ac = ty->As<type::AccessControl>()) {
inner_ty = ctx.Clone(old_ac->type());

View File

@ -72,11 +72,11 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
for (auto* param : func->params()) {
auto param_name = ctx.Clone(param->symbol());
auto* param_ty = ctx.src->Sem().Get(param)->Type();
auto* param_declared_ty = ctx.src->Sem().Get(param)->DeclaredType();
ast::Expression* func_const_initializer = nullptr;
if (auto* struct_ty =
param_ty->UnwrapAliasIfNeeded()->As<type::Struct>()) {
if (auto* struct_ty = param_ty->As<type::Struct>()) {
// Pull out all struct members and build initializer list.
ast::ExpressionList init_values;
for (auto* member : struct_ty->impl()->members()) {
@ -97,7 +97,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
}
func_const_initializer =
ctx.dst->Construct(ctx.Clone(param_ty), init_values);
ctx.dst->Construct(ctx.Clone(param_declared_ty), init_values);
} else {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, param->decorations(), [](const ast::Decoration* deco) {
@ -105,7 +105,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
ast::LocationDecoration>();
});
new_struct_members.push_back(ctx.dst->Member(
param_name, ctx.Clone(param_ty), new_decorations));
param_name, ctx.Clone(param_declared_ty), new_decorations));
func_const_initializer =
ctx.dst->MemberAccessor(new_struct_param_symbol, param_name);
}
@ -117,8 +117,8 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
// Create a function-scope const to replace the parameter.
// Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const(param_name, ctx.Clone(param_ty),
func_const_initializer);
auto* func_const = ctx.dst->Const(
param_name, ctx.Clone(param_declared_ty), func_const_initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_const));

View File

@ -141,7 +141,8 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
for (auto* param : func->params()) {
Symbol new_var = HoistToInputVariables(
ctx, func, ctx.src->Sem().Get(param)->Type(), param->decorations());
ctx, func, ctx.src->Sem().Get(param)->Type(),
ctx.src->Sem().Get(param)->DeclaredType(), param->decorations());
// Replace all uses of the function parameter with the new variable.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
@ -153,9 +154,9 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
if (!func->return_type()->Is<type::Void>()) {
ast::StatementList stores;
auto store_value_symbol = ctx.dst->Symbols().New();
HoistToOutputVariables(ctx, func, func->return_type(),
func->return_type_decorations(), {},
store_value_symbol, stores);
HoistToOutputVariables(
ctx, func, func->return_type(), func->return_type(),
func->return_type_decorations(), {}, store_value_symbol, stores);
// Create a function that writes a return value to all output variables.
auto* store_value =
@ -251,8 +252,9 @@ Symbol Spirv::HoistToInputVariables(
CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
type::Type* declared_ty,
const ast::DecorationList& decorations) const {
if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
if (!ty->Is<type::Struct>()) {
// Base case: create a global variable and return.
ast::DecorationList new_decorations =
RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
@ -261,7 +263,7 @@ Symbol Spirv::HoistToInputVariables(
});
auto global_var_symbol = ctx.dst->Symbols().New();
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
ast::StorageClass::kInput, nullptr, new_decorations);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
return global_var_symbol;
@ -269,10 +271,10 @@ Symbol Spirv::HoistToInputVariables(
// Recurse into struct members and build the initializer list.
ast::ExpressionList init_values;
auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
auto* struct_ty = ty->As<type::Struct>();
for (auto* member : struct_ty->impl()->members()) {
auto member_var =
HoistToInputVariables(ctx, func, member->type(), member->decorations());
auto member_var = HoistToInputVariables(
ctx, func, member->type(), member->type(), member->decorations());
init_values.push_back(ctx.dst->Expr(member_var));
}
@ -283,8 +285,9 @@ Symbol Spirv::HoistToInputVariables(
}
// Create a function-scope variable for the struct.
auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
auto* initializer = ctx.dst->Construct(ctx.Clone(declared_ty), init_values);
auto* func_var =
ctx.dst->Const(func_var_symbol, ctx.Clone(declared_ty), initializer);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_var));
return func_var_symbol;
@ -293,12 +296,13 @@ Symbol Spirv::HoistToInputVariables(
void Spirv::HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
type::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,
ast::StatementList& stores) const {
// Base case.
if (!ty->UnwrapAliasIfNeeded()->Is<type::Struct>()) {
if (!ty->Is<type::Struct>()) {
// Create a global variable.
ast::DecorationList new_decorations =
RemoveDecorations(&ctx, decorations, [](const ast::Decoration* deco) {
@ -307,7 +311,7 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
});
auto global_var_symbol = ctx.dst->Symbols().New();
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
ctx.dst->Var(global_var_symbol, ctx.Clone(declared_ty),
ast::StorageClass::kOutput, nullptr, new_decorations);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
@ -322,11 +326,12 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
}
// Recurse into struct members.
auto* struct_ty = ty->UnwrapAliasIfNeeded()->As<type::Struct>();
auto* struct_ty = ty->As<type::Struct>();
for (auto* member : struct_ty->impl()->members()) {
member_accesses.push_back(ctx.Clone(member->symbol()));
HoistToOutputVariables(ctx, func, member->type(), member->decorations(),
member_accesses, store_value, stores);
HoistToOutputVariables(ctx, func, member->type(), member->type(),
member->decorations(), member_accesses, store_value,
stores);
member_accesses.pop_back();
}
}

View File

@ -60,6 +60,7 @@ class Spirv : public Transform {
Symbol HoistToInputVariables(CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
type::Type* declared_ty,
const ast::DecorationList& decorations) const;
/// Recursively create module-scope output variables for `ty` and build a list
@ -73,6 +74,7 @@ class Spirv : public Transform {
void HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
type::Type* ty,
type::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,

View File

@ -1894,7 +1894,7 @@ bool GeneratorImpl::EmitEntryPointData(
for (auto* var : func_sem->ReferencedModuleVariables()) {
auto* decl = var->Declaration();
auto* unwrapped_type = var->Type()->UnwrapAll();
auto* unwrapped_type = var->DeclaredType()->UnwrapAll();
if (!emitted_globals.emplace(decl->symbol()).second) {
continue; // Global already emitted
}
@ -1905,7 +1905,7 @@ bool GeneratorImpl::EmitEntryPointData(
continue; // Not interested in this type
}
if (!EmitType(out, var->Type(), var->StorageClass(), "")) {
if (!EmitType(out, var->DeclaredType(), var->StorageClass(), "")) {
return false;
}
out << " " << builder_.Symbols().NameFor(decl->symbol());
@ -1915,9 +1915,7 @@ bool GeneratorImpl::EmitEntryPointData(
if (unwrapped_type->Is<type::Texture>()) {
register_space = "t";
if (unwrapped_type->Is<type::StorageTexture>()) {
if (auto* ac = var->Type()
->UnwrapAliasIfNeeded()
->As<type::AccessControl>()) {
if (auto* ac = var->Type()->As<type::AccessControl>()) {
if (!ac->IsReadOnly()) {
register_space = "u";
}

View File

@ -325,7 +325,7 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
out_ << program_->Symbols().NameFor(v->symbol()) << " : ";
if (!EmitType(program_->Sem().Get(v)->Type())) {
if (!EmitType(program_->Sem().Get(v)->DeclaredType())) {
return false;
}
}
@ -599,7 +599,7 @@ bool GeneratorImpl::EmitVariable(ast::Variable* var) {
}
out_ << " " << program_->Symbols().NameFor(var->symbol()) << " : ";
if (!EmitType(sem->Type())) {
if (!EmitType(sem->DeclaredType())) {
return false;
}