[tint] Remove ast:: prefixes from AST transforms

Now that AST transforms live in the AST namespace, these prefixes are
no longer necessary.

Change-Id: I658746ac04220075653ec57d6dc998947232d8bc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132425
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2023-05-12 01:43:50 +00:00 committed by Dawn LUCI CQ
parent 4ae03fa8d0
commit 2b7406ad55
59 changed files with 1025 additions and 1064 deletions

View File

@ -41,7 +41,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
// A map from a type in the source program to a block-decorated wrapper that contains it in the
// destination program.
utils::Hashmap<const type::Type*, const ast::Struct*, 8> wrapper_structs;
utils::Hashmap<const type::Type*, const Struct*, 8> wrapper_structs;
// Process global 'var' declarations that are buffers.
bool made_changes = false;
@ -71,7 +71,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
auto wrapper_name = global->name->symbol.Name() + "_block";
auto* ret = b.create<ast::Struct>(
auto* ret = b.create<Struct>(
b.Ident(b.Symbols().New(wrapper_name)),
utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
utils::Vector{block});
@ -101,7 +101,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
return Program(std::move(b));
}
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid)
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, NodeID nid)
: Base(pid, nid, utils::Empty) {}
AddBlockAttribute::BlockAttribute::~BlockAttribute() = default;
std::string AddBlockAttribute::BlockAttribute::InternalName() const {

View File

@ -28,12 +28,12 @@ class AddBlockAttribute final : public utils::Castable<AddBlockAttribute, Transf
public:
/// BlockAttribute is an InternalAttribute that is used to decorate a
// structure that is used as a buffer in SPIR-V or GLSL.
class BlockAttribute final : public utils::Castable<BlockAttribute, ast::InternalAttribute> {
class BlockAttribute final : public utils::Castable<BlockAttribute, InternalAttribute> {
public:
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param nid the unique node identifier
BlockAttribute(ProgramID program_id, ast::NodeID nid);
BlockAttribute(ProgramID program_id, NodeID nid);
/// Destructor
~BlockAttribute() override;

View File

@ -52,7 +52,7 @@ Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src,
b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});

View File

@ -81,8 +81,8 @@ struct ArrayLengthFromUniform::State {
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*,
const sem::GlobalVariable* var) {
IterateArrayLengthOnStorageVar(
[&](const CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
if (auto binding = var->BindingPoint()) {
auto idx_itr = cfg->bindpoint_to_size_index.find(*binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
@ -96,7 +96,7 @@ struct ArrayLengthFromUniform::State {
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
const Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
@ -118,7 +118,7 @@ struct ArrayLengthFromUniform::State {
std::unordered_set<uint32_t> used_size_indices;
IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
IterateArrayLengthOnStorageVar([&](const CallExpression* call_expr,
const sem::VariableUser* storage_buffer_sem,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
@ -144,7 +144,7 @@ struct ArrayLengthFromUniform::State {
// total_storage_buffer_size - array_offset
// array_length = ----------------------------------------
// array_stride
const ast::Expression* total_size = total_storage_buffer_size;
const Expression* total_size = total_storage_buffer_size;
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
const type::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<type::Struct>()) {
@ -186,9 +186,9 @@ struct ArrayLengthFromUniform::State {
/// Iterate over all arrayLength() builtins that operate on
/// storage buffer variables.
/// @param functor of type void(const ast::CallExpression*, const
/// @param functor of type void(const CallExpression*, const
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
/// ast::CallExpression of the arrayLength call expression node, a
/// CallExpression of the arrayLength call expression node, a
/// sem::VariableUser of the used storage buffer variable, and the
/// sem::GlobalVariable for the storage buffer.
template <typename F>
@ -197,7 +197,7 @@ struct ArrayLengthFromUniform::State {
// Find all calls to the arrayLength() builtin.
for (auto* node : src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>();
auto* call_expr = node->As<CallExpression>();
if (!call_expr) {
continue;
}
@ -208,7 +208,7 @@ struct ArrayLengthFromUniform::State {
continue;
}
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
if (call_stmt->expr == call_expr) {
// arrayLength() is used as a statement.
// The argument expression must be side-effect free, so just drop the statement.
@ -222,15 +222,15 @@ struct ArrayLengthFromUniform::State {
// call has one of two forms:
// arrayLength(&struct_var.array_member)
// arrayLength(&array_var)
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
if (TINT_UNLIKELY(!param || param->op != ast::UnaryOp::kAddressOf)) {
auto* param = call_expr->args[0]->As<UnaryOpExpression>();
if (TINT_UNLIKELY(!param || param->op != UnaryOp::kAddressOf)) {
TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
}
auto* storage_buffer_expr = param->expr;
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
if (auto* accessor = param->expr->As<MemberAccessorExpression>()) {
storage_buffer_expr = accessor->object;
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);

View File

@ -90,7 +90,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
}
}
for (auto* var : src->AST().Globals<ast::Var>()) {
for (auto* var : src->AST().Globals<Var>()) {
if (var->HasBindingPoint()) {
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
@ -109,8 +109,8 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
auto* new_group = b.Group(AInt(to.group));
auto* new_binding = b.Binding(AInt(to.binding));
auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
auto* old_group = GetAttribute<GroupAttribute>(var->attributes);
auto* old_binding = GetAttribute<BindingAttribute>(var->attributes);
ctx.Replace(old_group, new_group);
ctx.Replace(old_binding, new_binding);
@ -139,7 +139,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
auto* ty = sem->Type()->UnwrapRef();
auto inner_ty = CreateASTTypeFor(ctx, ty);
auto* new_var =
b.create<ast::Var>(ctx.Clone(var->source), // source
b.create<Var>(ctx.Clone(var->source), // source
b.Ident(ctx.Clone(var->name->symbol)), // name
inner_ty, // type
ctx.Clone(var->declared_address_space), // address space
@ -151,7 +151,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
// Add `DisableValidationAttribute`s if required
if (add_collision_attr.count(bp)) {
auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision);
auto* attribute = b.Disable(DisabledValidation::kBindingPointCollision);
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
}
}

View File

@ -37,7 +37,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::BuiltinPolyfill::Config);
namespace tint::ast::transform {
/// BinaryOpSignature is tuple of a binary op, LHS type and RHS type
using BinaryOpSignature = std::tuple<ast::BinaryOp, const type::Type*, const type::Type*>;
using BinaryOpSignature = std::tuple<BinaryOp, const type::Type*, const type::Type*>;
/// PIMPL state for the transform
struct BuiltinPolyfill::State {
@ -60,16 +60,16 @@ struct BuiltinPolyfill::State {
for (auto* node : src->ASTNodes().Objects()) {
Switch(
node, //
[&](const ast::CallExpression* expr) { Call(expr); },
[&](const ast::BinaryExpression* bin_op) {
[&](const CallExpression* expr) { Call(expr); },
[&](const BinaryExpression* bin_op) {
if (auto* s = src->Sem().Get(bin_op);
!s || s->Stage() == sem::EvaluationStage::kConstant ||
s->Stage() == sem::EvaluationStage::kNotEvaluated) {
return; // Don't polyfill @const expressions
}
switch (bin_op->op) {
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight: {
case BinaryOp::kShiftLeft:
case BinaryOp::kShiftRight: {
if (cfg.builtins.bitshift_modulo) {
ctx.Replace(bin_op,
[this, bin_op] { return BitshiftModulo(bin_op); });
@ -77,7 +77,7 @@ struct BuiltinPolyfill::State {
}
break;
}
case ast::BinaryOp::kDivide: {
case BinaryOp::kDivide: {
if (cfg.builtins.int_div_mod) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
if (lhs_ty->is_integer_scalar_or_vector()) {
@ -88,7 +88,7 @@ struct BuiltinPolyfill::State {
}
break;
}
case ast::BinaryOp::kModulo: {
case BinaryOp::kModulo: {
if (cfg.builtins.int_div_mod) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
if (lhs_ty->is_integer_scalar_or_vector()) {
@ -111,7 +111,7 @@ struct BuiltinPolyfill::State {
break;
}
},
[&](const ast::Expression* expr) {
[&](const Expression* expr) {
if (cfg.builtins.bgra8unorm) {
if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
@ -170,15 +170,15 @@ struct BuiltinPolyfill::State {
auto name = b.Symbols().New("tint_acosh");
uint32_t width = WidthOf(ty);
auto V = [&](AFloat value) -> const ast::Expression* {
const ast::Expression* expr = b.Expr(value);
auto V = [&](AFloat value) -> const Expression* {
const Expression* expr = b.Expr(value);
if (width == 1) {
return expr;
}
return b.Call(T(ty), expr);
};
utils::Vector<const ast::Statement*, 4> body;
utils::Vector<const Statement*, 4> body;
switch (cfg.builtins.acosh) {
case Level::kFull:
// return log(x + sqrt(x*x - 1));
@ -224,15 +224,15 @@ struct BuiltinPolyfill::State {
auto name = b.Symbols().New("tint_atanh");
uint32_t width = WidthOf(ty);
auto V = [&](AFloat value) -> const ast::Expression* {
const ast::Expression* expr = b.Expr(value);
auto V = [&](AFloat value) -> const Expression* {
const Expression* expr = b.Expr(value);
if (width == 1) {
return expr;
}
return b.Call(T(ty), expr);
};
utils::Vector<const ast::Statement*, 1> body;
utils::Vector<const Statement*, 1> body;
switch (cfg.builtins.atanh) {
case Level::kFull:
// return log((1+x) / (1-x)) * 0.5
@ -290,7 +290,7 @@ struct BuiltinPolyfill::State {
}
return b.ty.vec<u32>(width);
};
auto V = [&](uint32_t value) -> const ast::Expression* {
auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value));
};
b.Func(
@ -348,10 +348,10 @@ struct BuiltinPolyfill::State {
}
return b.ty.vec<u32>(width);
};
auto V = [&](uint32_t value) -> const ast::Expression* {
auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value));
};
auto B = [&](const ast::Expression* value) -> const ast::Expression* {
auto B = [&](const Expression* value) -> const Expression* {
if (width == 1) {
return b.Call<bool>(value);
}
@ -402,14 +402,14 @@ struct BuiltinPolyfill::State {
constexpr uint32_t W = 32u; // 32-bit
auto vecN_u32 = [&](const ast::Expression* value) -> const ast::Expression* {
auto vecN_u32 = [&](const Expression* value) -> const Expression* {
if (width == 1) {
return value;
}
return b.Call(b.ty.vec<u32>(width), value);
};
utils::Vector<const ast::Statement*, 8> body{
utils::Vector<const Statement*, 8> body{
b.Decl(b.Let("s", b.Call("min", "offset", u32(W)))),
b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))),
};
@ -465,17 +465,17 @@ struct BuiltinPolyfill::State {
}
return b.ty.vec<u32>(width);
};
auto V = [&](uint32_t value) -> const ast::Expression* {
auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value));
};
auto B = [&](const ast::Expression* value) -> const ast::Expression* {
auto B = [&](const Expression* value) -> const Expression* {
if (width == 1) {
return b.Call<bool>(value);
}
return b.Call(b.ty.vec<bool>(width), value);
};
const ast::Expression* x = nullptr;
const Expression* x = nullptr;
if (ty->is_unsigned_integer_scalar_or_vector()) {
x = b.Expr("v");
} else {
@ -537,10 +537,10 @@ struct BuiltinPolyfill::State {
}
return b.ty.vec<u32>(width);
};
auto V = [&](uint32_t value) -> const ast::Expression* {
auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value));
};
auto B = [&](const ast::Expression* value) -> const ast::Expression* {
auto B = [&](const Expression* value) -> const Expression* {
if (width == 1) {
return b.Call<bool>(value);
}
@ -599,8 +599,8 @@ struct BuiltinPolyfill::State {
constexpr uint32_t W = 32u; // 32-bit
auto V = [&](auto value) -> const ast::Expression* {
const ast::Expression* expr = b.Expr(value);
auto V = [&](auto value) -> const Expression* {
const Expression* expr = b.Expr(value);
if (!ty->is_unsigned_integer_scalar_or_vector()) {
expr = b.Call<i32>(expr);
}
@ -609,7 +609,7 @@ struct BuiltinPolyfill::State {
}
return expr;
};
auto U = [&](auto value) -> const ast::Expression* {
auto U = [&](auto value) -> const Expression* {
if (width == 1) {
return b.Expr(value);
}
@ -638,7 +638,7 @@ struct BuiltinPolyfill::State {
// return ((select(T(), n << offset, offset < 32u) & mask) | (v & ~(mask)));
// }
utils::Vector<const ast::Statement*, 8> body;
utils::Vector<const Statement*, 8> body;
switch (cfg.builtins.insert_bits) {
case Level::kFull:
@ -788,7 +788,7 @@ struct BuiltinPolyfill::State {
/// @return the polyfill function name
Symbol quantizeToF16(const type::Vector* vec) {
auto name = b.Symbols().New("tint_quantizeToF16");
utils::Vector<const ast::Expression*, 4> args;
utils::Vector<const Expression*, 4> args;
for (uint32_t i = 0; i < vec->Width(); i++) {
args.Push(b.Call("quantizeToF16", b.IndexAccessor("v", u32(i))));
}
@ -880,29 +880,29 @@ struct BuiltinPolyfill::State {
/// the RHS is modulo the bit-width of the LHS.
/// @param bin_op the original BinaryExpression
/// @return the polyfill value for bitshift operation
const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) {
const Expression* BitshiftModulo(const BinaryExpression* bin_op) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
const Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
if (rhs_ty->Is<type::Vector>()) {
mask = b.Call(CreateASTTypeFor(ctx, rhs_ty), mask);
}
auto* lhs = ctx.Clone(bin_op->lhs);
auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask);
return b.create<ast::BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
return b.create<BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
}
/// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and
/// integer overflows.
/// @param bin_op the original BinaryExpression
/// @return the polyfill divide or modulo
const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) {
const Expression* IntDivMod(const BinaryExpression* bin_op) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
const bool is_div = bin_op->op == ast::BinaryOp::kDivide;
const bool is_div = bin_op->op == BinaryOp::kDivide;
uint32_t lhs_width = 1;
uint32_t rhs_width = 1;
@ -914,7 +914,7 @@ struct BuiltinPolyfill::State {
const char* lhs = "lhs";
const char* rhs = "rhs";
utils::Vector<const ast::Statement*, 4> body;
utils::Vector<const Statement*, 4> body;
if (lhs_width < width) {
// lhs is scalar, rhs is vector. Convert lhs to vector.
@ -934,8 +934,8 @@ struct BuiltinPolyfill::State {
if (lhs_ty->is_signed_integer_scalar_or_vector()) {
const auto bits = lhs_el_ty->Size() * 8;
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
const Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
const Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
// use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1))
auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one));
@ -992,7 +992,7 @@ struct BuiltinPolyfill::State {
/// Builds the polyfill inline expression for a precise float modulo, as defined in the spec.
/// @param bin_op the original BinaryExpression
/// @return the polyfill divide or modulo
const ast::Expression* PreciseFloatMod(const ast::BinaryExpression* bin_op) {
const Expression* PreciseFloatMod(const BinaryExpression* bin_op) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
@ -1007,7 +1007,7 @@ struct BuiltinPolyfill::State {
const char* lhs = "lhs";
const char* rhs = "rhs";
utils::Vector<const ast::Statement*, 4> body;
utils::Vector<const Statement*, 4> body;
if (lhs_width < width) {
// lhs is scalar, rhs is vector. Convert lhs to vector.
@ -1042,7 +1042,7 @@ struct BuiltinPolyfill::State {
}
/// @returns the AST type for the given sem type
ast::Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); }
Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); }
/// @returns 1 if `ty` is not a vector, otherwise the vector width
uint32_t WidthOf(const type::Type* ty) const {
@ -1055,7 +1055,7 @@ struct BuiltinPolyfill::State {
/// @returns a scalar or vector with the given width, with each element with
/// the given value.
template <typename T>
const ast::Expression* ScalarOrVector(uint32_t width, T value) {
const Expression* ScalarOrVector(uint32_t width, T value) {
if (width == 1) {
return b.Expr(value);
}
@ -1063,7 +1063,7 @@ struct BuiltinPolyfill::State {
}
template <typename To>
const ast::Expression* CastScalarOrVector(uint32_t width, const ast::Expression* e) {
const Expression* CastScalarOrVector(uint32_t width, const Expression* e) {
if (width == 1) {
return b.Call(b.ty.Of<To>(), e);
}
@ -1071,7 +1071,7 @@ struct BuiltinPolyfill::State {
}
/// Examines the call expression @p expr, applying any necessary polyfill transforms
void Call(const ast::CallExpression* expr) {
void Call(const CallExpression* expr) {
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
if (!call || call->Stage() == sem::EvaluationStage::kConstant ||
call->Stage() == sem::EvaluationStage::kNotEvaluated) {
@ -1207,7 +1207,7 @@ struct BuiltinPolyfill::State {
size_t value_idx = static_cast<size_t>(
sig.IndexOf(sem::ParameterUsage::kValue));
ctx.Replace(expr, [this, expr, value_idx] {
utils::Vector<const ast::Expression*, 3> args;
utils::Vector<const Expression*, 3> args;
for (auto* arg : expr->args) {
arg = ctx.Clone(arg);
if (args.Length() == value_idx) { // value

View File

@ -57,7 +57,7 @@ bool ShouldRun(const Program* program) {
/// ArrayUsage describes a runtime array usage.
/// It is used as a key by the array_length_by_usage map.
struct ArrayUsage {
ast::BlockStatement const* const block;
BlockStatement const* const block;
sem::Variable const* const buffer;
bool operator==(const ArrayUsage& rhs) const {
return block == rhs.block && buffer == rhs.buffer;
@ -71,7 +71,7 @@ struct ArrayUsage {
} // namespace
CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, ast::NodeID nid)
CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, NodeID nid)
: Base(pid, nid, utils::Empty) {}
CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
@ -106,7 +106,7 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = b.Sym();
auto type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
auto* disable_validation = b.Disable(DisabledValidation::kFunctionParameter);
b.Func(
name,
utils::Vector{
@ -128,13 +128,13 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
// Find all the arrayLength() calls...
for (auto* node : src->ASTNodes().Objects()) {
if (auto* call_expr = node->As<ast::CallExpression>()) {
if (auto* call_expr = node->As<CallExpression>()) {
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == builtin::Function::kArrayLength) {
// We're dealing with an arrayLength() call
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
if (call_stmt->expr == call_expr) {
// arrayLength() is used as a statement.
// The argument expression must be side-effect free, so just drop the
@ -151,13 +151,13 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
// arrayLength(&struct_var.array_member)
// arrayLength(&array_var)
auto* arg = call_expr->args[0];
auto* address_of = arg->As<ast::UnaryOpExpression>();
if (TINT_UNLIKELY(!address_of || address_of->op != ast::UnaryOp::kAddressOf)) {
auto* address_of = arg->As<UnaryOpExpression>();
if (TINT_UNLIKELY(!address_of || address_of->op != UnaryOp::kAddressOf)) {
TINT_ICE(Transform, b.Diagnostics())
<< "arrayLength() expected address-of, got " << arg->TypeInfo().name;
}
auto* storage_buffer_expr = address_of->expr;
if (auto* accessor = storage_buffer_expr->As<ast::MemberAccessorExpression>()) {
if (auto* accessor = storage_buffer_expr->As<MemberAccessorExpression>()) {
storage_buffer_expr = accessor->object;
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
@ -199,8 +199,7 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
// array_length = ----------------------------------------
// array_stride
auto name = b.Sym();
const ast::Expression* total_size =
b.Expr(buffer_size_result->variable);
const Expression* total_size = b.Expr(buffer_size_result->variable);
const type::Array* array_type = Switch(
storage_buffer_type->StoreType(),

View File

@ -37,12 +37,12 @@ class CalculateArrayLength final : public utils::Castable<CalculateArrayLength,
/// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
/// functions used to obtain the runtime size of a storage buffer.
class BufferSizeIntrinsic final
: public utils::Castable<BufferSizeIntrinsic, ast::InternalAttribute> {
: public utils::Castable<BufferSizeIntrinsic, InternalAttribute> {
public:
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param nid the unique node identifier
BufferSizeIntrinsic(ProgramID program_id, ast::NodeID nid);
BufferSizeIntrinsic(ProgramID program_id, NodeID nid);
/// Destructor
~BufferSizeIntrinsic() override;

View File

@ -41,7 +41,7 @@ namespace {
/// Info for a struct member
struct MemberInfo {
/// The struct member item
const ast::StructMember* member;
const StructMember* member;
/// The struct member location if provided
std::optional<uint32_t> location;
};
@ -83,9 +83,9 @@ uint32_t BuiltinOrder(builtin::BuiltinValue builtin) {
}
// Returns true if `attr` is a shader IO attribute.
bool IsShaderIOAttribute(const ast::Attribute* attr) {
return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute, ast::InvariantAttribute,
ast::LocationAttribute>();
bool IsShaderIOAttribute(const Attribute* attr) {
return attr
->IsAnyOf<BuiltinAttribute, InterpolateAttribute, InvariantAttribute, LocationAttribute>();
}
} // namespace
@ -97,11 +97,11 @@ struct CanonicalizeEntryPointIO::State {
/// The name of the output value.
std::string name;
/// The type of the output value.
ast::Type type;
Type type;
/// The shader IO attributes.
utils::Vector<const ast::Attribute*, 8> attributes;
utils::Vector<const Attribute*, 8> attributes;
/// The value itself.
const ast::Expression* value;
const Expression* value;
/// The output location.
std::optional<uint32_t> location;
};
@ -111,29 +111,29 @@ struct CanonicalizeEntryPointIO::State {
/// The transform config.
CanonicalizeEntryPointIO::Config const cfg;
/// The entry point function (AST).
const ast::Function* func_ast;
const Function* func_ast;
/// The entry point function (SEM).
const sem::Function* func_sem;
/// The new entry point wrapper function's parameters.
utils::Vector<const ast::Parameter*, 8> wrapper_ep_parameters;
utils::Vector<const Parameter*, 8> wrapper_ep_parameters;
/// The members of the wrapper function's struct parameter.
utils::Vector<MemberInfo, 8> wrapper_struct_param_members;
/// The name of the wrapper function's struct parameter.
Symbol wrapper_struct_param_name;
/// The parameters that will be passed to the original function.
utils::Vector<const ast::Expression*, 8> inner_call_parameters;
utils::Vector<const Expression*, 8> inner_call_parameters;
/// The members of the wrapper function's struct return type.
utils::Vector<MemberInfo, 8> wrapper_struct_output_members;
/// The wrapper function output values.
utils::Vector<OutputValue, 8> wrapper_output_values;
/// The body of the wrapper function.
utils::Vector<const ast::Statement*, 8> wrapper_body;
utils::Vector<const Statement*, 8> wrapper_body;
/// Input names used by the entrypoint
std::unordered_set<std::string> input_names;
/// A map of cloned attribute to builtin value
utils::Hashmap<const ast::BuiltinAttribute*, builtin::BuiltinValue, 16> builtin_attrs;
utils::Hashmap<const BuiltinAttribute*, builtin::BuiltinValue, 16> builtin_attrs;
/// Constructor
/// @param context the clone context
@ -141,7 +141,7 @@ struct CanonicalizeEntryPointIO::State {
/// @param function the entry point function
State(CloneContext& context,
const CanonicalizeEntryPointIO::Config& config,
const ast::Function* function)
const Function* function)
: ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {}
/// Clones the attributes from @p in and adds it to @p out. If @p in is a builtin attribute,
@ -149,12 +149,11 @@ struct CanonicalizeEntryPointIO::State {
/// @param in the attribute to clone
/// @param out the output Attributes
template <size_t N>
void CloneAttribute(const ast::Attribute* in, utils::Vector<const ast::Attribute*, N>& out) {
void CloneAttribute(const Attribute* in, utils::Vector<const Attribute*, N>& out) {
auto* cloned = ctx.Clone(in);
out.Push(cloned);
if (auto* builtin = in->As<ast::BuiltinAttribute>()) {
builtin_attrs.Add(cloned->As<ast::BuiltinAttribute>(),
ctx.src->Sem().Get(builtin)->Value());
if (auto* builtin = in->As<BuiltinAttribute>()) {
builtin_attrs.Add(cloned->As<BuiltinAttribute>(), ctx.src->Sem().Get(builtin)->Value());
}
}
@ -163,12 +162,11 @@ struct CanonicalizeEntryPointIO::State {
/// @param do_interpolate whether to clone InterpolateAttribute
/// @return the cloned attributes
template <size_t N>
auto CloneShaderIOAttributes(const utils::Vector<const ast::Attribute*, N> in,
bool do_interpolate) {
utils::Vector<const ast::Attribute*, N> out;
auto CloneShaderIOAttributes(const utils::Vector<const Attribute*, N> in, bool do_interpolate) {
utils::Vector<const Attribute*, N> out;
for (auto* attr : in) {
if (IsShaderIOAttribute(attr) &&
(do_interpolate || !attr->template Is<ast::InterpolateAttribute>())) {
(do_interpolate || !attr->template Is<InterpolateAttribute>())) {
CloneAttribute(attr, out);
}
}
@ -177,7 +175,7 @@ struct CanonicalizeEntryPointIO::State {
/// @param attr the input attribute
/// @returns the builtin value of the attribute
builtin::BuiltinValue BuiltinOf(const ast::BuiltinAttribute* attr) {
builtin::BuiltinValue BuiltinOf(const BuiltinAttribute* attr) {
if (attr->program_id == ctx.dst->ID()) {
// attr belongs to the target program.
// Obtain the builtin value from #builtin_attrs.
@ -197,8 +195,8 @@ struct CanonicalizeEntryPointIO::State {
/// @param attrs the input attribute list
/// @returns the builtin value if any of the attributes in @p attrs is a builtin attribute,
/// otherwise builtin::BuiltinValue::kUndefined
builtin::BuiltinValue BuiltinOf(utils::VectorRef<const ast::Attribute*> attrs) {
if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs)) {
builtin::BuiltinValue BuiltinOf(utils::VectorRef<const Attribute*> attrs) {
if (auto* builtin = GetAttribute<BuiltinAttribute>(attrs)) {
return BuiltinOf(builtin);
}
return builtin::BuiltinValue::kUndefined;
@ -219,10 +217,10 @@ struct CanonicalizeEntryPointIO::State {
/// @param location the location if provided
/// @param attrs the attributes to apply to the shader input
/// @returns an expression which evaluates to the value of the shader input
const ast::Expression* AddInput(std::string name,
const Expression* AddInput(std::string name,
const type::Type* type,
std::optional<uint32_t> location,
utils::Vector<const ast::Attribute*, 8> attrs) {
utils::Vector<const Attribute*, 8> attrs) {
auto ast_type = CreateASTTypeFor(ctx, type);
auto builtin_attr = BuiltinOf(attrs);
@ -233,17 +231,16 @@ struct CanonicalizeEntryPointIO::State {
// https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/StandaloneSpirv.html#VUID-StandaloneSpirv-Flat-04744
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is
// required for integers.
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
type->is_integer_scalar_or_vector() &&
!ast::HasAttribute<ast::InterpolateAttribute>(attrs) &&
(ast::HasAttribute<ast::LocationAttribute>(attrs) ||
if (func_ast->PipelineStage() == PipelineStage::kFragment &&
type->is_integer_scalar_or_vector() && !HasAttribute<InterpolateAttribute>(attrs) &&
(HasAttribute<LocationAttribute>(attrs) ||
cfg.shader_style == ShaderStyle::kSpirv)) {
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
builtin::InterpolationSampling::kUndefined));
}
// Disable validation for use of the `input` address space.
attrs.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
attrs.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
// In GLSL, if it's a builtin, override the name with the
// corresponding gl_ builtin name
@ -255,7 +252,7 @@ struct CanonicalizeEntryPointIO::State {
auto symbol = ctx.dst->Symbols().New(name);
// Create the global variable and use its value for the shader input.
const ast::Expression* value = ctx.dst->Expr(symbol);
const Expression* value = ctx.dst->Expr(symbol);
if (builtin_attr != builtin::BuiltinValue::kUndefined) {
if (cfg.shader_style == ShaderStyle::kGlsl) {
@ -296,18 +293,17 @@ struct CanonicalizeEntryPointIO::State {
void AddOutput(std::string name,
const type::Type* type,
std::optional<uint32_t> location,
utils::Vector<const ast::Attribute*, 8> attrs,
const ast::Expression* value) {
utils::Vector<const Attribute*, 8> attrs,
const Expression* value) {
auto builtin_attr = BuiltinOf(attrs);
// Vulkan requires that integer user-defined vertex outputs are always decorated with
// `Flat`.
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is required
// for integers.
if (cfg.shader_style == ShaderStyle::kSpirv &&
func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
type->is_integer_scalar_or_vector() &&
ast::HasAttribute<ast::LocationAttribute>(attrs) &&
!ast::HasAttribute<ast::InterpolateAttribute>(attrs)) {
func_ast->PipelineStage() == PipelineStage::kVertex &&
type->is_integer_scalar_or_vector() && HasAttribute<LocationAttribute>(attrs) &&
!HasAttribute<InterpolateAttribute>(attrs)) {
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
builtin::InterpolationSampling::kUndefined));
}
@ -338,14 +334,14 @@ struct CanonicalizeEntryPointIO::State {
/// @param param the original function parameter
void ProcessNonStructParameter(const sem::Parameter* param) {
// Do not add interpolation attributes on vertex input
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kVertex;
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
// Remove the shader IO attributes from the inner function parameter, and attach them to the
// new object instead.
utils::Vector<const ast::Attribute*, 8> attributes;
utils::Vector<const Attribute*, 8> attributes;
for (auto* attr : param->Declaration()->attributes) {
if (IsShaderIOAttribute(attr)) {
ctx.Remove(param->Declaration()->attributes, attr);
if ((do_interpolate || !attr->Is<ast::InterpolateAttribute>())) {
if ((do_interpolate || !attr->Is<InterpolateAttribute>())) {
CloneAttribute(attr, attributes);
}
}
@ -363,13 +359,13 @@ struct CanonicalizeEntryPointIO::State {
/// @param param the original function parameter
void ProcessStructParameter(const sem::Parameter* param) {
// Do not add interpolation attributes on vertex input
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kVertex;
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
auto* str = param->Type()->As<sem::Struct>();
// Recreate struct members in the outer entry point and build an initializer
// list to pass them through to the inner function.
utils::Vector<const ast::Expression*, 8> inner_struct_values;
utils::Vector<const Expression*, 8> inner_struct_values;
for (auto* member : str->Members()) {
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
@ -397,7 +393,7 @@ struct CanonicalizeEntryPointIO::State {
/// @param original_result the result object produced by the original function
void ProcessReturnType(const type::Type* inner_ret_type, Symbol original_result) {
// Do not add interpolation attributes on fragment output
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kFragment;
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kFragment;
if (auto* str = inner_ret_type->As<sem::Struct>()) {
for (auto* member : str->Members()) {
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
@ -456,7 +452,7 @@ struct CanonicalizeEntryPointIO::State {
/// Create an expression for gl_Position.[component]
/// @param component the component of gl_Position to access
/// @returns the new expression
const ast::Expression* GLPosition(const char* component) {
const Expression* GLPosition(const char* component) {
Symbol pos = ctx.dst->Symbols().Register("gl_Position");
Symbol c = ctx.dst->Symbols().Register(component);
return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), c);
@ -469,10 +465,10 @@ struct CanonicalizeEntryPointIO::State {
/// @param b another struct member
/// @returns true if a comes before b
bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) {
auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a.member->attributes);
auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b.member->attributes);
auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a.member->attributes);
auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b.member->attributes);
auto* a_loc = GetAttribute<LocationAttribute>(a.member->attributes);
auto* b_loc = GetAttribute<LocationAttribute>(b.member->attributes);
auto* a_blt = GetAttribute<BuiltinAttribute>(a.member->attributes);
auto* b_blt = GetAttribute<BuiltinAttribute>(b.member->attributes);
if (a_loc) {
if (!b_loc) {
// `a` has location attribute and `b` does not: `a` goes first.
@ -497,15 +493,15 @@ struct CanonicalizeEntryPointIO::State {
std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
[&](auto& a, auto& b) { return StructMemberComparator(a, b); });
utils::Vector<const ast::StructMember*, 8> members;
utils::Vector<const StructMember*, 8> members;
for (auto& mem : wrapper_struct_param_members) {
members.Push(mem.member);
}
// Create the new struct type.
auto struct_name = ctx.dst->Sym();
auto* in_struct = ctx.dst->create<ast::Struct>(ctx.dst->Ident(struct_name),
std::move(members), utils::Empty);
auto* in_struct =
ctx.dst->create<Struct>(ctx.dst->Ident(struct_name), std::move(members), utils::Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
// Create a new function parameter using this struct type.
@ -515,8 +511,8 @@ struct CanonicalizeEntryPointIO::State {
/// Create and return the wrapper function's struct result object.
/// @returns the struct type
ast::Struct* CreateOutputStruct() {
utils::Vector<const ast::Statement*, 8> assignments;
Struct* CreateOutputStruct() {
utils::Vector<const Statement*, 8> assignments;
auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
@ -544,13 +540,13 @@ struct CanonicalizeEntryPointIO::State {
std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
[&](auto& a, auto& b) { return StructMemberComparator(a, b); });
utils::Vector<const ast::StructMember*, 8> members;
utils::Vector<const StructMember*, 8> members;
for (auto& mem : wrapper_struct_output_members) {
members.Push(mem.member);
}
// Create the new struct type.
auto* out_struct = ctx.dst->create<ast::Struct>(ctx.dst->Ident(ctx.dst->Sym()),
auto* out_struct = ctx.dst->create<Struct>(ctx.dst->Ident(ctx.dst->Sym()),
std::move(members), utils::Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
@ -570,12 +566,12 @@ struct CanonicalizeEntryPointIO::State {
for (auto& outval : wrapper_output_values) {
// Disable validation for use of the `output` address space.
auto attributes = std::move(outval.attributes);
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
// Create the global variable and assign it the output value.
auto name = ctx.dst->Symbols().New(outval.name);
ast::Type type = outval.type;
const ast::Expression* lhs = ctx.dst->Expr(name);
Type type = outval.type;
const Expression* lhs = ctx.dst->Expr(name);
if (BuiltinOf(attributes) == builtin::BuiltinValue::kSampleMask) {
// Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then store to the first element.
@ -589,7 +585,7 @@ struct CanonicalizeEntryPointIO::State {
// Recreate the original function without entry point attributes and call it.
/// @returns the inner function call expression
const ast::CallExpression* CallInnerFunction() {
const CallExpression* CallInnerFunction() {
Symbol inner_name;
if (cfg.shader_style == ShaderStyle::kGlsl) {
// In GLSL, clone the original entry point name, as the wrapper will be
@ -606,9 +602,9 @@ struct CanonicalizeEntryPointIO::State {
// The parameter attributes will have already been stripped during
// processing.
auto* inner_function =
ctx.dst->create<ast::Function>(ctx.dst->Ident(inner_name), ctx.Clone(func_ast->params),
ctx.Clone(func_ast->return_type),
ctx.Clone(func_ast->body), utils::Empty, utils::Empty);
ctx.dst->create<Function>(ctx.dst->Ident(inner_name), ctx.Clone(func_ast->params),
ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
utils::Empty, utils::Empty);
ctx.Replace(func_ast, inner_function);
// Call the function.
@ -619,12 +615,11 @@ struct CanonicalizeEntryPointIO::State {
void Process() {
bool needs_fixed_sample_mask = false;
bool needs_vertex_point_size = false;
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
if (func_ast->PipelineStage() == PipelineStage::kFragment &&
cfg.fixed_sample_mask != 0xFFFFFFFF) {
needs_fixed_sample_mask = true;
}
if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
cfg.emit_vertex_point_size) {
if (func_ast->PipelineStage() == PipelineStage::kVertex && cfg.emit_vertex_point_size) {
needs_vertex_point_size = true;
}
@ -656,7 +651,7 @@ struct CanonicalizeEntryPointIO::State {
auto* call_inner = CallInnerFunction();
// Process the return type, and start building the wrapper function body.
std::function<ast::Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
std::function<Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
if (func_sem->ReturnType()->Is<type::Void>()) {
// The function call is just a statement with no result.
wrapper_body.Push(ctx.dst->CallStmt(call_inner));
@ -693,10 +688,10 @@ struct CanonicalizeEntryPointIO::State {
}
if (cfg.shader_style == ShaderStyle::kGlsl &&
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
func_ast->PipelineStage() == PipelineStage::kVertex) {
auto* pos_y = GLPosition("y");
auto* negate_pos_y =
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation, GLPosition("y"));
ctx.dst->create<UnaryOpExpression>(UnaryOp::kNegation, GLPosition("y"));
wrapper_body.Push(ctx.dst->Assign(pos_y, negate_pos_y));
auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z"));
@ -714,7 +709,7 @@ struct CanonicalizeEntryPointIO::State {
name = ctx.Clone(func_ast->name->symbol);
}
auto* wrapper_func = ctx.dst->create<ast::Function>(
auto* wrapper_func = ctx.dst->create<Function>(
ctx.dst->Ident(name), wrapper_ep_parameters, ctx.dst->ty(wrapper_ret_type()),
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes), utils::Empty);
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
@ -726,14 +721,14 @@ struct CanonicalizeEntryPointIO::State {
/// @param address_space the address space (input or output)
/// @returns the gl_ string corresponding to that builtin
const char* GLSLBuiltinToString(builtin::BuiltinValue builtin,
ast::PipelineStage stage,
PipelineStage stage,
builtin::AddressSpace address_space) {
switch (builtin) {
case builtin::BuiltinValue::kPosition:
switch (stage) {
case ast::PipelineStage::kVertex:
case PipelineStage::kVertex:
return "gl_Position";
case ast::PipelineStage::kFragment:
case PipelineStage::kFragment:
return "gl_FragCoord";
default:
return "";
@ -775,9 +770,9 @@ struct CanonicalizeEntryPointIO::State {
/// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
/// @returns an expression representing the GLSL builtin converted to what
/// WGSL expects
const ast::Expression* FromGLSLBuiltin(builtin::BuiltinValue builtin,
const ast::Expression* value,
ast::Type& ast_type) {
const Expression* FromGLSLBuiltin(builtin::BuiltinValue builtin,
const Expression* value,
Type& ast_type) {
switch (builtin) {
case builtin::BuiltinValue::kVertexIndex:
case builtin::BuiltinValue::kInstanceIndex:
@ -805,8 +800,8 @@ struct CanonicalizeEntryPointIO::State {
/// @param value the value to convert
/// @param type (out) the type to which the value was converted
/// @returns the converted value which can be assigned to the GLSL builtin
const ast::Expression* ToGLSLBuiltin(builtin::BuiltinValue builtin,
const ast::Expression* value,
const Expression* ToGLSLBuiltin(builtin::BuiltinValue builtin,
const Expression* value,
const type::Type*& type) {
switch (builtin) {
case builtin::BuiltinValue::kVertexIndex:
@ -839,7 +834,7 @@ Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
// Remove entry point IO attributes from struct declarations.
// New structures will be created for each entry point, as necessary.
for (auto* ty : src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
if (auto* struct_ty = ty->As<Struct>()) {
for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) {
if (IsShaderIOAttribute(attr)) {

View File

@ -51,7 +51,7 @@ struct ClampFragDepth::State {
Transform::ApplyResult Run() {
// Abort on any use of push constants in the module.
for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (auto* var = global->As<Var>()) {
auto* v = src->Sem().Get(var);
if (TINT_UNLIKELY(v->AddressSpace() == builtin::AddressSpace::kPushConstant)) {
TINT_ICE(Transform, b.Diagnostics())
@ -101,14 +101,14 @@ struct ClampFragDepth::State {
// Map of io struct to helper function to return the structure with the depth clamping
// applied.
utils::Hashmap<const ast::Struct*, Symbol, 4u> io_structs_clamp_helpers;
utils::Hashmap<const Struct*, Symbol, 4u> io_structs_clamp_helpers;
// Register a callback that will be called for each visted AST function.
// This call wraps the cloning of the function's statements, and will assign to
// `returns_frag_depth_as_value` or `returns_frag_depth_as_struct_helper` if the function's
// return value requires depth clamping.
ctx.ReplaceAll([&](const ast::Function* fn) {
if (fn->PipelineStage() != ast::PipelineStage::kFragment) {
ctx.ReplaceAll([&](const Function* fn) {
if (fn->PipelineStage() != PipelineStage::kFragment) {
return ctx.CloneWithoutTransform(fn);
}
@ -129,9 +129,9 @@ struct ClampFragDepth::State {
auto fn_sym =
b.Symbols().New("clamp_frag_depth_" + struct_ty->name->symbol.Name());
utils::Vector<const ast::Expression*, 8u> initializer_args;
utils::Vector<const Expression*, 8u> initializer_args;
for (auto* member : struct_ty->members) {
const ast::Expression* arg =
const Expression* arg =
b.MemberAccessor("s", ctx.Clone(member->name->symbol));
if (ContainsFragDepth(member->attributes)) {
arg = b.Call(base_fn_sym, arg);
@ -154,7 +154,7 @@ struct ClampFragDepth::State {
});
// Replace the return statements `return expr` with `return clamp_frag_depth(expr)`.
ctx.ReplaceAll([&](const ast::ReturnStatement* stmt) -> const ast::ReturnStatement* {
ctx.ReplaceAll([&](const ReturnStatement* stmt) -> const ReturnStatement* {
if (returns_frag_depth_as_value) {
return b.Return(stmt->source, b.Call(base_fn_sym, ctx.Clone(stmt->value)));
}
@ -173,7 +173,7 @@ struct ClampFragDepth::State {
/// @returns true if the transform should run
bool ShouldRun() {
for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kFragment &&
if (fn->PipelineStage() == PipelineStage::kFragment &&
(ReturnsFragDepthAsValue(fn) || ReturnsFragDepthInStruct(fn))) {
return true;
}
@ -183,9 +183,9 @@ struct ClampFragDepth::State {
}
/// @param attrs the attributes to examine
/// @returns true if @p attrs contains a `@builtin(frag_depth)` attribute
bool ContainsFragDepth(utils::VectorRef<const ast::Attribute*> attrs) {
bool ContainsFragDepth(utils::VectorRef<const Attribute*> attrs) {
for (auto* attribute : attrs) {
if (auto* builtin_attr = attribute->As<ast::BuiltinAttribute>()) {
if (auto* builtin_attr = attribute->As<BuiltinAttribute>()) {
auto builtin = sem.Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kFragDepth) {
return true;
@ -198,14 +198,14 @@ struct ClampFragDepth::State {
/// @param fn the function to examine
/// @returns true if @p fn has a return type with a `@builtin(frag_depth)` attribute
bool ReturnsFragDepthAsValue(const ast::Function* fn) {
bool ReturnsFragDepthAsValue(const Function* fn) {
return ContainsFragDepth(fn->return_type_attributes);
}
/// @param fn the function to examine
/// @returns true if @p fn has a return structure with a `@builtin(frag_depth)` attribute on one
/// of the members
bool ReturnsFragDepthInStruct(const ast::Function* fn) {
bool ReturnsFragDepthInStruct(const Function* fn) {
if (auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()) {
for (auto* member : struct_ty->Members()) {
if (ContainsFragDepth(member->Declaration()->attributes)) {

View File

@ -61,7 +61,7 @@ struct CombineSamplers::State {
/// Map from a texture/sampler pair to the corresponding combined sampler
/// variable
using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const ast::Variable*>;
using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const Variable*>;
/// Use sem::BindingPoint without scope.
using BindingPoint = sem::BindingPoint;
@ -79,15 +79,14 @@ struct CombineSamplers::State {
/// references (one comparison sampler, one regular). These are also used as
/// temporary sampler parameters to the texture builtins to satisfy the WGSL
/// resolver, but are then ignored and removed by the GLSL writer.
const ast::Variable* placeholder_samplers_[2] = {};
const Variable* placeholder_samplers_[2] = {};
/// Group and binding attributes used by all combined sampler globals.
/// Group 0 and binding 0 are used, with collisions disabled.
/// @returns the newly-created attribute list
auto Attributes() const {
utils::Vector<const ast::Attribute*, 3> attributes{ctx.dst->Group(0_a),
ctx.dst->Binding(0_a)};
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision));
utils::Vector<const Attribute*, 3> attributes{ctx.dst->Group(0_a), ctx.dst->Binding(0_a)};
attributes.Push(ctx.dst->Disable(DisabledValidation::kBindingPointCollision));
return attributes;
}
@ -103,7 +102,7 @@ struct CombineSamplers::State {
/// @param sampler_var the sampler (global) variable
/// @param name the default name to use (may be overridden by map lookup)
/// @returns the newly-created global variable
const ast::Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
const Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
const sem::Variable* sampler_var,
std::string name) {
SamplerTexturePair bp_pair;
@ -115,7 +114,7 @@ struct CombineSamplers::State {
if (it != binding_info->binding_map.end()) {
name = it->second;
}
ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
Symbol symbol = ctx.dst->Symbols().New(name);
return ctx.dst->GlobalVar(symbol, type, Attributes());
}
@ -123,8 +122,8 @@ struct CombineSamplers::State {
/// Creates placeholder global sampler variables.
/// @param kind the sampler kind to create for
/// @returns the newly-created global variable
const ast::Variable* CreatePlaceholder(type::SamplerKind kind) {
ast::Type type = ctx.dst->ty.sampler(kind);
const Variable* CreatePlaceholder(type::SamplerKind kind) {
Type type = ctx.dst->ty.sampler(kind);
const char* name = kind == type::SamplerKind::kComparisonSampler
? "placeholder_comparison_sampler"
: "placeholder_sampler";
@ -132,13 +131,13 @@ struct CombineSamplers::State {
return ctx.dst->GlobalVar(symbol, type, Attributes());
}
/// Creates ast::Identifier for a given texture and sampler variable pair.
/// Creates Identifier for a given texture and sampler variable pair.
/// Depth textures with no samplers are turned into the corresponding
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
/// @param texture the texture variable of interest
/// @param sampler the texture variable of interest
/// @returns the newly-created type
ast::Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
const type::Type* texture_type = texture->Type()->UnwrapRef();
const type::DepthTexture* depth = texture_type->As<type::DepthTexture>();
if (depth && !sampler) {
@ -163,8 +162,7 @@ struct CombineSamplers::State {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
} else if (auto binding_point = global_sem->BindingPoint()) {
if (binding_point->group == 0 && binding_point->binding == 0) {
auto* attribute =
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
auto* attribute = ctx.dst->Disable(DisabledValidation::kBindingPointCollision);
ctx.InsertFront(global->attributes, attribute);
}
}
@ -172,13 +170,13 @@ struct CombineSamplers::State {
// Rewrite all function signatures to use combined samplers, and remove
// separate textures & samplers. Create new combined globals where found.
ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* {
ctx.ReplaceAll([&](const Function* ast_fn) -> const Function* {
if (auto* fn = sem.Get(ast_fn)) {
auto pairs = fn->TextureSamplerPairs();
if (pairs.IsEmpty()) {
return nullptr;
}
utils::Vector<const ast::Parameter*, 8> params;
utils::Vector<const Parameter*, 8> params;
for (auto pair : fn->TextureSamplerPairs()) {
const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second;
@ -195,7 +193,7 @@ struct CombineSamplers::State {
} else {
// Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler.
ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.Push(var);
function_combined_texture_samplers_[fn][pair] = var;
@ -215,7 +213,7 @@ struct CombineSamplers::State {
auto* body = ctx.Clone(ast_fn->body);
auto attributes = ctx.Clone(ast_fn->attributes);
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
return ctx.dst->create<ast::Function>(name, params, return_type, body,
return ctx.dst->create<Function>(name, params, return_type, body,
std::move(attributes),
std::move(return_type_attributes));
}
@ -225,9 +223,9 @@ struct CombineSamplers::State {
// Replace all function call expressions containing texture or
// sampler parameters to use the current function's combined samplers or
// the combined global samplers, as appropriate.
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
utils::Vector<const ast::Expression*, 8> args;
utils::Vector<const Expression*, 8> args;
// Replace all texture builtin calls.
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
const auto& signature = builtin->Signature();
@ -254,7 +252,7 @@ struct CombineSamplers::State {
for (auto* arg : expr->args) {
auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
if (type->Is<type::Texture>()) {
const ast::Variable* var =
const Variable* var =
IsGlobal(new_pair)
? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_[call->Stmt()->Function()]
@ -263,7 +261,7 @@ struct CombineSamplers::State {
} else if (auto* sampler_type = type->As<type::Sampler>()) {
type::SamplerKind kind = sampler_type->kind();
int index = (kind == type::SamplerKind::kSampler) ? 0 : 1;
const ast::Variable*& p = placeholder_samplers_[index];
const Variable*& p = placeholder_samplers_[index];
if (!p) {
p = CreatePlaceholder(kind);
}
@ -272,10 +270,10 @@ struct CombineSamplers::State {
args.Push(ctx.Clone(arg));
}
}
const ast::Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
const Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
if (builtin->Type() == builtin::Function::kTextureLoad &&
texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() &&
!call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
!call->Stmt()->Declaration()->Is<CallStatement>()) {
value = ctx.dst->MemberAccessor(value, "x");
}
return value;
@ -307,7 +305,7 @@ struct CombineSamplers::State {
// If both texture and sampler are (now) global, pass that
// global variable to the callee. Otherwise use the caller's
// function parameter for this pair.
const ast::Variable* var =
const Variable* var =
IsGlobal(new_pair)
? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_[call->Stmt()->Function()]

View File

@ -60,21 +60,21 @@ bool ShouldRun(const Program* program) {
return false;
}
/// Offset is a simple ast::Expression builder interface, used to build byte
/// Offset is a simple Expression builder interface, used to build byte
/// offsets for storage and uniform buffer accesses.
struct Offset : utils::Castable<Offset> {
/// @returns builds and returns the ast::Expression in `ctx.dst`
virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
/// @returns builds and returns the Expression in `ctx.dst`
virtual const Expression* Build(CloneContext& ctx) const = 0;
};
/// OffsetExpr is an implementation of Offset that clones and casts the given
/// expression to `u32`.
struct OffsetExpr : Offset {
const ast::Expression* const expr = nullptr;
const Expression* const expr = nullptr;
explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
explicit OffsetExpr(const Expression* e) : expr(e) {}
const ast::Expression* Build(CloneContext& ctx) const override {
const Expression* Build(CloneContext& ctx) const override {
auto* type = ctx.src->Sem().GetVal(expr)->Type()->UnwrapRef();
auto* res = ctx.Clone(expr);
if (!type->Is<type::U32>()) {
@ -91,7 +91,7 @@ struct OffsetLiteral final : utils::Castable<OffsetLiteral, Offset> {
explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
const ast::Expression* Build(CloneContext& ctx) const override {
const Expression* Build(CloneContext& ctx) const override {
return ctx.dst->Expr(u32(literal));
}
};
@ -99,12 +99,12 @@ struct OffsetLiteral final : utils::Castable<OffsetLiteral, Offset> {
/// OffsetBinOp is an implementation of Offset that constructs a binary-op of
/// two Offsets.
struct OffsetBinOp : Offset {
ast::BinaryOp op;
BinaryOp op;
Offset const* lhs = nullptr;
Offset const* rhs = nullptr;
const ast::Expression* Build(CloneContext& ctx) const override {
return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx));
const Expression* Build(CloneContext& ctx) const override {
return ctx.dst->create<BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx));
}
};
@ -313,7 +313,7 @@ struct BufferAccess {
/// Store describes a single storage or uniform buffer write
struct Store {
const ast::AssignmentStatement* assignment; // The AST assignment statement
const AssignmentStatement* assignment; // The AST assignment statement
BufferAccess target; // The target for the write
};
@ -330,9 +330,9 @@ struct DecomposeMemoryAccess::State {
/// expressions chain the access.
/// Subset of #expression_order, as expressions are not removed from
/// #expression_order.
std::unordered_map<const ast::Expression*, BufferAccess> accesses;
std::unordered_map<const Expression*, BufferAccess> accesses;
/// The visited order of AST expressions (superset of #accesses)
std::vector<const ast::Expression*> expression_order;
std::vector<const Expression*> expression_order;
/// [buffer-type, element-type] -> load function name
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
/// [buffer-type, element-type] -> store function name
@ -353,9 +353,9 @@ struct DecomposeMemoryAccess::State {
const Offset* ToOffset(uint32_t offset) { return offsets_.Create<OffsetLiteral>(offset); }
/// @param expr the expression to convert to an Offset
/// @returns an Offset for the given ast::Expression
const Offset* ToOffset(const ast::Expression* expr) {
if (auto* lit = expr->As<ast::IntLiteralExpression>()) {
/// @returns an Offset for the given Expression
const Offset* ToOffset(const Expression* expr) {
if (auto* lit = expr->As<IntLiteralExpression>()) {
if (lit->value >= 0) {
return offsets_.Create<OffsetLiteral>(static_cast<uint32_t>(lit->value));
}
@ -390,7 +390,7 @@ struct DecomposeMemoryAccess::State {
}
}
auto* out = offsets_.Create<OffsetBinOp>();
out->op = ast::BinaryOp::kAdd;
out->op = BinaryOp::kAdd;
out->lhs = lhs;
out->rhs = rhs;
return out;
@ -422,7 +422,7 @@ struct DecomposeMemoryAccess::State {
return offsets_.Create<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal);
}
auto* out = offsets_.Create<OffsetBinOp>();
out->op = ast::BinaryOp::kMultiply;
out->op = BinaryOp::kMultiply;
out->lhs = lhs;
out->rhs = rhs;
return out;
@ -432,7 +432,7 @@ struct DecomposeMemoryAccess::State {
/// to #expression_order.
/// @param expr the expression that performs the access
/// @param access the access
void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
void AddAccess(const Expression* expr, const BufferAccess& access) {
TINT_ASSERT(Transform, access.type);
accesses.emplace(expr, access);
expression_order.emplace_back(expr);
@ -443,7 +443,7 @@ struct DecomposeMemoryAccess::State {
/// `node`, an invalid BufferAccess is returned.
/// @param node the expression that performed an access
/// @return the BufferAccess for the given expression
BufferAccess TakeAccess(const ast::Expression* node) {
BufferAccess TakeAccess(const Expression* node) {
auto lhs_it = accesses.find(node);
if (lhs_it == accesses.end()) {
return {};
@ -475,7 +475,7 @@ struct DecomposeMemoryAccess::State {
b.Func(name, params, el_ast_ty, nullptr,
utils::Vector{
intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
b.Disable(DisabledValidation::kFunctionHasNoBody),
});
} else if (auto* arr_ty = el_ty->As<type::Array>()) {
// fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
@ -498,8 +498,8 @@ struct DecomposeMemoryAccess::State {
TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
arr_cnt = 1;
}
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
@ -514,7 +514,7 @@ struct DecomposeMemoryAccess::State {
b.Return(arr),
});
} else {
utils::Vector<const ast::Expression*, 8> values;
utils::Vector<const Expression*, 8> values;
if (auto* mat_ty = el_ty->As<type::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType();
Symbol load = LoadFunc(vec_ty, address_space, buffer);
@ -557,10 +557,10 @@ struct DecomposeMemoryAccess::State {
b.Func(name, params, b.ty.void_(), nullptr,
utils::Vector{
intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
b.Disable(DisabledValidation::kFunctionHasNoBody),
});
} else {
auto body = Switch<utils::Vector<const ast::Statement*, 8>>(
auto body = Switch<utils::Vector<const Statement*, 8>>(
el_ty, //
[&](const type::Array* arr_ty) {
// fn store_func(buffer : buf_ty, offset : u32, value : el_ty) {
@ -585,8 +585,8 @@ struct DecomposeMemoryAccess::State {
<< "unexpected non-constant array count";
arr_cnt = 1;
}
auto* for_cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
@ -598,7 +598,7 @@ struct DecomposeMemoryAccess::State {
[&](const type::Matrix* mat_ty) {
auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(vec_ty, buffer);
utils::Vector<const ast::Statement*, 4> stmts;
utils::Vector<const Statement*, 4> stmts;
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
auto* element = b.IndexAccessor("value", u32(i));
@ -608,7 +608,7 @@ struct DecomposeMemoryAccess::State {
return stmts;
},
[&](const type::Struct* str) {
utils::Vector<const ast::Statement*, 8> stmts;
utils::Vector<const Statement*, 8> stmts;
for (auto* member : str->Members()) {
auto* offset = b.Add("offset", u32(member->Offset()));
auto* element = b.MemberAccessor("value", ctx.Clone(member->Name()));
@ -656,14 +656,14 @@ struct DecomposeMemoryAccess::State {
<< el_ty->TypeInfo().name;
}
ast::Type ret_ty;
Type ret_ty;
// For intrinsics that return a struct, there is no AST node for it, so create one now.
if (intrinsic->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
auto* str = intrinsic->ReturnType()->As<type::Struct>();
TINT_ASSERT(Transform, str);
utils::Vector<const ast::StructMember*, 8> ast_members;
utils::Vector<const StructMember*, 8> ast_members;
ast_members.Reserve(str->Members().Length());
for (auto& m : str->Members()) {
ast_members.Push(
@ -681,7 +681,7 @@ struct DecomposeMemoryAccess::State {
b.Func(name, std::move(params), ret_ty, nullptr,
utils::Vector{
atomic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
b.Disable(DisabledValidation::kFunctionHasNoBody),
});
return name;
});
@ -689,11 +689,11 @@ struct DecomposeMemoryAccess::State {
};
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
ast::NodeID nid,
NodeID nid,
Op o,
DataType ty,
builtin::AddressSpace as,
const ast::IdentifierExpression* buf)
const IdentifierExpression* buf)
: Base(pid, nid, utils::Vector{buf}), op(o), type(ty), address_space(as) {}
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
@ -804,7 +804,7 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const {
return op != Op::kLoad && op != Op::kStore;
}
const ast::IdentifierExpression* DecomposeMemoryAccess::Intrinsic::Buffer() const {
const IdentifierExpression* DecomposeMemoryAccess::Intrinsic::Buffer() const {
return dependencies[0];
}
@ -832,7 +832,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
// nodes are fully immutable and require their children to be constructed
// first so their pointer can be passed to the parent's initializer.
for (auto* node : src->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
if (auto* ident = node->As<IdentifierExpression>()) {
// X
if (auto* sem_ident = sem.GetVal(ident)) {
if (auto* user = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) {
@ -852,7 +852,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
continue;
}
if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
if (auto* accessor = node->As<MemberAccessorExpression>()) {
// X.Y
auto* accessor_sem = sem.Get(accessor)->UnwrapLoad();
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
@ -882,7 +882,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
continue;
}
if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
if (auto* accessor = node->As<IndexAccessorExpression>()) {
if (auto access = state.TakeAccess(accessor->object)) {
// X[Y]
if (auto* arr = access.type->As<type::Array>()) {
@ -915,8 +915,8 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
}
}
if (auto* op = node->As<ast::UnaryOpExpression>()) {
if (op->op == ast::UnaryOp::kAddressOf) {
if (auto* op = node->As<UnaryOpExpression>()) {
if (op->op == UnaryOp::kAddressOf) {
// &X
if (auto access = state.TakeAccess(op->expr)) {
// HLSL does not support pointers, so just take the access from the
@ -927,7 +927,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
}
}
if (auto* assign = node->As<ast::AssignmentStatement>()) {
if (auto* assign = node->As<AssignmentStatement>()) {
// X = Y
// Move the LHS access to a store.
if (auto lhs = state.TakeAccess(assign->lhs)) {
@ -935,7 +935,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
}
}
if (auto* call_expr = node->As<ast::CallExpression>()) {
if (auto* call_expr = node->As<CallExpression>()) {
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == builtin::Function::kArrayLength) {
@ -953,7 +953,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
auto buffer = ctx.Clone(access.var->Declaration()->name->symbol);
Symbol func = state.AtomicFunc(el_ty, builtin, buffer);
utils::Vector<const ast::Expression*, 8> args{offset};
utils::Vector<const Expression*, 8> args{offset};
for (size_t i = 1; i < call_expr->args.Length(); i++) {
auto* arg = call_expr->args[i];
args.Push(ctx.Clone(arg));

View File

@ -35,7 +35,7 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
/// transforms this into calls to
/// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
/// with a possible cast.
class Intrinsic final : public utils::Castable<Intrinsic, ast::InternalAttribute> {
class Intrinsic final : public utils::Castable<Intrinsic, InternalAttribute> {
public:
/// Intrinsic op
enum class Op {
@ -82,11 +82,11 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
/// @param address_space the address space of the buffer
/// @param buffer the storage or uniform buffer identifier
Intrinsic(ProgramID pid,
ast::NodeID nid,
NodeID nid,
Op o,
DataType type,
builtin::AddressSpace address_space,
const ast::IdentifierExpression* buffer);
const IdentifierExpression* buffer);
/// Destructor
~Intrinsic() override;
@ -103,7 +103,7 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
bool IsAtomic() const;
/// @return the buffer that this intrinsic operates on
const ast::IdentifierExpression* Buffer() const;
const IdentifierExpression* Buffer() const;
/// The op of the intrinsic
const Op op;

View File

@ -37,8 +37,8 @@ using DecomposedArrays = std::unordered_map<const type::Array*, Symbol>;
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::TemplatedIdentifier>()) {
if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) {
if (auto* ident = node->As<TemplatedIdentifier>()) {
if (GetAttribute<StrideAttribute>(ident->attributes)) {
return true;
}
}
@ -74,8 +74,8 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
// stride for the array element type, then replace the array element type with
// a structure, holding a single field with a @size attribute equal to the
// array stride.
ctx.ReplaceAll([&](const ast::IdentifierExpression* expr) -> const ast::IdentifierExpression* {
auto* ident = expr->identifier->As<ast::TemplatedIdentifier>();
ctx.ReplaceAll([&](const IdentifierExpression* expr) -> const IdentifierExpression* {
auto* ident = expr->identifier->As<TemplatedIdentifier>();
if (!ident) {
return nullptr;
}
@ -90,8 +90,8 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
if (!arr->IsStrideImplicit()) {
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
auto name = b.Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>());
auto* member = b.Member(kMemberName, ast::Type{member_ty},
auto* member_ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
auto* member = b.Member(kMemberName, Type{member_ty},
utils::Vector{
b.MemberSize(AInt(arr->Stride())),
});
@ -105,14 +105,14 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
return b.Expr(b.ty.array(b.ty(el_ty)));
}
}
if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) {
if (GetAttribute<StrideAttribute>(ident->attributes)) {
// Strip the @stride attribute
auto* ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>());
auto* ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
if (ident->arguments.Length() > 1) {
auto* count = ctx.Clone(ident->arguments[1]);
return b.Expr(b.ty.array(ast::Type{ty}, count));
return b.Expr(b.ty.array(Type{ty}, count));
} else {
return b.Expr(b.ty.array(ast::Type{ty}));
return b.Expr(b.ty.array(Type{ty}));
}
}
return nullptr;
@ -122,7 +122,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
// element changed to a single field structure. These expressions are adjusted
// to insert an additional member accessor for the single structure field.
// Example: `arr[i]` -> `arr[i].el`
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
ctx.ReplaceAll([&](const IndexAccessorExpression* idx) -> const Expression* {
if (auto* ty = src->TypeOf(idx->object)) {
if (auto* arr = ty->UnwrapRef()->As<type::Array>()) {
if (!arr->IsStrideImplicit()) {
@ -140,7 +140,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
// `@stride(32) array<i32, 3>(1, 2, 3)`
// ->
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
if (!expr->args.IsEmpty()) {
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
if (auto* ctor = call->Target()->As<sem::ValueConstructor>()) {
@ -153,7 +153,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
auto* target = ctx.Clone(expr->target);
utils::Vector<const ast::Expression*, 8> args;
utils::Vector<const Expression*, 8> args;
if (auto it = decomposed.find(arr); it != decomposed.end()) {
args.Reserve(expr->args.Length());
for (auto* arg : expr->args) {

View File

@ -101,7 +101,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -145,7 +145,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -195,7 +195,7 @@ TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -253,7 +253,7 @@ TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 2_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -303,7 +303,7 @@ TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -357,7 +357,7 @@ TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -410,7 +410,7 @@ TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -472,7 +472,7 @@ TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -531,7 +531,7 @@ TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), 5_f),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -593,7 +593,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -696,7 +696,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
5_f),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});

View File

@ -38,7 +38,7 @@ struct MatrixInfo {
const type::Matrix* matrix = nullptr;
/// @returns the identifier of an array that holds an vector column for each row of the matrix.
ast::Type array(ProgramBuilder* b) const {
Type array(ProgramBuilder* b) const {
return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()),
utils::Vector{
b->Stride(stride),
@ -72,7 +72,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// and populate the `decomposed` map with the members that have been replaced.
utils::Hashmap<const type::StructMember*, MatrixInfo, 8> decomposed;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
if (auto* str = node->As<Struct>()) {
auto* str_ty = src->Sem().Get(str);
if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) &&
!str_ty->UsedAs(builtin::AddressSpace::kStorage)) {
@ -83,8 +83,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
if (!matrix) {
continue;
}
auto* attr =
ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes);
auto* attr = GetAttribute<StrideAttribute>(member->Declaration()->attributes);
if (!attr) {
continue;
}
@ -111,8 +110,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// preserve these without calling conversion functions.
// Example:
// ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
ctx.ReplaceAll([&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
if (decomposed.Contains(access->Member())) {
auto* obj = ctx.CloneWithoutTransform(expr->object);
@ -129,7 +127,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// Example:
// ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
ctx.ReplaceAll([&](const AssignmentStatement* stmt) -> const Statement* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
if (auto info = decomposed.Find(access->Member())) {
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
@ -142,7 +140,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
auto array = [&] { return info->array(ctx.dst); };
auto mat = b.Sym("m");
utils::Vector<const ast::Expression*, 4> columns;
utils::Vector<const Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
columns.Push(b.IndexAccessor(mat, u32(i)));
}
@ -168,7 +166,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// matrix type. Example:
// m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
ctx.ReplaceAll([&](const MemberAccessorExpression* expr) -> const Expression* {
if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) {
if (auto info = decomposed.Find(access->Member())) {
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
@ -181,7 +179,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
auto array = [&] { return info->array(ctx.dst); };
auto arr = b.Sym("arr");
utils::Vector<const ast::Expression*, 4> columns;
utils::Vector<const Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
columns.Push(b.IndexAccessor(arr, u32(i)));
}

View File

@ -67,13 +67,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
@ -82,7 +82,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -124,13 +124,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
@ -140,7 +140,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -178,13 +178,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(8u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(8u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
@ -193,7 +193,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -232,13 +232,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -248,7 +248,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -290,13 +290,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
// let x : vec2<f32> = s.m[1];
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -307,7 +307,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -345,13 +345,13 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -362,7 +362,7 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -404,13 +404,13 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
// s.m[1] = vec2<f32>(1.0, 2.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -420,7 +420,7 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.vec2<f32>(1_f, 2_f)),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -464,13 +464,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
// (*b)[1] = vec2<f32>(5.0, 6.0);
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -486,7 +486,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), b.vec2<f32>(5_f, 6_f)),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -536,13 +536,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
// let x : mat2x2<f32> = s.m;
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
@ -551,7 +551,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
@ -590,13 +590,13 @@ TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// }
ProgramBuilder b;
auto* S = b.Structure(
"S", utils::Vector{
auto* S =
b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{
b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
b.create<StrideAttribute>(32u),
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}),
});
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
@ -606,7 +606,7 @@ TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});

View File

@ -85,9 +85,8 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
b.GlobalVar(flag, builtin::AddressSpace::kPrivate, b.Expr(false));
// Replace all discard statements with a statement that marks the invocation as discarded.
ctx.ReplaceAll([&](const ast::DiscardStatement*) -> const ast::Statement* {
return b.Assign(flag, b.Expr(true));
});
ctx.ReplaceAll(
[&](const DiscardStatement*) -> const Statement* { return b.Assign(flag, b.Expr(true)); });
// Insert a conditional discard at the end of each entry point that does not end with a return.
for (auto* func : functions_to_process) {
@ -111,7 +110,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
node,
// Mask assignments to storage buffer variables.
[&](const ast::AssignmentStatement* assign) {
[&](const AssignmentStatement* assign) {
// Skip writes in functions that are not called from shaders that discard.
auto* func = sem.Get(assign)->Function();
if (functions_to_process.count(func) == 0) {
@ -119,7 +118,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
}
// Skip phony assignments.
if (assign->lhs->Is<ast::PhonyExpression>()) {
if (assign->lhs->Is<PhonyExpression>()) {
return;
}
@ -144,7 +143,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
},
// Mask builtins that write to host-visible memory.
[&](const ast::CallExpression* call) {
[&](const CallExpression* call) {
auto* sem_call = sem.Get<sem::Call>(call);
auto* stmt = sem_call ? sem_call->Stmt() : nullptr;
auto* func = stmt ? stmt->Function() : nullptr;
@ -161,7 +160,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
} else if (builtin->IsAtomic() &&
builtin->Type() != builtin::Function::kAtomicLoad) {
// A call to an atomic builtin can be a statement or an expression.
if (auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>();
if (auto* call_stmt = stmt->Declaration()->As<CallStatement>();
call_stmt && call_stmt->expr == call) {
// This call is a statement.
// Wrap it inside a conditional block.
@ -178,8 +177,8 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
// }
// let y = x + tmp;
auto result = b.Sym();
ast::Type result_ty;
const ast::Statement* masked_call = nullptr;
Type result_ty;
const Statement* masked_call = nullptr;
if (builtin->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
// Special case for atomicCompareExchangeWeak as we cannot name its
// result type. We have to declare an equivalent struct and copy the
@ -232,7 +231,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
},
// Insert a conditional discard before all return statements in entry points.
[&](const ast::ReturnStatement* ret) {
[&](const ReturnStatement* ret) {
auto* func = sem.Get(ret)->Function();
if (func->Declaration()->IsEntryPoint() && functions_to_process.count(func)) {
auto* discard = b.If(flag, b.Block(b.Discard()));

View File

@ -450,19 +450,19 @@ struct DirectVariableAccess::State {
Switch(
variable->Declaration(),
[&](const ast::Var*) {
[&](const Var*) {
if (variable->AddressSpace() != builtin::AddressSpace::kHandle) {
// Start a new access chain for the non-handle 'var' access
create_new_chain();
}
},
[&](const ast::Parameter*) {
[&](const Parameter*) {
if (variable->Type()->Is<type::Pointer>()) {
// Start a new access chain for the pointer parameter access
create_new_chain();
}
},
[&](const ast::Let*) {
[&](const Let*) {
if (variable->Type()->Is<type::Pointer>()) {
// variable is a pointer-let.
auto* init = sem.GetVal(variable->Declaration()->initializer);
@ -494,11 +494,10 @@ struct DirectVariableAccess::State {
}
},
[&](const sem::ValueExpression* e) {
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
// Unary op.
// If this is a '&' or '*', simply move the chain to the unary op expression.
if (unary->op == ast::UnaryOp::kAddressOf ||
unary->op == ast::UnaryOp::kIndirection) {
if (unary->op == UnaryOp::kAddressOf || unary->op == UnaryOp::kIndirection) {
take_chain(sem.GetVal(unary->expr));
}
}
@ -529,7 +528,7 @@ struct DirectVariableAccess::State {
if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) {
auto* idx_variable = idx_variable_user->Variable();
if (idx_variable->Declaration()->IsAnyOf<ast::Let, ast::Parameter>()) {
if (idx_variable->Declaration()->IsAnyOf<Let, Parameter>()) {
// Dynamic index is an immutable variable
continue; // Hoisting not required.
}
@ -557,7 +556,7 @@ struct DirectVariableAccess::State {
/// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type
/// isn't implicitly usable as a u32. This is to help feed the expression into a
/// `array<u32, N>` argument passed to a callee variant function.
const ast::Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) {
const Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) {
if (auto* val = idx->ConstantValue()) {
// Expression evaluated to a constant value. Just emit that constant.
return b.Expr(val->ValueAs<AInt>());
@ -808,7 +807,7 @@ struct DirectVariableAccess::State {
// many variant functions, keep a record of the last created variant, and explicitly add
// this to the module if it isn't the last. We'll return the last created variant,
// taking the place of the original function.
const ast::Function* pending_variant = nullptr;
const Function* pending_variant = nullptr;
// For each variant of fn...
for (auto variant_it : fn_info->SortedVariants()) {
@ -827,7 +826,7 @@ struct DirectVariableAccess::State {
// Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are
// either replaced with an array of dynamic indices, or are dropped (if there are no
// dynamic indices).
utils::Vector<const ast::Parameter*, 8> params;
utils::Vector<const Parameter*, 8> params;
for (auto* param : fn->Parameters()) {
if (auto incoming_shape = variant_sig.Find(param)) {
auto& symbols = *variant.ptr_param_symbols.Find(param);
@ -856,7 +855,7 @@ struct DirectVariableAccess::State {
auto attrs = ctx.Clone(fn->Declaration()->attributes);
auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
pending_variant =
b.create<ast::Function>(b.Ident(variant.name), std::move(params), ret_ty, body,
b.create<Function>(b.Ident(variant.name), std::move(params), ret_ty, body,
std::move(attrs), std::move(ret_attrs));
}
@ -877,7 +876,7 @@ struct DirectVariableAccess::State {
}
// Build the new call expressions's arguments.
utils::Vector<const ast::Expression*, 8> new_args;
utils::Vector<const Expression*, 8> new_args;
for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) {
auto* arg = call->Arguments()[arg_idx];
auto* param = call->Target()->Parameters()[arg_idx];
@ -915,7 +914,7 @@ struct DirectVariableAccess::State {
// Get or create the dynamic indices array.
if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
// Build an array of dynamic indices to pass as the replacement for the pointer.
utils::Vector<const ast::Expression*, 8> dyn_idx_args;
utils::Vector<const Expression*, 8> dyn_idx_args;
if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
// Access chain originates from a pointer parameter.
if (auto incoming_chain =
@ -985,7 +984,7 @@ struct DirectVariableAccess::State {
/// let.
void TransformAccessChainExpressions() {
// Register a custom handler for all non-function call expressions
ctx.ReplaceAll([this](const ast::Expression* ast_expr) -> const ast::Expression* {
ctx.ReplaceAll([this](const Expression* ast_expr) -> const Expression* {
if (!clone_state->current_variant) {
// Expression does not belong to a function variant.
return nullptr; // Just clone the expression.
@ -1065,7 +1064,7 @@ struct DirectVariableAccess::State {
/// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
/// if this is the first call for the given shape.
ast::Type DynamicIndexArrayType(const AccessShape& shape) {
Type DynamicIndexArrayType(const AccessShape& shape) {
auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
// Count the number of dynamic indices
uint32_t num_dyn_indices = shape.NumDynamicIndices();
@ -1076,7 +1075,7 @@ struct DirectVariableAccess::State {
b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
return symbol;
});
return name.IsValid() ? b.ty(name) : ast::Type{};
return name.IsValid() ? b.ty(name) : Type{};
}
/// @returns a name describing the given shape
@ -1113,7 +1112,7 @@ struct DirectVariableAccess::State {
/// Builds an expresion to the root of an access, returning the new expression.
/// @param root the AccessRoot
/// @param deref if true, the returned expression will always be a reference type.
const ast::Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
const Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
if (auto* param = root.variable->As<sem::Parameter>()) {
if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
if (deref) {
@ -1123,7 +1122,7 @@ struct DirectVariableAccess::State {
}
}
const ast::Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol));
const Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol));
if (deref) {
if (root.variable->Type()->Is<type::Pointer>()) {
expr = b.Deref(expr);
@ -1137,10 +1136,9 @@ struct DirectVariableAccess::State {
/// @param expr the input expression
/// @param access the access to perform on the current expression
/// @param dynamic_index a function that obtains the i'th dynamic index
const ast::Expression* BuildAccessExpr(
const ast::Expression* expr,
const Expression* BuildAccessExpr(const Expression* expr,
const AccessOp& access,
std::function<const ast::Expression*(size_t)> dynamic_index) {
std::function<const Expression*(size_t)> dynamic_index) {
if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
/// The access uses a dynamic (runtime-expression) index.
auto* idx = dynamic_index(dyn_idx->slot);

View File

@ -35,7 +35,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) {
if (node->IsAnyOf<CompoundAssignmentStatement, IncrementDecrementStatement>()) {
return true;
}
}
@ -58,17 +58,14 @@ struct ExpandCompoundAssignment::State {
/// @param lhs the lhs expression from the source statement
/// @param rhs the rhs expression in the destination module
/// @param op the binary operator
void Expand(const ast::Statement* stmt,
const ast::Expression* lhs,
const ast::Expression* rhs,
ast::BinaryOp op) {
void Expand(const Statement* stmt, const Expression* lhs, const Expression* rhs, BinaryOp op) {
// Helper function to create the new LHS expression. This will be called
// twice when building the non-compound assignment statement, so must
// not produce expressions that cause side effects.
std::function<const ast::Expression*()> new_lhs;
std::function<const Expression*()> new_lhs;
// Helper function to create a variable that is a pointer to `expr`.
auto hoist_pointer_to = [&](const ast::Expression* expr) {
auto hoist_pointer_to = [&](const Expression* expr) {
auto name = b.Sym();
auto* ptr = b.AddressOf(ctx.Clone(expr));
auto* decl = b.Decl(b.Let(name, ptr));
@ -77,7 +74,7 @@ struct ExpandCompoundAssignment::State {
};
// Helper function to hoist `expr` to a let declaration.
auto hoist_expr_to_let = [&](const ast::Expression* expr) {
auto hoist_expr_to_let = [&](const Expression* expr) {
auto name = b.Sym();
auto* decl = b.Decl(b.Let(name, ctx.Clone(expr)));
hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
@ -85,7 +82,7 @@ struct ExpandCompoundAssignment::State {
};
// Helper function that returns `true` if the type of `expr` is a vector.
auto is_vec = [&](const ast::Expression* expr) {
auto is_vec = [&](const Expression* expr) {
if (auto* val_expr = ctx.src->Sem().GetVal(expr)) {
return val_expr->Type()->UnwrapRef()->Is<type::Vector>();
}
@ -96,10 +93,10 @@ struct ExpandCompoundAssignment::State {
// LHS that we can evaluate twice.
// We need to special case compound assignments to vector components since
// we cannot take the address of a vector component.
auto* index_accessor = lhs->As<ast::IndexAccessorExpression>();
auto* member_accessor = lhs->As<ast::MemberAccessorExpression>();
if (lhs->Is<ast::IdentifierExpression>() ||
(member_accessor && member_accessor->object->Is<ast::IdentifierExpression>())) {
auto* index_accessor = lhs->As<IndexAccessorExpression>();
auto* member_accessor = lhs->As<MemberAccessorExpression>();
if (lhs->Is<IdentifierExpression>() ||
(member_accessor && member_accessor->object->Is<IdentifierExpression>())) {
// This is the simple case with no side effects, so we can just use the
// original LHS expression directly.
// Before:
@ -144,7 +141,7 @@ struct ExpandCompoundAssignment::State {
}
// Replace the statement with a regular assignment statement.
auto* value = b.create<ast::BinaryExpression>(op, new_lhs(), rhs);
auto* value = b.create<BinaryExpression>(op, new_lhs(), rhs);
ctx.Replace(stmt, b.Assign(new_lhs(), value));
}
@ -174,11 +171,11 @@ Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src,
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state(ctx);
for (auto* node : src->ASTNodes().Objects()) {
if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
if (auto* assign = node->As<CompoundAssignmentStatement>()) {
state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
} else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
} else if (auto* inc_dec = node->As<IncrementDecrementStatement>()) {
// For increment/decrement statements, `i++` becomes `i = i + 1`.
auto op = inc_dec->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract;
auto op = inc_dec->increment ? BinaryOp::kAdd : BinaryOp::kSubtract;
state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
}
}

View File

@ -38,7 +38,7 @@ constexpr char kFirstInstanceName[] = "first_instance_index";
bool ShouldRun(const Program* program) {
for (auto* fn : program->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
if (fn->PipelineStage() == PipelineStage::kVertex) {
return true;
}
}
@ -86,9 +86,9 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
// Traverse the AST scanning for builtin accesses via variables (includes
// parameters) or structure member accesses.
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Variable>()) {
if (auto* var = node->As<Variable>()) {
for (auto* attr : var->attributes) {
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kVertexIndex) {
auto* sem_var = ctx.src->Sem().Get(var);
@ -103,9 +103,9 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
}
}
}
if (auto* member = node->As<ast::StructMember>()) {
if (auto* member = node->As<StructMember>()) {
for (auto* attr : member->attributes) {
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kVertexIndex) {
auto* sem_mem = ctx.src->Sem().Get(member);
@ -124,7 +124,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
if (has_vertex_or_instance_index) {
// Add uniform buffer members and calculate byte offsets
utils::Vector<const ast::StructMember*, 8> members;
utils::Vector<const StructMember*, 8> members;
members.Push(b.Member(kFirstVertexName, b.ty.u32()));
members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
auto* struct_ = b.Structure(b.Sym(), std::move(members));
@ -138,7 +138,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
});
// Fix up all references to the builtins with the offsets
ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
ctx.ReplaceAll([=, &ctx](const Expression* expr) -> const Expression* {
if (auto* sem = ctx.src->Sem().GetVal(expr)) {
if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) {
auto it = builtin_vars.find(user->Variable());

View File

@ -26,7 +26,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::ForLoopStatement>()) {
if (node->Is<ForLoopStatement>()) {
return true;
}
}
@ -47,8 +47,8 @@ Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&,
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
utils::Vector<const ast::Statement*, 8> stmts;
ctx.ReplaceAll([&](const ForLoopStatement* for_loop) -> const Statement* {
utils::Vector<const Statement*, 8> stmts;
if (auto* cond = for_loop->condition) {
// !condition
auto* not_cond = b.Not(ctx.Clone(cond));
@ -63,7 +63,7 @@ Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&,
stmts.Push(ctx.Clone(stmt));
}
const ast::BlockStatement* continuing = nullptr;
const BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
continuing = b.Block(ctx.Clone(cont));
}

View File

@ -43,14 +43,14 @@ struct LocalizeStructArrayAssignment::State {
ApplyResult Run() {
struct Shared {
bool process_nested_nodes = false;
utils::Vector<const ast::Statement*, 4> insert_before_stmts;
utils::Vector<const ast::Statement*, 4> insert_after_stmts;
utils::Vector<const Statement*, 4> insert_before_stmts;
utils::Vector<const Statement*, 4> insert_after_stmts;
} s;
bool made_changes = false;
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* assign_stmt = node->As<ast::AssignmentStatement>()) {
if (auto* assign_stmt = node->As<AssignmentStatement>()) {
// Process if it's an assignment statement to a dynamically indexed array
// within a struct on a function or private storage variable. This
// specific use-case is what FXC fails to compile with:
@ -70,7 +70,7 @@ struct LocalizeStructArrayAssignment::State {
// Reset shared state for this assignment statement
s = Shared{};
const ast::Expression* new_lhs = nullptr;
const Expression* new_lhs = nullptr;
{
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
new_lhs = ctx.Clone(assign_stmt->lhs);
@ -98,14 +98,13 @@ struct LocalizeStructArrayAssignment::State {
return SkipTransform;
}
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* index_access) -> const ast::Expression* {
ctx.ReplaceAll([&](const IndexAccessorExpression* index_access) -> const Expression* {
if (!s.process_nested_nodes) {
return nullptr;
}
// Indexing a member access expr?
auto* mem_access = index_access->object->As<ast::MemberAccessorExpression>();
auto* mem_access = index_access->object->As<MemberAccessorExpression>();
if (!mem_access) {
return nullptr;
}
@ -136,7 +135,7 @@ struct LocalizeStructArrayAssignment::State {
// e.g. *(tint_symbol) = tint_symbol_1;
auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
{
utils::Vector<const ast::Statement*, 8> stmts{assign_rhs_to_temp};
utils::Vector<const Statement*, 8> stmts{assign_rhs_to_temp};
for (auto* stmt : s.insert_after_stmts) {
stmts.Push(stmt);
}
@ -160,23 +159,22 @@ struct LocalizeStructArrayAssignment::State {
/// Returns true if `expr` contains an index accessor expression to a
/// structure member of array type.
bool ContainsStructArrayIndex(const ast::Expression* expr) {
bool ContainsStructArrayIndex(const Expression* expr) {
bool result = false;
ast::TraverseExpressions(
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
TraverseExpressions(expr, b.Diagnostics(), [&](const IndexAccessorExpression* ia) {
// Indexing using a runtime value?
auto* idx_sem = src->Sem().GetVal(ia->index);
if (!idx_sem->ConstantValue()) {
// Indexing a member access expr?
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
if (auto* ma = ia->object->As<MemberAccessorExpression>()) {
// That accesses an array?
if (src->TypeOf(ma)->UnwrapRef()->Is<type::Array>()) {
result = true;
return ast::TraverseAction::Stop;
return TraverseAction::Stop;
}
}
}
return ast::TraverseAction::Descend;
return TraverseAction::Descend;
});
return result;
@ -186,7 +184,7 @@ struct LocalizeStructArrayAssignment::State {
// of the assignment statement.
// See https://www.w3.org/TR/WGSL/#originating-variable-section
std::pair<const type::Type*, builtin::AddressSpace> GetOriginatingTypeAndAddressSpace(
const ast::AssignmentStatement* assign_stmt) {
const AssignmentStatement* assign_stmt) {
auto* root_ident = src->Sem().GetVal(assign_stmt->lhs)->RootIdentifier();
if (TINT_UNLIKELY(!root_ident)) {
TINT_ICE(Transform, b.Diagnostics())

View File

@ -30,12 +30,12 @@ namespace tint::ast::transform {
namespace {
/// Returns `true` if `stmt` has the behavior `behavior`.
bool HasBehavior(const Program* program, const ast::Statement* stmt, sem::Behavior behavior) {
bool HasBehavior(const Program* program, const Statement* stmt, sem::Behavior behavior) {
return program->Sem().Get(stmt)->Behaviors().Contains(behavior);
}
/// Returns `true` if `func` needs to be transformed.
bool NeedsTransform(const Program* program, const ast::Function* func) {
bool NeedsTransform(const Program* program, const Function* func) {
// Entry points and intrinsic declarations never need transforming.
if (func->IsEntryPoint() || func->body == nullptr) {
return false;
@ -49,7 +49,7 @@ bool NeedsTransform(const Program* program, const ast::Function* func) {
if (HasBehavior(program, s, sem::Behavior::kReturn)) {
// If this statement is itself a return, it will be the only exit point,
// so no need to apply the transform to the function.
if (s->Is<ast::ReturnStatement>()) {
if (s->Is<ReturnStatement>()) {
return false;
} else {
// Apply the transform in all other cases.
@ -78,7 +78,7 @@ class State {
ProgramBuilder& b;
/// The function.
const ast::Function* function;
const Function* function;
/// The symbol for the return flag variable.
Symbol flag;
@ -92,32 +92,32 @@ class State {
public:
/// Constructor
/// @param context the clone context
State(CloneContext& context, const ast::Function* func)
State(CloneContext& context, const Function* func)
: ctx(context), b(*ctx.dst), function(func) {}
/// Process a statement (recursively).
void ProcessStatement(const ast::Statement* stmt) {
void ProcessStatement(const Statement* stmt) {
if (stmt == nullptr || !HasBehavior(ctx.src, stmt, sem::Behavior::kReturn)) {
return;
}
Switch(
stmt, [&](const ast::BlockStatement* block) { ProcessBlock(block); },
[&](const ast::CaseStatement* c) { ProcessStatement(c->body); },
[&](const ast::ForLoopStatement* f) {
stmt, [&](const BlockStatement* block) { ProcessBlock(block); },
[&](const CaseStatement* c) { ProcessStatement(c->body); },
[&](const ForLoopStatement* f) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
ProcessStatement(f->body);
},
[&](const ast::IfStatement* i) {
[&](const IfStatement* i) {
ProcessStatement(i->body);
ProcessStatement(i->else_statement);
},
[&](const ast::LoopStatement* l) {
[&](const LoopStatement* l) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
ProcessStatement(l->body);
},
[&](const ast::ReturnStatement* r) {
utils::Vector<const ast::Statement*, 3> stmts;
[&](const ReturnStatement* r) {
utils::Vector<const Statement*, 3> stmts;
// Set the return flag to signal that we have hit a return.
stmts.Push(b.Assign(b.Expr(flag), true));
if (r->value) {
@ -130,25 +130,25 @@ class State {
}
ctx.Replace(r, b.Block(std::move(stmts)));
},
[&](const ast::SwitchStatement* s) {
[&](const SwitchStatement* s) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
for (auto* c : s->body) {
ProcessStatement(c);
}
},
[&](const ast::WhileStatement* w) {
[&](const WhileStatement* w) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
ProcessStatement(w->body);
},
[&](Default) { TINT_ICE(Transform, b.Diagnostics()) << "unhandled statement type"; });
}
void ProcessBlock(const ast::BlockStatement* block) {
void ProcessBlock(const BlockStatement* block) {
// We will rebuild the contents of the block statement.
// We may introduce conditionals around statements that follow a statement with the
// `Return` behavior, so build a stack of statement lists that represent the new
// (potentially nested) conditional blocks.
utils::Vector<utils::Vector<const ast::Statement*, 8>, 8> new_stmts({{}});
utils::Vector<utils::Vector<const Statement*, 8>, 8> new_stmts({{}});
// Insert variables for the return flag and return value at the top of the function.
if (block == function->body) {
@ -173,8 +173,7 @@ class State {
if (is_in_loop_or_switch) {
// We're in a loop/switch, and so we would have inserted a `break`.
// If we've just come out of a loop/switch statement, we need to `break` again.
if (s->IsAnyOf<ast::LoopStatement, ast::ForLoopStatement,
ast::SwitchStatement>()) {
if (s->IsAnyOf<LoopStatement, ForLoopStatement, SwitchStatement>()) {
// If the loop only has the 'Return' behavior, we can just unconditionally
// break. Otherwise check the return flag.
if (HasBehavior(ctx.src, s, sem::Behavior::kNext)) {
@ -194,7 +193,7 @@ class State {
// Descend the stack of new block statements, wrapping them in conditionals.
while (new_stmts.Length() > 1) {
const ast::IfStatement* i = nullptr;
const IfStatement* i = nullptr;
if (new_stmts.Back().Length() > 0) {
i = b.If(b.Not(b.Expr(flag)), b.Block(new_stmts.Back()));
}

View File

@ -33,14 +33,14 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ModuleScopeVarToEntryPointParam)
namespace tint::ast::transform {
namespace {
using StructMemberList = utils::Vector<const ast::StructMember*, 8>;
using StructMemberList = utils::Vector<const StructMember*, 8>;
// The name of the struct member for arrays that are wrapped in structures.
const char* kWrappedArrayMemberName = "arr";
bool ShouldRun(const Program* program) {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (decl->Is<ast::Variable>()) {
if (decl->Is<Variable>()) {
return true;
}
}
@ -110,7 +110,7 @@ struct ModuleScopeVarToEntryPointParam::State {
/// @param workgroup_parameter_members reference to a list of a workgroup struct members
/// @param is_pointer output signalling whether the replacement is a pointer
/// @param is_wrapped output signalling whether the replacement is wrapped in a struct
void ProcessVariableInEntryPoint(const ast::Function* func,
void ProcessVariableInEntryPoint(const Function* func,
const sem::Variable* var,
Symbol new_var_symbol,
std::function<Symbol()> workgroup_param,
@ -128,7 +128,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// For a texture or sampler variable, redeclare it as an entry point parameter.
// Disable entry point parameter validation.
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
ctx.dst->Disable(DisabledValidation::kEntryPointParameter);
auto attrs = ctx.Clone(var->Declaration()->attributes);
attrs.Push(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
@ -141,8 +141,8 @@ struct ModuleScopeVarToEntryPointParam::State {
// Variables into the Storage and Uniform address spaces are redeclared as entry
// point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->attributes);
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
attributes.Push(ctx.dst->Disable(DisabledValidation::kEntryPointParameter));
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
auto param_type = store_type();
if (auto* arr = ty->As<type::Array>();
@ -190,7 +190,7 @@ struct ModuleScopeVarToEntryPointParam::State {
is_pointer = true;
} else {
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace);
ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace);
auto* initializer = ctx.Clone(var->Declaration()->initializer);
auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer,
utils::Vector{disable_validation});
@ -218,7 +218,7 @@ struct ModuleScopeVarToEntryPointParam::State {
/// @param var the variable
/// @param new_var_symbol the symbol to use for the replacement
/// @param is_pointer output signalling whether the replacement is a pointer or not
void ProcessVariableInUserFunction(const ast::Function* func,
void ProcessVariableInUserFunction(const Function* func,
const sem::Variable* var,
Symbol new_var_symbol,
bool& is_pointer) {
@ -247,7 +247,7 @@ struct ModuleScopeVarToEntryPointParam::State {
}
// Use a pointer for non-handle types.
utils::Vector<const ast::Attribute*, 2> attributes;
utils::Vector<const Attribute*, 2> attributes;
if (!ty->is_handle()) {
param_type = sc == builtin::AddressSpace::kStorage
? ctx.dst->ty.pointer(param_type, sc, var->Access())
@ -255,9 +255,8 @@ struct ModuleScopeVarToEntryPointParam::State {
is_pointer = true;
// Disable validation of the parameter's address space and of arguments passed to it.
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
attributes.Push(
ctx.dst->Disable(ast::DisabledValidation::kIgnoreInvalidPointerArgument));
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreInvalidPointerArgument));
}
// Redeclare the variable as a parameter.
@ -271,18 +270,18 @@ struct ModuleScopeVarToEntryPointParam::State {
/// @param new_var the symbol to use for replacement
/// @param is_pointer true if `new_var` is a pointer to the new variable
/// @param member_name if valid, the name of the struct member that holds this variable
void ReplaceUsesInFunction(const ast::Function* func,
void ReplaceUsesInFunction(const Function* func,
const sem::Variable* var,
Symbol new_var,
bool is_pointer,
Symbol member_name) {
for (auto* user : var->Users()) {
if (user->Stmt()->Function()->Declaration() == func) {
const ast::Expression* expr = ctx.dst->Expr(new_var);
const Expression* expr = ctx.dst->Expr(new_var);
if (is_pointer) {
// If this identifier is used by an address-of operator, just remove the
// address-of instead of adding a deref, since we already have a pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
auto* ident = user->Declaration()->As<IdentifierExpression>();
if (ident_to_address_of_.count(ident) && !member_name.IsValid()) {
ctx.Replace(ident_to_address_of_[ident], expr);
continue;
@ -302,19 +301,19 @@ struct ModuleScopeVarToEntryPointParam::State {
/// Process the module.
void Process() {
// Predetermine the list of function calls that need to be replaced.
using CallList = utils::Vector<const ast::CallExpression*, 8>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
using CallList = utils::Vector<const CallExpression*, 8>;
std::unordered_map<const Function*, CallList> calls_to_replace;
utils::Vector<const ast::Function*, 8> functions_to_process;
utils::Vector<const Function*, 8> functions_to_process;
// Collect private variables into a single structure.
StructMemberList private_struct_members;
utils::Vector<std::function<const ast::AssignmentStatement*()>, 4> private_initializers;
std::unordered_set<const ast::Function*> uses_privates;
utils::Vector<std::function<const AssignmentStatement*()>, 4> private_initializers;
std::unordered_set<const Function*> uses_privates;
// Build a list of functions that transitively reference any module-scope variables.
for (auto* decl : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
if (auto* var = decl->As<ast::Var>()) {
if (auto* var = decl->As<Var>()) {
auto* sem_var = ctx.src->Sem().Get(var);
if (sem_var->AddressSpace() == builtin::AddressSpace::kPrivate) {
// Create a member in the private variable struct.
@ -335,7 +334,7 @@ struct ModuleScopeVarToEntryPointParam::State {
continue;
}
auto* func_ast = decl->As<ast::Function>();
auto* func_ast = decl->As<Function>();
if (!func_ast) {
continue;
}
@ -376,11 +375,11 @@ struct ModuleScopeVarToEntryPointParam::State {
// TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can
// do this on the fly instead.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
auto* address_of = node->As<UnaryOpExpression>();
if (!address_of || address_of->op != UnaryOp::kAddressOf) {
continue;
}
if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
if (auto* ident = address_of->expr->As<IdentifierExpression>()) {
ident_to_address_of_[ident] = address_of;
}
}
@ -414,11 +413,11 @@ struct ModuleScopeVarToEntryPointParam::State {
if (uses_privates.count(func_ast)) {
if (is_entry_point) {
// Create a local declaration for the private variable struct.
auto* var = ctx.dst->Var(
PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()),
auto* var =
ctx.dst->Var(PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()),
builtin::AddressSpace::kPrivate,
utils::Vector{
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace),
ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
});
ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(var));
@ -482,7 +481,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Allow pointer aliasing if needed.
if (needs_pointer_aliasing) {
ctx.InsertBack(func_ast->attributes,
ctx.dst->Disable(ast::DisabledValidation::kIgnorePointerAliasing));
ctx.dst->Disable(DisabledValidation::kIgnorePointerAliasing));
}
if (!workgroup_parameter_members.IsEmpty()) {
@ -492,11 +491,11 @@ struct ModuleScopeVarToEntryPointParam::State {
ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
auto param_type =
ctx.dst->ty.pointer(ctx.dst->ty.Of(str), builtin::AddressSpace::kWorkgroup);
auto* param = ctx.dst->Param(
workgroup_param(), param_type,
auto* param =
ctx.dst->Param(workgroup_param(), param_type,
utils::Vector{
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter),
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace),
ctx.dst->Disable(DisabledValidation::kEntryPointParameter),
ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
});
ctx.InsertFront(func_ast->params, param);
}
@ -508,7 +507,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Pass the private variable struct pointer if needed.
if (uses_privates.count(target_sem->Declaration())) {
const ast::Expression* arg = ctx.dst->Expr(PrivateStructVariableName());
const Expression* arg = ctx.dst->Expr(PrivateStructVariableName());
if (is_entry_point) {
arg = ctx.dst->AddressOf(arg);
}
@ -531,7 +530,7 @@ struct ModuleScopeVarToEntryPointParam::State {
auto new_var = it->second;
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
const Expression* arg = ctx.dst->Expr(new_var.symbol);
if (new_var.is_wrapped) {
// The variable is wrapped in a struct, so we need to pass a pointer to the
// struct member instead.
@ -577,8 +576,7 @@ struct ModuleScopeVarToEntryPointParam::State {
std::unordered_set<const sem::Struct*> cloned_structs_;
// Map from identifier expression to the address-of expression that uses it.
std::unordered_map<const ast::IdentifierExpression*, const ast::UnaryOpExpression*>
ident_to_address_of_;
std::unordered_map<const IdentifierExpression*, const UnaryOpExpression*> ident_to_address_of_;
// The name of the structure that contains all the module-scope private variables.
Symbol private_struct_name;

View File

@ -143,7 +143,7 @@ struct MultiplanarExternalTexture::State {
// Replace the original texture_external binding with a texture_2d<f32> binding.
auto cloned_attributes = ctx.Clone(global->attributes);
const ast::Expression* cloned_initializer = ctx.Clone(global->initializer);
const Expression* cloned_initializer = ctx.Clone(global->initializer);
auto* replacement =
b.Var(syms.plane_0, b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32()),
@ -153,7 +153,7 @@ struct MultiplanarExternalTexture::State {
// We must update all the texture_external parameters for user declared functions.
for (auto* fn : ctx.src->AST().Functions()) {
for (const ast::Variable* param : fn->params) {
for (const Variable* param : fn->params) {
if (auto* sem_var = sem.Get(param)) {
if (!sem_var->Type()->UnwrapRef()->Is<type::ExternalTexture>()) {
continue;
@ -184,7 +184,7 @@ struct MultiplanarExternalTexture::State {
// Transform the external texture builtin calls into calls to the external texture
// functions.
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* builtin = call->Target()->As<sem::Builtin>();
@ -305,10 +305,10 @@ struct MultiplanarExternalTexture::State {
/// @param call_type determines which function body to generate
/// @returns a statement list that makes of the body of the chosen function
auto buildTextureBuiltinBody(builtin::Function call_type) {
utils::Vector<const ast::Statement*, 16> stmts;
const ast::CallExpression* single_plane_call = nullptr;
const ast::CallExpression* plane_0_call = nullptr;
const ast::CallExpression* plane_1_call = nullptr;
utils::Vector<const Statement*, 16> stmts;
const CallExpression* single_plane_call = nullptr;
const CallExpression* plane_0_call = nullptr;
const CallExpression* plane_1_call = nullptr;
switch (call_type) {
case builtin::Function::kTextureSampleBaseClampToEdge:
stmts.Push(b.Decl(b.Let(
@ -395,9 +395,9 @@ struct MultiplanarExternalTexture::State {
/// @param expr the call expression being transformed
/// @param syms the expanded symbols to be used in the new call
/// @returns a call expression to textureSampleExternal
const ast::CallExpression* createTextureSampleBaseClampToEdge(const ast::CallExpression* expr,
const CallExpression* createTextureSampleBaseClampToEdge(const CallExpression* expr,
NewBindingSymbols syms) {
const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
const Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
if (TINT_UNLIKELY(expr->args.Length() != 3)) {
TINT_ICE(Transform, b.Diagnostics())
@ -443,7 +443,7 @@ struct MultiplanarExternalTexture::State {
/// @param call the call expression being transformed
/// @param syms the expanded symbols to be used in the new call
/// @returns a call expression to textureLoadExternal
const ast::CallExpression* createTextureLoad(const sem::Call* call, NewBindingSymbols syms) {
const CallExpression* createTextureLoad(const sem::Call* call, NewBindingSymbols syms) {
if (TINT_UNLIKELY(call->Arguments().Length() != 2)) {
TINT_ICE(Transform, b.Diagnostics())
<< "expected textureLoad call with a texture_external to have 2 arguments, found "

View File

@ -33,7 +33,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
if (auto* attr = node->As<BuiltinAttribute>()) {
if (program->Sem().Get(attr)->Value() == builtin::BuiltinValue::kNumWorkgroups) {
return true;
}
@ -86,7 +86,7 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
for (auto* func : src->AST().Functions()) {
// num_workgroups is only valid for compute stages.
if (func->PipelineStage() != ast::PipelineStage::kCompute) {
if (func->PipelineStage() != PipelineStage::kCompute) {
continue;
}
@ -126,7 +126,7 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
// Get (or create, on first call) the uniform buffer that will receive the
// number of workgroups.
const ast::Variable* num_workgroups_ubo = nullptr;
const Variable* num_workgroups_ubo = nullptr;
auto get_ubo = [&]() {
if (!num_workgroups_ubo) {
auto* num_workgroups_struct =
@ -166,11 +166,11 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
// Now replace all the places where the builtins are accessed with the value
// loaded from the uniform buffer.
for (auto* node : src->ASTNodes().Objects()) {
auto* accessor = node->As<ast::MemberAccessorExpression>();
auto* accessor = node->As<MemberAccessorExpression>();
if (!accessor) {
continue;
}
auto* ident = accessor->object->As<ast::IdentifierExpression>();
auto* ident = accessor->object->As<IdentifierExpression>();
if (!ident) {
continue;
}

View File

@ -94,7 +94,7 @@ struct PackedVec3::State {
/// Create a `__packed_vec3` type with the same element type as `ty`.
/// @param ty a three-element vector type
/// @returns the new AST type
ast::Type MakePackedVec3(const type::Type* ty) {
Type MakePackedVec3(const type::Type* ty) {
auto* vec = ty->As<type::Vector>();
TINT_ASSERT(Transform, vec != nullptr && vec->Width() == 3);
return b.ty(builtin::Builtin::kPackedVec3, CreateASTTypeFor(ctx, vec->type()));
@ -109,10 +109,10 @@ struct PackedVec3::State {
/// @param ty the type to rewrite
/// @param array_element `true` if this is being called for the element of an array
/// @returns the new AST type, or nullptr if rewriting was not necessary
ast::Type RewriteType(const type::Type* ty, bool array_element = false) {
Type RewriteType(const type::Type* ty, bool array_element = false) {
return Switch(
ty,
[&](const type::Vector* vec) -> ast::Type {
[&](const type::Vector* vec) -> Type {
if (IsVec3(vec)) {
if (array_element) {
// Create a struct with a single `__packed_vec3` member.
@ -134,7 +134,7 @@ struct PackedVec3::State {
}
return {};
},
[&](const type::Matrix* mat) -> ast::Type {
[&](const type::Matrix* mat) -> Type {
// Rewrite the matrix as an array of columns that use the aligned wrapper struct.
auto new_col_type = RewriteType(mat->ColumnType(), /* array_element */ true);
if (new_col_type) {
@ -142,11 +142,11 @@ struct PackedVec3::State {
}
return {};
},
[&](const type::Array* arr) -> ast::Type {
[&](const type::Array* arr) -> Type {
// Rewrite the array with the modified element type.
auto new_type = RewriteType(arr->ElemType(), /* array_element */ true);
if (new_type) {
utils::Vector<const ast::Attribute*, 1> attrs;
utils::Vector<const Attribute*, 1> attrs;
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
return b.ty.array(new_type, std::move(attrs));
} else if (auto count = arr->ConstantCount()) {
@ -159,21 +159,21 @@ struct PackedVec3::State {
}
return {};
},
[&](const type::Struct* str) -> ast::Type {
[&](const type::Struct* str) -> Type {
if (ContainsVec3(str)) {
auto name = rewritten_structs.GetOrCreate(str, [&]() {
utils::Vector<const ast::StructMember*, 4> members;
utils::Vector<const StructMember*, 4> members;
for (auto* member : str->Members()) {
// If the member type contains a vec3, rewrite it.
auto new_type = RewriteType(member->Type());
if (new_type) {
// Copy the member attributes.
bool needs_align = true;
utils::Vector<const ast::Attribute*, 4> attributes;
utils::Vector<const Attribute*, 4> attributes;
if (auto* sem_mem = member->As<sem::StructMember>()) {
for (auto* attr : sem_mem->Declaration()->attributes) {
if (attr->IsAnyOf<ast::StructMemberAlignAttribute,
ast::StructMemberOffsetAttribute>()) {
if (attr->IsAnyOf<StructMemberAlignAttribute,
StructMemberOffsetAttribute>()) {
needs_align = false;
}
attributes.Push(ctx.Clone(attr));
@ -219,12 +219,12 @@ struct PackedVec3::State {
Symbol MakePackUnpackHelper(
const char* name_prefix,
const type::Type* ty,
const std::function<const ast::Expression*(const ast::Expression*, const type::Type*)>&
const std::function<const Expression*(const Expression*, const type::Type*)>&
pack_or_unpack_element,
const std::function<ast::Type()>& in_type,
const std::function<ast::Type()>& out_type) {
const std::function<Type()>& in_type,
const std::function<Type()>& out_type) {
// Allocate a variable to hold the return value of the function.
utils::Vector<const ast::Statement*, 4> statements;
utils::Vector<const Statement*, 4> statements;
statements.Push(b.Decl(b.Var("result", out_type())));
// Helper that generates a loop to copy and pack/unpack elements of an array to the result:
@ -256,7 +256,7 @@ struct PackedVec3::State {
[&](const type::Struct* str) {
// Copy the struct members over one at a time, packing/unpacking as necessary.
for (auto* member : str->Members()) {
const ast::Expression* element =
const Expression* element =
b.MemberAccessor("in", b.Ident(ctx.Clone(member->Name())));
if (ContainsVec3(member->Type())) {
element = pack_or_unpack_element(element, member->Type());
@ -280,16 +280,16 @@ struct PackedVec3::State {
/// @param expr the composite value expression to unpack
/// @param ty the unpacked type
/// @returns an expression that holds the unpacked value
const ast::Expression* UnpackComposite(const ast::Expression* expr, const type::Type* ty) {
const Expression* UnpackComposite(const Expression* expr, const type::Type* ty) {
auto helper = unpack_helpers.GetOrCreate(ty, [&]() {
return MakePackUnpackHelper(
"tint_unpack_vec3_in_composite", ty,
[&](const ast::Expression* element,
const type::Type* element_type) -> const ast::Expression* {
[&](const Expression* element,
const type::Type* element_type) -> const Expression* {
if (element_type->Is<type::Vector>()) {
// Unpack a `__packed_vec3` by casting it to a regular vec3.
// If it is an array element, extract the vector from the wrapper struct.
if (element->Is<ast::IndexAccessorExpression>()) {
if (element->Is<IndexAccessorExpression>()) {
element = b.MemberAccessor(element, kStructMemberName);
}
return b.Call(CreateASTTypeFor(ctx, element_type), element);
@ -308,17 +308,17 @@ struct PackedVec3::State {
/// @param expr the composite value expression to pack
/// @param ty the unpacked type
/// @returns an expression that holds the packed value
const ast::Expression* PackComposite(const ast::Expression* expr, const type::Type* ty) {
const Expression* PackComposite(const Expression* expr, const type::Type* ty) {
auto helper = pack_helpers.GetOrCreate(ty, [&]() {
return MakePackUnpackHelper(
"tint_pack_vec3_in_composite", ty,
[&](const ast::Expression* element,
const type::Type* element_type) -> const ast::Expression* {
[&](const Expression* element,
const type::Type* element_type) -> const Expression* {
if (element_type->Is<type::Vector>()) {
// Pack a vector element by casting it to a packed_vec3.
// If it is an array element, construct a wrapper struct.
auto* packed = b.Call(MakePackedVec3(element_type), element);
if (element->Is<ast::IndexAccessorExpression>()) {
if (element->Is<IndexAccessorExpression>()) {
packed = b.Call(RewriteType(element_type, true), packed);
}
return packed;
@ -400,7 +400,7 @@ struct PackedVec3::State {
},
[&](const sem::Statement* stmt) {
// Pack the RHS of assignment statements that are writing to packed types.
if (auto* assign = stmt->Declaration()->As<ast::AssignmentStatement>()) {
if (auto* assign = stmt->Declaration()->As<AssignmentStatement>()) {
auto* lhs = sem.GetVal(assign->lhs);
auto* rhs = sem.GetVal(assign->rhs);
if (!ContainsVec3(rhs->Type()) ||
@ -463,7 +463,7 @@ struct PackedVec3::State {
for (auto* expr : to_unpack_sorted) {
TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
auto* packed = ctx.Clone(expr->Declaration());
const ast::Expression* unpacked = nullptr;
const Expression* unpacked = nullptr;
if (IsVec3(expr->Type())) {
if (expr->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
// If we are unpacking a vec3 from an array element, extract the vector from the
@ -484,7 +484,7 @@ struct PackedVec3::State {
for (auto* expr : to_pack_sorted) {
TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
auto* unpacked = ctx.Clone(expr->Declaration());
const ast::Expression* packed = nullptr;
const Expression* packed = nullptr;
if (IsVec3(expr->Type())) {
// Cast the regular vec3 to a packed vector type.
packed = b.Call(MakePackedVec3(expr->Type()), unpacked);

View File

@ -33,8 +33,8 @@ namespace tint::ast::transform {
namespace {
void CreatePadding(utils::Vector<const ast::StructMember*, 8>* new_members,
utils::Hashset<const ast::StructMember*, 8>* padding_members,
void CreatePadding(utils::Vector<const StructMember*, 8>* new_members,
utils::Hashset<const StructMember*, 8>* padding_members,
ProgramBuilder* b,
uint32_t bytes) {
const size_t count = bytes / 4u;
@ -59,17 +59,17 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto& sem = src->Sem();
std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs;
utils::Hashset<const ast::StructMember*, 8> padding_members;
std::unordered_map<const Struct*, const Struct*> replaced_structs;
utils::Hashset<const StructMember*, 8> padding_members;
ctx.ReplaceAll([&](const ast::Struct* ast_str) -> const ast::Struct* {
ctx.ReplaceAll([&](const Struct* ast_str) -> const Struct* {
auto* str = sem.Get(ast_str);
if (!str || !str->IsHostShareable()) {
return nullptr;
}
uint32_t offset = 0;
bool has_runtime_sized_array = false;
utils::Vector<const ast::StructMember*, 8> new_members;
utils::Vector<const StructMember*, 8> new_members;
for (auto* mem : str->Members()) {
auto name = mem->Name().Name();
@ -104,19 +104,18 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
}
utils::Vector<const ast::Attribute*, 1> struct_attribs;
utils::Vector<const Attribute*, 1> struct_attribs;
if (!padding_members.IsEmpty()) {
struct_attribs =
utils::Vector{b.Disable(ast::DisabledValidation::kIgnoreStructMemberLimit)};
struct_attribs = utils::Vector{b.Disable(DisabledValidation::kIgnoreStructMemberLimit)};
}
auto* new_struct = b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members),
auto* new_struct = b.create<Struct>(ctx.Clone(ast_str->name), std::move(new_members),
std::move(struct_attribs));
replaced_structs[ast_str] = new_struct;
return new_struct;
});
ctx.ReplaceAll([&](const ast::CallExpression* ast_call) -> const ast::CallExpression* {
ctx.ReplaceAll([&](const CallExpression* ast_call) -> const CallExpression* {
if (ast_call->args.Length() == 0) {
return nullptr;
}
@ -139,7 +138,7 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
return nullptr;
}
utils::Vector<const ast::Expression*, 8> new_args;
utils::Vector<const Expression*, 8> new_args;
auto* arg = ast_call->args.begin();
for (auto* member : new_struct->members) {

View File

@ -44,13 +44,13 @@ struct PreservePadding::State {
/// @returns the ApplyResult
ApplyResult Run() {
// Gather a list of assignments that need to be transformed.
std::unordered_set<const ast::AssignmentStatement*> assignments_to_transform;
std::unordered_set<const AssignmentStatement*> assignments_to_transform;
for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch(
node, //
[&](const ast::AssignmentStatement* assign) {
[&](const AssignmentStatement* assign) {
auto* ty = sem.GetVal(assign->lhs)->Type();
if (assign->lhs->Is<ast::PhonyExpression>()) {
if (assign->lhs->Is<PhonyExpression>()) {
// Ignore phony assignment.
return;
}
@ -65,7 +65,7 @@ struct PreservePadding::State {
assignments_to_transform.insert(assign);
}
},
[&](const ast::Enable* enable) {
[&](const Enable* enable) {
// Check if the full pointer parameters extension is already enabled.
if (enable->HasExtension(
builtin::Extension::kChromiumExperimentalFullPtrParameters)) {
@ -78,7 +78,7 @@ struct PreservePadding::State {
}
// Replace all assignments that include padding with decomposed versions.
ctx.ReplaceAll([&](const ast::AssignmentStatement* assign) -> const ast::Statement* {
ctx.ReplaceAll([&](const AssignmentStatement* assign) -> const Statement* {
if (!assignments_to_transform.count(assign)) {
return nullptr;
}
@ -96,9 +96,9 @@ struct PreservePadding::State {
/// @param lhs the lhs expression (in the destination program)
/// @param rhs the rhs expression (in the destination program)
/// @returns the statement that performs the assignment
const ast::Statement* MakeAssignment(const type::Type* ty,
const ast::Expression* lhs,
const ast::Expression* rhs) {
const Statement* MakeAssignment(const type::Type* ty,
const Expression* lhs,
const Expression* rhs) {
if (!HasPadding(ty)) {
// No padding - use a regular assignment.
return b.Assign(lhs, rhs);
@ -120,7 +120,7 @@ struct PreservePadding::State {
EnableExtension();
auto helper = helpers.GetOrCreate(ty, [&]() {
auto helper_name = b.Symbols().New("assign_and_preserve_padding");
utils::Vector<const ast::Parameter*, 2> params = {
utils::Vector<const Parameter*, 2> params = {
b.Param(kDestParamName,
b.ty.pointer(CreateASTTypeFor(ctx, ty), builtin::AddressSpace::kStorage,
builtin::Access::kReadWrite)),
@ -137,7 +137,7 @@ struct PreservePadding::State {
[&](const type::Array* arr) {
// Call a helper function that uses a loop to assigns each element separately.
return call_helper([&]() {
utils::Vector<const ast::Statement*, 8> body;
utils::Vector<const Statement*, 8> body;
auto* idx = b.Var("i", b.Expr(0_u));
body.Push(
b.For(b.Decl(idx), b.LessThan(idx, u32(arr->ConstantCount().value())),
@ -151,7 +151,7 @@ struct PreservePadding::State {
[&](const type::Matrix* mat) {
// Call a helper function that assigns each column separately.
return call_helper([&]() {
utils::Vector<const ast::Statement*, 4> body;
utils::Vector<const Statement*, 4> body;
for (uint32_t i = 0; i < mat->columns(); i++) {
body.Push(MakeAssignment(mat->ColumnType(),
b.IndexAccessor(b.Deref(kDestParamName), u32(i)),
@ -163,7 +163,7 @@ struct PreservePadding::State {
[&](const type::Struct* str) {
// Call a helper function that assigns each member separately.
return call_helper([&]() {
utils::Vector<const ast::Statement*, 8> body;
utils::Vector<const Statement*, 8> body;
for (auto member : str->Members()) {
auto name = member->Name().Name();
body.Push(MakeAssignment(member->Type(),

View File

@ -69,7 +69,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
}
}
if (auto* src_var_decl = expr->Stmt()->Declaration()->As<ast::VariableDeclStatement>()) {
if (auto* src_var_decl = expr->Stmt()->Declaration()->As<VariableDeclStatement>()) {
if (src_var_decl->variable->initializer == expr->Declaration()) {
// This statement is just a variable declaration with the initializer as the
// initializer value. This is what we're attempting to transform to, and so
@ -84,7 +84,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
// A list of expressions that should be hoisted.
utils::Vector<const sem::ValueExpression*, 32> to_hoist;
// A set of expressions that are constant, which _may_ need to be hoisted.
utils::Hashset<const ast::Expression*, 32> const_chains;
utils::Hashset<const Expression*, 32> const_chains;
// Walk the AST nodes. This order guarantees that leaf-expressions are visited first.
for (auto* node : src->ASTNodes().Objects()) {
@ -104,11 +104,9 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
// visit leaf-expressions first, this means the content of const_chains only
// contains the outer-most constant expressions.
auto* expr = sem->Declaration();
bool ok = ast::TraverseExpressions(
expr, b.Diagnostics(), [&](const ast::Expression* child) {
bool ok = TraverseExpressions(expr, b.Diagnostics(), [&](const Expression* child) {
const_chains.Remove(child);
return child == expr ? ast::TraverseAction::Descend
: ast::TraverseAction::Skip;
return child == expr ? TraverseAction::Descend : TraverseAction::Skip;
});
if (!ok) {
return Program(std::move(b));

View File

@ -97,49 +97,48 @@ struct DecomposeSideEffects : tint::utils::Castable<PromoteSideEffectsToDecl, Tr
// need to be hoisted to ensure order of evaluation, both those that give
// side-effects, as well as those that receive, and returns a set of these
// expressions.
using ToHoistSet = std::unordered_set<const ast::Expression*>;
using ToHoistSet = std::unordered_set<const Expression*>;
class DecomposeSideEffects::CollectHoistsState : public StateBase {
// Expressions to hoist because they either cause or receive side-effects.
ToHoistSet to_hoist;
// Used to mark expressions as not or no longer having side-effects.
std::unordered_set<const ast::Expression*> no_side_effects;
std::unordered_set<const Expression*> no_side_effects;
// Returns true if `expr` has side-effects. Unlike invoking
// sem::ValueExpression::HasSideEffects(), this function takes into account whether
// `expr` has been hoisted, returning false in that case. Furthermore, it
// returns the correct result on parent expression nodes by traversing the
// expression tree, memoizing the results to ensure O(1) amortized lookup.
bool HasSideEffects(const ast::Expression* expr) {
bool HasSideEffects(const Expression* expr) {
if (no_side_effects.count(expr)) {
return false;
}
return Switch(
expr,
[&](const ast::CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); },
[&](const ast::BinaryExpression* e) {
expr, [&](const CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); },
[&](const BinaryExpression* e) {
if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) {
return true;
}
no_side_effects.insert(e);
return false;
},
[&](const ast::IndexAccessorExpression* e) {
[&](const IndexAccessorExpression* e) {
if (HasSideEffects(e->object) || HasSideEffects(e->index)) {
return true;
}
no_side_effects.insert(e);
return false;
},
[&](const ast::MemberAccessorExpression* e) {
[&](const MemberAccessorExpression* e) {
if (HasSideEffects(e->object)) {
return true;
}
no_side_effects.insert(e);
return false;
},
[&](const ast::BitcastExpression* e) { //
[&](const BitcastExpression* e) { //
if (HasSideEffects(e->expr)) {
return true;
}
@ -147,22 +146,22 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return false;
},
[&](const ast::UnaryOpExpression* e) { //
[&](const UnaryOpExpression* e) { //
if (HasSideEffects(e->expr)) {
return true;
}
no_side_effects.insert(e);
return false;
},
[&](const ast::IdentifierExpression* e) {
[&](const IdentifierExpression* e) {
no_side_effects.insert(e);
return false;
},
[&](const ast::LiteralExpression* e) {
[&](const LiteralExpression* e) {
no_side_effects.insert(e);
return false;
},
[&](const ast::PhonyExpression* e) {
[&](const PhonyExpression* e) {
no_side_effects.insert(e);
return false;
},
@ -173,14 +172,14 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
}
// Adds `e` to `to_hoist` for hoisting to a let later on.
void Hoist(const ast::Expression* e) {
void Hoist(const Expression* e) {
no_side_effects.insert(e);
to_hoist.emplace(e);
}
// Hoists any expressions in `maybe_hoist` and clears it
template <size_t N>
void Flush(tint::utils::Vector<const ast::Expression*, N>& maybe_hoist) {
void Flush(tint::utils::Vector<const Expression*, N>& maybe_hoist) {
for (auto* m : maybe_hoist) {
Hoist(m);
}
@ -198,13 +197,13 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
// over-hoist the lhs expressions, as these may be be chained to refer to a
// single memory location.
template <size_t N>
bool ProcessExpression(const ast::Expression* expr,
tint::utils::Vector<const ast::Expression*, N>& maybe_hoist) {
auto process = [&](const ast::Expression* e) -> bool {
bool ProcessExpression(const Expression* expr,
tint::utils::Vector<const Expression*, N>& maybe_hoist) {
auto process = [&](const Expression* e) -> bool {
return ProcessExpression(e, maybe_hoist);
};
auto default_process = [&](const ast::Expression* e) {
auto default_process = [&](const Expression* e) {
auto maybe = process(e);
if (maybe) {
maybe_hoist.Push(e);
@ -215,7 +214,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return false;
};
auto binary_process = [&](const ast::Expression* lhs, const ast::Expression* rhs) {
auto binary_process = [&](const Expression* lhs, const Expression* rhs) {
// If neither side causes side-effects, but at least one receives them,
// let parent node hoist. This avoids over-hoisting side-effect receivers
// of compound binary expressions (e.g. for "((a && b) && c) && f()", we
@ -235,8 +234,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return false;
};
auto accessor_process = [&](const ast::Expression* lhs,
const ast::Expression* rhs = nullptr) {
auto accessor_process = [&](const Expression* lhs, const Expression* rhs = nullptr) {
auto maybe = process(lhs);
// If lhs is a variable, let parent node hoist otherwise flush it right
// away. This is to avoid over-hoisting the lhs of accessor chains (e.g.
@ -255,7 +253,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return Switch(
expr,
[&](const ast::CallExpression* e) -> bool {
[&](const CallExpression* e) -> bool {
// We eagerly flush any variables in maybe_hoist for the current
// call expression. Then we scope maybe_hoist to the processing of
// the call args. This ensures that given: g(c, a(0), d) we hoist
@ -276,7 +274,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
// no_side_effects() first.
return true;
},
[&](const ast::IdentifierExpression* e) {
[&](const IdentifierExpression* e) {
if (auto* sem_e = sem.GetVal(e)) {
if (auto* var_user = sem_e->UnwrapLoad()->As<sem::VariableUser>()) {
// Don't hoist constants.
@ -297,7 +295,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
}
return false;
},
[&](const ast::BinaryExpression* e) {
[&](const BinaryExpression* e) {
if (e->IsLogical() && HasSideEffects(e)) {
// Don't hoist children of logical binary expressions with
// side-effects. These will be handled by DecomposeState.
@ -307,27 +305,25 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
}
return binary_process(e->lhs, e->rhs);
},
[&](const ast::BitcastExpression* e) { //
[&](const BitcastExpression* e) { //
return process(e->expr);
},
[&](const ast::UnaryOpExpression* e) { //
[&](const UnaryOpExpression* e) { //
auto r = process(e->expr);
// Don't hoist address-of expressions.
// E.g. for "g(&b, a(0))", we hoist "a(0)" only.
if (e->op == ast::UnaryOp::kAddressOf) {
if (e->op == UnaryOp::kAddressOf) {
return false;
}
return r;
},
[&](const ast::IndexAccessorExpression* e) {
return accessor_process(e->object, e->index);
},
[&](const ast::MemberAccessorExpression* e) { return accessor_process(e->object); },
[&](const ast::LiteralExpression*) {
[&](const IndexAccessorExpression* e) { return accessor_process(e->object, e->index); },
[&](const MemberAccessorExpression* e) { return accessor_process(e->object); },
[&](const LiteralExpression*) {
// Leaf
return false;
},
[&](const ast::PhonyExpression*) {
[&](const PhonyExpression*) {
// Leaf
return false;
},
@ -338,12 +334,12 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
}
// Starts the recursive processing of a statement's expression(s) to hoist side-effects to lets.
void ProcessExpression(const ast::Expression* expr) {
void ProcessExpression(const Expression* expr) {
if (!expr) {
return;
}
tint::utils::Vector<const ast::Expression*, 8> maybe_hoist;
tint::utils::Vector<const Expression*, 8> maybe_hoist;
ProcessExpression(expr, maybe_hoist);
}
@ -354,31 +350,31 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
// Traverse all statements, recursively processing their expression tree(s)
// to hoist side-effects to lets.
for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* stmt = node->As<ast::Statement>();
auto* stmt = node->As<Statement>();
if (!stmt) {
continue;
}
Switch(
stmt, //
[&](const ast::AssignmentStatement* s) {
tint::utils::Vector<const ast::Expression*, 8> maybe_hoist;
[&](const AssignmentStatement* s) {
tint::utils::Vector<const Expression*, 8> maybe_hoist;
ProcessExpression(s->lhs, maybe_hoist);
ProcessExpression(s->rhs, maybe_hoist);
},
[&](const ast::CallStatement* s) { //
[&](const CallStatement* s) { //
ProcessExpression(s->expr);
},
[&](const ast::ForLoopStatement* s) { ProcessExpression(s->condition); },
[&](const ast::WhileStatement* s) { ProcessExpression(s->condition); },
[&](const ast::IfStatement* s) { //
[&](const ForLoopStatement* s) { ProcessExpression(s->condition); },
[&](const WhileStatement* s) { ProcessExpression(s->condition); },
[&](const IfStatement* s) { //
ProcessExpression(s->condition);
},
[&](const ast::ReturnStatement* s) { //
[&](const ReturnStatement* s) { //
ProcessExpression(s->value);
},
[&](const ast::SwitchStatement* s) { ProcessExpression(s->condition); },
[&](const ast::VariableDeclStatement* s) {
[&](const SwitchStatement* s) { ProcessExpression(s->condition); },
[&](const VariableDeclStatement* s) {
ProcessExpression(s->variable->initializer);
});
}
@ -394,20 +390,20 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
ToHoistSet to_hoist;
// Returns true if `binary_expr` should be decomposed for short-circuit eval.
bool IsLogicalWithSideEffects(const ast::BinaryExpression* binary_expr) {
bool IsLogicalWithSideEffects(const BinaryExpression* binary_expr) {
return binary_expr->IsLogical() && (sem.GetVal(binary_expr->lhs)->HasSideEffects() ||
sem.GetVal(binary_expr->rhs)->HasSideEffects());
}
// Recursive function used to decompose an expression for short-circuit eval.
template <size_t N>
const ast::Expression* Decompose(const ast::Expression* expr,
tint::utils::Vector<const ast::Statement*, N>* curr_stmts) {
const Expression* Decompose(const Expression* expr,
tint::utils::Vector<const Statement*, N>* curr_stmts) {
// Helper to avoid passing in same args.
auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); };
// Clones `expr`, possibly hoisting it to a let.
auto clone_maybe_hoisted = [&](const ast::Expression* e) -> const ast::Expression* {
auto clone_maybe_hoisted = [&](const Expression* e) -> const Expression* {
if (to_hoist.count(e)) {
auto name = b.Symbols().New();
auto* v = b.Let(name, ctx.Clone(e));
@ -420,7 +416,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return Switch(
expr,
[&](const ast::BinaryExpression* bin_expr) -> const ast::Expression* {
[&](const BinaryExpression* bin_expr) -> const Expression* {
if (!IsLogicalWithSideEffects(bin_expr)) {
// No short-circuit, emit usual binary expr
ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs));
@ -461,16 +457,16 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
auto name = b.Sym();
curr_stmts->Push(b.Decl(b.Var(name, decompose(bin_expr->lhs))));
const ast::Expression* if_cond = nullptr;
const Expression* if_cond = nullptr;
if (bin_expr->IsLogicalOr()) {
if_cond = b.Not(name);
} else {
if_cond = b.Expr(name);
}
const ast::BlockStatement* if_body = nullptr;
const BlockStatement* if_body = nullptr;
{
tint::utils::Vector<const ast::Statement*, N> stmts;
tint::utils::Vector<const Statement*, N> stmts;
TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts);
auto* new_rhs = decompose(bin_expr->rhs);
curr_stmts->Push(b.Assign(name, new_rhs));
@ -481,36 +477,36 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return b.Expr(name);
},
[&](const ast::IndexAccessorExpression* idx) {
[&](const IndexAccessorExpression* idx) {
ctx.Replace(idx->object, decompose(idx->object));
ctx.Replace(idx->index, decompose(idx->index));
return clone_maybe_hoisted(idx);
},
[&](const ast::BitcastExpression* bitcast) {
[&](const BitcastExpression* bitcast) {
ctx.Replace(bitcast->expr, decompose(bitcast->expr));
return clone_maybe_hoisted(bitcast);
},
[&](const ast::CallExpression* call) {
[&](const CallExpression* call) {
for (auto* a : call->args) {
ctx.Replace(a, decompose(a));
}
return clone_maybe_hoisted(call);
},
[&](const ast::MemberAccessorExpression* member) {
[&](const MemberAccessorExpression* member) {
ctx.Replace(member->object, decompose(member->object));
return clone_maybe_hoisted(member);
},
[&](const ast::UnaryOpExpression* unary) {
[&](const UnaryOpExpression* unary) {
ctx.Replace(unary->expr, decompose(unary->expr));
return clone_maybe_hoisted(unary);
},
[&](const ast::LiteralExpression* lit) {
[&](const LiteralExpression* lit) {
return clone_maybe_hoisted(lit); // Leaf expression, just clone as is
},
[&](const ast::IdentifierExpression* id) {
[&](const IdentifierExpression* id) {
return clone_maybe_hoisted(id); // Leaf expression, just clone as is
},
[&](const ast::PhonyExpression* phony) {
[&](const PhonyExpression* phony) {
return clone_maybe_hoisted(phony); // Leaf expression, just clone as is
},
[&](Default) {
@ -522,8 +518,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
// Inserts statements in `stmts` before `stmt`
template <size_t N>
void InsertBefore(tint::utils::Vector<const ast::Statement*, N>& stmts,
const ast::Statement* stmt) {
void InsertBefore(tint::utils::Vector<const Statement*, N>& stmts, const Statement* stmt) {
if (!stmts.IsEmpty()) {
auto ip = utils::GetInsertionPoint(ctx, stmt);
for (auto* s : stmts) {
@ -534,86 +529,86 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
// Decomposes expressions of `stmt`, returning a replacement statement or
// nullptr if not replacing it.
const ast::Statement* DecomposeStatement(const ast::Statement* stmt) {
const Statement* DecomposeStatement(const Statement* stmt) {
return Switch(
stmt,
[&](const ast::AssignmentStatement* s) -> const ast::Statement* {
[&](const AssignmentStatement* s) -> const Statement* {
if (!sem.GetVal(s->lhs)->HasSideEffects() &&
!sem.GetVal(s->rhs)->HasSideEffects()) {
return nullptr;
}
// lhs before rhs
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->lhs, Decompose(s->lhs, &stmts));
ctx.Replace(s->rhs, Decompose(s->rhs, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::CallStatement* s) -> const ast::Statement* {
[&](const CallStatement* s) -> const Statement* {
if (!sem.Get(s->expr)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->expr, Decompose(s->expr, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::ForLoopStatement* s) -> const ast::Statement* {
[&](const ForLoopStatement* s) -> const Statement* {
if (!s->condition || !sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::WhileStatement* s) -> const ast::Statement* {
[&](const WhileStatement* s) -> const Statement* {
if (!sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::IfStatement* s) -> const ast::Statement* {
[&](const IfStatement* s) -> const Statement* {
if (!sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::ReturnStatement* s) -> const ast::Statement* {
[&](const ReturnStatement* s) -> const Statement* {
if (!s->value || !sem.GetVal(s->value)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->value, Decompose(s->value, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::SwitchStatement* s) -> const ast::Statement* {
[&](const SwitchStatement* s) -> const Statement* {
if (!sem.Get(s->condition)) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
[&](const VariableDeclStatement* s) -> const Statement* {
auto* var = s->variable;
if (!var->initializer || !sem.GetVal(var->initializer)->HasSideEffects()) {
return nullptr;
}
tint::utils::Vector<const ast::Statement*, 8> stmts;
tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(var->initializer, Decompose(var->initializer, &stmts));
InsertBefore(stmts, s);
return b.Decl(ctx.CloneWithoutTransform(var));
},
[](Default) -> const ast::Statement* {
[](Default) -> const Statement* {
// Other statement types don't have expressions
return nullptr;
});
@ -626,7 +621,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
void Run() {
// We replace all BlockStatements as this allows us to iterate over the
// block statements and ctx.InsertBefore hoisted declarations on them.
ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
ctx.ReplaceAll([&](const BlockStatement* block) -> const Statement* {
for (auto* stmt : block->statements) {
if (auto* new_stmt = DecomposeStatement(stmt)) {
ctx.Replace(stmt, new_stmt);
@ -634,7 +629,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
// Handle for loops, as they are the only other AST node that
// contains statements outside of BlockStatements.
if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
if (auto* fl = stmt->As<ForLoopStatement>()) {
if (auto* new_stmt = DecomposeStatement(fl->initializer)) {
ctx.Replace(fl->initializer, new_stmt);
}

View File

@ -45,7 +45,7 @@ struct RemoveContinueInSwitch::State {
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
auto* cont = node->As<ast::ContinueStatement>();
auto* cont = node->As<ContinueStatement>();
if (!cont) {
continue;
}
@ -103,12 +103,12 @@ struct RemoveContinueInSwitch::State {
const sem::Info& sem = src->Sem();
// Map of switch statement to 'tint_continue' variable.
std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
std::unordered_map<const SwitchStatement*, Symbol> switch_to_cont_var_name;
// If `cont` is within a switch statement within a loop, returns a pointer to
// that switch statement.
static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
const ast::ContinueStatement* cont) {
static const SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
const ContinueStatement* cont) {
// Find whether first parent is a switch or a loop
auto* sem_stmt = sem.Get(cont);
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
@ -116,7 +116,7 @@ struct RemoveContinueInSwitch::State {
if (!sem_parent) {
return nullptr;
}
return sem_parent->Declaration()->As<ast::SwitchStatement>();
return sem_parent->Declaration()->As<SwitchStatement>();
}
};

View File

@ -53,14 +53,14 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
for (auto* node : src->ASTNodes().Objects()) {
Switch(
node,
[&](const ast::AssignmentStatement* stmt) {
if (stmt->lhs->Is<ast::PhonyExpression>()) {
[&](const AssignmentStatement* stmt) {
if (stmt->lhs->Is<PhonyExpression>()) {
made_changes = true;
std::vector<const ast::Expression*> side_effects;
if (!ast::TraverseExpressions(
stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) {
// ast::CallExpression may map to a function or builtin call
std::vector<const Expression*> side_effects;
if (!TraverseExpressions(
stmt->rhs, b.Diagnostics(), [&](const CallExpression* expr) {
// CallExpression may map to a function or builtin call
// (both may have side-effects), or a value constructor or value
// conversion (both do not have side effects).
auto* call = sem.Get<sem::Call>(expr);
@ -68,14 +68,14 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
// Semantic node must be a Materialize, in which case the
// expression was creation-time (compile time), so could not
// have side effects. Just skip.
return ast::TraverseAction::Skip;
return TraverseAction::Skip;
}
if (call->Target()->IsAnyOf<sem::Function, sem::Builtin>() &&
call->HasSideEffects()) {
side_effects.push_back(expr);
return ast::TraverseAction::Skip;
return TraverseAction::Skip;
}
return ast::TraverseAction::Descend;
return TraverseAction::Descend;
})) {
return;
}
@ -88,12 +88,12 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
}
if (side_effects.size() == 1) {
if (auto* call_expr = side_effects[0]->As<ast::CallExpression>()) {
if (auto* call_expr = side_effects[0]->As<CallExpression>()) {
// Phony assignment with single call side effect.
auto* call = sem.Get(call_expr)->Unwrap()->As<sem::Call>();
if (call->Target()->MustUse()) {
// Replace phony assignment assignment to uniquely named let.
ctx.Replace<ast::Statement>(stmt, [&, call_expr] { //
ctx.Replace<Statement>(stmt, [&, call_expr] { //
auto name = b.Symbols().New("tint_phony");
auto* rhs = ctx.Clone(call_expr);
return b.Decl(b.Let(name, rhs));
@ -118,7 +118,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
}
auto sink = sinks.GetOrCreate(sig, [&] {
auto name = b.Symbols().New("phony_sink");
utils::Vector<const ast::Parameter*, 8> params;
utils::Vector<const Parameter*, 8> params;
for (auto* ty : sig) {
auto ast_ty = CreateASTTypeFor(ctx, ty);
params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
@ -126,7 +126,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
b.Func(name, params, b.ty.void_(), {});
return name;
});
utils::Vector<const ast::Expression*, 8> args;
utils::Vector<const Expression*, 8> args;
for (auto* arg : side_effects) {
args.Push(ctx.Clone(arg));
}
@ -134,7 +134,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
});
}
},
[&](const ast::CallStatement* stmt) {
[&](const CallStatement* stmt) {
// Remove call statements to const value-returning functions.
// TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
auto* sem_expr = sem.Get(stmt->expr);

View File

@ -1265,10 +1265,10 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
}
// Identifiers that need to keep their symbols preserved.
utils::Hashset<const ast::Identifier*, 16> preserved_identifiers;
utils::Hashset<const Identifier*, 16> preserved_identifiers;
for (auto* node : src->ASTNodes().Objects()) {
auto preserve_if_builtin_type = [&](const ast::Identifier* ident) {
auto preserve_if_builtin_type = [&](const Identifier* ident) {
if (!global_decls.Contains(ident->symbol)) {
preserved_identifiers.Add(ident);
}
@ -1276,7 +1276,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
Switch(
node,
[&](const ast::MemberAccessorExpression* accessor) {
[&](const MemberAccessorExpression* accessor) {
auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
if (sem->Is<sem::Swizzle>()) {
preserved_identifiers.Add(accessor->member);
@ -1288,19 +1288,19 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
}
}
},
[&](const ast::DiagnosticAttribute* diagnostic) {
[&](const DiagnosticAttribute* diagnostic) {
if (auto* category = diagnostic->control.rule_name->category) {
preserved_identifiers.Add(category);
}
preserved_identifiers.Add(diagnostic->control.rule_name->name);
},
[&](const ast::DiagnosticDirective* diagnostic) {
[&](const DiagnosticDirective* diagnostic) {
if (auto* category = diagnostic->control.rule_name->category) {
preserved_identifiers.Add(category);
}
preserved_identifiers.Add(diagnostic->control.rule_name->name);
},
[&](const ast::IdentifierExpression* expr) {
[&](const IdentifierExpression* expr) {
Switch(
src->Sem().Get(expr), //
[&](const sem::BuiltinEnumExpressionBase*) {
@ -1310,7 +1310,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
preserve_if_builtin_type(expr->identifier);
});
},
[&](const ast::CallExpression* call) {
[&](const CallExpression* call) {
Switch(
src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(),
[&](const sem::Builtin*) {
@ -1372,7 +1372,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
ctx.ReplaceAll([&](const ast::Identifier* ident) -> const ast::Identifier* {
ctx.ReplaceAll([&](const Identifier* ident) -> const Identifier* {
const auto symbol = ident->symbol;
if (preserved_identifiers.Contains(ident) || !should_rename(symbol)) {
return nullptr; // Preserve symbol
@ -1382,12 +1382,12 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
auto replacement = remappings.GetOrCreate(symbol, [&] { return b.Symbols().New(); });
// Reconstruct the identifier
if (auto* tmpl_ident = ident->As<ast::TemplatedIdentifier>()) {
if (auto* tmpl_ident = ident->As<TemplatedIdentifier>()) {
auto args = ctx.Clone(tmpl_ident->arguments);
return ctx.dst->create<ast::TemplatedIdentifier>(ctx.Clone(ident->source), replacement,
return ctx.dst->create<TemplatedIdentifier>(ctx.Clone(ident->source), replacement,
std::move(args), utils::Empty);
}
return ctx.dst->create<ast::Identifier>(ctx.Clone(ident->source), replacement);
return ctx.dst->create<Identifier>(ctx.Clone(ident->source), replacement);
});
ctx.Clone();

View File

@ -58,7 +58,7 @@ struct Robustness::State {
for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch(
node, //
[&](const ast::IndexAccessorExpression* e) {
[&](const IndexAccessorExpression* e) {
// obj[idx]
// Array, matrix and vector indexing may require robustness transformation.
auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>();
@ -73,7 +73,7 @@ struct Robustness::State {
break;
}
},
[&](const ast::IdentifierExpression* e) {
[&](const IdentifierExpression* e) {
// Identifiers may resolve to pointer lets, which may be predicated.
// Inspect.
if (auto* user = sem.Get<sem::VariableUser>(e)) {
@ -86,42 +86,42 @@ struct Robustness::State {
}
}
},
[&](const ast::AccessorExpression* e) {
[&](const AccessorExpression* e) {
// obj.member
// Propagate the predication from the object to this expression.
if (auto pred = predicates.Get(e->object)) {
predicates.Add(e, *pred);
}
},
[&](const ast::UnaryOpExpression* e) {
[&](const UnaryOpExpression* e) {
// Includes address-of, or indirection
// Propagate the predication from the inner expression to this expression.
if (auto pred = predicates.Get(e->expr)) {
predicates.Add(e, *pred);
}
},
[&](const ast::AssignmentStatement* s) {
[&](const AssignmentStatement* s) {
if (auto pred = predicates.Get(s->lhs)) {
// Assignment target is predicated
// Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
}
},
[&](const ast::CompoundAssignmentStatement* s) {
[&](const CompoundAssignmentStatement* s) {
if (auto pred = predicates.Get(s->lhs)) {
// Assignment expression is predicated
// Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
}
},
[&](const ast::IncrementDecrementStatement* s) {
[&](const IncrementDecrementStatement* s) {
if (auto pred = predicates.Get(s->lhs)) {
// Assignment expression is predicated
// Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
}
},
[&](const ast::CallExpression* e) {
[&](const CallExpression* e) {
if (auto* call = sem.Get<sem::Call>(e)) {
Switch(
call->Target(), //
@ -163,7 +163,7 @@ struct Robustness::State {
// predicated_expr = expr;
// }
//
if (auto* expr = node->As<ast::Expression>()) {
if (auto* expr = node->As<Expression>()) {
if (auto pred = predicates.Get(expr)) {
// Expression is predicated
auto* sem_expr = sem.GetVal(expr);
@ -202,15 +202,15 @@ struct Robustness::State {
/// Alias to the source program's semantic info
const sem::Info& sem = ctx.src->Sem();
/// Map of expression to predicate condition
utils::Hashmap<const ast::Expression*, Symbol, 32> predicates{};
utils::Hashmap<const Expression*, Symbol, 32> predicates{};
/// @return the `u32` typed expression that represents the maximum indexable value for the index
/// accessor @p expr, or nullptr if there is no robustness limit for this expression.
const ast::Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) {
const Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) {
auto* obj_type = expr->Object()->Type();
return Switch(
obj_type->UnwrapRef(), //
[&](const type::Vector* vec) -> const ast::Expression* {
[&](const type::Vector* vec) -> const Expression* {
if (expr->Index()->ConstantValue() || expr->Index()->Is<sem::Swizzle>()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
@ -218,7 +218,7 @@ struct Robustness::State {
}
return b.Expr(u32(vec->Width() - 1u));
},
[&](const type::Matrix* mat) -> const ast::Expression* {
[&](const type::Matrix* mat) -> const Expression* {
if (expr->Index()->ConstantValue()) {
// Index and size is constant.
// Validation will have rejected any OOB accesses.
@ -226,7 +226,7 @@ struct Robustness::State {
}
return b.Expr(u32(mat->columns() - 1u));
},
[&](const type::Array* arr) -> const ast::Expression* {
[&](const type::Array* arr) -> const Expression* {
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
// Size is unknown until runtime.
// Must clamp, even if the index is constant.
@ -248,7 +248,7 @@ struct Robustness::State {
type::Array::kErrExpectedConstantCount);
return nullptr;
},
[&](Default) -> const ast::Expression* {
[&](Default) -> const Expression* {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled object type in robustness of array index: "
<< obj_type->UnwrapRef()->FriendlyName();
@ -350,7 +350,7 @@ struct Robustness::State {
/// Applies predication to the non-texture builtin call, if required.
void MaybePredicateNonTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
// Gather the predications for the builtin arguments
const ast::Expression* predicate = nullptr;
const Expression* predicate = nullptr;
for (auto* arg : call->Declaration()->args) {
if (auto pred = predicates.Get(arg)) {
predicate = And(predicate, b.Expr(*pred));
@ -393,7 +393,7 @@ struct Robustness::State {
auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
// Build the builtin predicate from the arguments
const ast::Expression* predicate = nullptr;
const Expression* predicate = nullptr;
Symbol level_idx, num_levels;
if (level_arg_idx >= 0) {
@ -554,7 +554,7 @@ struct Robustness::State {
}
/// @returns a bitwise and of the two expressions, or the other expression if one is null.
const ast::Expression* And(const ast::Expression* lhs, const ast::Expression* rhs) {
const Expression* And(const Expression* lhs, const Expression* rhs) {
if (lhs && rhs) {
return b.And(lhs, rhs);
}
@ -568,11 +568,11 @@ struct Robustness::State {
/// predicate.
/// @param else_stmt - the statement to execute for the predication failure
void PredicateCall(const sem::Call* call,
const ast::Expression* predicate,
const ast::BlockStatement* else_stmt = nullptr) {
const Expression* predicate,
const BlockStatement* else_stmt = nullptr) {
auto* expr = call->Declaration();
auto* stmt = call->Stmt();
auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>();
auto* call_stmt = stmt->Declaration()->As<CallStatement>();
if (call_stmt && call_stmt->expr == expr) {
// Wrap the statement in an if-statement with the predicate condition.
hoist.Replace(stmt, b.If(predicate, b.Block(ctx.Clone(stmt->Declaration())),
@ -646,7 +646,7 @@ struct Robustness::State {
}
/// @returns a scalar or vector type with the element type @p scalar and width @p width
ast::Type ScalarOrVecTy(ast::Type scalar, uint32_t width) const {
Type ScalarOrVecTy(Type scalar, uint32_t width) const {
if (width > 1) {
return b.ty.vec(scalar, width);
}
@ -655,7 +655,7 @@ struct Robustness::State {
/// @returns a vector constructed with the scalar expression @p scalar if @p width > 1,
/// otherwise returns @p scalar.
const ast::Expression* ScalarOrVec(const ast::Expression* scalar, uint32_t width) {
const Expression* ScalarOrVec(const Expression* scalar, uint32_t width) {
if (width > 1) {
return b.Call(b.ty.vec<Infer>(width), scalar);
}
@ -664,13 +664,13 @@ struct Robustness::State {
/// @returns @p val cast to a `vecN<i32>`, where `N` is @p width, or cast to i32 if @p width
/// is 1.
const ast::CallExpression* CastToSigned(const ast::Expression* val, uint32_t width) {
const CallExpression* CastToSigned(const Expression* val, uint32_t width) {
return b.Call(ScalarOrVecTy(b.ty.i32(), width), val);
}
/// @returns @p val cast to a `vecN<u32>`, where `N` is @p width, or cast to u32 if @p width
/// is 1.
const ast::CallExpression* CastToUnsigned(const ast::Expression* val, uint32_t width) {
const CallExpression* CastToUnsigned(const Expression* val, uint32_t width) {
return b.Call(ScalarOrVecTy(b.ty.u32(), width), val);
}
};

View File

@ -41,7 +41,7 @@ struct PointerOp {
/// Zero: no pointer op on `expr`
int indirections = 0;
/// The expression being operated on
const ast::Expression* expr = nullptr;
const Expression* expr = nullptr;
};
} // namespace
@ -64,29 +64,29 @@ struct SimplifyPointers::State {
/// expression. The function-like argument `cb` is called for each found.
/// @param expr the expression to traverse
/// @param cb a function-like object with the signature
/// `void(const ast::Expression*)`, which is called for each array index
/// `void(const Expression*)`, which is called for each array index
/// expression
template <typename F>
static void CollectSavedArrayIndices(const ast::Expression* expr, F&& cb) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
static void CollectSavedArrayIndices(const Expression* expr, F&& cb) {
if (auto* a = expr->As<IndexAccessorExpression>()) {
CollectSavedArrayIndices(a->object, cb);
if (!a->index->Is<ast::LiteralExpression>()) {
if (!a->index->Is<LiteralExpression>()) {
cb(a->index);
}
return;
}
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
if (auto* m = expr->As<MemberAccessorExpression>()) {
CollectSavedArrayIndices(m->object, cb);
return;
}
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
if (auto* u = expr->As<UnaryOpExpression>()) {
CollectSavedArrayIndices(u->expr, cb);
return;
}
// Note: Other ast::Expression types can be safely ignored as they cannot be
// Note: Other Expression types can be safely ignored as they cannot be
// used to generate a reference or pointer.
// See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
}
@ -95,16 +95,16 @@ struct SimplifyPointers::State {
/// indirection ops into a PointerOp.
/// @param in the expression to walk
/// @returns the reduced PointerOp
PointerOp Reduce(const ast::Expression* in) const {
PointerOp Reduce(const Expression* in) const {
PointerOp op{0, in};
while (true) {
if (auto* unary = op.expr->As<ast::UnaryOpExpression>()) {
if (auto* unary = op.expr->As<UnaryOpExpression>()) {
switch (unary->op) {
case ast::UnaryOp::kIndirection:
case UnaryOp::kIndirection:
op.indirections++;
op.expr = unary->expr;
continue;
case ast::UnaryOp::kAddressOf:
case UnaryOp::kAddressOf:
op.indirections--;
op.expr = unary->expr;
continue;
@ -115,7 +115,7 @@ struct SimplifyPointers::State {
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
auto* var = user->Variable();
if (var->Is<sem::LocalVariable>() && //
var->Declaration()->Is<ast::Let>() && //
var->Declaration()->Is<Let>() && //
var->Type()->Is<type::Pointer>()) {
op.expr = var->Declaration()->initializer;
continue;
@ -129,7 +129,7 @@ struct SimplifyPointers::State {
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
// A map of saved expressions to their saved variable name
utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars;
utils::Hashmap<const Expression*, Symbol, 8> saved_vars;
bool needs_transform = false;
for (auto* ty : ctx.src->Types()) {
@ -146,8 +146,8 @@ struct SimplifyPointers::State {
for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch(
node, //
[&](const ast::VariableDeclStatement* let) {
if (!let->variable->Is<ast::Let>()) {
[&](const VariableDeclStatement* let) {
if (!let->variable->Is<Let>()) {
return; // Not a `let` declaration. Ignore.
}
@ -160,9 +160,9 @@ struct SimplifyPointers::State {
// Scan the initializer expression for array index expressions that need
// to be hoist to temporary "saved" variables.
utils::Vector<const ast::VariableDeclStatement*, 8> saved;
utils::Vector<const VariableDeclStatement*, 8> saved;
CollectSavedArrayIndices(
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) {
var->Declaration()->initializer, [&](const Expression* idx_expr) {
// We have a sub-expression that needs to be saved.
// Create a new variable
auto saved_name = ctx.dst->Symbols().New(
@ -205,8 +205,8 @@ struct SimplifyPointers::State {
// need for the original declaration to exist. Remove it.
RemoveStatement(ctx, let);
},
[&](const ast::UnaryOpExpression* op) {
if (op->op == ast::UnaryOp::kAddressOf) {
[&](const UnaryOpExpression* op) {
if (op->op == UnaryOp::kAddressOf) {
// Transform can be skipped if no address-of operator is used, as there
// will be no pointers that can be inlined.
needs_transform = true;
@ -218,7 +218,7 @@ struct SimplifyPointers::State {
return SkipTransform;
}
// Register the ast::Expression transform handler.
// Register the Expression transform handler.
// This performs two different transformations:
// * Identifiers that resolve to the pointer-typed `let` declarations are
// replaced with the recursively inlined initializer expression for the
@ -226,7 +226,7 @@ struct SimplifyPointers::State {
// * Sub-expressions inside the pointer-typed `let` initializer expression
// that have been hoisted to a saved variable are replaced with the saved
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
// Look to see if we need to swap this Expression with a saved variable.
if (auto saved_var = saved_vars.Find(expr)) {
return ctx.dst->Expr(*saved_var);

View File

@ -45,7 +45,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
}
// Find the target entry point.
const ast::Function* entry_point = nullptr;
const Function* entry_point = nullptr;
for (auto* f : src->AST().Functions()) {
if (!f->IsEntryPoint()) {
continue;
@ -69,7 +69,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
for (auto* decl : src->AST().GlobalDeclarations()) {
Switch(
decl, //
[&](const ast::TypeDecl* ty) {
[&](const TypeDecl* ty) {
// Strip aliases that reference unused override declarations.
if (auto* arr = sem.Get(ty)->As<type::Array>()) {
auto* refs = sem.TransitivelyReferencedOverrides(arr);
@ -85,9 +85,9 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
// TODO(jrprice): Strip other unused types.
b.AST().AddTypeDecl(ctx.Clone(ty));
},
[&](const ast::Override* override) {
[&](const Override* override) {
if (referenced_vars.Contains(sem.Get(override))) {
if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
if (!HasAttribute<IdAttribute>(override->attributes)) {
// If the override doesn't already have an @id() attribute, add one
// so that its allocated ID so that it won't be affected by other
// stripped away overrides
@ -98,26 +98,24 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
b.AST().AddGlobalVariable(ctx.Clone(override));
}
},
[&](const ast::Var* var) {
[&](const Var* var) {
if (referenced_vars.Contains(sem.Get<sem::GlobalVariable>(var))) {
b.AST().AddGlobalVariable(ctx.Clone(var));
}
},
[&](const ast::Const* c) {
[&](const Const* c) {
// Always keep 'const' declarations, as these can be used by attributes and array
// sizes, which are not tracked as transitively used by functions. They also don't
// typically get emitted by the backend unless they're actually used.
b.AST().AddGlobalVariable(ctx.Clone(c));
},
[&](const ast::Function* func) {
[&](const Function* func) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->name->symbol)) {
b.AST().AddFunction(ctx.Clone(func));
}
},
[&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
[&](const ast::DiagnosticDirective* d) {
b.AST().AddDiagnosticDirective(ctx.Clone(d));
},
[&](const Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
[&](const DiagnosticDirective* d) { b.AST().AddDiagnosticDirective(ctx.Clone(d)); },
[&](Default) {
TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;

View File

@ -70,7 +70,7 @@ struct SpirvAtomic::State {
// Look for stub functions generated by the SPIR-V reader, which are used as placeholders
// for atomic builtin calls.
for (auto* fn : ctx.src->AST().Functions()) {
if (auto* stub = ast::GetAttribute<Stub>(fn->attributes)) {
if (auto* stub = GetAttribute<Stub>(fn->attributes)) {
auto* sem = ctx.src->Sem().Get(fn);
for (auto* call : sem->CallSites()) {
@ -121,14 +121,14 @@ struct SpirvAtomic::State {
// If we need to change structure members, then fork them.
if (!forked_structs.empty()) {
ctx.ReplaceAll([&](const ast::Struct* str) {
ctx.ReplaceAll([&](const Struct* str) {
// Is `str` a structure we need to fork?
auto* str_ty = ctx.src->Sem().Get(str);
if (auto it = forked_structs.find(str_ty); it != forked_structs.end()) {
const auto& forked = it->second;
// Re-create the structure swapping in the atomic-flavoured members
utils::Vector<const ast::StructMember*, 8> members;
utils::Vector<const StructMember*, 8> members;
members.Reserve(str->members.Length());
for (size_t i = 0; i < str->members.Length(); i++) {
auto* member = str->members[i];
@ -187,14 +187,14 @@ struct SpirvAtomic::State {
atomic_expressions.Add(index->Object());
},
[&](const sem::ValueExpression* e) {
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
}
});
}
}
ast::Type AtomicTypeFor(const type::Type* ty) {
Type AtomicTypeFor(const type::Type* ty) {
return Switch(
ty, //
[&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
@ -221,7 +221,7 @@ struct SpirvAtomic::State {
[&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
[&](Default) {
TINT_ICE(Transform, b.Diagnostics()) << "unhandled type: " << ty->FriendlyName();
return ast::Type{};
return Type{};
});
}
@ -249,7 +249,7 @@ struct SpirvAtomic::State {
for (auto* vu : atomic_var->Users()) {
Switch(
vu->Stmt()->Declaration(),
[&](const ast::AssignmentStatement* assign) {
[&](const AssignmentStatement* assign) {
auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
if (is_ref_to_atomic_var(sem_lhs)) {
ctx.Replace(assign, [=] {
@ -272,7 +272,7 @@ struct SpirvAtomic::State {
return;
}
},
[&](const ast::VariableDeclStatement* decl) {
[&](const VariableDeclStatement* decl) {
auto* var = decl->variable;
if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
@ -293,7 +293,7 @@ struct SpirvAtomic::State {
SpirvAtomic::SpirvAtomic() = default;
SpirvAtomic::~SpirvAtomic() = default;
SpirvAtomic::Stub::Stub(ProgramID pid, ast::NodeID nid, builtin::Function b)
SpirvAtomic::Stub::Stub(ProgramID pid, NodeID nid, builtin::Function b)
: Base(pid, nid, utils::Empty), builtin(b) {}
SpirvAtomic::Stub::~Stub() = default;
std::string SpirvAtomic::Stub::InternalName() const {

View File

@ -41,12 +41,12 @@ class SpirvAtomic final : public utils::Castable<SpirvAtomic, Transform> {
/// Stub is an attribute applied to stub SPIR-V reader generated functions that need to be
/// translated to an atomic builtin.
class Stub final : public utils::Castable<Stub, ast::InternalAttribute> {
class Stub final : public utils::Castable<Stub, InternalAttribute> {
public:
/// @param pid the identifier of the program that owns this node
/// @param nid the unique node identifier
/// @param builtin the atomic builtin this stub represents
Stub(ProgramID pid, ast::NodeID nid, builtin::Function builtin);
Stub(ProgramID pid, NodeID nid, builtin::Function builtin);
/// Destructor
~Stub() override;

View File

@ -102,7 +102,7 @@ struct Std140::State {
// Finally, replace all expression chains that used the authored types with those that
// correctly use the forked types.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
if (auto access = AccessChainFor(expr)) {
if (!access->std140_mat_idx.has_value()) {
// loading a std140 type, which is not a whole or partial decomposed matrix
@ -230,7 +230,7 @@ struct Std140::State {
// Map of structure member in src of a matrix type, to list of decomposed column
// members in ctx.dst.
utils::Hashmap<const type::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8>
utils::Hashmap<const type::StructMember*, utils::Vector<const StructMember*, 4>, 8>
std140_mat_members;
/// Describes a matrix that has been forked to a std140-structure holding the decomposed column
@ -284,7 +284,7 @@ struct Std140::State {
if (str && str->UsedAs(builtin::AddressSpace::kUniform)) {
// Should this uniform buffer be forked for std140 usage?
bool fork_std140 = false;
utils::Vector<const ast::StructMember*, 8> members;
utils::Vector<const StructMember*, 8> members;
for (auto* member : str->Members()) {
if (auto* mat = member->Type()->As<type::Matrix>()) {
// Is this member a matrix that needs decomposition for std140-layout?
@ -335,7 +335,7 @@ struct Std140::State {
// Create a new forked structure, and insert it just under the original
// structure.
auto name = b.Symbols().New(str->Name().Name() + "_std140");
auto* std140 = b.create<ast::Struct>(b.Ident(name), std::move(members),
auto* std140 = b.create<Struct>(b.Ident(name), std::move(members),
ctx.Clone(str->Declaration()->attributes));
ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
std140_structs.Add(str, name);
@ -349,7 +349,7 @@ struct Std140::State {
/// Populates the #std140_uniforms set.
void ReplaceUniformVarTypes() {
for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (auto* var = global->As<Var>()) {
auto* v = sem.Get(var);
if (v->AddressSpace() == builtin::AddressSpace::kUniform) {
if (auto std140_ty = Std140Type(v->Type()->UnwrapRef())) {
@ -367,9 +367,7 @@ struct Std140::State {
/// @param str the structure that will hold the uniquely named member.
/// @param unsuffixed the common name prefix to use for the new members.
/// @param count the number of members that need to be created.
std::string PrefixForUniqueNames(const ast::Struct* str,
Symbol unsuffixed,
uint32_t count) const {
std::string PrefixForUniqueNames(const Struct* str, Symbol unsuffixed, uint32_t count) const {
auto prefix = unsuffixed.Name();
// Keep on inserting '_' between the unsuffixed name and the suffix numbers until the name
// is unique.
@ -400,14 +398,14 @@ struct Std140::State {
/// If the semantic type is not split for std140-layout, then nullptr is returned.
/// @note will construct new std140 structures to hold decomposed matrices, populating
/// #std140_mats.
ast::Type Std140Type(const type::Type* ty) {
Type Std140Type(const type::Type* ty) {
return Switch(
ty, //
[&](const type::Struct* str) {
if (auto std140 = std140_structs.Find(str)) {
return b.ty(*std140);
}
return ast::Type{};
return Type{};
},
[&](const type::Matrix* mat) {
if (MatrixNeedsDecomposing(mat)) {
@ -426,13 +424,13 @@ struct Std140::State {
});
return b.ty(std140_mat.name);
}
return ast::Type{};
return Type{};
},
[&](const type::Array* arr) {
if (auto std140 = Std140Type(arr->ElemType())) {
utils::Vector<const ast::Attribute*, 1> attrs;
utils::Vector<const Attribute*, 1> attrs;
if (!arr->IsStrideImplicit()) {
attrs.Push(b.create<ast::StrideAttribute>(arr->Stride()));
attrs.Push(b.create<StrideAttribute>(arr->Stride()));
}
auto count = arr->ConstantCount();
if (TINT_UNLIKELY(!count)) {
@ -446,7 +444,7 @@ struct Std140::State {
}
return b.ty.array(std140, b.Expr(u32(count.value())), std::move(attrs));
}
return ast::Type{};
return Type{};
});
}
@ -455,7 +453,7 @@ struct Std140::State {
/// @param align the alignment in bytes of the matrix.
/// @param size the size in bytes of the matrix.
/// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst).
utils::Vector<const ast::StructMember*, 4> DecomposedMatrixStructMembers(
utils::Vector<const StructMember*, 4> DecomposedMatrixStructMembers(
const type::Matrix* mat,
const std::string& name_prefix,
uint32_t align,
@ -463,9 +461,9 @@ struct Std140::State {
// Replace the member with column vectors.
const auto num_columns = mat->columns();
// Build a struct member for each column of the matrix
utils::Vector<const ast::StructMember*, 4> out;
utils::Vector<const StructMember*, 4> out;
for (uint32_t i = 0; i < num_columns; i++) {
utils::Vector<const ast::Attribute*, 1> attributes;
utils::Vector<const Attribute*, 1> attributes;
if ((i == 0) && mat->Align() != align) {
// The matrix was @align() annotated with a larger alignment
// than the natural alignment for the matrix. This extra padding
@ -493,7 +491,7 @@ struct Std140::State {
/// Walks the @p ast_expr, constructing and returning an AccessChain.
/// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer,
/// otherwise returns a std::nullopt.
std::optional<AccessChain> AccessChainFor(const ast::Expression* ast_expr) {
std::optional<AccessChain> AccessChainFor(const Expression* ast_expr) {
auto* expr = sem.GetVal(ast_expr);
if (!expr) {
return std::nullopt;
@ -576,10 +574,10 @@ struct Std140::State {
[&](const sem::ValueExpression* e) {
// Walk past indirection and address-of unary ops.
return Switch(e->Declaration(), //
[&](const ast::UnaryOpExpression* u) {
[&](const UnaryOpExpression* u) {
switch (u->op) {
case ast::UnaryOp::kAddressOf:
case ast::UnaryOp::kIndirection:
case UnaryOp::kAddressOf:
case UnaryOp::kIndirection:
expr = sem.GetVal(u->expr);
return Action::kContinue;
default:
@ -660,8 +658,8 @@ struct Std140::State {
/// Generates and returns an expression that loads the value from a std140 uniform buffer,
/// converting the final result to a non-std140 type.
/// @param chain the access chain from a uniform buffer to the value to load.
const ast::Expression* LoadWithConvert(const AccessChain& chain) {
const ast::Expression* expr = nullptr;
const Expression* LoadWithConvert(const AccessChain& chain) {
const Expression* expr = nullptr;
const type::Type* ty = nullptr;
auto dynamic_index = [&](size_t idx) {
return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
@ -678,7 +676,7 @@ struct Std140::State {
/// std140-forked type to the type @p ty. If @p expr is not a std140-forked type, then Convert()
/// will simply return @p expr.
/// @returns the converted value expression.
const ast::Expression* Convert(const type::Type* ty, const ast::Expression* expr) {
const Expression* Convert(const type::Type* ty, const Expression* expr) {
// Get an existing, or create a new function for converting the std140 type to ty.
auto fn = conv_fns.GetOrCreate(ty, [&] {
auto std140_ty = Std140Type(ty);
@ -690,20 +688,20 @@ struct Std140::State {
// The converter function takes a single argument of the std140 type.
auto* param = b.Param("val", std140_ty);
utils::Vector<const ast::Statement*, 3> stmts;
utils::Vector<const Statement*, 3> stmts;
Switch(
ty, //
[&](const type::Struct* str) {
// Convert each of the structure members using either a converter function
// call, or by reassembling a std140 matrix from column vector members.
utils::Vector<const ast::Expression*, 8> args;
utils::Vector<const Expression*, 8> args;
for (auto* member : str->Members()) {
if (auto col_members = std140_mat_members.Find(member)) {
// std140 decomposed matrix. Reassemble.
auto mat_ty = CreateASTTypeFor(ctx, member->Type());
auto mat_args =
utils::Transform(*col_members, [&](const ast::StructMember* m) {
utils::Transform(*col_members, [&](const StructMember* m) {
return b.MemberAccessor(param, m->name->symbol);
});
args.Push(b.Call(mat_ty, std::move(mat_args)));
@ -719,7 +717,7 @@ struct Std140::State {
// Reassemble a std140 matrix from the structure of column vector members.
auto std140_mat = std140_mats.Get(mat);
if (TINT_LIKELY(std140_mat)) {
utils::Vector<const ast::Expression*, 8> args;
utils::Vector<const Expression*, 8> args;
// std140 decomposed matrix. Reassemble.
auto mat_ty = CreateASTTypeFor(ctx, mat);
auto mat_args = utils::Transform(std140_mat->columns, [&](Symbol name) {
@ -782,7 +780,7 @@ struct Std140::State {
/// @param access the access chain from the uniform buffer to either the whole matrix or part of
/// the matrix (column, column-swizzle, or element).
/// @returns the loaded value expression.
const ast::Expression* LoadMatrixWithFn(const AccessChain& access) {
const Expression* LoadMatrixWithFn(const AccessChain& access) {
// Get an existing, or create a new function for loading the uniform buffer value.
// This function is keyed off the uniform buffer variable and the access chain.
auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] {
@ -810,14 +808,14 @@ struct Std140::State {
/// column-swizzle, or element).
/// @note The matrix column must be statically indexed to use this method.
/// @returns the loaded value expression.
const ast::Expression* LoadSubMatrixInline(const AccessChain& chain) {
const Expression* LoadSubMatrixInline(const AccessChain& chain) {
// Method for generating dynamic index expressions.
// As this is inline, we can just clone the expression.
auto dynamic_index = [&](size_t idx) {
return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
};
const ast::Expression* expr = nullptr;
const Expression* expr = nullptr;
const type::Type* ty = nullptr;
// Build the expression up to, but not including the matrix member
@ -891,7 +889,7 @@ struct Std140::State {
std::string name = "load";
// The switch cases
utils::Vector<const ast::CaseStatement*, 4> cases;
utils::Vector<const CaseStatement*, 4> cases;
// The function return type.
const type::Type* ret_ty = nullptr;
@ -899,7 +897,7 @@ struct Std140::State {
// Build switch() cases for each column of the matrix
auto num_columns = chain.std140_mat_ty->columns();
for (uint32_t column_idx = 0; column_idx < num_columns; column_idx++) {
const ast::Expression* expr = nullptr;
const Expression* expr = nullptr;
const type::Type* ty = nullptr;
// Build the expression up to, but not including the matrix
@ -991,7 +989,7 @@ struct Std140::State {
return b.Expr(dynamic_index_params[idx]->name->symbol);
};
const ast::Expression* expr = nullptr;
const Expression* expr = nullptr;
const type::Type* ty = nullptr;
std::string name = "load";
@ -1005,13 +1003,13 @@ struct Std140::State {
name += "_" + access_name;
}
utils::Vector<const ast::Statement*, 2> stmts;
utils::Vector<const Statement*, 2> stmts;
// Create a temporary pointer to the structure that holds the matrix columns
auto* let = b.Let("s", b.AddressOf(expr));
stmts.Push(b.Decl(let));
utils::Vector<const ast::MemberAccessorExpression*, 4> columns;
utils::Vector<const MemberAccessorExpression*, 4> columns;
if (auto* str = tint::As<type::Struct>(ty)) {
// Structure member matrix. The columns are decomposed into the structure.
auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]);
@ -1053,7 +1051,7 @@ struct Std140::State {
/// Return type of BuildAccessExpr()
struct ExprTypeName {
/// The new, post-access expression
const ast::Expression* expr;
const Expression* expr;
/// The type of #expr
const type::Type* type;
/// A name segment which can be used to build sensible names for helper functions
@ -1067,11 +1065,11 @@ struct Std140::State {
/// @param dynamic_index a function that obtains the i'th dynamic index
/// @returns a ExprTypeName which holds the new expression, new type and a name segment which
/// can be used for creating helper function names.
ExprTypeName BuildAccessExpr(const ast::Expression* lhs,
ExprTypeName BuildAccessExpr(const Expression* lhs,
const type::Type* ty,
const AccessChain& chain,
size_t index,
std::function<const ast::Expression*(size_t)> dynamic_index) {
std::function<const Expression*(size_t)> dynamic_index) {
auto& access = chain.indices[index];
if (std::get_if<UniformVariable>(&access)) {

View File

@ -32,7 +32,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->AST().GlobalVariables()) {
if (node->Is<ast::Override>()) {
if (node->Is<Override>()) {
return true;
}
}
@ -61,12 +61,12 @@ Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
return SkipTransform;
}
ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* {
ctx.ReplaceAll([&](const Override* w) -> const Const* {
auto* sem = ctx.src->Sem().Get(w);
auto source = ctx.Clone(w->source);
auto sym = ctx.Clone(w->name->symbol);
ast::Type ty = w->type ? ctx.Clone(w->type) : ast::Type{};
Type ty = w->type ? ctx.Clone(w->type) : Type{};
// No replacement provided, just clone the override node as a const.
auto iter = data->map.find(sem->OverrideId());
@ -102,7 +102,7 @@ Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
// If the object is not materialized, and the 'override' variable is turned to a 'const', the
// resulting type of the index may change. See: crbug.com/tint/1697.
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
[&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
if (auto* sem = src->Sem().Get(expr)) {
if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&

View File

@ -85,18 +85,18 @@ struct Texture1DTo2D::State {
return SkipTransform;
}
auto create_var = [&](const ast::Variable* v, ast::Type type) -> const ast::Variable* {
if (v->As<ast::Parameter>()) {
auto create_var = [&](const Variable* v, Type type) -> const Variable* {
if (v->As<Parameter>()) {
return ctx.dst->Param(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
} else {
return ctx.dst->Var(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
}
};
ctx.ReplaceAll([&](const ast::Variable* v) -> const ast::Variable* {
const ast::Variable* r = Switch(
ctx.ReplaceAll([&](const Variable* v) -> const Variable* {
const Variable* r = Switch(
sem.Get(v)->Type()->UnwrapRef(),
[&](const type::SampledTexture* tex) -> const ast::Variable* {
[&](const type::SampledTexture* tex) -> const Variable* {
if (tex->dim() == type::TextureDimension::k1d) {
auto type = ctx.dst->ty.sampled_texture(type::TextureDimension::k2d,
CreateASTTypeFor(ctx, tex->type()));
@ -105,7 +105,7 @@ struct Texture1DTo2D::State {
return nullptr;
}
},
[&](const type::StorageTexture* storage_tex) -> const ast::Variable* {
[&](const type::StorageTexture* storage_tex) -> const Variable* {
if (storage_tex->dim() == type::TextureDimension::k1d) {
auto type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d,
storage_tex->texel_format(),
@ -119,7 +119,7 @@ struct Texture1DTo2D::State {
return r;
});
ctx.ReplaceAll([&](const ast::CallExpression* c) -> const ast::Expression* {
ctx.ReplaceAll([&](const CallExpression* c) -> const Expression* {
auto* call = sem.Get(c)->UnwrapMaterialize()->As<sem::Call>();
if (!call) {
return nullptr;
@ -141,7 +141,7 @@ struct Texture1DTo2D::State {
if (builtin->Type() == builtin::Function::kTextureDimensions) {
// If this textureDimensions() call is in a CallStatement, we can leave it
// unmodified since the return value will be dropped on the floor anyway.
if (call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
if (call->Stmt()->Declaration()->Is<CallStatement>()) {
return nullptr;
}
auto* new_call = ctx.CloneWithoutTransform(c);
@ -153,14 +153,14 @@ struct Texture1DTo2D::State {
return nullptr;
}
utils::Vector<const ast::Expression*, 8> args;
utils::Vector<const Expression*, 8> args;
int index = 0;
for (auto* arg : c->args) {
if (index == coords_index) {
auto* ctype = call->Arguments()[static_cast<size_t>(coords_index)]->Type();
auto* coords = c->args[static_cast<size_t>(coords_index)];
const ast::LiteralExpression* half = nullptr;
const LiteralExpression* half = nullptr;
if (ctype->is_integer_scalar()) {
half = ctx.dst->Expr(0_a);
} else {

View File

@ -60,23 +60,23 @@ Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const
return output;
}
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
void Transform::RemoveStatement(CloneContext& ctx, const Statement* stmt) {
auto* sem = ctx.src->Sem().Get(stmt);
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
ctx.Remove(block->Declaration()->statements, stmt);
return;
}
if (TINT_LIKELY(tint::Is<sem::ForLoopStatement>(sem->Parent()))) {
ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr));
ctx.Replace(stmt, static_cast<Expression*>(nullptr));
return;
}
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unable to remove statement from parent of type " << sem->TypeInfo().name;
}
ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
if (ty->Is<type::Void>()) {
return ast::Type{};
return Type{};
}
if (ty->Is<type::I32>()) {
return ctx.dst->ty.i32();
@ -108,9 +108,9 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
}
if (auto* a = ty->As<type::Array>()) {
auto el = CreateASTTypeFor(ctx, a->ElemType());
utils::Vector<const ast::Attribute*, 1> attrs;
utils::Vector<const Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) {
attrs.Push(ctx.dst->create<ast::StrideAttribute>(a->Stride()));
attrs.Push(ctx.dst->create<StrideAttribute>(a->Stride()));
}
if (a->Count()->Is<type::RuntimeArrayCount>()) {
return ctx.dst->ty.array(el, std::move(attrs));
@ -125,7 +125,7 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
// See crbug.com/tint/1764.
// Look for a type alias for this array.
for (auto* type_decl : ctx.src->AST().TypeDecls()) {
if (auto* alias = type_decl->As<ast::Alias>()) {
if (auto* alias = type_decl->As<Alias>()) {
if (ty == ctx.src->Sem().Get(alias)) {
// Alias found. Use the alias name to ensure types compare equal.
return ctx.dst->ty(ctx.Clone(alias->name->symbol));
@ -184,7 +184,7 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
}
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "Unhandled type: " << ty->TypeInfo().name;
return ast::Type{};
return Type{};
}
} // namespace tint::ast::transform

View File

@ -181,11 +181,11 @@ class Transform : public utils::Castable<Transform> {
const DataMap& inputs,
DataMap& outputs) const = 0;
/// CreateASTTypeFor constructs new ast::Type that reconstructs the semantic type `ty`.
/// CreateASTTypeFor constructs new Type that reconstructs the semantic type `ty`.
/// @param ctx the clone context
/// @param ty the semantic type to reconstruct
/// @returns an ast::Type that when resolved, will produce the semantic type `ty`.
static ast::Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
/// @returns an Type that when resolved, will produce the semantic type `ty`.
static Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
protected:
/// Removes the statement `stmt` from the transformed program.
@ -193,7 +193,7 @@ class Transform : public utils::Castable<Transform> {
/// continuing of for-loops.
/// @param ctx the clone context
/// @param stmt the statement to remove when the program is cloned
static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt);
static void RemoveStatement(CloneContext& ctx, const Statement* stmt);
};
} // namespace tint::ast::transform

View File

@ -32,7 +32,7 @@ struct CreateASTTypeForTest : public testing::Test, public Transform {
return SkipTransform;
}
ast::Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
ProgramBuilder sem_type_builder;
auto* sem_type = create_sem_type(sem_type_builder);
Program program(std::move(sem_type_builder));
@ -44,9 +44,7 @@ struct CreateASTTypeForTest : public testing::Test, public Transform {
};
TEST_F(CreateASTTypeForTest, Basic) {
auto check = [&](ast::Type ty, const char* expect) {
ast::CheckIdentifier(ty->identifier, expect);
};
auto check = [&](Type ty, const char* expect) { CheckIdentifier(ty->identifier, expect); };
check(create([](ProgramBuilder& b) { return b.create<type::I32>(); }), "i32");
check(create([](ProgramBuilder& b) { return b.create<type::U32>(); }), "u32");
@ -61,14 +59,14 @@ TEST_F(CreateASTTypeForTest, Matrix) {
return b.create<type::Matrix>(column_type, 3u);
});
ast::CheckIdentifier(mat, ast::Template("mat3x2", "f32"));
CheckIdentifier(mat, Template("mat3x2", "f32"));
}
TEST_F(CreateASTTypeForTest, Vector) {
auto vec =
create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
ast::CheckIdentifier(vec, ast::Template("vec2", "f32"));
CheckIdentifier(vec, Template("vec2", "f32"));
}
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
@ -77,8 +75,8 @@ TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
4u, 4u, 32u, 32u);
});
ast::CheckIdentifier(arr, ast::Template("array", "f32", 2_u));
auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>();
CheckIdentifier(arr, Template("array", "f32", 2_u));
auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
ASSERT_NE(tmpl_attr, nullptr);
EXPECT_TRUE(tmpl_attr->attributes.IsEmpty());
}
@ -88,12 +86,12 @@ TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
4u, 4u, 64u, 32u);
});
ast::CheckIdentifier(arr, ast::Template("array", "f32", 2_u));
auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>();
CheckIdentifier(arr, Template("array", "f32", 2_u));
auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
ASSERT_NE(tmpl_attr, nullptr);
ASSERT_EQ(tmpl_attr->attributes.Length(), 1u);
ASSERT_TRUE(tmpl_attr->attributes[0]->Is<ast::StrideAttribute>());
ASSERT_EQ(tmpl_attr->attributes[0]->As<ast::StrideAttribute>()->stride, 64u);
ASSERT_TRUE(tmpl_attr->attributes[0]->Is<StrideAttribute>());
ASSERT_EQ(tmpl_attr->attributes[0]->As<StrideAttribute>()->stride, 64u);
}
// crbug.com/tint/1764
@ -114,7 +112,7 @@ TEST_F(CreateASTTypeForTest, AliasedArrayWithComplexOverrideLength) {
CloneContext ctx(&ast_type_builder, &program, false);
auto ast_ty = CreateASTTypeFor(ctx, arr_ty);
ast::CheckIdentifier(ast_ty, "A");
CheckIdentifier(ast_ty, "A");
}
TEST_F(CreateASTTypeForTest, Struct) {
@ -124,7 +122,7 @@ TEST_F(CreateASTTypeForTest, Struct) {
4u /* size */, 4u /* size_no_padding */);
});
ast::CheckIdentifier(str, "S");
CheckIdentifier(str, "S");
}
TEST_F(CreateASTTypeForTest, PrivatePointer) {
@ -133,7 +131,7 @@ TEST_F(CreateASTTypeForTest, PrivatePointer) {
builtin::Access::kReadWrite);
});
ast::CheckIdentifier(ptr, ast::Template("ptr", "private", "i32"));
CheckIdentifier(ptr, Template("ptr", "private", "i32"));
}
TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
@ -142,7 +140,7 @@ TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
builtin::Access::kReadWrite);
});
ast::CheckIdentifier(ptr, ast::Template("ptr", "storage", "i32", "read_write"));
CheckIdentifier(ptr, Template("ptr", "storage", "i32", "read_write"));
}
} // namespace

View File

@ -74,7 +74,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
continue;
}
if (func_ast->PipelineStage() != ast::PipelineStage::kVertex) {
if (func_ast->PipelineStage() != PipelineStage::kVertex) {
// Currently only vertex stage could have interstage output variables that need
// truncated.
continue;
@ -118,8 +118,8 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] {
auto new_struct_sym = b.Symbols().New();
utils::Vector<const ast::StructMember*, 20> truncated_members;
utils::Vector<const ast::Expression*, 20> initializer_exprs;
utils::Vector<const StructMember*, 20> truncated_members;
utils::Vector<const Expression*, 20> initializer_exprs;
for (auto* member : str->Members()) {
if (omit_members.Contains(member)) {
@ -155,7 +155,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
// Replace return statements with new truncated shader IO struct
ctx.ReplaceAll(
[&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* {
[&](const ReturnStatement* return_statement) -> const ReturnStatement* {
auto* return_sem = sem.Get(return_statement);
if (auto mapping_fn_sym =
entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
@ -168,11 +168,11 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
// Remove IO attributes from old shader IO struct which is not used as entry point output
// anymore.
for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) {
const ast::Struct* struct_ty = it.key->Declaration();
const Struct* struct_ty = it.key->Declaration();
for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) {
if (attr->IsAnyOf<ast::BuiltinAttribute, ast::LocationAttribute,
ast::InterpolateAttribute, ast::InvariantAttribute>()) {
if (attr->IsAnyOf<BuiltinAttribute, LocationAttribute, InterpolateAttribute,
InvariantAttribute>()) {
ctx.Remove(member->attributes, attr);
}
}

View File

@ -50,29 +50,27 @@ struct Unshadow::State {
// Maps a variable to its new name.
utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
auto rename = [&](const sem::Variable* v) -> const ast::Variable* {
auto rename = [&](const sem::Variable* v) -> const Variable* {
auto* decl = v->Declaration();
auto name = decl->name->symbol.Name();
auto symbol = b.Symbols().New(name);
renamed_to.Add(v, symbol);
auto source = ctx.Clone(decl->source);
auto type = decl->type ? ctx.Clone(decl->type) : ast::Type{};
auto type = decl->type ? ctx.Clone(decl->type) : Type{};
auto* initializer = ctx.Clone(decl->initializer);
auto attributes = ctx.Clone(decl->attributes);
return Switch(
decl, //
[&](const ast::Var* var) {
[&](const Var* var) {
return b.Var(source, symbol, type, var->declared_address_space,
var->declared_access, initializer, attributes);
},
[&](const ast::Let*) {
return b.Let(source, symbol, type, initializer, attributes);
},
[&](const ast::Const*) {
[&](const Let*) { return b.Let(source, symbol, type, initializer, attributes); },
[&](const Const*) {
return b.Const(source, symbol, type, initializer, attributes);
},
[&](const ast::Parameter*) { //
[&](const Parameter*) { //
return b.Param(source, symbol, type, attributes);
},
[&](Default) {
@ -105,8 +103,7 @@ struct Unshadow::State {
return SkipTransform;
}
ctx.ReplaceAll(
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
ctx.ReplaceAll([&](const IdentifierExpression* ident) -> const IdentifierExpression* {
if (auto* sem_ident = sem.GetVal(ident)) {
if (auto* user = sem_ident->Unwrap()->As<sem::VariableUser>()) {
if (auto renamed = renamed_to.Find(user->Variable())) {

View File

@ -20,10 +20,10 @@
namespace tint::ast::transform::utils {
InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt) {
InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt) {
auto& sem = ctx.src->Sem();
auto& diag = ctx.dst->Diagnostics();
using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>;
using RetType = std::pair<const sem::BlockStatement*, const Statement*>;
if (auto* sem_stmt = sem.Get(stmt)) {
auto* parent = sem_stmt->Parent();

View File

@ -24,7 +24,7 @@ namespace tint::ast::transform::utils {
/// InsertionPoint is a pair of the block (`first`) within which, and the
/// statement (`second`) before or after which to insert.
using InsertionPoint = std::pair<const sem::BlockStatement*, const ast::Statement*>;
using InsertionPoint = std::pair<const sem::BlockStatement*, const Statement*>;
/// For the input statement, returns the block and statement within that
/// block to insert before/after. If `stmt` is a for-loop continue statement,
@ -32,7 +32,7 @@ using InsertionPoint = std::pair<const sem::BlockStatement*, const ast::Statemen
/// @param ctx the clone context
/// @param stmt the statement to insert before or after
/// @return the insertion point
InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt);
InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt);
} // namespace tint::ast::transform::utils

View File

@ -37,7 +37,7 @@ struct HoistToDeclBefore::State {
/// @copydoc HoistToDeclBefore::Add()
bool Add(const sem::ValueExpression* before_expr,
const ast::Expression* expr,
const Expression* expr,
VariableKind kind,
const char* decl_name) {
auto name = b.Symbols().New(decl_name);
@ -85,8 +85,8 @@ struct HoistToDeclBefore::State {
return true;
}
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const Statement*)
bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
if (stmt) {
auto builder = [stmt] { return stmt; };
return InsertBeforeImpl(before_stmt, std::move(builder));
@ -99,8 +99,8 @@ struct HoistToDeclBefore::State {
return InsertBeforeImpl(before_stmt, std::move(builder));
}
/// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with)
bool Replace(const sem::Statement* what, const ast::Statement* with) {
/// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with)
bool Replace(const sem::Statement* what, const Statement* with) {
auto builder = [with] { return with; };
return Replace(what, std::move(builder));
}
@ -145,7 +145,7 @@ struct HoistToDeclBefore::State {
utils::Hashmap<const sem::WhileStatement*, LoopInfo, 4> while_loops;
/// 'else if' statements that need to be decomposed to 'else {if}'
utils::Hashmap<const ast::IfStatement*, ElseIfInfo, 4> else_ifs;
utils::Hashmap<const IfStatement*, ElseIfInfo, 4> else_ifs;
template <size_t N>
static auto Build(const utils::Vector<StmtBuilder, N>& builders) {
@ -181,7 +181,7 @@ struct HoistToDeclBefore::State {
/// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the
/// #else_ifs map is mutated.
auto ElseIf(const ast::IfStatement* else_if) {
auto ElseIf(const IfStatement* else_if) {
if (else_ifs.IsEmpty()) {
RegisterElseIfTransform();
}
@ -190,7 +190,7 @@ struct HoistToDeclBefore::State {
/// Registers the handler for transforming for-loops based on the content of the #for_loops map.
void RegisterForLoopTransform() const {
ctx.ReplaceAll([&](const ast::ForLoopStatement* stmt) -> const ast::Statement* {
ctx.ReplaceAll([&](const ForLoopStatement* stmt) -> const Statement* {
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
@ -205,9 +205,9 @@ struct HoistToDeclBefore::State {
if (auto* cond = for_loop->condition) {
// !condition
auto* not_cond =
b.create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
b.create<UnaryOpExpression>(UnaryOp::kNot, ctx.Clone(cond));
// { break; }
auto* break_body = b.Block(b.create<ast::BreakStatement>());
auto* break_body = b.Block(b.create<BreakStatement>());
// if (!condition) { break; }
body_stmts.Push(b.If(not_cond, break_body));
}
@ -215,7 +215,7 @@ struct HoistToDeclBefore::State {
body_stmts.Push(ctx.Clone(for_loop->body));
// Create the continuing block if there was one.
const ast::BlockStatement* continuing = nullptr;
const BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
// Continuing block starts with any let declarations used by
// the continuing.
@ -249,7 +249,7 @@ struct HoistToDeclBefore::State {
/// map.
void RegisterWhileLoopTransform() const {
// At least one while needs to be transformed into a loop.
ctx.ReplaceAll([&](const ast::WhileStatement* stmt) -> const ast::Statement* {
ctx.ReplaceAll([&](const WhileStatement* stmt) -> const Statement* {
auto& sem = ctx.src->Sem();
if (auto* w = sem.Get(stmt)) {
@ -274,7 +274,7 @@ struct HoistToDeclBefore::State {
// Next emit the body
body_stmts.Push(ctx.Clone(while_loop->body));
const ast::BlockStatement* continuing = nullptr;
const BlockStatement* continuing = nullptr;
auto* body = b.Block(body_stmts);
auto* loop = b.Loop(body, continuing);
@ -289,7 +289,7 @@ struct HoistToDeclBefore::State {
/// map.
void RegisterElseIfTransform() const {
// Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* {
ctx.ReplaceAll([&](const IfStatement* stmt) -> const Statement* {
if (auto info = else_ifs.Find(stmt)) {
// Build the else block's body statements, starting with let decls for the
// conditional expression.
@ -412,14 +412,13 @@ HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_uniqu
HoistToDeclBefore::~HoistToDeclBefore() {}
bool HoistToDeclBefore::Add(const sem::ValueExpression* before_expr,
const ast::Expression* expr,
const Expression* expr,
VariableKind kind,
const char* decl_name) {
return state_->Add(before_expr, expr, kind, decl_name);
}
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
const ast::Statement* stmt) {
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
return state_->InsertBefore(before_stmt, stmt);
}
@ -428,7 +427,7 @@ bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
return state_->InsertBefore(before_stmt, builder);
}
bool HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with) {
bool HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with) {
return state_->Replace(what, with);
}

View File

@ -36,7 +36,7 @@ class HoistToDeclBefore {
~HoistToDeclBefore();
/// StmtBuilder is a builder of an AST statement
using StmtBuilder = std::function<const ast::Statement*()>;
using StmtBuilder = std::function<const Statement*()>;
/// VariableKind is either a var, let or const
enum class VariableKind {
@ -53,7 +53,7 @@ class HoistToDeclBefore {
/// @param decl_name optional name to use for the variable/constant name
/// @return true on success
bool Add(const sem::ValueExpression* before_expr,
const ast::Expression* expr,
const Expression* expr,
VariableKind kind,
const char* decl_name = "");
@ -64,7 +64,7 @@ class HoistToDeclBefore {
/// @param before_stmt statement to insert @p stmt before
/// @param stmt statement to insert
/// @return true on success
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt);
bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt);
/// Inserts the returned statement of @p builder before @p before_stmt, possibly converting
/// 'for-loop's to 'loop's if necessary.
@ -81,7 +81,7 @@ class HoistToDeclBefore {
/// @param what the statement to replace
/// @param with the replacement statement
/// @return true on success
bool Replace(const sem::Statement* what, const ast::Statement* with);
bool Replace(const sem::Statement* what, const Statement* with);
/// Replaces the statement @p what with the statement returned by @p stmt, possibly converting
/// 'for-loop's to 'loop's if necessary.

View File

@ -628,7 +628,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -637,7 +637,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
@ -679,7 +679,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont_Function) {
ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -688,7 +688,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont_Function) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
hoistToDeclBefore.InsertBefore(before_stmt,
[&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
@ -1048,7 +1048,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) {
ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -1057,7 +1057,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* target_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
hoistToDeclBefore.Replace(target_stmt, new_stmt);
@ -1098,7 +1098,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) {
ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -1107,7 +1107,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) {
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* target_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
ctx.Clone();

View File

@ -37,7 +37,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
// Extracts array and matrix values that are dynamically indexed to a
// temporary `var` local that is then indexed.
auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) {
auto dynamic_index_to_var = [&](const IndexAccessorExpression* access_expr) {
auto* index_expr = access_expr->index;
auto* object_expr = access_expr->object;
auto& sem = src->Sem();
@ -62,7 +62,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
bool index_accessor_found = false;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
if (auto* access_expr = node->As<IndexAccessorExpression>()) {
if (!dynamic_index_to_var(access_expr)) {
return Program(std::move(b));
}

View File

@ -70,7 +70,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_conv = call->Target()->As<sem::ValueConversion>();
if (!ty_conv) {
@ -102,7 +102,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
}
auto build_vectorized_conversion_expression = [&](auto&& src_expression_builder) {
utils::Vector<const ast::Expression*, 4> columns;
utils::Vector<const Expression*, 4> columns;
for (uint32_t c = 0; c < dst_type->columns(); c++) {
auto* src_matrix_expr = src_expression_builder();
auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));

View File

@ -61,7 +61,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
std::unordered_map<const type::Matrix*, Symbol> scalar_inits;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_init = call->Target()->As<sem::ValueConstructor>();
if (!ty_init) {
@ -91,9 +91,9 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
// Constructs a matrix using vector columns, with the elements constructed using the
// 'element(uint32_t c, uint32_t r)' callback.
auto build_mat = [&](auto&& element) {
utils::Vector<const ast::Expression*, 4> columns;
utils::Vector<const Expression*, 4> columns;
for (uint32_t c = 0; c < mat_type->columns(); c++) {
utils::Vector<const ast::Expression*, 4> row_values;
utils::Vector<const Expression*, 4> row_values;
for (uint32_t r = 0; r < mat_type->rows(); r++) {
row_values.Push(element(c, r));
}

View File

@ -238,9 +238,9 @@ struct VertexPulling::State {
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
// Find entry point
const ast::Function* func = nullptr;
const Function* func = nullptr;
for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
if (fn->PipelineStage() == PipelineStage::kVertex) {
if (func != nullptr) {
b.Diagnostics().add_error(
diag::System::Transform,
@ -264,18 +264,18 @@ struct VertexPulling::State {
}
private:
/// LocationReplacement describes an ast::Variable replacement for a location input.
/// LocationReplacement describes an Variable replacement for a location input.
struct LocationReplacement {
/// The variable to replace in the source Program
ast::Variable* from;
Variable* from;
/// The replacement to use in the target ProgramBuilder
ast::Variable* to;
Variable* to;
};
/// LocationInfo describes an input location
struct LocationInfo {
/// A builder that builds the expression that resolves to the (transformed) input location
std::function<const ast::Expression*()> expr;
std::function<const Expression*()> expr;
/// The store type of the location variable
const type::Type* type;
};
@ -289,12 +289,12 @@ struct VertexPulling::State {
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
std::unordered_map<uint32_t, LocationInfo> location_info;
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
std::function<const ast::Expression*()> instance_index_expr = nullptr;
std::function<const Expression*()> vertex_index_expr = nullptr;
std::function<const Expression*()> instance_index_expr = nullptr;
Symbol pulling_position_name;
Symbol struct_buffer_name;
std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
utils::Vector<const ast::Parameter*, 8> new_function_parameters;
utils::Vector<const Parameter*, 8> new_function_parameters;
/// Generate the vertex buffer binding name
/// @param index index to append to buffer name
@ -331,11 +331,11 @@ struct VertexPulling::State {
}
/// Creates and returns the assignment to the variables from the buffers
const ast::BlockStatement* CreateVertexPullingPreamble() {
const BlockStatement* CreateVertexPullingPreamble() {
// Assign by looking at the vertex descriptor to find attributes with
// matching location.
utils::Vector<const ast::Statement*, 8> stmts;
utils::Vector<const Statement*, 8> stmts;
for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) {
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
@ -399,7 +399,7 @@ struct VertexPulling::State {
// Convert the fetched scalar/vector if WGSL variable is of `f16` types
if (var_dt.base_type == BaseWGSLType::kF16) {
// The type of the same element number of base type of target WGSL variable
ast::Type loaded_data_target_type;
Type loaded_data_target_type;
if (fmt_dt.width == 1) {
loaded_data_target_type = b.ty.f16();
} else {
@ -433,7 +433,7 @@ struct VertexPulling::State {
// The components of result vector variable, initialized with type-converted
// loaded data vector.
utils::Vector<const ast::Expression*, 8> values{fetch};
utils::Vector<const Expression*, 8> values{fetch};
// Add padding elements. The result must be of vector types of signed/unsigned
// integer or float, so use the abstract integer or abstract float value to do
@ -470,7 +470,7 @@ struct VertexPulling::State {
/// @param offset the byte offset of the data from `buffer_base`
/// @param buffer the index of the vertex buffer
/// @param format the vertex format to read
const ast::Expression* Fetch(Symbol array_base,
const Expression* Fetch(Symbol array_base,
uint32_t offset,
uint32_t buffer,
VertexFormat format) {
@ -679,11 +679,11 @@ struct VertexPulling::State {
/// @param buffer the index of the vertex buffer
/// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
/// VertexFormat::kFloat32
const ast::Expression* LoadPrimitive(Symbol array_base,
const Expression* LoadPrimitive(Symbol array_base,
uint32_t offset,
uint32_t buffer,
VertexFormat format) {
const ast::Expression* u = nullptr;
const Expression* u = nullptr;
if ((offset & 3) == 0) {
// Aligned load.
@ -734,14 +734,14 @@ struct VertexPulling::State {
/// @param base_type underlying AST type
/// @param base_format underlying vertex format
/// @param count how many elements the vector has
const ast::Expression* LoadVec(Symbol array_base,
const Expression* LoadVec(Symbol array_base,
uint32_t offset,
uint32_t buffer,
uint32_t element_stride,
ast::Type base_type,
Type base_type,
VertexFormat base_format,
uint32_t count) {
utils::Vector<const ast::Expression*, 8> expr_list;
utils::Vector<const Expression*, 8> expr_list;
for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component
uint32_t primitive_offset = offset + element_stride * i;
@ -756,8 +756,8 @@ struct VertexPulling::State {
/// vertex_index and instance_index builtins if present.
/// @param func the entry point function
/// @param param the parameter to process
void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) {
if (ast::HasAttribute<ast::LocationAttribute>(param->attributes)) {
void ProcessNonStructParameter(const Function* func, const Parameter* param) {
if (HasAttribute<LocationAttribute>(param->attributes)) {
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->name->symbol);
auto func_var_type = ctx.Clone(param->type);
@ -776,7 +776,7 @@ struct VertexPulling::State {
}
location_info[sem->Location().value()] = info;
} else {
auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes);
auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes);
if (TINT_UNLIKELY(!builtin_attr)) {
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
return;
@ -804,21 +804,21 @@ struct VertexPulling::State {
/// @param func the entry point function
/// @param param the parameter to process
/// @param struct_ty the structure type
void ProcessStructParameter(const ast::Function* func,
const ast::Parameter* param,
const ast::Struct* struct_ty) {
void ProcessStructParameter(const Function* func,
const Parameter* param,
const Struct* struct_ty) {
auto param_sym = ctx.Clone(param->name->symbol);
// Process the struct members.
bool has_locations = false;
utils::Vector<const ast::StructMember*, 8> members_to_clone;
utils::Vector<const StructMember*, 8> members_to_clone;
for (auto* member : struct_ty->members) {
auto member_sym = ctx.Clone(member->name->symbol);
std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() {
std::function<const Expression*()> member_expr = [this, param_sym, member_sym]() {
return b.MemberAccessor(param_sym, member_sym);
};
if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) {
if (HasAttribute<LocationAttribute>(member->attributes)) {
// Capture mapping from location to struct member.
LocationInfo info;
info.expr = member_expr;
@ -830,7 +830,7 @@ struct VertexPulling::State {
location_info[sem->Attributes().location.value()] = info;
has_locations = true;
} else {
auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(member->attributes);
auto* builtin_attr = GetAttribute<BuiltinAttribute>(member->attributes);
if (TINT_UNLIKELY(!builtin_attr)) {
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
return;
@ -858,7 +858,7 @@ struct VertexPulling::State {
if (!members_to_clone.IsEmpty()) {
// Create a new struct without the location attributes.
utils::Vector<const ast::StructMember*, 8> new_members;
utils::Vector<const StructMember*, 8> new_members;
for (auto* member : members_to_clone) {
auto member_name = ctx.Clone(member->name);
auto member_type = ctx.Clone(member->type);
@ -883,7 +883,7 @@ struct VertexPulling::State {
/// Process an entry point function.
/// @param func the entry point function
void Process(const ast::Function* func) {
void Process(const Function* func) {
if (func->body->Empty()) {
return;
}
@ -936,8 +936,8 @@ struct VertexPulling::State {
auto attrs = ctx.Clone(func->attributes);
auto ret_attrs = ctx.Clone(func->return_type_attributes);
auto* new_func =
b.create<ast::Function>(func->source, b.Ident(func_sym), new_function_parameters,
ret_type, body, std::move(attrs), std::move(ret_attrs));
b.create<Function>(func->source, b.Ident(func_sym), new_function_parameters, ret_type,
body, std::move(attrs), std::move(ret_attrs));
ctx.Replace(func, new_func);
}
};

View File

@ -26,7 +26,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::WhileStatement>()) {
if (node->Is<WhileStatement>()) {
return true;
}
}
@ -47,8 +47,8 @@ Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, Da
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
utils::Vector<const ast::Statement*, 16> stmts;
ctx.ReplaceAll([&](const WhileStatement* w) -> const Statement* {
utils::Vector<const Statement*, 16> stmts;
auto* cond = w->condition;
// !condition
@ -64,7 +64,7 @@ Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, Da
stmts.Push(ctx.Clone(stmt));
}
const ast::BlockStatement* continuing = nullptr;
const BlockStatement* continuing = nullptr;
auto* body = b.Block(stmts);
auto* loop = b.Loop(body, continuing);

View File

@ -36,7 +36,7 @@ namespace {
bool ShouldRun(const Program* program) {
for (auto* global : program->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (auto* var = global->As<Var>()) {
auto* v = program->Sem().Get(var);
if (v->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
return true;
@ -48,7 +48,7 @@ bool ShouldRun(const Program* program) {
} // namespace
using StatementList = utils::Vector<const ast::Statement*, 8>;
using StatementList = utils::Vector<const Statement*, 8>;
/// PIMPL state for the transform
struct ZeroInitWorkgroupMemory::State {
@ -132,10 +132,10 @@ struct ZeroInitWorkgroupMemory::State {
/// Run inserts the workgroup memory zero-initialization logic at the top of
/// the given function
/// @param fn a compute shader entry point function
void Run(const ast::Function* fn) {
void Run(const Function* fn) {
auto& sem = ctx.src->Sem();
CalculateWorkgroupSize(ast::GetAttribute<ast::WorkgroupAttribute>(fn->attributes));
CalculateWorkgroupSize(GetAttribute<WorkgroupAttribute>(fn->attributes));
// Generate a list of statements to zero initialize each of the
// workgroup storage variables used by `fn`. This will populate #statements.
@ -160,7 +160,7 @@ struct ZeroInitWorkgroupMemory::State {
// parameter
std::function<const ast::Expression*()> local_index;
for (auto* param : fn->params) {
if (auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
if (auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes)) {
auto builtin = sem.Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kLocalInvocationIndex) {
local_index = [=] { return b.Expr(ctx.Clone(param->name->symbol)); };
@ -231,7 +231,7 @@ struct ZeroInitWorkgroupMemory::State {
// }
auto idx = b.Symbols().New("idx");
auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
auto* cond = b.create<ast::BinaryExpression>(ast::BinaryOp::kLessThan, b.Expr(idx),
auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(idx),
b.Expr(u32(num_iterations)));
auto* cont = b.Assign(
idx, b.Add(idx, workgroup_size_const ? b.Expr(u32(workgroup_size_const))
@ -251,8 +251,8 @@ struct ZeroInitWorkgroupMemory::State {
// if (local_index < num_iterations) {
// ...
// }
auto* cond = b.create<ast::BinaryExpression>(
ast::BinaryOp::kLessThan, local_index(), b.Expr(u32(num_iterations)));
auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, local_index(),
b.Expr(u32(num_iterations)));
auto block = DeclareArrayIndices(num_iterations, array_indices,
[&] { return b.Expr(local_index()); });
for (auto& s : stmts) {
@ -382,7 +382,7 @@ struct ZeroInitWorkgroupMemory::State {
for (auto index : array_indices) {
auto name = array_index_names.at(index);
auto* mod = (num_iterations > index.modulo)
? b.create<ast::BinaryExpression>(ast::BinaryOp::kModulo, iteration(),
? b.create<BinaryExpression>(BinaryOp::kModulo, iteration(),
b.Expr(u32(index.modulo)))
: iteration();
auto* div = (index.division != 1u) ? b.Div(mod, u32(index.division)) : mod;
@ -395,7 +395,7 @@ struct ZeroInitWorkgroupMemory::State {
/// CalculateWorkgroupSize initializes the members #workgroup_size_const and
/// #workgroup_size_expr with the linear workgroup size.
/// @param attr the workgroup attribute applied to the entry point function
void CalculateWorkgroupSize(const ast::WorkgroupAttribute* attr) {
void CalculateWorkgroupSize(const WorkgroupAttribute* attr) {
bool is_signed = false;
workgroup_size_const = 1u;
workgroup_size_expr = nullptr;
@ -471,7 +471,7 @@ Transform::ApplyResult ZeroInitWorkgroupMemory::Apply(const Program* src,
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
if (fn->PipelineStage() == PipelineStage::kCompute) {
State{ctx}.Run(fn);
}
}