ast: Migrate to using ast::Type

Remove all sem::Type references from the AST.
ConstructedTypes are now all AST types.

The parsers will still create semantic types, but these are now disjoint
and ignored.
The parsers will be updated with future changes to stop creating these
semantic types.

Resolver creates semantic types from the AST types. Most downstream
logic continues to use the semantic types, however transforms will now
need to rebuild AST type information instead of reassigning semantic
information, as semantic nodes are fully rebuilt by the Resolver.

Bug: tint:724
Change-Id: I4ce03a075f13c77648cda5c3691bae202752ecc5
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49747
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton
2021-05-05 09:09:41 +00:00
committed by Commit Bot service account
parent 781de097eb
commit 02ebf0dcae
72 changed files with 1267 additions and 1091 deletions

View File

@@ -65,13 +65,13 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) {
if (ac_it != remappings->access_controls.end()) {
ast::AccessControl::Access ac = ac_it->second;
auto* ty = in->Sem().Get(var)->Type();
sem::Type* inner_ty = nullptr;
ast::Type* inner_ty = nullptr;
if (auto* old_ac = ty->As<sem::AccessControl>()) {
inner_ty = ctx.Clone(old_ac->type());
inner_ty = CreateASTTypeFor(&ctx, old_ac->type());
} else {
inner_ty = ctx.Clone(ty);
inner_ty = CreateASTTypeFor(&ctx, ty);
}
auto* new_ty = ctx.dst->create<sem::AccessControl>(ac, inner_ty);
auto* new_ty = ctx.dst->create<ast::AccessControl>(ac, inner_ty);
auto* new_var = ctx.dst->create<ast::Variable>(
ctx.Clone(var->source()), ctx.Clone(var->symbol()),
var->declared_storage_class(), new_ty, var->is_const(),

View File

@@ -81,6 +81,8 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
auto get_buffer_size_intrinsic = [&](sem::StructType* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = ctx.dst->Sym();
auto* buffer_typename =
ctx.dst->ty.type_name(ctx.Clone(buffer_type->impl()->name()));
auto* func = ctx.dst->create<ast::Function>(
name,
ast::VariableList{
@@ -88,7 +90,7 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
// in order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
ctx.Clone(buffer_type), true, nullptr, ast::DecorationList{}),
buffer_typename, true, nullptr, ast::DecorationList{}),
ctx.dst->Param("result",
ctx.dst->ty.pointer(ctx.dst->ty.u32(),
ast::StorageClass::kFunction)),
@@ -98,7 +100,8 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID()),
},
ast::DecorationList{});
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type, func);
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), buffer_type->impl(),
func);
return name;
});
};

View File

@@ -21,6 +21,7 @@
#include "src/program_builder.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
namespace tint {
@@ -65,11 +66,11 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Strip entry point IO decorations from struct declarations.
// TODO(jrprice): This code is duplicated with the SPIR-V transform.
for (auto ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<sem::StructType>()) {
for (auto* ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
// Build new list of struct members without entry point IO decorations.
ast::StructMemberList new_struct_members;
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : struct_ty->members()) {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
return deco
@@ -81,49 +82,53 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
}
// Redeclare the struct.
auto new_struct_name = ctx.Clone(struct_ty->impl()->name());
auto new_struct_name = ctx.Clone(struct_ty->name());
auto* new_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
new_struct_name, new_struct_members,
ctx.Clone(struct_ty->impl()->decorations())));
ctx.dst->create<ast::Struct>(new_struct_name, new_struct_members,
ctx.Clone(struct_ty->decorations()));
ctx.Replace(struct_ty, new_struct);
}
}
for (auto* func : ctx.src->AST().Functions()) {
if (!func->IsEntryPoint()) {
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
ast::VariableList new_parameters;
if (!func->params().empty()) {
if (!func->Parameters().empty()) {
// Collect all parameters and build a list of new struct members.
auto new_struct_param_symbol = ctx.dst->Sym();
ast::StructMemberList new_struct_members;
for (auto* param : func->params()) {
auto param_name = ctx.Clone(param->symbol());
auto* param_ty = ctx.src->Sem().Get(param)->Type();
auto* param_declared_ty = ctx.src->Sem().Get(param)->DeclaredType();
for (auto* param : func->Parameters()) {
auto param_name = ctx.Clone(param->Declaration()->symbol());
auto* param_ty = param->Type();
auto* param_declared_ty = param->Declaration()->type();
std::function<ast::Expression*()> func_const_initializer;
if (auto* struct_ty = param_ty->As<sem::StructType>()) {
auto* str = ctx.src->Sem().Get(struct_ty);
// Pull out all struct members and build initializer list.
std::vector<Symbol> member_names;
for (auto* member : struct_ty->impl()->members()) {
if (member->type()->UnwrapAll()->Is<sem::StructType>()) {
for (auto* member : str->Members()) {
if (member->Type()->UnwrapAll()->Is<sem::StructType>()) {
TINT_ICE(ctx.dst->Diagnostics()) << "nested pipeline IO struct";
}
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
&ctx, member->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
auto member_name = ctx.Clone(member->symbol());
new_struct_members.push_back(ctx.dst->Member(
member_name, ctx.Clone(member->type()), new_decorations));
auto member_name = ctx.Clone(member->Declaration()->symbol());
auto* member_type = ctx.Clone(member->Declaration()->type());
new_struct_members.push_back(
ctx.dst->Member(member_name, member_type, new_decorations));
member_names.emplace_back(member_name);
}
@@ -139,7 +144,8 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
};
} else {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, param->decorations(), [](const ast::Decoration* deco) {
&ctx, param->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
@@ -151,7 +157,7 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
};
}
if (func->body()->empty()) {
if (func_ast->body()->empty()) {
// Don't generate a function-scope const if the function is empty.
continue;
}
@@ -160,11 +166,12 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const(
param_name, ctx.Clone(param_declared_ty), func_const_initializer());
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.InsertBefore(func_ast->body()->statements(),
*func_ast->body()->begin(),
ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
for (auto* user : param->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(param_name));
}
@@ -176,44 +183,49 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Create the new struct type.
auto in_struct_name = ctx.dst->Sym();
auto* in_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
in_struct_name, new_struct_members, ast::DecorationList{}));
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct);
auto* in_struct = ctx.dst->create<ast::Struct>(
in_struct_name, new_struct_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
in_struct);
// Create a new function parameter using this struct type.
auto* struct_param = ctx.dst->Param(new_struct_param_symbol, in_struct);
auto* struct_param = ctx.dst->Param(
new_struct_param_symbol, ctx.dst->ty.type_name(in_struct_name));
new_parameters.push_back(struct_param);
}
// Handle return type.
auto* ret_type = func->return_type()->UnwrapAliasIfNeeded();
sem::Type* new_ret_type;
auto* ret_type = func->ReturnType()->UnwrapAliasIfNeeded();
std::function<ast::Type*()> new_ret_type;
if (ret_type->Is<sem::Void>()) {
new_ret_type = ctx.dst->ty.void_();
new_ret_type = [&ctx] { return ctx.dst->ty.void_(); };
} else {
ast::StructMemberList new_struct_members;
if (auto* struct_ty = ret_type->As<sem::StructType>()) {
auto* str = ctx.src->Sem().Get(struct_ty);
// Rebuild struct with only the entry point IO attributes.
for (auto* member : struct_ty->impl()->members()) {
if (member->type()->UnwrapAll()->Is<sem::StructType>()) {
for (auto* member : str->Members()) {
if (member->Type()->UnwrapAll()->Is<sem::StructType>()) {
TINT_ICE(ctx.dst->Diagnostics()) << "nested pipeline IO struct";
}
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
&ctx, member->Declaration()->decorations(),
[](const ast::Decoration* deco) {
return !deco->IsAnyOf<ast::BuiltinDecoration,
ast::LocationDecoration>();
});
auto symbol = ctx.Clone(member->Declaration()->symbol());
auto* member_ty = ctx.Clone(member->Declaration()->type());
new_struct_members.push_back(
ctx.dst->Member(ctx.Clone(member->symbol()),
ctx.Clone(member->type()), new_decorations));
ctx.dst->Member(symbol, member_ty, new_decorations));
}
} else {
auto* member_ty = ctx.Clone(func->Declaration()->return_type());
auto decos = ctx.Clone(func_ast->return_type_decorations());
new_struct_members.push_back(
ctx.dst->Member("value", ctx.Clone(ret_type),
ctx.Clone(func->return_type_decorations())));
ctx.dst->Member("value", member_ty, std::move(decos)));
}
// Sort struct members to satisfy HLSL interfacing matching rules.
@@ -222,15 +234,16 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Create the new struct type.
auto out_struct_name = ctx.dst->Sym();
auto* out_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
out_struct_name, new_struct_members, ast::DecorationList{}));
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, out_struct);
new_ret_type = out_struct;
auto* out_struct = ctx.dst->create<ast::Struct>(
out_struct_name, new_struct_members, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
out_struct);
new_ret_type = [out_struct_name, &ctx] {
return ctx.dst->ty.type_name(out_struct_name);
};
// Replace all return statements.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
for (auto* ret : func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
// Reconstruct the return value using the newly created struct.
std::function<ast::Expression*()> new_ret_value = [&ctx, ret] {
@@ -243,8 +256,9 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
// Create a const to hold the return value expression to avoid
// re-evaluating it multiple times.
auto temp = ctx.dst->Sym();
auto* temp_var = ctx.dst->Decl(
ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value()));
auto* ty = CreateASTTypeFor(&ctx, ret_type);
auto* temp_var =
ctx.dst->Decl(ctx.dst->Const(temp, ty, new_ret_value()));
ctx.InsertBefore(ret_sem->Block()->statements(), ret, temp_var);
new_ret_value = [&ctx, temp] { return ctx.dst->Expr(temp); };
}
@@ -258,17 +272,17 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) {
}
auto* new_ret =
ctx.dst->Return(ctx.dst->Construct(new_ret_type, ret_values));
ctx.dst->Return(ctx.dst->Construct(new_ret_type(), ret_values));
ctx.Replace(ret, new_ret);
}
}
// Rewrite the function header with the new parameters.
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), new_parameters, new_ret_type,
ctx.Clone(func->body()), ctx.Clone(func->decorations()),
ast::DecorationList{});
ctx.Replace(func, new_func);
func_ast->source(), ctx.Clone(func_ast->symbol()), new_parameters,
new_ret_type(), ctx.Clone(func_ast->body()),
ctx.Clone(func_ast->decorations()), ast::DecorationList{});
ctx.Replace(func_ast, new_func);
}
ctx.Clone();

View File

@@ -23,6 +23,7 @@
#include "src/ast/assignment_statement.h"
#include "src/ast/call_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type_name.h"
#include "src/program_builder.h"
#include "src/sem/access_control_type.h"
#include "src/sem/array.h"
@@ -318,7 +319,9 @@ DecomposeStorageAccess::Intrinsic* IntrinsicStoreFor(ProgramBuilder* builder,
/// Inserts `node` before `insert_after` in the global declarations of
/// `ctx.dst`. If `insert_after` is nullptr, then `node` is inserted at the top
/// of the module.
void InsertGlobal(CloneContext& ctx, Cloneable* insert_after, Cloneable* node) {
void InsertGlobal(CloneContext& ctx,
const Cloneable* insert_after,
Cloneable* node) {
auto& globals = ctx.src->AST().GlobalDeclarations();
if (insert_after) {
ctx.InsertAfter(globals, insert_after, node);
@@ -328,7 +331,7 @@ void InsertGlobal(CloneContext& ctx, Cloneable* insert_after, Cloneable* node) {
}
/// @returns the unwrapped, user-declared constructed type of ty.
sem::Type* ConstructedTypeOf(sem::Type* ty) {
ast::NamedType* ConstructedTypeOf(sem::Type* ty) {
while (true) {
if (auto* ptr = ty->As<sem::Pointer>()) {
ty = ptr->type();
@@ -338,11 +341,8 @@ sem::Type* ConstructedTypeOf(sem::Type* ty) {
ty = access->type();
continue;
}
if (auto* alias = ty->As<sem::Alias>()) {
return alias;
}
if (auto* str = ty->As<sem::StructType>()) {
return str;
return str->impl();
}
// Not a constructed type
return nullptr;
@@ -368,8 +368,10 @@ struct Store {
StorageBufferAccess target; // The target for the write
};
} // namespace
/// State holds the current transform state
struct State {
struct DecomposeStorageAccess::State {
/// Map of AST expression to storage buffer access
/// This map has entries added when encountered, and removed when outer
/// expressions chain the access.
@@ -385,9 +387,12 @@ struct State {
/// List of storage buffer writes
std::vector<Store> stores;
/// AddAccesss() adds the `expr -> access` map item to #accesses, and `expr`
/// AddAccess() adds the `expr -> access` map item to #accesses, and `expr`
/// to #expression_order.
void AddAccesss(ast::Expression* expr, StorageBufferAccess&& access) {
/// @param expr the expression that performs the access
/// @param access the access
void AddAccess(ast::Expression* expr, StorageBufferAccess&& access) {
TINT_ASSERT(access.type);
accesses.emplace(expr, std::move(access));
expression_order.emplace_back(expr);
}
@@ -395,6 +400,8 @@ struct State {
/// TakeAccess() removes the `node` item from #accesses (if it exists),
/// returning the StorageBufferAccess. If #accesses does not hold an item for
/// `node`, an invalid StorageBufferAccess is returned.
/// @param node the expression that performed an access
/// @return the StorageBufferAccess for the given expression
StorageBufferAccess TakeAccess(ast::Expression* node) {
auto lhs_it = accesses.find(node);
if (lhs_it == accesses.end()) {
@@ -408,24 +415,31 @@ struct State {
/// LoadFunc() returns a symbol to an intrinsic function that loads an element
/// of type `el_ty` from a storage buffer of type `buf_ty`. The function has
/// the signature: `fn load(buf : buf_ty, offset : u32) -> el_ty`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @return the name of the function that performs the load
Symbol LoadFunc(CloneContext& ctx,
Cloneable* insert_after,
ast::NamedType* insert_after,
sem::Type* buf_ty,
sem::Type* el_ty) {
return utils::GetOrCreate(load_funcs, TypePair{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
ast::VariableList params = {
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
ctx.Clone(buf_ty), true, nullptr, ast::DecorationList{}),
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, buf_ast_ty,
true, nullptr, ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicLoadFor(ctx.dst, el_ty)) {
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.Clone(el_ty), nullptr,
ctx.dst->Sym(), params, el_ast_ty, nullptr,
ast::DecorationList{intrinsic}, ast::DecorationList{});
} else {
ast::ExpressionList values;
@@ -444,7 +458,7 @@ struct State {
for (auto* member : str->Members()) {
auto* offset = ctx.dst->Add("offset", member->Offset());
Symbol load = LoadFunc(ctx, insert_after, buf_ty,
member->Declaration()->type()->UnwrapAll());
member->Type()->UnwrapAll());
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
} else if (auto* arr_ty = el_ty->As<sem::ArrayType>()) {
@@ -457,11 +471,12 @@ struct State {
values.emplace_back(ctx.dst->Call(load, "buffer", offset));
}
}
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
func = ctx.dst->create<ast::Function>(
ctx.dst->Sym(), params, ctx.Clone(el_ty),
ctx.dst->Sym(), params, el_ast_ty,
ctx.dst->Block(
ctx.dst->Return(ctx.dst->create<ast::TypeConstructorExpression>(
ctx.Clone(el_ty), values))),
CreateASTTypeFor(&ctx, el_ty), values))),
ast::DecorationList{}, ast::DecorationList{});
}
InsertGlobal(ctx, insert_after, func);
@@ -472,19 +487,26 @@ struct State {
/// StoreFunc() returns a symbol to an intrinsic function that stores an
/// element of type `el_ty` to a storage buffer of type `buf_ty`. The function
/// has the signature: `fn store(buf : buf_ty, offset : u32, value : el_ty)`
/// @param ctx the CloneContext
/// @param insert_after the user-declared type to insert the function after
/// @param buf_ty the storage buffer type
/// @param el_ty the storage buffer element type
/// @return the name of the function that performs the store
Symbol StoreFunc(CloneContext& ctx,
Cloneable* insert_after,
ast::NamedType* insert_after,
sem::Type* buf_ty,
sem::Type* el_ty) {
return utils::GetOrCreate(store_funcs, TypePair{buf_ty, el_ty}, [&] {
auto* buf_ast_ty = CreateASTTypeFor(&ctx, buf_ty);
auto* el_ast_ty = CreateASTTypeFor(&ctx, el_ty);
ast::VariableList params{
// Note: The buffer parameter requires the kStorage StorageClass in
// order for HLSL to emit this as a ByteAddressBuffer.
ctx.dst->create<ast::Variable>(
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage,
ctx.Clone(buf_ty), true, nullptr, ast::DecorationList{}),
ctx.dst->Sym("buffer"), ast::StorageClass::kStorage, buf_ast_ty,
true, nullptr, ast::DecorationList{}),
ctx.dst->Param("offset", ctx.dst->ty.u32()),
ctx.dst->Param("value", ctx.Clone(el_ty)),
ctx.dst->Param("value", el_ast_ty),
};
ast::Function* func = nullptr;
if (auto* intrinsic = IntrinsicStoreFor(ctx.dst, el_ty)) {
@@ -512,9 +534,8 @@ struct State {
auto* offset = ctx.dst->Add("offset", member->Offset());
auto* access = ctx.dst->MemberAccessor(
"value", ctx.Clone(member->Declaration()->symbol()));
Symbol store =
StoreFunc(ctx, insert_after, buf_ty,
member->Declaration()->type()->UnwrapAll());
Symbol store = StoreFunc(ctx, insert_after, buf_ty,
member->Type()->UnwrapAll());
auto* call = ctx.dst->Call(store, "buffer", offset, access);
body.emplace_back(ctx.dst->create<ast::CallStatement>(call));
}
@@ -541,8 +562,6 @@ struct State {
}
};
} // namespace
DecomposeStorageAccess::Intrinsic::Intrinsic(ProgramID program_id, Type ty)
: Base(program_id), type(ty) {}
DecomposeStorageAccess::Intrinsic::~Intrinsic() = default;
@@ -630,11 +649,11 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage) {
// Variable to a storage buffer
state.AddAccesss(ident, {
var,
ToOffset(0u),
var->Type()->UnwrapAll(),
});
state.AddAccess(ident, {
var,
ToOffset(0u),
var->Type()->UnwrapAll(),
});
}
}
continue;
@@ -649,7 +668,7 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
auto* vec_ty = access.type->As<sem::Vector>();
auto offset =
Mul(ScalarSize(vec_ty->type()), swizzle->Indices()[0]);
state.AddAccesss(
state.AddAccess(
accessor, {
access.var,
Add(std::move(access.offset), std::move(offset)),
@@ -663,12 +682,12 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
auto* member =
sem.Get(str_ty)->FindMember(accessor->member()->symbol());
auto offset = member->Offset();
state.AddAccesss(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
member->Declaration()->type()->UnwrapAll(),
});
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
member->Type()->UnwrapAll(),
});
}
}
continue;
@@ -680,34 +699,34 @@ Output DecomposeStorageAccess::Run(const Program* in, const DataMap&) {
if (auto* arr_ty = access.type->As<sem::ArrayType>()) {
auto stride = sem.Get(arr_ty)->Stride();
auto offset = Mul(stride, accessor->idx_expr());
state.AddAccesss(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
arr_ty->type()->UnwrapAll(),
});
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
arr_ty->type()->UnwrapAll(),
});
continue;
}
if (auto* vec_ty = access.type->As<sem::Vector>()) {
auto offset = Mul(ScalarSize(vec_ty->type()), accessor->idx_expr());
state.AddAccesss(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
vec_ty->type()->UnwrapAll(),
});
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
vec_ty->type()->UnwrapAll(),
});
continue;
}
if (auto* mat_ty = access.type->As<sem::Matrix>()) {
auto offset = Mul(MatrixColumnStride(mat_ty), accessor->idx_expr());
auto* vec_ty = ctx.dst->create<sem::Vector>(
ctx.Clone(mat_ty->type()->UnwrapAll()), mat_ty->rows());
state.AddAccesss(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
vec_ty,
});
state.AddAccess(accessor,
{
access.var,
Add(std::move(access.offset), std::move(offset)),
vec_ty,
});
continue;
}
}

View File

@@ -95,6 +95,8 @@ class DecomposeStorageAccess : public Transform {
/// @param data optional extra transform-specific data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) override;
struct State;
};
} // namespace transform

View File

@@ -100,7 +100,7 @@ void Hlsl::PromoteInitializersToConstVar(CloneContext& ctx) const {
// Create a new symbol for the constant
auto dst_symbol = ctx.dst->Sym();
// Clone the type
auto* dst_ty = ctx.Clone(src_ty);
auto* dst_ty = ctx.Clone(src_init->type());
// Clone the initializer
auto* dst_init = ctx.Clone(src_init);
// Construct the constant that holds the hoisted initializer

View File

@@ -69,7 +69,7 @@ Output SingleEntryPoint::Run(const Program* in, const DataMap& data) {
// Clone any module-scope variables, types, and functions that are statically
// referenced by the target entry point.
for (auto* decl : in->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<sem::Type>()) {
if (auto* ty = decl->As<ast::NamedType>()) {
// TODO(jrprice): Strip unused types.
out.AST().AddConstructedType(ctx.Clone(ty));
} else if (auto* var = decl->As<ast::Variable>()) {

View File

@@ -23,6 +23,7 @@
#include "src/program_builder.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/struct.h"
#include "src/sem/variable.h"
namespace tint {
@@ -110,11 +111,11 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
// ```
// Strip entry point IO decorations from struct declarations.
for (auto ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<sem::StructType>()) {
for (auto* ty : ctx.src->AST().ConstructedTypes()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
// Build new list of struct members without entry point IO decorations.
ast::StructMemberList new_struct_members;
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : struct_ty->members()) {
ast::DecorationList new_decorations = RemoveDecorations(
&ctx, member->decorations(), [](const ast::Decoration* deco) {
return deco
@@ -126,52 +127,52 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
}
// Redeclare the struct.
auto new_struct_name = ctx.Clone(struct_ty->impl()->name());
auto new_struct_name = ctx.Clone(struct_ty->name());
auto* new_struct =
ctx.dst->create<sem::StructType>(ctx.dst->create<ast::Struct>(
new_struct_name, new_struct_members,
ctx.Clone(struct_ty->impl()->decorations())));
ctx.dst->create<ast::Struct>(new_struct_name, new_struct_members,
ctx.Clone(struct_ty->decorations()));
ctx.Replace(struct_ty, new_struct);
}
}
for (auto* func : ctx.src->AST().Functions()) {
if (!func->IsEntryPoint()) {
for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
for (auto* param : func->params()) {
for (auto* param : func->Parameters()) {
Symbol new_var = HoistToInputVariables(
ctx, func, ctx.src->Sem().Get(param)->Type(),
ctx.src->Sem().Get(param)->DeclaredType(), param->decorations());
ctx, func_ast, param->Type(), param->Declaration()->type(),
param->Declaration()->decorations());
// Replace all uses of the function parameter with the new variable.
for (auto* user : ctx.src->Sem().Get(param)->Users()) {
for (auto* user : param->Users()) {
ctx.Replace<ast::Expression>(user->Declaration(),
ctx.dst->Expr(new_var));
}
}
if (!func->return_type()->Is<sem::Void>()) {
if (!func->ReturnType()->Is<sem::Void>()) {
ast::StatementList stores;
auto store_value_symbol = ctx.dst->Sym();
HoistToOutputVariables(
ctx, func, func->return_type(), func->return_type(),
func->return_type_decorations(), {}, store_value_symbol, stores);
ctx, func_ast, func->ReturnType(), func_ast->return_type(),
func_ast->return_type_decorations(), {}, store_value_symbol, stores);
// Create a function that writes a return value to all output variables.
auto* store_value =
ctx.dst->Param(store_value_symbol, ctx.Clone(func->return_type()));
auto* store_value = ctx.dst->Param(store_value_symbol,
ctx.Clone(func_ast->return_type()));
auto return_func_symbol = ctx.dst->Sym();
auto* return_func = ctx.dst->create<ast::Function>(
return_func_symbol, ast::VariableList{store_value},
ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
ast::DecorationList{}, ast::DecorationList{});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, return_func);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast,
return_func);
// Replace all return statements with calls to the output function.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
for (auto* ret : func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
ctx.InsertBefore(ret_sem->Block()->statements(), ret,
@@ -181,11 +182,13 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
}
// Rewrite the function header to remove the parameters and return value.
auto name = ctx.Clone(func_ast->symbol());
auto* body = ctx.Clone(func_ast->body());
auto decos = ctx.Clone(func_ast->decorations());
auto* new_func = ctx.dst->create<ast::Function>(
func->source(), ctx.Clone(func->symbol()), ast::VariableList{},
ctx.dst->ty.void_(), ctx.Clone(func->body()),
ctx.Clone(func->decorations()), ast::DecorationList{});
ctx.Replace(func, new_func);
func_ast->source(), name, ast::VariableList{}, ctx.dst->ty.void_(),
body, decos, ast::DecorationList{});
ctx.Replace(func_ast, new_func);
}
}
@@ -253,7 +256,7 @@ Symbol Spirv::HoistToInputVariables(
CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations) const {
if (!ty->Is<sem::StructType>()) {
// Base case: create a global variable and return.
@@ -273,9 +276,10 @@ Symbol Spirv::HoistToInputVariables(
// Recurse into struct members and build the initializer list.
std::vector<Symbol> init_value_names;
auto* struct_ty = ty->As<sem::StructType>();
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : ctx.src->Sem().Get(struct_ty)->Members()) {
auto member_var = HoistToInputVariables(
ctx, func, member->type(), member->type(), member->decorations());
ctx, func, member->Type(), member->Declaration()->type(),
member->Declaration()->decorations());
init_value_names.emplace_back(member_var);
}
@@ -302,7 +306,7 @@ Symbol Spirv::HoistToInputVariables(
void Spirv::HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,
@@ -333,11 +337,12 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
// Recurse into struct members.
auto* struct_ty = ty->As<sem::StructType>();
for (auto* member : struct_ty->impl()->members()) {
member_accesses.push_back(ctx.Clone(member->symbol()));
HoistToOutputVariables(ctx, func, member->type(), member->type(),
member->decorations(), member_accesses, store_value,
stores);
for (auto* member : ctx.src->Sem().Get(struct_ty)->Members()) {
member_accesses.push_back(ctx.Clone(member->Declaration()->symbol()));
HoistToOutputVariables(ctx, func, member->Type(),
member->Declaration()->type(),
member->Declaration()->decorations(),
member_accesses, store_value, stores);
member_accesses.pop_back();
}
}

View File

@@ -60,7 +60,7 @@ class Spirv : public Transform {
Symbol HoistToInputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations) const;
/// Recursively create module-scope output variables for `ty` and build a list
@@ -74,7 +74,7 @@ class Spirv : public Transform {
void HoistToOutputVariables(CloneContext& ctx,
const ast::Function* func,
sem::Type* ty,
sem::Type* declared_ty,
ast::Type* declared_ty,
const ast::DecorationList& decorations,
std::vector<Symbol> member_accesses,
Symbol store_value,

View File

@@ -49,7 +49,7 @@ ast::Function* Transform::CloneWithStatementsAtStart(
auto source = ctx->Clone(in->source());
auto symbol = ctx->Clone(in->symbol());
auto params = ctx->Clone(in->params());
auto return_type = ctx->Clone(in->return_type());
auto* return_type = ctx->Clone(in->return_type());
auto* body = ctx->dst->create<ast::BlockStatement>(
ctx->Clone(in->body()->source()), statements);
auto decos = ctx->Clone(in->decorations());

View File

@@ -184,7 +184,7 @@ struct State {
// identifier strings instead of pointers, so we don't need to update
// any other place in the AST.
auto name = ctx.Clone(v->symbol());
auto* replacement = ctx.dst->Var(name, ctx.Clone(v->declared_type()),
auto* replacement = ctx.dst->Var(name, ctx.Clone(v->type()),
ast::StorageClass::kPrivate);
location_to_expr[location] = [this, name]() {
return ctx.dst->Expr(name);
@@ -212,9 +212,9 @@ struct State {
{
ctx.dst->create<ast::StructBlockDecoration>(),
});
auto access =
ctx.dst->ty.access(ast::AccessControl::kReadOnly, struct_type);
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
auto access =
ctx.dst->ty.access(ast::AccessControl::kReadOnly, struct_type);
// The decorated variable with struct type
ctx.dst->Global(
GetVertexBufferName(i), access, ast::StorageClass::kStorage, nullptr,
@@ -369,7 +369,7 @@ struct State {
/// @param count how many elements the vector has
ast::Expression* AccessVec(uint32_t buffer,
uint32_t element_stride,
sem::Type* base_type,
ast::Type* base_type,
VertexFormat base_format,
uint32_t count) {
ast::ExpressionList expr_list;
@@ -381,7 +381,7 @@ struct State {
}
return ctx.dst->create<ast::TypeConstructorExpression>(
ctx.dst->create<sem::Vector>(base_type, count), std::move(expr_list));
ctx.dst->create<ast::Vector>(base_type, count), std::move(expr_list));
}
/// Process a non-struct entry point parameter.
@@ -394,7 +394,7 @@ struct State {
ast::GetDecoration<ast::LocationDecoration>(param->decorations())) {
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol());
auto* func_var_type = ctx.Clone(param->declared_type());
auto* func_var_type = ctx.Clone(param->type());
auto* func_var = ctx.dst->Var(func_var_sym, func_var_type,
ast::StorageClass::kFunction);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
@@ -428,18 +428,16 @@ struct State {
/// instance_index builtins.
/// @param func the entry point function
/// @param param the parameter to process
void ProcessStructParameter(ast::Function* func, ast::Variable* param) {
auto* struct_ty = param->declared_type()->As<sem::StructType>();
if (!struct_ty) {
TINT_ICE(ctx.dst->Diagnostics()) << "Invalid struct parameter";
}
/// @param struct_ty the structure type
void ProcessStructParameter(ast::Function* func,
ast::Variable* param,
ast::Struct* struct_ty) {
auto param_sym = ctx.Clone(param->symbol());
// Process the struct members.
bool has_locations = false;
ast::StructMemberList members_to_clone;
for (auto* member : struct_ty->impl()->members()) {
for (auto* member : struct_ty->members()) {
auto member_sym = ctx.Clone(member->symbol());
std::function<ast::Expression*()> member_expr = [this, param_sym,
member_sym]() {
@@ -472,7 +470,7 @@ struct State {
}
// Create a function-scope variable to replace the parameter.
auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->declared_type()),
auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type()),
ast::StorageClass::kFunction);
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->Decl(func_var));
@@ -482,7 +480,7 @@ struct State {
ast::StructMemberList new_members;
for (auto* member : members_to_clone) {
auto member_sym = ctx.Clone(member->symbol());
auto member_type = ctx.Clone(member->type());
auto* member_type = ctx.Clone(member->type());
auto member_decos = ctx.Clone(member->decorations());
new_members.push_back(
ctx.dst->Member(member_sym, member_type, std::move(member_decos)));
@@ -514,8 +512,8 @@ struct State {
// Process entry point parameters.
for (auto* param : func->params()) {
auto* sem = ctx.src->Sem().Get(param);
if (sem->Type()->Is<sem::StructType>()) {
ProcessStructParameter(func, param);
if (auto* str = sem->Type()->As<sem::StructType>()) {
ProcessStructParameter(func, param, str->impl());
} else {
ProcessNonStructParameter(func, param);
}
@@ -553,7 +551,7 @@ struct State {
// Rewrite the function header with the new parameters.
auto func_sym = ctx.Clone(func->symbol());
auto ret_type = ctx.Clone(func->return_type());
auto* ret_type = ctx.Clone(func->return_type());
auto* body = ctx.Clone(func->body());
auto decos = ctx.Clone(func->decorations());
auto ret_decos = ctx.Clone(func->return_type_decorations());