mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-16 00:17:03 +00:00
Make all ast and sem pointers const
And remove a whole load of const_cast hackery. Semantic nodes may contain internally mutable fields (although only ever modified during resolving), so these are always passed by `const` pointer. While all AST nodes are internally immutable, we have decided that pointers to AST nodes should also be marked `const`, for consistency. There's still a collection of const_cast calls in the Resolver. These will be fixed up in a later change. Bug: tint:745 Change-Id: I046309b8e586772605fc0fe6b2d27f28806d40ef Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/66606 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@chromium.org> Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
7d0fc07b20
commit
8648120bbe
@@ -64,7 +64,7 @@ void ArrayLengthFromUniform::Run(CloneContext& ctx,
|
||||
|
||||
// Get (or create, on first call) the uniform buffer that will receive the
|
||||
// size of each storage buffer in the module.
|
||||
ast::Variable* buffer_size_ubo = nullptr;
|
||||
const ast::Variable* buffer_size_ubo = nullptr;
|
||||
auto get_ubo = [&]() {
|
||||
if (!buffer_size_ubo) {
|
||||
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
|
||||
|
||||
@@ -89,7 +89,7 @@ class ArrayLengthFromUniform
|
||||
~Result() override;
|
||||
|
||||
/// True if the transform generated the buffer sizes UBO.
|
||||
bool const needs_buffer_sizes;
|
||||
const bool needs_buffer_sizes;
|
||||
};
|
||||
|
||||
protected:
|
||||
|
||||
@@ -129,7 +129,7 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
|
||||
return;
|
||||
}
|
||||
auto* ty = sem->Type()->UnwrapRef();
|
||||
ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
|
||||
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
|
||||
auto* new_var = ctx.dst->create<ast::Variable>(
|
||||
ctx.Clone(var->source), ctx.Clone(var->symbol),
|
||||
var->declared_storage_class, ac, inner_ty, var->is_const,
|
||||
|
||||
@@ -54,14 +54,14 @@ class BindingRemapper : public Castable<BindingRemapper, Transform> {
|
||||
~Remappings() override;
|
||||
|
||||
/// A map of old binding point to new binding point
|
||||
BindingPoints const binding_points;
|
||||
const BindingPoints binding_points;
|
||||
|
||||
/// A map of old binding point to new access controls
|
||||
AccessControls const access_controls;
|
||||
const AccessControls access_controls;
|
||||
|
||||
/// If true, then validation will be disabled for binding point collisions
|
||||
/// generated by this transform
|
||||
bool const allow_collisions;
|
||||
const bool allow_collisions;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
|
||||
@@ -61,7 +61,7 @@ std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
|
||||
return "intrinsic_buffer_size";
|
||||
}
|
||||
|
||||
CalculateArrayLength::BufferSizeIntrinsic*
|
||||
const CalculateArrayLength::BufferSizeIntrinsic*
|
||||
CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
|
||||
return ctx->dst->ASTNodes().Create<CalculateArrayLength::BufferSizeIntrinsic>(
|
||||
ctx->dst->ID());
|
||||
|
||||
@@ -48,7 +48,7 @@ class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
|
||||
/// Performs a deep clone of this object using the CloneContext `ctx`.
|
||||
/// @param ctx the clone context
|
||||
/// @return the newly cloned object
|
||||
BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
|
||||
const BufferSizeIntrinsic* Clone(CloneContext* ctx) const override;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
|
||||
@@ -82,11 +82,11 @@ struct CanonicalizeEntryPointIO::State {
|
||||
/// The name of the output value.
|
||||
std::string name;
|
||||
/// The type of the output value.
|
||||
ast::Type* type;
|
||||
const ast::Type* type;
|
||||
/// The shader IO attributes.
|
||||
ast::DecorationList attributes;
|
||||
/// The value itself.
|
||||
ast::Expression* value;
|
||||
const ast::Expression* value;
|
||||
};
|
||||
|
||||
/// The clone context.
|
||||
@@ -94,9 +94,9 @@ struct CanonicalizeEntryPointIO::State {
|
||||
/// The transform config.
|
||||
CanonicalizeEntryPointIO::Config const cfg;
|
||||
/// The entry point function (AST).
|
||||
ast::Function* func_ast;
|
||||
const ast::Function* func_ast;
|
||||
/// The entry point function (SEM).
|
||||
sem::Function const* func_sem;
|
||||
const sem::Function* func_sem;
|
||||
|
||||
/// The new entry point wrapper function's parameters.
|
||||
ast::VariableList wrapper_ep_parameters;
|
||||
@@ -121,7 +121,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||
/// @param function the entry point function
|
||||
State(CloneContext& context,
|
||||
const CanonicalizeEntryPointIO::Config& config,
|
||||
ast::Function* function)
|
||||
const ast::Function* function)
|
||||
: ctx(context),
|
||||
cfg(config),
|
||||
func_ast(function),
|
||||
@@ -154,9 +154,9 @@ struct CanonicalizeEntryPointIO::State {
|
||||
/// @param type the type of the shader input
|
||||
/// @param attributes the attributes to apply to the shader input
|
||||
/// @returns an expression which evaluates to the value of the shader input
|
||||
ast::Expression* AddInput(std::string name,
|
||||
sem::Type* type,
|
||||
ast::DecorationList attributes) {
|
||||
const ast::Expression* AddInput(std::string name,
|
||||
const sem::Type* type,
|
||||
ast::DecorationList attributes) {
|
||||
auto* ast_type = CreateASTTypeFor(ctx, type);
|
||||
if (cfg.shader_style == ShaderStyle::kSpirv) {
|
||||
// Vulkan requires that integer user-defined fragment inputs are
|
||||
@@ -175,7 +175,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||
|
||||
// Create the global variable and use its value for the shader input.
|
||||
auto symbol = ctx.dst->Symbols().New(name);
|
||||
ast::Expression* value = ctx.dst->Expr(symbol);
|
||||
const ast::Expression* value = ctx.dst->Expr(symbol);
|
||||
if (HasSampleMask(attributes)) {
|
||||
// Vulkan requires the type of a SampleMask builtin to be an array.
|
||||
// Declare it as array<u32, 1> and then load the first element.
|
||||
@@ -212,9 +212,9 @@ struct CanonicalizeEntryPointIO::State {
|
||||
/// @param attributes the attributes to apply to the shader output
|
||||
/// @param value the value of the shader output
|
||||
void AddOutput(std::string name,
|
||||
sem::Type* type,
|
||||
const sem::Type* type,
|
||||
ast::DecorationList attributes,
|
||||
ast::Expression* value) {
|
||||
const ast::Expression* value) {
|
||||
// Vulkan requires that integer user-defined vertex outputs are
|
||||
// always decorated with `Flat`.
|
||||
if (cfg.shader_style == ShaderStyle::kSpirv &&
|
||||
@@ -417,7 +417,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||
// Create the global variable and assign it the output value.
|
||||
auto name = ctx.dst->Symbols().New(outval.name);
|
||||
auto* type = outval.type;
|
||||
ast::Expression* lhs = ctx.dst->Expr(name);
|
||||
const ast::Expression* lhs = ctx.dst->Expr(name);
|
||||
if (HasSampleMask(attributes)) {
|
||||
// Vulkan requires the type of a SampleMask builtin to be an array.
|
||||
// Declare it as array<u32, 1> and then store to the first element.
|
||||
@@ -432,7 +432,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||
|
||||
// Recreate the original function without entry point attributes and call it.
|
||||
/// @returns the inner function call expression
|
||||
ast::CallExpression* CallInnerFunction() {
|
||||
const ast::CallExpression* CallInnerFunction() {
|
||||
// Add a suffix to the function name, as the wrapper function will take the
|
||||
// original entry point name.
|
||||
auto ep_name = ctx.src->Symbols().NameFor(func_ast->symbol);
|
||||
@@ -492,7 +492,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||
auto* call_inner = CallInnerFunction();
|
||||
|
||||
// Process the return type, and start building the wrapper function body.
|
||||
std::function<ast::Type*()> wrapper_ret_type = [&] {
|
||||
std::function<const ast::Type*()> wrapper_ret_type = [&] {
|
||||
return ctx.dst->ty.void_();
|
||||
};
|
||||
if (func_sem->ReturnType()->Is<sem::Void>()) {
|
||||
|
||||
@@ -110,14 +110,14 @@ class CanonicalizeEntryPointIO
|
||||
~Config() override;
|
||||
|
||||
/// The approach to use for emitting shader IO.
|
||||
ShaderStyle const shader_style;
|
||||
const ShaderStyle shader_style;
|
||||
|
||||
/// A fixed sample mask to combine into masks produced by fragment shaders.
|
||||
uint32_t const fixed_sample_mask;
|
||||
const uint32_t fixed_sample_mask;
|
||||
|
||||
/// Set to `true` to generate a pointsize builtin and have it set to 1.0
|
||||
/// from all vertex shaders in the module.
|
||||
bool const emit_vertex_point_size;
|
||||
const bool emit_vertex_point_size;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
|
||||
@@ -51,17 +51,17 @@ namespace {
|
||||
/// offsets for storage and uniform buffer accesses.
|
||||
struct Offset : Castable<Offset> {
|
||||
/// @returns builds and returns the ast::Expression in `ctx.dst`
|
||||
virtual ast::Expression* Build(CloneContext& ctx) const = 0;
|
||||
virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
|
||||
};
|
||||
|
||||
/// OffsetExpr is an implementation of Offset that clones and casts the given
|
||||
/// expression to `u32`.
|
||||
struct OffsetExpr : Offset {
|
||||
ast::Expression* const expr = nullptr;
|
||||
const ast::Expression* const expr = nullptr;
|
||||
|
||||
explicit OffsetExpr(ast::Expression* e) : expr(e) {}
|
||||
explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
|
||||
|
||||
ast::Expression* Build(CloneContext& ctx) const override {
|
||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
||||
auto* type = ctx.src->Sem().Get(expr)->Type()->UnwrapRef();
|
||||
auto* res = ctx.Clone(expr);
|
||||
if (!type->Is<sem::U32>()) {
|
||||
@@ -78,7 +78,7 @@ struct OffsetLiteral : Castable<OffsetLiteral, Offset> {
|
||||
|
||||
explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
|
||||
|
||||
ast::Expression* Build(CloneContext& ctx) const override {
|
||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
||||
return ctx.dst->Expr(literal);
|
||||
}
|
||||
};
|
||||
@@ -90,7 +90,7 @@ struct OffsetBinOp : Offset {
|
||||
Offset const* lhs = nullptr;
|
||||
Offset const* rhs = nullptr;
|
||||
|
||||
ast::Expression* Build(CloneContext& ctx) const override {
|
||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
||||
return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx),
|
||||
rhs->Build(ctx));
|
||||
}
|
||||
@@ -304,9 +304,9 @@ struct DecomposeMemoryAccess::State {
|
||||
/// expressions chain the access.
|
||||
/// Subset of #expression_order, as expressions are not removed from
|
||||
/// #expression_order.
|
||||
std::unordered_map<ast::Expression*, BufferAccess> accesses;
|
||||
std::unordered_map<const ast::Expression*, BufferAccess> accesses;
|
||||
/// The visited order of AST expressions (superset of #accesses)
|
||||
std::vector<ast::Expression*> expression_order;
|
||||
std::vector<const ast::Expression*> expression_order;
|
||||
/// [buffer-type, element-type] -> load function name
|
||||
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
|
||||
/// [buffer-type, element-type] -> store function name
|
||||
@@ -330,7 +330,7 @@ struct DecomposeMemoryAccess::State {
|
||||
|
||||
/// @param expr the expression to convert to an Offset
|
||||
/// @returns an Offset for the given ast::Expression
|
||||
const Offset* ToOffset(ast::Expression* expr) {
|
||||
const Offset* ToOffset(const ast::Expression* expr) {
|
||||
if (auto* scalar = expr->As<ast::ScalarConstructorExpression>()) {
|
||||
if (auto* u32 = scalar->literal->As<ast::UintLiteral>()) {
|
||||
return offsets_.Create<OffsetLiteral>(u32->value);
|
||||
@@ -415,7 +415,7 @@ struct DecomposeMemoryAccess::State {
|
||||
/// to #expression_order.
|
||||
/// @param expr the expression that performs the access
|
||||
/// @param access the access
|
||||
void AddAccess(ast::Expression* expr, const BufferAccess& access) {
|
||||
void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
|
||||
TINT_ASSERT(Transform, access.type);
|
||||
accesses.emplace(expr, access);
|
||||
expression_order.emplace_back(expr);
|
||||
@@ -426,7 +426,7 @@ struct DecomposeMemoryAccess::State {
|
||||
/// `node`, an invalid BufferAccess is returned.
|
||||
/// @param node the expression that performed an access
|
||||
/// @return the BufferAccess for the given expression
|
||||
BufferAccess TakeAccess(ast::Expression* node) {
|
||||
BufferAccess TakeAccess(const ast::Expression* node) {
|
||||
auto lhs_it = accesses.find(node);
|
||||
if (lhs_it == accesses.end()) {
|
||||
return {};
|
||||
@@ -793,7 +793,7 @@ std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
|
||||
const DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
|
||||
CloneContext* ctx) const {
|
||||
return ctx->dst->ASTNodes().Create<DecomposeMemoryAccess::Intrinsic>(
|
||||
ctx->dst->ID(), op, storage_class, type);
|
||||
|
||||
@@ -88,16 +88,16 @@ class DecomposeMemoryAccess
|
||||
/// Performs a deep clone of this object using the CloneContext `ctx`.
|
||||
/// @param ctx the clone context
|
||||
/// @return the newly cloned object
|
||||
Intrinsic* Clone(CloneContext* ctx) const override;
|
||||
const Intrinsic* Clone(CloneContext* ctx) const override;
|
||||
|
||||
/// The op of the intrinsic
|
||||
Op const op;
|
||||
const Op op;
|
||||
|
||||
/// The storage class of the buffer this intrinsic operates on
|
||||
ast::StorageClass const storage_class;
|
||||
|
||||
/// The type of the intrinsic
|
||||
DataType const type;
|
||||
const DataType type;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
|
||||
@@ -37,11 +37,11 @@ struct MatrixInfo {
|
||||
/// The stride in bytes between columns of the matrix
|
||||
uint32_t stride = 0;
|
||||
/// The type of the matrix
|
||||
sem::Matrix const* matrix = nullptr;
|
||||
const sem::Matrix* matrix = nullptr;
|
||||
|
||||
/// @returns a new ast::Array that holds an vector column for each row of the
|
||||
/// matrix.
|
||||
ast::Array* array(ProgramBuilder* b) const {
|
||||
const ast::Array* array(ProgramBuilder* b) const {
|
||||
return b->ty.array(b->ty.vec<ProgramBuilder::f32>(matrix->rows()),
|
||||
matrix->columns(), stride);
|
||||
}
|
||||
@@ -126,7 +126,7 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// Scan the program for all storage and uniform structure matrix members with
|
||||
// a custom stride attribute. Replace these matrices with an equivalent array,
|
||||
// and populate the `decomposed` map with the members that have been replaced.
|
||||
std::unordered_map<ast::StructMember*, MatrixInfo> decomposed;
|
||||
std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
|
||||
GatherCustomStrideMatrixMembers(
|
||||
ctx.src, [&](const sem::StructMember* member, sem::Matrix* matrix,
|
||||
uint32_t stride) {
|
||||
@@ -144,19 +144,19 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// preserve these without calling conversion functions.
|
||||
// Example:
|
||||
// ssbo.mat[2] -> ssbo.mat[2]
|
||||
ctx.ReplaceAll(
|
||||
[&](ast::ArrayAccessorExpression* expr) -> ast::ArrayAccessorExpression* {
|
||||
if (auto* access =
|
||||
ctx.src->Sem().Get<sem::StructMemberAccess>(expr->array)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it != decomposed.end()) {
|
||||
auto* obj = ctx.CloneWithoutTransform(expr->array);
|
||||
auto* idx = ctx.Clone(expr->index);
|
||||
return ctx.dst->IndexAccessor(obj, idx);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
ctx.ReplaceAll([&](const ast::ArrayAccessorExpression* expr)
|
||||
-> const ast::ArrayAccessorExpression* {
|
||||
if (auto* access =
|
||||
ctx.src->Sem().Get<sem::StructMemberAccess>(expr->array)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it != decomposed.end()) {
|
||||
auto* obj = ctx.CloneWithoutTransform(expr->array);
|
||||
auto* idx = ctx.Clone(expr->index);
|
||||
return ctx.dst->IndexAccessor(obj, idx);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// For all struct member accesses to the matrix on the LHS of an assignment,
|
||||
// we need to convert the matrix to the array before assigning to the
|
||||
@@ -164,7 +164,8 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// Example:
|
||||
// ssbo.mat = mat_to_arr(m)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
|
||||
ctx.ReplaceAll([&](ast::AssignmentStatement* stmt) -> ast::Statement* {
|
||||
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt)
|
||||
-> const ast::Statement* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
@@ -206,42 +207,44 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// matrix type. Example:
|
||||
// m = arr_to_mat(ssbo.mat)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
|
||||
ctx.ReplaceAll([&](ast::MemberAccessorExpression* expr) -> ast::Expression* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
|
||||
auto name = ctx.dst->Symbols().New(
|
||||
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" +
|
||||
std::to_string(info.stride));
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
|
||||
auto name = ctx.dst->Symbols().New(
|
||||
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" +
|
||||
std::to_string(info.stride));
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
|
||||
auto arr = ctx.dst->Sym("arr");
|
||||
ast::ExpressionList columns(info.matrix->columns());
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size()); i++) {
|
||||
columns[i] = ctx.dst->IndexAccessor(arr, i);
|
||||
auto arr = ctx.dst->Sym("arr");
|
||||
ast::ExpressionList columns(info.matrix->columns());
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(columns.size());
|
||||
i++) {
|
||||
columns[i] = ctx.dst->IndexAccessor(arr, i);
|
||||
}
|
||||
ctx.dst->Func(
|
||||
name,
|
||||
{
|
||||
ctx.dst->Param(arr, array()),
|
||||
},
|
||||
matrix(),
|
||||
{
|
||||
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
|
||||
}
|
||||
ctx.dst->Func(
|
||||
name,
|
||||
{
|
||||
ctx.dst->Param(arr, array()),
|
||||
},
|
||||
matrix(),
|
||||
{
|
||||
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
|
||||
});
|
||||
return name;
|
||||
return nullptr;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ void FirstIndexOffset::Run(CloneContext& ctx,
|
||||
// parameters) or structure member accesses.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* var = node->As<ast::Variable>()) {
|
||||
for (ast::Decoration* dec : var->decorations) {
|
||||
for (auto* dec : var->decorations) {
|
||||
if (auto* builtin_dec = dec->As<ast::BuiltinDecoration>()) {
|
||||
ast::Builtin builtin = builtin_dec->builtin;
|
||||
if (builtin == ast::Builtin::kVertexIndex) {
|
||||
@@ -97,7 +97,7 @@ void FirstIndexOffset::Run(CloneContext& ctx,
|
||||
}
|
||||
}
|
||||
if (auto* member = node->As<ast::StructMember>()) {
|
||||
for (ast::Decoration* dec : member->decorations) {
|
||||
for (auto* dec : member->decorations) {
|
||||
if (auto* builtin_dec = dec->As<ast::BuiltinDecoration>()) {
|
||||
ast::Builtin builtin = builtin_dec->builtin;
|
||||
if (builtin == ast::Builtin::kVertexIndex) {
|
||||
@@ -147,28 +147,29 @@ void FirstIndexOffset::Run(CloneContext& ctx,
|
||||
});
|
||||
|
||||
// Fix up all references to the builtins with the offsets
|
||||
ctx.ReplaceAll([=, &ctx](ast::Expression* expr) -> ast::Expression* {
|
||||
if (auto* sem = ctx.src->Sem().Get(expr)) {
|
||||
if (auto* user = sem->As<sem::VariableUser>()) {
|
||||
auto it = builtin_vars.find(user->Variable());
|
||||
if (it != builtin_vars.end()) {
|
||||
return ctx.dst->Add(
|
||||
ctx.CloneWithoutTransform(expr),
|
||||
ctx.dst->MemberAccessor(buffer_name, it->second));
|
||||
ctx.ReplaceAll(
|
||||
[=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
|
||||
if (auto* sem = ctx.src->Sem().Get(expr)) {
|
||||
if (auto* user = sem->As<sem::VariableUser>()) {
|
||||
auto it = builtin_vars.find(user->Variable());
|
||||
if (it != builtin_vars.end()) {
|
||||
return ctx.dst->Add(
|
||||
ctx.CloneWithoutTransform(expr),
|
||||
ctx.dst->MemberAccessor(buffer_name, it->second));
|
||||
}
|
||||
}
|
||||
if (auto* access = sem->As<sem::StructMemberAccess>()) {
|
||||
auto it = builtin_members.find(access->Member());
|
||||
if (it != builtin_members.end()) {
|
||||
return ctx.dst->Add(
|
||||
ctx.CloneWithoutTransform(expr),
|
||||
ctx.dst->MemberAccessor(buffer_name, it->second));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (auto* access = sem->As<sem::StructMemberAccess>()) {
|
||||
auto it = builtin_members.find(access->Member());
|
||||
if (it != builtin_members.end()) {
|
||||
return ctx.dst->Add(
|
||||
ctx.CloneWithoutTransform(expr),
|
||||
ctx.dst->MemberAccessor(buffer_name, it->second));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not interested in this experssion. Just clone.
|
||||
return nullptr;
|
||||
});
|
||||
// Not interested in this experssion. Just clone.
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
@@ -98,13 +98,13 @@ class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
|
||||
~Data() override;
|
||||
|
||||
/// True if the shader uses vertex_index
|
||||
bool const has_vertex_index;
|
||||
const bool has_vertex_index;
|
||||
/// True if the shader uses instance_index
|
||||
bool const has_instance_index;
|
||||
const bool has_instance_index;
|
||||
/// Offset of first vertex into constant buffer
|
||||
uint32_t const first_vertex_offset;
|
||||
const uint32_t first_vertex_offset;
|
||||
/// Offset of first instance into constant buffer
|
||||
uint32_t const first_instance_offset;
|
||||
const uint32_t first_instance_offset;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
|
||||
@@ -31,7 +31,7 @@ FoldConstants::FoldConstants() = default;
|
||||
FoldConstants::~FoldConstants() = default;
|
||||
|
||||
void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* {
|
||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
||||
auto* sem = ctx.src->Sem().Get(expr);
|
||||
if (!sem) {
|
||||
return nullptr;
|
||||
|
||||
@@ -27,7 +27,7 @@ namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
ast::VariableDeclStatement* AsTrivialLetDecl(ast::Statement* stmt) {
|
||||
const ast::VariableDeclStatement* AsTrivialLetDecl(const ast::Statement* stmt) {
|
||||
auto* var_decl = stmt->As<ast::VariableDeclStatement>();
|
||||
if (!var_decl) {
|
||||
return nullptr;
|
||||
|
||||
@@ -26,37 +26,39 @@ ForLoopToLoop::ForLoopToLoop() = default;
|
||||
ForLoopToLoop::~ForLoopToLoop() = default;
|
||||
|
||||
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
ctx.ReplaceAll([&](ast::ForLoopStatement* for_loop) -> ast::Statement* {
|
||||
ast::StatementList stmts;
|
||||
if (auto* cond = for_loop->condition) {
|
||||
// !condition
|
||||
auto* not_cond = ctx.dst->create<ast::UnaryOpExpression>(
|
||||
ast::UnaryOp::kNot, ctx.Clone(cond));
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
|
||||
ast::StatementList stmts;
|
||||
if (auto* cond = for_loop->condition) {
|
||||
// !condition
|
||||
auto* not_cond = ctx.dst->create<ast::UnaryOpExpression>(
|
||||
ast::UnaryOp::kNot, ctx.Clone(cond));
|
||||
|
||||
// { break; }
|
||||
auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
|
||||
// { break; }
|
||||
auto* break_body =
|
||||
ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
|
||||
|
||||
// if (!condition) { break; }
|
||||
stmts.emplace_back(ctx.dst->If(not_cond, break_body));
|
||||
}
|
||||
for (auto* stmt : for_loop->body->statements) {
|
||||
stmts.emplace_back(ctx.Clone(stmt));
|
||||
}
|
||||
// if (!condition) { break; }
|
||||
stmts.emplace_back(ctx.dst->If(not_cond, break_body));
|
||||
}
|
||||
for (auto* stmt : for_loop->body->statements) {
|
||||
stmts.emplace_back(ctx.Clone(stmt));
|
||||
}
|
||||
|
||||
ast::BlockStatement* continuing = nullptr;
|
||||
if (auto* cont = for_loop->continuing) {
|
||||
continuing = ctx.dst->Block(ctx.Clone(cont));
|
||||
}
|
||||
const ast::BlockStatement* continuing = nullptr;
|
||||
if (auto* cont = for_loop->continuing) {
|
||||
continuing = ctx.dst->Block(ctx.Clone(cont));
|
||||
}
|
||||
|
||||
auto* body = ctx.dst->Block(stmts);
|
||||
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
|
||||
auto* body = ctx.dst->Block(stmts);
|
||||
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
|
||||
|
||||
if (auto* init = for_loop->initializer) {
|
||||
return ctx.dst->Block(ctx.Clone(init), loop);
|
||||
}
|
||||
if (auto* init = for_loop->initializer) {
|
||||
return ctx.dst->Block(ctx.Clone(init), loop);
|
||||
}
|
||||
|
||||
return loop;
|
||||
});
|
||||
return loop;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ namespace {
|
||||
/// expression
|
||||
template <typename F>
|
||||
void CollectSavedArrayIndices(const Program* program,
|
||||
ast::Expression* expr,
|
||||
const ast::Expression* expr,
|
||||
F&& cb) {
|
||||
if (auto* a = expr->As<ast::ArrayAccessorExpression>()) {
|
||||
CollectSavedArrayIndices(program, a->array, cb);
|
||||
@@ -95,7 +95,7 @@ void InlinePointerLets::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// * Sub-expressions inside the pointer-typed `let` initializer expression
|
||||
// that have been hoisted to a saved variable are replaced with the saved
|
||||
// variable identifier.
|
||||
ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* {
|
||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
||||
if (current_ptr_let) {
|
||||
// We're currently processing the initializer expression of a
|
||||
// pointer-typed `let` declaration. Look to see if we need to swap this
|
||||
@@ -150,7 +150,7 @@ void InlinePointerLets::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
// to be hoist to temporary "saved" variables.
|
||||
CollectSavedArrayIndices(
|
||||
ctx.src, var->Declaration()->constructor,
|
||||
[&](ast::Expression* idx_expr) {
|
||||
[&](const ast::Expression* idx_expr) {
|
||||
// We have a sub-expression that needs to be saved.
|
||||
// Create a new variable
|
||||
auto saved_name = ctx.dst->Symbols().New(
|
||||
|
||||
@@ -29,7 +29,7 @@ namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
bool IsBlockWithSingleBreak(ast::BlockStatement* block) {
|
||||
bool IsBlockWithSingleBreak(const ast::BlockStatement* block) {
|
||||
if (block->statements.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
@@ -37,8 +37,8 @@ bool IsBlockWithSingleBreak(ast::BlockStatement* block) {
|
||||
}
|
||||
|
||||
bool IsVarUsedByStmt(const sem::Info& sem,
|
||||
ast::Variable* var,
|
||||
ast::Statement* stmt) {
|
||||
const ast::Variable* var,
|
||||
const ast::Statement* stmt) {
|
||||
auto* var_sem = sem.Get(var);
|
||||
for (auto* user : var_sem->Users()) {
|
||||
if (auto* s = user->Stmt()) {
|
||||
@@ -57,7 +57,7 @@ LoopToForLoop::LoopToForLoop() = default;
|
||||
LoopToForLoop::~LoopToForLoop() = default;
|
||||
|
||||
void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
ctx.ReplaceAll([&](ast::LoopStatement* loop) -> ast::Statement* {
|
||||
ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
|
||||
// For loop condition is taken from the first statement in the loop.
|
||||
// This requires an if-statement with either:
|
||||
// * A true block with no else statements, and the true block contains a
|
||||
@@ -90,7 +90,7 @@ void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
|
||||
// The continuing block must be empty or contain a single, assignment or
|
||||
// function call statement.
|
||||
ast::Statement* continuing = nullptr;
|
||||
const ast::Statement* continuing = nullptr;
|
||||
if (auto* loop_cont = loop->continuing) {
|
||||
if (loop_cont->statements.size() != 1) {
|
||||
return nullptr;
|
||||
|
||||
@@ -77,7 +77,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
// Clone the struct and add it to the global declaration list.
|
||||
// Remove the old declaration.
|
||||
auto* ast_str = str->Declaration();
|
||||
ctx.dst->AST().AddTypeDecl(ctx.Clone(const_cast<ast::Struct*>(ast_str)));
|
||||
ctx.dst->AST().AddTypeDecl(ctx.Clone(ast_str));
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), ast_str);
|
||||
} else if (auto* arr = ty->As<sem::Array>()) {
|
||||
CloneStructTypes(arr->ElemType());
|
||||
@@ -90,7 +90,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
using CallList = std::vector<const ast::CallExpression*>;
|
||||
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
|
||||
|
||||
std::vector<ast::Function*> functions_to_process;
|
||||
std::vector<const ast::Function*> functions_to_process;
|
||||
|
||||
// Build a list of functions that transitively reference any private or
|
||||
// workgroup variables, or texture/sampler variables.
|
||||
@@ -123,7 +123,8 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
// rules when this expression is passed to a function.
|
||||
// TODO(jrprice): We should add support for bidirectional SEM tree traversal
|
||||
// so that we can do this on the fly instead.
|
||||
std::unordered_map<ast::IdentifierExpression*, ast::UnaryOpExpression*>
|
||||
std::unordered_map<const ast::IdentifierExpression*,
|
||||
const ast::UnaryOpExpression*>
|
||||
ident_to_address_of;
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
auto* address_of = node->As<ast::UnaryOpExpression>();
|
||||
@@ -248,7 +249,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
// For non-entry points, dereference non-handle pointer parameters.
|
||||
for (auto* user : var->Users()) {
|
||||
if (user->Stmt()->Function() == func_ast) {
|
||||
ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
|
||||
const ast::Expression* expr = ctx.dst->Expr(new_var_symbol);
|
||||
if (is_pointer) {
|
||||
// If this identifier is used by an address-of operator, just
|
||||
// remove the address-of instead of adding a deref, since we
|
||||
@@ -301,7 +302,8 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||
target_var->StorageClass() == ast::StorageClass::kWorkgroup ||
|
||||
target_var->StorageClass() ==
|
||||
ast::StorageClass::kUniformConstant) {
|
||||
ast::Expression* arg = ctx.dst->Expr(var_to_symbol[target_var]);
|
||||
const ast::Expression* arg =
|
||||
ctx.dst->Expr(var_to_symbol[target_var]);
|
||||
if (is_entry_point && !is_handle && !is_workgroup_matrix) {
|
||||
arg = ctx.dst->AddressOf(arg);
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx,
|
||||
|
||||
// Get (or create, on first call) the uniform buffer that will receive the
|
||||
// number of workgroups.
|
||||
ast::Variable* num_workgroups_ubo = nullptr;
|
||||
const ast::Variable* num_workgroups_ubo = nullptr;
|
||||
auto get_ubo = [&]() {
|
||||
if (!num_workgroups_ubo) {
|
||||
auto* num_workgroups_struct = ctx.dst->Structure(
|
||||
|
||||
@@ -28,7 +28,7 @@ namespace tint {
|
||||
namespace transform {
|
||||
namespace {
|
||||
|
||||
using ArrayBuilder = std::function<ast::Array*()>;
|
||||
using ArrayBuilder = std::function<const ast::Array*()>;
|
||||
|
||||
/// PadArray returns a function that constructs a new array in `ctx.dst` with
|
||||
/// the element type padded to account for the explicit stride. PadArray will
|
||||
@@ -55,7 +55,7 @@ ArrayBuilder PadArray(
|
||||
auto name = ctx.dst->Symbols().New("tint_padded_array_element");
|
||||
|
||||
// Examine the element type. Is it also an array?
|
||||
ast::Type* el_ty = nullptr;
|
||||
const ast::Type* el_ty = nullptr;
|
||||
if (auto* el_array = array->ElemType()->As<sem::Array>()) {
|
||||
// Array of array - call PadArray() on the element type
|
||||
if (auto p =
|
||||
@@ -104,7 +104,7 @@ void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
};
|
||||
|
||||
// Replace all array types with their corresponding padded array type
|
||||
ctx.ReplaceAll([&](ast::Type* ast_type) -> ast::Type* {
|
||||
ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
|
||||
auto* type = ctx.src->TypeOf(ast_type);
|
||||
if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
|
||||
if (auto p = pad(array)) {
|
||||
@@ -115,23 +115,24 @@ void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
});
|
||||
|
||||
// Fix up array accessors so `a[1]` becomes `a[1].el`
|
||||
ctx.ReplaceAll(
|
||||
[&](ast::ArrayAccessorExpression* accessor) -> ast::Expression* {
|
||||
if (auto* array = tint::As<sem::Array>(
|
||||
sem.Get(accessor->array)->Type()->UnwrapRef())) {
|
||||
if (pad(array)) {
|
||||
// Array element is wrapped in a structure. Emit a member accessor
|
||||
// to get to the actual array element.
|
||||
auto* idx = ctx.CloneWithoutTransform(accessor);
|
||||
return ctx.dst->MemberAccessor(idx, "el");
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
ctx.ReplaceAll([&](const ast::ArrayAccessorExpression* accessor)
|
||||
-> const ast::Expression* {
|
||||
if (auto* array = tint::As<sem::Array>(
|
||||
sem.Get(accessor->array)->Type()->UnwrapRef())) {
|
||||
if (pad(array)) {
|
||||
// Array element is wrapped in a structure. Emit a member accessor
|
||||
// to get to the actual array element.
|
||||
auto* idx = ctx.CloneWithoutTransform(accessor);
|
||||
return ctx.dst->MemberAccessor(idx, "el");
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
// Fix up array constructors so `A(1,2)` becomes
|
||||
// `A(padded(1), padded(2))`
|
||||
ctx.ReplaceAll([&](ast::TypeConstructorExpression* ctor) -> ast::Expression* {
|
||||
ctx.ReplaceAll([&](const ast::TypeConstructorExpression* ctor)
|
||||
-> const ast::Expression* {
|
||||
if (auto* array =
|
||||
tint::As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) {
|
||||
if (auto p = pad(array)) {
|
||||
|
||||
@@ -1131,7 +1131,7 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) {
|
||||
|
||||
// Swizzles, intrinsic calls and builtin structure members need to keep their
|
||||
// symbols preserved.
|
||||
std::unordered_set<ast::IdentifierExpression*> preserve;
|
||||
std::unordered_set<const ast::IdentifierExpression*> preserve;
|
||||
for (auto* node : in->ASTNodes().Objects()) {
|
||||
if (auto* member = node->As<ast::MemberAccessorExpression>()) {
|
||||
auto* sem = in->Sem().Get(member);
|
||||
@@ -1213,17 +1213,17 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) {
|
||||
return sym_out;
|
||||
});
|
||||
|
||||
ctx.ReplaceAll(
|
||||
[&](ast::IdentifierExpression* ident) -> ast::IdentifierExpression* {
|
||||
if (preserve.count(ident)) {
|
||||
auto sym_in = ident->symbol;
|
||||
auto str = in->Symbols().NameFor(sym_in);
|
||||
auto sym_out = out.Symbols().Register(str);
|
||||
return ctx.dst->create<ast::IdentifierExpression>(
|
||||
ctx.Clone(ident->source), sym_out);
|
||||
}
|
||||
return nullptr; // Clone ident. Uses the symbol remapping above.
|
||||
});
|
||||
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident)
|
||||
-> const ast::IdentifierExpression* {
|
||||
if (preserve.count(ident)) {
|
||||
auto sym_in = ident->symbol;
|
||||
auto str = in->Symbols().NameFor(sym_in);
|
||||
auto sym_out = out.Symbols().Register(str);
|
||||
return ctx.dst->create<ast::IdentifierExpression>(
|
||||
ctx.Clone(ident->source), sym_out);
|
||||
}
|
||||
return nullptr; // Clone ident. Uses the symbol remapping above.
|
||||
});
|
||||
ctx.Clone();
|
||||
|
||||
return Output(Program(std::move(out)),
|
||||
|
||||
@@ -43,7 +43,7 @@ class Renamer : public Castable<Renamer, Transform> {
|
||||
~Data() override;
|
||||
|
||||
/// A map of old symbol name to new symbol name
|
||||
Remappings const remappings;
|
||||
const Remappings remappings;
|
||||
};
|
||||
|
||||
/// Target is an enumerator of rename targets that can be used
|
||||
|
||||
@@ -41,16 +41,19 @@ struct Robustness::State {
|
||||
|
||||
/// Applies the transformation state to `ctx`.
|
||||
void Transform() {
|
||||
ctx.ReplaceAll([&](const ast::ArrayAccessorExpression* expr) {
|
||||
return Transform(expr);
|
||||
});
|
||||
ctx.ReplaceAll(
|
||||
[&](ast::ArrayAccessorExpression* expr) { return Transform(expr); });
|
||||
ctx.ReplaceAll([&](ast::CallExpression* expr) { return Transform(expr); });
|
||||
[&](const ast::CallExpression* expr) { return Transform(expr); });
|
||||
}
|
||||
|
||||
/// Apply bounds clamping to array, vector and matrix indexing
|
||||
/// @param expr the array, vector or matrix index expression
|
||||
/// @return the clamped replacement expression, or nullptr if `expr` should be
|
||||
/// cloned without changes.
|
||||
ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr) {
|
||||
const ast::ArrayAccessorExpression* Transform(
|
||||
const ast::ArrayAccessorExpression* expr) {
|
||||
auto* ret_type = ctx.src->Sem().Get(expr->array)->Type();
|
||||
|
||||
auto* ref = ret_type->As<sem::Reference>();
|
||||
@@ -64,7 +67,7 @@ struct Robustness::State {
|
||||
using u32 = ProgramBuilder::u32;
|
||||
|
||||
struct Value {
|
||||
ast::Expression* expr = nullptr; // If null, then is a constant
|
||||
const ast::Expression* expr = nullptr; // If null, then is a constant
|
||||
union {
|
||||
uint32_t u32 = 0; // use if is_signed == false
|
||||
int32_t i32; // use if is_signed == true
|
||||
@@ -208,7 +211,7 @@ struct Robustness::State {
|
||||
/// @param expr the intrinsic call expression
|
||||
/// @return the clamped replacement call expression, or nullptr if `expr`
|
||||
/// should be cloned without changes.
|
||||
ast::CallExpression* Transform(ast::CallExpression* expr) {
|
||||
const ast::CallExpression* Transform(const ast::CallExpression* expr) {
|
||||
auto* call = ctx.src->Sem().Get(expr);
|
||||
auto* call_target = call->Target();
|
||||
auto* intrinsic = call_target->As<sem::Intrinsic>();
|
||||
@@ -235,7 +238,7 @@ struct Robustness::State {
|
||||
// to clamp both usages.
|
||||
// TODO(bclayton): We probably want to place this into a let so that the
|
||||
// calculation can be reused. This is fiddly to get right.
|
||||
std::function<ast::Expression*()> level_arg;
|
||||
std::function<const ast::Expression*()> level_arg;
|
||||
if (level_idx >= 0) {
|
||||
level_arg = [&] {
|
||||
auto* arg = expr->args[level_idx];
|
||||
|
||||
@@ -35,7 +35,7 @@ Simplify::Simplify() = default;
|
||||
Simplify::~Simplify() = default;
|
||||
|
||||
void Simplify::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* {
|
||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
||||
if (auto* outer = expr->As<ast::UnaryOpExpression>()) {
|
||||
if (auto* inner = outer->expr->As<ast::UnaryOpExpression>()) {
|
||||
if (outer->op == ast::UnaryOp::kAddressOf &&
|
||||
|
||||
@@ -42,7 +42,7 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
|
||||
}
|
||||
|
||||
// Find the target entry point.
|
||||
ast::Function* entry_point = nullptr;
|
||||
const ast::Function* entry_point = nullptr;
|
||||
for (auto* f : ctx.src->AST().Functions()) {
|
||||
if (!f->IsEntryPoint()) {
|
||||
continue;
|
||||
|
||||
@@ -90,7 +90,8 @@ void Transform::RemoveStatement(CloneContext& ctx, ast::Statement* stmt) {
|
||||
<< sem->TypeInfo().name;
|
||||
}
|
||||
|
||||
ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type* ty) {
|
||||
const ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx,
|
||||
const sem::Type* ty) {
|
||||
if (ty->Is<sem::Void>()) {
|
||||
return ctx.dst->create<ast::Void>();
|
||||
}
|
||||
|
||||
@@ -198,7 +198,8 @@ class Transform : public Castable<Transform> {
|
||||
/// @param ty the semantic type to reconstruct
|
||||
/// @returns a ast::Type that when resolved, will produce the semantic type
|
||||
/// `ty`.
|
||||
static ast::Type* CreateASTTypeFor(CloneContext& ctx, const sem::Type* ty);
|
||||
static const ast::Type* CreateASTTypeFor(CloneContext& ctx,
|
||||
const sem::Type* ty);
|
||||
};
|
||||
|
||||
} // namespace transform
|
||||
|
||||
@@ -26,7 +26,7 @@ namespace {
|
||||
struct CreateASTTypeForTest : public testing::Test, public Transform {
|
||||
Output Run(const Program*, const DataMap&) override { return {}; }
|
||||
|
||||
ast::Type* create(
|
||||
const ast::Type* create(
|
||||
std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
|
||||
ProgramBuilder sem_type_builder;
|
||||
auto* sem_type = create_sem_type(sem_type_builder);
|
||||
|
||||
@@ -137,7 +137,7 @@ struct DataType {
|
||||
uint32_t width; // 1 for scalar, 2+ for a vector
|
||||
};
|
||||
|
||||
DataType DataTypeOf(sem::Type* ty) {
|
||||
DataType DataTypeOf(const sem::Type* ty) {
|
||||
if (ty->Is<sem::I32>()) {
|
||||
return {BaseType::kI32, 1};
|
||||
}
|
||||
@@ -217,15 +217,15 @@ struct State {
|
||||
};
|
||||
|
||||
struct LocationInfo {
|
||||
std::function<ast::Expression*()> expr;
|
||||
sem::Type* type;
|
||||
std::function<const ast::Expression*()> expr;
|
||||
const sem::Type* type;
|
||||
};
|
||||
|
||||
CloneContext& ctx;
|
||||
VertexPulling::Config const cfg;
|
||||
std::unordered_map<uint32_t, LocationInfo> location_info;
|
||||
std::function<ast::Expression*()> vertex_index_expr = nullptr;
|
||||
std::function<ast::Expression*()> instance_index_expr = nullptr;
|
||||
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
|
||||
std::function<const ast::Expression*()> instance_index_expr = nullptr;
|
||||
Symbol pulling_position_name;
|
||||
Symbol struct_buffer_name;
|
||||
std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
|
||||
@@ -369,7 +369,7 @@ struct State {
|
||||
}
|
||||
} else if (var_dt.width > fmt_dt.width) {
|
||||
// WGSL variable vector width is wider than the loaded vector width
|
||||
ast::Type* ty = nullptr;
|
||||
const ast::Type* ty = nullptr;
|
||||
ast::ExpressionList values{fetch};
|
||||
switch (var_dt.base_type) {
|
||||
case BaseType::kI32:
|
||||
@@ -416,10 +416,10 @@ struct State {
|
||||
/// @param offset the byte offset of the data from `buffer_base`
|
||||
/// @param buffer the index of the vertex buffer
|
||||
/// @param format the format to read
|
||||
ast::Expression* Fetch(Symbol array_base,
|
||||
uint32_t offset,
|
||||
uint32_t buffer,
|
||||
VertexFormat format) {
|
||||
const ast::Expression* Fetch(Symbol array_base,
|
||||
uint32_t offset,
|
||||
uint32_t buffer,
|
||||
VertexFormat format) {
|
||||
using u32 = ProgramBuilder::u32;
|
||||
using i32 = ProgramBuilder::i32;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
@@ -642,15 +642,15 @@ struct State {
|
||||
/// @param buffer the index of the vertex buffer
|
||||
/// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
|
||||
/// VertexFormat::kFloat32
|
||||
ast::Expression* LoadPrimitive(Symbol array_base,
|
||||
uint32_t offset,
|
||||
uint32_t buffer,
|
||||
VertexFormat format) {
|
||||
ast::Expression* u32 = nullptr;
|
||||
const ast::Expression* LoadPrimitive(Symbol array_base,
|
||||
uint32_t offset,
|
||||
uint32_t buffer,
|
||||
VertexFormat format) {
|
||||
const ast::Expression* u32 = nullptr;
|
||||
if ((offset & 3) == 0) {
|
||||
// Aligned load.
|
||||
|
||||
ast ::Expression* index = nullptr;
|
||||
const ast ::Expression* index = nullptr;
|
||||
if (offset > 0) {
|
||||
index = ctx.dst->Add(array_base, offset / 4);
|
||||
} else {
|
||||
@@ -700,13 +700,13 @@ struct State {
|
||||
/// @param base_type underlying AST type
|
||||
/// @param base_format underlying vertex format
|
||||
/// @param count how many elements the vector has
|
||||
ast::Expression* LoadVec(Symbol array_base,
|
||||
uint32_t offset,
|
||||
uint32_t buffer,
|
||||
uint32_t element_stride,
|
||||
ast::Type* base_type,
|
||||
VertexFormat base_format,
|
||||
uint32_t count) {
|
||||
const ast::Expression* LoadVec(Symbol array_base,
|
||||
uint32_t offset,
|
||||
uint32_t buffer,
|
||||
uint32_t element_stride,
|
||||
const ast::Type* base_type,
|
||||
VertexFormat base_format,
|
||||
uint32_t count) {
|
||||
ast::ExpressionList expr_list;
|
||||
for (uint32_t i = 0; i < count; ++i) {
|
||||
// Offset read position by element_stride for each component
|
||||
@@ -724,7 +724,8 @@ struct State {
|
||||
/// vertex_index and instance_index builtins if present.
|
||||
/// @param func the entry point function
|
||||
/// @param param the parameter to process
|
||||
void ProcessNonStructParameter(ast::Function* func, ast::Variable* param) {
|
||||
void ProcessNonStructParameter(const ast::Function* func,
|
||||
const ast::Variable* param) {
|
||||
if (auto* location =
|
||||
ast::GetDecoration<ast::LocationDecoration>(param->decorations)) {
|
||||
// Create a function-scope variable to replace the parameter.
|
||||
@@ -764,8 +765,8 @@ struct State {
|
||||
/// @param func the entry point function
|
||||
/// @param param the parameter to process
|
||||
/// @param struct_ty the structure type
|
||||
void ProcessStructParameter(ast::Function* func,
|
||||
ast::Variable* param,
|
||||
void ProcessStructParameter(const ast::Function* func,
|
||||
const ast::Variable* param,
|
||||
const ast::Struct* struct_ty) {
|
||||
auto param_sym = ctx.Clone(param->symbol);
|
||||
|
||||
@@ -774,8 +775,8 @@ struct State {
|
||||
ast::StructMemberList members_to_clone;
|
||||
for (auto* member : struct_ty->members) {
|
||||
auto member_sym = ctx.Clone(member->symbol);
|
||||
std::function<ast::Expression*()> member_expr = [this, param_sym,
|
||||
member_sym]() {
|
||||
std::function<const ast::Expression*()> member_expr = [this, param_sym,
|
||||
member_sym]() {
|
||||
return ctx.dst->MemberAccessor(param_sym, member_sym);
|
||||
};
|
||||
|
||||
@@ -842,7 +843,7 @@ struct State {
|
||||
|
||||
/// Process an entry point function.
|
||||
/// @param func the entry point function
|
||||
void Process(ast::Function* func) {
|
||||
void Process(const ast::Function* func) {
|
||||
if (func->body->Empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
};
|
||||
|
||||
// Replace all array types with their corresponding wrapper
|
||||
ctx.ReplaceAll([&](ast::Type* ast_type) -> ast::Type* {
|
||||
ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
|
||||
auto* type = ctx.src->TypeOf(ast_type);
|
||||
if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
|
||||
return wrapper_typename(array);
|
||||
@@ -57,8 +57,8 @@ void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
});
|
||||
|
||||
// Fix up array accessors so `a[1]` becomes `a.arr[1]`
|
||||
ctx.ReplaceAll([&](ast::ArrayAccessorExpression* accessor)
|
||||
-> ast::ArrayAccessorExpression* {
|
||||
ctx.ReplaceAll([&](const ast::ArrayAccessorExpression* accessor)
|
||||
-> const ast::ArrayAccessorExpression* {
|
||||
if (auto* array = ::tint::As<sem::Array>(
|
||||
sem.Get(accessor->array)->Type()->UnwrapRef())) {
|
||||
if (wrapper(array)) {
|
||||
@@ -74,7 +74,8 @@ void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) {
|
||||
});
|
||||
|
||||
// Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))`
|
||||
ctx.ReplaceAll([&](ast::TypeConstructorExpression* ctor) -> ast::Expression* {
|
||||
ctx.ReplaceAll([&](const ast::TypeConstructorExpression* ctor)
|
||||
-> const ast::Expression* {
|
||||
if (auto* array =
|
||||
::tint::As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) {
|
||||
if (auto w = wrapper(array)) {
|
||||
@@ -107,7 +108,7 @@ WrapArraysInStructs::WrappedArrayInfo WrapArraysInStructs::WrapArray(
|
||||
info.wrapper_name = ctx.dst->Symbols().New("tint_array_wrapper");
|
||||
|
||||
// Examine the element type. Is it also an array?
|
||||
std::function<ast::Type*(CloneContext&)> el_type;
|
||||
std::function<const ast::Type*(CloneContext&)> el_type;
|
||||
if (auto* el_array = array->ElemType()->As<sem::Array>()) {
|
||||
// Array of array - call WrapArray() on the element type
|
||||
if (auto el = WrapArray(ctx, wrapped_arrays, el_array)) {
|
||||
|
||||
@@ -60,7 +60,7 @@ class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
|
||||
~WrappedArrayInfo();
|
||||
|
||||
Symbol wrapper_name;
|
||||
std::function<ast::Type*(CloneContext&)> array_type;
|
||||
std::function<const ast::Type*(CloneContext&)> array_type;
|
||||
|
||||
operator bool() { return wrapper_name.IsValid(); }
|
||||
};
|
||||
|
||||
@@ -46,7 +46,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||
uint32_t workgroup_size_const = 0;
|
||||
/// The size of the workgroup as an expression generator. Use if
|
||||
/// #workgroup_size_const is 0.
|
||||
std::function<ast::Expression*()> workgroup_size_expr;
|
||||
std::function<const ast::Expression*()> workgroup_size_expr;
|
||||
|
||||
/// ArrayIndex represents a function on the local invocation index, of
|
||||
/// the form: `array_index = (local_invocation_index % modulo) / division`
|
||||
@@ -80,7 +80,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||
/// statement will zero workgroup values.
|
||||
struct Expression {
|
||||
/// The AST expression node
|
||||
ast::Expression* expr = nullptr;
|
||||
const ast::Expression* expr = nullptr;
|
||||
/// The number of iterations required to zero the value
|
||||
uint32_t num_iterations = 0;
|
||||
/// All array indices used by this expression
|
||||
@@ -91,7 +91,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||
/// values.
|
||||
struct Statement {
|
||||
/// The AST statement node
|
||||
ast::Statement* stmt;
|
||||
const ast::Statement* stmt;
|
||||
/// The number of iterations required to zero the value
|
||||
uint32_t num_iterations;
|
||||
/// All array indices used by this statement
|
||||
@@ -112,7 +112,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||
/// Run inserts the workgroup memory zero-initialization logic at the top of
|
||||
/// the given function
|
||||
/// @param fn a compute shader entry point function
|
||||
void Run(ast::Function* fn) {
|
||||
void Run(const ast::Function* fn) {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
CalculateWorkgroupSize(
|
||||
@@ -137,7 +137,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||
|
||||
// Scan the entry point for an existing local_invocation_index builtin
|
||||
// parameter
|
||||
std::function<ast::Expression*()> local_index;
|
||||
std::function<const ast::Expression*()> local_index;
|
||||
for (auto* param : fn->params) {
|
||||
if (auto* builtin =
|
||||
ast::GetDecoration<ast::BuiltinDecoration>(param->decorations)) {
|
||||
@@ -341,7 +341,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||
ast::StatementList DeclareArrayIndices(
|
||||
uint32_t num_iterations,
|
||||
const ArrayIndices& array_indices,
|
||||
const std::function<ast::Expression*()>& iteration) {
|
||||
const std::function<const ast::Expression*()>& iteration) {
|
||||
ast::StatementList stmts;
|
||||
std::map<Symbol, ArrayIndex> indices_by_name;
|
||||
for (auto index : array_indices) {
|
||||
|
||||
Reference in New Issue
Block a user