[tint] Remove ast:: prefixes from AST transforms

Now that AST transforms live in the AST namespace, these prefixes are
no longer necessary.

Change-Id: I658746ac04220075653ec57d6dc998947232d8bc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132425
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2023-05-12 01:43:50 +00:00 committed by Dawn LUCI CQ
parent 4ae03fa8d0
commit 2b7406ad55
59 changed files with 1025 additions and 1064 deletions

View File

@ -41,7 +41,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
// A map from a type in the source program to a block-decorated wrapper that contains it in the // A map from a type in the source program to a block-decorated wrapper that contains it in the
// destination program. // destination program.
utils::Hashmap<const type::Type*, const ast::Struct*, 8> wrapper_structs; utils::Hashmap<const type::Type*, const Struct*, 8> wrapper_structs;
// Process global 'var' declarations that are buffers. // Process global 'var' declarations that are buffers.
bool made_changes = false; bool made_changes = false;
@ -71,7 +71,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] { auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID()); auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
auto wrapper_name = global->name->symbol.Name() + "_block"; auto wrapper_name = global->name->symbol.Name() + "_block";
auto* ret = b.create<ast::Struct>( auto* ret = b.create<Struct>(
b.Ident(b.Symbols().New(wrapper_name)), b.Ident(b.Symbols().New(wrapper_name)),
utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))}, utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
utils::Vector{block}); utils::Vector{block});
@ -101,7 +101,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
return Program(std::move(b)); return Program(std::move(b));
} }
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid) AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, NodeID nid)
: Base(pid, nid, utils::Empty) {} : Base(pid, nid, utils::Empty) {}
AddBlockAttribute::BlockAttribute::~BlockAttribute() = default; AddBlockAttribute::BlockAttribute::~BlockAttribute() = default;
std::string AddBlockAttribute::BlockAttribute::InternalName() const { std::string AddBlockAttribute::BlockAttribute::InternalName() const {

View File

@ -28,12 +28,12 @@ class AddBlockAttribute final : public utils::Castable<AddBlockAttribute, Transf
public: public:
/// BlockAttribute is an InternalAttribute that is used to decorate a /// BlockAttribute is an InternalAttribute that is used to decorate a
// structure that is used as a buffer in SPIR-V or GLSL. // structure that is used as a buffer in SPIR-V or GLSL.
class BlockAttribute final : public utils::Castable<BlockAttribute, ast::InternalAttribute> { class BlockAttribute final : public utils::Castable<BlockAttribute, InternalAttribute> {
public: public:
/// Constructor /// Constructor
/// @param program_id the identifier of the program that owns this node /// @param program_id the identifier of the program that owns this node
/// @param nid the unique node identifier /// @param nid the unique node identifier
BlockAttribute(ProgramID program_id, ast::NodeID nid); BlockAttribute(ProgramID program_id, NodeID nid);
/// Destructor /// Destructor
~BlockAttribute() override; ~BlockAttribute() override;

View File

@ -52,7 +52,7 @@ Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src,
b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {}, b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });

View File

@ -81,8 +81,8 @@ struct ArrayLengthFromUniform::State {
// Determine the size of the buffer size array. // Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0; uint32_t max_buffer_size_index = 0;
IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*, IterateArrayLengthOnStorageVar(
const sem::GlobalVariable* var) { [&](const CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
if (auto binding = var->BindingPoint()) { if (auto binding = var->BindingPoint()) {
auto idx_itr = cfg->bindpoint_to_size_index.find(*binding); auto idx_itr = cfg->bindpoint_to_size_index.find(*binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) { if (idx_itr == cfg->bindpoint_to_size_index.end()) {
@ -96,7 +96,7 @@ struct ArrayLengthFromUniform::State {
// Get (or create, on first call) the uniform buffer that will receive the // Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module. // size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr; const Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() { auto get_ubo = [&]() {
if (!buffer_size_ubo) { if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements. // Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
@ -118,7 +118,7 @@ struct ArrayLengthFromUniform::State {
std::unordered_set<uint32_t> used_size_indices; std::unordered_set<uint32_t> used_size_indices;
IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr, IterateArrayLengthOnStorageVar([&](const CallExpression* call_expr,
const sem::VariableUser* storage_buffer_sem, const sem::VariableUser* storage_buffer_sem,
const sem::GlobalVariable* var) { const sem::GlobalVariable* var) {
auto binding = var->BindingPoint(); auto binding = var->BindingPoint();
@ -144,7 +144,7 @@ struct ArrayLengthFromUniform::State {
// total_storage_buffer_size - array_offset // total_storage_buffer_size - array_offset
// array_length = ---------------------------------------- // array_length = ----------------------------------------
// array_stride // array_stride
const ast::Expression* total_size = total_storage_buffer_size; const Expression* total_size = total_storage_buffer_size;
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef(); auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
const type::Array* array_type = nullptr; const type::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<type::Struct>()) { if (auto* str = storage_buffer_type->As<type::Struct>()) {
@ -186,9 +186,9 @@ struct ArrayLengthFromUniform::State {
/// Iterate over all arrayLength() builtins that operate on /// Iterate over all arrayLength() builtins that operate on
/// storage buffer variables. /// storage buffer variables.
/// @param functor of type void(const ast::CallExpression*, const /// @param functor of type void(const CallExpression*, const
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an /// sem::VariableUser, const sem::GlobalVariable*). It takes in an
/// ast::CallExpression of the arrayLength call expression node, a /// CallExpression of the arrayLength call expression node, a
/// sem::VariableUser of the used storage buffer variable, and the /// sem::VariableUser of the used storage buffer variable, and the
/// sem::GlobalVariable for the storage buffer. /// sem::GlobalVariable for the storage buffer.
template <typename F> template <typename F>
@ -197,7 +197,7 @@ struct ArrayLengthFromUniform::State {
// Find all calls to the arrayLength() builtin. // Find all calls to the arrayLength() builtin.
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>(); auto* call_expr = node->As<CallExpression>();
if (!call_expr) { if (!call_expr) {
continue; continue;
} }
@ -208,7 +208,7 @@ struct ArrayLengthFromUniform::State {
continue; continue;
} }
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) { if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
if (call_stmt->expr == call_expr) { if (call_stmt->expr == call_expr) {
// arrayLength() is used as a statement. // arrayLength() is used as a statement.
// The argument expression must be side-effect free, so just drop the statement. // The argument expression must be side-effect free, so just drop the statement.
@ -222,15 +222,15 @@ struct ArrayLengthFromUniform::State {
// call has one of two forms: // call has one of two forms:
// arrayLength(&struct_var.array_member) // arrayLength(&struct_var.array_member)
// arrayLength(&array_var) // arrayLength(&array_var)
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>(); auto* param = call_expr->args[0]->As<UnaryOpExpression>();
if (TINT_UNLIKELY(!param || param->op != ast::UnaryOp::kAddressOf)) { if (TINT_UNLIKELY(!param || param->op != UnaryOp::kAddressOf)) {
TINT_ICE(Transform, b.Diagnostics()) TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or " << "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member"; "&struct_var.array_member";
break; break;
} }
auto* storage_buffer_expr = param->expr; auto* storage_buffer_expr = param->expr;
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) { if (auto* accessor = param->expr->As<MemberAccessorExpression>()) {
storage_buffer_expr = accessor->object; storage_buffer_expr = accessor->object;
} }
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);

View File

@ -90,7 +90,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
} }
} }
for (auto* var : src->AST().Globals<ast::Var>()) { for (auto* var : src->AST().Globals<Var>()) {
if (var->HasBindingPoint()) { if (var->HasBindingPoint()) {
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var); auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
@ -109,8 +109,8 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
auto* new_group = b.Group(AInt(to.group)); auto* new_group = b.Group(AInt(to.group));
auto* new_binding = b.Binding(AInt(to.binding)); auto* new_binding = b.Binding(AInt(to.binding));
auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes); auto* old_group = GetAttribute<GroupAttribute>(var->attributes);
auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes); auto* old_binding = GetAttribute<BindingAttribute>(var->attributes);
ctx.Replace(old_group, new_group); ctx.Replace(old_group, new_group);
ctx.Replace(old_binding, new_binding); ctx.Replace(old_binding, new_binding);
@ -139,7 +139,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
auto* ty = sem->Type()->UnwrapRef(); auto* ty = sem->Type()->UnwrapRef();
auto inner_ty = CreateASTTypeFor(ctx, ty); auto inner_ty = CreateASTTypeFor(ctx, ty);
auto* new_var = auto* new_var =
b.create<ast::Var>(ctx.Clone(var->source), // source b.create<Var>(ctx.Clone(var->source), // source
b.Ident(ctx.Clone(var->name->symbol)), // name b.Ident(ctx.Clone(var->name->symbol)), // name
inner_ty, // type inner_ty, // type
ctx.Clone(var->declared_address_space), // address space ctx.Clone(var->declared_address_space), // address space
@ -151,7 +151,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
// Add `DisableValidationAttribute`s if required // Add `DisableValidationAttribute`s if required
if (add_collision_attr.count(bp)) { if (add_collision_attr.count(bp)) {
auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision); auto* attribute = b.Disable(DisabledValidation::kBindingPointCollision);
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute); ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
} }
} }

View File

@ -37,7 +37,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::BuiltinPolyfill::Config);
namespace tint::ast::transform { namespace tint::ast::transform {
/// BinaryOpSignature is tuple of a binary op, LHS type and RHS type /// BinaryOpSignature is tuple of a binary op, LHS type and RHS type
using BinaryOpSignature = std::tuple<ast::BinaryOp, const type::Type*, const type::Type*>; using BinaryOpSignature = std::tuple<BinaryOp, const type::Type*, const type::Type*>;
/// PIMPL state for the transform /// PIMPL state for the transform
struct BuiltinPolyfill::State { struct BuiltinPolyfill::State {
@ -60,16 +60,16 @@ struct BuiltinPolyfill::State {
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
Switch( Switch(
node, // node, //
[&](const ast::CallExpression* expr) { Call(expr); }, [&](const CallExpression* expr) { Call(expr); },
[&](const ast::BinaryExpression* bin_op) { [&](const BinaryExpression* bin_op) {
if (auto* s = src->Sem().Get(bin_op); if (auto* s = src->Sem().Get(bin_op);
!s || s->Stage() == sem::EvaluationStage::kConstant || !s || s->Stage() == sem::EvaluationStage::kConstant ||
s->Stage() == sem::EvaluationStage::kNotEvaluated) { s->Stage() == sem::EvaluationStage::kNotEvaluated) {
return; // Don't polyfill @const expressions return; // Don't polyfill @const expressions
} }
switch (bin_op->op) { switch (bin_op->op) {
case ast::BinaryOp::kShiftLeft: case BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight: { case BinaryOp::kShiftRight: {
if (cfg.builtins.bitshift_modulo) { if (cfg.builtins.bitshift_modulo) {
ctx.Replace(bin_op, ctx.Replace(bin_op,
[this, bin_op] { return BitshiftModulo(bin_op); }); [this, bin_op] { return BitshiftModulo(bin_op); });
@ -77,7 +77,7 @@ struct BuiltinPolyfill::State {
} }
break; break;
} }
case ast::BinaryOp::kDivide: { case BinaryOp::kDivide: {
if (cfg.builtins.int_div_mod) { if (cfg.builtins.int_div_mod) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
if (lhs_ty->is_integer_scalar_or_vector()) { if (lhs_ty->is_integer_scalar_or_vector()) {
@ -88,7 +88,7 @@ struct BuiltinPolyfill::State {
} }
break; break;
} }
case ast::BinaryOp::kModulo: { case BinaryOp::kModulo: {
if (cfg.builtins.int_div_mod) { if (cfg.builtins.int_div_mod) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
if (lhs_ty->is_integer_scalar_or_vector()) { if (lhs_ty->is_integer_scalar_or_vector()) {
@ -111,7 +111,7 @@ struct BuiltinPolyfill::State {
break; break;
} }
}, },
[&](const ast::Expression* expr) { [&](const Expression* expr) {
if (cfg.builtins.bgra8unorm) { if (cfg.builtins.bgra8unorm) {
if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) { if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) { if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
@ -170,15 +170,15 @@ struct BuiltinPolyfill::State {
auto name = b.Symbols().New("tint_acosh"); auto name = b.Symbols().New("tint_acosh");
uint32_t width = WidthOf(ty); uint32_t width = WidthOf(ty);
auto V = [&](AFloat value) -> const ast::Expression* { auto V = [&](AFloat value) -> const Expression* {
const ast::Expression* expr = b.Expr(value); const Expression* expr = b.Expr(value);
if (width == 1) { if (width == 1) {
return expr; return expr;
} }
return b.Call(T(ty), expr); return b.Call(T(ty), expr);
}; };
utils::Vector<const ast::Statement*, 4> body; utils::Vector<const Statement*, 4> body;
switch (cfg.builtins.acosh) { switch (cfg.builtins.acosh) {
case Level::kFull: case Level::kFull:
// return log(x + sqrt(x*x - 1)); // return log(x + sqrt(x*x - 1));
@ -224,15 +224,15 @@ struct BuiltinPolyfill::State {
auto name = b.Symbols().New("tint_atanh"); auto name = b.Symbols().New("tint_atanh");
uint32_t width = WidthOf(ty); uint32_t width = WidthOf(ty);
auto V = [&](AFloat value) -> const ast::Expression* { auto V = [&](AFloat value) -> const Expression* {
const ast::Expression* expr = b.Expr(value); const Expression* expr = b.Expr(value);
if (width == 1) { if (width == 1) {
return expr; return expr;
} }
return b.Call(T(ty), expr); return b.Call(T(ty), expr);
}; };
utils::Vector<const ast::Statement*, 1> body; utils::Vector<const Statement*, 1> body;
switch (cfg.builtins.atanh) { switch (cfg.builtins.atanh) {
case Level::kFull: case Level::kFull:
// return log((1+x) / (1-x)) * 0.5 // return log((1+x) / (1-x)) * 0.5
@ -290,7 +290,7 @@ struct BuiltinPolyfill::State {
} }
return b.ty.vec<u32>(width); return b.ty.vec<u32>(width);
}; };
auto V = [&](uint32_t value) -> const ast::Expression* { auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value)); return ScalarOrVector(width, u32(value));
}; };
b.Func( b.Func(
@ -348,10 +348,10 @@ struct BuiltinPolyfill::State {
} }
return b.ty.vec<u32>(width); return b.ty.vec<u32>(width);
}; };
auto V = [&](uint32_t value) -> const ast::Expression* { auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value)); return ScalarOrVector(width, u32(value));
}; };
auto B = [&](const ast::Expression* value) -> const ast::Expression* { auto B = [&](const Expression* value) -> const Expression* {
if (width == 1) { if (width == 1) {
return b.Call<bool>(value); return b.Call<bool>(value);
} }
@ -402,14 +402,14 @@ struct BuiltinPolyfill::State {
constexpr uint32_t W = 32u; // 32-bit constexpr uint32_t W = 32u; // 32-bit
auto vecN_u32 = [&](const ast::Expression* value) -> const ast::Expression* { auto vecN_u32 = [&](const Expression* value) -> const Expression* {
if (width == 1) { if (width == 1) {
return value; return value;
} }
return b.Call(b.ty.vec<u32>(width), value); return b.Call(b.ty.vec<u32>(width), value);
}; };
utils::Vector<const ast::Statement*, 8> body{ utils::Vector<const Statement*, 8> body{
b.Decl(b.Let("s", b.Call("min", "offset", u32(W)))), b.Decl(b.Let("s", b.Call("min", "offset", u32(W)))),
b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))), b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))),
}; };
@ -465,17 +465,17 @@ struct BuiltinPolyfill::State {
} }
return b.ty.vec<u32>(width); return b.ty.vec<u32>(width);
}; };
auto V = [&](uint32_t value) -> const ast::Expression* { auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value)); return ScalarOrVector(width, u32(value));
}; };
auto B = [&](const ast::Expression* value) -> const ast::Expression* { auto B = [&](const Expression* value) -> const Expression* {
if (width == 1) { if (width == 1) {
return b.Call<bool>(value); return b.Call<bool>(value);
} }
return b.Call(b.ty.vec<bool>(width), value); return b.Call(b.ty.vec<bool>(width), value);
}; };
const ast::Expression* x = nullptr; const Expression* x = nullptr;
if (ty->is_unsigned_integer_scalar_or_vector()) { if (ty->is_unsigned_integer_scalar_or_vector()) {
x = b.Expr("v"); x = b.Expr("v");
} else { } else {
@ -537,10 +537,10 @@ struct BuiltinPolyfill::State {
} }
return b.ty.vec<u32>(width); return b.ty.vec<u32>(width);
}; };
auto V = [&](uint32_t value) -> const ast::Expression* { auto V = [&](uint32_t value) -> const Expression* {
return ScalarOrVector(width, u32(value)); return ScalarOrVector(width, u32(value));
}; };
auto B = [&](const ast::Expression* value) -> const ast::Expression* { auto B = [&](const Expression* value) -> const Expression* {
if (width == 1) { if (width == 1) {
return b.Call<bool>(value); return b.Call<bool>(value);
} }
@ -599,8 +599,8 @@ struct BuiltinPolyfill::State {
constexpr uint32_t W = 32u; // 32-bit constexpr uint32_t W = 32u; // 32-bit
auto V = [&](auto value) -> const ast::Expression* { auto V = [&](auto value) -> const Expression* {
const ast::Expression* expr = b.Expr(value); const Expression* expr = b.Expr(value);
if (!ty->is_unsigned_integer_scalar_or_vector()) { if (!ty->is_unsigned_integer_scalar_or_vector()) {
expr = b.Call<i32>(expr); expr = b.Call<i32>(expr);
} }
@ -609,7 +609,7 @@ struct BuiltinPolyfill::State {
} }
return expr; return expr;
}; };
auto U = [&](auto value) -> const ast::Expression* { auto U = [&](auto value) -> const Expression* {
if (width == 1) { if (width == 1) {
return b.Expr(value); return b.Expr(value);
} }
@ -638,7 +638,7 @@ struct BuiltinPolyfill::State {
// return ((select(T(), n << offset, offset < 32u) & mask) | (v & ~(mask))); // return ((select(T(), n << offset, offset < 32u) & mask) | (v & ~(mask)));
// } // }
utils::Vector<const ast::Statement*, 8> body; utils::Vector<const Statement*, 8> body;
switch (cfg.builtins.insert_bits) { switch (cfg.builtins.insert_bits) {
case Level::kFull: case Level::kFull:
@ -788,7 +788,7 @@ struct BuiltinPolyfill::State {
/// @return the polyfill function name /// @return the polyfill function name
Symbol quantizeToF16(const type::Vector* vec) { Symbol quantizeToF16(const type::Vector* vec) {
auto name = b.Symbols().New("tint_quantizeToF16"); auto name = b.Symbols().New("tint_quantizeToF16");
utils::Vector<const ast::Expression*, 4> args; utils::Vector<const Expression*, 4> args;
for (uint32_t i = 0; i < vec->Width(); i++) { for (uint32_t i = 0; i < vec->Width(); i++) {
args.Push(b.Call("quantizeToF16", b.IndexAccessor("v", u32(i)))); args.Push(b.Call("quantizeToF16", b.IndexAccessor("v", u32(i))));
} }
@ -880,29 +880,29 @@ struct BuiltinPolyfill::State {
/// the RHS is modulo the bit-width of the LHS. /// the RHS is modulo the bit-width of the LHS.
/// @param bin_op the original BinaryExpression /// @param bin_op the original BinaryExpression
/// @return the polyfill value for bitshift operation /// @return the polyfill value for bitshift operation
const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) { const Expression* BitshiftModulo(const BinaryExpression* bin_op) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef(); auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty); auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1)); const Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
if (rhs_ty->Is<type::Vector>()) { if (rhs_ty->Is<type::Vector>()) {
mask = b.Call(CreateASTTypeFor(ctx, rhs_ty), mask); mask = b.Call(CreateASTTypeFor(ctx, rhs_ty), mask);
} }
auto* lhs = ctx.Clone(bin_op->lhs); auto* lhs = ctx.Clone(bin_op->lhs);
auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask); auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask);
return b.create<ast::BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs); return b.create<BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
} }
/// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and /// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and
/// integer overflows. /// integer overflows.
/// @param bin_op the original BinaryExpression /// @param bin_op the original BinaryExpression
/// @return the polyfill divide or modulo /// @return the polyfill divide or modulo
const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) { const Expression* IntDivMod(const BinaryExpression* bin_op) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef(); auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty}; BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] { auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
const bool is_div = bin_op->op == ast::BinaryOp::kDivide; const bool is_div = bin_op->op == BinaryOp::kDivide;
uint32_t lhs_width = 1; uint32_t lhs_width = 1;
uint32_t rhs_width = 1; uint32_t rhs_width = 1;
@ -914,7 +914,7 @@ struct BuiltinPolyfill::State {
const char* lhs = "lhs"; const char* lhs = "lhs";
const char* rhs = "rhs"; const char* rhs = "rhs";
utils::Vector<const ast::Statement*, 4> body; utils::Vector<const Statement*, 4> body;
if (lhs_width < width) { if (lhs_width < width) {
// lhs is scalar, rhs is vector. Convert lhs to vector. // lhs is scalar, rhs is vector. Convert lhs to vector.
@ -934,8 +934,8 @@ struct BuiltinPolyfill::State {
if (lhs_ty->is_signed_integer_scalar_or_vector()) { if (lhs_ty->is_signed_integer_scalar_or_vector()) {
const auto bits = lhs_el_ty->Size() * 8; const auto bits = lhs_el_ty->Size() * 8;
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits)); auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int)); const Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a)); const Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
// use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1)) // use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1))
auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one)); auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one));
@ -992,7 +992,7 @@ struct BuiltinPolyfill::State {
/// Builds the polyfill inline expression for a precise float modulo, as defined in the spec. /// Builds the polyfill inline expression for a precise float modulo, as defined in the spec.
/// @param bin_op the original BinaryExpression /// @param bin_op the original BinaryExpression
/// @return the polyfill divide or modulo /// @return the polyfill divide or modulo
const ast::Expression* PreciseFloatMod(const ast::BinaryExpression* bin_op) { const Expression* PreciseFloatMod(const BinaryExpression* bin_op) {
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef(); auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef(); auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty}; BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
@ -1007,7 +1007,7 @@ struct BuiltinPolyfill::State {
const char* lhs = "lhs"; const char* lhs = "lhs";
const char* rhs = "rhs"; const char* rhs = "rhs";
utils::Vector<const ast::Statement*, 4> body; utils::Vector<const Statement*, 4> body;
if (lhs_width < width) { if (lhs_width < width) {
// lhs is scalar, rhs is vector. Convert lhs to vector. // lhs is scalar, rhs is vector. Convert lhs to vector.
@ -1042,7 +1042,7 @@ struct BuiltinPolyfill::State {
} }
/// @returns the AST type for the given sem type /// @returns the AST type for the given sem type
ast::Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); } Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); }
/// @returns 1 if `ty` is not a vector, otherwise the vector width /// @returns 1 if `ty` is not a vector, otherwise the vector width
uint32_t WidthOf(const type::Type* ty) const { uint32_t WidthOf(const type::Type* ty) const {
@ -1055,7 +1055,7 @@ struct BuiltinPolyfill::State {
/// @returns a scalar or vector with the given width, with each element with /// @returns a scalar or vector with the given width, with each element with
/// the given value. /// the given value.
template <typename T> template <typename T>
const ast::Expression* ScalarOrVector(uint32_t width, T value) { const Expression* ScalarOrVector(uint32_t width, T value) {
if (width == 1) { if (width == 1) {
return b.Expr(value); return b.Expr(value);
} }
@ -1063,7 +1063,7 @@ struct BuiltinPolyfill::State {
} }
template <typename To> template <typename To>
const ast::Expression* CastScalarOrVector(uint32_t width, const ast::Expression* e) { const Expression* CastScalarOrVector(uint32_t width, const Expression* e) {
if (width == 1) { if (width == 1) {
return b.Call(b.ty.Of<To>(), e); return b.Call(b.ty.Of<To>(), e);
} }
@ -1071,7 +1071,7 @@ struct BuiltinPolyfill::State {
} }
/// Examines the call expression @p expr, applying any necessary polyfill transforms /// Examines the call expression @p expr, applying any necessary polyfill transforms
void Call(const ast::CallExpression* expr) { void Call(const CallExpression* expr) {
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
if (!call || call->Stage() == sem::EvaluationStage::kConstant || if (!call || call->Stage() == sem::EvaluationStage::kConstant ||
call->Stage() == sem::EvaluationStage::kNotEvaluated) { call->Stage() == sem::EvaluationStage::kNotEvaluated) {
@ -1207,7 +1207,7 @@ struct BuiltinPolyfill::State {
size_t value_idx = static_cast<size_t>( size_t value_idx = static_cast<size_t>(
sig.IndexOf(sem::ParameterUsage::kValue)); sig.IndexOf(sem::ParameterUsage::kValue));
ctx.Replace(expr, [this, expr, value_idx] { ctx.Replace(expr, [this, expr, value_idx] {
utils::Vector<const ast::Expression*, 3> args; utils::Vector<const Expression*, 3> args;
for (auto* arg : expr->args) { for (auto* arg : expr->args) {
arg = ctx.Clone(arg); arg = ctx.Clone(arg);
if (args.Length() == value_idx) { // value if (args.Length() == value_idx) { // value

View File

@ -57,7 +57,7 @@ bool ShouldRun(const Program* program) {
/// ArrayUsage describes a runtime array usage. /// ArrayUsage describes a runtime array usage.
/// It is used as a key by the array_length_by_usage map. /// It is used as a key by the array_length_by_usage map.
struct ArrayUsage { struct ArrayUsage {
ast::BlockStatement const* const block; BlockStatement const* const block;
sem::Variable const* const buffer; sem::Variable const* const buffer;
bool operator==(const ArrayUsage& rhs) const { bool operator==(const ArrayUsage& rhs) const {
return block == rhs.block && buffer == rhs.buffer; return block == rhs.block && buffer == rhs.buffer;
@ -71,7 +71,7 @@ struct ArrayUsage {
} // namespace } // namespace
CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, ast::NodeID nid) CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, NodeID nid)
: Base(pid, nid, utils::Empty) {} : Base(pid, nid, utils::Empty) {}
CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default; CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const { std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
@ -106,7 +106,7 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] { return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = b.Sym(); auto name = b.Sym();
auto type = CreateASTTypeFor(ctx, buffer_type); auto type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter); auto* disable_validation = b.Disable(DisabledValidation::kFunctionParameter);
b.Func( b.Func(
name, name,
utils::Vector{ utils::Vector{
@ -128,13 +128,13 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
// Find all the arrayLength() calls... // Find all the arrayLength() calls...
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
if (auto* call_expr = node->As<ast::CallExpression>()) { if (auto* call_expr = node->As<CallExpression>()) {
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>(); auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
if (auto* builtin = call->Target()->As<sem::Builtin>()) { if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == builtin::Function::kArrayLength) { if (builtin->Type() == builtin::Function::kArrayLength) {
// We're dealing with an arrayLength() call // We're dealing with an arrayLength() call
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) { if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
if (call_stmt->expr == call_expr) { if (call_stmt->expr == call_expr) {
// arrayLength() is used as a statement. // arrayLength() is used as a statement.
// The argument expression must be side-effect free, so just drop the // The argument expression must be side-effect free, so just drop the
@ -151,13 +151,13 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
// arrayLength(&struct_var.array_member) // arrayLength(&struct_var.array_member)
// arrayLength(&array_var) // arrayLength(&array_var)
auto* arg = call_expr->args[0]; auto* arg = call_expr->args[0];
auto* address_of = arg->As<ast::UnaryOpExpression>(); auto* address_of = arg->As<UnaryOpExpression>();
if (TINT_UNLIKELY(!address_of || address_of->op != ast::UnaryOp::kAddressOf)) { if (TINT_UNLIKELY(!address_of || address_of->op != UnaryOp::kAddressOf)) {
TINT_ICE(Transform, b.Diagnostics()) TINT_ICE(Transform, b.Diagnostics())
<< "arrayLength() expected address-of, got " << arg->TypeInfo().name; << "arrayLength() expected address-of, got " << arg->TypeInfo().name;
} }
auto* storage_buffer_expr = address_of->expr; auto* storage_buffer_expr = address_of->expr;
if (auto* accessor = storage_buffer_expr->As<ast::MemberAccessorExpression>()) { if (auto* accessor = storage_buffer_expr->As<MemberAccessorExpression>()) {
storage_buffer_expr = accessor->object; storage_buffer_expr = accessor->object;
} }
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr); auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
@ -199,8 +199,7 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
// array_length = ---------------------------------------- // array_length = ----------------------------------------
// array_stride // array_stride
auto name = b.Sym(); auto name = b.Sym();
const ast::Expression* total_size = const Expression* total_size = b.Expr(buffer_size_result->variable);
b.Expr(buffer_size_result->variable);
const type::Array* array_type = Switch( const type::Array* array_type = Switch(
storage_buffer_type->StoreType(), storage_buffer_type->StoreType(),

View File

@ -37,12 +37,12 @@ class CalculateArrayLength final : public utils::Castable<CalculateArrayLength,
/// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic /// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
/// functions used to obtain the runtime size of a storage buffer. /// functions used to obtain the runtime size of a storage buffer.
class BufferSizeIntrinsic final class BufferSizeIntrinsic final
: public utils::Castable<BufferSizeIntrinsic, ast::InternalAttribute> { : public utils::Castable<BufferSizeIntrinsic, InternalAttribute> {
public: public:
/// Constructor /// Constructor
/// @param program_id the identifier of the program that owns this node /// @param program_id the identifier of the program that owns this node
/// @param nid the unique node identifier /// @param nid the unique node identifier
BufferSizeIntrinsic(ProgramID program_id, ast::NodeID nid); BufferSizeIntrinsic(ProgramID program_id, NodeID nid);
/// Destructor /// Destructor
~BufferSizeIntrinsic() override; ~BufferSizeIntrinsic() override;

View File

@ -41,7 +41,7 @@ namespace {
/// Info for a struct member /// Info for a struct member
struct MemberInfo { struct MemberInfo {
/// The struct member item /// The struct member item
const ast::StructMember* member; const StructMember* member;
/// The struct member location if provided /// The struct member location if provided
std::optional<uint32_t> location; std::optional<uint32_t> location;
}; };
@ -83,9 +83,9 @@ uint32_t BuiltinOrder(builtin::BuiltinValue builtin) {
} }
// Returns true if `attr` is a shader IO attribute. // Returns true if `attr` is a shader IO attribute.
bool IsShaderIOAttribute(const ast::Attribute* attr) { bool IsShaderIOAttribute(const Attribute* attr) {
return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute, ast::InvariantAttribute, return attr
ast::LocationAttribute>(); ->IsAnyOf<BuiltinAttribute, InterpolateAttribute, InvariantAttribute, LocationAttribute>();
} }
} // namespace } // namespace
@ -97,11 +97,11 @@ struct CanonicalizeEntryPointIO::State {
/// The name of the output value. /// The name of the output value.
std::string name; std::string name;
/// The type of the output value. /// The type of the output value.
ast::Type type; Type type;
/// The shader IO attributes. /// The shader IO attributes.
utils::Vector<const ast::Attribute*, 8> attributes; utils::Vector<const Attribute*, 8> attributes;
/// The value itself. /// The value itself.
const ast::Expression* value; const Expression* value;
/// The output location. /// The output location.
std::optional<uint32_t> location; std::optional<uint32_t> location;
}; };
@ -111,29 +111,29 @@ struct CanonicalizeEntryPointIO::State {
/// The transform config. /// The transform config.
CanonicalizeEntryPointIO::Config const cfg; CanonicalizeEntryPointIO::Config const cfg;
/// The entry point function (AST). /// The entry point function (AST).
const ast::Function* func_ast; const Function* func_ast;
/// The entry point function (SEM). /// The entry point function (SEM).
const sem::Function* func_sem; const sem::Function* func_sem;
/// The new entry point wrapper function's parameters. /// The new entry point wrapper function's parameters.
utils::Vector<const ast::Parameter*, 8> wrapper_ep_parameters; utils::Vector<const Parameter*, 8> wrapper_ep_parameters;
/// The members of the wrapper function's struct parameter. /// The members of the wrapper function's struct parameter.
utils::Vector<MemberInfo, 8> wrapper_struct_param_members; utils::Vector<MemberInfo, 8> wrapper_struct_param_members;
/// The name of the wrapper function's struct parameter. /// The name of the wrapper function's struct parameter.
Symbol wrapper_struct_param_name; Symbol wrapper_struct_param_name;
/// The parameters that will be passed to the original function. /// The parameters that will be passed to the original function.
utils::Vector<const ast::Expression*, 8> inner_call_parameters; utils::Vector<const Expression*, 8> inner_call_parameters;
/// The members of the wrapper function's struct return type. /// The members of the wrapper function's struct return type.
utils::Vector<MemberInfo, 8> wrapper_struct_output_members; utils::Vector<MemberInfo, 8> wrapper_struct_output_members;
/// The wrapper function output values. /// The wrapper function output values.
utils::Vector<OutputValue, 8> wrapper_output_values; utils::Vector<OutputValue, 8> wrapper_output_values;
/// The body of the wrapper function. /// The body of the wrapper function.
utils::Vector<const ast::Statement*, 8> wrapper_body; utils::Vector<const Statement*, 8> wrapper_body;
/// Input names used by the entrypoint /// Input names used by the entrypoint
std::unordered_set<std::string> input_names; std::unordered_set<std::string> input_names;
/// A map of cloned attribute to builtin value /// A map of cloned attribute to builtin value
utils::Hashmap<const ast::BuiltinAttribute*, builtin::BuiltinValue, 16> builtin_attrs; utils::Hashmap<const BuiltinAttribute*, builtin::BuiltinValue, 16> builtin_attrs;
/// Constructor /// Constructor
/// @param context the clone context /// @param context the clone context
@ -141,7 +141,7 @@ struct CanonicalizeEntryPointIO::State {
/// @param function the entry point function /// @param function the entry point function
State(CloneContext& context, State(CloneContext& context,
const CanonicalizeEntryPointIO::Config& config, const CanonicalizeEntryPointIO::Config& config,
const ast::Function* function) const Function* function)
: ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {} : ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {}
/// Clones the attributes from @p in and adds it to @p out. If @p in is a builtin attribute, /// Clones the attributes from @p in and adds it to @p out. If @p in is a builtin attribute,
@ -149,12 +149,11 @@ struct CanonicalizeEntryPointIO::State {
/// @param in the attribute to clone /// @param in the attribute to clone
/// @param out the output Attributes /// @param out the output Attributes
template <size_t N> template <size_t N>
void CloneAttribute(const ast::Attribute* in, utils::Vector<const ast::Attribute*, N>& out) { void CloneAttribute(const Attribute* in, utils::Vector<const Attribute*, N>& out) {
auto* cloned = ctx.Clone(in); auto* cloned = ctx.Clone(in);
out.Push(cloned); out.Push(cloned);
if (auto* builtin = in->As<ast::BuiltinAttribute>()) { if (auto* builtin = in->As<BuiltinAttribute>()) {
builtin_attrs.Add(cloned->As<ast::BuiltinAttribute>(), builtin_attrs.Add(cloned->As<BuiltinAttribute>(), ctx.src->Sem().Get(builtin)->Value());
ctx.src->Sem().Get(builtin)->Value());
} }
} }
@ -163,12 +162,11 @@ struct CanonicalizeEntryPointIO::State {
/// @param do_interpolate whether to clone InterpolateAttribute /// @param do_interpolate whether to clone InterpolateAttribute
/// @return the cloned attributes /// @return the cloned attributes
template <size_t N> template <size_t N>
auto CloneShaderIOAttributes(const utils::Vector<const ast::Attribute*, N> in, auto CloneShaderIOAttributes(const utils::Vector<const Attribute*, N> in, bool do_interpolate) {
bool do_interpolate) { utils::Vector<const Attribute*, N> out;
utils::Vector<const ast::Attribute*, N> out;
for (auto* attr : in) { for (auto* attr : in) {
if (IsShaderIOAttribute(attr) && if (IsShaderIOAttribute(attr) &&
(do_interpolate || !attr->template Is<ast::InterpolateAttribute>())) { (do_interpolate || !attr->template Is<InterpolateAttribute>())) {
CloneAttribute(attr, out); CloneAttribute(attr, out);
} }
} }
@ -177,7 +175,7 @@ struct CanonicalizeEntryPointIO::State {
/// @param attr the input attribute /// @param attr the input attribute
/// @returns the builtin value of the attribute /// @returns the builtin value of the attribute
builtin::BuiltinValue BuiltinOf(const ast::BuiltinAttribute* attr) { builtin::BuiltinValue BuiltinOf(const BuiltinAttribute* attr) {
if (attr->program_id == ctx.dst->ID()) { if (attr->program_id == ctx.dst->ID()) {
// attr belongs to the target program. // attr belongs to the target program.
// Obtain the builtin value from #builtin_attrs. // Obtain the builtin value from #builtin_attrs.
@ -197,8 +195,8 @@ struct CanonicalizeEntryPointIO::State {
/// @param attrs the input attribute list /// @param attrs the input attribute list
/// @returns the builtin value if any of the attributes in @p attrs is a builtin attribute, /// @returns the builtin value if any of the attributes in @p attrs is a builtin attribute,
/// otherwise builtin::BuiltinValue::kUndefined /// otherwise builtin::BuiltinValue::kUndefined
builtin::BuiltinValue BuiltinOf(utils::VectorRef<const ast::Attribute*> attrs) { builtin::BuiltinValue BuiltinOf(utils::VectorRef<const Attribute*> attrs) {
if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs)) { if (auto* builtin = GetAttribute<BuiltinAttribute>(attrs)) {
return BuiltinOf(builtin); return BuiltinOf(builtin);
} }
return builtin::BuiltinValue::kUndefined; return builtin::BuiltinValue::kUndefined;
@ -219,10 +217,10 @@ struct CanonicalizeEntryPointIO::State {
/// @param location the location if provided /// @param location the location if provided
/// @param attrs the attributes to apply to the shader input /// @param attrs the attributes to apply to the shader input
/// @returns an expression which evaluates to the value of the shader input /// @returns an expression which evaluates to the value of the shader input
const ast::Expression* AddInput(std::string name, const Expression* AddInput(std::string name,
const type::Type* type, const type::Type* type,
std::optional<uint32_t> location, std::optional<uint32_t> location,
utils::Vector<const ast::Attribute*, 8> attrs) { utils::Vector<const Attribute*, 8> attrs) {
auto ast_type = CreateASTTypeFor(ctx, type); auto ast_type = CreateASTTypeFor(ctx, type);
auto builtin_attr = BuiltinOf(attrs); auto builtin_attr = BuiltinOf(attrs);
@ -233,17 +231,16 @@ struct CanonicalizeEntryPointIO::State {
// https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/StandaloneSpirv.html#VUID-StandaloneSpirv-Flat-04744 // https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/StandaloneSpirv.html#VUID-StandaloneSpirv-Flat-04744
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is // TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is
// required for integers. // required for integers.
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment && if (func_ast->PipelineStage() == PipelineStage::kFragment &&
type->is_integer_scalar_or_vector() && type->is_integer_scalar_or_vector() && !HasAttribute<InterpolateAttribute>(attrs) &&
!ast::HasAttribute<ast::InterpolateAttribute>(attrs) && (HasAttribute<LocationAttribute>(attrs) ||
(ast::HasAttribute<ast::LocationAttribute>(attrs) ||
cfg.shader_style == ShaderStyle::kSpirv)) { cfg.shader_style == ShaderStyle::kSpirv)) {
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat, attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
builtin::InterpolationSampling::kUndefined)); builtin::InterpolationSampling::kUndefined));
} }
// Disable validation for use of the `input` address space. // Disable validation for use of the `input` address space.
attrs.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace)); attrs.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
// In GLSL, if it's a builtin, override the name with the // In GLSL, if it's a builtin, override the name with the
// corresponding gl_ builtin name // corresponding gl_ builtin name
@ -255,7 +252,7 @@ struct CanonicalizeEntryPointIO::State {
auto symbol = ctx.dst->Symbols().New(name); auto symbol = ctx.dst->Symbols().New(name);
// Create the global variable and use its value for the shader input. // Create the global variable and use its value for the shader input.
const ast::Expression* value = ctx.dst->Expr(symbol); const Expression* value = ctx.dst->Expr(symbol);
if (builtin_attr != builtin::BuiltinValue::kUndefined) { if (builtin_attr != builtin::BuiltinValue::kUndefined) {
if (cfg.shader_style == ShaderStyle::kGlsl) { if (cfg.shader_style == ShaderStyle::kGlsl) {
@ -296,18 +293,17 @@ struct CanonicalizeEntryPointIO::State {
void AddOutput(std::string name, void AddOutput(std::string name,
const type::Type* type, const type::Type* type,
std::optional<uint32_t> location, std::optional<uint32_t> location,
utils::Vector<const ast::Attribute*, 8> attrs, utils::Vector<const Attribute*, 8> attrs,
const ast::Expression* value) { const Expression* value) {
auto builtin_attr = BuiltinOf(attrs); auto builtin_attr = BuiltinOf(attrs);
// Vulkan requires that integer user-defined vertex outputs are always decorated with // Vulkan requires that integer user-defined vertex outputs are always decorated with
// `Flat`. // `Flat`.
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is required // TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is required
// for integers. // for integers.
if (cfg.shader_style == ShaderStyle::kSpirv && if (cfg.shader_style == ShaderStyle::kSpirv &&
func_ast->PipelineStage() == ast::PipelineStage::kVertex && func_ast->PipelineStage() == PipelineStage::kVertex &&
type->is_integer_scalar_or_vector() && type->is_integer_scalar_or_vector() && HasAttribute<LocationAttribute>(attrs) &&
ast::HasAttribute<ast::LocationAttribute>(attrs) && !HasAttribute<InterpolateAttribute>(attrs)) {
!ast::HasAttribute<ast::InterpolateAttribute>(attrs)) {
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat, attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
builtin::InterpolationSampling::kUndefined)); builtin::InterpolationSampling::kUndefined));
} }
@ -338,14 +334,14 @@ struct CanonicalizeEntryPointIO::State {
/// @param param the original function parameter /// @param param the original function parameter
void ProcessNonStructParameter(const sem::Parameter* param) { void ProcessNonStructParameter(const sem::Parameter* param) {
// Do not add interpolation attributes on vertex input // Do not add interpolation attributes on vertex input
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kVertex; bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
// Remove the shader IO attributes from the inner function parameter, and attach them to the // Remove the shader IO attributes from the inner function parameter, and attach them to the
// new object instead. // new object instead.
utils::Vector<const ast::Attribute*, 8> attributes; utils::Vector<const Attribute*, 8> attributes;
for (auto* attr : param->Declaration()->attributes) { for (auto* attr : param->Declaration()->attributes) {
if (IsShaderIOAttribute(attr)) { if (IsShaderIOAttribute(attr)) {
ctx.Remove(param->Declaration()->attributes, attr); ctx.Remove(param->Declaration()->attributes, attr);
if ((do_interpolate || !attr->Is<ast::InterpolateAttribute>())) { if ((do_interpolate || !attr->Is<InterpolateAttribute>())) {
CloneAttribute(attr, attributes); CloneAttribute(attr, attributes);
} }
} }
@ -363,13 +359,13 @@ struct CanonicalizeEntryPointIO::State {
/// @param param the original function parameter /// @param param the original function parameter
void ProcessStructParameter(const sem::Parameter* param) { void ProcessStructParameter(const sem::Parameter* param) {
// Do not add interpolation attributes on vertex input // Do not add interpolation attributes on vertex input
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kVertex; bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
auto* str = param->Type()->As<sem::Struct>(); auto* str = param->Type()->As<sem::Struct>();
// Recreate struct members in the outer entry point and build an initializer // Recreate struct members in the outer entry point and build an initializer
// list to pass them through to the inner function. // list to pass them through to the inner function.
utils::Vector<const ast::Expression*, 8> inner_struct_values; utils::Vector<const Expression*, 8> inner_struct_values;
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) { if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct"; TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
@ -397,7 +393,7 @@ struct CanonicalizeEntryPointIO::State {
/// @param original_result the result object produced by the original function /// @param original_result the result object produced by the original function
void ProcessReturnType(const type::Type* inner_ret_type, Symbol original_result) { void ProcessReturnType(const type::Type* inner_ret_type, Symbol original_result) {
// Do not add interpolation attributes on fragment output // Do not add interpolation attributes on fragment output
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kFragment; bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kFragment;
if (auto* str = inner_ret_type->As<sem::Struct>()) { if (auto* str = inner_ret_type->As<sem::Struct>()) {
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) { if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
@ -456,7 +452,7 @@ struct CanonicalizeEntryPointIO::State {
/// Create an expression for gl_Position.[component] /// Create an expression for gl_Position.[component]
/// @param component the component of gl_Position to access /// @param component the component of gl_Position to access
/// @returns the new expression /// @returns the new expression
const ast::Expression* GLPosition(const char* component) { const Expression* GLPosition(const char* component) {
Symbol pos = ctx.dst->Symbols().Register("gl_Position"); Symbol pos = ctx.dst->Symbols().Register("gl_Position");
Symbol c = ctx.dst->Symbols().Register(component); Symbol c = ctx.dst->Symbols().Register(component);
return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), c); return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), c);
@ -469,10 +465,10 @@ struct CanonicalizeEntryPointIO::State {
/// @param b another struct member /// @param b another struct member
/// @returns true if a comes before b /// @returns true if a comes before b
bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) { bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) {
auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a.member->attributes); auto* a_loc = GetAttribute<LocationAttribute>(a.member->attributes);
auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b.member->attributes); auto* b_loc = GetAttribute<LocationAttribute>(b.member->attributes);
auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a.member->attributes); auto* a_blt = GetAttribute<BuiltinAttribute>(a.member->attributes);
auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b.member->attributes); auto* b_blt = GetAttribute<BuiltinAttribute>(b.member->attributes);
if (a_loc) { if (a_loc) {
if (!b_loc) { if (!b_loc) {
// `a` has location attribute and `b` does not: `a` goes first. // `a` has location attribute and `b` does not: `a` goes first.
@ -497,15 +493,15 @@ struct CanonicalizeEntryPointIO::State {
std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(), std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
[&](auto& a, auto& b) { return StructMemberComparator(a, b); }); [&](auto& a, auto& b) { return StructMemberComparator(a, b); });
utils::Vector<const ast::StructMember*, 8> members; utils::Vector<const StructMember*, 8> members;
for (auto& mem : wrapper_struct_param_members) { for (auto& mem : wrapper_struct_param_members) {
members.Push(mem.member); members.Push(mem.member);
} }
// Create the new struct type. // Create the new struct type.
auto struct_name = ctx.dst->Sym(); auto struct_name = ctx.dst->Sym();
auto* in_struct = ctx.dst->create<ast::Struct>(ctx.dst->Ident(struct_name), auto* in_struct =
std::move(members), utils::Empty); ctx.dst->create<Struct>(ctx.dst->Ident(struct_name), std::move(members), utils::Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct); ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
// Create a new function parameter using this struct type. // Create a new function parameter using this struct type.
@ -515,8 +511,8 @@ struct CanonicalizeEntryPointIO::State {
/// Create and return the wrapper function's struct result object. /// Create and return the wrapper function's struct result object.
/// @returns the struct type /// @returns the struct type
ast::Struct* CreateOutputStruct() { Struct* CreateOutputStruct() {
utils::Vector<const ast::Statement*, 8> assignments; utils::Vector<const Statement*, 8> assignments;
auto wrapper_result = ctx.dst->Symbols().New("wrapper_result"); auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
@ -544,13 +540,13 @@ struct CanonicalizeEntryPointIO::State {
std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(), std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
[&](auto& a, auto& b) { return StructMemberComparator(a, b); }); [&](auto& a, auto& b) { return StructMemberComparator(a, b); });
utils::Vector<const ast::StructMember*, 8> members; utils::Vector<const StructMember*, 8> members;
for (auto& mem : wrapper_struct_output_members) { for (auto& mem : wrapper_struct_output_members) {
members.Push(mem.member); members.Push(mem.member);
} }
// Create the new struct type. // Create the new struct type.
auto* out_struct = ctx.dst->create<ast::Struct>(ctx.dst->Ident(ctx.dst->Sym()), auto* out_struct = ctx.dst->create<Struct>(ctx.dst->Ident(ctx.dst->Sym()),
std::move(members), utils::Empty); std::move(members), utils::Empty);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct); ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
@ -570,12 +566,12 @@ struct CanonicalizeEntryPointIO::State {
for (auto& outval : wrapper_output_values) { for (auto& outval : wrapper_output_values) {
// Disable validation for use of the `output` address space. // Disable validation for use of the `output` address space.
auto attributes = std::move(outval.attributes); auto attributes = std::move(outval.attributes);
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace)); attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
// Create the global variable and assign it the output value. // Create the global variable and assign it the output value.
auto name = ctx.dst->Symbols().New(outval.name); auto name = ctx.dst->Symbols().New(outval.name);
ast::Type type = outval.type; Type type = outval.type;
const ast::Expression* lhs = ctx.dst->Expr(name); const Expression* lhs = ctx.dst->Expr(name);
if (BuiltinOf(attributes) == builtin::BuiltinValue::kSampleMask) { if (BuiltinOf(attributes) == builtin::BuiltinValue::kSampleMask) {
// Vulkan requires the type of a SampleMask builtin to be an array. // Vulkan requires the type of a SampleMask builtin to be an array.
// Declare it as array<u32, 1> and then store to the first element. // Declare it as array<u32, 1> and then store to the first element.
@ -589,7 +585,7 @@ struct CanonicalizeEntryPointIO::State {
// Recreate the original function without entry point attributes and call it. // Recreate the original function without entry point attributes and call it.
/// @returns the inner function call expression /// @returns the inner function call expression
const ast::CallExpression* CallInnerFunction() { const CallExpression* CallInnerFunction() {
Symbol inner_name; Symbol inner_name;
if (cfg.shader_style == ShaderStyle::kGlsl) { if (cfg.shader_style == ShaderStyle::kGlsl) {
// In GLSL, clone the original entry point name, as the wrapper will be // In GLSL, clone the original entry point name, as the wrapper will be
@ -606,9 +602,9 @@ struct CanonicalizeEntryPointIO::State {
// The parameter attributes will have already been stripped during // The parameter attributes will have already been stripped during
// processing. // processing.
auto* inner_function = auto* inner_function =
ctx.dst->create<ast::Function>(ctx.dst->Ident(inner_name), ctx.Clone(func_ast->params), ctx.dst->create<Function>(ctx.dst->Ident(inner_name), ctx.Clone(func_ast->params),
ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
ctx.Clone(func_ast->body), utils::Empty, utils::Empty); utils::Empty, utils::Empty);
ctx.Replace(func_ast, inner_function); ctx.Replace(func_ast, inner_function);
// Call the function. // Call the function.
@ -619,12 +615,11 @@ struct CanonicalizeEntryPointIO::State {
void Process() { void Process() {
bool needs_fixed_sample_mask = false; bool needs_fixed_sample_mask = false;
bool needs_vertex_point_size = false; bool needs_vertex_point_size = false;
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment && if (func_ast->PipelineStage() == PipelineStage::kFragment &&
cfg.fixed_sample_mask != 0xFFFFFFFF) { cfg.fixed_sample_mask != 0xFFFFFFFF) {
needs_fixed_sample_mask = true; needs_fixed_sample_mask = true;
} }
if (func_ast->PipelineStage() == ast::PipelineStage::kVertex && if (func_ast->PipelineStage() == PipelineStage::kVertex && cfg.emit_vertex_point_size) {
cfg.emit_vertex_point_size) {
needs_vertex_point_size = true; needs_vertex_point_size = true;
} }
@ -656,7 +651,7 @@ struct CanonicalizeEntryPointIO::State {
auto* call_inner = CallInnerFunction(); auto* call_inner = CallInnerFunction();
// Process the return type, and start building the wrapper function body. // Process the return type, and start building the wrapper function body.
std::function<ast::Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); }; std::function<Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
if (func_sem->ReturnType()->Is<type::Void>()) { if (func_sem->ReturnType()->Is<type::Void>()) {
// The function call is just a statement with no result. // The function call is just a statement with no result.
wrapper_body.Push(ctx.dst->CallStmt(call_inner)); wrapper_body.Push(ctx.dst->CallStmt(call_inner));
@ -693,10 +688,10 @@ struct CanonicalizeEntryPointIO::State {
} }
if (cfg.shader_style == ShaderStyle::kGlsl && if (cfg.shader_style == ShaderStyle::kGlsl &&
func_ast->PipelineStage() == ast::PipelineStage::kVertex) { func_ast->PipelineStage() == PipelineStage::kVertex) {
auto* pos_y = GLPosition("y"); auto* pos_y = GLPosition("y");
auto* negate_pos_y = auto* negate_pos_y =
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation, GLPosition("y")); ctx.dst->create<UnaryOpExpression>(UnaryOp::kNegation, GLPosition("y"));
wrapper_body.Push(ctx.dst->Assign(pos_y, negate_pos_y)); wrapper_body.Push(ctx.dst->Assign(pos_y, negate_pos_y));
auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z")); auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z"));
@ -714,7 +709,7 @@ struct CanonicalizeEntryPointIO::State {
name = ctx.Clone(func_ast->name->symbol); name = ctx.Clone(func_ast->name->symbol);
} }
auto* wrapper_func = ctx.dst->create<ast::Function>( auto* wrapper_func = ctx.dst->create<Function>(
ctx.dst->Ident(name), wrapper_ep_parameters, ctx.dst->ty(wrapper_ret_type()), ctx.dst->Ident(name), wrapper_ep_parameters, ctx.dst->ty(wrapper_ret_type()),
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes), utils::Empty); ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes), utils::Empty);
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func); ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
@ -726,14 +721,14 @@ struct CanonicalizeEntryPointIO::State {
/// @param address_space the address space (input or output) /// @param address_space the address space (input or output)
/// @returns the gl_ string corresponding to that builtin /// @returns the gl_ string corresponding to that builtin
const char* GLSLBuiltinToString(builtin::BuiltinValue builtin, const char* GLSLBuiltinToString(builtin::BuiltinValue builtin,
ast::PipelineStage stage, PipelineStage stage,
builtin::AddressSpace address_space) { builtin::AddressSpace address_space) {
switch (builtin) { switch (builtin) {
case builtin::BuiltinValue::kPosition: case builtin::BuiltinValue::kPosition:
switch (stage) { switch (stage) {
case ast::PipelineStage::kVertex: case PipelineStage::kVertex:
return "gl_Position"; return "gl_Position";
case ast::PipelineStage::kFragment: case PipelineStage::kFragment:
return "gl_FragCoord"; return "gl_FragCoord";
default: default:
return ""; return "";
@ -775,9 +770,9 @@ struct CanonicalizeEntryPointIO::State {
/// @param ast_type (inout) the incoming WGSL and outgoing GLSL types /// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
/// @returns an expression representing the GLSL builtin converted to what /// @returns an expression representing the GLSL builtin converted to what
/// WGSL expects /// WGSL expects
const ast::Expression* FromGLSLBuiltin(builtin::BuiltinValue builtin, const Expression* FromGLSLBuiltin(builtin::BuiltinValue builtin,
const ast::Expression* value, const Expression* value,
ast::Type& ast_type) { Type& ast_type) {
switch (builtin) { switch (builtin) {
case builtin::BuiltinValue::kVertexIndex: case builtin::BuiltinValue::kVertexIndex:
case builtin::BuiltinValue::kInstanceIndex: case builtin::BuiltinValue::kInstanceIndex:
@ -805,8 +800,8 @@ struct CanonicalizeEntryPointIO::State {
/// @param value the value to convert /// @param value the value to convert
/// @param type (out) the type to which the value was converted /// @param type (out) the type to which the value was converted
/// @returns the converted value which can be assigned to the GLSL builtin /// @returns the converted value which can be assigned to the GLSL builtin
const ast::Expression* ToGLSLBuiltin(builtin::BuiltinValue builtin, const Expression* ToGLSLBuiltin(builtin::BuiltinValue builtin,
const ast::Expression* value, const Expression* value,
const type::Type*& type) { const type::Type*& type) {
switch (builtin) { switch (builtin) {
case builtin::BuiltinValue::kVertexIndex: case builtin::BuiltinValue::kVertexIndex:
@ -839,7 +834,7 @@ Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
// Remove entry point IO attributes from struct declarations. // Remove entry point IO attributes from struct declarations.
// New structures will be created for each entry point, as necessary. // New structures will be created for each entry point, as necessary.
for (auto* ty : src->AST().TypeDecls()) { for (auto* ty : src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<ast::Struct>()) { if (auto* struct_ty = ty->As<Struct>()) {
for (auto* member : struct_ty->members) { for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) { for (auto* attr : member->attributes) {
if (IsShaderIOAttribute(attr)) { if (IsShaderIOAttribute(attr)) {

View File

@ -51,7 +51,7 @@ struct ClampFragDepth::State {
Transform::ApplyResult Run() { Transform::ApplyResult Run() {
// Abort on any use of push constants in the module. // Abort on any use of push constants in the module.
for (auto* global : src->AST().GlobalVariables()) { for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) { if (auto* var = global->As<Var>()) {
auto* v = src->Sem().Get(var); auto* v = src->Sem().Get(var);
if (TINT_UNLIKELY(v->AddressSpace() == builtin::AddressSpace::kPushConstant)) { if (TINT_UNLIKELY(v->AddressSpace() == builtin::AddressSpace::kPushConstant)) {
TINT_ICE(Transform, b.Diagnostics()) TINT_ICE(Transform, b.Diagnostics())
@ -101,14 +101,14 @@ struct ClampFragDepth::State {
// Map of io struct to helper function to return the structure with the depth clamping // Map of io struct to helper function to return the structure with the depth clamping
// applied. // applied.
utils::Hashmap<const ast::Struct*, Symbol, 4u> io_structs_clamp_helpers; utils::Hashmap<const Struct*, Symbol, 4u> io_structs_clamp_helpers;
// Register a callback that will be called for each visted AST function. // Register a callback that will be called for each visted AST function.
// This call wraps the cloning of the function's statements, and will assign to // This call wraps the cloning of the function's statements, and will assign to
// `returns_frag_depth_as_value` or `returns_frag_depth_as_struct_helper` if the function's // `returns_frag_depth_as_value` or `returns_frag_depth_as_struct_helper` if the function's
// return value requires depth clamping. // return value requires depth clamping.
ctx.ReplaceAll([&](const ast::Function* fn) { ctx.ReplaceAll([&](const Function* fn) {
if (fn->PipelineStage() != ast::PipelineStage::kFragment) { if (fn->PipelineStage() != PipelineStage::kFragment) {
return ctx.CloneWithoutTransform(fn); return ctx.CloneWithoutTransform(fn);
} }
@ -129,9 +129,9 @@ struct ClampFragDepth::State {
auto fn_sym = auto fn_sym =
b.Symbols().New("clamp_frag_depth_" + struct_ty->name->symbol.Name()); b.Symbols().New("clamp_frag_depth_" + struct_ty->name->symbol.Name());
utils::Vector<const ast::Expression*, 8u> initializer_args; utils::Vector<const Expression*, 8u> initializer_args;
for (auto* member : struct_ty->members) { for (auto* member : struct_ty->members) {
const ast::Expression* arg = const Expression* arg =
b.MemberAccessor("s", ctx.Clone(member->name->symbol)); b.MemberAccessor("s", ctx.Clone(member->name->symbol));
if (ContainsFragDepth(member->attributes)) { if (ContainsFragDepth(member->attributes)) {
arg = b.Call(base_fn_sym, arg); arg = b.Call(base_fn_sym, arg);
@ -154,7 +154,7 @@ struct ClampFragDepth::State {
}); });
// Replace the return statements `return expr` with `return clamp_frag_depth(expr)`. // Replace the return statements `return expr` with `return clamp_frag_depth(expr)`.
ctx.ReplaceAll([&](const ast::ReturnStatement* stmt) -> const ast::ReturnStatement* { ctx.ReplaceAll([&](const ReturnStatement* stmt) -> const ReturnStatement* {
if (returns_frag_depth_as_value) { if (returns_frag_depth_as_value) {
return b.Return(stmt->source, b.Call(base_fn_sym, ctx.Clone(stmt->value))); return b.Return(stmt->source, b.Call(base_fn_sym, ctx.Clone(stmt->value)));
} }
@ -173,7 +173,7 @@ struct ClampFragDepth::State {
/// @returns true if the transform should run /// @returns true if the transform should run
bool ShouldRun() { bool ShouldRun() {
for (auto* fn : src->AST().Functions()) { for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kFragment && if (fn->PipelineStage() == PipelineStage::kFragment &&
(ReturnsFragDepthAsValue(fn) || ReturnsFragDepthInStruct(fn))) { (ReturnsFragDepthAsValue(fn) || ReturnsFragDepthInStruct(fn))) {
return true; return true;
} }
@ -183,9 +183,9 @@ struct ClampFragDepth::State {
} }
/// @param attrs the attributes to examine /// @param attrs the attributes to examine
/// @returns true if @p attrs contains a `@builtin(frag_depth)` attribute /// @returns true if @p attrs contains a `@builtin(frag_depth)` attribute
bool ContainsFragDepth(utils::VectorRef<const ast::Attribute*> attrs) { bool ContainsFragDepth(utils::VectorRef<const Attribute*> attrs) {
for (auto* attribute : attrs) { for (auto* attribute : attrs) {
if (auto* builtin_attr = attribute->As<ast::BuiltinAttribute>()) { if (auto* builtin_attr = attribute->As<BuiltinAttribute>()) {
auto builtin = sem.Get(builtin_attr)->Value(); auto builtin = sem.Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kFragDepth) { if (builtin == builtin::BuiltinValue::kFragDepth) {
return true; return true;
@ -198,14 +198,14 @@ struct ClampFragDepth::State {
/// @param fn the function to examine /// @param fn the function to examine
/// @returns true if @p fn has a return type with a `@builtin(frag_depth)` attribute /// @returns true if @p fn has a return type with a `@builtin(frag_depth)` attribute
bool ReturnsFragDepthAsValue(const ast::Function* fn) { bool ReturnsFragDepthAsValue(const Function* fn) {
return ContainsFragDepth(fn->return_type_attributes); return ContainsFragDepth(fn->return_type_attributes);
} }
/// @param fn the function to examine /// @param fn the function to examine
/// @returns true if @p fn has a return structure with a `@builtin(frag_depth)` attribute on one /// @returns true if @p fn has a return structure with a `@builtin(frag_depth)` attribute on one
/// of the members /// of the members
bool ReturnsFragDepthInStruct(const ast::Function* fn) { bool ReturnsFragDepthInStruct(const Function* fn) {
if (auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()) { if (auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()) {
for (auto* member : struct_ty->Members()) { for (auto* member : struct_ty->Members()) {
if (ContainsFragDepth(member->Declaration()->attributes)) { if (ContainsFragDepth(member->Declaration()->attributes)) {

View File

@ -61,7 +61,7 @@ struct CombineSamplers::State {
/// Map from a texture/sampler pair to the corresponding combined sampler /// Map from a texture/sampler pair to the corresponding combined sampler
/// variable /// variable
using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const ast::Variable*>; using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const Variable*>;
/// Use sem::BindingPoint without scope. /// Use sem::BindingPoint without scope.
using BindingPoint = sem::BindingPoint; using BindingPoint = sem::BindingPoint;
@ -79,15 +79,14 @@ struct CombineSamplers::State {
/// references (one comparison sampler, one regular). These are also used as /// references (one comparison sampler, one regular). These are also used as
/// temporary sampler parameters to the texture builtins to satisfy the WGSL /// temporary sampler parameters to the texture builtins to satisfy the WGSL
/// resolver, but are then ignored and removed by the GLSL writer. /// resolver, but are then ignored and removed by the GLSL writer.
const ast::Variable* placeholder_samplers_[2] = {}; const Variable* placeholder_samplers_[2] = {};
/// Group and binding attributes used by all combined sampler globals. /// Group and binding attributes used by all combined sampler globals.
/// Group 0 and binding 0 are used, with collisions disabled. /// Group 0 and binding 0 are used, with collisions disabled.
/// @returns the newly-created attribute list /// @returns the newly-created attribute list
auto Attributes() const { auto Attributes() const {
utils::Vector<const ast::Attribute*, 3> attributes{ctx.dst->Group(0_a), utils::Vector<const Attribute*, 3> attributes{ctx.dst->Group(0_a), ctx.dst->Binding(0_a)};
ctx.dst->Binding(0_a)}; attributes.Push(ctx.dst->Disable(DisabledValidation::kBindingPointCollision));
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision));
return attributes; return attributes;
} }
@ -103,7 +102,7 @@ struct CombineSamplers::State {
/// @param sampler_var the sampler (global) variable /// @param sampler_var the sampler (global) variable
/// @param name the default name to use (may be overridden by map lookup) /// @param name the default name to use (may be overridden by map lookup)
/// @returns the newly-created global variable /// @returns the newly-created global variable
const ast::Variable* CreateCombinedGlobal(const sem::Variable* texture_var, const Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
const sem::Variable* sampler_var, const sem::Variable* sampler_var,
std::string name) { std::string name) {
SamplerTexturePair bp_pair; SamplerTexturePair bp_pair;
@ -115,7 +114,7 @@ struct CombineSamplers::State {
if (it != binding_info->binding_map.end()) { if (it != binding_info->binding_map.end()) {
name = it->second; name = it->second;
} }
ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var); Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
Symbol symbol = ctx.dst->Symbols().New(name); Symbol symbol = ctx.dst->Symbols().New(name);
return ctx.dst->GlobalVar(symbol, type, Attributes()); return ctx.dst->GlobalVar(symbol, type, Attributes());
} }
@ -123,8 +122,8 @@ struct CombineSamplers::State {
/// Creates placeholder global sampler variables. /// Creates placeholder global sampler variables.
/// @param kind the sampler kind to create for /// @param kind the sampler kind to create for
/// @returns the newly-created global variable /// @returns the newly-created global variable
const ast::Variable* CreatePlaceholder(type::SamplerKind kind) { const Variable* CreatePlaceholder(type::SamplerKind kind) {
ast::Type type = ctx.dst->ty.sampler(kind); Type type = ctx.dst->ty.sampler(kind);
const char* name = kind == type::SamplerKind::kComparisonSampler const char* name = kind == type::SamplerKind::kComparisonSampler
? "placeholder_comparison_sampler" ? "placeholder_comparison_sampler"
: "placeholder_sampler"; : "placeholder_sampler";
@ -132,13 +131,13 @@ struct CombineSamplers::State {
return ctx.dst->GlobalVar(symbol, type, Attributes()); return ctx.dst->GlobalVar(symbol, type, Attributes());
} }
/// Creates ast::Identifier for a given texture and sampler variable pair. /// Creates Identifier for a given texture and sampler variable pair.
/// Depth textures with no samplers are turned into the corresponding /// Depth textures with no samplers are turned into the corresponding
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>). /// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
/// @param texture the texture variable of interest /// @param texture the texture variable of interest
/// @param sampler the texture variable of interest /// @param sampler the texture variable of interest
/// @returns the newly-created type /// @returns the newly-created type
ast::Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) { Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
const type::Type* texture_type = texture->Type()->UnwrapRef(); const type::Type* texture_type = texture->Type()->UnwrapRef();
const type::DepthTexture* depth = texture_type->As<type::DepthTexture>(); const type::DepthTexture* depth = texture_type->As<type::DepthTexture>();
if (depth && !sampler) { if (depth && !sampler) {
@ -163,8 +162,7 @@ struct CombineSamplers::State {
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global); ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
} else if (auto binding_point = global_sem->BindingPoint()) { } else if (auto binding_point = global_sem->BindingPoint()) {
if (binding_point->group == 0 && binding_point->binding == 0) { if (binding_point->group == 0 && binding_point->binding == 0) {
auto* attribute = auto* attribute = ctx.dst->Disable(DisabledValidation::kBindingPointCollision);
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertFront(global->attributes, attribute); ctx.InsertFront(global->attributes, attribute);
} }
} }
@ -172,13 +170,13 @@ struct CombineSamplers::State {
// Rewrite all function signatures to use combined samplers, and remove // Rewrite all function signatures to use combined samplers, and remove
// separate textures & samplers. Create new combined globals where found. // separate textures & samplers. Create new combined globals where found.
ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* { ctx.ReplaceAll([&](const Function* ast_fn) -> const Function* {
if (auto* fn = sem.Get(ast_fn)) { if (auto* fn = sem.Get(ast_fn)) {
auto pairs = fn->TextureSamplerPairs(); auto pairs = fn->TextureSamplerPairs();
if (pairs.IsEmpty()) { if (pairs.IsEmpty()) {
return nullptr; return nullptr;
} }
utils::Vector<const ast::Parameter*, 8> params; utils::Vector<const Parameter*, 8> params;
for (auto pair : fn->TextureSamplerPairs()) { for (auto pair : fn->TextureSamplerPairs()) {
const sem::Variable* texture_var = pair.first; const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second; const sem::Variable* sampler_var = pair.second;
@ -195,7 +193,7 @@ struct CombineSamplers::State {
} else { } else {
// Either texture or sampler (or both) is a function parameter; // Either texture or sampler (or both) is a function parameter;
// add a new function parameter to represent the combined sampler. // add a new function parameter to represent the combined sampler.
ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var); Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type); auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.Push(var); params.Push(var);
function_combined_texture_samplers_[fn][pair] = var; function_combined_texture_samplers_[fn][pair] = var;
@ -215,7 +213,7 @@ struct CombineSamplers::State {
auto* body = ctx.Clone(ast_fn->body); auto* body = ctx.Clone(ast_fn->body);
auto attributes = ctx.Clone(ast_fn->attributes); auto attributes = ctx.Clone(ast_fn->attributes);
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes); auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
return ctx.dst->create<ast::Function>(name, params, return_type, body, return ctx.dst->create<Function>(name, params, return_type, body,
std::move(attributes), std::move(attributes),
std::move(return_type_attributes)); std::move(return_type_attributes));
} }
@ -225,9 +223,9 @@ struct CombineSamplers::State {
// Replace all function call expressions containing texture or // Replace all function call expressions containing texture or
// sampler parameters to use the current function's combined samplers or // sampler parameters to use the current function's combined samplers or
// the combined global samplers, as appropriate. // the combined global samplers, as appropriate.
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* { ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) { if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const Expression*, 8> args;
// Replace all texture builtin calls. // Replace all texture builtin calls.
if (auto* builtin = call->Target()->As<sem::Builtin>()) { if (auto* builtin = call->Target()->As<sem::Builtin>()) {
const auto& signature = builtin->Signature(); const auto& signature = builtin->Signature();
@ -254,7 +252,7 @@ struct CombineSamplers::State {
for (auto* arg : expr->args) { for (auto* arg : expr->args) {
auto* type = ctx.src->TypeOf(arg)->UnwrapRef(); auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
if (type->Is<type::Texture>()) { if (type->Is<type::Texture>()) {
const ast::Variable* var = const Variable* var =
IsGlobal(new_pair) IsGlobal(new_pair)
? global_combined_texture_samplers_[new_pair] ? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_[call->Stmt()->Function()] : function_combined_texture_samplers_[call->Stmt()->Function()]
@ -263,7 +261,7 @@ struct CombineSamplers::State {
} else if (auto* sampler_type = type->As<type::Sampler>()) { } else if (auto* sampler_type = type->As<type::Sampler>()) {
type::SamplerKind kind = sampler_type->kind(); type::SamplerKind kind = sampler_type->kind();
int index = (kind == type::SamplerKind::kSampler) ? 0 : 1; int index = (kind == type::SamplerKind::kSampler) ? 0 : 1;
const ast::Variable*& p = placeholder_samplers_[index]; const Variable*& p = placeholder_samplers_[index];
if (!p) { if (!p) {
p = CreatePlaceholder(kind); p = CreatePlaceholder(kind);
} }
@ -272,10 +270,10 @@ struct CombineSamplers::State {
args.Push(ctx.Clone(arg)); args.Push(ctx.Clone(arg));
} }
} }
const ast::Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args); const Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
if (builtin->Type() == builtin::Function::kTextureLoad && if (builtin->Type() == builtin::Function::kTextureLoad &&
texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() && texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() &&
!call->Stmt()->Declaration()->Is<ast::CallStatement>()) { !call->Stmt()->Declaration()->Is<CallStatement>()) {
value = ctx.dst->MemberAccessor(value, "x"); value = ctx.dst->MemberAccessor(value, "x");
} }
return value; return value;
@ -307,7 +305,7 @@ struct CombineSamplers::State {
// If both texture and sampler are (now) global, pass that // If both texture and sampler are (now) global, pass that
// global variable to the callee. Otherwise use the caller's // global variable to the callee. Otherwise use the caller's
// function parameter for this pair. // function parameter for this pair.
const ast::Variable* var = const Variable* var =
IsGlobal(new_pair) IsGlobal(new_pair)
? global_combined_texture_samplers_[new_pair] ? global_combined_texture_samplers_[new_pair]
: function_combined_texture_samplers_[call->Stmt()->Function()] : function_combined_texture_samplers_[call->Stmt()->Function()]

View File

@ -60,21 +60,21 @@ bool ShouldRun(const Program* program) {
return false; return false;
} }
/// Offset is a simple ast::Expression builder interface, used to build byte /// Offset is a simple Expression builder interface, used to build byte
/// offsets for storage and uniform buffer accesses. /// offsets for storage and uniform buffer accesses.
struct Offset : utils::Castable<Offset> { struct Offset : utils::Castable<Offset> {
/// @returns builds and returns the ast::Expression in `ctx.dst` /// @returns builds and returns the Expression in `ctx.dst`
virtual const ast::Expression* Build(CloneContext& ctx) const = 0; virtual const Expression* Build(CloneContext& ctx) const = 0;
}; };
/// OffsetExpr is an implementation of Offset that clones and casts the given /// OffsetExpr is an implementation of Offset that clones and casts the given
/// expression to `u32`. /// expression to `u32`.
struct OffsetExpr : Offset { struct OffsetExpr : Offset {
const ast::Expression* const expr = nullptr; const Expression* const expr = nullptr;
explicit OffsetExpr(const ast::Expression* e) : expr(e) {} explicit OffsetExpr(const Expression* e) : expr(e) {}
const ast::Expression* Build(CloneContext& ctx) const override { const Expression* Build(CloneContext& ctx) const override {
auto* type = ctx.src->Sem().GetVal(expr)->Type()->UnwrapRef(); auto* type = ctx.src->Sem().GetVal(expr)->Type()->UnwrapRef();
auto* res = ctx.Clone(expr); auto* res = ctx.Clone(expr);
if (!type->Is<type::U32>()) { if (!type->Is<type::U32>()) {
@ -91,7 +91,7 @@ struct OffsetLiteral final : utils::Castable<OffsetLiteral, Offset> {
explicit OffsetLiteral(uint32_t lit) : literal(lit) {} explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
const ast::Expression* Build(CloneContext& ctx) const override { const Expression* Build(CloneContext& ctx) const override {
return ctx.dst->Expr(u32(literal)); return ctx.dst->Expr(u32(literal));
} }
}; };
@ -99,12 +99,12 @@ struct OffsetLiteral final : utils::Castable<OffsetLiteral, Offset> {
/// OffsetBinOp is an implementation of Offset that constructs a binary-op of /// OffsetBinOp is an implementation of Offset that constructs a binary-op of
/// two Offsets. /// two Offsets.
struct OffsetBinOp : Offset { struct OffsetBinOp : Offset {
ast::BinaryOp op; BinaryOp op;
Offset const* lhs = nullptr; Offset const* lhs = nullptr;
Offset const* rhs = nullptr; Offset const* rhs = nullptr;
const ast::Expression* Build(CloneContext& ctx) const override { const Expression* Build(CloneContext& ctx) const override {
return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx)); return ctx.dst->create<BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx));
} }
}; };
@ -313,7 +313,7 @@ struct BufferAccess {
/// Store describes a single storage or uniform buffer write /// Store describes a single storage or uniform buffer write
struct Store { struct Store {
const ast::AssignmentStatement* assignment; // The AST assignment statement const AssignmentStatement* assignment; // The AST assignment statement
BufferAccess target; // The target for the write BufferAccess target; // The target for the write
}; };
@ -330,9 +330,9 @@ struct DecomposeMemoryAccess::State {
/// expressions chain the access. /// expressions chain the access.
/// Subset of #expression_order, as expressions are not removed from /// Subset of #expression_order, as expressions are not removed from
/// #expression_order. /// #expression_order.
std::unordered_map<const ast::Expression*, BufferAccess> accesses; std::unordered_map<const Expression*, BufferAccess> accesses;
/// The visited order of AST expressions (superset of #accesses) /// The visited order of AST expressions (superset of #accesses)
std::vector<const ast::Expression*> expression_order; std::vector<const Expression*> expression_order;
/// [buffer-type, element-type] -> load function name /// [buffer-type, element-type] -> load function name
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs; std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
/// [buffer-type, element-type] -> store function name /// [buffer-type, element-type] -> store function name
@ -353,9 +353,9 @@ struct DecomposeMemoryAccess::State {
const Offset* ToOffset(uint32_t offset) { return offsets_.Create<OffsetLiteral>(offset); } const Offset* ToOffset(uint32_t offset) { return offsets_.Create<OffsetLiteral>(offset); }
/// @param expr the expression to convert to an Offset /// @param expr the expression to convert to an Offset
/// @returns an Offset for the given ast::Expression /// @returns an Offset for the given Expression
const Offset* ToOffset(const ast::Expression* expr) { const Offset* ToOffset(const Expression* expr) {
if (auto* lit = expr->As<ast::IntLiteralExpression>()) { if (auto* lit = expr->As<IntLiteralExpression>()) {
if (lit->value >= 0) { if (lit->value >= 0) {
return offsets_.Create<OffsetLiteral>(static_cast<uint32_t>(lit->value)); return offsets_.Create<OffsetLiteral>(static_cast<uint32_t>(lit->value));
} }
@ -390,7 +390,7 @@ struct DecomposeMemoryAccess::State {
} }
} }
auto* out = offsets_.Create<OffsetBinOp>(); auto* out = offsets_.Create<OffsetBinOp>();
out->op = ast::BinaryOp::kAdd; out->op = BinaryOp::kAdd;
out->lhs = lhs; out->lhs = lhs;
out->rhs = rhs; out->rhs = rhs;
return out; return out;
@ -422,7 +422,7 @@ struct DecomposeMemoryAccess::State {
return offsets_.Create<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal); return offsets_.Create<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal);
} }
auto* out = offsets_.Create<OffsetBinOp>(); auto* out = offsets_.Create<OffsetBinOp>();
out->op = ast::BinaryOp::kMultiply; out->op = BinaryOp::kMultiply;
out->lhs = lhs; out->lhs = lhs;
out->rhs = rhs; out->rhs = rhs;
return out; return out;
@ -432,7 +432,7 @@ struct DecomposeMemoryAccess::State {
/// to #expression_order. /// to #expression_order.
/// @param expr the expression that performs the access /// @param expr the expression that performs the access
/// @param access the access /// @param access the access
void AddAccess(const ast::Expression* expr, const BufferAccess& access) { void AddAccess(const Expression* expr, const BufferAccess& access) {
TINT_ASSERT(Transform, access.type); TINT_ASSERT(Transform, access.type);
accesses.emplace(expr, access); accesses.emplace(expr, access);
expression_order.emplace_back(expr); expression_order.emplace_back(expr);
@ -443,7 +443,7 @@ struct DecomposeMemoryAccess::State {
/// `node`, an invalid BufferAccess is returned. /// `node`, an invalid BufferAccess is returned.
/// @param node the expression that performed an access /// @param node the expression that performed an access
/// @return the BufferAccess for the given expression /// @return the BufferAccess for the given expression
BufferAccess TakeAccess(const ast::Expression* node) { BufferAccess TakeAccess(const Expression* node) {
auto lhs_it = accesses.find(node); auto lhs_it = accesses.find(node);
if (lhs_it == accesses.end()) { if (lhs_it == accesses.end()) {
return {}; return {};
@ -475,7 +475,7 @@ struct DecomposeMemoryAccess::State {
b.Func(name, params, el_ast_ty, nullptr, b.Func(name, params, el_ast_ty, nullptr,
utils::Vector{ utils::Vector{
intrinsic, intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody), b.Disable(DisabledValidation::kFunctionHasNoBody),
}); });
} else if (auto* arr_ty = el_ty->As<type::Array>()) { } else if (auto* arr_ty = el_ty->As<type::Array>()) {
// fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> { // fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
@ -498,8 +498,8 @@ struct DecomposeMemoryAccess::State {
TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count"; TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
arr_cnt = 1; arr_cnt = 1;
} }
auto* for_cond = b.create<ast::BinaryExpression>( auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value()))); b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u)); auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(arr, i); auto* arr_el = b.IndexAccessor(arr, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride()))); auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
@ -514,7 +514,7 @@ struct DecomposeMemoryAccess::State {
b.Return(arr), b.Return(arr),
}); });
} else { } else {
utils::Vector<const ast::Expression*, 8> values; utils::Vector<const Expression*, 8> values;
if (auto* mat_ty = el_ty->As<type::Matrix>()) { if (auto* mat_ty = el_ty->As<type::Matrix>()) {
auto* vec_ty = mat_ty->ColumnType(); auto* vec_ty = mat_ty->ColumnType();
Symbol load = LoadFunc(vec_ty, address_space, buffer); Symbol load = LoadFunc(vec_ty, address_space, buffer);
@ -557,10 +557,10 @@ struct DecomposeMemoryAccess::State {
b.Func(name, params, b.ty.void_(), nullptr, b.Func(name, params, b.ty.void_(), nullptr,
utils::Vector{ utils::Vector{
intrinsic, intrinsic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody), b.Disable(DisabledValidation::kFunctionHasNoBody),
}); });
} else { } else {
auto body = Switch<utils::Vector<const ast::Statement*, 8>>( auto body = Switch<utils::Vector<const Statement*, 8>>(
el_ty, // el_ty, //
[&](const type::Array* arr_ty) { [&](const type::Array* arr_ty) {
// fn store_func(buffer : buf_ty, offset : u32, value : el_ty) { // fn store_func(buffer : buf_ty, offset : u32, value : el_ty) {
@ -585,8 +585,8 @@ struct DecomposeMemoryAccess::State {
<< "unexpected non-constant array count"; << "unexpected non-constant array count";
arr_cnt = 1; arr_cnt = 1;
} }
auto* for_cond = b.create<ast::BinaryExpression>( auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value()))); b.Expr(u32(arr_cnt.value())));
auto* for_cont = b.Assign(i, b.Add(i, 1_u)); auto* for_cont = b.Assign(i, b.Add(i, 1_u));
auto* arr_el = b.IndexAccessor(array, i); auto* arr_el = b.IndexAccessor(array, i);
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride()))); auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
@ -598,7 +598,7 @@ struct DecomposeMemoryAccess::State {
[&](const type::Matrix* mat_ty) { [&](const type::Matrix* mat_ty) {
auto* vec_ty = mat_ty->ColumnType(); auto* vec_ty = mat_ty->ColumnType();
Symbol store = StoreFunc(vec_ty, buffer); Symbol store = StoreFunc(vec_ty, buffer);
utils::Vector<const ast::Statement*, 4> stmts; utils::Vector<const Statement*, 4> stmts;
for (uint32_t i = 0; i < mat_ty->columns(); i++) { for (uint32_t i = 0; i < mat_ty->columns(); i++) {
auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride())); auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
auto* element = b.IndexAccessor("value", u32(i)); auto* element = b.IndexAccessor("value", u32(i));
@ -608,7 +608,7 @@ struct DecomposeMemoryAccess::State {
return stmts; return stmts;
}, },
[&](const type::Struct* str) { [&](const type::Struct* str) {
utils::Vector<const ast::Statement*, 8> stmts; utils::Vector<const Statement*, 8> stmts;
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
auto* offset = b.Add("offset", u32(member->Offset())); auto* offset = b.Add("offset", u32(member->Offset()));
auto* element = b.MemberAccessor("value", ctx.Clone(member->Name())); auto* element = b.MemberAccessor("value", ctx.Clone(member->Name()));
@ -656,14 +656,14 @@ struct DecomposeMemoryAccess::State {
<< el_ty->TypeInfo().name; << el_ty->TypeInfo().name;
} }
ast::Type ret_ty; Type ret_ty;
// For intrinsics that return a struct, there is no AST node for it, so create one now. // For intrinsics that return a struct, there is no AST node for it, so create one now.
if (intrinsic->Type() == builtin::Function::kAtomicCompareExchangeWeak) { if (intrinsic->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
auto* str = intrinsic->ReturnType()->As<type::Struct>(); auto* str = intrinsic->ReturnType()->As<type::Struct>();
TINT_ASSERT(Transform, str); TINT_ASSERT(Transform, str);
utils::Vector<const ast::StructMember*, 8> ast_members; utils::Vector<const StructMember*, 8> ast_members;
ast_members.Reserve(str->Members().Length()); ast_members.Reserve(str->Members().Length());
for (auto& m : str->Members()) { for (auto& m : str->Members()) {
ast_members.Push( ast_members.Push(
@ -681,7 +681,7 @@ struct DecomposeMemoryAccess::State {
b.Func(name, std::move(params), ret_ty, nullptr, b.Func(name, std::move(params), ret_ty, nullptr,
utils::Vector{ utils::Vector{
atomic, atomic,
b.Disable(ast::DisabledValidation::kFunctionHasNoBody), b.Disable(DisabledValidation::kFunctionHasNoBody),
}); });
return name; return name;
}); });
@ -689,11 +689,11 @@ struct DecomposeMemoryAccess::State {
}; };
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid, DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
ast::NodeID nid, NodeID nid,
Op o, Op o,
DataType ty, DataType ty,
builtin::AddressSpace as, builtin::AddressSpace as,
const ast::IdentifierExpression* buf) const IdentifierExpression* buf)
: Base(pid, nid, utils::Vector{buf}), op(o), type(ty), address_space(as) {} : Base(pid, nid, utils::Vector{buf}), op(o), type(ty), address_space(as) {}
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default; DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const { std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
@ -804,7 +804,7 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const {
return op != Op::kLoad && op != Op::kStore; return op != Op::kLoad && op != Op::kStore;
} }
const ast::IdentifierExpression* DecomposeMemoryAccess::Intrinsic::Buffer() const { const IdentifierExpression* DecomposeMemoryAccess::Intrinsic::Buffer() const {
return dependencies[0]; return dependencies[0];
} }
@ -832,7 +832,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
// nodes are fully immutable and require their children to be constructed // nodes are fully immutable and require their children to be constructed
// first so their pointer can be passed to the parent's initializer. // first so their pointer can be passed to the parent's initializer.
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::IdentifierExpression>()) { if (auto* ident = node->As<IdentifierExpression>()) {
// X // X
if (auto* sem_ident = sem.GetVal(ident)) { if (auto* sem_ident = sem.GetVal(ident)) {
if (auto* user = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) { if (auto* user = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) {
@ -852,7 +852,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
continue; continue;
} }
if (auto* accessor = node->As<ast::MemberAccessorExpression>()) { if (auto* accessor = node->As<MemberAccessorExpression>()) {
// X.Y // X.Y
auto* accessor_sem = sem.Get(accessor)->UnwrapLoad(); auto* accessor_sem = sem.Get(accessor)->UnwrapLoad();
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) { if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
@ -882,7 +882,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
continue; continue;
} }
if (auto* accessor = node->As<ast::IndexAccessorExpression>()) { if (auto* accessor = node->As<IndexAccessorExpression>()) {
if (auto access = state.TakeAccess(accessor->object)) { if (auto access = state.TakeAccess(accessor->object)) {
// X[Y] // X[Y]
if (auto* arr = access.type->As<type::Array>()) { if (auto* arr = access.type->As<type::Array>()) {
@ -915,8 +915,8 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
} }
} }
if (auto* op = node->As<ast::UnaryOpExpression>()) { if (auto* op = node->As<UnaryOpExpression>()) {
if (op->op == ast::UnaryOp::kAddressOf) { if (op->op == UnaryOp::kAddressOf) {
// &X // &X
if (auto access = state.TakeAccess(op->expr)) { if (auto access = state.TakeAccess(op->expr)) {
// HLSL does not support pointers, so just take the access from the // HLSL does not support pointers, so just take the access from the
@ -927,7 +927,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
} }
} }
if (auto* assign = node->As<ast::AssignmentStatement>()) { if (auto* assign = node->As<AssignmentStatement>()) {
// X = Y // X = Y
// Move the LHS access to a store. // Move the LHS access to a store.
if (auto lhs = state.TakeAccess(assign->lhs)) { if (auto lhs = state.TakeAccess(assign->lhs)) {
@ -935,7 +935,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
} }
} }
if (auto* call_expr = node->As<ast::CallExpression>()) { if (auto* call_expr = node->As<CallExpression>()) {
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>(); auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
if (auto* builtin = call->Target()->As<sem::Builtin>()) { if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (builtin->Type() == builtin::Function::kArrayLength) { if (builtin->Type() == builtin::Function::kArrayLength) {
@ -953,7 +953,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
auto buffer = ctx.Clone(access.var->Declaration()->name->symbol); auto buffer = ctx.Clone(access.var->Declaration()->name->symbol);
Symbol func = state.AtomicFunc(el_ty, builtin, buffer); Symbol func = state.AtomicFunc(el_ty, builtin, buffer);
utils::Vector<const ast::Expression*, 8> args{offset}; utils::Vector<const Expression*, 8> args{offset};
for (size_t i = 1; i < call_expr->args.Length(); i++) { for (size_t i = 1; i < call_expr->args.Length(); i++) {
auto* arg = call_expr->args[i]; auto* arg = call_expr->args[i];
args.Push(ctx.Clone(arg)); args.Push(ctx.Clone(arg));

View File

@ -35,7 +35,7 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
/// transforms this into calls to /// transforms this into calls to
/// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`, /// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
/// with a possible cast. /// with a possible cast.
class Intrinsic final : public utils::Castable<Intrinsic, ast::InternalAttribute> { class Intrinsic final : public utils::Castable<Intrinsic, InternalAttribute> {
public: public:
/// Intrinsic op /// Intrinsic op
enum class Op { enum class Op {
@ -82,11 +82,11 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
/// @param address_space the address space of the buffer /// @param address_space the address space of the buffer
/// @param buffer the storage or uniform buffer identifier /// @param buffer the storage or uniform buffer identifier
Intrinsic(ProgramID pid, Intrinsic(ProgramID pid,
ast::NodeID nid, NodeID nid,
Op o, Op o,
DataType type, DataType type,
builtin::AddressSpace address_space, builtin::AddressSpace address_space,
const ast::IdentifierExpression* buffer); const IdentifierExpression* buffer);
/// Destructor /// Destructor
~Intrinsic() override; ~Intrinsic() override;
@ -103,7 +103,7 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
bool IsAtomic() const; bool IsAtomic() const;
/// @return the buffer that this intrinsic operates on /// @return the buffer that this intrinsic operates on
const ast::IdentifierExpression* Buffer() const; const IdentifierExpression* Buffer() const;
/// The op of the intrinsic /// The op of the intrinsic
const Op op; const Op op;

View File

@ -37,8 +37,8 @@ using DecomposedArrays = std::unordered_map<const type::Array*, Symbol>;
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) { for (auto* node : program->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::TemplatedIdentifier>()) { if (auto* ident = node->As<TemplatedIdentifier>()) {
if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) { if (GetAttribute<StrideAttribute>(ident->attributes)) {
return true; return true;
} }
} }
@ -74,8 +74,8 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
// stride for the array element type, then replace the array element type with // stride for the array element type, then replace the array element type with
// a structure, holding a single field with a @size attribute equal to the // a structure, holding a single field with a @size attribute equal to the
// array stride. // array stride.
ctx.ReplaceAll([&](const ast::IdentifierExpression* expr) -> const ast::IdentifierExpression* { ctx.ReplaceAll([&](const IdentifierExpression* expr) -> const IdentifierExpression* {
auto* ident = expr->identifier->As<ast::TemplatedIdentifier>(); auto* ident = expr->identifier->As<TemplatedIdentifier>();
if (!ident) { if (!ident) {
return nullptr; return nullptr;
} }
@ -90,8 +90,8 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
if (!arr->IsStrideImplicit()) { if (!arr->IsStrideImplicit()) {
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] { auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
auto name = b.Symbols().New("strided_arr"); auto name = b.Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>()); auto* member_ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
auto* member = b.Member(kMemberName, ast::Type{member_ty}, auto* member = b.Member(kMemberName, Type{member_ty},
utils::Vector{ utils::Vector{
b.MemberSize(AInt(arr->Stride())), b.MemberSize(AInt(arr->Stride())),
}); });
@ -105,14 +105,14 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
return b.Expr(b.ty.array(b.ty(el_ty))); return b.Expr(b.ty.array(b.ty(el_ty)));
} }
} }
if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) { if (GetAttribute<StrideAttribute>(ident->attributes)) {
// Strip the @stride attribute // Strip the @stride attribute
auto* ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>()); auto* ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
if (ident->arguments.Length() > 1) { if (ident->arguments.Length() > 1) {
auto* count = ctx.Clone(ident->arguments[1]); auto* count = ctx.Clone(ident->arguments[1]);
return b.Expr(b.ty.array(ast::Type{ty}, count)); return b.Expr(b.ty.array(Type{ty}, count));
} else { } else {
return b.Expr(b.ty.array(ast::Type{ty})); return b.Expr(b.ty.array(Type{ty}));
} }
} }
return nullptr; return nullptr;
@ -122,7 +122,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
// element changed to a single field structure. These expressions are adjusted // element changed to a single field structure. These expressions are adjusted
// to insert an additional member accessor for the single structure field. // to insert an additional member accessor for the single structure field.
// Example: `arr[i]` -> `arr[i].el` // Example: `arr[i]` -> `arr[i].el`
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* { ctx.ReplaceAll([&](const IndexAccessorExpression* idx) -> const Expression* {
if (auto* ty = src->TypeOf(idx->object)) { if (auto* ty = src->TypeOf(idx->object)) {
if (auto* arr = ty->UnwrapRef()->As<type::Array>()) { if (auto* arr = ty->UnwrapRef()->As<type::Array>()) {
if (!arr->IsStrideImplicit()) { if (!arr->IsStrideImplicit()) {
@ -140,7 +140,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
// `@stride(32) array<i32, 3>(1, 2, 3)` // `@stride(32) array<i32, 3>(1, 2, 3)`
// -> // ->
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))` // `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* { ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
if (!expr->args.IsEmpty()) { if (!expr->args.IsEmpty()) {
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) { if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
if (auto* ctor = call->Target()->As<sem::ValueConstructor>()) { if (auto* ctor = call->Target()->As<sem::ValueConstructor>()) {
@ -153,7 +153,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
auto* target = ctx.Clone(expr->target); auto* target = ctx.Clone(expr->target);
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const Expression*, 8> args;
if (auto it = decomposed.find(arr); it != decomposed.end()) { if (auto it = decomposed.find(arr); it != decomposed.end()) {
args.Reserve(expr->args.Length()); args.Reserve(expr->args.Length());
for (auto* arg : expr->args) { for (auto* arg : expr->args) {

View File

@ -101,7 +101,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -145,7 +145,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -195,7 +195,7 @@ TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -253,7 +253,7 @@ TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 2_i))), b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 2_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -303,7 +303,7 @@ TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -357,7 +357,7 @@ TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))), b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -410,7 +410,7 @@ TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f), b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -472,7 +472,7 @@ TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f), b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -531,7 +531,7 @@ TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), 5_f), b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), 5_f),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -593,7 +593,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f), b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -696,7 +696,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
5_f), 5_f),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });

View File

@ -38,7 +38,7 @@ struct MatrixInfo {
const type::Matrix* matrix = nullptr; const type::Matrix* matrix = nullptr;
/// @returns the identifier of an array that holds an vector column for each row of the matrix. /// @returns the identifier of an array that holds an vector column for each row of the matrix.
ast::Type array(ProgramBuilder* b) const { Type array(ProgramBuilder* b) const {
return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()), return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()),
utils::Vector{ utils::Vector{
b->Stride(stride), b->Stride(stride),
@ -72,7 +72,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// and populate the `decomposed` map with the members that have been replaced. // and populate the `decomposed` map with the members that have been replaced.
utils::Hashmap<const type::StructMember*, MatrixInfo, 8> decomposed; utils::Hashmap<const type::StructMember*, MatrixInfo, 8> decomposed;
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) { if (auto* str = node->As<Struct>()) {
auto* str_ty = src->Sem().Get(str); auto* str_ty = src->Sem().Get(str);
if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) && if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) &&
!str_ty->UsedAs(builtin::AddressSpace::kStorage)) { !str_ty->UsedAs(builtin::AddressSpace::kStorage)) {
@ -83,8 +83,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
if (!matrix) { if (!matrix) {
continue; continue;
} }
auto* attr = auto* attr = GetAttribute<StrideAttribute>(member->Declaration()->attributes);
ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes);
if (!attr) { if (!attr) {
continue; continue;
} }
@ -111,8 +110,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// preserve these without calling conversion functions. // preserve these without calling conversion functions.
// Example: // Example:
// ssbo.mat[2] -> ssbo.mat[2] // ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll( ctx.ReplaceAll([&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) { if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
if (decomposed.Contains(access->Member())) { if (decomposed.Contains(access->Member())) {
auto* obj = ctx.CloneWithoutTransform(expr->object); auto* obj = ctx.CloneWithoutTransform(expr->object);
@ -129,7 +127,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// Example: // Example:
// ssbo.mat = mat_to_arr(m) // ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr; std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { ctx.ReplaceAll([&](const AssignmentStatement* stmt) -> const Statement* {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) { if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
if (auto info = decomposed.Find(access->Member())) { if (auto info = decomposed.Find(access->Member())) {
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] { auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
@ -142,7 +140,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
auto array = [&] { return info->array(ctx.dst); }; auto array = [&] { return info->array(ctx.dst); };
auto mat = b.Sym("m"); auto mat = b.Sym("m");
utils::Vector<const ast::Expression*, 4> columns; utils::Vector<const Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
columns.Push(b.IndexAccessor(mat, u32(i))); columns.Push(b.IndexAccessor(mat, u32(i)));
} }
@ -168,7 +166,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
// matrix type. Example: // matrix type. Example:
// m = arr_to_mat(ssbo.mat) // m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat; std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { ctx.ReplaceAll([&](const MemberAccessorExpression* expr) -> const Expression* {
if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) { if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) {
if (auto info = decomposed.Find(access->Member())) { if (auto info = decomposed.Find(access->Member())) {
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
@ -181,7 +179,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
auto array = [&] { return info->array(ctx.dst); }; auto array = [&] { return info->array(ctx.dst); };
auto arr = b.Sym("arr"); auto arr = b.Sym("arr");
utils::Vector<const ast::Expression*, 4> columns; utils::Vector<const Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) { for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
columns.Push(b.IndexAccessor(arr, u32(i))); columns.Push(b.IndexAccessor(arr, u32(i)));
} }

View File

@ -67,13 +67,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
// let x : mat2x2<f32> = s.m; // let x : mat2x2<f32> = s.m;
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(16_u), b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a)); b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
@ -82,7 +82,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))), b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -124,13 +124,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
// let x : vec2<f32> = s.m[1]; // let x : vec2<f32> = s.m[1];
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(16_u), b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a)); b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
@ -140,7 +140,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))), b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -178,13 +178,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
// let x : mat2x2<f32> = s.m; // let x : mat2x2<f32> = s.m;
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(16_u), b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(8u), b.create<StrideAttribute>(8u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a)); b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
@ -193,7 +193,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))), b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -232,13 +232,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
// let x : mat2x2<f32> = s.m; // let x : mat2x2<f32> = s.m;
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(8_u), b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite, b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -248,7 +248,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))), b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -290,13 +290,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
// let x : vec2<f32> = s.m[1]; // let x : vec2<f32> = s.m[1];
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(16_u), b.MemberOffset(16_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite, b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -307,7 +307,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))), b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -345,13 +345,13 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)); // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(8_u), b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite, b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -362,7 +362,7 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))), b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -404,13 +404,13 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
// s.m[1] = vec2<f32>(1.0, 2.0); // s.m[1] = vec2<f32>(1.0, 2.0);
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(8_u), b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite, b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -420,7 +420,7 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.vec2<f32>(1_f, 2_f)), b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.vec2<f32>(1_f, 2_f)),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -464,13 +464,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
// (*b)[1] = vec2<f32>(5.0, 6.0); // (*b)[1] = vec2<f32>(5.0, 6.0);
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(8_u), b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite, b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
@ -486,7 +486,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), b.vec2<f32>(5_f, 6_f)), b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), b.vec2<f32>(5_f, 6_f)),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -536,13 +536,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
// let x : mat2x2<f32> = s.m; // let x : mat2x2<f32> = s.m;
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(8_u), b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate); b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
@ -551,7 +551,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))), b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });
@ -590,13 +590,13 @@ TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0)); // s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
// } // }
ProgramBuilder b; ProgramBuilder b;
auto* S = b.Structure( auto* S =
"S", utils::Vector{ b.Structure("S", utils::Vector{
b.Member("m", b.ty.mat2x2<f32>(), b.Member("m", b.ty.mat2x2<f32>(),
utils::Vector{ utils::Vector{
b.MemberOffset(8_u), b.MemberOffset(8_u),
b.create<ast::StrideAttribute>(32u), b.create<StrideAttribute>(32u),
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute), b.Disable(DisabledValidation::kIgnoreStrideAttribute),
}), }),
}); });
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate); b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
@ -606,7 +606,7 @@ TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))), b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
}, },
utils::Vector{ utils::Vector{
b.Stage(ast::PipelineStage::kCompute), b.Stage(PipelineStage::kCompute),
b.WorkgroupSize(1_i), b.WorkgroupSize(1_i),
}); });

View File

@ -85,9 +85,8 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
b.GlobalVar(flag, builtin::AddressSpace::kPrivate, b.Expr(false)); b.GlobalVar(flag, builtin::AddressSpace::kPrivate, b.Expr(false));
// Replace all discard statements with a statement that marks the invocation as discarded. // Replace all discard statements with a statement that marks the invocation as discarded.
ctx.ReplaceAll([&](const ast::DiscardStatement*) -> const ast::Statement* { ctx.ReplaceAll(
return b.Assign(flag, b.Expr(true)); [&](const DiscardStatement*) -> const Statement* { return b.Assign(flag, b.Expr(true)); });
});
// Insert a conditional discard at the end of each entry point that does not end with a return. // Insert a conditional discard at the end of each entry point that does not end with a return.
for (auto* func : functions_to_process) { for (auto* func : functions_to_process) {
@ -111,7 +110,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
node, node,
// Mask assignments to storage buffer variables. // Mask assignments to storage buffer variables.
[&](const ast::AssignmentStatement* assign) { [&](const AssignmentStatement* assign) {
// Skip writes in functions that are not called from shaders that discard. // Skip writes in functions that are not called from shaders that discard.
auto* func = sem.Get(assign)->Function(); auto* func = sem.Get(assign)->Function();
if (functions_to_process.count(func) == 0) { if (functions_to_process.count(func) == 0) {
@ -119,7 +118,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
} }
// Skip phony assignments. // Skip phony assignments.
if (assign->lhs->Is<ast::PhonyExpression>()) { if (assign->lhs->Is<PhonyExpression>()) {
return; return;
} }
@ -144,7 +143,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
}, },
// Mask builtins that write to host-visible memory. // Mask builtins that write to host-visible memory.
[&](const ast::CallExpression* call) { [&](const CallExpression* call) {
auto* sem_call = sem.Get<sem::Call>(call); auto* sem_call = sem.Get<sem::Call>(call);
auto* stmt = sem_call ? sem_call->Stmt() : nullptr; auto* stmt = sem_call ? sem_call->Stmt() : nullptr;
auto* func = stmt ? stmt->Function() : nullptr; auto* func = stmt ? stmt->Function() : nullptr;
@ -161,7 +160,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
} else if (builtin->IsAtomic() && } else if (builtin->IsAtomic() &&
builtin->Type() != builtin::Function::kAtomicLoad) { builtin->Type() != builtin::Function::kAtomicLoad) {
// A call to an atomic builtin can be a statement or an expression. // A call to an atomic builtin can be a statement or an expression.
if (auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>(); if (auto* call_stmt = stmt->Declaration()->As<CallStatement>();
call_stmt && call_stmt->expr == call) { call_stmt && call_stmt->expr == call) {
// This call is a statement. // This call is a statement.
// Wrap it inside a conditional block. // Wrap it inside a conditional block.
@ -178,8 +177,8 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
// } // }
// let y = x + tmp; // let y = x + tmp;
auto result = b.Sym(); auto result = b.Sym();
ast::Type result_ty; Type result_ty;
const ast::Statement* masked_call = nullptr; const Statement* masked_call = nullptr;
if (builtin->Type() == builtin::Function::kAtomicCompareExchangeWeak) { if (builtin->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
// Special case for atomicCompareExchangeWeak as we cannot name its // Special case for atomicCompareExchangeWeak as we cannot name its
// result type. We have to declare an equivalent struct and copy the // result type. We have to declare an equivalent struct and copy the
@ -232,7 +231,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
}, },
// Insert a conditional discard before all return statements in entry points. // Insert a conditional discard before all return statements in entry points.
[&](const ast::ReturnStatement* ret) { [&](const ReturnStatement* ret) {
auto* func = sem.Get(ret)->Function(); auto* func = sem.Get(ret)->Function();
if (func->Declaration()->IsEntryPoint() && functions_to_process.count(func)) { if (func->Declaration()->IsEntryPoint() && functions_to_process.count(func)) {
auto* discard = b.If(flag, b.Block(b.Discard())); auto* discard = b.If(flag, b.Block(b.Discard()));

View File

@ -450,19 +450,19 @@ struct DirectVariableAccess::State {
Switch( Switch(
variable->Declaration(), variable->Declaration(),
[&](const ast::Var*) { [&](const Var*) {
if (variable->AddressSpace() != builtin::AddressSpace::kHandle) { if (variable->AddressSpace() != builtin::AddressSpace::kHandle) {
// Start a new access chain for the non-handle 'var' access // Start a new access chain for the non-handle 'var' access
create_new_chain(); create_new_chain();
} }
}, },
[&](const ast::Parameter*) { [&](const Parameter*) {
if (variable->Type()->Is<type::Pointer>()) { if (variable->Type()->Is<type::Pointer>()) {
// Start a new access chain for the pointer parameter access // Start a new access chain for the pointer parameter access
create_new_chain(); create_new_chain();
} }
}, },
[&](const ast::Let*) { [&](const Let*) {
if (variable->Type()->Is<type::Pointer>()) { if (variable->Type()->Is<type::Pointer>()) {
// variable is a pointer-let. // variable is a pointer-let.
auto* init = sem.GetVal(variable->Declaration()->initializer); auto* init = sem.GetVal(variable->Declaration()->initializer);
@ -494,11 +494,10 @@ struct DirectVariableAccess::State {
} }
}, },
[&](const sem::ValueExpression* e) { [&](const sem::ValueExpression* e) {
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) { if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
// Unary op. // Unary op.
// If this is a '&' or '*', simply move the chain to the unary op expression. // If this is a '&' or '*', simply move the chain to the unary op expression.
if (unary->op == ast::UnaryOp::kAddressOf || if (unary->op == UnaryOp::kAddressOf || unary->op == UnaryOp::kIndirection) {
unary->op == ast::UnaryOp::kIndirection) {
take_chain(sem.GetVal(unary->expr)); take_chain(sem.GetVal(unary->expr));
} }
} }
@ -529,7 +528,7 @@ struct DirectVariableAccess::State {
if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) { if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) {
auto* idx_variable = idx_variable_user->Variable(); auto* idx_variable = idx_variable_user->Variable();
if (idx_variable->Declaration()->IsAnyOf<ast::Let, ast::Parameter>()) { if (idx_variable->Declaration()->IsAnyOf<Let, Parameter>()) {
// Dynamic index is an immutable variable // Dynamic index is an immutable variable
continue; // Hoisting not required. continue; // Hoisting not required.
} }
@ -557,7 +556,7 @@ struct DirectVariableAccess::State {
/// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type /// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type
/// isn't implicitly usable as a u32. This is to help feed the expression into a /// isn't implicitly usable as a u32. This is to help feed the expression into a
/// `array<u32, N>` argument passed to a callee variant function. /// `array<u32, N>` argument passed to a callee variant function.
const ast::Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) { const Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) {
if (auto* val = idx->ConstantValue()) { if (auto* val = idx->ConstantValue()) {
// Expression evaluated to a constant value. Just emit that constant. // Expression evaluated to a constant value. Just emit that constant.
return b.Expr(val->ValueAs<AInt>()); return b.Expr(val->ValueAs<AInt>());
@ -808,7 +807,7 @@ struct DirectVariableAccess::State {
// many variant functions, keep a record of the last created variant, and explicitly add // many variant functions, keep a record of the last created variant, and explicitly add
// this to the module if it isn't the last. We'll return the last created variant, // this to the module if it isn't the last. We'll return the last created variant,
// taking the place of the original function. // taking the place of the original function.
const ast::Function* pending_variant = nullptr; const Function* pending_variant = nullptr;
// For each variant of fn... // For each variant of fn...
for (auto variant_it : fn_info->SortedVariants()) { for (auto variant_it : fn_info->SortedVariants()) {
@ -827,7 +826,7 @@ struct DirectVariableAccess::State {
// Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are // Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are
// either replaced with an array of dynamic indices, or are dropped (if there are no // either replaced with an array of dynamic indices, or are dropped (if there are no
// dynamic indices). // dynamic indices).
utils::Vector<const ast::Parameter*, 8> params; utils::Vector<const Parameter*, 8> params;
for (auto* param : fn->Parameters()) { for (auto* param : fn->Parameters()) {
if (auto incoming_shape = variant_sig.Find(param)) { if (auto incoming_shape = variant_sig.Find(param)) {
auto& symbols = *variant.ptr_param_symbols.Find(param); auto& symbols = *variant.ptr_param_symbols.Find(param);
@ -856,7 +855,7 @@ struct DirectVariableAccess::State {
auto attrs = ctx.Clone(fn->Declaration()->attributes); auto attrs = ctx.Clone(fn->Declaration()->attributes);
auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes); auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
pending_variant = pending_variant =
b.create<ast::Function>(b.Ident(variant.name), std::move(params), ret_ty, body, b.create<Function>(b.Ident(variant.name), std::move(params), ret_ty, body,
std::move(attrs), std::move(ret_attrs)); std::move(attrs), std::move(ret_attrs));
} }
@ -877,7 +876,7 @@ struct DirectVariableAccess::State {
} }
// Build the new call expressions's arguments. // Build the new call expressions's arguments.
utils::Vector<const ast::Expression*, 8> new_args; utils::Vector<const Expression*, 8> new_args;
for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) { for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) {
auto* arg = call->Arguments()[arg_idx]; auto* arg = call->Arguments()[arg_idx];
auto* param = call->Target()->Parameters()[arg_idx]; auto* param = call->Target()->Parameters()[arg_idx];
@ -915,7 +914,7 @@ struct DirectVariableAccess::State {
// Get or create the dynamic indices array. // Get or create the dynamic indices array.
if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) { if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
// Build an array of dynamic indices to pass as the replacement for the pointer. // Build an array of dynamic indices to pass as the replacement for the pointer.
utils::Vector<const ast::Expression*, 8> dyn_idx_args; utils::Vector<const Expression*, 8> dyn_idx_args;
if (auto* root_param = chain->root.variable->As<sem::Parameter>()) { if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
// Access chain originates from a pointer parameter. // Access chain originates from a pointer parameter.
if (auto incoming_chain = if (auto incoming_chain =
@ -985,7 +984,7 @@ struct DirectVariableAccess::State {
/// let. /// let.
void TransformAccessChainExpressions() { void TransformAccessChainExpressions() {
// Register a custom handler for all non-function call expressions // Register a custom handler for all non-function call expressions
ctx.ReplaceAll([this](const ast::Expression* ast_expr) -> const ast::Expression* { ctx.ReplaceAll([this](const Expression* ast_expr) -> const Expression* {
if (!clone_state->current_variant) { if (!clone_state->current_variant) {
// Expression does not belong to a function variant. // Expression does not belong to a function variant.
return nullptr; // Just clone the expression. return nullptr; // Just clone the expression.
@ -1065,7 +1064,7 @@ struct DirectVariableAccess::State {
/// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias /// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
/// if this is the first call for the given shape. /// if this is the first call for the given shape.
ast::Type DynamicIndexArrayType(const AccessShape& shape) { Type DynamicIndexArrayType(const AccessShape& shape) {
auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] { auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
// Count the number of dynamic indices // Count the number of dynamic indices
uint32_t num_dyn_indices = shape.NumDynamicIndices(); uint32_t num_dyn_indices = shape.NumDynamicIndices();
@ -1076,7 +1075,7 @@ struct DirectVariableAccess::State {
b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices))); b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
return symbol; return symbol;
}); });
return name.IsValid() ? b.ty(name) : ast::Type{}; return name.IsValid() ? b.ty(name) : Type{};
} }
/// @returns a name describing the given shape /// @returns a name describing the given shape
@ -1113,7 +1112,7 @@ struct DirectVariableAccess::State {
/// Builds an expresion to the root of an access, returning the new expression. /// Builds an expresion to the root of an access, returning the new expression.
/// @param root the AccessRoot /// @param root the AccessRoot
/// @param deref if true, the returned expression will always be a reference type. /// @param deref if true, the returned expression will always be a reference type.
const ast::Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) { const Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
if (auto* param = root.variable->As<sem::Parameter>()) { if (auto* param = root.variable->As<sem::Parameter>()) {
if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) { if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
if (deref) { if (deref) {
@ -1123,7 +1122,7 @@ struct DirectVariableAccess::State {
} }
} }
const ast::Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol)); const Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol));
if (deref) { if (deref) {
if (root.variable->Type()->Is<type::Pointer>()) { if (root.variable->Type()->Is<type::Pointer>()) {
expr = b.Deref(expr); expr = b.Deref(expr);
@ -1137,10 +1136,9 @@ struct DirectVariableAccess::State {
/// @param expr the input expression /// @param expr the input expression
/// @param access the access to perform on the current expression /// @param access the access to perform on the current expression
/// @param dynamic_index a function that obtains the i'th dynamic index /// @param dynamic_index a function that obtains the i'th dynamic index
const ast::Expression* BuildAccessExpr( const Expression* BuildAccessExpr(const Expression* expr,
const ast::Expression* expr,
const AccessOp& access, const AccessOp& access,
std::function<const ast::Expression*(size_t)> dynamic_index) { std::function<const Expression*(size_t)> dynamic_index) {
if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) { if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
/// The access uses a dynamic (runtime-expression) index. /// The access uses a dynamic (runtime-expression) index.
auto* idx = dynamic_index(dyn_idx->slot); auto* idx = dynamic_index(dyn_idx->slot);

View File

@ -35,7 +35,7 @@ namespace {
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) { for (auto* node : program->ASTNodes().Objects()) {
if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) { if (node->IsAnyOf<CompoundAssignmentStatement, IncrementDecrementStatement>()) {
return true; return true;
} }
} }
@ -58,17 +58,14 @@ struct ExpandCompoundAssignment::State {
/// @param lhs the lhs expression from the source statement /// @param lhs the lhs expression from the source statement
/// @param rhs the rhs expression in the destination module /// @param rhs the rhs expression in the destination module
/// @param op the binary operator /// @param op the binary operator
void Expand(const ast::Statement* stmt, void Expand(const Statement* stmt, const Expression* lhs, const Expression* rhs, BinaryOp op) {
const ast::Expression* lhs,
const ast::Expression* rhs,
ast::BinaryOp op) {
// Helper function to create the new LHS expression. This will be called // Helper function to create the new LHS expression. This will be called
// twice when building the non-compound assignment statement, so must // twice when building the non-compound assignment statement, so must
// not produce expressions that cause side effects. // not produce expressions that cause side effects.
std::function<const ast::Expression*()> new_lhs; std::function<const Expression*()> new_lhs;
// Helper function to create a variable that is a pointer to `expr`. // Helper function to create a variable that is a pointer to `expr`.
auto hoist_pointer_to = [&](const ast::Expression* expr) { auto hoist_pointer_to = [&](const Expression* expr) {
auto name = b.Sym(); auto name = b.Sym();
auto* ptr = b.AddressOf(ctx.Clone(expr)); auto* ptr = b.AddressOf(ctx.Clone(expr));
auto* decl = b.Decl(b.Let(name, ptr)); auto* decl = b.Decl(b.Let(name, ptr));
@ -77,7 +74,7 @@ struct ExpandCompoundAssignment::State {
}; };
// Helper function to hoist `expr` to a let declaration. // Helper function to hoist `expr` to a let declaration.
auto hoist_expr_to_let = [&](const ast::Expression* expr) { auto hoist_expr_to_let = [&](const Expression* expr) {
auto name = b.Sym(); auto name = b.Sym();
auto* decl = b.Decl(b.Let(name, ctx.Clone(expr))); auto* decl = b.Decl(b.Let(name, ctx.Clone(expr)));
hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl); hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
@ -85,7 +82,7 @@ struct ExpandCompoundAssignment::State {
}; };
// Helper function that returns `true` if the type of `expr` is a vector. // Helper function that returns `true` if the type of `expr` is a vector.
auto is_vec = [&](const ast::Expression* expr) { auto is_vec = [&](const Expression* expr) {
if (auto* val_expr = ctx.src->Sem().GetVal(expr)) { if (auto* val_expr = ctx.src->Sem().GetVal(expr)) {
return val_expr->Type()->UnwrapRef()->Is<type::Vector>(); return val_expr->Type()->UnwrapRef()->Is<type::Vector>();
} }
@ -96,10 +93,10 @@ struct ExpandCompoundAssignment::State {
// LHS that we can evaluate twice. // LHS that we can evaluate twice.
// We need to special case compound assignments to vector components since // We need to special case compound assignments to vector components since
// we cannot take the address of a vector component. // we cannot take the address of a vector component.
auto* index_accessor = lhs->As<ast::IndexAccessorExpression>(); auto* index_accessor = lhs->As<IndexAccessorExpression>();
auto* member_accessor = lhs->As<ast::MemberAccessorExpression>(); auto* member_accessor = lhs->As<MemberAccessorExpression>();
if (lhs->Is<ast::IdentifierExpression>() || if (lhs->Is<IdentifierExpression>() ||
(member_accessor && member_accessor->object->Is<ast::IdentifierExpression>())) { (member_accessor && member_accessor->object->Is<IdentifierExpression>())) {
// This is the simple case with no side effects, so we can just use the // This is the simple case with no side effects, so we can just use the
// original LHS expression directly. // original LHS expression directly.
// Before: // Before:
@ -144,7 +141,7 @@ struct ExpandCompoundAssignment::State {
} }
// Replace the statement with a regular assignment statement. // Replace the statement with a regular assignment statement.
auto* value = b.create<ast::BinaryExpression>(op, new_lhs(), rhs); auto* value = b.create<BinaryExpression>(op, new_lhs(), rhs);
ctx.Replace(stmt, b.Assign(new_lhs(), value)); ctx.Replace(stmt, b.Assign(new_lhs(), value));
} }
@ -174,11 +171,11 @@ Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src,
CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state(ctx); State state(ctx);
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) { if (auto* assign = node->As<CompoundAssignmentStatement>()) {
state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op); state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
} else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) { } else if (auto* inc_dec = node->As<IncrementDecrementStatement>()) {
// For increment/decrement statements, `i++` becomes `i = i + 1`. // For increment/decrement statements, `i++` becomes `i = i + 1`.
auto op = inc_dec->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract; auto op = inc_dec->increment ? BinaryOp::kAdd : BinaryOp::kSubtract;
state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op); state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
} }
} }

View File

@ -38,7 +38,7 @@ constexpr char kFirstInstanceName[] = "first_instance_index";
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* fn : program->AST().Functions()) { for (auto* fn : program->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) { if (fn->PipelineStage() == PipelineStage::kVertex) {
return true; return true;
} }
} }
@ -86,9 +86,9 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
// Traverse the AST scanning for builtin accesses via variables (includes // Traverse the AST scanning for builtin accesses via variables (includes
// parameters) or structure member accesses. // parameters) or structure member accesses.
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Variable>()) { if (auto* var = node->As<Variable>()) {
for (auto* attr : var->attributes) { for (auto* attr : var->attributes) {
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) { if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value(); builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kVertexIndex) { if (builtin == builtin::BuiltinValue::kVertexIndex) {
auto* sem_var = ctx.src->Sem().Get(var); auto* sem_var = ctx.src->Sem().Get(var);
@ -103,9 +103,9 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
} }
} }
} }
if (auto* member = node->As<ast::StructMember>()) { if (auto* member = node->As<StructMember>()) {
for (auto* attr : member->attributes) { for (auto* attr : member->attributes) {
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) { if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value(); builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kVertexIndex) { if (builtin == builtin::BuiltinValue::kVertexIndex) {
auto* sem_mem = ctx.src->Sem().Get(member); auto* sem_mem = ctx.src->Sem().Get(member);
@ -124,7 +124,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
if (has_vertex_or_instance_index) { if (has_vertex_or_instance_index) {
// Add uniform buffer members and calculate byte offsets // Add uniform buffer members and calculate byte offsets
utils::Vector<const ast::StructMember*, 8> members; utils::Vector<const StructMember*, 8> members;
members.Push(b.Member(kFirstVertexName, b.ty.u32())); members.Push(b.Member(kFirstVertexName, b.ty.u32()));
members.Push(b.Member(kFirstInstanceName, b.ty.u32())); members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
auto* struct_ = b.Structure(b.Sym(), std::move(members)); auto* struct_ = b.Structure(b.Sym(), std::move(members));
@ -138,7 +138,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
}); });
// Fix up all references to the builtins with the offsets // Fix up all references to the builtins with the offsets
ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* { ctx.ReplaceAll([=, &ctx](const Expression* expr) -> const Expression* {
if (auto* sem = ctx.src->Sem().GetVal(expr)) { if (auto* sem = ctx.src->Sem().GetVal(expr)) {
if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) { if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) {
auto it = builtin_vars.find(user->Variable()); auto it = builtin_vars.find(user->Variable());

View File

@ -26,7 +26,7 @@ namespace {
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) { for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::ForLoopStatement>()) { if (node->Is<ForLoopStatement>()) {
return true; return true;
} }
} }
@ -47,8 +47,8 @@ Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&,
ProgramBuilder b; ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* { ctx.ReplaceAll([&](const ForLoopStatement* for_loop) -> const Statement* {
utils::Vector<const ast::Statement*, 8> stmts; utils::Vector<const Statement*, 8> stmts;
if (auto* cond = for_loop->condition) { if (auto* cond = for_loop->condition) {
// !condition // !condition
auto* not_cond = b.Not(ctx.Clone(cond)); auto* not_cond = b.Not(ctx.Clone(cond));
@ -63,7 +63,7 @@ Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&,
stmts.Push(ctx.Clone(stmt)); stmts.Push(ctx.Clone(stmt));
} }
const ast::BlockStatement* continuing = nullptr; const BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) { if (auto* cont = for_loop->continuing) {
continuing = b.Block(ctx.Clone(cont)); continuing = b.Block(ctx.Clone(cont));
} }

View File

@ -43,14 +43,14 @@ struct LocalizeStructArrayAssignment::State {
ApplyResult Run() { ApplyResult Run() {
struct Shared { struct Shared {
bool process_nested_nodes = false; bool process_nested_nodes = false;
utils::Vector<const ast::Statement*, 4> insert_before_stmts; utils::Vector<const Statement*, 4> insert_before_stmts;
utils::Vector<const ast::Statement*, 4> insert_after_stmts; utils::Vector<const Statement*, 4> insert_after_stmts;
} s; } s;
bool made_changes = false; bool made_changes = false;
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* assign_stmt = node->As<ast::AssignmentStatement>()) { if (auto* assign_stmt = node->As<AssignmentStatement>()) {
// Process if it's an assignment statement to a dynamically indexed array // Process if it's an assignment statement to a dynamically indexed array
// within a struct on a function or private storage variable. This // within a struct on a function or private storage variable. This
// specific use-case is what FXC fails to compile with: // specific use-case is what FXC fails to compile with:
@ -70,7 +70,7 @@ struct LocalizeStructArrayAssignment::State {
// Reset shared state for this assignment statement // Reset shared state for this assignment statement
s = Shared{}; s = Shared{};
const ast::Expression* new_lhs = nullptr; const Expression* new_lhs = nullptr;
{ {
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true); TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
new_lhs = ctx.Clone(assign_stmt->lhs); new_lhs = ctx.Clone(assign_stmt->lhs);
@ -98,14 +98,13 @@ struct LocalizeStructArrayAssignment::State {
return SkipTransform; return SkipTransform;
} }
ctx.ReplaceAll( ctx.ReplaceAll([&](const IndexAccessorExpression* index_access) -> const Expression* {
[&](const ast::IndexAccessorExpression* index_access) -> const ast::Expression* {
if (!s.process_nested_nodes) { if (!s.process_nested_nodes) {
return nullptr; return nullptr;
} }
// Indexing a member access expr? // Indexing a member access expr?
auto* mem_access = index_access->object->As<ast::MemberAccessorExpression>(); auto* mem_access = index_access->object->As<MemberAccessorExpression>();
if (!mem_access) { if (!mem_access) {
return nullptr; return nullptr;
} }
@ -136,7 +135,7 @@ struct LocalizeStructArrayAssignment::State {
// e.g. *(tint_symbol) = tint_symbol_1; // e.g. *(tint_symbol) = tint_symbol_1;
auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var); auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
{ {
utils::Vector<const ast::Statement*, 8> stmts{assign_rhs_to_temp}; utils::Vector<const Statement*, 8> stmts{assign_rhs_to_temp};
for (auto* stmt : s.insert_after_stmts) { for (auto* stmt : s.insert_after_stmts) {
stmts.Push(stmt); stmts.Push(stmt);
} }
@ -160,23 +159,22 @@ struct LocalizeStructArrayAssignment::State {
/// Returns true if `expr` contains an index accessor expression to a /// Returns true if `expr` contains an index accessor expression to a
/// structure member of array type. /// structure member of array type.
bool ContainsStructArrayIndex(const ast::Expression* expr) { bool ContainsStructArrayIndex(const Expression* expr) {
bool result = false; bool result = false;
ast::TraverseExpressions( TraverseExpressions(expr, b.Diagnostics(), [&](const IndexAccessorExpression* ia) {
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
// Indexing using a runtime value? // Indexing using a runtime value?
auto* idx_sem = src->Sem().GetVal(ia->index); auto* idx_sem = src->Sem().GetVal(ia->index);
if (!idx_sem->ConstantValue()) { if (!idx_sem->ConstantValue()) {
// Indexing a member access expr? // Indexing a member access expr?
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) { if (auto* ma = ia->object->As<MemberAccessorExpression>()) {
// That accesses an array? // That accesses an array?
if (src->TypeOf(ma)->UnwrapRef()->Is<type::Array>()) { if (src->TypeOf(ma)->UnwrapRef()->Is<type::Array>()) {
result = true; result = true;
return ast::TraverseAction::Stop; return TraverseAction::Stop;
} }
} }
} }
return ast::TraverseAction::Descend; return TraverseAction::Descend;
}); });
return result; return result;
@ -186,7 +184,7 @@ struct LocalizeStructArrayAssignment::State {
// of the assignment statement. // of the assignment statement.
// See https://www.w3.org/TR/WGSL/#originating-variable-section // See https://www.w3.org/TR/WGSL/#originating-variable-section
std::pair<const type::Type*, builtin::AddressSpace> GetOriginatingTypeAndAddressSpace( std::pair<const type::Type*, builtin::AddressSpace> GetOriginatingTypeAndAddressSpace(
const ast::AssignmentStatement* assign_stmt) { const AssignmentStatement* assign_stmt) {
auto* root_ident = src->Sem().GetVal(assign_stmt->lhs)->RootIdentifier(); auto* root_ident = src->Sem().GetVal(assign_stmt->lhs)->RootIdentifier();
if (TINT_UNLIKELY(!root_ident)) { if (TINT_UNLIKELY(!root_ident)) {
TINT_ICE(Transform, b.Diagnostics()) TINT_ICE(Transform, b.Diagnostics())

View File

@ -30,12 +30,12 @@ namespace tint::ast::transform {
namespace { namespace {
/// Returns `true` if `stmt` has the behavior `behavior`. /// Returns `true` if `stmt` has the behavior `behavior`.
bool HasBehavior(const Program* program, const ast::Statement* stmt, sem::Behavior behavior) { bool HasBehavior(const Program* program, const Statement* stmt, sem::Behavior behavior) {
return program->Sem().Get(stmt)->Behaviors().Contains(behavior); return program->Sem().Get(stmt)->Behaviors().Contains(behavior);
} }
/// Returns `true` if `func` needs to be transformed. /// Returns `true` if `func` needs to be transformed.
bool NeedsTransform(const Program* program, const ast::Function* func) { bool NeedsTransform(const Program* program, const Function* func) {
// Entry points and intrinsic declarations never need transforming. // Entry points and intrinsic declarations never need transforming.
if (func->IsEntryPoint() || func->body == nullptr) { if (func->IsEntryPoint() || func->body == nullptr) {
return false; return false;
@ -49,7 +49,7 @@ bool NeedsTransform(const Program* program, const ast::Function* func) {
if (HasBehavior(program, s, sem::Behavior::kReturn)) { if (HasBehavior(program, s, sem::Behavior::kReturn)) {
// If this statement is itself a return, it will be the only exit point, // If this statement is itself a return, it will be the only exit point,
// so no need to apply the transform to the function. // so no need to apply the transform to the function.
if (s->Is<ast::ReturnStatement>()) { if (s->Is<ReturnStatement>()) {
return false; return false;
} else { } else {
// Apply the transform in all other cases. // Apply the transform in all other cases.
@ -78,7 +78,7 @@ class State {
ProgramBuilder& b; ProgramBuilder& b;
/// The function. /// The function.
const ast::Function* function; const Function* function;
/// The symbol for the return flag variable. /// The symbol for the return flag variable.
Symbol flag; Symbol flag;
@ -92,32 +92,32 @@ class State {
public: public:
/// Constructor /// Constructor
/// @param context the clone context /// @param context the clone context
State(CloneContext& context, const ast::Function* func) State(CloneContext& context, const Function* func)
: ctx(context), b(*ctx.dst), function(func) {} : ctx(context), b(*ctx.dst), function(func) {}
/// Process a statement (recursively). /// Process a statement (recursively).
void ProcessStatement(const ast::Statement* stmt) { void ProcessStatement(const Statement* stmt) {
if (stmt == nullptr || !HasBehavior(ctx.src, stmt, sem::Behavior::kReturn)) { if (stmt == nullptr || !HasBehavior(ctx.src, stmt, sem::Behavior::kReturn)) {
return; return;
} }
Switch( Switch(
stmt, [&](const ast::BlockStatement* block) { ProcessBlock(block); }, stmt, [&](const BlockStatement* block) { ProcessBlock(block); },
[&](const ast::CaseStatement* c) { ProcessStatement(c->body); }, [&](const CaseStatement* c) { ProcessStatement(c->body); },
[&](const ast::ForLoopStatement* f) { [&](const ForLoopStatement* f) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true); TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
ProcessStatement(f->body); ProcessStatement(f->body);
}, },
[&](const ast::IfStatement* i) { [&](const IfStatement* i) {
ProcessStatement(i->body); ProcessStatement(i->body);
ProcessStatement(i->else_statement); ProcessStatement(i->else_statement);
}, },
[&](const ast::LoopStatement* l) { [&](const LoopStatement* l) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true); TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
ProcessStatement(l->body); ProcessStatement(l->body);
}, },
[&](const ast::ReturnStatement* r) { [&](const ReturnStatement* r) {
utils::Vector<const ast::Statement*, 3> stmts; utils::Vector<const Statement*, 3> stmts;
// Set the return flag to signal that we have hit a return. // Set the return flag to signal that we have hit a return.
stmts.Push(b.Assign(b.Expr(flag), true)); stmts.Push(b.Assign(b.Expr(flag), true));
if (r->value) { if (r->value) {
@ -130,25 +130,25 @@ class State {
} }
ctx.Replace(r, b.Block(std::move(stmts))); ctx.Replace(r, b.Block(std::move(stmts)));
}, },
[&](const ast::SwitchStatement* s) { [&](const SwitchStatement* s) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true); TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
for (auto* c : s->body) { for (auto* c : s->body) {
ProcessStatement(c); ProcessStatement(c);
} }
}, },
[&](const ast::WhileStatement* w) { [&](const WhileStatement* w) {
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true); TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
ProcessStatement(w->body); ProcessStatement(w->body);
}, },
[&](Default) { TINT_ICE(Transform, b.Diagnostics()) << "unhandled statement type"; }); [&](Default) { TINT_ICE(Transform, b.Diagnostics()) << "unhandled statement type"; });
} }
void ProcessBlock(const ast::BlockStatement* block) { void ProcessBlock(const BlockStatement* block) {
// We will rebuild the contents of the block statement. // We will rebuild the contents of the block statement.
// We may introduce conditionals around statements that follow a statement with the // We may introduce conditionals around statements that follow a statement with the
// `Return` behavior, so build a stack of statement lists that represent the new // `Return` behavior, so build a stack of statement lists that represent the new
// (potentially nested) conditional blocks. // (potentially nested) conditional blocks.
utils::Vector<utils::Vector<const ast::Statement*, 8>, 8> new_stmts({{}}); utils::Vector<utils::Vector<const Statement*, 8>, 8> new_stmts({{}});
// Insert variables for the return flag and return value at the top of the function. // Insert variables for the return flag and return value at the top of the function.
if (block == function->body) { if (block == function->body) {
@ -173,8 +173,7 @@ class State {
if (is_in_loop_or_switch) { if (is_in_loop_or_switch) {
// We're in a loop/switch, and so we would have inserted a `break`. // We're in a loop/switch, and so we would have inserted a `break`.
// If we've just come out of a loop/switch statement, we need to `break` again. // If we've just come out of a loop/switch statement, we need to `break` again.
if (s->IsAnyOf<ast::LoopStatement, ast::ForLoopStatement, if (s->IsAnyOf<LoopStatement, ForLoopStatement, SwitchStatement>()) {
ast::SwitchStatement>()) {
// If the loop only has the 'Return' behavior, we can just unconditionally // If the loop only has the 'Return' behavior, we can just unconditionally
// break. Otherwise check the return flag. // break. Otherwise check the return flag.
if (HasBehavior(ctx.src, s, sem::Behavior::kNext)) { if (HasBehavior(ctx.src, s, sem::Behavior::kNext)) {
@ -194,7 +193,7 @@ class State {
// Descend the stack of new block statements, wrapping them in conditionals. // Descend the stack of new block statements, wrapping them in conditionals.
while (new_stmts.Length() > 1) { while (new_stmts.Length() > 1) {
const ast::IfStatement* i = nullptr; const IfStatement* i = nullptr;
if (new_stmts.Back().Length() > 0) { if (new_stmts.Back().Length() > 0) {
i = b.If(b.Not(b.Expr(flag)), b.Block(new_stmts.Back())); i = b.If(b.Not(b.Expr(flag)), b.Block(new_stmts.Back()));
} }

View File

@ -33,14 +33,14 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ModuleScopeVarToEntryPointParam)
namespace tint::ast::transform { namespace tint::ast::transform {
namespace { namespace {
using StructMemberList = utils::Vector<const ast::StructMember*, 8>; using StructMemberList = utils::Vector<const StructMember*, 8>;
// The name of the struct member for arrays that are wrapped in structures. // The name of the struct member for arrays that are wrapped in structures.
const char* kWrappedArrayMemberName = "arr"; const char* kWrappedArrayMemberName = "arr";
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* decl : program->AST().GlobalDeclarations()) { for (auto* decl : program->AST().GlobalDeclarations()) {
if (decl->Is<ast::Variable>()) { if (decl->Is<Variable>()) {
return true; return true;
} }
} }
@ -110,7 +110,7 @@ struct ModuleScopeVarToEntryPointParam::State {
/// @param workgroup_parameter_members reference to a list of a workgroup struct members /// @param workgroup_parameter_members reference to a list of a workgroup struct members
/// @param is_pointer output signalling whether the replacement is a pointer /// @param is_pointer output signalling whether the replacement is a pointer
/// @param is_wrapped output signalling whether the replacement is wrapped in a struct /// @param is_wrapped output signalling whether the replacement is wrapped in a struct
void ProcessVariableInEntryPoint(const ast::Function* func, void ProcessVariableInEntryPoint(const Function* func,
const sem::Variable* var, const sem::Variable* var,
Symbol new_var_symbol, Symbol new_var_symbol,
std::function<Symbol()> workgroup_param, std::function<Symbol()> workgroup_param,
@ -128,7 +128,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// For a texture or sampler variable, redeclare it as an entry point parameter. // For a texture or sampler variable, redeclare it as an entry point parameter.
// Disable entry point parameter validation. // Disable entry point parameter validation.
auto* disable_validation = auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter); ctx.dst->Disable(DisabledValidation::kEntryPointParameter);
auto attrs = ctx.Clone(var->Declaration()->attributes); auto attrs = ctx.Clone(var->Declaration()->attributes);
attrs.Push(disable_validation); attrs.Push(disable_validation);
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs); auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
@ -141,8 +141,8 @@ struct ModuleScopeVarToEntryPointParam::State {
// Variables into the Storage and Uniform address spaces are redeclared as entry // Variables into the Storage and Uniform address spaces are redeclared as entry
// point parameters with a pointer type. // point parameters with a pointer type.
auto attributes = ctx.Clone(var->Declaration()->attributes); auto attributes = ctx.Clone(var->Declaration()->attributes);
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter)); attributes.Push(ctx.dst->Disable(DisabledValidation::kEntryPointParameter));
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace)); attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
auto param_type = store_type(); auto param_type = store_type();
if (auto* arr = ty->As<type::Array>(); if (auto* arr = ty->As<type::Array>();
@ -190,7 +190,7 @@ struct ModuleScopeVarToEntryPointParam::State {
is_pointer = true; is_pointer = true;
} else { } else {
auto* disable_validation = auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace); ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace);
auto* initializer = ctx.Clone(var->Declaration()->initializer); auto* initializer = ctx.Clone(var->Declaration()->initializer);
auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer, auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer,
utils::Vector{disable_validation}); utils::Vector{disable_validation});
@ -218,7 +218,7 @@ struct ModuleScopeVarToEntryPointParam::State {
/// @param var the variable /// @param var the variable
/// @param new_var_symbol the symbol to use for the replacement /// @param new_var_symbol the symbol to use for the replacement
/// @param is_pointer output signalling whether the replacement is a pointer or not /// @param is_pointer output signalling whether the replacement is a pointer or not
void ProcessVariableInUserFunction(const ast::Function* func, void ProcessVariableInUserFunction(const Function* func,
const sem::Variable* var, const sem::Variable* var,
Symbol new_var_symbol, Symbol new_var_symbol,
bool& is_pointer) { bool& is_pointer) {
@ -247,7 +247,7 @@ struct ModuleScopeVarToEntryPointParam::State {
} }
// Use a pointer for non-handle types. // Use a pointer for non-handle types.
utils::Vector<const ast::Attribute*, 2> attributes; utils::Vector<const Attribute*, 2> attributes;
if (!ty->is_handle()) { if (!ty->is_handle()) {
param_type = sc == builtin::AddressSpace::kStorage param_type = sc == builtin::AddressSpace::kStorage
? ctx.dst->ty.pointer(param_type, sc, var->Access()) ? ctx.dst->ty.pointer(param_type, sc, var->Access())
@ -255,9 +255,8 @@ struct ModuleScopeVarToEntryPointParam::State {
is_pointer = true; is_pointer = true;
// Disable validation of the parameter's address space and of arguments passed to it. // Disable validation of the parameter's address space and of arguments passed to it.
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace)); attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
attributes.Push( attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreInvalidPointerArgument));
ctx.dst->Disable(ast::DisabledValidation::kIgnoreInvalidPointerArgument));
} }
// Redeclare the variable as a parameter. // Redeclare the variable as a parameter.
@ -271,18 +270,18 @@ struct ModuleScopeVarToEntryPointParam::State {
/// @param new_var the symbol to use for replacement /// @param new_var the symbol to use for replacement
/// @param is_pointer true if `new_var` is a pointer to the new variable /// @param is_pointer true if `new_var` is a pointer to the new variable
/// @param member_name if valid, the name of the struct member that holds this variable /// @param member_name if valid, the name of the struct member that holds this variable
void ReplaceUsesInFunction(const ast::Function* func, void ReplaceUsesInFunction(const Function* func,
const sem::Variable* var, const sem::Variable* var,
Symbol new_var, Symbol new_var,
bool is_pointer, bool is_pointer,
Symbol member_name) { Symbol member_name) {
for (auto* user : var->Users()) { for (auto* user : var->Users()) {
if (user->Stmt()->Function()->Declaration() == func) { if (user->Stmt()->Function()->Declaration() == func) {
const ast::Expression* expr = ctx.dst->Expr(new_var); const Expression* expr = ctx.dst->Expr(new_var);
if (is_pointer) { if (is_pointer) {
// If this identifier is used by an address-of operator, just remove the // If this identifier is used by an address-of operator, just remove the
// address-of instead of adding a deref, since we already have a pointer. // address-of instead of adding a deref, since we already have a pointer.
auto* ident = user->Declaration()->As<ast::IdentifierExpression>(); auto* ident = user->Declaration()->As<IdentifierExpression>();
if (ident_to_address_of_.count(ident) && !member_name.IsValid()) { if (ident_to_address_of_.count(ident) && !member_name.IsValid()) {
ctx.Replace(ident_to_address_of_[ident], expr); ctx.Replace(ident_to_address_of_[ident], expr);
continue; continue;
@ -302,19 +301,19 @@ struct ModuleScopeVarToEntryPointParam::State {
/// Process the module. /// Process the module.
void Process() { void Process() {
// Predetermine the list of function calls that need to be replaced. // Predetermine the list of function calls that need to be replaced.
using CallList = utils::Vector<const ast::CallExpression*, 8>; using CallList = utils::Vector<const CallExpression*, 8>;
std::unordered_map<const ast::Function*, CallList> calls_to_replace; std::unordered_map<const Function*, CallList> calls_to_replace;
utils::Vector<const ast::Function*, 8> functions_to_process; utils::Vector<const Function*, 8> functions_to_process;
// Collect private variables into a single structure. // Collect private variables into a single structure.
StructMemberList private_struct_members; StructMemberList private_struct_members;
utils::Vector<std::function<const ast::AssignmentStatement*()>, 4> private_initializers; utils::Vector<std::function<const AssignmentStatement*()>, 4> private_initializers;
std::unordered_set<const ast::Function*> uses_privates; std::unordered_set<const Function*> uses_privates;
// Build a list of functions that transitively reference any module-scope variables. // Build a list of functions that transitively reference any module-scope variables.
for (auto* decl : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) { for (auto* decl : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
if (auto* var = decl->As<ast::Var>()) { if (auto* var = decl->As<Var>()) {
auto* sem_var = ctx.src->Sem().Get(var); auto* sem_var = ctx.src->Sem().Get(var);
if (sem_var->AddressSpace() == builtin::AddressSpace::kPrivate) { if (sem_var->AddressSpace() == builtin::AddressSpace::kPrivate) {
// Create a member in the private variable struct. // Create a member in the private variable struct.
@ -335,7 +334,7 @@ struct ModuleScopeVarToEntryPointParam::State {
continue; continue;
} }
auto* func_ast = decl->As<ast::Function>(); auto* func_ast = decl->As<Function>();
if (!func_ast) { if (!func_ast) {
continue; continue;
} }
@ -376,11 +375,11 @@ struct ModuleScopeVarToEntryPointParam::State {
// TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can // TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can
// do this on the fly instead. // do this on the fly instead.
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* address_of = node->As<ast::UnaryOpExpression>(); auto* address_of = node->As<UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) { if (!address_of || address_of->op != UnaryOp::kAddressOf) {
continue; continue;
} }
if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) { if (auto* ident = address_of->expr->As<IdentifierExpression>()) {
ident_to_address_of_[ident] = address_of; ident_to_address_of_[ident] = address_of;
} }
} }
@ -414,11 +413,11 @@ struct ModuleScopeVarToEntryPointParam::State {
if (uses_privates.count(func_ast)) { if (uses_privates.count(func_ast)) {
if (is_entry_point) { if (is_entry_point) {
// Create a local declaration for the private variable struct. // Create a local declaration for the private variable struct.
auto* var = ctx.dst->Var( auto* var =
PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()), ctx.dst->Var(PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()),
builtin::AddressSpace::kPrivate, builtin::AddressSpace::kPrivate,
utils::Vector{ utils::Vector{
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace), ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
}); });
ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(var)); ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(var));
@ -482,7 +481,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Allow pointer aliasing if needed. // Allow pointer aliasing if needed.
if (needs_pointer_aliasing) { if (needs_pointer_aliasing) {
ctx.InsertBack(func_ast->attributes, ctx.InsertBack(func_ast->attributes,
ctx.dst->Disable(ast::DisabledValidation::kIgnorePointerAliasing)); ctx.dst->Disable(DisabledValidation::kIgnorePointerAliasing));
} }
if (!workgroup_parameter_members.IsEmpty()) { if (!workgroup_parameter_members.IsEmpty()) {
@ -492,11 +491,11 @@ struct ModuleScopeVarToEntryPointParam::State {
ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members)); ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
auto param_type = auto param_type =
ctx.dst->ty.pointer(ctx.dst->ty.Of(str), builtin::AddressSpace::kWorkgroup); ctx.dst->ty.pointer(ctx.dst->ty.Of(str), builtin::AddressSpace::kWorkgroup);
auto* param = ctx.dst->Param( auto* param =
workgroup_param(), param_type, ctx.dst->Param(workgroup_param(), param_type,
utils::Vector{ utils::Vector{
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter), ctx.dst->Disable(DisabledValidation::kEntryPointParameter),
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace), ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
}); });
ctx.InsertFront(func_ast->params, param); ctx.InsertFront(func_ast->params, param);
} }
@ -508,7 +507,7 @@ struct ModuleScopeVarToEntryPointParam::State {
// Pass the private variable struct pointer if needed. // Pass the private variable struct pointer if needed.
if (uses_privates.count(target_sem->Declaration())) { if (uses_privates.count(target_sem->Declaration())) {
const ast::Expression* arg = ctx.dst->Expr(PrivateStructVariableName()); const Expression* arg = ctx.dst->Expr(PrivateStructVariableName());
if (is_entry_point) { if (is_entry_point) {
arg = ctx.dst->AddressOf(arg); arg = ctx.dst->AddressOf(arg);
} }
@ -531,7 +530,7 @@ struct ModuleScopeVarToEntryPointParam::State {
auto new_var = it->second; auto new_var = it->second;
bool is_handle = target_var->Type()->UnwrapRef()->is_handle(); bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol); const Expression* arg = ctx.dst->Expr(new_var.symbol);
if (new_var.is_wrapped) { if (new_var.is_wrapped) {
// The variable is wrapped in a struct, so we need to pass a pointer to the // The variable is wrapped in a struct, so we need to pass a pointer to the
// struct member instead. // struct member instead.
@ -577,8 +576,7 @@ struct ModuleScopeVarToEntryPointParam::State {
std::unordered_set<const sem::Struct*> cloned_structs_; std::unordered_set<const sem::Struct*> cloned_structs_;
// Map from identifier expression to the address-of expression that uses it. // Map from identifier expression to the address-of expression that uses it.
std::unordered_map<const ast::IdentifierExpression*, const ast::UnaryOpExpression*> std::unordered_map<const IdentifierExpression*, const UnaryOpExpression*> ident_to_address_of_;
ident_to_address_of_;
// The name of the structure that contains all the module-scope private variables. // The name of the structure that contains all the module-scope private variables.
Symbol private_struct_name; Symbol private_struct_name;

View File

@ -143,7 +143,7 @@ struct MultiplanarExternalTexture::State {
// Replace the original texture_external binding with a texture_2d<f32> binding. // Replace the original texture_external binding with a texture_2d<f32> binding.
auto cloned_attributes = ctx.Clone(global->attributes); auto cloned_attributes = ctx.Clone(global->attributes);
const ast::Expression* cloned_initializer = ctx.Clone(global->initializer); const Expression* cloned_initializer = ctx.Clone(global->initializer);
auto* replacement = auto* replacement =
b.Var(syms.plane_0, b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32()), b.Var(syms.plane_0, b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32()),
@ -153,7 +153,7 @@ struct MultiplanarExternalTexture::State {
// We must update all the texture_external parameters for user declared functions. // We must update all the texture_external parameters for user declared functions.
for (auto* fn : ctx.src->AST().Functions()) { for (auto* fn : ctx.src->AST().Functions()) {
for (const ast::Variable* param : fn->params) { for (const Variable* param : fn->params) {
if (auto* sem_var = sem.Get(param)) { if (auto* sem_var = sem.Get(param)) {
if (!sem_var->Type()->UnwrapRef()->Is<type::ExternalTexture>()) { if (!sem_var->Type()->UnwrapRef()->Is<type::ExternalTexture>()) {
continue; continue;
@ -184,7 +184,7 @@ struct MultiplanarExternalTexture::State {
// Transform the external texture builtin calls into calls to the external texture // Transform the external texture builtin calls into calls to the external texture
// functions. // functions.
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>(); auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* builtin = call->Target()->As<sem::Builtin>(); auto* builtin = call->Target()->As<sem::Builtin>();
@ -305,10 +305,10 @@ struct MultiplanarExternalTexture::State {
/// @param call_type determines which function body to generate /// @param call_type determines which function body to generate
/// @returns a statement list that makes of the body of the chosen function /// @returns a statement list that makes of the body of the chosen function
auto buildTextureBuiltinBody(builtin::Function call_type) { auto buildTextureBuiltinBody(builtin::Function call_type) {
utils::Vector<const ast::Statement*, 16> stmts; utils::Vector<const Statement*, 16> stmts;
const ast::CallExpression* single_plane_call = nullptr; const CallExpression* single_plane_call = nullptr;
const ast::CallExpression* plane_0_call = nullptr; const CallExpression* plane_0_call = nullptr;
const ast::CallExpression* plane_1_call = nullptr; const CallExpression* plane_1_call = nullptr;
switch (call_type) { switch (call_type) {
case builtin::Function::kTextureSampleBaseClampToEdge: case builtin::Function::kTextureSampleBaseClampToEdge:
stmts.Push(b.Decl(b.Let( stmts.Push(b.Decl(b.Let(
@ -395,9 +395,9 @@ struct MultiplanarExternalTexture::State {
/// @param expr the call expression being transformed /// @param expr the call expression being transformed
/// @param syms the expanded symbols to be used in the new call /// @param syms the expanded symbols to be used in the new call
/// @returns a call expression to textureSampleExternal /// @returns a call expression to textureSampleExternal
const ast::CallExpression* createTextureSampleBaseClampToEdge(const ast::CallExpression* expr, const CallExpression* createTextureSampleBaseClampToEdge(const CallExpression* expr,
NewBindingSymbols syms) { NewBindingSymbols syms) {
const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]); const Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
if (TINT_UNLIKELY(expr->args.Length() != 3)) { if (TINT_UNLIKELY(expr->args.Length() != 3)) {
TINT_ICE(Transform, b.Diagnostics()) TINT_ICE(Transform, b.Diagnostics())
@ -443,7 +443,7 @@ struct MultiplanarExternalTexture::State {
/// @param call the call expression being transformed /// @param call the call expression being transformed
/// @param syms the expanded symbols to be used in the new call /// @param syms the expanded symbols to be used in the new call
/// @returns a call expression to textureLoadExternal /// @returns a call expression to textureLoadExternal
const ast::CallExpression* createTextureLoad(const sem::Call* call, NewBindingSymbols syms) { const CallExpression* createTextureLoad(const sem::Call* call, NewBindingSymbols syms) {
if (TINT_UNLIKELY(call->Arguments().Length() != 2)) { if (TINT_UNLIKELY(call->Arguments().Length() != 2)) {
TINT_ICE(Transform, b.Diagnostics()) TINT_ICE(Transform, b.Diagnostics())
<< "expected textureLoad call with a texture_external to have 2 arguments, found " << "expected textureLoad call with a texture_external to have 2 arguments, found "

View File

@ -33,7 +33,7 @@ namespace {
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) { for (auto* node : program->ASTNodes().Objects()) {
if (auto* attr = node->As<ast::BuiltinAttribute>()) { if (auto* attr = node->As<BuiltinAttribute>()) {
if (program->Sem().Get(attr)->Value() == builtin::BuiltinValue::kNumWorkgroups) { if (program->Sem().Get(attr)->Value() == builtin::BuiltinValue::kNumWorkgroups) {
return true; return true;
} }
@ -86,7 +86,7 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
std::unordered_set<Accessor, Accessor::Hasher> to_replace; std::unordered_set<Accessor, Accessor::Hasher> to_replace;
for (auto* func : src->AST().Functions()) { for (auto* func : src->AST().Functions()) {
// num_workgroups is only valid for compute stages. // num_workgroups is only valid for compute stages.
if (func->PipelineStage() != ast::PipelineStage::kCompute) { if (func->PipelineStage() != PipelineStage::kCompute) {
continue; continue;
} }
@ -126,7 +126,7 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
// Get (or create, on first call) the uniform buffer that will receive the // Get (or create, on first call) the uniform buffer that will receive the
// number of workgroups. // number of workgroups.
const ast::Variable* num_workgroups_ubo = nullptr; const Variable* num_workgroups_ubo = nullptr;
auto get_ubo = [&]() { auto get_ubo = [&]() {
if (!num_workgroups_ubo) { if (!num_workgroups_ubo) {
auto* num_workgroups_struct = auto* num_workgroups_struct =
@ -166,11 +166,11 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
// Now replace all the places where the builtins are accessed with the value // Now replace all the places where the builtins are accessed with the value
// loaded from the uniform buffer. // loaded from the uniform buffer.
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
auto* accessor = node->As<ast::MemberAccessorExpression>(); auto* accessor = node->As<MemberAccessorExpression>();
if (!accessor) { if (!accessor) {
continue; continue;
} }
auto* ident = accessor->object->As<ast::IdentifierExpression>(); auto* ident = accessor->object->As<IdentifierExpression>();
if (!ident) { if (!ident) {
continue; continue;
} }

View File

@ -94,7 +94,7 @@ struct PackedVec3::State {
/// Create a `__packed_vec3` type with the same element type as `ty`. /// Create a `__packed_vec3` type with the same element type as `ty`.
/// @param ty a three-element vector type /// @param ty a three-element vector type
/// @returns the new AST type /// @returns the new AST type
ast::Type MakePackedVec3(const type::Type* ty) { Type MakePackedVec3(const type::Type* ty) {
auto* vec = ty->As<type::Vector>(); auto* vec = ty->As<type::Vector>();
TINT_ASSERT(Transform, vec != nullptr && vec->Width() == 3); TINT_ASSERT(Transform, vec != nullptr && vec->Width() == 3);
return b.ty(builtin::Builtin::kPackedVec3, CreateASTTypeFor(ctx, vec->type())); return b.ty(builtin::Builtin::kPackedVec3, CreateASTTypeFor(ctx, vec->type()));
@ -109,10 +109,10 @@ struct PackedVec3::State {
/// @param ty the type to rewrite /// @param ty the type to rewrite
/// @param array_element `true` if this is being called for the element of an array /// @param array_element `true` if this is being called for the element of an array
/// @returns the new AST type, or nullptr if rewriting was not necessary /// @returns the new AST type, or nullptr if rewriting was not necessary
ast::Type RewriteType(const type::Type* ty, bool array_element = false) { Type RewriteType(const type::Type* ty, bool array_element = false) {
return Switch( return Switch(
ty, ty,
[&](const type::Vector* vec) -> ast::Type { [&](const type::Vector* vec) -> Type {
if (IsVec3(vec)) { if (IsVec3(vec)) {
if (array_element) { if (array_element) {
// Create a struct with a single `__packed_vec3` member. // Create a struct with a single `__packed_vec3` member.
@ -134,7 +134,7 @@ struct PackedVec3::State {
} }
return {}; return {};
}, },
[&](const type::Matrix* mat) -> ast::Type { [&](const type::Matrix* mat) -> Type {
// Rewrite the matrix as an array of columns that use the aligned wrapper struct. // Rewrite the matrix as an array of columns that use the aligned wrapper struct.
auto new_col_type = RewriteType(mat->ColumnType(), /* array_element */ true); auto new_col_type = RewriteType(mat->ColumnType(), /* array_element */ true);
if (new_col_type) { if (new_col_type) {
@ -142,11 +142,11 @@ struct PackedVec3::State {
} }
return {}; return {};
}, },
[&](const type::Array* arr) -> ast::Type { [&](const type::Array* arr) -> Type {
// Rewrite the array with the modified element type. // Rewrite the array with the modified element type.
auto new_type = RewriteType(arr->ElemType(), /* array_element */ true); auto new_type = RewriteType(arr->ElemType(), /* array_element */ true);
if (new_type) { if (new_type) {
utils::Vector<const ast::Attribute*, 1> attrs; utils::Vector<const Attribute*, 1> attrs;
if (arr->Count()->Is<type::RuntimeArrayCount>()) { if (arr->Count()->Is<type::RuntimeArrayCount>()) {
return b.ty.array(new_type, std::move(attrs)); return b.ty.array(new_type, std::move(attrs));
} else if (auto count = arr->ConstantCount()) { } else if (auto count = arr->ConstantCount()) {
@ -159,21 +159,21 @@ struct PackedVec3::State {
} }
return {}; return {};
}, },
[&](const type::Struct* str) -> ast::Type { [&](const type::Struct* str) -> Type {
if (ContainsVec3(str)) { if (ContainsVec3(str)) {
auto name = rewritten_structs.GetOrCreate(str, [&]() { auto name = rewritten_structs.GetOrCreate(str, [&]() {
utils::Vector<const ast::StructMember*, 4> members; utils::Vector<const StructMember*, 4> members;
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
// If the member type contains a vec3, rewrite it. // If the member type contains a vec3, rewrite it.
auto new_type = RewriteType(member->Type()); auto new_type = RewriteType(member->Type());
if (new_type) { if (new_type) {
// Copy the member attributes. // Copy the member attributes.
bool needs_align = true; bool needs_align = true;
utils::Vector<const ast::Attribute*, 4> attributes; utils::Vector<const Attribute*, 4> attributes;
if (auto* sem_mem = member->As<sem::StructMember>()) { if (auto* sem_mem = member->As<sem::StructMember>()) {
for (auto* attr : sem_mem->Declaration()->attributes) { for (auto* attr : sem_mem->Declaration()->attributes) {
if (attr->IsAnyOf<ast::StructMemberAlignAttribute, if (attr->IsAnyOf<StructMemberAlignAttribute,
ast::StructMemberOffsetAttribute>()) { StructMemberOffsetAttribute>()) {
needs_align = false; needs_align = false;
} }
attributes.Push(ctx.Clone(attr)); attributes.Push(ctx.Clone(attr));
@ -219,12 +219,12 @@ struct PackedVec3::State {
Symbol MakePackUnpackHelper( Symbol MakePackUnpackHelper(
const char* name_prefix, const char* name_prefix,
const type::Type* ty, const type::Type* ty,
const std::function<const ast::Expression*(const ast::Expression*, const type::Type*)>& const std::function<const Expression*(const Expression*, const type::Type*)>&
pack_or_unpack_element, pack_or_unpack_element,
const std::function<ast::Type()>& in_type, const std::function<Type()>& in_type,
const std::function<ast::Type()>& out_type) { const std::function<Type()>& out_type) {
// Allocate a variable to hold the return value of the function. // Allocate a variable to hold the return value of the function.
utils::Vector<const ast::Statement*, 4> statements; utils::Vector<const Statement*, 4> statements;
statements.Push(b.Decl(b.Var("result", out_type()))); statements.Push(b.Decl(b.Var("result", out_type())));
// Helper that generates a loop to copy and pack/unpack elements of an array to the result: // Helper that generates a loop to copy and pack/unpack elements of an array to the result:
@ -256,7 +256,7 @@ struct PackedVec3::State {
[&](const type::Struct* str) { [&](const type::Struct* str) {
// Copy the struct members over one at a time, packing/unpacking as necessary. // Copy the struct members over one at a time, packing/unpacking as necessary.
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
const ast::Expression* element = const Expression* element =
b.MemberAccessor("in", b.Ident(ctx.Clone(member->Name()))); b.MemberAccessor("in", b.Ident(ctx.Clone(member->Name())));
if (ContainsVec3(member->Type())) { if (ContainsVec3(member->Type())) {
element = pack_or_unpack_element(element, member->Type()); element = pack_or_unpack_element(element, member->Type());
@ -280,16 +280,16 @@ struct PackedVec3::State {
/// @param expr the composite value expression to unpack /// @param expr the composite value expression to unpack
/// @param ty the unpacked type /// @param ty the unpacked type
/// @returns an expression that holds the unpacked value /// @returns an expression that holds the unpacked value
const ast::Expression* UnpackComposite(const ast::Expression* expr, const type::Type* ty) { const Expression* UnpackComposite(const Expression* expr, const type::Type* ty) {
auto helper = unpack_helpers.GetOrCreate(ty, [&]() { auto helper = unpack_helpers.GetOrCreate(ty, [&]() {
return MakePackUnpackHelper( return MakePackUnpackHelper(
"tint_unpack_vec3_in_composite", ty, "tint_unpack_vec3_in_composite", ty,
[&](const ast::Expression* element, [&](const Expression* element,
const type::Type* element_type) -> const ast::Expression* { const type::Type* element_type) -> const Expression* {
if (element_type->Is<type::Vector>()) { if (element_type->Is<type::Vector>()) {
// Unpack a `__packed_vec3` by casting it to a regular vec3. // Unpack a `__packed_vec3` by casting it to a regular vec3.
// If it is an array element, extract the vector from the wrapper struct. // If it is an array element, extract the vector from the wrapper struct.
if (element->Is<ast::IndexAccessorExpression>()) { if (element->Is<IndexAccessorExpression>()) {
element = b.MemberAccessor(element, kStructMemberName); element = b.MemberAccessor(element, kStructMemberName);
} }
return b.Call(CreateASTTypeFor(ctx, element_type), element); return b.Call(CreateASTTypeFor(ctx, element_type), element);
@ -308,17 +308,17 @@ struct PackedVec3::State {
/// @param expr the composite value expression to pack /// @param expr the composite value expression to pack
/// @param ty the unpacked type /// @param ty the unpacked type
/// @returns an expression that holds the packed value /// @returns an expression that holds the packed value
const ast::Expression* PackComposite(const ast::Expression* expr, const type::Type* ty) { const Expression* PackComposite(const Expression* expr, const type::Type* ty) {
auto helper = pack_helpers.GetOrCreate(ty, [&]() { auto helper = pack_helpers.GetOrCreate(ty, [&]() {
return MakePackUnpackHelper( return MakePackUnpackHelper(
"tint_pack_vec3_in_composite", ty, "tint_pack_vec3_in_composite", ty,
[&](const ast::Expression* element, [&](const Expression* element,
const type::Type* element_type) -> const ast::Expression* { const type::Type* element_type) -> const Expression* {
if (element_type->Is<type::Vector>()) { if (element_type->Is<type::Vector>()) {
// Pack a vector element by casting it to a packed_vec3. // Pack a vector element by casting it to a packed_vec3.
// If it is an array element, construct a wrapper struct. // If it is an array element, construct a wrapper struct.
auto* packed = b.Call(MakePackedVec3(element_type), element); auto* packed = b.Call(MakePackedVec3(element_type), element);
if (element->Is<ast::IndexAccessorExpression>()) { if (element->Is<IndexAccessorExpression>()) {
packed = b.Call(RewriteType(element_type, true), packed); packed = b.Call(RewriteType(element_type, true), packed);
} }
return packed; return packed;
@ -400,7 +400,7 @@ struct PackedVec3::State {
}, },
[&](const sem::Statement* stmt) { [&](const sem::Statement* stmt) {
// Pack the RHS of assignment statements that are writing to packed types. // Pack the RHS of assignment statements that are writing to packed types.
if (auto* assign = stmt->Declaration()->As<ast::AssignmentStatement>()) { if (auto* assign = stmt->Declaration()->As<AssignmentStatement>()) {
auto* lhs = sem.GetVal(assign->lhs); auto* lhs = sem.GetVal(assign->lhs);
auto* rhs = sem.GetVal(assign->rhs); auto* rhs = sem.GetVal(assign->rhs);
if (!ContainsVec3(rhs->Type()) || if (!ContainsVec3(rhs->Type()) ||
@ -463,7 +463,7 @@ struct PackedVec3::State {
for (auto* expr : to_unpack_sorted) { for (auto* expr : to_unpack_sorted) {
TINT_ASSERT(Transform, ContainsVec3(expr->Type())); TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
auto* packed = ctx.Clone(expr->Declaration()); auto* packed = ctx.Clone(expr->Declaration());
const ast::Expression* unpacked = nullptr; const Expression* unpacked = nullptr;
if (IsVec3(expr->Type())) { if (IsVec3(expr->Type())) {
if (expr->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) { if (expr->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
// If we are unpacking a vec3 from an array element, extract the vector from the // If we are unpacking a vec3 from an array element, extract the vector from the
@ -484,7 +484,7 @@ struct PackedVec3::State {
for (auto* expr : to_pack_sorted) { for (auto* expr : to_pack_sorted) {
TINT_ASSERT(Transform, ContainsVec3(expr->Type())); TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
auto* unpacked = ctx.Clone(expr->Declaration()); auto* unpacked = ctx.Clone(expr->Declaration());
const ast::Expression* packed = nullptr; const Expression* packed = nullptr;
if (IsVec3(expr->Type())) { if (IsVec3(expr->Type())) {
// Cast the regular vec3 to a packed vector type. // Cast the regular vec3 to a packed vector type.
packed = b.Call(MakePackedVec3(expr->Type()), unpacked); packed = b.Call(MakePackedVec3(expr->Type()), unpacked);

View File

@ -33,8 +33,8 @@ namespace tint::ast::transform {
namespace { namespace {
void CreatePadding(utils::Vector<const ast::StructMember*, 8>* new_members, void CreatePadding(utils::Vector<const StructMember*, 8>* new_members,
utils::Hashset<const ast::StructMember*, 8>* padding_members, utils::Hashset<const StructMember*, 8>* padding_members,
ProgramBuilder* b, ProgramBuilder* b,
uint32_t bytes) { uint32_t bytes) {
const size_t count = bytes / 4u; const size_t count = bytes / 4u;
@ -59,17 +59,17 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto& sem = src->Sem(); auto& sem = src->Sem();
std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs; std::unordered_map<const Struct*, const Struct*> replaced_structs;
utils::Hashset<const ast::StructMember*, 8> padding_members; utils::Hashset<const StructMember*, 8> padding_members;
ctx.ReplaceAll([&](const ast::Struct* ast_str) -> const ast::Struct* { ctx.ReplaceAll([&](const Struct* ast_str) -> const Struct* {
auto* str = sem.Get(ast_str); auto* str = sem.Get(ast_str);
if (!str || !str->IsHostShareable()) { if (!str || !str->IsHostShareable()) {
return nullptr; return nullptr;
} }
uint32_t offset = 0; uint32_t offset = 0;
bool has_runtime_sized_array = false; bool has_runtime_sized_array = false;
utils::Vector<const ast::StructMember*, 8> new_members; utils::Vector<const StructMember*, 8> new_members;
for (auto* mem : str->Members()) { for (auto* mem : str->Members()) {
auto name = mem->Name().Name(); auto name = mem->Name().Name();
@ -104,19 +104,18 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset); CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
} }
utils::Vector<const ast::Attribute*, 1> struct_attribs; utils::Vector<const Attribute*, 1> struct_attribs;
if (!padding_members.IsEmpty()) { if (!padding_members.IsEmpty()) {
struct_attribs = struct_attribs = utils::Vector{b.Disable(DisabledValidation::kIgnoreStructMemberLimit)};
utils::Vector{b.Disable(ast::DisabledValidation::kIgnoreStructMemberLimit)};
} }
auto* new_struct = b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members), auto* new_struct = b.create<Struct>(ctx.Clone(ast_str->name), std::move(new_members),
std::move(struct_attribs)); std::move(struct_attribs));
replaced_structs[ast_str] = new_struct; replaced_structs[ast_str] = new_struct;
return new_struct; return new_struct;
}); });
ctx.ReplaceAll([&](const ast::CallExpression* ast_call) -> const ast::CallExpression* { ctx.ReplaceAll([&](const CallExpression* ast_call) -> const CallExpression* {
if (ast_call->args.Length() == 0) { if (ast_call->args.Length() == 0) {
return nullptr; return nullptr;
} }
@ -139,7 +138,7 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
return nullptr; return nullptr;
} }
utils::Vector<const ast::Expression*, 8> new_args; utils::Vector<const Expression*, 8> new_args;
auto* arg = ast_call->args.begin(); auto* arg = ast_call->args.begin();
for (auto* member : new_struct->members) { for (auto* member : new_struct->members) {

View File

@ -44,13 +44,13 @@ struct PreservePadding::State {
/// @returns the ApplyResult /// @returns the ApplyResult
ApplyResult Run() { ApplyResult Run() {
// Gather a list of assignments that need to be transformed. // Gather a list of assignments that need to be transformed.
std::unordered_set<const ast::AssignmentStatement*> assignments_to_transform; std::unordered_set<const AssignmentStatement*> assignments_to_transform;
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch( Switch(
node, // node, //
[&](const ast::AssignmentStatement* assign) { [&](const AssignmentStatement* assign) {
auto* ty = sem.GetVal(assign->lhs)->Type(); auto* ty = sem.GetVal(assign->lhs)->Type();
if (assign->lhs->Is<ast::PhonyExpression>()) { if (assign->lhs->Is<PhonyExpression>()) {
// Ignore phony assignment. // Ignore phony assignment.
return; return;
} }
@ -65,7 +65,7 @@ struct PreservePadding::State {
assignments_to_transform.insert(assign); assignments_to_transform.insert(assign);
} }
}, },
[&](const ast::Enable* enable) { [&](const Enable* enable) {
// Check if the full pointer parameters extension is already enabled. // Check if the full pointer parameters extension is already enabled.
if (enable->HasExtension( if (enable->HasExtension(
builtin::Extension::kChromiumExperimentalFullPtrParameters)) { builtin::Extension::kChromiumExperimentalFullPtrParameters)) {
@ -78,7 +78,7 @@ struct PreservePadding::State {
} }
// Replace all assignments that include padding with decomposed versions. // Replace all assignments that include padding with decomposed versions.
ctx.ReplaceAll([&](const ast::AssignmentStatement* assign) -> const ast::Statement* { ctx.ReplaceAll([&](const AssignmentStatement* assign) -> const Statement* {
if (!assignments_to_transform.count(assign)) { if (!assignments_to_transform.count(assign)) {
return nullptr; return nullptr;
} }
@ -96,9 +96,9 @@ struct PreservePadding::State {
/// @param lhs the lhs expression (in the destination program) /// @param lhs the lhs expression (in the destination program)
/// @param rhs the rhs expression (in the destination program) /// @param rhs the rhs expression (in the destination program)
/// @returns the statement that performs the assignment /// @returns the statement that performs the assignment
const ast::Statement* MakeAssignment(const type::Type* ty, const Statement* MakeAssignment(const type::Type* ty,
const ast::Expression* lhs, const Expression* lhs,
const ast::Expression* rhs) { const Expression* rhs) {
if (!HasPadding(ty)) { if (!HasPadding(ty)) {
// No padding - use a regular assignment. // No padding - use a regular assignment.
return b.Assign(lhs, rhs); return b.Assign(lhs, rhs);
@ -120,7 +120,7 @@ struct PreservePadding::State {
EnableExtension(); EnableExtension();
auto helper = helpers.GetOrCreate(ty, [&]() { auto helper = helpers.GetOrCreate(ty, [&]() {
auto helper_name = b.Symbols().New("assign_and_preserve_padding"); auto helper_name = b.Symbols().New("assign_and_preserve_padding");
utils::Vector<const ast::Parameter*, 2> params = { utils::Vector<const Parameter*, 2> params = {
b.Param(kDestParamName, b.Param(kDestParamName,
b.ty.pointer(CreateASTTypeFor(ctx, ty), builtin::AddressSpace::kStorage, b.ty.pointer(CreateASTTypeFor(ctx, ty), builtin::AddressSpace::kStorage,
builtin::Access::kReadWrite)), builtin::Access::kReadWrite)),
@ -137,7 +137,7 @@ struct PreservePadding::State {
[&](const type::Array* arr) { [&](const type::Array* arr) {
// Call a helper function that uses a loop to assigns each element separately. // Call a helper function that uses a loop to assigns each element separately.
return call_helper([&]() { return call_helper([&]() {
utils::Vector<const ast::Statement*, 8> body; utils::Vector<const Statement*, 8> body;
auto* idx = b.Var("i", b.Expr(0_u)); auto* idx = b.Var("i", b.Expr(0_u));
body.Push( body.Push(
b.For(b.Decl(idx), b.LessThan(idx, u32(arr->ConstantCount().value())), b.For(b.Decl(idx), b.LessThan(idx, u32(arr->ConstantCount().value())),
@ -151,7 +151,7 @@ struct PreservePadding::State {
[&](const type::Matrix* mat) { [&](const type::Matrix* mat) {
// Call a helper function that assigns each column separately. // Call a helper function that assigns each column separately.
return call_helper([&]() { return call_helper([&]() {
utils::Vector<const ast::Statement*, 4> body; utils::Vector<const Statement*, 4> body;
for (uint32_t i = 0; i < mat->columns(); i++) { for (uint32_t i = 0; i < mat->columns(); i++) {
body.Push(MakeAssignment(mat->ColumnType(), body.Push(MakeAssignment(mat->ColumnType(),
b.IndexAccessor(b.Deref(kDestParamName), u32(i)), b.IndexAccessor(b.Deref(kDestParamName), u32(i)),
@ -163,7 +163,7 @@ struct PreservePadding::State {
[&](const type::Struct* str) { [&](const type::Struct* str) {
// Call a helper function that assigns each member separately. // Call a helper function that assigns each member separately.
return call_helper([&]() { return call_helper([&]() {
utils::Vector<const ast::Statement*, 8> body; utils::Vector<const Statement*, 8> body;
for (auto member : str->Members()) { for (auto member : str->Members()) {
auto name = member->Name().Name(); auto name = member->Name().Name();
body.Push(MakeAssignment(member->Type(), body.Push(MakeAssignment(member->Type(),

View File

@ -69,7 +69,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
} }
} }
if (auto* src_var_decl = expr->Stmt()->Declaration()->As<ast::VariableDeclStatement>()) { if (auto* src_var_decl = expr->Stmt()->Declaration()->As<VariableDeclStatement>()) {
if (src_var_decl->variable->initializer == expr->Declaration()) { if (src_var_decl->variable->initializer == expr->Declaration()) {
// This statement is just a variable declaration with the initializer as the // This statement is just a variable declaration with the initializer as the
// initializer value. This is what we're attempting to transform to, and so // initializer value. This is what we're attempting to transform to, and so
@ -84,7 +84,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
// A list of expressions that should be hoisted. // A list of expressions that should be hoisted.
utils::Vector<const sem::ValueExpression*, 32> to_hoist; utils::Vector<const sem::ValueExpression*, 32> to_hoist;
// A set of expressions that are constant, which _may_ need to be hoisted. // A set of expressions that are constant, which _may_ need to be hoisted.
utils::Hashset<const ast::Expression*, 32> const_chains; utils::Hashset<const Expression*, 32> const_chains;
// Walk the AST nodes. This order guarantees that leaf-expressions are visited first. // Walk the AST nodes. This order guarantees that leaf-expressions are visited first.
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
@ -104,11 +104,9 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
// visit leaf-expressions first, this means the content of const_chains only // visit leaf-expressions first, this means the content of const_chains only
// contains the outer-most constant expressions. // contains the outer-most constant expressions.
auto* expr = sem->Declaration(); auto* expr = sem->Declaration();
bool ok = ast::TraverseExpressions( bool ok = TraverseExpressions(expr, b.Diagnostics(), [&](const Expression* child) {
expr, b.Diagnostics(), [&](const ast::Expression* child) {
const_chains.Remove(child); const_chains.Remove(child);
return child == expr ? ast::TraverseAction::Descend return child == expr ? TraverseAction::Descend : TraverseAction::Skip;
: ast::TraverseAction::Skip;
}); });
if (!ok) { if (!ok) {
return Program(std::move(b)); return Program(std::move(b));

View File

@ -97,49 +97,48 @@ struct DecomposeSideEffects : tint::utils::Castable<PromoteSideEffectsToDecl, Tr
// need to be hoisted to ensure order of evaluation, both those that give // need to be hoisted to ensure order of evaluation, both those that give
// side-effects, as well as those that receive, and returns a set of these // side-effects, as well as those that receive, and returns a set of these
// expressions. // expressions.
using ToHoistSet = std::unordered_set<const ast::Expression*>; using ToHoistSet = std::unordered_set<const Expression*>;
class DecomposeSideEffects::CollectHoistsState : public StateBase { class DecomposeSideEffects::CollectHoistsState : public StateBase {
// Expressions to hoist because they either cause or receive side-effects. // Expressions to hoist because they either cause or receive side-effects.
ToHoistSet to_hoist; ToHoistSet to_hoist;
// Used to mark expressions as not or no longer having side-effects. // Used to mark expressions as not or no longer having side-effects.
std::unordered_set<const ast::Expression*> no_side_effects; std::unordered_set<const Expression*> no_side_effects;
// Returns true if `expr` has side-effects. Unlike invoking // Returns true if `expr` has side-effects. Unlike invoking
// sem::ValueExpression::HasSideEffects(), this function takes into account whether // sem::ValueExpression::HasSideEffects(), this function takes into account whether
// `expr` has been hoisted, returning false in that case. Furthermore, it // `expr` has been hoisted, returning false in that case. Furthermore, it
// returns the correct result on parent expression nodes by traversing the // returns the correct result on parent expression nodes by traversing the
// expression tree, memoizing the results to ensure O(1) amortized lookup. // expression tree, memoizing the results to ensure O(1) amortized lookup.
bool HasSideEffects(const ast::Expression* expr) { bool HasSideEffects(const Expression* expr) {
if (no_side_effects.count(expr)) { if (no_side_effects.count(expr)) {
return false; return false;
} }
return Switch( return Switch(
expr, expr, [&](const CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); },
[&](const ast::CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); }, [&](const BinaryExpression* e) {
[&](const ast::BinaryExpression* e) {
if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) { if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) {
return true; return true;
} }
no_side_effects.insert(e); no_side_effects.insert(e);
return false; return false;
}, },
[&](const ast::IndexAccessorExpression* e) { [&](const IndexAccessorExpression* e) {
if (HasSideEffects(e->object) || HasSideEffects(e->index)) { if (HasSideEffects(e->object) || HasSideEffects(e->index)) {
return true; return true;
} }
no_side_effects.insert(e); no_side_effects.insert(e);
return false; return false;
}, },
[&](const ast::MemberAccessorExpression* e) { [&](const MemberAccessorExpression* e) {
if (HasSideEffects(e->object)) { if (HasSideEffects(e->object)) {
return true; return true;
} }
no_side_effects.insert(e); no_side_effects.insert(e);
return false; return false;
}, },
[&](const ast::BitcastExpression* e) { // [&](const BitcastExpression* e) { //
if (HasSideEffects(e->expr)) { if (HasSideEffects(e->expr)) {
return true; return true;
} }
@ -147,22 +146,22 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return false; return false;
}, },
[&](const ast::UnaryOpExpression* e) { // [&](const UnaryOpExpression* e) { //
if (HasSideEffects(e->expr)) { if (HasSideEffects(e->expr)) {
return true; return true;
} }
no_side_effects.insert(e); no_side_effects.insert(e);
return false; return false;
}, },
[&](const ast::IdentifierExpression* e) { [&](const IdentifierExpression* e) {
no_side_effects.insert(e); no_side_effects.insert(e);
return false; return false;
}, },
[&](const ast::LiteralExpression* e) { [&](const LiteralExpression* e) {
no_side_effects.insert(e); no_side_effects.insert(e);
return false; return false;
}, },
[&](const ast::PhonyExpression* e) { [&](const PhonyExpression* e) {
no_side_effects.insert(e); no_side_effects.insert(e);
return false; return false;
}, },
@ -173,14 +172,14 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
} }
// Adds `e` to `to_hoist` for hoisting to a let later on. // Adds `e` to `to_hoist` for hoisting to a let later on.
void Hoist(const ast::Expression* e) { void Hoist(const Expression* e) {
no_side_effects.insert(e); no_side_effects.insert(e);
to_hoist.emplace(e); to_hoist.emplace(e);
} }
// Hoists any expressions in `maybe_hoist` and clears it // Hoists any expressions in `maybe_hoist` and clears it
template <size_t N> template <size_t N>
void Flush(tint::utils::Vector<const ast::Expression*, N>& maybe_hoist) { void Flush(tint::utils::Vector<const Expression*, N>& maybe_hoist) {
for (auto* m : maybe_hoist) { for (auto* m : maybe_hoist) {
Hoist(m); Hoist(m);
} }
@ -198,13 +197,13 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
// over-hoist the lhs expressions, as these may be be chained to refer to a // over-hoist the lhs expressions, as these may be be chained to refer to a
// single memory location. // single memory location.
template <size_t N> template <size_t N>
bool ProcessExpression(const ast::Expression* expr, bool ProcessExpression(const Expression* expr,
tint::utils::Vector<const ast::Expression*, N>& maybe_hoist) { tint::utils::Vector<const Expression*, N>& maybe_hoist) {
auto process = [&](const ast::Expression* e) -> bool { auto process = [&](const Expression* e) -> bool {
return ProcessExpression(e, maybe_hoist); return ProcessExpression(e, maybe_hoist);
}; };
auto default_process = [&](const ast::Expression* e) { auto default_process = [&](const Expression* e) {
auto maybe = process(e); auto maybe = process(e);
if (maybe) { if (maybe) {
maybe_hoist.Push(e); maybe_hoist.Push(e);
@ -215,7 +214,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return false; return false;
}; };
auto binary_process = [&](const ast::Expression* lhs, const ast::Expression* rhs) { auto binary_process = [&](const Expression* lhs, const Expression* rhs) {
// If neither side causes side-effects, but at least one receives them, // If neither side causes side-effects, but at least one receives them,
// let parent node hoist. This avoids over-hoisting side-effect receivers // let parent node hoist. This avoids over-hoisting side-effect receivers
// of compound binary expressions (e.g. for "((a && b) && c) && f()", we // of compound binary expressions (e.g. for "((a && b) && c) && f()", we
@ -235,8 +234,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return false; return false;
}; };
auto accessor_process = [&](const ast::Expression* lhs, auto accessor_process = [&](const Expression* lhs, const Expression* rhs = nullptr) {
const ast::Expression* rhs = nullptr) {
auto maybe = process(lhs); auto maybe = process(lhs);
// If lhs is a variable, let parent node hoist otherwise flush it right // If lhs is a variable, let parent node hoist otherwise flush it right
// away. This is to avoid over-hoisting the lhs of accessor chains (e.g. // away. This is to avoid over-hoisting the lhs of accessor chains (e.g.
@ -255,7 +253,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
return Switch( return Switch(
expr, expr,
[&](const ast::CallExpression* e) -> bool { [&](const CallExpression* e) -> bool {
// We eagerly flush any variables in maybe_hoist for the current // We eagerly flush any variables in maybe_hoist for the current
// call expression. Then we scope maybe_hoist to the processing of // call expression. Then we scope maybe_hoist to the processing of
// the call args. This ensures that given: g(c, a(0), d) we hoist // the call args. This ensures that given: g(c, a(0), d) we hoist
@ -276,7 +274,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
// no_side_effects() first. // no_side_effects() first.
return true; return true;
}, },
[&](const ast::IdentifierExpression* e) { [&](const IdentifierExpression* e) {
if (auto* sem_e = sem.GetVal(e)) { if (auto* sem_e = sem.GetVal(e)) {
if (auto* var_user = sem_e->UnwrapLoad()->As<sem::VariableUser>()) { if (auto* var_user = sem_e->UnwrapLoad()->As<sem::VariableUser>()) {
// Don't hoist constants. // Don't hoist constants.
@ -297,7 +295,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
} }
return false; return false;
}, },
[&](const ast::BinaryExpression* e) { [&](const BinaryExpression* e) {
if (e->IsLogical() && HasSideEffects(e)) { if (e->IsLogical() && HasSideEffects(e)) {
// Don't hoist children of logical binary expressions with // Don't hoist children of logical binary expressions with
// side-effects. These will be handled by DecomposeState. // side-effects. These will be handled by DecomposeState.
@ -307,27 +305,25 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
} }
return binary_process(e->lhs, e->rhs); return binary_process(e->lhs, e->rhs);
}, },
[&](const ast::BitcastExpression* e) { // [&](const BitcastExpression* e) { //
return process(e->expr); return process(e->expr);
}, },
[&](const ast::UnaryOpExpression* e) { // [&](const UnaryOpExpression* e) { //
auto r = process(e->expr); auto r = process(e->expr);
// Don't hoist address-of expressions. // Don't hoist address-of expressions.
// E.g. for "g(&b, a(0))", we hoist "a(0)" only. // E.g. for "g(&b, a(0))", we hoist "a(0)" only.
if (e->op == ast::UnaryOp::kAddressOf) { if (e->op == UnaryOp::kAddressOf) {
return false; return false;
} }
return r; return r;
}, },
[&](const ast::IndexAccessorExpression* e) { [&](const IndexAccessorExpression* e) { return accessor_process(e->object, e->index); },
return accessor_process(e->object, e->index); [&](const MemberAccessorExpression* e) { return accessor_process(e->object); },
}, [&](const LiteralExpression*) {
[&](const ast::MemberAccessorExpression* e) { return accessor_process(e->object); },
[&](const ast::LiteralExpression*) {
// Leaf // Leaf
return false; return false;
}, },
[&](const ast::PhonyExpression*) { [&](const PhonyExpression*) {
// Leaf // Leaf
return false; return false;
}, },
@ -338,12 +334,12 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
} }
// Starts the recursive processing of a statement's expression(s) to hoist side-effects to lets. // Starts the recursive processing of a statement's expression(s) to hoist side-effects to lets.
void ProcessExpression(const ast::Expression* expr) { void ProcessExpression(const Expression* expr) {
if (!expr) { if (!expr) {
return; return;
} }
tint::utils::Vector<const ast::Expression*, 8> maybe_hoist; tint::utils::Vector<const Expression*, 8> maybe_hoist;
ProcessExpression(expr, maybe_hoist); ProcessExpression(expr, maybe_hoist);
} }
@ -354,31 +350,31 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
// Traverse all statements, recursively processing their expression tree(s) // Traverse all statements, recursively processing their expression tree(s)
// to hoist side-effects to lets. // to hoist side-effects to lets.
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
auto* stmt = node->As<ast::Statement>(); auto* stmt = node->As<Statement>();
if (!stmt) { if (!stmt) {
continue; continue;
} }
Switch( Switch(
stmt, // stmt, //
[&](const ast::AssignmentStatement* s) { [&](const AssignmentStatement* s) {
tint::utils::Vector<const ast::Expression*, 8> maybe_hoist; tint::utils::Vector<const Expression*, 8> maybe_hoist;
ProcessExpression(s->lhs, maybe_hoist); ProcessExpression(s->lhs, maybe_hoist);
ProcessExpression(s->rhs, maybe_hoist); ProcessExpression(s->rhs, maybe_hoist);
}, },
[&](const ast::CallStatement* s) { // [&](const CallStatement* s) { //
ProcessExpression(s->expr); ProcessExpression(s->expr);
}, },
[&](const ast::ForLoopStatement* s) { ProcessExpression(s->condition); }, [&](const ForLoopStatement* s) { ProcessExpression(s->condition); },
[&](const ast::WhileStatement* s) { ProcessExpression(s->condition); }, [&](const WhileStatement* s) { ProcessExpression(s->condition); },
[&](const ast::IfStatement* s) { // [&](const IfStatement* s) { //
ProcessExpression(s->condition); ProcessExpression(s->condition);
}, },
[&](const ast::ReturnStatement* s) { // [&](const ReturnStatement* s) { //
ProcessExpression(s->value); ProcessExpression(s->value);
}, },
[&](const ast::SwitchStatement* s) { ProcessExpression(s->condition); }, [&](const SwitchStatement* s) { ProcessExpression(s->condition); },
[&](const ast::VariableDeclStatement* s) { [&](const VariableDeclStatement* s) {
ProcessExpression(s->variable->initializer); ProcessExpression(s->variable->initializer);
}); });
} }
@ -394,20 +390,20 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
ToHoistSet to_hoist; ToHoistSet to_hoist;
// Returns true if `binary_expr` should be decomposed for short-circuit eval. // Returns true if `binary_expr` should be decomposed for short-circuit eval.
bool IsLogicalWithSideEffects(const ast::BinaryExpression* binary_expr) { bool IsLogicalWithSideEffects(const BinaryExpression* binary_expr) {
return binary_expr->IsLogical() && (sem.GetVal(binary_expr->lhs)->HasSideEffects() || return binary_expr->IsLogical() && (sem.GetVal(binary_expr->lhs)->HasSideEffects() ||
sem.GetVal(binary_expr->rhs)->HasSideEffects()); sem.GetVal(binary_expr->rhs)->HasSideEffects());
} }
// Recursive function used to decompose an expression for short-circuit eval. // Recursive function used to decompose an expression for short-circuit eval.
template <size_t N> template <size_t N>
const ast::Expression* Decompose(const ast::Expression* expr, const Expression* Decompose(const Expression* expr,
tint::utils::Vector<const ast::Statement*, N>* curr_stmts) { tint::utils::Vector<const Statement*, N>* curr_stmts) {
// Helper to avoid passing in same args. // Helper to avoid passing in same args.
auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); }; auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); };
// Clones `expr`, possibly hoisting it to a let. // Clones `expr`, possibly hoisting it to a let.
auto clone_maybe_hoisted = [&](const ast::Expression* e) -> const ast::Expression* { auto clone_maybe_hoisted = [&](const Expression* e) -> const Expression* {
if (to_hoist.count(e)) { if (to_hoist.count(e)) {
auto name = b.Symbols().New(); auto name = b.Symbols().New();
auto* v = b.Let(name, ctx.Clone(e)); auto* v = b.Let(name, ctx.Clone(e));
@ -420,7 +416,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return Switch( return Switch(
expr, expr,
[&](const ast::BinaryExpression* bin_expr) -> const ast::Expression* { [&](const BinaryExpression* bin_expr) -> const Expression* {
if (!IsLogicalWithSideEffects(bin_expr)) { if (!IsLogicalWithSideEffects(bin_expr)) {
// No short-circuit, emit usual binary expr // No short-circuit, emit usual binary expr
ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs)); ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs));
@ -461,16 +457,16 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
auto name = b.Sym(); auto name = b.Sym();
curr_stmts->Push(b.Decl(b.Var(name, decompose(bin_expr->lhs)))); curr_stmts->Push(b.Decl(b.Var(name, decompose(bin_expr->lhs))));
const ast::Expression* if_cond = nullptr; const Expression* if_cond = nullptr;
if (bin_expr->IsLogicalOr()) { if (bin_expr->IsLogicalOr()) {
if_cond = b.Not(name); if_cond = b.Not(name);
} else { } else {
if_cond = b.Expr(name); if_cond = b.Expr(name);
} }
const ast::BlockStatement* if_body = nullptr; const BlockStatement* if_body = nullptr;
{ {
tint::utils::Vector<const ast::Statement*, N> stmts; tint::utils::Vector<const Statement*, N> stmts;
TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts); TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts);
auto* new_rhs = decompose(bin_expr->rhs); auto* new_rhs = decompose(bin_expr->rhs);
curr_stmts->Push(b.Assign(name, new_rhs)); curr_stmts->Push(b.Assign(name, new_rhs));
@ -481,36 +477,36 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
return b.Expr(name); return b.Expr(name);
}, },
[&](const ast::IndexAccessorExpression* idx) { [&](const IndexAccessorExpression* idx) {
ctx.Replace(idx->object, decompose(idx->object)); ctx.Replace(idx->object, decompose(idx->object));
ctx.Replace(idx->index, decompose(idx->index)); ctx.Replace(idx->index, decompose(idx->index));
return clone_maybe_hoisted(idx); return clone_maybe_hoisted(idx);
}, },
[&](const ast::BitcastExpression* bitcast) { [&](const BitcastExpression* bitcast) {
ctx.Replace(bitcast->expr, decompose(bitcast->expr)); ctx.Replace(bitcast->expr, decompose(bitcast->expr));
return clone_maybe_hoisted(bitcast); return clone_maybe_hoisted(bitcast);
}, },
[&](const ast::CallExpression* call) { [&](const CallExpression* call) {
for (auto* a : call->args) { for (auto* a : call->args) {
ctx.Replace(a, decompose(a)); ctx.Replace(a, decompose(a));
} }
return clone_maybe_hoisted(call); return clone_maybe_hoisted(call);
}, },
[&](const ast::MemberAccessorExpression* member) { [&](const MemberAccessorExpression* member) {
ctx.Replace(member->object, decompose(member->object)); ctx.Replace(member->object, decompose(member->object));
return clone_maybe_hoisted(member); return clone_maybe_hoisted(member);
}, },
[&](const ast::UnaryOpExpression* unary) { [&](const UnaryOpExpression* unary) {
ctx.Replace(unary->expr, decompose(unary->expr)); ctx.Replace(unary->expr, decompose(unary->expr));
return clone_maybe_hoisted(unary); return clone_maybe_hoisted(unary);
}, },
[&](const ast::LiteralExpression* lit) { [&](const LiteralExpression* lit) {
return clone_maybe_hoisted(lit); // Leaf expression, just clone as is return clone_maybe_hoisted(lit); // Leaf expression, just clone as is
}, },
[&](const ast::IdentifierExpression* id) { [&](const IdentifierExpression* id) {
return clone_maybe_hoisted(id); // Leaf expression, just clone as is return clone_maybe_hoisted(id); // Leaf expression, just clone as is
}, },
[&](const ast::PhonyExpression* phony) { [&](const PhonyExpression* phony) {
return clone_maybe_hoisted(phony); // Leaf expression, just clone as is return clone_maybe_hoisted(phony); // Leaf expression, just clone as is
}, },
[&](Default) { [&](Default) {
@ -522,8 +518,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
// Inserts statements in `stmts` before `stmt` // Inserts statements in `stmts` before `stmt`
template <size_t N> template <size_t N>
void InsertBefore(tint::utils::Vector<const ast::Statement*, N>& stmts, void InsertBefore(tint::utils::Vector<const Statement*, N>& stmts, const Statement* stmt) {
const ast::Statement* stmt) {
if (!stmts.IsEmpty()) { if (!stmts.IsEmpty()) {
auto ip = utils::GetInsertionPoint(ctx, stmt); auto ip = utils::GetInsertionPoint(ctx, stmt);
for (auto* s : stmts) { for (auto* s : stmts) {
@ -534,86 +529,86 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
// Decomposes expressions of `stmt`, returning a replacement statement or // Decomposes expressions of `stmt`, returning a replacement statement or
// nullptr if not replacing it. // nullptr if not replacing it.
const ast::Statement* DecomposeStatement(const ast::Statement* stmt) { const Statement* DecomposeStatement(const Statement* stmt) {
return Switch( return Switch(
stmt, stmt,
[&](const ast::AssignmentStatement* s) -> const ast::Statement* { [&](const AssignmentStatement* s) -> const Statement* {
if (!sem.GetVal(s->lhs)->HasSideEffects() && if (!sem.GetVal(s->lhs)->HasSideEffects() &&
!sem.GetVal(s->rhs)->HasSideEffects()) { !sem.GetVal(s->rhs)->HasSideEffects()) {
return nullptr; return nullptr;
} }
// lhs before rhs // lhs before rhs
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->lhs, Decompose(s->lhs, &stmts)); ctx.Replace(s->lhs, Decompose(s->lhs, &stmts));
ctx.Replace(s->rhs, Decompose(s->rhs, &stmts)); ctx.Replace(s->rhs, Decompose(s->rhs, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s); return ctx.CloneWithoutTransform(s);
}, },
[&](const ast::CallStatement* s) -> const ast::Statement* { [&](const CallStatement* s) -> const Statement* {
if (!sem.Get(s->expr)->HasSideEffects()) { if (!sem.Get(s->expr)->HasSideEffects()) {
return nullptr; return nullptr;
} }
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->expr, Decompose(s->expr, &stmts)); ctx.Replace(s->expr, Decompose(s->expr, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s); return ctx.CloneWithoutTransform(s);
}, },
[&](const ast::ForLoopStatement* s) -> const ast::Statement* { [&](const ForLoopStatement* s) -> const Statement* {
if (!s->condition || !sem.GetVal(s->condition)->HasSideEffects()) { if (!s->condition || !sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr; return nullptr;
} }
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts)); ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s); return ctx.CloneWithoutTransform(s);
}, },
[&](const ast::WhileStatement* s) -> const ast::Statement* { [&](const WhileStatement* s) -> const Statement* {
if (!sem.GetVal(s->condition)->HasSideEffects()) { if (!sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr; return nullptr;
} }
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts)); ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s); return ctx.CloneWithoutTransform(s);
}, },
[&](const ast::IfStatement* s) -> const ast::Statement* { [&](const IfStatement* s) -> const Statement* {
if (!sem.GetVal(s->condition)->HasSideEffects()) { if (!sem.GetVal(s->condition)->HasSideEffects()) {
return nullptr; return nullptr;
} }
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts)); ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s); return ctx.CloneWithoutTransform(s);
}, },
[&](const ast::ReturnStatement* s) -> const ast::Statement* { [&](const ReturnStatement* s) -> const Statement* {
if (!s->value || !sem.GetVal(s->value)->HasSideEffects()) { if (!s->value || !sem.GetVal(s->value)->HasSideEffects()) {
return nullptr; return nullptr;
} }
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->value, Decompose(s->value, &stmts)); ctx.Replace(s->value, Decompose(s->value, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s); return ctx.CloneWithoutTransform(s);
}, },
[&](const ast::SwitchStatement* s) -> const ast::Statement* { [&](const SwitchStatement* s) -> const Statement* {
if (!sem.Get(s->condition)) { if (!sem.Get(s->condition)) {
return nullptr; return nullptr;
} }
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts)); ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s); return ctx.CloneWithoutTransform(s);
}, },
[&](const ast::VariableDeclStatement* s) -> const ast::Statement* { [&](const VariableDeclStatement* s) -> const Statement* {
auto* var = s->variable; auto* var = s->variable;
if (!var->initializer || !sem.GetVal(var->initializer)->HasSideEffects()) { if (!var->initializer || !sem.GetVal(var->initializer)->HasSideEffects()) {
return nullptr; return nullptr;
} }
tint::utils::Vector<const ast::Statement*, 8> stmts; tint::utils::Vector<const Statement*, 8> stmts;
ctx.Replace(var->initializer, Decompose(var->initializer, &stmts)); ctx.Replace(var->initializer, Decompose(var->initializer, &stmts));
InsertBefore(stmts, s); InsertBefore(stmts, s);
return b.Decl(ctx.CloneWithoutTransform(var)); return b.Decl(ctx.CloneWithoutTransform(var));
}, },
[](Default) -> const ast::Statement* { [](Default) -> const Statement* {
// Other statement types don't have expressions // Other statement types don't have expressions
return nullptr; return nullptr;
}); });
@ -626,7 +621,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
void Run() { void Run() {
// We replace all BlockStatements as this allows us to iterate over the // We replace all BlockStatements as this allows us to iterate over the
// block statements and ctx.InsertBefore hoisted declarations on them. // block statements and ctx.InsertBefore hoisted declarations on them.
ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* { ctx.ReplaceAll([&](const BlockStatement* block) -> const Statement* {
for (auto* stmt : block->statements) { for (auto* stmt : block->statements) {
if (auto* new_stmt = DecomposeStatement(stmt)) { if (auto* new_stmt = DecomposeStatement(stmt)) {
ctx.Replace(stmt, new_stmt); ctx.Replace(stmt, new_stmt);
@ -634,7 +629,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
// Handle for loops, as they are the only other AST node that // Handle for loops, as they are the only other AST node that
// contains statements outside of BlockStatements. // contains statements outside of BlockStatements.
if (auto* fl = stmt->As<ast::ForLoopStatement>()) { if (auto* fl = stmt->As<ForLoopStatement>()) {
if (auto* new_stmt = DecomposeStatement(fl->initializer)) { if (auto* new_stmt = DecomposeStatement(fl->initializer)) {
ctx.Replace(fl->initializer, new_stmt); ctx.Replace(fl->initializer, new_stmt);
} }

View File

@ -45,7 +45,7 @@ struct RemoveContinueInSwitch::State {
bool made_changes = false; bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
auto* cont = node->As<ast::ContinueStatement>(); auto* cont = node->As<ContinueStatement>();
if (!cont) { if (!cont) {
continue; continue;
} }
@ -103,12 +103,12 @@ struct RemoveContinueInSwitch::State {
const sem::Info& sem = src->Sem(); const sem::Info& sem = src->Sem();
// Map of switch statement to 'tint_continue' variable. // Map of switch statement to 'tint_continue' variable.
std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name; std::unordered_map<const SwitchStatement*, Symbol> switch_to_cont_var_name;
// If `cont` is within a switch statement within a loop, returns a pointer to // If `cont` is within a switch statement within a loop, returns a pointer to
// that switch statement. // that switch statement.
static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem, static const SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
const ast::ContinueStatement* cont) { const ContinueStatement* cont) {
// Find whether first parent is a switch or a loop // Find whether first parent is a switch or a loop
auto* sem_stmt = sem.Get(cont); auto* sem_stmt = sem.Get(cont);
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement, auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
@ -116,7 +116,7 @@ struct RemoveContinueInSwitch::State {
if (!sem_parent) { if (!sem_parent) {
return nullptr; return nullptr;
} }
return sem_parent->Declaration()->As<ast::SwitchStatement>(); return sem_parent->Declaration()->As<SwitchStatement>();
} }
}; };

View File

@ -53,14 +53,14 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
Switch( Switch(
node, node,
[&](const ast::AssignmentStatement* stmt) { [&](const AssignmentStatement* stmt) {
if (stmt->lhs->Is<ast::PhonyExpression>()) { if (stmt->lhs->Is<PhonyExpression>()) {
made_changes = true; made_changes = true;
std::vector<const ast::Expression*> side_effects; std::vector<const Expression*> side_effects;
if (!ast::TraverseExpressions( if (!TraverseExpressions(
stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) { stmt->rhs, b.Diagnostics(), [&](const CallExpression* expr) {
// ast::CallExpression may map to a function or builtin call // CallExpression may map to a function or builtin call
// (both may have side-effects), or a value constructor or value // (both may have side-effects), or a value constructor or value
// conversion (both do not have side effects). // conversion (both do not have side effects).
auto* call = sem.Get<sem::Call>(expr); auto* call = sem.Get<sem::Call>(expr);
@ -68,14 +68,14 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
// Semantic node must be a Materialize, in which case the // Semantic node must be a Materialize, in which case the
// expression was creation-time (compile time), so could not // expression was creation-time (compile time), so could not
// have side effects. Just skip. // have side effects. Just skip.
return ast::TraverseAction::Skip; return TraverseAction::Skip;
} }
if (call->Target()->IsAnyOf<sem::Function, sem::Builtin>() && if (call->Target()->IsAnyOf<sem::Function, sem::Builtin>() &&
call->HasSideEffects()) { call->HasSideEffects()) {
side_effects.push_back(expr); side_effects.push_back(expr);
return ast::TraverseAction::Skip; return TraverseAction::Skip;
} }
return ast::TraverseAction::Descend; return TraverseAction::Descend;
})) { })) {
return; return;
} }
@ -88,12 +88,12 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
} }
if (side_effects.size() == 1) { if (side_effects.size() == 1) {
if (auto* call_expr = side_effects[0]->As<ast::CallExpression>()) { if (auto* call_expr = side_effects[0]->As<CallExpression>()) {
// Phony assignment with single call side effect. // Phony assignment with single call side effect.
auto* call = sem.Get(call_expr)->Unwrap()->As<sem::Call>(); auto* call = sem.Get(call_expr)->Unwrap()->As<sem::Call>();
if (call->Target()->MustUse()) { if (call->Target()->MustUse()) {
// Replace phony assignment assignment to uniquely named let. // Replace phony assignment assignment to uniquely named let.
ctx.Replace<ast::Statement>(stmt, [&, call_expr] { // ctx.Replace<Statement>(stmt, [&, call_expr] { //
auto name = b.Symbols().New("tint_phony"); auto name = b.Symbols().New("tint_phony");
auto* rhs = ctx.Clone(call_expr); auto* rhs = ctx.Clone(call_expr);
return b.Decl(b.Let(name, rhs)); return b.Decl(b.Let(name, rhs));
@ -118,7 +118,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
} }
auto sink = sinks.GetOrCreate(sig, [&] { auto sink = sinks.GetOrCreate(sig, [&] {
auto name = b.Symbols().New("phony_sink"); auto name = b.Symbols().New("phony_sink");
utils::Vector<const ast::Parameter*, 8> params; utils::Vector<const Parameter*, 8> params;
for (auto* ty : sig) { for (auto* ty : sig) {
auto ast_ty = CreateASTTypeFor(ctx, ty); auto ast_ty = CreateASTTypeFor(ctx, ty);
params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty)); params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
@ -126,7 +126,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
b.Func(name, params, b.ty.void_(), {}); b.Func(name, params, b.ty.void_(), {});
return name; return name;
}); });
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const Expression*, 8> args;
for (auto* arg : side_effects) { for (auto* arg : side_effects) {
args.Push(ctx.Clone(arg)); args.Push(ctx.Clone(arg));
} }
@ -134,7 +134,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
}); });
} }
}, },
[&](const ast::CallStatement* stmt) { [&](const CallStatement* stmt) {
// Remove call statements to const value-returning functions. // Remove call statements to const value-returning functions.
// TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects. // TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
auto* sem_expr = sem.Get(stmt->expr); auto* sem_expr = sem.Get(stmt->expr);

View File

@ -1265,10 +1265,10 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
} }
// Identifiers that need to keep their symbols preserved. // Identifiers that need to keep their symbols preserved.
utils::Hashset<const ast::Identifier*, 16> preserved_identifiers; utils::Hashset<const Identifier*, 16> preserved_identifiers;
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
auto preserve_if_builtin_type = [&](const ast::Identifier* ident) { auto preserve_if_builtin_type = [&](const Identifier* ident) {
if (!global_decls.Contains(ident->symbol)) { if (!global_decls.Contains(ident->symbol)) {
preserved_identifiers.Add(ident); preserved_identifiers.Add(ident);
} }
@ -1276,7 +1276,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
Switch( Switch(
node, node,
[&](const ast::MemberAccessorExpression* accessor) { [&](const MemberAccessorExpression* accessor) {
auto* sem = src->Sem().Get(accessor)->UnwrapLoad(); auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
if (sem->Is<sem::Swizzle>()) { if (sem->Is<sem::Swizzle>()) {
preserved_identifiers.Add(accessor->member); preserved_identifiers.Add(accessor->member);
@ -1288,19 +1288,19 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
} }
} }
}, },
[&](const ast::DiagnosticAttribute* diagnostic) { [&](const DiagnosticAttribute* diagnostic) {
if (auto* category = diagnostic->control.rule_name->category) { if (auto* category = diagnostic->control.rule_name->category) {
preserved_identifiers.Add(category); preserved_identifiers.Add(category);
} }
preserved_identifiers.Add(diagnostic->control.rule_name->name); preserved_identifiers.Add(diagnostic->control.rule_name->name);
}, },
[&](const ast::DiagnosticDirective* diagnostic) { [&](const DiagnosticDirective* diagnostic) {
if (auto* category = diagnostic->control.rule_name->category) { if (auto* category = diagnostic->control.rule_name->category) {
preserved_identifiers.Add(category); preserved_identifiers.Add(category);
} }
preserved_identifiers.Add(diagnostic->control.rule_name->name); preserved_identifiers.Add(diagnostic->control.rule_name->name);
}, },
[&](const ast::IdentifierExpression* expr) { [&](const IdentifierExpression* expr) {
Switch( Switch(
src->Sem().Get(expr), // src->Sem().Get(expr), //
[&](const sem::BuiltinEnumExpressionBase*) { [&](const sem::BuiltinEnumExpressionBase*) {
@ -1310,7 +1310,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
preserve_if_builtin_type(expr->identifier); preserve_if_builtin_type(expr->identifier);
}); });
}, },
[&](const ast::CallExpression* call) { [&](const CallExpression* call) {
Switch( Switch(
src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(), src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(),
[&](const sem::Builtin*) { [&](const sem::Builtin*) {
@ -1372,7 +1372,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
ProgramBuilder b; ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ false}; CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
ctx.ReplaceAll([&](const ast::Identifier* ident) -> const ast::Identifier* { ctx.ReplaceAll([&](const Identifier* ident) -> const Identifier* {
const auto symbol = ident->symbol; const auto symbol = ident->symbol;
if (preserved_identifiers.Contains(ident) || !should_rename(symbol)) { if (preserved_identifiers.Contains(ident) || !should_rename(symbol)) {
return nullptr; // Preserve symbol return nullptr; // Preserve symbol
@ -1382,12 +1382,12 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
auto replacement = remappings.GetOrCreate(symbol, [&] { return b.Symbols().New(); }); auto replacement = remappings.GetOrCreate(symbol, [&] { return b.Symbols().New(); });
// Reconstruct the identifier // Reconstruct the identifier
if (auto* tmpl_ident = ident->As<ast::TemplatedIdentifier>()) { if (auto* tmpl_ident = ident->As<TemplatedIdentifier>()) {
auto args = ctx.Clone(tmpl_ident->arguments); auto args = ctx.Clone(tmpl_ident->arguments);
return ctx.dst->create<ast::TemplatedIdentifier>(ctx.Clone(ident->source), replacement, return ctx.dst->create<TemplatedIdentifier>(ctx.Clone(ident->source), replacement,
std::move(args), utils::Empty); std::move(args), utils::Empty);
} }
return ctx.dst->create<ast::Identifier>(ctx.Clone(ident->source), replacement); return ctx.dst->create<Identifier>(ctx.Clone(ident->source), replacement);
}); });
ctx.Clone(); ctx.Clone();

View File

@ -58,7 +58,7 @@ struct Robustness::State {
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch( Switch(
node, // node, //
[&](const ast::IndexAccessorExpression* e) { [&](const IndexAccessorExpression* e) {
// obj[idx] // obj[idx]
// Array, matrix and vector indexing may require robustness transformation. // Array, matrix and vector indexing may require robustness transformation.
auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>(); auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>();
@ -73,7 +73,7 @@ struct Robustness::State {
break; break;
} }
}, },
[&](const ast::IdentifierExpression* e) { [&](const IdentifierExpression* e) {
// Identifiers may resolve to pointer lets, which may be predicated. // Identifiers may resolve to pointer lets, which may be predicated.
// Inspect. // Inspect.
if (auto* user = sem.Get<sem::VariableUser>(e)) { if (auto* user = sem.Get<sem::VariableUser>(e)) {
@ -86,42 +86,42 @@ struct Robustness::State {
} }
} }
}, },
[&](const ast::AccessorExpression* e) { [&](const AccessorExpression* e) {
// obj.member // obj.member
// Propagate the predication from the object to this expression. // Propagate the predication from the object to this expression.
if (auto pred = predicates.Get(e->object)) { if (auto pred = predicates.Get(e->object)) {
predicates.Add(e, *pred); predicates.Add(e, *pred);
} }
}, },
[&](const ast::UnaryOpExpression* e) { [&](const UnaryOpExpression* e) {
// Includes address-of, or indirection // Includes address-of, or indirection
// Propagate the predication from the inner expression to this expression. // Propagate the predication from the inner expression to this expression.
if (auto pred = predicates.Get(e->expr)) { if (auto pred = predicates.Get(e->expr)) {
predicates.Add(e, *pred); predicates.Add(e, *pred);
} }
}, },
[&](const ast::AssignmentStatement* s) { [&](const AssignmentStatement* s) {
if (auto pred = predicates.Get(s->lhs)) { if (auto pred = predicates.Get(s->lhs)) {
// Assignment target is predicated // Assignment target is predicated
// Replace statement with condition on the predicate // Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s)))); ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
} }
}, },
[&](const ast::CompoundAssignmentStatement* s) { [&](const CompoundAssignmentStatement* s) {
if (auto pred = predicates.Get(s->lhs)) { if (auto pred = predicates.Get(s->lhs)) {
// Assignment expression is predicated // Assignment expression is predicated
// Replace statement with condition on the predicate // Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s)))); ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
} }
}, },
[&](const ast::IncrementDecrementStatement* s) { [&](const IncrementDecrementStatement* s) {
if (auto pred = predicates.Get(s->lhs)) { if (auto pred = predicates.Get(s->lhs)) {
// Assignment expression is predicated // Assignment expression is predicated
// Replace statement with condition on the predicate // Replace statement with condition on the predicate
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s)))); ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
} }
}, },
[&](const ast::CallExpression* e) { [&](const CallExpression* e) {
if (auto* call = sem.Get<sem::Call>(e)) { if (auto* call = sem.Get<sem::Call>(e)) {
Switch( Switch(
call->Target(), // call->Target(), //
@ -163,7 +163,7 @@ struct Robustness::State {
// predicated_expr = expr; // predicated_expr = expr;
// } // }
// //
if (auto* expr = node->As<ast::Expression>()) { if (auto* expr = node->As<Expression>()) {
if (auto pred = predicates.Get(expr)) { if (auto pred = predicates.Get(expr)) {
// Expression is predicated // Expression is predicated
auto* sem_expr = sem.GetVal(expr); auto* sem_expr = sem.GetVal(expr);
@ -202,15 +202,15 @@ struct Robustness::State {
/// Alias to the source program's semantic info /// Alias to the source program's semantic info
const sem::Info& sem = ctx.src->Sem(); const sem::Info& sem = ctx.src->Sem();
/// Map of expression to predicate condition /// Map of expression to predicate condition
utils::Hashmap<const ast::Expression*, Symbol, 32> predicates{}; utils::Hashmap<const Expression*, Symbol, 32> predicates{};
/// @return the `u32` typed expression that represents the maximum indexable value for the index /// @return the `u32` typed expression that represents the maximum indexable value for the index
/// accessor @p expr, or nullptr if there is no robustness limit for this expression. /// accessor @p expr, or nullptr if there is no robustness limit for this expression.
const ast::Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) { const Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) {
auto* obj_type = expr->Object()->Type(); auto* obj_type = expr->Object()->Type();
return Switch( return Switch(
obj_type->UnwrapRef(), // obj_type->UnwrapRef(), //
[&](const type::Vector* vec) -> const ast::Expression* { [&](const type::Vector* vec) -> const Expression* {
if (expr->Index()->ConstantValue() || expr->Index()->Is<sem::Swizzle>()) { if (expr->Index()->ConstantValue() || expr->Index()->Is<sem::Swizzle>()) {
// Index and size is constant. // Index and size is constant.
// Validation will have rejected any OOB accesses. // Validation will have rejected any OOB accesses.
@ -218,7 +218,7 @@ struct Robustness::State {
} }
return b.Expr(u32(vec->Width() - 1u)); return b.Expr(u32(vec->Width() - 1u));
}, },
[&](const type::Matrix* mat) -> const ast::Expression* { [&](const type::Matrix* mat) -> const Expression* {
if (expr->Index()->ConstantValue()) { if (expr->Index()->ConstantValue()) {
// Index and size is constant. // Index and size is constant.
// Validation will have rejected any OOB accesses. // Validation will have rejected any OOB accesses.
@ -226,7 +226,7 @@ struct Robustness::State {
} }
return b.Expr(u32(mat->columns() - 1u)); return b.Expr(u32(mat->columns() - 1u));
}, },
[&](const type::Array* arr) -> const ast::Expression* { [&](const type::Array* arr) -> const Expression* {
if (arr->Count()->Is<type::RuntimeArrayCount>()) { if (arr->Count()->Is<type::RuntimeArrayCount>()) {
// Size is unknown until runtime. // Size is unknown until runtime.
// Must clamp, even if the index is constant. // Must clamp, even if the index is constant.
@ -248,7 +248,7 @@ struct Robustness::State {
type::Array::kErrExpectedConstantCount); type::Array::kErrExpectedConstantCount);
return nullptr; return nullptr;
}, },
[&](Default) -> const ast::Expression* { [&](Default) -> const Expression* {
TINT_ICE(Transform, b.Diagnostics()) TINT_ICE(Transform, b.Diagnostics())
<< "unhandled object type in robustness of array index: " << "unhandled object type in robustness of array index: "
<< obj_type->UnwrapRef()->FriendlyName(); << obj_type->UnwrapRef()->FriendlyName();
@ -350,7 +350,7 @@ struct Robustness::State {
/// Applies predication to the non-texture builtin call, if required. /// Applies predication to the non-texture builtin call, if required.
void MaybePredicateNonTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) { void MaybePredicateNonTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
// Gather the predications for the builtin arguments // Gather the predications for the builtin arguments
const ast::Expression* predicate = nullptr; const Expression* predicate = nullptr;
for (auto* arg : call->Declaration()->args) { for (auto* arg : call->Declaration()->args) {
if (auto pred = predicates.Get(arg)) { if (auto pred = predicates.Get(arg)) {
predicate = And(predicate, b.Expr(*pred)); predicate = And(predicate, b.Expr(*pred));
@ -393,7 +393,7 @@ struct Robustness::State {
auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)]; auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
// Build the builtin predicate from the arguments // Build the builtin predicate from the arguments
const ast::Expression* predicate = nullptr; const Expression* predicate = nullptr;
Symbol level_idx, num_levels; Symbol level_idx, num_levels;
if (level_arg_idx >= 0) { if (level_arg_idx >= 0) {
@ -554,7 +554,7 @@ struct Robustness::State {
} }
/// @returns a bitwise and of the two expressions, or the other expression if one is null. /// @returns a bitwise and of the two expressions, or the other expression if one is null.
const ast::Expression* And(const ast::Expression* lhs, const ast::Expression* rhs) { const Expression* And(const Expression* lhs, const Expression* rhs) {
if (lhs && rhs) { if (lhs && rhs) {
return b.And(lhs, rhs); return b.And(lhs, rhs);
} }
@ -568,11 +568,11 @@ struct Robustness::State {
/// predicate. /// predicate.
/// @param else_stmt - the statement to execute for the predication failure /// @param else_stmt - the statement to execute for the predication failure
void PredicateCall(const sem::Call* call, void PredicateCall(const sem::Call* call,
const ast::Expression* predicate, const Expression* predicate,
const ast::BlockStatement* else_stmt = nullptr) { const BlockStatement* else_stmt = nullptr) {
auto* expr = call->Declaration(); auto* expr = call->Declaration();
auto* stmt = call->Stmt(); auto* stmt = call->Stmt();
auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>(); auto* call_stmt = stmt->Declaration()->As<CallStatement>();
if (call_stmt && call_stmt->expr == expr) { if (call_stmt && call_stmt->expr == expr) {
// Wrap the statement in an if-statement with the predicate condition. // Wrap the statement in an if-statement with the predicate condition.
hoist.Replace(stmt, b.If(predicate, b.Block(ctx.Clone(stmt->Declaration())), hoist.Replace(stmt, b.If(predicate, b.Block(ctx.Clone(stmt->Declaration())),
@ -646,7 +646,7 @@ struct Robustness::State {
} }
/// @returns a scalar or vector type with the element type @p scalar and width @p width /// @returns a scalar or vector type with the element type @p scalar and width @p width
ast::Type ScalarOrVecTy(ast::Type scalar, uint32_t width) const { Type ScalarOrVecTy(Type scalar, uint32_t width) const {
if (width > 1) { if (width > 1) {
return b.ty.vec(scalar, width); return b.ty.vec(scalar, width);
} }
@ -655,7 +655,7 @@ struct Robustness::State {
/// @returns a vector constructed with the scalar expression @p scalar if @p width > 1, /// @returns a vector constructed with the scalar expression @p scalar if @p width > 1,
/// otherwise returns @p scalar. /// otherwise returns @p scalar.
const ast::Expression* ScalarOrVec(const ast::Expression* scalar, uint32_t width) { const Expression* ScalarOrVec(const Expression* scalar, uint32_t width) {
if (width > 1) { if (width > 1) {
return b.Call(b.ty.vec<Infer>(width), scalar); return b.Call(b.ty.vec<Infer>(width), scalar);
} }
@ -664,13 +664,13 @@ struct Robustness::State {
/// @returns @p val cast to a `vecN<i32>`, where `N` is @p width, or cast to i32 if @p width /// @returns @p val cast to a `vecN<i32>`, where `N` is @p width, or cast to i32 if @p width
/// is 1. /// is 1.
const ast::CallExpression* CastToSigned(const ast::Expression* val, uint32_t width) { const CallExpression* CastToSigned(const Expression* val, uint32_t width) {
return b.Call(ScalarOrVecTy(b.ty.i32(), width), val); return b.Call(ScalarOrVecTy(b.ty.i32(), width), val);
} }
/// @returns @p val cast to a `vecN<u32>`, where `N` is @p width, or cast to u32 if @p width /// @returns @p val cast to a `vecN<u32>`, where `N` is @p width, or cast to u32 if @p width
/// is 1. /// is 1.
const ast::CallExpression* CastToUnsigned(const ast::Expression* val, uint32_t width) { const CallExpression* CastToUnsigned(const Expression* val, uint32_t width) {
return b.Call(ScalarOrVecTy(b.ty.u32(), width), val); return b.Call(ScalarOrVecTy(b.ty.u32(), width), val);
} }
}; };

View File

@ -41,7 +41,7 @@ struct PointerOp {
/// Zero: no pointer op on `expr` /// Zero: no pointer op on `expr`
int indirections = 0; int indirections = 0;
/// The expression being operated on /// The expression being operated on
const ast::Expression* expr = nullptr; const Expression* expr = nullptr;
}; };
} // namespace } // namespace
@ -64,29 +64,29 @@ struct SimplifyPointers::State {
/// expression. The function-like argument `cb` is called for each found. /// expression. The function-like argument `cb` is called for each found.
/// @param expr the expression to traverse /// @param expr the expression to traverse
/// @param cb a function-like object with the signature /// @param cb a function-like object with the signature
/// `void(const ast::Expression*)`, which is called for each array index /// `void(const Expression*)`, which is called for each array index
/// expression /// expression
template <typename F> template <typename F>
static void CollectSavedArrayIndices(const ast::Expression* expr, F&& cb) { static void CollectSavedArrayIndices(const Expression* expr, F&& cb) {
if (auto* a = expr->As<ast::IndexAccessorExpression>()) { if (auto* a = expr->As<IndexAccessorExpression>()) {
CollectSavedArrayIndices(a->object, cb); CollectSavedArrayIndices(a->object, cb);
if (!a->index->Is<ast::LiteralExpression>()) { if (!a->index->Is<LiteralExpression>()) {
cb(a->index); cb(a->index);
} }
return; return;
} }
if (auto* m = expr->As<ast::MemberAccessorExpression>()) { if (auto* m = expr->As<MemberAccessorExpression>()) {
CollectSavedArrayIndices(m->object, cb); CollectSavedArrayIndices(m->object, cb);
return; return;
} }
if (auto* u = expr->As<ast::UnaryOpExpression>()) { if (auto* u = expr->As<UnaryOpExpression>()) {
CollectSavedArrayIndices(u->expr, cb); CollectSavedArrayIndices(u->expr, cb);
return; return;
} }
// Note: Other ast::Expression types can be safely ignored as they cannot be // Note: Other Expression types can be safely ignored as they cannot be
// used to generate a reference or pointer. // used to generate a reference or pointer.
// See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers // See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
} }
@ -95,16 +95,16 @@ struct SimplifyPointers::State {
/// indirection ops into a PointerOp. /// indirection ops into a PointerOp.
/// @param in the expression to walk /// @param in the expression to walk
/// @returns the reduced PointerOp /// @returns the reduced PointerOp
PointerOp Reduce(const ast::Expression* in) const { PointerOp Reduce(const Expression* in) const {
PointerOp op{0, in}; PointerOp op{0, in};
while (true) { while (true) {
if (auto* unary = op.expr->As<ast::UnaryOpExpression>()) { if (auto* unary = op.expr->As<UnaryOpExpression>()) {
switch (unary->op) { switch (unary->op) {
case ast::UnaryOp::kIndirection: case UnaryOp::kIndirection:
op.indirections++; op.indirections++;
op.expr = unary->expr; op.expr = unary->expr;
continue; continue;
case ast::UnaryOp::kAddressOf: case UnaryOp::kAddressOf:
op.indirections--; op.indirections--;
op.expr = unary->expr; op.expr = unary->expr;
continue; continue;
@ -115,7 +115,7 @@ struct SimplifyPointers::State {
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) { if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
auto* var = user->Variable(); auto* var = user->Variable();
if (var->Is<sem::LocalVariable>() && // if (var->Is<sem::LocalVariable>() && //
var->Declaration()->Is<ast::Let>() && // var->Declaration()->Is<Let>() && //
var->Type()->Is<type::Pointer>()) { var->Type()->Is<type::Pointer>()) {
op.expr = var->Declaration()->initializer; op.expr = var->Declaration()->initializer;
continue; continue;
@ -129,7 +129,7 @@ struct SimplifyPointers::State {
/// @returns the new program or SkipTransform if the transform is not required /// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() { ApplyResult Run() {
// A map of saved expressions to their saved variable name // A map of saved expressions to their saved variable name
utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars; utils::Hashmap<const Expression*, Symbol, 8> saved_vars;
bool needs_transform = false; bool needs_transform = false;
for (auto* ty : ctx.src->Types()) { for (auto* ty : ctx.src->Types()) {
@ -146,8 +146,8 @@ struct SimplifyPointers::State {
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
Switch( Switch(
node, // node, //
[&](const ast::VariableDeclStatement* let) { [&](const VariableDeclStatement* let) {
if (!let->variable->Is<ast::Let>()) { if (!let->variable->Is<Let>()) {
return; // Not a `let` declaration. Ignore. return; // Not a `let` declaration. Ignore.
} }
@ -160,9 +160,9 @@ struct SimplifyPointers::State {
// Scan the initializer expression for array index expressions that need // Scan the initializer expression for array index expressions that need
// to be hoist to temporary "saved" variables. // to be hoist to temporary "saved" variables.
utils::Vector<const ast::VariableDeclStatement*, 8> saved; utils::Vector<const VariableDeclStatement*, 8> saved;
CollectSavedArrayIndices( CollectSavedArrayIndices(
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) { var->Declaration()->initializer, [&](const Expression* idx_expr) {
// We have a sub-expression that needs to be saved. // We have a sub-expression that needs to be saved.
// Create a new variable // Create a new variable
auto saved_name = ctx.dst->Symbols().New( auto saved_name = ctx.dst->Symbols().New(
@ -205,8 +205,8 @@ struct SimplifyPointers::State {
// need for the original declaration to exist. Remove it. // need for the original declaration to exist. Remove it.
RemoveStatement(ctx, let); RemoveStatement(ctx, let);
}, },
[&](const ast::UnaryOpExpression* op) { [&](const UnaryOpExpression* op) {
if (op->op == ast::UnaryOp::kAddressOf) { if (op->op == UnaryOp::kAddressOf) {
// Transform can be skipped if no address-of operator is used, as there // Transform can be skipped if no address-of operator is used, as there
// will be no pointers that can be inlined. // will be no pointers that can be inlined.
needs_transform = true; needs_transform = true;
@ -218,7 +218,7 @@ struct SimplifyPointers::State {
return SkipTransform; return SkipTransform;
} }
// Register the ast::Expression transform handler. // Register the Expression transform handler.
// This performs two different transformations: // This performs two different transformations:
// * Identifiers that resolve to the pointer-typed `let` declarations are // * Identifiers that resolve to the pointer-typed `let` declarations are
// replaced with the recursively inlined initializer expression for the // replaced with the recursively inlined initializer expression for the
@ -226,7 +226,7 @@ struct SimplifyPointers::State {
// * Sub-expressions inside the pointer-typed `let` initializer expression // * Sub-expressions inside the pointer-typed `let` initializer expression
// that have been hoisted to a saved variable are replaced with the saved // that have been hoisted to a saved variable are replaced with the saved
// variable identifier. // variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
// Look to see if we need to swap this Expression with a saved variable. // Look to see if we need to swap this Expression with a saved variable.
if (auto saved_var = saved_vars.Find(expr)) { if (auto saved_var = saved_vars.Find(expr)) {
return ctx.dst->Expr(*saved_var); return ctx.dst->Expr(*saved_var);

View File

@ -45,7 +45,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
} }
// Find the target entry point. // Find the target entry point.
const ast::Function* entry_point = nullptr; const Function* entry_point = nullptr;
for (auto* f : src->AST().Functions()) { for (auto* f : src->AST().Functions()) {
if (!f->IsEntryPoint()) { if (!f->IsEntryPoint()) {
continue; continue;
@ -69,7 +69,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
for (auto* decl : src->AST().GlobalDeclarations()) { for (auto* decl : src->AST().GlobalDeclarations()) {
Switch( Switch(
decl, // decl, //
[&](const ast::TypeDecl* ty) { [&](const TypeDecl* ty) {
// Strip aliases that reference unused override declarations. // Strip aliases that reference unused override declarations.
if (auto* arr = sem.Get(ty)->As<type::Array>()) { if (auto* arr = sem.Get(ty)->As<type::Array>()) {
auto* refs = sem.TransitivelyReferencedOverrides(arr); auto* refs = sem.TransitivelyReferencedOverrides(arr);
@ -85,9 +85,9 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
// TODO(jrprice): Strip other unused types. // TODO(jrprice): Strip other unused types.
b.AST().AddTypeDecl(ctx.Clone(ty)); b.AST().AddTypeDecl(ctx.Clone(ty));
}, },
[&](const ast::Override* override) { [&](const Override* override) {
if (referenced_vars.Contains(sem.Get(override))) { if (referenced_vars.Contains(sem.Get(override))) {
if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) { if (!HasAttribute<IdAttribute>(override->attributes)) {
// If the override doesn't already have an @id() attribute, add one // If the override doesn't already have an @id() attribute, add one
// so that its allocated ID so that it won't be affected by other // so that its allocated ID so that it won't be affected by other
// stripped away overrides // stripped away overrides
@ -98,26 +98,24 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
b.AST().AddGlobalVariable(ctx.Clone(override)); b.AST().AddGlobalVariable(ctx.Clone(override));
} }
}, },
[&](const ast::Var* var) { [&](const Var* var) {
if (referenced_vars.Contains(sem.Get<sem::GlobalVariable>(var))) { if (referenced_vars.Contains(sem.Get<sem::GlobalVariable>(var))) {
b.AST().AddGlobalVariable(ctx.Clone(var)); b.AST().AddGlobalVariable(ctx.Clone(var));
} }
}, },
[&](const ast::Const* c) { [&](const Const* c) {
// Always keep 'const' declarations, as these can be used by attributes and array // Always keep 'const' declarations, as these can be used by attributes and array
// sizes, which are not tracked as transitively used by functions. They also don't // sizes, which are not tracked as transitively used by functions. They also don't
// typically get emitted by the backend unless they're actually used. // typically get emitted by the backend unless they're actually used.
b.AST().AddGlobalVariable(ctx.Clone(c)); b.AST().AddGlobalVariable(ctx.Clone(c));
}, },
[&](const ast::Function* func) { [&](const Function* func) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->name->symbol)) { if (sem.Get(func)->HasAncestorEntryPoint(entry_point->name->symbol)) {
b.AST().AddFunction(ctx.Clone(func)); b.AST().AddFunction(ctx.Clone(func));
} }
}, },
[&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); }, [&](const Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
[&](const ast::DiagnosticDirective* d) { [&](const DiagnosticDirective* d) { b.AST().AddDiagnosticDirective(ctx.Clone(d)); },
b.AST().AddDiagnosticDirective(ctx.Clone(d));
},
[&](Default) { [&](Default) {
TINT_UNREACHABLE(Transform, b.Diagnostics()) TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name; << "unhandled global declaration: " << decl->TypeInfo().name;

View File

@ -70,7 +70,7 @@ struct SpirvAtomic::State {
// Look for stub functions generated by the SPIR-V reader, which are used as placeholders // Look for stub functions generated by the SPIR-V reader, which are used as placeholders
// for atomic builtin calls. // for atomic builtin calls.
for (auto* fn : ctx.src->AST().Functions()) { for (auto* fn : ctx.src->AST().Functions()) {
if (auto* stub = ast::GetAttribute<Stub>(fn->attributes)) { if (auto* stub = GetAttribute<Stub>(fn->attributes)) {
auto* sem = ctx.src->Sem().Get(fn); auto* sem = ctx.src->Sem().Get(fn);
for (auto* call : sem->CallSites()) { for (auto* call : sem->CallSites()) {
@ -121,14 +121,14 @@ struct SpirvAtomic::State {
// If we need to change structure members, then fork them. // If we need to change structure members, then fork them.
if (!forked_structs.empty()) { if (!forked_structs.empty()) {
ctx.ReplaceAll([&](const ast::Struct* str) { ctx.ReplaceAll([&](const Struct* str) {
// Is `str` a structure we need to fork? // Is `str` a structure we need to fork?
auto* str_ty = ctx.src->Sem().Get(str); auto* str_ty = ctx.src->Sem().Get(str);
if (auto it = forked_structs.find(str_ty); it != forked_structs.end()) { if (auto it = forked_structs.find(str_ty); it != forked_structs.end()) {
const auto& forked = it->second; const auto& forked = it->second;
// Re-create the structure swapping in the atomic-flavoured members // Re-create the structure swapping in the atomic-flavoured members
utils::Vector<const ast::StructMember*, 8> members; utils::Vector<const StructMember*, 8> members;
members.Reserve(str->members.Length()); members.Reserve(str->members.Length());
for (size_t i = 0; i < str->members.Length(); i++) { for (size_t i = 0; i < str->members.Length(); i++) {
auto* member = str->members[i]; auto* member = str->members[i];
@ -187,14 +187,14 @@ struct SpirvAtomic::State {
atomic_expressions.Add(index->Object()); atomic_expressions.Add(index->Object());
}, },
[&](const sem::ValueExpression* e) { [&](const sem::ValueExpression* e) {
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) { if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr)); atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
} }
}); });
} }
} }
ast::Type AtomicTypeFor(const type::Type* ty) { Type AtomicTypeFor(const type::Type* ty) {
return Switch( return Switch(
ty, // ty, //
[&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); }, [&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
@ -221,7 +221,7 @@ struct SpirvAtomic::State {
[&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); }, [&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
[&](Default) { [&](Default) {
TINT_ICE(Transform, b.Diagnostics()) << "unhandled type: " << ty->FriendlyName(); TINT_ICE(Transform, b.Diagnostics()) << "unhandled type: " << ty->FriendlyName();
return ast::Type{}; return Type{};
}); });
} }
@ -249,7 +249,7 @@ struct SpirvAtomic::State {
for (auto* vu : atomic_var->Users()) { for (auto* vu : atomic_var->Users()) {
Switch( Switch(
vu->Stmt()->Declaration(), vu->Stmt()->Declaration(),
[&](const ast::AssignmentStatement* assign) { [&](const AssignmentStatement* assign) {
auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs); auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
if (is_ref_to_atomic_var(sem_lhs)) { if (is_ref_to_atomic_var(sem_lhs)) {
ctx.Replace(assign, [=] { ctx.Replace(assign, [=] {
@ -272,7 +272,7 @@ struct SpirvAtomic::State {
return; return;
} }
}, },
[&](const ast::VariableDeclStatement* decl) { [&](const VariableDeclStatement* decl) {
auto* var = decl->variable; auto* var = decl->variable;
if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) { if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) { if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
@ -293,7 +293,7 @@ struct SpirvAtomic::State {
SpirvAtomic::SpirvAtomic() = default; SpirvAtomic::SpirvAtomic() = default;
SpirvAtomic::~SpirvAtomic() = default; SpirvAtomic::~SpirvAtomic() = default;
SpirvAtomic::Stub::Stub(ProgramID pid, ast::NodeID nid, builtin::Function b) SpirvAtomic::Stub::Stub(ProgramID pid, NodeID nid, builtin::Function b)
: Base(pid, nid, utils::Empty), builtin(b) {} : Base(pid, nid, utils::Empty), builtin(b) {}
SpirvAtomic::Stub::~Stub() = default; SpirvAtomic::Stub::~Stub() = default;
std::string SpirvAtomic::Stub::InternalName() const { std::string SpirvAtomic::Stub::InternalName() const {

View File

@ -41,12 +41,12 @@ class SpirvAtomic final : public utils::Castable<SpirvAtomic, Transform> {
/// Stub is an attribute applied to stub SPIR-V reader generated functions that need to be /// Stub is an attribute applied to stub SPIR-V reader generated functions that need to be
/// translated to an atomic builtin. /// translated to an atomic builtin.
class Stub final : public utils::Castable<Stub, ast::InternalAttribute> { class Stub final : public utils::Castable<Stub, InternalAttribute> {
public: public:
/// @param pid the identifier of the program that owns this node /// @param pid the identifier of the program that owns this node
/// @param nid the unique node identifier /// @param nid the unique node identifier
/// @param builtin the atomic builtin this stub represents /// @param builtin the atomic builtin this stub represents
Stub(ProgramID pid, ast::NodeID nid, builtin::Function builtin); Stub(ProgramID pid, NodeID nid, builtin::Function builtin);
/// Destructor /// Destructor
~Stub() override; ~Stub() override;

View File

@ -102,7 +102,7 @@ struct Std140::State {
// Finally, replace all expression chains that used the authored types with those that // Finally, replace all expression chains that used the authored types with those that
// correctly use the forked types. // correctly use the forked types.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
if (auto access = AccessChainFor(expr)) { if (auto access = AccessChainFor(expr)) {
if (!access->std140_mat_idx.has_value()) { if (!access->std140_mat_idx.has_value()) {
// loading a std140 type, which is not a whole or partial decomposed matrix // loading a std140 type, which is not a whole or partial decomposed matrix
@ -230,7 +230,7 @@ struct Std140::State {
// Map of structure member in src of a matrix type, to list of decomposed column // Map of structure member in src of a matrix type, to list of decomposed column
// members in ctx.dst. // members in ctx.dst.
utils::Hashmap<const type::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8> utils::Hashmap<const type::StructMember*, utils::Vector<const StructMember*, 4>, 8>
std140_mat_members; std140_mat_members;
/// Describes a matrix that has been forked to a std140-structure holding the decomposed column /// Describes a matrix that has been forked to a std140-structure holding the decomposed column
@ -284,7 +284,7 @@ struct Std140::State {
if (str && str->UsedAs(builtin::AddressSpace::kUniform)) { if (str && str->UsedAs(builtin::AddressSpace::kUniform)) {
// Should this uniform buffer be forked for std140 usage? // Should this uniform buffer be forked for std140 usage?
bool fork_std140 = false; bool fork_std140 = false;
utils::Vector<const ast::StructMember*, 8> members; utils::Vector<const StructMember*, 8> members;
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
if (auto* mat = member->Type()->As<type::Matrix>()) { if (auto* mat = member->Type()->As<type::Matrix>()) {
// Is this member a matrix that needs decomposition for std140-layout? // Is this member a matrix that needs decomposition for std140-layout?
@ -335,7 +335,7 @@ struct Std140::State {
// Create a new forked structure, and insert it just under the original // Create a new forked structure, and insert it just under the original
// structure. // structure.
auto name = b.Symbols().New(str->Name().Name() + "_std140"); auto name = b.Symbols().New(str->Name().Name() + "_std140");
auto* std140 = b.create<ast::Struct>(b.Ident(name), std::move(members), auto* std140 = b.create<Struct>(b.Ident(name), std::move(members),
ctx.Clone(str->Declaration()->attributes)); ctx.Clone(str->Declaration()->attributes));
ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140); ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
std140_structs.Add(str, name); std140_structs.Add(str, name);
@ -349,7 +349,7 @@ struct Std140::State {
/// Populates the #std140_uniforms set. /// Populates the #std140_uniforms set.
void ReplaceUniformVarTypes() { void ReplaceUniformVarTypes() {
for (auto* global : src->AST().GlobalVariables()) { for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) { if (auto* var = global->As<Var>()) {
auto* v = sem.Get(var); auto* v = sem.Get(var);
if (v->AddressSpace() == builtin::AddressSpace::kUniform) { if (v->AddressSpace() == builtin::AddressSpace::kUniform) {
if (auto std140_ty = Std140Type(v->Type()->UnwrapRef())) { if (auto std140_ty = Std140Type(v->Type()->UnwrapRef())) {
@ -367,9 +367,7 @@ struct Std140::State {
/// @param str the structure that will hold the uniquely named member. /// @param str the structure that will hold the uniquely named member.
/// @param unsuffixed the common name prefix to use for the new members. /// @param unsuffixed the common name prefix to use for the new members.
/// @param count the number of members that need to be created. /// @param count the number of members that need to be created.
std::string PrefixForUniqueNames(const ast::Struct* str, std::string PrefixForUniqueNames(const Struct* str, Symbol unsuffixed, uint32_t count) const {
Symbol unsuffixed,
uint32_t count) const {
auto prefix = unsuffixed.Name(); auto prefix = unsuffixed.Name();
// Keep on inserting '_' between the unsuffixed name and the suffix numbers until the name // Keep on inserting '_' between the unsuffixed name and the suffix numbers until the name
// is unique. // is unique.
@ -400,14 +398,14 @@ struct Std140::State {
/// If the semantic type is not split for std140-layout, then nullptr is returned. /// If the semantic type is not split for std140-layout, then nullptr is returned.
/// @note will construct new std140 structures to hold decomposed matrices, populating /// @note will construct new std140 structures to hold decomposed matrices, populating
/// #std140_mats. /// #std140_mats.
ast::Type Std140Type(const type::Type* ty) { Type Std140Type(const type::Type* ty) {
return Switch( return Switch(
ty, // ty, //
[&](const type::Struct* str) { [&](const type::Struct* str) {
if (auto std140 = std140_structs.Find(str)) { if (auto std140 = std140_structs.Find(str)) {
return b.ty(*std140); return b.ty(*std140);
} }
return ast::Type{}; return Type{};
}, },
[&](const type::Matrix* mat) { [&](const type::Matrix* mat) {
if (MatrixNeedsDecomposing(mat)) { if (MatrixNeedsDecomposing(mat)) {
@ -426,13 +424,13 @@ struct Std140::State {
}); });
return b.ty(std140_mat.name); return b.ty(std140_mat.name);
} }
return ast::Type{}; return Type{};
}, },
[&](const type::Array* arr) { [&](const type::Array* arr) {
if (auto std140 = Std140Type(arr->ElemType())) { if (auto std140 = Std140Type(arr->ElemType())) {
utils::Vector<const ast::Attribute*, 1> attrs; utils::Vector<const Attribute*, 1> attrs;
if (!arr->IsStrideImplicit()) { if (!arr->IsStrideImplicit()) {
attrs.Push(b.create<ast::StrideAttribute>(arr->Stride())); attrs.Push(b.create<StrideAttribute>(arr->Stride()));
} }
auto count = arr->ConstantCount(); auto count = arr->ConstantCount();
if (TINT_UNLIKELY(!count)) { if (TINT_UNLIKELY(!count)) {
@ -446,7 +444,7 @@ struct Std140::State {
} }
return b.ty.array(std140, b.Expr(u32(count.value())), std::move(attrs)); return b.ty.array(std140, b.Expr(u32(count.value())), std::move(attrs));
} }
return ast::Type{}; return Type{};
}); });
} }
@ -455,7 +453,7 @@ struct Std140::State {
/// @param align the alignment in bytes of the matrix. /// @param align the alignment in bytes of the matrix.
/// @param size the size in bytes of the matrix. /// @param size the size in bytes of the matrix.
/// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst). /// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst).
utils::Vector<const ast::StructMember*, 4> DecomposedMatrixStructMembers( utils::Vector<const StructMember*, 4> DecomposedMatrixStructMembers(
const type::Matrix* mat, const type::Matrix* mat,
const std::string& name_prefix, const std::string& name_prefix,
uint32_t align, uint32_t align,
@ -463,9 +461,9 @@ struct Std140::State {
// Replace the member with column vectors. // Replace the member with column vectors.
const auto num_columns = mat->columns(); const auto num_columns = mat->columns();
// Build a struct member for each column of the matrix // Build a struct member for each column of the matrix
utils::Vector<const ast::StructMember*, 4> out; utils::Vector<const StructMember*, 4> out;
for (uint32_t i = 0; i < num_columns; i++) { for (uint32_t i = 0; i < num_columns; i++) {
utils::Vector<const ast::Attribute*, 1> attributes; utils::Vector<const Attribute*, 1> attributes;
if ((i == 0) && mat->Align() != align) { if ((i == 0) && mat->Align() != align) {
// The matrix was @align() annotated with a larger alignment // The matrix was @align() annotated with a larger alignment
// than the natural alignment for the matrix. This extra padding // than the natural alignment for the matrix. This extra padding
@ -493,7 +491,7 @@ struct Std140::State {
/// Walks the @p ast_expr, constructing and returning an AccessChain. /// Walks the @p ast_expr, constructing and returning an AccessChain.
/// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer, /// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer,
/// otherwise returns a std::nullopt. /// otherwise returns a std::nullopt.
std::optional<AccessChain> AccessChainFor(const ast::Expression* ast_expr) { std::optional<AccessChain> AccessChainFor(const Expression* ast_expr) {
auto* expr = sem.GetVal(ast_expr); auto* expr = sem.GetVal(ast_expr);
if (!expr) { if (!expr) {
return std::nullopt; return std::nullopt;
@ -576,10 +574,10 @@ struct Std140::State {
[&](const sem::ValueExpression* e) { [&](const sem::ValueExpression* e) {
// Walk past indirection and address-of unary ops. // Walk past indirection and address-of unary ops.
return Switch(e->Declaration(), // return Switch(e->Declaration(), //
[&](const ast::UnaryOpExpression* u) { [&](const UnaryOpExpression* u) {
switch (u->op) { switch (u->op) {
case ast::UnaryOp::kAddressOf: case UnaryOp::kAddressOf:
case ast::UnaryOp::kIndirection: case UnaryOp::kIndirection:
expr = sem.GetVal(u->expr); expr = sem.GetVal(u->expr);
return Action::kContinue; return Action::kContinue;
default: default:
@ -660,8 +658,8 @@ struct Std140::State {
/// Generates and returns an expression that loads the value from a std140 uniform buffer, /// Generates and returns an expression that loads the value from a std140 uniform buffer,
/// converting the final result to a non-std140 type. /// converting the final result to a non-std140 type.
/// @param chain the access chain from a uniform buffer to the value to load. /// @param chain the access chain from a uniform buffer to the value to load.
const ast::Expression* LoadWithConvert(const AccessChain& chain) { const Expression* LoadWithConvert(const AccessChain& chain) {
const ast::Expression* expr = nullptr; const Expression* expr = nullptr;
const type::Type* ty = nullptr; const type::Type* ty = nullptr;
auto dynamic_index = [&](size_t idx) { auto dynamic_index = [&](size_t idx) {
return ctx.Clone(chain.dynamic_indices[idx]->Declaration()); return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
@ -678,7 +676,7 @@ struct Std140::State {
/// std140-forked type to the type @p ty. If @p expr is not a std140-forked type, then Convert() /// std140-forked type to the type @p ty. If @p expr is not a std140-forked type, then Convert()
/// will simply return @p expr. /// will simply return @p expr.
/// @returns the converted value expression. /// @returns the converted value expression.
const ast::Expression* Convert(const type::Type* ty, const ast::Expression* expr) { const Expression* Convert(const type::Type* ty, const Expression* expr) {
// Get an existing, or create a new function for converting the std140 type to ty. // Get an existing, or create a new function for converting the std140 type to ty.
auto fn = conv_fns.GetOrCreate(ty, [&] { auto fn = conv_fns.GetOrCreate(ty, [&] {
auto std140_ty = Std140Type(ty); auto std140_ty = Std140Type(ty);
@ -690,20 +688,20 @@ struct Std140::State {
// The converter function takes a single argument of the std140 type. // The converter function takes a single argument of the std140 type.
auto* param = b.Param("val", std140_ty); auto* param = b.Param("val", std140_ty);
utils::Vector<const ast::Statement*, 3> stmts; utils::Vector<const Statement*, 3> stmts;
Switch( Switch(
ty, // ty, //
[&](const type::Struct* str) { [&](const type::Struct* str) {
// Convert each of the structure members using either a converter function // Convert each of the structure members using either a converter function
// call, or by reassembling a std140 matrix from column vector members. // call, or by reassembling a std140 matrix from column vector members.
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const Expression*, 8> args;
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
if (auto col_members = std140_mat_members.Find(member)) { if (auto col_members = std140_mat_members.Find(member)) {
// std140 decomposed matrix. Reassemble. // std140 decomposed matrix. Reassemble.
auto mat_ty = CreateASTTypeFor(ctx, member->Type()); auto mat_ty = CreateASTTypeFor(ctx, member->Type());
auto mat_args = auto mat_args =
utils::Transform(*col_members, [&](const ast::StructMember* m) { utils::Transform(*col_members, [&](const StructMember* m) {
return b.MemberAccessor(param, m->name->symbol); return b.MemberAccessor(param, m->name->symbol);
}); });
args.Push(b.Call(mat_ty, std::move(mat_args))); args.Push(b.Call(mat_ty, std::move(mat_args)));
@ -719,7 +717,7 @@ struct Std140::State {
// Reassemble a std140 matrix from the structure of column vector members. // Reassemble a std140 matrix from the structure of column vector members.
auto std140_mat = std140_mats.Get(mat); auto std140_mat = std140_mats.Get(mat);
if (TINT_LIKELY(std140_mat)) { if (TINT_LIKELY(std140_mat)) {
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const Expression*, 8> args;
// std140 decomposed matrix. Reassemble. // std140 decomposed matrix. Reassemble.
auto mat_ty = CreateASTTypeFor(ctx, mat); auto mat_ty = CreateASTTypeFor(ctx, mat);
auto mat_args = utils::Transform(std140_mat->columns, [&](Symbol name) { auto mat_args = utils::Transform(std140_mat->columns, [&](Symbol name) {
@ -782,7 +780,7 @@ struct Std140::State {
/// @param access the access chain from the uniform buffer to either the whole matrix or part of /// @param access the access chain from the uniform buffer to either the whole matrix or part of
/// the matrix (column, column-swizzle, or element). /// the matrix (column, column-swizzle, or element).
/// @returns the loaded value expression. /// @returns the loaded value expression.
const ast::Expression* LoadMatrixWithFn(const AccessChain& access) { const Expression* LoadMatrixWithFn(const AccessChain& access) {
// Get an existing, or create a new function for loading the uniform buffer value. // Get an existing, or create a new function for loading the uniform buffer value.
// This function is keyed off the uniform buffer variable and the access chain. // This function is keyed off the uniform buffer variable and the access chain.
auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] { auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] {
@ -810,14 +808,14 @@ struct Std140::State {
/// column-swizzle, or element). /// column-swizzle, or element).
/// @note The matrix column must be statically indexed to use this method. /// @note The matrix column must be statically indexed to use this method.
/// @returns the loaded value expression. /// @returns the loaded value expression.
const ast::Expression* LoadSubMatrixInline(const AccessChain& chain) { const Expression* LoadSubMatrixInline(const AccessChain& chain) {
// Method for generating dynamic index expressions. // Method for generating dynamic index expressions.
// As this is inline, we can just clone the expression. // As this is inline, we can just clone the expression.
auto dynamic_index = [&](size_t idx) { auto dynamic_index = [&](size_t idx) {
return ctx.Clone(chain.dynamic_indices[idx]->Declaration()); return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
}; };
const ast::Expression* expr = nullptr; const Expression* expr = nullptr;
const type::Type* ty = nullptr; const type::Type* ty = nullptr;
// Build the expression up to, but not including the matrix member // Build the expression up to, but not including the matrix member
@ -891,7 +889,7 @@ struct Std140::State {
std::string name = "load"; std::string name = "load";
// The switch cases // The switch cases
utils::Vector<const ast::CaseStatement*, 4> cases; utils::Vector<const CaseStatement*, 4> cases;
// The function return type. // The function return type.
const type::Type* ret_ty = nullptr; const type::Type* ret_ty = nullptr;
@ -899,7 +897,7 @@ struct Std140::State {
// Build switch() cases for each column of the matrix // Build switch() cases for each column of the matrix
auto num_columns = chain.std140_mat_ty->columns(); auto num_columns = chain.std140_mat_ty->columns();
for (uint32_t column_idx = 0; column_idx < num_columns; column_idx++) { for (uint32_t column_idx = 0; column_idx < num_columns; column_idx++) {
const ast::Expression* expr = nullptr; const Expression* expr = nullptr;
const type::Type* ty = nullptr; const type::Type* ty = nullptr;
// Build the expression up to, but not including the matrix // Build the expression up to, but not including the matrix
@ -991,7 +989,7 @@ struct Std140::State {
return b.Expr(dynamic_index_params[idx]->name->symbol); return b.Expr(dynamic_index_params[idx]->name->symbol);
}; };
const ast::Expression* expr = nullptr; const Expression* expr = nullptr;
const type::Type* ty = nullptr; const type::Type* ty = nullptr;
std::string name = "load"; std::string name = "load";
@ -1005,13 +1003,13 @@ struct Std140::State {
name += "_" + access_name; name += "_" + access_name;
} }
utils::Vector<const ast::Statement*, 2> stmts; utils::Vector<const Statement*, 2> stmts;
// Create a temporary pointer to the structure that holds the matrix columns // Create a temporary pointer to the structure that holds the matrix columns
auto* let = b.Let("s", b.AddressOf(expr)); auto* let = b.Let("s", b.AddressOf(expr));
stmts.Push(b.Decl(let)); stmts.Push(b.Decl(let));
utils::Vector<const ast::MemberAccessorExpression*, 4> columns; utils::Vector<const MemberAccessorExpression*, 4> columns;
if (auto* str = tint::As<type::Struct>(ty)) { if (auto* str = tint::As<type::Struct>(ty)) {
// Structure member matrix. The columns are decomposed into the structure. // Structure member matrix. The columns are decomposed into the structure.
auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]); auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]);
@ -1053,7 +1051,7 @@ struct Std140::State {
/// Return type of BuildAccessExpr() /// Return type of BuildAccessExpr()
struct ExprTypeName { struct ExprTypeName {
/// The new, post-access expression /// The new, post-access expression
const ast::Expression* expr; const Expression* expr;
/// The type of #expr /// The type of #expr
const type::Type* type; const type::Type* type;
/// A name segment which can be used to build sensible names for helper functions /// A name segment which can be used to build sensible names for helper functions
@ -1067,11 +1065,11 @@ struct Std140::State {
/// @param dynamic_index a function that obtains the i'th dynamic index /// @param dynamic_index a function that obtains the i'th dynamic index
/// @returns a ExprTypeName which holds the new expression, new type and a name segment which /// @returns a ExprTypeName which holds the new expression, new type and a name segment which
/// can be used for creating helper function names. /// can be used for creating helper function names.
ExprTypeName BuildAccessExpr(const ast::Expression* lhs, ExprTypeName BuildAccessExpr(const Expression* lhs,
const type::Type* ty, const type::Type* ty,
const AccessChain& chain, const AccessChain& chain,
size_t index, size_t index,
std::function<const ast::Expression*(size_t)> dynamic_index) { std::function<const Expression*(size_t)> dynamic_index) {
auto& access = chain.indices[index]; auto& access = chain.indices[index];
if (std::get_if<UniformVariable>(&access)) { if (std::get_if<UniformVariable>(&access)) {

View File

@ -32,7 +32,7 @@ namespace {
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* node : program->AST().GlobalVariables()) { for (auto* node : program->AST().GlobalVariables()) {
if (node->Is<ast::Override>()) { if (node->Is<Override>()) {
return true; return true;
} }
} }
@ -61,12 +61,12 @@ Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
return SkipTransform; return SkipTransform;
} }
ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* { ctx.ReplaceAll([&](const Override* w) -> const Const* {
auto* sem = ctx.src->Sem().Get(w); auto* sem = ctx.src->Sem().Get(w);
auto source = ctx.Clone(w->source); auto source = ctx.Clone(w->source);
auto sym = ctx.Clone(w->name->symbol); auto sym = ctx.Clone(w->name->symbol);
ast::Type ty = w->type ? ctx.Clone(w->type) : ast::Type{}; Type ty = w->type ? ctx.Clone(w->type) : Type{};
// No replacement provided, just clone the override node as a const. // No replacement provided, just clone the override node as a const.
auto iter = data->map.find(sem->OverrideId()); auto iter = data->map.find(sem->OverrideId());
@ -102,7 +102,7 @@ Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
// If the object is not materialized, and the 'override' variable is turned to a 'const', the // If the object is not materialized, and the 'override' variable is turned to a 'const', the
// resulting type of the index may change. See: crbug.com/tint/1697. // resulting type of the index may change. See: crbug.com/tint/1697.
ctx.ReplaceAll( ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* { [&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
if (auto* sem = src->Sem().Get(expr)) { if (auto* sem = src->Sem().Get(expr)) {
if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) { if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() && if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&

View File

@ -85,18 +85,18 @@ struct Texture1DTo2D::State {
return SkipTransform; return SkipTransform;
} }
auto create_var = [&](const ast::Variable* v, ast::Type type) -> const ast::Variable* { auto create_var = [&](const Variable* v, Type type) -> const Variable* {
if (v->As<ast::Parameter>()) { if (v->As<Parameter>()) {
return ctx.dst->Param(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes)); return ctx.dst->Param(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
} else { } else {
return ctx.dst->Var(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes)); return ctx.dst->Var(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
} }
}; };
ctx.ReplaceAll([&](const ast::Variable* v) -> const ast::Variable* { ctx.ReplaceAll([&](const Variable* v) -> const Variable* {
const ast::Variable* r = Switch( const Variable* r = Switch(
sem.Get(v)->Type()->UnwrapRef(), sem.Get(v)->Type()->UnwrapRef(),
[&](const type::SampledTexture* tex) -> const ast::Variable* { [&](const type::SampledTexture* tex) -> const Variable* {
if (tex->dim() == type::TextureDimension::k1d) { if (tex->dim() == type::TextureDimension::k1d) {
auto type = ctx.dst->ty.sampled_texture(type::TextureDimension::k2d, auto type = ctx.dst->ty.sampled_texture(type::TextureDimension::k2d,
CreateASTTypeFor(ctx, tex->type())); CreateASTTypeFor(ctx, tex->type()));
@ -105,7 +105,7 @@ struct Texture1DTo2D::State {
return nullptr; return nullptr;
} }
}, },
[&](const type::StorageTexture* storage_tex) -> const ast::Variable* { [&](const type::StorageTexture* storage_tex) -> const Variable* {
if (storage_tex->dim() == type::TextureDimension::k1d) { if (storage_tex->dim() == type::TextureDimension::k1d) {
auto type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d, auto type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d,
storage_tex->texel_format(), storage_tex->texel_format(),
@ -119,7 +119,7 @@ struct Texture1DTo2D::State {
return r; return r;
}); });
ctx.ReplaceAll([&](const ast::CallExpression* c) -> const ast::Expression* { ctx.ReplaceAll([&](const CallExpression* c) -> const Expression* {
auto* call = sem.Get(c)->UnwrapMaterialize()->As<sem::Call>(); auto* call = sem.Get(c)->UnwrapMaterialize()->As<sem::Call>();
if (!call) { if (!call) {
return nullptr; return nullptr;
@ -141,7 +141,7 @@ struct Texture1DTo2D::State {
if (builtin->Type() == builtin::Function::kTextureDimensions) { if (builtin->Type() == builtin::Function::kTextureDimensions) {
// If this textureDimensions() call is in a CallStatement, we can leave it // If this textureDimensions() call is in a CallStatement, we can leave it
// unmodified since the return value will be dropped on the floor anyway. // unmodified since the return value will be dropped on the floor anyway.
if (call->Stmt()->Declaration()->Is<ast::CallStatement>()) { if (call->Stmt()->Declaration()->Is<CallStatement>()) {
return nullptr; return nullptr;
} }
auto* new_call = ctx.CloneWithoutTransform(c); auto* new_call = ctx.CloneWithoutTransform(c);
@ -153,14 +153,14 @@ struct Texture1DTo2D::State {
return nullptr; return nullptr;
} }
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const Expression*, 8> args;
int index = 0; int index = 0;
for (auto* arg : c->args) { for (auto* arg : c->args) {
if (index == coords_index) { if (index == coords_index) {
auto* ctype = call->Arguments()[static_cast<size_t>(coords_index)]->Type(); auto* ctype = call->Arguments()[static_cast<size_t>(coords_index)]->Type();
auto* coords = c->args[static_cast<size_t>(coords_index)]; auto* coords = c->args[static_cast<size_t>(coords_index)];
const ast::LiteralExpression* half = nullptr; const LiteralExpression* half = nullptr;
if (ctype->is_integer_scalar()) { if (ctype->is_integer_scalar()) {
half = ctx.dst->Expr(0_a); half = ctx.dst->Expr(0_a);
} else { } else {

View File

@ -60,23 +60,23 @@ Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const
return output; return output;
} }
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) { void Transform::RemoveStatement(CloneContext& ctx, const Statement* stmt) {
auto* sem = ctx.src->Sem().Get(stmt); auto* sem = ctx.src->Sem().Get(stmt);
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) { if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
ctx.Remove(block->Declaration()->statements, stmt); ctx.Remove(block->Declaration()->statements, stmt);
return; return;
} }
if (TINT_LIKELY(tint::Is<sem::ForLoopStatement>(sem->Parent()))) { if (TINT_LIKELY(tint::Is<sem::ForLoopStatement>(sem->Parent()))) {
ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr)); ctx.Replace(stmt, static_cast<Expression*>(nullptr));
return; return;
} }
TINT_ICE(Transform, ctx.dst->Diagnostics()) TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unable to remove statement from parent of type " << sem->TypeInfo().name; << "unable to remove statement from parent of type " << sem->TypeInfo().name;
} }
ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) { Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
if (ty->Is<type::Void>()) { if (ty->Is<type::Void>()) {
return ast::Type{}; return Type{};
} }
if (ty->Is<type::I32>()) { if (ty->Is<type::I32>()) {
return ctx.dst->ty.i32(); return ctx.dst->ty.i32();
@ -108,9 +108,9 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
} }
if (auto* a = ty->As<type::Array>()) { if (auto* a = ty->As<type::Array>()) {
auto el = CreateASTTypeFor(ctx, a->ElemType()); auto el = CreateASTTypeFor(ctx, a->ElemType());
utils::Vector<const ast::Attribute*, 1> attrs; utils::Vector<const Attribute*, 1> attrs;
if (!a->IsStrideImplicit()) { if (!a->IsStrideImplicit()) {
attrs.Push(ctx.dst->create<ast::StrideAttribute>(a->Stride())); attrs.Push(ctx.dst->create<StrideAttribute>(a->Stride()));
} }
if (a->Count()->Is<type::RuntimeArrayCount>()) { if (a->Count()->Is<type::RuntimeArrayCount>()) {
return ctx.dst->ty.array(el, std::move(attrs)); return ctx.dst->ty.array(el, std::move(attrs));
@ -125,7 +125,7 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
// See crbug.com/tint/1764. // See crbug.com/tint/1764.
// Look for a type alias for this array. // Look for a type alias for this array.
for (auto* type_decl : ctx.src->AST().TypeDecls()) { for (auto* type_decl : ctx.src->AST().TypeDecls()) {
if (auto* alias = type_decl->As<ast::Alias>()) { if (auto* alias = type_decl->As<Alias>()) {
if (ty == ctx.src->Sem().Get(alias)) { if (ty == ctx.src->Sem().Get(alias)) {
// Alias found. Use the alias name to ensure types compare equal. // Alias found. Use the alias name to ensure types compare equal.
return ctx.dst->ty(ctx.Clone(alias->name->symbol)); return ctx.dst->ty(ctx.Clone(alias->name->symbol));
@ -184,7 +184,7 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
} }
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "Unhandled type: " << ty->TypeInfo().name; << "Unhandled type: " << ty->TypeInfo().name;
return ast::Type{}; return Type{};
} }
} // namespace tint::ast::transform } // namespace tint::ast::transform

View File

@ -181,11 +181,11 @@ class Transform : public utils::Castable<Transform> {
const DataMap& inputs, const DataMap& inputs,
DataMap& outputs) const = 0; DataMap& outputs) const = 0;
/// CreateASTTypeFor constructs new ast::Type that reconstructs the semantic type `ty`. /// CreateASTTypeFor constructs new Type that reconstructs the semantic type `ty`.
/// @param ctx the clone context /// @param ctx the clone context
/// @param ty the semantic type to reconstruct /// @param ty the semantic type to reconstruct
/// @returns an ast::Type that when resolved, will produce the semantic type `ty`. /// @returns an Type that when resolved, will produce the semantic type `ty`.
static ast::Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty); static Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
protected: protected:
/// Removes the statement `stmt` from the transformed program. /// Removes the statement `stmt` from the transformed program.
@ -193,7 +193,7 @@ class Transform : public utils::Castable<Transform> {
/// continuing of for-loops. /// continuing of for-loops.
/// @param ctx the clone context /// @param ctx the clone context
/// @param stmt the statement to remove when the program is cloned /// @param stmt the statement to remove when the program is cloned
static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt); static void RemoveStatement(CloneContext& ctx, const Statement* stmt);
}; };
} // namespace tint::ast::transform } // namespace tint::ast::transform

View File

@ -32,7 +32,7 @@ struct CreateASTTypeForTest : public testing::Test, public Transform {
return SkipTransform; return SkipTransform;
} }
ast::Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) { Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
ProgramBuilder sem_type_builder; ProgramBuilder sem_type_builder;
auto* sem_type = create_sem_type(sem_type_builder); auto* sem_type = create_sem_type(sem_type_builder);
Program program(std::move(sem_type_builder)); Program program(std::move(sem_type_builder));
@ -44,9 +44,7 @@ struct CreateASTTypeForTest : public testing::Test, public Transform {
}; };
TEST_F(CreateASTTypeForTest, Basic) { TEST_F(CreateASTTypeForTest, Basic) {
auto check = [&](ast::Type ty, const char* expect) { auto check = [&](Type ty, const char* expect) { CheckIdentifier(ty->identifier, expect); };
ast::CheckIdentifier(ty->identifier, expect);
};
check(create([](ProgramBuilder& b) { return b.create<type::I32>(); }), "i32"); check(create([](ProgramBuilder& b) { return b.create<type::I32>(); }), "i32");
check(create([](ProgramBuilder& b) { return b.create<type::U32>(); }), "u32"); check(create([](ProgramBuilder& b) { return b.create<type::U32>(); }), "u32");
@ -61,14 +59,14 @@ TEST_F(CreateASTTypeForTest, Matrix) {
return b.create<type::Matrix>(column_type, 3u); return b.create<type::Matrix>(column_type, 3u);
}); });
ast::CheckIdentifier(mat, ast::Template("mat3x2", "f32")); CheckIdentifier(mat, Template("mat3x2", "f32"));
} }
TEST_F(CreateASTTypeForTest, Vector) { TEST_F(CreateASTTypeForTest, Vector) {
auto vec = auto vec =
create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); }); create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
ast::CheckIdentifier(vec, ast::Template("vec2", "f32")); CheckIdentifier(vec, Template("vec2", "f32"));
} }
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) { TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
@ -77,8 +75,8 @@ TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
4u, 4u, 32u, 32u); 4u, 4u, 32u, 32u);
}); });
ast::CheckIdentifier(arr, ast::Template("array", "f32", 2_u)); CheckIdentifier(arr, Template("array", "f32", 2_u));
auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>(); auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
ASSERT_NE(tmpl_attr, nullptr); ASSERT_NE(tmpl_attr, nullptr);
EXPECT_TRUE(tmpl_attr->attributes.IsEmpty()); EXPECT_TRUE(tmpl_attr->attributes.IsEmpty());
} }
@ -88,12 +86,12 @@ TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u), return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
4u, 4u, 64u, 32u); 4u, 4u, 64u, 32u);
}); });
ast::CheckIdentifier(arr, ast::Template("array", "f32", 2_u)); CheckIdentifier(arr, Template("array", "f32", 2_u));
auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>(); auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
ASSERT_NE(tmpl_attr, nullptr); ASSERT_NE(tmpl_attr, nullptr);
ASSERT_EQ(tmpl_attr->attributes.Length(), 1u); ASSERT_EQ(tmpl_attr->attributes.Length(), 1u);
ASSERT_TRUE(tmpl_attr->attributes[0]->Is<ast::StrideAttribute>()); ASSERT_TRUE(tmpl_attr->attributes[0]->Is<StrideAttribute>());
ASSERT_EQ(tmpl_attr->attributes[0]->As<ast::StrideAttribute>()->stride, 64u); ASSERT_EQ(tmpl_attr->attributes[0]->As<StrideAttribute>()->stride, 64u);
} }
// crbug.com/tint/1764 // crbug.com/tint/1764
@ -114,7 +112,7 @@ TEST_F(CreateASTTypeForTest, AliasedArrayWithComplexOverrideLength) {
CloneContext ctx(&ast_type_builder, &program, false); CloneContext ctx(&ast_type_builder, &program, false);
auto ast_ty = CreateASTTypeFor(ctx, arr_ty); auto ast_ty = CreateASTTypeFor(ctx, arr_ty);
ast::CheckIdentifier(ast_ty, "A"); CheckIdentifier(ast_ty, "A");
} }
TEST_F(CreateASTTypeForTest, Struct) { TEST_F(CreateASTTypeForTest, Struct) {
@ -124,7 +122,7 @@ TEST_F(CreateASTTypeForTest, Struct) {
4u /* size */, 4u /* size_no_padding */); 4u /* size */, 4u /* size_no_padding */);
}); });
ast::CheckIdentifier(str, "S"); CheckIdentifier(str, "S");
} }
TEST_F(CreateASTTypeForTest, PrivatePointer) { TEST_F(CreateASTTypeForTest, PrivatePointer) {
@ -133,7 +131,7 @@ TEST_F(CreateASTTypeForTest, PrivatePointer) {
builtin::Access::kReadWrite); builtin::Access::kReadWrite);
}); });
ast::CheckIdentifier(ptr, ast::Template("ptr", "private", "i32")); CheckIdentifier(ptr, Template("ptr", "private", "i32"));
} }
TEST_F(CreateASTTypeForTest, StorageReadWritePointer) { TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
@ -142,7 +140,7 @@ TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
builtin::Access::kReadWrite); builtin::Access::kReadWrite);
}); });
ast::CheckIdentifier(ptr, ast::Template("ptr", "storage", "i32", "read_write")); CheckIdentifier(ptr, Template("ptr", "storage", "i32", "read_write"));
} }
} // namespace } // namespace

View File

@ -74,7 +74,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
continue; continue;
} }
if (func_ast->PipelineStage() != ast::PipelineStage::kVertex) { if (func_ast->PipelineStage() != PipelineStage::kVertex) {
// Currently only vertex stage could have interstage output variables that need // Currently only vertex stage could have interstage output variables that need
// truncated. // truncated.
continue; continue;
@ -118,8 +118,8 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] { old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] {
auto new_struct_sym = b.Symbols().New(); auto new_struct_sym = b.Symbols().New();
utils::Vector<const ast::StructMember*, 20> truncated_members; utils::Vector<const StructMember*, 20> truncated_members;
utils::Vector<const ast::Expression*, 20> initializer_exprs; utils::Vector<const Expression*, 20> initializer_exprs;
for (auto* member : str->Members()) { for (auto* member : str->Members()) {
if (omit_members.Contains(member)) { if (omit_members.Contains(member)) {
@ -155,7 +155,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
// Replace return statements with new truncated shader IO struct // Replace return statements with new truncated shader IO struct
ctx.ReplaceAll( ctx.ReplaceAll(
[&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* { [&](const ReturnStatement* return_statement) -> const ReturnStatement* {
auto* return_sem = sem.Get(return_statement); auto* return_sem = sem.Get(return_statement);
if (auto mapping_fn_sym = if (auto mapping_fn_sym =
entry_point_functions_to_truncate_functions.Find(return_sem->Function())) { entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
@ -168,11 +168,11 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
// Remove IO attributes from old shader IO struct which is not used as entry point output // Remove IO attributes from old shader IO struct which is not used as entry point output
// anymore. // anymore.
for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) { for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) {
const ast::Struct* struct_ty = it.key->Declaration(); const Struct* struct_ty = it.key->Declaration();
for (auto* member : struct_ty->members) { for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) { for (auto* attr : member->attributes) {
if (attr->IsAnyOf<ast::BuiltinAttribute, ast::LocationAttribute, if (attr->IsAnyOf<BuiltinAttribute, LocationAttribute, InterpolateAttribute,
ast::InterpolateAttribute, ast::InvariantAttribute>()) { InvariantAttribute>()) {
ctx.Remove(member->attributes, attr); ctx.Remove(member->attributes, attr);
} }
} }

View File

@ -50,29 +50,27 @@ struct Unshadow::State {
// Maps a variable to its new name. // Maps a variable to its new name.
utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to; utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
auto rename = [&](const sem::Variable* v) -> const ast::Variable* { auto rename = [&](const sem::Variable* v) -> const Variable* {
auto* decl = v->Declaration(); auto* decl = v->Declaration();
auto name = decl->name->symbol.Name(); auto name = decl->name->symbol.Name();
auto symbol = b.Symbols().New(name); auto symbol = b.Symbols().New(name);
renamed_to.Add(v, symbol); renamed_to.Add(v, symbol);
auto source = ctx.Clone(decl->source); auto source = ctx.Clone(decl->source);
auto type = decl->type ? ctx.Clone(decl->type) : ast::Type{}; auto type = decl->type ? ctx.Clone(decl->type) : Type{};
auto* initializer = ctx.Clone(decl->initializer); auto* initializer = ctx.Clone(decl->initializer);
auto attributes = ctx.Clone(decl->attributes); auto attributes = ctx.Clone(decl->attributes);
return Switch( return Switch(
decl, // decl, //
[&](const ast::Var* var) { [&](const Var* var) {
return b.Var(source, symbol, type, var->declared_address_space, return b.Var(source, symbol, type, var->declared_address_space,
var->declared_access, initializer, attributes); var->declared_access, initializer, attributes);
}, },
[&](const ast::Let*) { [&](const Let*) { return b.Let(source, symbol, type, initializer, attributes); },
return b.Let(source, symbol, type, initializer, attributes); [&](const Const*) {
},
[&](const ast::Const*) {
return b.Const(source, symbol, type, initializer, attributes); return b.Const(source, symbol, type, initializer, attributes);
}, },
[&](const ast::Parameter*) { // [&](const Parameter*) { //
return b.Param(source, symbol, type, attributes); return b.Param(source, symbol, type, attributes);
}, },
[&](Default) { [&](Default) {
@ -105,8 +103,7 @@ struct Unshadow::State {
return SkipTransform; return SkipTransform;
} }
ctx.ReplaceAll( ctx.ReplaceAll([&](const IdentifierExpression* ident) -> const IdentifierExpression* {
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
if (auto* sem_ident = sem.GetVal(ident)) { if (auto* sem_ident = sem.GetVal(ident)) {
if (auto* user = sem_ident->Unwrap()->As<sem::VariableUser>()) { if (auto* user = sem_ident->Unwrap()->As<sem::VariableUser>()) {
if (auto renamed = renamed_to.Find(user->Variable())) { if (auto renamed = renamed_to.Find(user->Variable())) {

View File

@ -20,10 +20,10 @@
namespace tint::ast::transform::utils { namespace tint::ast::transform::utils {
InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt) { InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt) {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
auto& diag = ctx.dst->Diagnostics(); auto& diag = ctx.dst->Diagnostics();
using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>; using RetType = std::pair<const sem::BlockStatement*, const Statement*>;
if (auto* sem_stmt = sem.Get(stmt)) { if (auto* sem_stmt = sem.Get(stmt)) {
auto* parent = sem_stmt->Parent(); auto* parent = sem_stmt->Parent();

View File

@ -24,7 +24,7 @@ namespace tint::ast::transform::utils {
/// InsertionPoint is a pair of the block (`first`) within which, and the /// InsertionPoint is a pair of the block (`first`) within which, and the
/// statement (`second`) before or after which to insert. /// statement (`second`) before or after which to insert.
using InsertionPoint = std::pair<const sem::BlockStatement*, const ast::Statement*>; using InsertionPoint = std::pair<const sem::BlockStatement*, const Statement*>;
/// For the input statement, returns the block and statement within that /// For the input statement, returns the block and statement within that
/// block to insert before/after. If `stmt` is a for-loop continue statement, /// block to insert before/after. If `stmt` is a for-loop continue statement,
@ -32,7 +32,7 @@ using InsertionPoint = std::pair<const sem::BlockStatement*, const ast::Statemen
/// @param ctx the clone context /// @param ctx the clone context
/// @param stmt the statement to insert before or after /// @param stmt the statement to insert before or after
/// @return the insertion point /// @return the insertion point
InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt); InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt);
} // namespace tint::ast::transform::utils } // namespace tint::ast::transform::utils

View File

@ -37,7 +37,7 @@ struct HoistToDeclBefore::State {
/// @copydoc HoistToDeclBefore::Add() /// @copydoc HoistToDeclBefore::Add()
bool Add(const sem::ValueExpression* before_expr, bool Add(const sem::ValueExpression* before_expr,
const ast::Expression* expr, const Expression* expr,
VariableKind kind, VariableKind kind,
const char* decl_name) { const char* decl_name) {
auto name = b.Symbols().New(decl_name); auto name = b.Symbols().New(decl_name);
@ -85,8 +85,8 @@ struct HoistToDeclBefore::State {
return true; return true;
} }
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*) /// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const Statement*)
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) { bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
if (stmt) { if (stmt) {
auto builder = [stmt] { return stmt; }; auto builder = [stmt] { return stmt; };
return InsertBeforeImpl(before_stmt, std::move(builder)); return InsertBeforeImpl(before_stmt, std::move(builder));
@ -99,8 +99,8 @@ struct HoistToDeclBefore::State {
return InsertBeforeImpl(before_stmt, std::move(builder)); return InsertBeforeImpl(before_stmt, std::move(builder));
} }
/// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with) /// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with)
bool Replace(const sem::Statement* what, const ast::Statement* with) { bool Replace(const sem::Statement* what, const Statement* with) {
auto builder = [with] { return with; }; auto builder = [with] { return with; };
return Replace(what, std::move(builder)); return Replace(what, std::move(builder));
} }
@ -145,7 +145,7 @@ struct HoistToDeclBefore::State {
utils::Hashmap<const sem::WhileStatement*, LoopInfo, 4> while_loops; utils::Hashmap<const sem::WhileStatement*, LoopInfo, 4> while_loops;
/// 'else if' statements that need to be decomposed to 'else {if}' /// 'else if' statements that need to be decomposed to 'else {if}'
utils::Hashmap<const ast::IfStatement*, ElseIfInfo, 4> else_ifs; utils::Hashmap<const IfStatement*, ElseIfInfo, 4> else_ifs;
template <size_t N> template <size_t N>
static auto Build(const utils::Vector<StmtBuilder, N>& builders) { static auto Build(const utils::Vector<StmtBuilder, N>& builders) {
@ -181,7 +181,7 @@ struct HoistToDeclBefore::State {
/// automatically called. /// automatically called.
/// @warning the returned reference is invalid if this is called a second time, or the /// @warning the returned reference is invalid if this is called a second time, or the
/// #else_ifs map is mutated. /// #else_ifs map is mutated.
auto ElseIf(const ast::IfStatement* else_if) { auto ElseIf(const IfStatement* else_if) {
if (else_ifs.IsEmpty()) { if (else_ifs.IsEmpty()) {
RegisterElseIfTransform(); RegisterElseIfTransform();
} }
@ -190,7 +190,7 @@ struct HoistToDeclBefore::State {
/// Registers the handler for transforming for-loops based on the content of the #for_loops map. /// Registers the handler for transforming for-loops based on the content of the #for_loops map.
void RegisterForLoopTransform() const { void RegisterForLoopTransform() const {
ctx.ReplaceAll([&](const ast::ForLoopStatement* stmt) -> const ast::Statement* { ctx.ReplaceAll([&](const ForLoopStatement* stmt) -> const Statement* {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) { if (auto* fl = sem.Get(stmt)) {
@ -205,9 +205,9 @@ struct HoistToDeclBefore::State {
if (auto* cond = for_loop->condition) { if (auto* cond = for_loop->condition) {
// !condition // !condition
auto* not_cond = auto* not_cond =
b.create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond)); b.create<UnaryOpExpression>(UnaryOp::kNot, ctx.Clone(cond));
// { break; } // { break; }
auto* break_body = b.Block(b.create<ast::BreakStatement>()); auto* break_body = b.Block(b.create<BreakStatement>());
// if (!condition) { break; } // if (!condition) { break; }
body_stmts.Push(b.If(not_cond, break_body)); body_stmts.Push(b.If(not_cond, break_body));
} }
@ -215,7 +215,7 @@ struct HoistToDeclBefore::State {
body_stmts.Push(ctx.Clone(for_loop->body)); body_stmts.Push(ctx.Clone(for_loop->body));
// Create the continuing block if there was one. // Create the continuing block if there was one.
const ast::BlockStatement* continuing = nullptr; const BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) { if (auto* cont = for_loop->continuing) {
// Continuing block starts with any let declarations used by // Continuing block starts with any let declarations used by
// the continuing. // the continuing.
@ -249,7 +249,7 @@ struct HoistToDeclBefore::State {
/// map. /// map.
void RegisterWhileLoopTransform() const { void RegisterWhileLoopTransform() const {
// At least one while needs to be transformed into a loop. // At least one while needs to be transformed into a loop.
ctx.ReplaceAll([&](const ast::WhileStatement* stmt) -> const ast::Statement* { ctx.ReplaceAll([&](const WhileStatement* stmt) -> const Statement* {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
if (auto* w = sem.Get(stmt)) { if (auto* w = sem.Get(stmt)) {
@ -274,7 +274,7 @@ struct HoistToDeclBefore::State {
// Next emit the body // Next emit the body
body_stmts.Push(ctx.Clone(while_loop->body)); body_stmts.Push(ctx.Clone(while_loop->body));
const ast::BlockStatement* continuing = nullptr; const BlockStatement* continuing = nullptr;
auto* body = b.Block(body_stmts); auto* body = b.Block(body_stmts);
auto* loop = b.Loop(body, continuing); auto* loop = b.Loop(body, continuing);
@ -289,7 +289,7 @@ struct HoistToDeclBefore::State {
/// map. /// map.
void RegisterElseIfTransform() const { void RegisterElseIfTransform() const {
// Decompose 'else-if' statements into 'else { if }' blocks. // Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* { ctx.ReplaceAll([&](const IfStatement* stmt) -> const Statement* {
if (auto info = else_ifs.Find(stmt)) { if (auto info = else_ifs.Find(stmt)) {
// Build the else block's body statements, starting with let decls for the // Build the else block's body statements, starting with let decls for the
// conditional expression. // conditional expression.
@ -412,14 +412,13 @@ HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_uniqu
HoistToDeclBefore::~HoistToDeclBefore() {} HoistToDeclBefore::~HoistToDeclBefore() {}
bool HoistToDeclBefore::Add(const sem::ValueExpression* before_expr, bool HoistToDeclBefore::Add(const sem::ValueExpression* before_expr,
const ast::Expression* expr, const Expression* expr,
VariableKind kind, VariableKind kind,
const char* decl_name) { const char* decl_name) {
return state_->Add(before_expr, expr, kind, decl_name); return state_->Add(before_expr, expr, kind, decl_name);
} }
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
const ast::Statement* stmt) {
return state_->InsertBefore(before_stmt, stmt); return state_->InsertBefore(before_stmt, stmt);
} }
@ -428,7 +427,7 @@ bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
return state_->InsertBefore(before_stmt, builder); return state_->InsertBefore(before_stmt, builder);
} }
bool HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with) { bool HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with) {
return state_->Replace(what, with); return state_->Replace(what, with);
} }

View File

@ -36,7 +36,7 @@ class HoistToDeclBefore {
~HoistToDeclBefore(); ~HoistToDeclBefore();
/// StmtBuilder is a builder of an AST statement /// StmtBuilder is a builder of an AST statement
using StmtBuilder = std::function<const ast::Statement*()>; using StmtBuilder = std::function<const Statement*()>;
/// VariableKind is either a var, let or const /// VariableKind is either a var, let or const
enum class VariableKind { enum class VariableKind {
@ -53,7 +53,7 @@ class HoistToDeclBefore {
/// @param decl_name optional name to use for the variable/constant name /// @param decl_name optional name to use for the variable/constant name
/// @return true on success /// @return true on success
bool Add(const sem::ValueExpression* before_expr, bool Add(const sem::ValueExpression* before_expr,
const ast::Expression* expr, const Expression* expr,
VariableKind kind, VariableKind kind,
const char* decl_name = ""); const char* decl_name = "");
@ -64,7 +64,7 @@ class HoistToDeclBefore {
/// @param before_stmt statement to insert @p stmt before /// @param before_stmt statement to insert @p stmt before
/// @param stmt statement to insert /// @param stmt statement to insert
/// @return true on success /// @return true on success
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt); bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt);
/// Inserts the returned statement of @p builder before @p before_stmt, possibly converting /// Inserts the returned statement of @p builder before @p before_stmt, possibly converting
/// 'for-loop's to 'loop's if necessary. /// 'for-loop's to 'loop's if necessary.
@ -81,7 +81,7 @@ class HoistToDeclBefore {
/// @param what the statement to replace /// @param what the statement to replace
/// @param with the replacement statement /// @param with the replacement statement
/// @return true on success /// @return true on success
bool Replace(const sem::Statement* what, const ast::Statement* with); bool Replace(const sem::Statement* what, const Statement* with);
/// Replaces the statement @p what with the statement returned by @p stmt, possibly converting /// Replaces the statement @p what with the statement returned by @p stmt, possibly converting
/// 'for-loop's to 'loop's if necessary. /// 'for-loop's to 'loop's if necessary.

View File

@ -628,7 +628,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
ProgramBuilder b; ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i))); auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd); auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block()); auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s}); b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -637,7 +637,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
CloneContext ctx(&cloned_b, &original); CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx); HoistToDeclBefore hoistToDeclBefore(ctx);
auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>()); auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo")); auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
hoistToDeclBefore.InsertBefore(before_stmt, new_stmt); hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
@ -679,7 +679,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont_Function) {
ProgramBuilder b; ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i))); auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd); auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block()); auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s}); b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -688,7 +688,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont_Function) {
CloneContext ctx(&cloned_b, &original); CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx); HoistToDeclBefore hoistToDeclBefore(ctx);
auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>()); auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
hoistToDeclBefore.InsertBefore(before_stmt, hoistToDeclBefore.InsertBefore(before_stmt,
[&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); }); [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
@ -1048,7 +1048,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) {
ProgramBuilder b; ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i))); auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd); auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block()); auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s}); b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -1057,7 +1057,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) {
CloneContext ctx(&cloned_b, &original); CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx); HoistToDeclBefore hoistToDeclBefore(ctx);
auto* target_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>()); auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo")); auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
hoistToDeclBefore.Replace(target_stmt, new_stmt); hoistToDeclBefore.Replace(target_stmt, new_stmt);
@ -1098,7 +1098,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) {
ProgramBuilder b; ProgramBuilder b;
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
auto* var = b.Decl(b.Var("a", b.Expr(1_i))); auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd); auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block()); auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s}); b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
@ -1107,7 +1107,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) {
CloneContext ctx(&cloned_b, &original); CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx); HoistToDeclBefore hoistToDeclBefore(ctx);
auto* target_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>()); auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); }); hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
ctx.Clone(); ctx.Clone();

View File

@ -37,7 +37,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
// Extracts array and matrix values that are dynamically indexed to a // Extracts array and matrix values that are dynamically indexed to a
// temporary `var` local that is then indexed. // temporary `var` local that is then indexed.
auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) { auto dynamic_index_to_var = [&](const IndexAccessorExpression* access_expr) {
auto* index_expr = access_expr->index; auto* index_expr = access_expr->index;
auto* object_expr = access_expr->object; auto* object_expr = access_expr->object;
auto& sem = src->Sem(); auto& sem = src->Sem();
@ -62,7 +62,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
bool index_accessor_found = false; bool index_accessor_found = false;
for (auto* node : src->ASTNodes().Objects()) { for (auto* node : src->ASTNodes().Objects()) {
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) { if (auto* access_expr = node->As<IndexAccessorExpression>()) {
if (!dynamic_index_to_var(access_expr)) { if (!dynamic_index_to_var(access_expr)) {
return Program(std::move(b)); return Program(std::move(b));
} }

View File

@ -70,7 +70,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs; std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_conv = call->Target()->As<sem::ValueConversion>(); auto* ty_conv = call->Target()->As<sem::ValueConversion>();
if (!ty_conv) { if (!ty_conv) {
@ -102,7 +102,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
} }
auto build_vectorized_conversion_expression = [&](auto&& src_expression_builder) { auto build_vectorized_conversion_expression = [&](auto&& src_expression_builder) {
utils::Vector<const ast::Expression*, 4> columns; utils::Vector<const Expression*, 4> columns;
for (uint32_t c = 0; c < dst_type->columns(); c++) { for (uint32_t c = 0; c < dst_type->columns(); c++) {
auto* src_matrix_expr = src_expression_builder(); auto* src_matrix_expr = src_expression_builder();
auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c))); auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));

View File

@ -61,7 +61,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
std::unordered_map<const type::Matrix*, Symbol> scalar_inits; std::unordered_map<const type::Matrix*, Symbol> scalar_inits;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* { ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>(); auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_init = call->Target()->As<sem::ValueConstructor>(); auto* ty_init = call->Target()->As<sem::ValueConstructor>();
if (!ty_init) { if (!ty_init) {
@ -91,9 +91,9 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
// Constructs a matrix using vector columns, with the elements constructed using the // Constructs a matrix using vector columns, with the elements constructed using the
// 'element(uint32_t c, uint32_t r)' callback. // 'element(uint32_t c, uint32_t r)' callback.
auto build_mat = [&](auto&& element) { auto build_mat = [&](auto&& element) {
utils::Vector<const ast::Expression*, 4> columns; utils::Vector<const Expression*, 4> columns;
for (uint32_t c = 0; c < mat_type->columns(); c++) { for (uint32_t c = 0; c < mat_type->columns(); c++) {
utils::Vector<const ast::Expression*, 4> row_values; utils::Vector<const Expression*, 4> row_values;
for (uint32_t r = 0; r < mat_type->rows(); r++) { for (uint32_t r = 0; r < mat_type->rows(); r++) {
row_values.Push(element(c, r)); row_values.Push(element(c, r));
} }

View File

@ -238,9 +238,9 @@ struct VertexPulling::State {
/// @returns the new program or SkipTransform if the transform is not required /// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() { ApplyResult Run() {
// Find entry point // Find entry point
const ast::Function* func = nullptr; const Function* func = nullptr;
for (auto* fn : src->AST().Functions()) { for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) { if (fn->PipelineStage() == PipelineStage::kVertex) {
if (func != nullptr) { if (func != nullptr) {
b.Diagnostics().add_error( b.Diagnostics().add_error(
diag::System::Transform, diag::System::Transform,
@ -264,18 +264,18 @@ struct VertexPulling::State {
} }
private: private:
/// LocationReplacement describes an ast::Variable replacement for a location input. /// LocationReplacement describes an Variable replacement for a location input.
struct LocationReplacement { struct LocationReplacement {
/// The variable to replace in the source Program /// The variable to replace in the source Program
ast::Variable* from; Variable* from;
/// The replacement to use in the target ProgramBuilder /// The replacement to use in the target ProgramBuilder
ast::Variable* to; Variable* to;
}; };
/// LocationInfo describes an input location /// LocationInfo describes an input location
struct LocationInfo { struct LocationInfo {
/// A builder that builds the expression that resolves to the (transformed) input location /// A builder that builds the expression that resolves to the (transformed) input location
std::function<const ast::Expression*()> expr; std::function<const Expression*()> expr;
/// The store type of the location variable /// The store type of the location variable
const type::Type* type; const type::Type* type;
}; };
@ -289,12 +289,12 @@ struct VertexPulling::State {
/// The clone context /// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true}; CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
std::unordered_map<uint32_t, LocationInfo> location_info; std::unordered_map<uint32_t, LocationInfo> location_info;
std::function<const ast::Expression*()> vertex_index_expr = nullptr; std::function<const Expression*()> vertex_index_expr = nullptr;
std::function<const ast::Expression*()> instance_index_expr = nullptr; std::function<const Expression*()> instance_index_expr = nullptr;
Symbol pulling_position_name; Symbol pulling_position_name;
Symbol struct_buffer_name; Symbol struct_buffer_name;
std::unordered_map<uint32_t, Symbol> vertex_buffer_names; std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
utils::Vector<const ast::Parameter*, 8> new_function_parameters; utils::Vector<const Parameter*, 8> new_function_parameters;
/// Generate the vertex buffer binding name /// Generate the vertex buffer binding name
/// @param index index to append to buffer name /// @param index index to append to buffer name
@ -331,11 +331,11 @@ struct VertexPulling::State {
} }
/// Creates and returns the assignment to the variables from the buffers /// Creates and returns the assignment to the variables from the buffers
const ast::BlockStatement* CreateVertexPullingPreamble() { const BlockStatement* CreateVertexPullingPreamble() {
// Assign by looking at the vertex descriptor to find attributes with // Assign by looking at the vertex descriptor to find attributes with
// matching location. // matching location.
utils::Vector<const ast::Statement*, 8> stmts; utils::Vector<const Statement*, 8> stmts;
for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) { for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) {
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx]; const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
@ -399,7 +399,7 @@ struct VertexPulling::State {
// Convert the fetched scalar/vector if WGSL variable is of `f16` types // Convert the fetched scalar/vector if WGSL variable is of `f16` types
if (var_dt.base_type == BaseWGSLType::kF16) { if (var_dt.base_type == BaseWGSLType::kF16) {
// The type of the same element number of base type of target WGSL variable // The type of the same element number of base type of target WGSL variable
ast::Type loaded_data_target_type; Type loaded_data_target_type;
if (fmt_dt.width == 1) { if (fmt_dt.width == 1) {
loaded_data_target_type = b.ty.f16(); loaded_data_target_type = b.ty.f16();
} else { } else {
@ -433,7 +433,7 @@ struct VertexPulling::State {
// The components of result vector variable, initialized with type-converted // The components of result vector variable, initialized with type-converted
// loaded data vector. // loaded data vector.
utils::Vector<const ast::Expression*, 8> values{fetch}; utils::Vector<const Expression*, 8> values{fetch};
// Add padding elements. The result must be of vector types of signed/unsigned // Add padding elements. The result must be of vector types of signed/unsigned
// integer or float, so use the abstract integer or abstract float value to do // integer or float, so use the abstract integer or abstract float value to do
@ -470,7 +470,7 @@ struct VertexPulling::State {
/// @param offset the byte offset of the data from `buffer_base` /// @param offset the byte offset of the data from `buffer_base`
/// @param buffer the index of the vertex buffer /// @param buffer the index of the vertex buffer
/// @param format the vertex format to read /// @param format the vertex format to read
const ast::Expression* Fetch(Symbol array_base, const Expression* Fetch(Symbol array_base,
uint32_t offset, uint32_t offset,
uint32_t buffer, uint32_t buffer,
VertexFormat format) { VertexFormat format) {
@ -679,11 +679,11 @@ struct VertexPulling::State {
/// @param buffer the index of the vertex buffer /// @param buffer the index of the vertex buffer
/// @param format VertexFormat::kUint32, VertexFormat::kSint32 or /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
/// VertexFormat::kFloat32 /// VertexFormat::kFloat32
const ast::Expression* LoadPrimitive(Symbol array_base, const Expression* LoadPrimitive(Symbol array_base,
uint32_t offset, uint32_t offset,
uint32_t buffer, uint32_t buffer,
VertexFormat format) { VertexFormat format) {
const ast::Expression* u = nullptr; const Expression* u = nullptr;
if ((offset & 3) == 0) { if ((offset & 3) == 0) {
// Aligned load. // Aligned load.
@ -734,14 +734,14 @@ struct VertexPulling::State {
/// @param base_type underlying AST type /// @param base_type underlying AST type
/// @param base_format underlying vertex format /// @param base_format underlying vertex format
/// @param count how many elements the vector has /// @param count how many elements the vector has
const ast::Expression* LoadVec(Symbol array_base, const Expression* LoadVec(Symbol array_base,
uint32_t offset, uint32_t offset,
uint32_t buffer, uint32_t buffer,
uint32_t element_stride, uint32_t element_stride,
ast::Type base_type, Type base_type,
VertexFormat base_format, VertexFormat base_format,
uint32_t count) { uint32_t count) {
utils::Vector<const ast::Expression*, 8> expr_list; utils::Vector<const Expression*, 8> expr_list;
for (uint32_t i = 0; i < count; ++i) { for (uint32_t i = 0; i < count; ++i) {
// Offset read position by element_stride for each component // Offset read position by element_stride for each component
uint32_t primitive_offset = offset + element_stride * i; uint32_t primitive_offset = offset + element_stride * i;
@ -756,8 +756,8 @@ struct VertexPulling::State {
/// vertex_index and instance_index builtins if present. /// vertex_index and instance_index builtins if present.
/// @param func the entry point function /// @param func the entry point function
/// @param param the parameter to process /// @param param the parameter to process
void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) { void ProcessNonStructParameter(const Function* func, const Parameter* param) {
if (ast::HasAttribute<ast::LocationAttribute>(param->attributes)) { if (HasAttribute<LocationAttribute>(param->attributes)) {
// Create a function-scope variable to replace the parameter. // Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->name->symbol); auto func_var_sym = ctx.Clone(param->name->symbol);
auto func_var_type = ctx.Clone(param->type); auto func_var_type = ctx.Clone(param->type);
@ -776,7 +776,7 @@ struct VertexPulling::State {
} }
location_info[sem->Location().value()] = info; location_info[sem->Location().value()] = info;
} else { } else {
auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes); auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes);
if (TINT_UNLIKELY(!builtin_attr)) { if (TINT_UNLIKELY(!builtin_attr)) {
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
return; return;
@ -804,21 +804,21 @@ struct VertexPulling::State {
/// @param func the entry point function /// @param func the entry point function
/// @param param the parameter to process /// @param param the parameter to process
/// @param struct_ty the structure type /// @param struct_ty the structure type
void ProcessStructParameter(const ast::Function* func, void ProcessStructParameter(const Function* func,
const ast::Parameter* param, const Parameter* param,
const ast::Struct* struct_ty) { const Struct* struct_ty) {
auto param_sym = ctx.Clone(param->name->symbol); auto param_sym = ctx.Clone(param->name->symbol);
// Process the struct members. // Process the struct members.
bool has_locations = false; bool has_locations = false;
utils::Vector<const ast::StructMember*, 8> members_to_clone; utils::Vector<const StructMember*, 8> members_to_clone;
for (auto* member : struct_ty->members) { for (auto* member : struct_ty->members) {
auto member_sym = ctx.Clone(member->name->symbol); auto member_sym = ctx.Clone(member->name->symbol);
std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() { std::function<const Expression*()> member_expr = [this, param_sym, member_sym]() {
return b.MemberAccessor(param_sym, member_sym); return b.MemberAccessor(param_sym, member_sym);
}; };
if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) { if (HasAttribute<LocationAttribute>(member->attributes)) {
// Capture mapping from location to struct member. // Capture mapping from location to struct member.
LocationInfo info; LocationInfo info;
info.expr = member_expr; info.expr = member_expr;
@ -830,7 +830,7 @@ struct VertexPulling::State {
location_info[sem->Attributes().location.value()] = info; location_info[sem->Attributes().location.value()] = info;
has_locations = true; has_locations = true;
} else { } else {
auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(member->attributes); auto* builtin_attr = GetAttribute<BuiltinAttribute>(member->attributes);
if (TINT_UNLIKELY(!builtin_attr)) { if (TINT_UNLIKELY(!builtin_attr)) {
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter"; TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
return; return;
@ -858,7 +858,7 @@ struct VertexPulling::State {
if (!members_to_clone.IsEmpty()) { if (!members_to_clone.IsEmpty()) {
// Create a new struct without the location attributes. // Create a new struct without the location attributes.
utils::Vector<const ast::StructMember*, 8> new_members; utils::Vector<const StructMember*, 8> new_members;
for (auto* member : members_to_clone) { for (auto* member : members_to_clone) {
auto member_name = ctx.Clone(member->name); auto member_name = ctx.Clone(member->name);
auto member_type = ctx.Clone(member->type); auto member_type = ctx.Clone(member->type);
@ -883,7 +883,7 @@ struct VertexPulling::State {
/// Process an entry point function. /// Process an entry point function.
/// @param func the entry point function /// @param func the entry point function
void Process(const ast::Function* func) { void Process(const Function* func) {
if (func->body->Empty()) { if (func->body->Empty()) {
return; return;
} }
@ -936,8 +936,8 @@ struct VertexPulling::State {
auto attrs = ctx.Clone(func->attributes); auto attrs = ctx.Clone(func->attributes);
auto ret_attrs = ctx.Clone(func->return_type_attributes); auto ret_attrs = ctx.Clone(func->return_type_attributes);
auto* new_func = auto* new_func =
b.create<ast::Function>(func->source, b.Ident(func_sym), new_function_parameters, b.create<Function>(func->source, b.Ident(func_sym), new_function_parameters, ret_type,
ret_type, body, std::move(attrs), std::move(ret_attrs)); body, std::move(attrs), std::move(ret_attrs));
ctx.Replace(func, new_func); ctx.Replace(func, new_func);
} }
}; };

View File

@ -26,7 +26,7 @@ namespace {
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) { for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::WhileStatement>()) { if (node->Is<WhileStatement>()) {
return true; return true;
} }
} }
@ -47,8 +47,8 @@ Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, Da
ProgramBuilder b; ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* { ctx.ReplaceAll([&](const WhileStatement* w) -> const Statement* {
utils::Vector<const ast::Statement*, 16> stmts; utils::Vector<const Statement*, 16> stmts;
auto* cond = w->condition; auto* cond = w->condition;
// !condition // !condition
@ -64,7 +64,7 @@ Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, Da
stmts.Push(ctx.Clone(stmt)); stmts.Push(ctx.Clone(stmt));
} }
const ast::BlockStatement* continuing = nullptr; const BlockStatement* continuing = nullptr;
auto* body = b.Block(stmts); auto* body = b.Block(stmts);
auto* loop = b.Loop(body, continuing); auto* loop = b.Loop(body, continuing);

View File

@ -36,7 +36,7 @@ namespace {
bool ShouldRun(const Program* program) { bool ShouldRun(const Program* program) {
for (auto* global : program->AST().GlobalVariables()) { for (auto* global : program->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) { if (auto* var = global->As<Var>()) {
auto* v = program->Sem().Get(var); auto* v = program->Sem().Get(var);
if (v->AddressSpace() == builtin::AddressSpace::kWorkgroup) { if (v->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
return true; return true;
@ -48,7 +48,7 @@ bool ShouldRun(const Program* program) {
} // namespace } // namespace
using StatementList = utils::Vector<const ast::Statement*, 8>; using StatementList = utils::Vector<const Statement*, 8>;
/// PIMPL state for the transform /// PIMPL state for the transform
struct ZeroInitWorkgroupMemory::State { struct ZeroInitWorkgroupMemory::State {
@ -132,10 +132,10 @@ struct ZeroInitWorkgroupMemory::State {
/// Run inserts the workgroup memory zero-initialization logic at the top of /// Run inserts the workgroup memory zero-initialization logic at the top of
/// the given function /// the given function
/// @param fn a compute shader entry point function /// @param fn a compute shader entry point function
void Run(const ast::Function* fn) { void Run(const Function* fn) {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
CalculateWorkgroupSize(ast::GetAttribute<ast::WorkgroupAttribute>(fn->attributes)); CalculateWorkgroupSize(GetAttribute<WorkgroupAttribute>(fn->attributes));
// Generate a list of statements to zero initialize each of the // Generate a list of statements to zero initialize each of the
// workgroup storage variables used by `fn`. This will populate #statements. // workgroup storage variables used by `fn`. This will populate #statements.
@ -160,7 +160,7 @@ struct ZeroInitWorkgroupMemory::State {
// parameter // parameter
std::function<const ast::Expression*()> local_index; std::function<const ast::Expression*()> local_index;
for (auto* param : fn->params) { for (auto* param : fn->params) {
if (auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) { if (auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes)) {
auto builtin = sem.Get(builtin_attr)->Value(); auto builtin = sem.Get(builtin_attr)->Value();
if (builtin == builtin::BuiltinValue::kLocalInvocationIndex) { if (builtin == builtin::BuiltinValue::kLocalInvocationIndex) {
local_index = [=] { return b.Expr(ctx.Clone(param->name->symbol)); }; local_index = [=] { return b.Expr(ctx.Clone(param->name->symbol)); };
@ -231,7 +231,7 @@ struct ZeroInitWorkgroupMemory::State {
// } // }
auto idx = b.Symbols().New("idx"); auto idx = b.Symbols().New("idx");
auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index())); auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
auto* cond = b.create<ast::BinaryExpression>(ast::BinaryOp::kLessThan, b.Expr(idx), auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(idx),
b.Expr(u32(num_iterations))); b.Expr(u32(num_iterations)));
auto* cont = b.Assign( auto* cont = b.Assign(
idx, b.Add(idx, workgroup_size_const ? b.Expr(u32(workgroup_size_const)) idx, b.Add(idx, workgroup_size_const ? b.Expr(u32(workgroup_size_const))
@ -251,8 +251,8 @@ struct ZeroInitWorkgroupMemory::State {
// if (local_index < num_iterations) { // if (local_index < num_iterations) {
// ... // ...
// } // }
auto* cond = b.create<ast::BinaryExpression>( auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, local_index(),
ast::BinaryOp::kLessThan, local_index(), b.Expr(u32(num_iterations))); b.Expr(u32(num_iterations)));
auto block = DeclareArrayIndices(num_iterations, array_indices, auto block = DeclareArrayIndices(num_iterations, array_indices,
[&] { return b.Expr(local_index()); }); [&] { return b.Expr(local_index()); });
for (auto& s : stmts) { for (auto& s : stmts) {
@ -382,7 +382,7 @@ struct ZeroInitWorkgroupMemory::State {
for (auto index : array_indices) { for (auto index : array_indices) {
auto name = array_index_names.at(index); auto name = array_index_names.at(index);
auto* mod = (num_iterations > index.modulo) auto* mod = (num_iterations > index.modulo)
? b.create<ast::BinaryExpression>(ast::BinaryOp::kModulo, iteration(), ? b.create<BinaryExpression>(BinaryOp::kModulo, iteration(),
b.Expr(u32(index.modulo))) b.Expr(u32(index.modulo)))
: iteration(); : iteration();
auto* div = (index.division != 1u) ? b.Div(mod, u32(index.division)) : mod; auto* div = (index.division != 1u) ? b.Div(mod, u32(index.division)) : mod;
@ -395,7 +395,7 @@ struct ZeroInitWorkgroupMemory::State {
/// CalculateWorkgroupSize initializes the members #workgroup_size_const and /// CalculateWorkgroupSize initializes the members #workgroup_size_const and
/// #workgroup_size_expr with the linear workgroup size. /// #workgroup_size_expr with the linear workgroup size.
/// @param attr the workgroup attribute applied to the entry point function /// @param attr the workgroup attribute applied to the entry point function
void CalculateWorkgroupSize(const ast::WorkgroupAttribute* attr) { void CalculateWorkgroupSize(const WorkgroupAttribute* attr) {
bool is_signed = false; bool is_signed = false;
workgroup_size_const = 1u; workgroup_size_const = 1u;
workgroup_size_expr = nullptr; workgroup_size_expr = nullptr;
@ -471,7 +471,7 @@ Transform::ApplyResult ZeroInitWorkgroupMemory::Apply(const Program* src,
CloneContext ctx{&b, src, /* auto_clone_symbols */ true}; CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
for (auto* fn : src->AST().Functions()) { for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kCompute) { if (fn->PipelineStage() == PipelineStage::kCompute) {
State{ctx}.Run(fn); State{ctx}.Run(fn);
} }
} }