[tint] Remove ast:: prefixes from AST transforms
Now that AST transforms live in the AST namespace, these prefixes are no longer necessary. Change-Id: I658746ac04220075653ec57d6dc998947232d8bc Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/132425 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
parent
4ae03fa8d0
commit
2b7406ad55
|
@ -41,7 +41,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
|
||||||
|
|
||||||
// A map from a type in the source program to a block-decorated wrapper that contains it in the
|
// A map from a type in the source program to a block-decorated wrapper that contains it in the
|
||||||
// destination program.
|
// destination program.
|
||||||
utils::Hashmap<const type::Type*, const ast::Struct*, 8> wrapper_structs;
|
utils::Hashmap<const type::Type*, const Struct*, 8> wrapper_structs;
|
||||||
|
|
||||||
// Process global 'var' declarations that are buffers.
|
// Process global 'var' declarations that are buffers.
|
||||||
bool made_changes = false;
|
bool made_changes = false;
|
||||||
|
@ -71,7 +71,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
|
||||||
auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
|
auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
|
||||||
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
|
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
|
||||||
auto wrapper_name = global->name->symbol.Name() + "_block";
|
auto wrapper_name = global->name->symbol.Name() + "_block";
|
||||||
auto* ret = b.create<ast::Struct>(
|
auto* ret = b.create<Struct>(
|
||||||
b.Ident(b.Symbols().New(wrapper_name)),
|
b.Ident(b.Symbols().New(wrapper_name)),
|
||||||
utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
|
utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
|
||||||
utils::Vector{block});
|
utils::Vector{block});
|
||||||
|
@ -101,7 +101,7 @@ Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
|
||||||
return Program(std::move(b));
|
return Program(std::move(b));
|
||||||
}
|
}
|
||||||
|
|
||||||
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid)
|
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, NodeID nid)
|
||||||
: Base(pid, nid, utils::Empty) {}
|
: Base(pid, nid, utils::Empty) {}
|
||||||
AddBlockAttribute::BlockAttribute::~BlockAttribute() = default;
|
AddBlockAttribute::BlockAttribute::~BlockAttribute() = default;
|
||||||
std::string AddBlockAttribute::BlockAttribute::InternalName() const {
|
std::string AddBlockAttribute::BlockAttribute::InternalName() const {
|
||||||
|
|
|
@ -28,12 +28,12 @@ class AddBlockAttribute final : public utils::Castable<AddBlockAttribute, Transf
|
||||||
public:
|
public:
|
||||||
/// BlockAttribute is an InternalAttribute that is used to decorate a
|
/// BlockAttribute is an InternalAttribute that is used to decorate a
|
||||||
// structure that is used as a buffer in SPIR-V or GLSL.
|
// structure that is used as a buffer in SPIR-V or GLSL.
|
||||||
class BlockAttribute final : public utils::Castable<BlockAttribute, ast::InternalAttribute> {
|
class BlockAttribute final : public utils::Castable<BlockAttribute, InternalAttribute> {
|
||||||
public:
|
public:
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param program_id the identifier of the program that owns this node
|
/// @param program_id the identifier of the program that owns this node
|
||||||
/// @param nid the unique node identifier
|
/// @param nid the unique node identifier
|
||||||
BlockAttribute(ProgramID program_id, ast::NodeID nid);
|
BlockAttribute(ProgramID program_id, NodeID nid);
|
||||||
/// Destructor
|
/// Destructor
|
||||||
~BlockAttribute() override;
|
~BlockAttribute() override;
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src,
|
||||||
|
|
||||||
b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
|
b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -81,8 +81,8 @@ struct ArrayLengthFromUniform::State {
|
||||||
// Determine the size of the buffer size array.
|
// Determine the size of the buffer size array.
|
||||||
uint32_t max_buffer_size_index = 0;
|
uint32_t max_buffer_size_index = 0;
|
||||||
|
|
||||||
IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*,
|
IterateArrayLengthOnStorageVar(
|
||||||
const sem::GlobalVariable* var) {
|
[&](const CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
|
||||||
if (auto binding = var->BindingPoint()) {
|
if (auto binding = var->BindingPoint()) {
|
||||||
auto idx_itr = cfg->bindpoint_to_size_index.find(*binding);
|
auto idx_itr = cfg->bindpoint_to_size_index.find(*binding);
|
||||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||||
|
@ -96,7 +96,7 @@ struct ArrayLengthFromUniform::State {
|
||||||
|
|
||||||
// Get (or create, on first call) the uniform buffer that will receive the
|
// Get (or create, on first call) the uniform buffer that will receive the
|
||||||
// size of each storage buffer in the module.
|
// size of each storage buffer in the module.
|
||||||
const ast::Variable* buffer_size_ubo = nullptr;
|
const Variable* buffer_size_ubo = nullptr;
|
||||||
auto get_ubo = [&]() {
|
auto get_ubo = [&]() {
|
||||||
if (!buffer_size_ubo) {
|
if (!buffer_size_ubo) {
|
||||||
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
|
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
|
||||||
|
@ -118,7 +118,7 @@ struct ArrayLengthFromUniform::State {
|
||||||
|
|
||||||
std::unordered_set<uint32_t> used_size_indices;
|
std::unordered_set<uint32_t> used_size_indices;
|
||||||
|
|
||||||
IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
|
IterateArrayLengthOnStorageVar([&](const CallExpression* call_expr,
|
||||||
const sem::VariableUser* storage_buffer_sem,
|
const sem::VariableUser* storage_buffer_sem,
|
||||||
const sem::GlobalVariable* var) {
|
const sem::GlobalVariable* var) {
|
||||||
auto binding = var->BindingPoint();
|
auto binding = var->BindingPoint();
|
||||||
|
@ -144,7 +144,7 @@ struct ArrayLengthFromUniform::State {
|
||||||
// total_storage_buffer_size - array_offset
|
// total_storage_buffer_size - array_offset
|
||||||
// array_length = ----------------------------------------
|
// array_length = ----------------------------------------
|
||||||
// array_stride
|
// array_stride
|
||||||
const ast::Expression* total_size = total_storage_buffer_size;
|
const Expression* total_size = total_storage_buffer_size;
|
||||||
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
|
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
|
||||||
const type::Array* array_type = nullptr;
|
const type::Array* array_type = nullptr;
|
||||||
if (auto* str = storage_buffer_type->As<type::Struct>()) {
|
if (auto* str = storage_buffer_type->As<type::Struct>()) {
|
||||||
|
@ -186,9 +186,9 @@ struct ArrayLengthFromUniform::State {
|
||||||
|
|
||||||
/// Iterate over all arrayLength() builtins that operate on
|
/// Iterate over all arrayLength() builtins that operate on
|
||||||
/// storage buffer variables.
|
/// storage buffer variables.
|
||||||
/// @param functor of type void(const ast::CallExpression*, const
|
/// @param functor of type void(const CallExpression*, const
|
||||||
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
|
/// sem::VariableUser, const sem::GlobalVariable*). It takes in an
|
||||||
/// ast::CallExpression of the arrayLength call expression node, a
|
/// CallExpression of the arrayLength call expression node, a
|
||||||
/// sem::VariableUser of the used storage buffer variable, and the
|
/// sem::VariableUser of the used storage buffer variable, and the
|
||||||
/// sem::GlobalVariable for the storage buffer.
|
/// sem::GlobalVariable for the storage buffer.
|
||||||
template <typename F>
|
template <typename F>
|
||||||
|
@ -197,7 +197,7 @@ struct ArrayLengthFromUniform::State {
|
||||||
|
|
||||||
// Find all calls to the arrayLength() builtin.
|
// Find all calls to the arrayLength() builtin.
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
auto* call_expr = node->As<ast::CallExpression>();
|
auto* call_expr = node->As<CallExpression>();
|
||||||
if (!call_expr) {
|
if (!call_expr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -208,7 +208,7 @@ struct ArrayLengthFromUniform::State {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
|
if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
|
||||||
if (call_stmt->expr == call_expr) {
|
if (call_stmt->expr == call_expr) {
|
||||||
// arrayLength() is used as a statement.
|
// arrayLength() is used as a statement.
|
||||||
// The argument expression must be side-effect free, so just drop the statement.
|
// The argument expression must be side-effect free, so just drop the statement.
|
||||||
|
@ -222,15 +222,15 @@ struct ArrayLengthFromUniform::State {
|
||||||
// call has one of two forms:
|
// call has one of two forms:
|
||||||
// arrayLength(&struct_var.array_member)
|
// arrayLength(&struct_var.array_member)
|
||||||
// arrayLength(&array_var)
|
// arrayLength(&array_var)
|
||||||
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
|
auto* param = call_expr->args[0]->As<UnaryOpExpression>();
|
||||||
if (TINT_UNLIKELY(!param || param->op != ast::UnaryOp::kAddressOf)) {
|
if (TINT_UNLIKELY(!param || param->op != UnaryOp::kAddressOf)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "expected form of arrayLength argument to be &array_var or "
|
<< "expected form of arrayLength argument to be &array_var or "
|
||||||
"&struct_var.array_member";
|
"&struct_var.array_member";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
auto* storage_buffer_expr = param->expr;
|
auto* storage_buffer_expr = param->expr;
|
||||||
if (auto* accessor = param->expr->As<ast::MemberAccessorExpression>()) {
|
if (auto* accessor = param->expr->As<MemberAccessorExpression>()) {
|
||||||
storage_buffer_expr = accessor->object;
|
storage_buffer_expr = accessor->object;
|
||||||
}
|
}
|
||||||
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||||
|
|
|
@ -90,7 +90,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto* var : src->AST().Globals<ast::Var>()) {
|
for (auto* var : src->AST().Globals<Var>()) {
|
||||||
if (var->HasBindingPoint()) {
|
if (var->HasBindingPoint()) {
|
||||||
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
|
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
|
||||||
|
|
||||||
|
@ -109,8 +109,8 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
|
||||||
auto* new_group = b.Group(AInt(to.group));
|
auto* new_group = b.Group(AInt(to.group));
|
||||||
auto* new_binding = b.Binding(AInt(to.binding));
|
auto* new_binding = b.Binding(AInt(to.binding));
|
||||||
|
|
||||||
auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
|
auto* old_group = GetAttribute<GroupAttribute>(var->attributes);
|
||||||
auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
|
auto* old_binding = GetAttribute<BindingAttribute>(var->attributes);
|
||||||
|
|
||||||
ctx.Replace(old_group, new_group);
|
ctx.Replace(old_group, new_group);
|
||||||
ctx.Replace(old_binding, new_binding);
|
ctx.Replace(old_binding, new_binding);
|
||||||
|
@ -139,7 +139,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
|
||||||
auto* ty = sem->Type()->UnwrapRef();
|
auto* ty = sem->Type()->UnwrapRef();
|
||||||
auto inner_ty = CreateASTTypeFor(ctx, ty);
|
auto inner_ty = CreateASTTypeFor(ctx, ty);
|
||||||
auto* new_var =
|
auto* new_var =
|
||||||
b.create<ast::Var>(ctx.Clone(var->source), // source
|
b.create<Var>(ctx.Clone(var->source), // source
|
||||||
b.Ident(ctx.Clone(var->name->symbol)), // name
|
b.Ident(ctx.Clone(var->name->symbol)), // name
|
||||||
inner_ty, // type
|
inner_ty, // type
|
||||||
ctx.Clone(var->declared_address_space), // address space
|
ctx.Clone(var->declared_address_space), // address space
|
||||||
|
@ -151,7 +151,7 @@ Transform::ApplyResult BindingRemapper::Apply(const Program* src,
|
||||||
|
|
||||||
// Add `DisableValidationAttribute`s if required
|
// Add `DisableValidationAttribute`s if required
|
||||||
if (add_collision_attr.count(bp)) {
|
if (add_collision_attr.count(bp)) {
|
||||||
auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision);
|
auto* attribute = b.Disable(DisabledValidation::kBindingPointCollision);
|
||||||
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
|
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::BuiltinPolyfill::Config);
|
||||||
namespace tint::ast::transform {
|
namespace tint::ast::transform {
|
||||||
|
|
||||||
/// BinaryOpSignature is tuple of a binary op, LHS type and RHS type
|
/// BinaryOpSignature is tuple of a binary op, LHS type and RHS type
|
||||||
using BinaryOpSignature = std::tuple<ast::BinaryOp, const type::Type*, const type::Type*>;
|
using BinaryOpSignature = std::tuple<BinaryOp, const type::Type*, const type::Type*>;
|
||||||
|
|
||||||
/// PIMPL state for the transform
|
/// PIMPL state for the transform
|
||||||
struct BuiltinPolyfill::State {
|
struct BuiltinPolyfill::State {
|
||||||
|
@ -60,16 +60,16 @@ struct BuiltinPolyfill::State {
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
Switch(
|
Switch(
|
||||||
node, //
|
node, //
|
||||||
[&](const ast::CallExpression* expr) { Call(expr); },
|
[&](const CallExpression* expr) { Call(expr); },
|
||||||
[&](const ast::BinaryExpression* bin_op) {
|
[&](const BinaryExpression* bin_op) {
|
||||||
if (auto* s = src->Sem().Get(bin_op);
|
if (auto* s = src->Sem().Get(bin_op);
|
||||||
!s || s->Stage() == sem::EvaluationStage::kConstant ||
|
!s || s->Stage() == sem::EvaluationStage::kConstant ||
|
||||||
s->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
s->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
||||||
return; // Don't polyfill @const expressions
|
return; // Don't polyfill @const expressions
|
||||||
}
|
}
|
||||||
switch (bin_op->op) {
|
switch (bin_op->op) {
|
||||||
case ast::BinaryOp::kShiftLeft:
|
case BinaryOp::kShiftLeft:
|
||||||
case ast::BinaryOp::kShiftRight: {
|
case BinaryOp::kShiftRight: {
|
||||||
if (cfg.builtins.bitshift_modulo) {
|
if (cfg.builtins.bitshift_modulo) {
|
||||||
ctx.Replace(bin_op,
|
ctx.Replace(bin_op,
|
||||||
[this, bin_op] { return BitshiftModulo(bin_op); });
|
[this, bin_op] { return BitshiftModulo(bin_op); });
|
||||||
|
@ -77,7 +77,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case ast::BinaryOp::kDivide: {
|
case BinaryOp::kDivide: {
|
||||||
if (cfg.builtins.int_div_mod) {
|
if (cfg.builtins.int_div_mod) {
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
if (lhs_ty->is_integer_scalar_or_vector()) {
|
if (lhs_ty->is_integer_scalar_or_vector()) {
|
||||||
|
@ -88,7 +88,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case ast::BinaryOp::kModulo: {
|
case BinaryOp::kModulo: {
|
||||||
if (cfg.builtins.int_div_mod) {
|
if (cfg.builtins.int_div_mod) {
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
if (lhs_ty->is_integer_scalar_or_vector()) {
|
if (lhs_ty->is_integer_scalar_or_vector()) {
|
||||||
|
@ -111,7 +111,7 @@ struct BuiltinPolyfill::State {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::Expression* expr) {
|
[&](const Expression* expr) {
|
||||||
if (cfg.builtins.bgra8unorm) {
|
if (cfg.builtins.bgra8unorm) {
|
||||||
if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
|
if (auto* ty_expr = src->Sem().Get<sem::TypeExpression>(expr)) {
|
||||||
if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
|
if (auto* tex = ty_expr->Type()->As<type::StorageTexture>()) {
|
||||||
|
@ -170,15 +170,15 @@ struct BuiltinPolyfill::State {
|
||||||
auto name = b.Symbols().New("tint_acosh");
|
auto name = b.Symbols().New("tint_acosh");
|
||||||
uint32_t width = WidthOf(ty);
|
uint32_t width = WidthOf(ty);
|
||||||
|
|
||||||
auto V = [&](AFloat value) -> const ast::Expression* {
|
auto V = [&](AFloat value) -> const Expression* {
|
||||||
const ast::Expression* expr = b.Expr(value);
|
const Expression* expr = b.Expr(value);
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return expr;
|
return expr;
|
||||||
}
|
}
|
||||||
return b.Call(T(ty), expr);
|
return b.Call(T(ty), expr);
|
||||||
};
|
};
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 4> body;
|
utils::Vector<const Statement*, 4> body;
|
||||||
switch (cfg.builtins.acosh) {
|
switch (cfg.builtins.acosh) {
|
||||||
case Level::kFull:
|
case Level::kFull:
|
||||||
// return log(x + sqrt(x*x - 1));
|
// return log(x + sqrt(x*x - 1));
|
||||||
|
@ -224,15 +224,15 @@ struct BuiltinPolyfill::State {
|
||||||
auto name = b.Symbols().New("tint_atanh");
|
auto name = b.Symbols().New("tint_atanh");
|
||||||
uint32_t width = WidthOf(ty);
|
uint32_t width = WidthOf(ty);
|
||||||
|
|
||||||
auto V = [&](AFloat value) -> const ast::Expression* {
|
auto V = [&](AFloat value) -> const Expression* {
|
||||||
const ast::Expression* expr = b.Expr(value);
|
const Expression* expr = b.Expr(value);
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return expr;
|
return expr;
|
||||||
}
|
}
|
||||||
return b.Call(T(ty), expr);
|
return b.Call(T(ty), expr);
|
||||||
};
|
};
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 1> body;
|
utils::Vector<const Statement*, 1> body;
|
||||||
switch (cfg.builtins.atanh) {
|
switch (cfg.builtins.atanh) {
|
||||||
case Level::kFull:
|
case Level::kFull:
|
||||||
// return log((1+x) / (1-x)) * 0.5
|
// return log((1+x) / (1-x)) * 0.5
|
||||||
|
@ -290,7 +290,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
return b.ty.vec<u32>(width);
|
return b.ty.vec<u32>(width);
|
||||||
};
|
};
|
||||||
auto V = [&](uint32_t value) -> const ast::Expression* {
|
auto V = [&](uint32_t value) -> const Expression* {
|
||||||
return ScalarOrVector(width, u32(value));
|
return ScalarOrVector(width, u32(value));
|
||||||
};
|
};
|
||||||
b.Func(
|
b.Func(
|
||||||
|
@ -348,10 +348,10 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
return b.ty.vec<u32>(width);
|
return b.ty.vec<u32>(width);
|
||||||
};
|
};
|
||||||
auto V = [&](uint32_t value) -> const ast::Expression* {
|
auto V = [&](uint32_t value) -> const Expression* {
|
||||||
return ScalarOrVector(width, u32(value));
|
return ScalarOrVector(width, u32(value));
|
||||||
};
|
};
|
||||||
auto B = [&](const ast::Expression* value) -> const ast::Expression* {
|
auto B = [&](const Expression* value) -> const Expression* {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return b.Call<bool>(value);
|
return b.Call<bool>(value);
|
||||||
}
|
}
|
||||||
|
@ -402,14 +402,14 @@ struct BuiltinPolyfill::State {
|
||||||
|
|
||||||
constexpr uint32_t W = 32u; // 32-bit
|
constexpr uint32_t W = 32u; // 32-bit
|
||||||
|
|
||||||
auto vecN_u32 = [&](const ast::Expression* value) -> const ast::Expression* {
|
auto vecN_u32 = [&](const Expression* value) -> const Expression* {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
return b.Call(b.ty.vec<u32>(width), value);
|
return b.Call(b.ty.vec<u32>(width), value);
|
||||||
};
|
};
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 8> body{
|
utils::Vector<const Statement*, 8> body{
|
||||||
b.Decl(b.Let("s", b.Call("min", "offset", u32(W)))),
|
b.Decl(b.Let("s", b.Call("min", "offset", u32(W)))),
|
||||||
b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))),
|
b.Decl(b.Let("e", b.Call("min", u32(W), b.Add("s", "count")))),
|
||||||
};
|
};
|
||||||
|
@ -465,17 +465,17 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
return b.ty.vec<u32>(width);
|
return b.ty.vec<u32>(width);
|
||||||
};
|
};
|
||||||
auto V = [&](uint32_t value) -> const ast::Expression* {
|
auto V = [&](uint32_t value) -> const Expression* {
|
||||||
return ScalarOrVector(width, u32(value));
|
return ScalarOrVector(width, u32(value));
|
||||||
};
|
};
|
||||||
auto B = [&](const ast::Expression* value) -> const ast::Expression* {
|
auto B = [&](const Expression* value) -> const Expression* {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return b.Call<bool>(value);
|
return b.Call<bool>(value);
|
||||||
}
|
}
|
||||||
return b.Call(b.ty.vec<bool>(width), value);
|
return b.Call(b.ty.vec<bool>(width), value);
|
||||||
};
|
};
|
||||||
|
|
||||||
const ast::Expression* x = nullptr;
|
const Expression* x = nullptr;
|
||||||
if (ty->is_unsigned_integer_scalar_or_vector()) {
|
if (ty->is_unsigned_integer_scalar_or_vector()) {
|
||||||
x = b.Expr("v");
|
x = b.Expr("v");
|
||||||
} else {
|
} else {
|
||||||
|
@ -537,10 +537,10 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
return b.ty.vec<u32>(width);
|
return b.ty.vec<u32>(width);
|
||||||
};
|
};
|
||||||
auto V = [&](uint32_t value) -> const ast::Expression* {
|
auto V = [&](uint32_t value) -> const Expression* {
|
||||||
return ScalarOrVector(width, u32(value));
|
return ScalarOrVector(width, u32(value));
|
||||||
};
|
};
|
||||||
auto B = [&](const ast::Expression* value) -> const ast::Expression* {
|
auto B = [&](const Expression* value) -> const Expression* {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return b.Call<bool>(value);
|
return b.Call<bool>(value);
|
||||||
}
|
}
|
||||||
|
@ -599,8 +599,8 @@ struct BuiltinPolyfill::State {
|
||||||
|
|
||||||
constexpr uint32_t W = 32u; // 32-bit
|
constexpr uint32_t W = 32u; // 32-bit
|
||||||
|
|
||||||
auto V = [&](auto value) -> const ast::Expression* {
|
auto V = [&](auto value) -> const Expression* {
|
||||||
const ast::Expression* expr = b.Expr(value);
|
const Expression* expr = b.Expr(value);
|
||||||
if (!ty->is_unsigned_integer_scalar_or_vector()) {
|
if (!ty->is_unsigned_integer_scalar_or_vector()) {
|
||||||
expr = b.Call<i32>(expr);
|
expr = b.Call<i32>(expr);
|
||||||
}
|
}
|
||||||
|
@ -609,7 +609,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
return expr;
|
return expr;
|
||||||
};
|
};
|
||||||
auto U = [&](auto value) -> const ast::Expression* {
|
auto U = [&](auto value) -> const Expression* {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return b.Expr(value);
|
return b.Expr(value);
|
||||||
}
|
}
|
||||||
|
@ -638,7 +638,7 @@ struct BuiltinPolyfill::State {
|
||||||
// return ((select(T(), n << offset, offset < 32u) & mask) | (v & ~(mask)));
|
// return ((select(T(), n << offset, offset < 32u) & mask) | (v & ~(mask)));
|
||||||
// }
|
// }
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 8> body;
|
utils::Vector<const Statement*, 8> body;
|
||||||
|
|
||||||
switch (cfg.builtins.insert_bits) {
|
switch (cfg.builtins.insert_bits) {
|
||||||
case Level::kFull:
|
case Level::kFull:
|
||||||
|
@ -788,7 +788,7 @@ struct BuiltinPolyfill::State {
|
||||||
/// @return the polyfill function name
|
/// @return the polyfill function name
|
||||||
Symbol quantizeToF16(const type::Vector* vec) {
|
Symbol quantizeToF16(const type::Vector* vec) {
|
||||||
auto name = b.Symbols().New("tint_quantizeToF16");
|
auto name = b.Symbols().New("tint_quantizeToF16");
|
||||||
utils::Vector<const ast::Expression*, 4> args;
|
utils::Vector<const Expression*, 4> args;
|
||||||
for (uint32_t i = 0; i < vec->Width(); i++) {
|
for (uint32_t i = 0; i < vec->Width(); i++) {
|
||||||
args.Push(b.Call("quantizeToF16", b.IndexAccessor("v", u32(i))));
|
args.Push(b.Call("quantizeToF16", b.IndexAccessor("v", u32(i))));
|
||||||
}
|
}
|
||||||
|
@ -880,29 +880,29 @@ struct BuiltinPolyfill::State {
|
||||||
/// the RHS is modulo the bit-width of the LHS.
|
/// the RHS is modulo the bit-width of the LHS.
|
||||||
/// @param bin_op the original BinaryExpression
|
/// @param bin_op the original BinaryExpression
|
||||||
/// @return the polyfill value for bitshift operation
|
/// @return the polyfill value for bitshift operation
|
||||||
const ast::Expression* BitshiftModulo(const ast::BinaryExpression* bin_op) {
|
const Expression* BitshiftModulo(const BinaryExpression* bin_op) {
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||||
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
|
auto* lhs_el_ty = type::Type::DeepestElementOf(lhs_ty);
|
||||||
const ast::Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
|
const Expression* mask = b.Expr(AInt(lhs_el_ty->Size() * 8 - 1));
|
||||||
if (rhs_ty->Is<type::Vector>()) {
|
if (rhs_ty->Is<type::Vector>()) {
|
||||||
mask = b.Call(CreateASTTypeFor(ctx, rhs_ty), mask);
|
mask = b.Call(CreateASTTypeFor(ctx, rhs_ty), mask);
|
||||||
}
|
}
|
||||||
auto* lhs = ctx.Clone(bin_op->lhs);
|
auto* lhs = ctx.Clone(bin_op->lhs);
|
||||||
auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask);
|
auto* rhs = b.And(ctx.Clone(bin_op->rhs), mask);
|
||||||
return b.create<ast::BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
|
return b.create<BinaryExpression>(ctx.Clone(bin_op->source), bin_op->op, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and
|
/// Builds the polyfill inline expression for a integer divide or modulo, preventing DBZs and
|
||||||
/// integer overflows.
|
/// integer overflows.
|
||||||
/// @param bin_op the original BinaryExpression
|
/// @param bin_op the original BinaryExpression
|
||||||
/// @return the polyfill divide or modulo
|
/// @return the polyfill divide or modulo
|
||||||
const ast::Expression* IntDivMod(const ast::BinaryExpression* bin_op) {
|
const Expression* IntDivMod(const BinaryExpression* bin_op) {
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||||
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
||||||
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
|
auto fn = binary_op_polyfills.GetOrCreate(sig, [&] {
|
||||||
const bool is_div = bin_op->op == ast::BinaryOp::kDivide;
|
const bool is_div = bin_op->op == BinaryOp::kDivide;
|
||||||
|
|
||||||
uint32_t lhs_width = 1;
|
uint32_t lhs_width = 1;
|
||||||
uint32_t rhs_width = 1;
|
uint32_t rhs_width = 1;
|
||||||
|
@ -914,7 +914,7 @@ struct BuiltinPolyfill::State {
|
||||||
const char* lhs = "lhs";
|
const char* lhs = "lhs";
|
||||||
const char* rhs = "rhs";
|
const char* rhs = "rhs";
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 4> body;
|
utils::Vector<const Statement*, 4> body;
|
||||||
|
|
||||||
if (lhs_width < width) {
|
if (lhs_width < width) {
|
||||||
// lhs is scalar, rhs is vector. Convert lhs to vector.
|
// lhs is scalar, rhs is vector. Convert lhs to vector.
|
||||||
|
@ -934,8 +934,8 @@ struct BuiltinPolyfill::State {
|
||||||
if (lhs_ty->is_signed_integer_scalar_or_vector()) {
|
if (lhs_ty->is_signed_integer_scalar_or_vector()) {
|
||||||
const auto bits = lhs_el_ty->Size() * 8;
|
const auto bits = lhs_el_ty->Size() * 8;
|
||||||
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
|
auto min_int = AInt(AInt::kLowestValue >> (AInt::kNumBits - bits));
|
||||||
const ast::Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
|
const Expression* lhs_is_min = b.Equal(lhs, ScalarOrVector(width, min_int));
|
||||||
const ast::Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
|
const Expression* rhs_is_minus_one = b.Equal(rhs, ScalarOrVector(width, -1_a));
|
||||||
// use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1))
|
// use_one = rhs_is_zero | ((lhs == MIN_INT) & (rhs == -1))
|
||||||
auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one));
|
auto* use_one = b.Or(rhs_is_zero, b.And(lhs_is_min, rhs_is_minus_one));
|
||||||
|
|
||||||
|
@ -992,7 +992,7 @@ struct BuiltinPolyfill::State {
|
||||||
/// Builds the polyfill inline expression for a precise float modulo, as defined in the spec.
|
/// Builds the polyfill inline expression for a precise float modulo, as defined in the spec.
|
||||||
/// @param bin_op the original BinaryExpression
|
/// @param bin_op the original BinaryExpression
|
||||||
/// @return the polyfill divide or modulo
|
/// @return the polyfill divide or modulo
|
||||||
const ast::Expression* PreciseFloatMod(const ast::BinaryExpression* bin_op) {
|
const Expression* PreciseFloatMod(const BinaryExpression* bin_op) {
|
||||||
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
auto* lhs_ty = src->TypeOf(bin_op->lhs)->UnwrapRef();
|
||||||
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
auto* rhs_ty = src->TypeOf(bin_op->rhs)->UnwrapRef();
|
||||||
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
BinaryOpSignature sig{bin_op->op, lhs_ty, rhs_ty};
|
||||||
|
@ -1007,7 +1007,7 @@ struct BuiltinPolyfill::State {
|
||||||
const char* lhs = "lhs";
|
const char* lhs = "lhs";
|
||||||
const char* rhs = "rhs";
|
const char* rhs = "rhs";
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 4> body;
|
utils::Vector<const Statement*, 4> body;
|
||||||
|
|
||||||
if (lhs_width < width) {
|
if (lhs_width < width) {
|
||||||
// lhs is scalar, rhs is vector. Convert lhs to vector.
|
// lhs is scalar, rhs is vector. Convert lhs to vector.
|
||||||
|
@ -1042,7 +1042,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @returns the AST type for the given sem type
|
/// @returns the AST type for the given sem type
|
||||||
ast::Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); }
|
Type T(const type::Type* ty) { return CreateASTTypeFor(ctx, ty); }
|
||||||
|
|
||||||
/// @returns 1 if `ty` is not a vector, otherwise the vector width
|
/// @returns 1 if `ty` is not a vector, otherwise the vector width
|
||||||
uint32_t WidthOf(const type::Type* ty) const {
|
uint32_t WidthOf(const type::Type* ty) const {
|
||||||
|
@ -1055,7 +1055,7 @@ struct BuiltinPolyfill::State {
|
||||||
/// @returns a scalar or vector with the given width, with each element with
|
/// @returns a scalar or vector with the given width, with each element with
|
||||||
/// the given value.
|
/// the given value.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
const ast::Expression* ScalarOrVector(uint32_t width, T value) {
|
const Expression* ScalarOrVector(uint32_t width, T value) {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return b.Expr(value);
|
return b.Expr(value);
|
||||||
}
|
}
|
||||||
|
@ -1063,7 +1063,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename To>
|
template <typename To>
|
||||||
const ast::Expression* CastScalarOrVector(uint32_t width, const ast::Expression* e) {
|
const Expression* CastScalarOrVector(uint32_t width, const Expression* e) {
|
||||||
if (width == 1) {
|
if (width == 1) {
|
||||||
return b.Call(b.ty.Of<To>(), e);
|
return b.Call(b.ty.Of<To>(), e);
|
||||||
}
|
}
|
||||||
|
@ -1071,7 +1071,7 @@ struct BuiltinPolyfill::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Examines the call expression @p expr, applying any necessary polyfill transforms
|
/// Examines the call expression @p expr, applying any necessary polyfill transforms
|
||||||
void Call(const ast::CallExpression* expr) {
|
void Call(const CallExpression* expr) {
|
||||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
if (!call || call->Stage() == sem::EvaluationStage::kConstant ||
|
if (!call || call->Stage() == sem::EvaluationStage::kConstant ||
|
||||||
call->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
call->Stage() == sem::EvaluationStage::kNotEvaluated) {
|
||||||
|
@ -1207,7 +1207,7 @@ struct BuiltinPolyfill::State {
|
||||||
size_t value_idx = static_cast<size_t>(
|
size_t value_idx = static_cast<size_t>(
|
||||||
sig.IndexOf(sem::ParameterUsage::kValue));
|
sig.IndexOf(sem::ParameterUsage::kValue));
|
||||||
ctx.Replace(expr, [this, expr, value_idx] {
|
ctx.Replace(expr, [this, expr, value_idx] {
|
||||||
utils::Vector<const ast::Expression*, 3> args;
|
utils::Vector<const Expression*, 3> args;
|
||||||
for (auto* arg : expr->args) {
|
for (auto* arg : expr->args) {
|
||||||
arg = ctx.Clone(arg);
|
arg = ctx.Clone(arg);
|
||||||
if (args.Length() == value_idx) { // value
|
if (args.Length() == value_idx) { // value
|
||||||
|
|
|
@ -57,7 +57,7 @@ bool ShouldRun(const Program* program) {
|
||||||
/// ArrayUsage describes a runtime array usage.
|
/// ArrayUsage describes a runtime array usage.
|
||||||
/// It is used as a key by the array_length_by_usage map.
|
/// It is used as a key by the array_length_by_usage map.
|
||||||
struct ArrayUsage {
|
struct ArrayUsage {
|
||||||
ast::BlockStatement const* const block;
|
BlockStatement const* const block;
|
||||||
sem::Variable const* const buffer;
|
sem::Variable const* const buffer;
|
||||||
bool operator==(const ArrayUsage& rhs) const {
|
bool operator==(const ArrayUsage& rhs) const {
|
||||||
return block == rhs.block && buffer == rhs.buffer;
|
return block == rhs.block && buffer == rhs.buffer;
|
||||||
|
@ -71,7 +71,7 @@ struct ArrayUsage {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, ast::NodeID nid)
|
CalculateArrayLength::BufferSizeIntrinsic::BufferSizeIntrinsic(ProgramID pid, NodeID nid)
|
||||||
: Base(pid, nid, utils::Empty) {}
|
: Base(pid, nid, utils::Empty) {}
|
||||||
CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
|
CalculateArrayLength::BufferSizeIntrinsic::~BufferSizeIntrinsic() = default;
|
||||||
std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
|
std::string CalculateArrayLength::BufferSizeIntrinsic::InternalName() const {
|
||||||
|
@ -106,7 +106,7 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
|
||||||
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
|
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
|
||||||
auto name = b.Sym();
|
auto name = b.Sym();
|
||||||
auto type = CreateASTTypeFor(ctx, buffer_type);
|
auto type = CreateASTTypeFor(ctx, buffer_type);
|
||||||
auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
|
auto* disable_validation = b.Disable(DisabledValidation::kFunctionParameter);
|
||||||
b.Func(
|
b.Func(
|
||||||
name,
|
name,
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
|
@ -128,13 +128,13 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
|
||||||
|
|
||||||
// Find all the arrayLength() calls...
|
// Find all the arrayLength() calls...
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
if (auto* call_expr = node->As<ast::CallExpression>()) {
|
if (auto* call_expr = node->As<CallExpression>()) {
|
||||||
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||||
if (builtin->Type() == builtin::Function::kArrayLength) {
|
if (builtin->Type() == builtin::Function::kArrayLength) {
|
||||||
// We're dealing with an arrayLength() call
|
// We're dealing with an arrayLength() call
|
||||||
|
|
||||||
if (auto* call_stmt = call->Stmt()->Declaration()->As<ast::CallStatement>()) {
|
if (auto* call_stmt = call->Stmt()->Declaration()->As<CallStatement>()) {
|
||||||
if (call_stmt->expr == call_expr) {
|
if (call_stmt->expr == call_expr) {
|
||||||
// arrayLength() is used as a statement.
|
// arrayLength() is used as a statement.
|
||||||
// The argument expression must be side-effect free, so just drop the
|
// The argument expression must be side-effect free, so just drop the
|
||||||
|
@ -151,13 +151,13 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
|
||||||
// arrayLength(&struct_var.array_member)
|
// arrayLength(&struct_var.array_member)
|
||||||
// arrayLength(&array_var)
|
// arrayLength(&array_var)
|
||||||
auto* arg = call_expr->args[0];
|
auto* arg = call_expr->args[0];
|
||||||
auto* address_of = arg->As<ast::UnaryOpExpression>();
|
auto* address_of = arg->As<UnaryOpExpression>();
|
||||||
if (TINT_UNLIKELY(!address_of || address_of->op != ast::UnaryOp::kAddressOf)) {
|
if (TINT_UNLIKELY(!address_of || address_of->op != UnaryOp::kAddressOf)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "arrayLength() expected address-of, got " << arg->TypeInfo().name;
|
<< "arrayLength() expected address-of, got " << arg->TypeInfo().name;
|
||||||
}
|
}
|
||||||
auto* storage_buffer_expr = address_of->expr;
|
auto* storage_buffer_expr = address_of->expr;
|
||||||
if (auto* accessor = storage_buffer_expr->As<ast::MemberAccessorExpression>()) {
|
if (auto* accessor = storage_buffer_expr->As<MemberAccessorExpression>()) {
|
||||||
storage_buffer_expr = accessor->object;
|
storage_buffer_expr = accessor->object;
|
||||||
}
|
}
|
||||||
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||||
|
@ -199,8 +199,7 @@ Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
|
||||||
// array_length = ----------------------------------------
|
// array_length = ----------------------------------------
|
||||||
// array_stride
|
// array_stride
|
||||||
auto name = b.Sym();
|
auto name = b.Sym();
|
||||||
const ast::Expression* total_size =
|
const Expression* total_size = b.Expr(buffer_size_result->variable);
|
||||||
b.Expr(buffer_size_result->variable);
|
|
||||||
|
|
||||||
const type::Array* array_type = Switch(
|
const type::Array* array_type = Switch(
|
||||||
storage_buffer_type->StoreType(),
|
storage_buffer_type->StoreType(),
|
||||||
|
|
|
@ -37,12 +37,12 @@ class CalculateArrayLength final : public utils::Castable<CalculateArrayLength,
|
||||||
/// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
|
/// BufferSizeIntrinsic is an InternalAttribute that's applied to intrinsic
|
||||||
/// functions used to obtain the runtime size of a storage buffer.
|
/// functions used to obtain the runtime size of a storage buffer.
|
||||||
class BufferSizeIntrinsic final
|
class BufferSizeIntrinsic final
|
||||||
: public utils::Castable<BufferSizeIntrinsic, ast::InternalAttribute> {
|
: public utils::Castable<BufferSizeIntrinsic, InternalAttribute> {
|
||||||
public:
|
public:
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param program_id the identifier of the program that owns this node
|
/// @param program_id the identifier of the program that owns this node
|
||||||
/// @param nid the unique node identifier
|
/// @param nid the unique node identifier
|
||||||
BufferSizeIntrinsic(ProgramID program_id, ast::NodeID nid);
|
BufferSizeIntrinsic(ProgramID program_id, NodeID nid);
|
||||||
/// Destructor
|
/// Destructor
|
||||||
~BufferSizeIntrinsic() override;
|
~BufferSizeIntrinsic() override;
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ namespace {
|
||||||
/// Info for a struct member
|
/// Info for a struct member
|
||||||
struct MemberInfo {
|
struct MemberInfo {
|
||||||
/// The struct member item
|
/// The struct member item
|
||||||
const ast::StructMember* member;
|
const StructMember* member;
|
||||||
/// The struct member location if provided
|
/// The struct member location if provided
|
||||||
std::optional<uint32_t> location;
|
std::optional<uint32_t> location;
|
||||||
};
|
};
|
||||||
|
@ -83,9 +83,9 @@ uint32_t BuiltinOrder(builtin::BuiltinValue builtin) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if `attr` is a shader IO attribute.
|
// Returns true if `attr` is a shader IO attribute.
|
||||||
bool IsShaderIOAttribute(const ast::Attribute* attr) {
|
bool IsShaderIOAttribute(const Attribute* attr) {
|
||||||
return attr->IsAnyOf<ast::BuiltinAttribute, ast::InterpolateAttribute, ast::InvariantAttribute,
|
return attr
|
||||||
ast::LocationAttribute>();
|
->IsAnyOf<BuiltinAttribute, InterpolateAttribute, InvariantAttribute, LocationAttribute>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -97,11 +97,11 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// The name of the output value.
|
/// The name of the output value.
|
||||||
std::string name;
|
std::string name;
|
||||||
/// The type of the output value.
|
/// The type of the output value.
|
||||||
ast::Type type;
|
Type type;
|
||||||
/// The shader IO attributes.
|
/// The shader IO attributes.
|
||||||
utils::Vector<const ast::Attribute*, 8> attributes;
|
utils::Vector<const Attribute*, 8> attributes;
|
||||||
/// The value itself.
|
/// The value itself.
|
||||||
const ast::Expression* value;
|
const Expression* value;
|
||||||
/// The output location.
|
/// The output location.
|
||||||
std::optional<uint32_t> location;
|
std::optional<uint32_t> location;
|
||||||
};
|
};
|
||||||
|
@ -111,29 +111,29 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// The transform config.
|
/// The transform config.
|
||||||
CanonicalizeEntryPointIO::Config const cfg;
|
CanonicalizeEntryPointIO::Config const cfg;
|
||||||
/// The entry point function (AST).
|
/// The entry point function (AST).
|
||||||
const ast::Function* func_ast;
|
const Function* func_ast;
|
||||||
/// The entry point function (SEM).
|
/// The entry point function (SEM).
|
||||||
const sem::Function* func_sem;
|
const sem::Function* func_sem;
|
||||||
|
|
||||||
/// The new entry point wrapper function's parameters.
|
/// The new entry point wrapper function's parameters.
|
||||||
utils::Vector<const ast::Parameter*, 8> wrapper_ep_parameters;
|
utils::Vector<const Parameter*, 8> wrapper_ep_parameters;
|
||||||
|
|
||||||
/// The members of the wrapper function's struct parameter.
|
/// The members of the wrapper function's struct parameter.
|
||||||
utils::Vector<MemberInfo, 8> wrapper_struct_param_members;
|
utils::Vector<MemberInfo, 8> wrapper_struct_param_members;
|
||||||
/// The name of the wrapper function's struct parameter.
|
/// The name of the wrapper function's struct parameter.
|
||||||
Symbol wrapper_struct_param_name;
|
Symbol wrapper_struct_param_name;
|
||||||
/// The parameters that will be passed to the original function.
|
/// The parameters that will be passed to the original function.
|
||||||
utils::Vector<const ast::Expression*, 8> inner_call_parameters;
|
utils::Vector<const Expression*, 8> inner_call_parameters;
|
||||||
/// The members of the wrapper function's struct return type.
|
/// The members of the wrapper function's struct return type.
|
||||||
utils::Vector<MemberInfo, 8> wrapper_struct_output_members;
|
utils::Vector<MemberInfo, 8> wrapper_struct_output_members;
|
||||||
/// The wrapper function output values.
|
/// The wrapper function output values.
|
||||||
utils::Vector<OutputValue, 8> wrapper_output_values;
|
utils::Vector<OutputValue, 8> wrapper_output_values;
|
||||||
/// The body of the wrapper function.
|
/// The body of the wrapper function.
|
||||||
utils::Vector<const ast::Statement*, 8> wrapper_body;
|
utils::Vector<const Statement*, 8> wrapper_body;
|
||||||
/// Input names used by the entrypoint
|
/// Input names used by the entrypoint
|
||||||
std::unordered_set<std::string> input_names;
|
std::unordered_set<std::string> input_names;
|
||||||
/// A map of cloned attribute to builtin value
|
/// A map of cloned attribute to builtin value
|
||||||
utils::Hashmap<const ast::BuiltinAttribute*, builtin::BuiltinValue, 16> builtin_attrs;
|
utils::Hashmap<const BuiltinAttribute*, builtin::BuiltinValue, 16> builtin_attrs;
|
||||||
|
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param context the clone context
|
/// @param context the clone context
|
||||||
|
@ -141,7 +141,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param function the entry point function
|
/// @param function the entry point function
|
||||||
State(CloneContext& context,
|
State(CloneContext& context,
|
||||||
const CanonicalizeEntryPointIO::Config& config,
|
const CanonicalizeEntryPointIO::Config& config,
|
||||||
const ast::Function* function)
|
const Function* function)
|
||||||
: ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {}
|
: ctx(context), cfg(config), func_ast(function), func_sem(ctx.src->Sem().Get(function)) {}
|
||||||
|
|
||||||
/// Clones the attributes from @p in and adds it to @p out. If @p in is a builtin attribute,
|
/// Clones the attributes from @p in and adds it to @p out. If @p in is a builtin attribute,
|
||||||
|
@ -149,12 +149,11 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param in the attribute to clone
|
/// @param in the attribute to clone
|
||||||
/// @param out the output Attributes
|
/// @param out the output Attributes
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
void CloneAttribute(const ast::Attribute* in, utils::Vector<const ast::Attribute*, N>& out) {
|
void CloneAttribute(const Attribute* in, utils::Vector<const Attribute*, N>& out) {
|
||||||
auto* cloned = ctx.Clone(in);
|
auto* cloned = ctx.Clone(in);
|
||||||
out.Push(cloned);
|
out.Push(cloned);
|
||||||
if (auto* builtin = in->As<ast::BuiltinAttribute>()) {
|
if (auto* builtin = in->As<BuiltinAttribute>()) {
|
||||||
builtin_attrs.Add(cloned->As<ast::BuiltinAttribute>(),
|
builtin_attrs.Add(cloned->As<BuiltinAttribute>(), ctx.src->Sem().Get(builtin)->Value());
|
||||||
ctx.src->Sem().Get(builtin)->Value());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,12 +162,11 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param do_interpolate whether to clone InterpolateAttribute
|
/// @param do_interpolate whether to clone InterpolateAttribute
|
||||||
/// @return the cloned attributes
|
/// @return the cloned attributes
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
auto CloneShaderIOAttributes(const utils::Vector<const ast::Attribute*, N> in,
|
auto CloneShaderIOAttributes(const utils::Vector<const Attribute*, N> in, bool do_interpolate) {
|
||||||
bool do_interpolate) {
|
utils::Vector<const Attribute*, N> out;
|
||||||
utils::Vector<const ast::Attribute*, N> out;
|
|
||||||
for (auto* attr : in) {
|
for (auto* attr : in) {
|
||||||
if (IsShaderIOAttribute(attr) &&
|
if (IsShaderIOAttribute(attr) &&
|
||||||
(do_interpolate || !attr->template Is<ast::InterpolateAttribute>())) {
|
(do_interpolate || !attr->template Is<InterpolateAttribute>())) {
|
||||||
CloneAttribute(attr, out);
|
CloneAttribute(attr, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -177,7 +175,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
|
|
||||||
/// @param attr the input attribute
|
/// @param attr the input attribute
|
||||||
/// @returns the builtin value of the attribute
|
/// @returns the builtin value of the attribute
|
||||||
builtin::BuiltinValue BuiltinOf(const ast::BuiltinAttribute* attr) {
|
builtin::BuiltinValue BuiltinOf(const BuiltinAttribute* attr) {
|
||||||
if (attr->program_id == ctx.dst->ID()) {
|
if (attr->program_id == ctx.dst->ID()) {
|
||||||
// attr belongs to the target program.
|
// attr belongs to the target program.
|
||||||
// Obtain the builtin value from #builtin_attrs.
|
// Obtain the builtin value from #builtin_attrs.
|
||||||
|
@ -197,8 +195,8 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param attrs the input attribute list
|
/// @param attrs the input attribute list
|
||||||
/// @returns the builtin value if any of the attributes in @p attrs is a builtin attribute,
|
/// @returns the builtin value if any of the attributes in @p attrs is a builtin attribute,
|
||||||
/// otherwise builtin::BuiltinValue::kUndefined
|
/// otherwise builtin::BuiltinValue::kUndefined
|
||||||
builtin::BuiltinValue BuiltinOf(utils::VectorRef<const ast::Attribute*> attrs) {
|
builtin::BuiltinValue BuiltinOf(utils::VectorRef<const Attribute*> attrs) {
|
||||||
if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(attrs)) {
|
if (auto* builtin = GetAttribute<BuiltinAttribute>(attrs)) {
|
||||||
return BuiltinOf(builtin);
|
return BuiltinOf(builtin);
|
||||||
}
|
}
|
||||||
return builtin::BuiltinValue::kUndefined;
|
return builtin::BuiltinValue::kUndefined;
|
||||||
|
@ -219,10 +217,10 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param location the location if provided
|
/// @param location the location if provided
|
||||||
/// @param attrs the attributes to apply to the shader input
|
/// @param attrs the attributes to apply to the shader input
|
||||||
/// @returns an expression which evaluates to the value of the shader input
|
/// @returns an expression which evaluates to the value of the shader input
|
||||||
const ast::Expression* AddInput(std::string name,
|
const Expression* AddInput(std::string name,
|
||||||
const type::Type* type,
|
const type::Type* type,
|
||||||
std::optional<uint32_t> location,
|
std::optional<uint32_t> location,
|
||||||
utils::Vector<const ast::Attribute*, 8> attrs) {
|
utils::Vector<const Attribute*, 8> attrs) {
|
||||||
auto ast_type = CreateASTTypeFor(ctx, type);
|
auto ast_type = CreateASTTypeFor(ctx, type);
|
||||||
|
|
||||||
auto builtin_attr = BuiltinOf(attrs);
|
auto builtin_attr = BuiltinOf(attrs);
|
||||||
|
@ -233,17 +231,16 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
// https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/StandaloneSpirv.html#VUID-StandaloneSpirv-Flat-04744
|
// https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/StandaloneSpirv.html#VUID-StandaloneSpirv-Flat-04744
|
||||||
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is
|
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is
|
||||||
// required for integers.
|
// required for integers.
|
||||||
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
|
if (func_ast->PipelineStage() == PipelineStage::kFragment &&
|
||||||
type->is_integer_scalar_or_vector() &&
|
type->is_integer_scalar_or_vector() && !HasAttribute<InterpolateAttribute>(attrs) &&
|
||||||
!ast::HasAttribute<ast::InterpolateAttribute>(attrs) &&
|
(HasAttribute<LocationAttribute>(attrs) ||
|
||||||
(ast::HasAttribute<ast::LocationAttribute>(attrs) ||
|
|
||||||
cfg.shader_style == ShaderStyle::kSpirv)) {
|
cfg.shader_style == ShaderStyle::kSpirv)) {
|
||||||
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
|
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
|
||||||
builtin::InterpolationSampling::kUndefined));
|
builtin::InterpolationSampling::kUndefined));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable validation for use of the `input` address space.
|
// Disable validation for use of the `input` address space.
|
||||||
attrs.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
|
attrs.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
|
||||||
|
|
||||||
// In GLSL, if it's a builtin, override the name with the
|
// In GLSL, if it's a builtin, override the name with the
|
||||||
// corresponding gl_ builtin name
|
// corresponding gl_ builtin name
|
||||||
|
@ -255,7 +252,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
auto symbol = ctx.dst->Symbols().New(name);
|
auto symbol = ctx.dst->Symbols().New(name);
|
||||||
|
|
||||||
// Create the global variable and use its value for the shader input.
|
// Create the global variable and use its value for the shader input.
|
||||||
const ast::Expression* value = ctx.dst->Expr(symbol);
|
const Expression* value = ctx.dst->Expr(symbol);
|
||||||
|
|
||||||
if (builtin_attr != builtin::BuiltinValue::kUndefined) {
|
if (builtin_attr != builtin::BuiltinValue::kUndefined) {
|
||||||
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
||||||
|
@ -296,18 +293,17 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
void AddOutput(std::string name,
|
void AddOutput(std::string name,
|
||||||
const type::Type* type,
|
const type::Type* type,
|
||||||
std::optional<uint32_t> location,
|
std::optional<uint32_t> location,
|
||||||
utils::Vector<const ast::Attribute*, 8> attrs,
|
utils::Vector<const Attribute*, 8> attrs,
|
||||||
const ast::Expression* value) {
|
const Expression* value) {
|
||||||
auto builtin_attr = BuiltinOf(attrs);
|
auto builtin_attr = BuiltinOf(attrs);
|
||||||
// Vulkan requires that integer user-defined vertex outputs are always decorated with
|
// Vulkan requires that integer user-defined vertex outputs are always decorated with
|
||||||
// `Flat`.
|
// `Flat`.
|
||||||
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is required
|
// TODO(crbug.com/tint/1224): Remove this once a flat interpolation attribute is required
|
||||||
// for integers.
|
// for integers.
|
||||||
if (cfg.shader_style == ShaderStyle::kSpirv &&
|
if (cfg.shader_style == ShaderStyle::kSpirv &&
|
||||||
func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
|
func_ast->PipelineStage() == PipelineStage::kVertex &&
|
||||||
type->is_integer_scalar_or_vector() &&
|
type->is_integer_scalar_or_vector() && HasAttribute<LocationAttribute>(attrs) &&
|
||||||
ast::HasAttribute<ast::LocationAttribute>(attrs) &&
|
!HasAttribute<InterpolateAttribute>(attrs)) {
|
||||||
!ast::HasAttribute<ast::InterpolateAttribute>(attrs)) {
|
|
||||||
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
|
attrs.Push(ctx.dst->Interpolate(builtin::InterpolationType::kFlat,
|
||||||
builtin::InterpolationSampling::kUndefined));
|
builtin::InterpolationSampling::kUndefined));
|
||||||
}
|
}
|
||||||
|
@ -338,14 +334,14 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param param the original function parameter
|
/// @param param the original function parameter
|
||||||
void ProcessNonStructParameter(const sem::Parameter* param) {
|
void ProcessNonStructParameter(const sem::Parameter* param) {
|
||||||
// Do not add interpolation attributes on vertex input
|
// Do not add interpolation attributes on vertex input
|
||||||
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kVertex;
|
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
|
||||||
// Remove the shader IO attributes from the inner function parameter, and attach them to the
|
// Remove the shader IO attributes from the inner function parameter, and attach them to the
|
||||||
// new object instead.
|
// new object instead.
|
||||||
utils::Vector<const ast::Attribute*, 8> attributes;
|
utils::Vector<const Attribute*, 8> attributes;
|
||||||
for (auto* attr : param->Declaration()->attributes) {
|
for (auto* attr : param->Declaration()->attributes) {
|
||||||
if (IsShaderIOAttribute(attr)) {
|
if (IsShaderIOAttribute(attr)) {
|
||||||
ctx.Remove(param->Declaration()->attributes, attr);
|
ctx.Remove(param->Declaration()->attributes, attr);
|
||||||
if ((do_interpolate || !attr->Is<ast::InterpolateAttribute>())) {
|
if ((do_interpolate || !attr->Is<InterpolateAttribute>())) {
|
||||||
CloneAttribute(attr, attributes);
|
CloneAttribute(attr, attributes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -363,13 +359,13 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param param the original function parameter
|
/// @param param the original function parameter
|
||||||
void ProcessStructParameter(const sem::Parameter* param) {
|
void ProcessStructParameter(const sem::Parameter* param) {
|
||||||
// Do not add interpolation attributes on vertex input
|
// Do not add interpolation attributes on vertex input
|
||||||
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kVertex;
|
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kVertex;
|
||||||
|
|
||||||
auto* str = param->Type()->As<sem::Struct>();
|
auto* str = param->Type()->As<sem::Struct>();
|
||||||
|
|
||||||
// Recreate struct members in the outer entry point and build an initializer
|
// Recreate struct members in the outer entry point and build an initializer
|
||||||
// list to pass them through to the inner function.
|
// list to pass them through to the inner function.
|
||||||
utils::Vector<const ast::Expression*, 8> inner_struct_values;
|
utils::Vector<const Expression*, 8> inner_struct_values;
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
|
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
|
||||||
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
|
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "nested IO struct";
|
||||||
|
@ -397,7 +393,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param original_result the result object produced by the original function
|
/// @param original_result the result object produced by the original function
|
||||||
void ProcessReturnType(const type::Type* inner_ret_type, Symbol original_result) {
|
void ProcessReturnType(const type::Type* inner_ret_type, Symbol original_result) {
|
||||||
// Do not add interpolation attributes on fragment output
|
// Do not add interpolation attributes on fragment output
|
||||||
bool do_interpolate = func_ast->PipelineStage() != ast::PipelineStage::kFragment;
|
bool do_interpolate = func_ast->PipelineStage() != PipelineStage::kFragment;
|
||||||
if (auto* str = inner_ret_type->As<sem::Struct>()) {
|
if (auto* str = inner_ret_type->As<sem::Struct>()) {
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
|
if (TINT_UNLIKELY(member->Type()->Is<type::Struct>())) {
|
||||||
|
@ -456,7 +452,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// Create an expression for gl_Position.[component]
|
/// Create an expression for gl_Position.[component]
|
||||||
/// @param component the component of gl_Position to access
|
/// @param component the component of gl_Position to access
|
||||||
/// @returns the new expression
|
/// @returns the new expression
|
||||||
const ast::Expression* GLPosition(const char* component) {
|
const Expression* GLPosition(const char* component) {
|
||||||
Symbol pos = ctx.dst->Symbols().Register("gl_Position");
|
Symbol pos = ctx.dst->Symbols().Register("gl_Position");
|
||||||
Symbol c = ctx.dst->Symbols().Register(component);
|
Symbol c = ctx.dst->Symbols().Register(component);
|
||||||
return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), c);
|
return ctx.dst->MemberAccessor(ctx.dst->Expr(pos), c);
|
||||||
|
@ -469,10 +465,10 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param b another struct member
|
/// @param b another struct member
|
||||||
/// @returns true if a comes before b
|
/// @returns true if a comes before b
|
||||||
bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) {
|
bool StructMemberComparator(const MemberInfo& a, const MemberInfo& b) {
|
||||||
auto* a_loc = ast::GetAttribute<ast::LocationAttribute>(a.member->attributes);
|
auto* a_loc = GetAttribute<LocationAttribute>(a.member->attributes);
|
||||||
auto* b_loc = ast::GetAttribute<ast::LocationAttribute>(b.member->attributes);
|
auto* b_loc = GetAttribute<LocationAttribute>(b.member->attributes);
|
||||||
auto* a_blt = ast::GetAttribute<ast::BuiltinAttribute>(a.member->attributes);
|
auto* a_blt = GetAttribute<BuiltinAttribute>(a.member->attributes);
|
||||||
auto* b_blt = ast::GetAttribute<ast::BuiltinAttribute>(b.member->attributes);
|
auto* b_blt = GetAttribute<BuiltinAttribute>(b.member->attributes);
|
||||||
if (a_loc) {
|
if (a_loc) {
|
||||||
if (!b_loc) {
|
if (!b_loc) {
|
||||||
// `a` has location attribute and `b` does not: `a` goes first.
|
// `a` has location attribute and `b` does not: `a` goes first.
|
||||||
|
@ -497,15 +493,15 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
|
std::sort(wrapper_struct_param_members.begin(), wrapper_struct_param_members.end(),
|
||||||
[&](auto& a, auto& b) { return StructMemberComparator(a, b); });
|
[&](auto& a, auto& b) { return StructMemberComparator(a, b); });
|
||||||
|
|
||||||
utils::Vector<const ast::StructMember*, 8> members;
|
utils::Vector<const StructMember*, 8> members;
|
||||||
for (auto& mem : wrapper_struct_param_members) {
|
for (auto& mem : wrapper_struct_param_members) {
|
||||||
members.Push(mem.member);
|
members.Push(mem.member);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the new struct type.
|
// Create the new struct type.
|
||||||
auto struct_name = ctx.dst->Sym();
|
auto struct_name = ctx.dst->Sym();
|
||||||
auto* in_struct = ctx.dst->create<ast::Struct>(ctx.dst->Ident(struct_name),
|
auto* in_struct =
|
||||||
std::move(members), utils::Empty);
|
ctx.dst->create<Struct>(ctx.dst->Ident(struct_name), std::move(members), utils::Empty);
|
||||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
|
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, in_struct);
|
||||||
|
|
||||||
// Create a new function parameter using this struct type.
|
// Create a new function parameter using this struct type.
|
||||||
|
@ -515,8 +511,8 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
|
|
||||||
/// Create and return the wrapper function's struct result object.
|
/// Create and return the wrapper function's struct result object.
|
||||||
/// @returns the struct type
|
/// @returns the struct type
|
||||||
ast::Struct* CreateOutputStruct() {
|
Struct* CreateOutputStruct() {
|
||||||
utils::Vector<const ast::Statement*, 8> assignments;
|
utils::Vector<const Statement*, 8> assignments;
|
||||||
|
|
||||||
auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
|
auto wrapper_result = ctx.dst->Symbols().New("wrapper_result");
|
||||||
|
|
||||||
|
@ -544,13 +540,13 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
|
std::sort(wrapper_struct_output_members.begin(), wrapper_struct_output_members.end(),
|
||||||
[&](auto& a, auto& b) { return StructMemberComparator(a, b); });
|
[&](auto& a, auto& b) { return StructMemberComparator(a, b); });
|
||||||
|
|
||||||
utils::Vector<const ast::StructMember*, 8> members;
|
utils::Vector<const StructMember*, 8> members;
|
||||||
for (auto& mem : wrapper_struct_output_members) {
|
for (auto& mem : wrapper_struct_output_members) {
|
||||||
members.Push(mem.member);
|
members.Push(mem.member);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the new struct type.
|
// Create the new struct type.
|
||||||
auto* out_struct = ctx.dst->create<ast::Struct>(ctx.dst->Ident(ctx.dst->Sym()),
|
auto* out_struct = ctx.dst->create<Struct>(ctx.dst->Ident(ctx.dst->Sym()),
|
||||||
std::move(members), utils::Empty);
|
std::move(members), utils::Empty);
|
||||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
|
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func_ast, out_struct);
|
||||||
|
|
||||||
|
@ -570,12 +566,12 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
for (auto& outval : wrapper_output_values) {
|
for (auto& outval : wrapper_output_values) {
|
||||||
// Disable validation for use of the `output` address space.
|
// Disable validation for use of the `output` address space.
|
||||||
auto attributes = std::move(outval.attributes);
|
auto attributes = std::move(outval.attributes);
|
||||||
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
|
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
|
||||||
|
|
||||||
// Create the global variable and assign it the output value.
|
// Create the global variable and assign it the output value.
|
||||||
auto name = ctx.dst->Symbols().New(outval.name);
|
auto name = ctx.dst->Symbols().New(outval.name);
|
||||||
ast::Type type = outval.type;
|
Type type = outval.type;
|
||||||
const ast::Expression* lhs = ctx.dst->Expr(name);
|
const Expression* lhs = ctx.dst->Expr(name);
|
||||||
if (BuiltinOf(attributes) == builtin::BuiltinValue::kSampleMask) {
|
if (BuiltinOf(attributes) == builtin::BuiltinValue::kSampleMask) {
|
||||||
// Vulkan requires the type of a SampleMask builtin to be an array.
|
// Vulkan requires the type of a SampleMask builtin to be an array.
|
||||||
// Declare it as array<u32, 1> and then store to the first element.
|
// Declare it as array<u32, 1> and then store to the first element.
|
||||||
|
@ -589,7 +585,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
|
|
||||||
// Recreate the original function without entry point attributes and call it.
|
// Recreate the original function without entry point attributes and call it.
|
||||||
/// @returns the inner function call expression
|
/// @returns the inner function call expression
|
||||||
const ast::CallExpression* CallInnerFunction() {
|
const CallExpression* CallInnerFunction() {
|
||||||
Symbol inner_name;
|
Symbol inner_name;
|
||||||
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
if (cfg.shader_style == ShaderStyle::kGlsl) {
|
||||||
// In GLSL, clone the original entry point name, as the wrapper will be
|
// In GLSL, clone the original entry point name, as the wrapper will be
|
||||||
|
@ -606,9 +602,9 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
// The parameter attributes will have already been stripped during
|
// The parameter attributes will have already been stripped during
|
||||||
// processing.
|
// processing.
|
||||||
auto* inner_function =
|
auto* inner_function =
|
||||||
ctx.dst->create<ast::Function>(ctx.dst->Ident(inner_name), ctx.Clone(func_ast->params),
|
ctx.dst->create<Function>(ctx.dst->Ident(inner_name), ctx.Clone(func_ast->params),
|
||||||
ctx.Clone(func_ast->return_type),
|
ctx.Clone(func_ast->return_type), ctx.Clone(func_ast->body),
|
||||||
ctx.Clone(func_ast->body), utils::Empty, utils::Empty);
|
utils::Empty, utils::Empty);
|
||||||
ctx.Replace(func_ast, inner_function);
|
ctx.Replace(func_ast, inner_function);
|
||||||
|
|
||||||
// Call the function.
|
// Call the function.
|
||||||
|
@ -619,12 +615,11 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
void Process() {
|
void Process() {
|
||||||
bool needs_fixed_sample_mask = false;
|
bool needs_fixed_sample_mask = false;
|
||||||
bool needs_vertex_point_size = false;
|
bool needs_vertex_point_size = false;
|
||||||
if (func_ast->PipelineStage() == ast::PipelineStage::kFragment &&
|
if (func_ast->PipelineStage() == PipelineStage::kFragment &&
|
||||||
cfg.fixed_sample_mask != 0xFFFFFFFF) {
|
cfg.fixed_sample_mask != 0xFFFFFFFF) {
|
||||||
needs_fixed_sample_mask = true;
|
needs_fixed_sample_mask = true;
|
||||||
}
|
}
|
||||||
if (func_ast->PipelineStage() == ast::PipelineStage::kVertex &&
|
if (func_ast->PipelineStage() == PipelineStage::kVertex && cfg.emit_vertex_point_size) {
|
||||||
cfg.emit_vertex_point_size) {
|
|
||||||
needs_vertex_point_size = true;
|
needs_vertex_point_size = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -656,7 +651,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
auto* call_inner = CallInnerFunction();
|
auto* call_inner = CallInnerFunction();
|
||||||
|
|
||||||
// Process the return type, and start building the wrapper function body.
|
// Process the return type, and start building the wrapper function body.
|
||||||
std::function<ast::Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
|
std::function<Type()> wrapper_ret_type = [&] { return ctx.dst->ty.void_(); };
|
||||||
if (func_sem->ReturnType()->Is<type::Void>()) {
|
if (func_sem->ReturnType()->Is<type::Void>()) {
|
||||||
// The function call is just a statement with no result.
|
// The function call is just a statement with no result.
|
||||||
wrapper_body.Push(ctx.dst->CallStmt(call_inner));
|
wrapper_body.Push(ctx.dst->CallStmt(call_inner));
|
||||||
|
@ -693,10 +688,10 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cfg.shader_style == ShaderStyle::kGlsl &&
|
if (cfg.shader_style == ShaderStyle::kGlsl &&
|
||||||
func_ast->PipelineStage() == ast::PipelineStage::kVertex) {
|
func_ast->PipelineStage() == PipelineStage::kVertex) {
|
||||||
auto* pos_y = GLPosition("y");
|
auto* pos_y = GLPosition("y");
|
||||||
auto* negate_pos_y =
|
auto* negate_pos_y =
|
||||||
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation, GLPosition("y"));
|
ctx.dst->create<UnaryOpExpression>(UnaryOp::kNegation, GLPosition("y"));
|
||||||
wrapper_body.Push(ctx.dst->Assign(pos_y, negate_pos_y));
|
wrapper_body.Push(ctx.dst->Assign(pos_y, negate_pos_y));
|
||||||
|
|
||||||
auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z"));
|
auto* two_z = ctx.dst->Mul(ctx.dst->Expr(2_f), GLPosition("z"));
|
||||||
|
@ -714,7 +709,7 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
name = ctx.Clone(func_ast->name->symbol);
|
name = ctx.Clone(func_ast->name->symbol);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* wrapper_func = ctx.dst->create<ast::Function>(
|
auto* wrapper_func = ctx.dst->create<Function>(
|
||||||
ctx.dst->Ident(name), wrapper_ep_parameters, ctx.dst->ty(wrapper_ret_type()),
|
ctx.dst->Ident(name), wrapper_ep_parameters, ctx.dst->ty(wrapper_ret_type()),
|
||||||
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes), utils::Empty);
|
ctx.dst->Block(wrapper_body), ctx.Clone(func_ast->attributes), utils::Empty);
|
||||||
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
|
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), func_ast, wrapper_func);
|
||||||
|
@ -726,14 +721,14 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param address_space the address space (input or output)
|
/// @param address_space the address space (input or output)
|
||||||
/// @returns the gl_ string corresponding to that builtin
|
/// @returns the gl_ string corresponding to that builtin
|
||||||
const char* GLSLBuiltinToString(builtin::BuiltinValue builtin,
|
const char* GLSLBuiltinToString(builtin::BuiltinValue builtin,
|
||||||
ast::PipelineStage stage,
|
PipelineStage stage,
|
||||||
builtin::AddressSpace address_space) {
|
builtin::AddressSpace address_space) {
|
||||||
switch (builtin) {
|
switch (builtin) {
|
||||||
case builtin::BuiltinValue::kPosition:
|
case builtin::BuiltinValue::kPosition:
|
||||||
switch (stage) {
|
switch (stage) {
|
||||||
case ast::PipelineStage::kVertex:
|
case PipelineStage::kVertex:
|
||||||
return "gl_Position";
|
return "gl_Position";
|
||||||
case ast::PipelineStage::kFragment:
|
case PipelineStage::kFragment:
|
||||||
return "gl_FragCoord";
|
return "gl_FragCoord";
|
||||||
default:
|
default:
|
||||||
return "";
|
return "";
|
||||||
|
@ -775,9 +770,9 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
|
/// @param ast_type (inout) the incoming WGSL and outgoing GLSL types
|
||||||
/// @returns an expression representing the GLSL builtin converted to what
|
/// @returns an expression representing the GLSL builtin converted to what
|
||||||
/// WGSL expects
|
/// WGSL expects
|
||||||
const ast::Expression* FromGLSLBuiltin(builtin::BuiltinValue builtin,
|
const Expression* FromGLSLBuiltin(builtin::BuiltinValue builtin,
|
||||||
const ast::Expression* value,
|
const Expression* value,
|
||||||
ast::Type& ast_type) {
|
Type& ast_type) {
|
||||||
switch (builtin) {
|
switch (builtin) {
|
||||||
case builtin::BuiltinValue::kVertexIndex:
|
case builtin::BuiltinValue::kVertexIndex:
|
||||||
case builtin::BuiltinValue::kInstanceIndex:
|
case builtin::BuiltinValue::kInstanceIndex:
|
||||||
|
@ -805,8 +800,8 @@ struct CanonicalizeEntryPointIO::State {
|
||||||
/// @param value the value to convert
|
/// @param value the value to convert
|
||||||
/// @param type (out) the type to which the value was converted
|
/// @param type (out) the type to which the value was converted
|
||||||
/// @returns the converted value which can be assigned to the GLSL builtin
|
/// @returns the converted value which can be assigned to the GLSL builtin
|
||||||
const ast::Expression* ToGLSLBuiltin(builtin::BuiltinValue builtin,
|
const Expression* ToGLSLBuiltin(builtin::BuiltinValue builtin,
|
||||||
const ast::Expression* value,
|
const Expression* value,
|
||||||
const type::Type*& type) {
|
const type::Type*& type) {
|
||||||
switch (builtin) {
|
switch (builtin) {
|
||||||
case builtin::BuiltinValue::kVertexIndex:
|
case builtin::BuiltinValue::kVertexIndex:
|
||||||
|
@ -839,7 +834,7 @@ Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
|
||||||
// Remove entry point IO attributes from struct declarations.
|
// Remove entry point IO attributes from struct declarations.
|
||||||
// New structures will be created for each entry point, as necessary.
|
// New structures will be created for each entry point, as necessary.
|
||||||
for (auto* ty : src->AST().TypeDecls()) {
|
for (auto* ty : src->AST().TypeDecls()) {
|
||||||
if (auto* struct_ty = ty->As<ast::Struct>()) {
|
if (auto* struct_ty = ty->As<Struct>()) {
|
||||||
for (auto* member : struct_ty->members) {
|
for (auto* member : struct_ty->members) {
|
||||||
for (auto* attr : member->attributes) {
|
for (auto* attr : member->attributes) {
|
||||||
if (IsShaderIOAttribute(attr)) {
|
if (IsShaderIOAttribute(attr)) {
|
||||||
|
|
|
@ -51,7 +51,7 @@ struct ClampFragDepth::State {
|
||||||
Transform::ApplyResult Run() {
|
Transform::ApplyResult Run() {
|
||||||
// Abort on any use of push constants in the module.
|
// Abort on any use of push constants in the module.
|
||||||
for (auto* global : src->AST().GlobalVariables()) {
|
for (auto* global : src->AST().GlobalVariables()) {
|
||||||
if (auto* var = global->As<ast::Var>()) {
|
if (auto* var = global->As<Var>()) {
|
||||||
auto* v = src->Sem().Get(var);
|
auto* v = src->Sem().Get(var);
|
||||||
if (TINT_UNLIKELY(v->AddressSpace() == builtin::AddressSpace::kPushConstant)) {
|
if (TINT_UNLIKELY(v->AddressSpace() == builtin::AddressSpace::kPushConstant)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
|
@ -101,14 +101,14 @@ struct ClampFragDepth::State {
|
||||||
|
|
||||||
// Map of io struct to helper function to return the structure with the depth clamping
|
// Map of io struct to helper function to return the structure with the depth clamping
|
||||||
// applied.
|
// applied.
|
||||||
utils::Hashmap<const ast::Struct*, Symbol, 4u> io_structs_clamp_helpers;
|
utils::Hashmap<const Struct*, Symbol, 4u> io_structs_clamp_helpers;
|
||||||
|
|
||||||
// Register a callback that will be called for each visted AST function.
|
// Register a callback that will be called for each visted AST function.
|
||||||
// This call wraps the cloning of the function's statements, and will assign to
|
// This call wraps the cloning of the function's statements, and will assign to
|
||||||
// `returns_frag_depth_as_value` or `returns_frag_depth_as_struct_helper` if the function's
|
// `returns_frag_depth_as_value` or `returns_frag_depth_as_struct_helper` if the function's
|
||||||
// return value requires depth clamping.
|
// return value requires depth clamping.
|
||||||
ctx.ReplaceAll([&](const ast::Function* fn) {
|
ctx.ReplaceAll([&](const Function* fn) {
|
||||||
if (fn->PipelineStage() != ast::PipelineStage::kFragment) {
|
if (fn->PipelineStage() != PipelineStage::kFragment) {
|
||||||
return ctx.CloneWithoutTransform(fn);
|
return ctx.CloneWithoutTransform(fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,9 +129,9 @@ struct ClampFragDepth::State {
|
||||||
auto fn_sym =
|
auto fn_sym =
|
||||||
b.Symbols().New("clamp_frag_depth_" + struct_ty->name->symbol.Name());
|
b.Symbols().New("clamp_frag_depth_" + struct_ty->name->symbol.Name());
|
||||||
|
|
||||||
utils::Vector<const ast::Expression*, 8u> initializer_args;
|
utils::Vector<const Expression*, 8u> initializer_args;
|
||||||
for (auto* member : struct_ty->members) {
|
for (auto* member : struct_ty->members) {
|
||||||
const ast::Expression* arg =
|
const Expression* arg =
|
||||||
b.MemberAccessor("s", ctx.Clone(member->name->symbol));
|
b.MemberAccessor("s", ctx.Clone(member->name->symbol));
|
||||||
if (ContainsFragDepth(member->attributes)) {
|
if (ContainsFragDepth(member->attributes)) {
|
||||||
arg = b.Call(base_fn_sym, arg);
|
arg = b.Call(base_fn_sym, arg);
|
||||||
|
@ -154,7 +154,7 @@ struct ClampFragDepth::State {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Replace the return statements `return expr` with `return clamp_frag_depth(expr)`.
|
// Replace the return statements `return expr` with `return clamp_frag_depth(expr)`.
|
||||||
ctx.ReplaceAll([&](const ast::ReturnStatement* stmt) -> const ast::ReturnStatement* {
|
ctx.ReplaceAll([&](const ReturnStatement* stmt) -> const ReturnStatement* {
|
||||||
if (returns_frag_depth_as_value) {
|
if (returns_frag_depth_as_value) {
|
||||||
return b.Return(stmt->source, b.Call(base_fn_sym, ctx.Clone(stmt->value)));
|
return b.Return(stmt->source, b.Call(base_fn_sym, ctx.Clone(stmt->value)));
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,7 @@ struct ClampFragDepth::State {
|
||||||
/// @returns true if the transform should run
|
/// @returns true if the transform should run
|
||||||
bool ShouldRun() {
|
bool ShouldRun() {
|
||||||
for (auto* fn : src->AST().Functions()) {
|
for (auto* fn : src->AST().Functions()) {
|
||||||
if (fn->PipelineStage() == ast::PipelineStage::kFragment &&
|
if (fn->PipelineStage() == PipelineStage::kFragment &&
|
||||||
(ReturnsFragDepthAsValue(fn) || ReturnsFragDepthInStruct(fn))) {
|
(ReturnsFragDepthAsValue(fn) || ReturnsFragDepthInStruct(fn))) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -183,9 +183,9 @@ struct ClampFragDepth::State {
|
||||||
}
|
}
|
||||||
/// @param attrs the attributes to examine
|
/// @param attrs the attributes to examine
|
||||||
/// @returns true if @p attrs contains a `@builtin(frag_depth)` attribute
|
/// @returns true if @p attrs contains a `@builtin(frag_depth)` attribute
|
||||||
bool ContainsFragDepth(utils::VectorRef<const ast::Attribute*> attrs) {
|
bool ContainsFragDepth(utils::VectorRef<const Attribute*> attrs) {
|
||||||
for (auto* attribute : attrs) {
|
for (auto* attribute : attrs) {
|
||||||
if (auto* builtin_attr = attribute->As<ast::BuiltinAttribute>()) {
|
if (auto* builtin_attr = attribute->As<BuiltinAttribute>()) {
|
||||||
auto builtin = sem.Get(builtin_attr)->Value();
|
auto builtin = sem.Get(builtin_attr)->Value();
|
||||||
if (builtin == builtin::BuiltinValue::kFragDepth) {
|
if (builtin == builtin::BuiltinValue::kFragDepth) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -198,14 +198,14 @@ struct ClampFragDepth::State {
|
||||||
|
|
||||||
/// @param fn the function to examine
|
/// @param fn the function to examine
|
||||||
/// @returns true if @p fn has a return type with a `@builtin(frag_depth)` attribute
|
/// @returns true if @p fn has a return type with a `@builtin(frag_depth)` attribute
|
||||||
bool ReturnsFragDepthAsValue(const ast::Function* fn) {
|
bool ReturnsFragDepthAsValue(const Function* fn) {
|
||||||
return ContainsFragDepth(fn->return_type_attributes);
|
return ContainsFragDepth(fn->return_type_attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @param fn the function to examine
|
/// @param fn the function to examine
|
||||||
/// @returns true if @p fn has a return structure with a `@builtin(frag_depth)` attribute on one
|
/// @returns true if @p fn has a return structure with a `@builtin(frag_depth)` attribute on one
|
||||||
/// of the members
|
/// of the members
|
||||||
bool ReturnsFragDepthInStruct(const ast::Function* fn) {
|
bool ReturnsFragDepthInStruct(const Function* fn) {
|
||||||
if (auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()) {
|
if (auto* struct_ty = sem.Get(fn)->ReturnType()->As<sem::Struct>()) {
|
||||||
for (auto* member : struct_ty->Members()) {
|
for (auto* member : struct_ty->Members()) {
|
||||||
if (ContainsFragDepth(member->Declaration()->attributes)) {
|
if (ContainsFragDepth(member->Declaration()->attributes)) {
|
||||||
|
|
|
@ -61,7 +61,7 @@ struct CombineSamplers::State {
|
||||||
|
|
||||||
/// Map from a texture/sampler pair to the corresponding combined sampler
|
/// Map from a texture/sampler pair to the corresponding combined sampler
|
||||||
/// variable
|
/// variable
|
||||||
using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const ast::Variable*>;
|
using CombinedTextureSamplerMap = std::unordered_map<sem::VariablePair, const Variable*>;
|
||||||
|
|
||||||
/// Use sem::BindingPoint without scope.
|
/// Use sem::BindingPoint without scope.
|
||||||
using BindingPoint = sem::BindingPoint;
|
using BindingPoint = sem::BindingPoint;
|
||||||
|
@ -79,15 +79,14 @@ struct CombineSamplers::State {
|
||||||
/// references (one comparison sampler, one regular). These are also used as
|
/// references (one comparison sampler, one regular). These are also used as
|
||||||
/// temporary sampler parameters to the texture builtins to satisfy the WGSL
|
/// temporary sampler parameters to the texture builtins to satisfy the WGSL
|
||||||
/// resolver, but are then ignored and removed by the GLSL writer.
|
/// resolver, but are then ignored and removed by the GLSL writer.
|
||||||
const ast::Variable* placeholder_samplers_[2] = {};
|
const Variable* placeholder_samplers_[2] = {};
|
||||||
|
|
||||||
/// Group and binding attributes used by all combined sampler globals.
|
/// Group and binding attributes used by all combined sampler globals.
|
||||||
/// Group 0 and binding 0 are used, with collisions disabled.
|
/// Group 0 and binding 0 are used, with collisions disabled.
|
||||||
/// @returns the newly-created attribute list
|
/// @returns the newly-created attribute list
|
||||||
auto Attributes() const {
|
auto Attributes() const {
|
||||||
utils::Vector<const ast::Attribute*, 3> attributes{ctx.dst->Group(0_a),
|
utils::Vector<const Attribute*, 3> attributes{ctx.dst->Group(0_a), ctx.dst->Binding(0_a)};
|
||||||
ctx.dst->Binding(0_a)};
|
attributes.Push(ctx.dst->Disable(DisabledValidation::kBindingPointCollision));
|
||||||
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision));
|
|
||||||
return attributes;
|
return attributes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,7 +102,7 @@ struct CombineSamplers::State {
|
||||||
/// @param sampler_var the sampler (global) variable
|
/// @param sampler_var the sampler (global) variable
|
||||||
/// @param name the default name to use (may be overridden by map lookup)
|
/// @param name the default name to use (may be overridden by map lookup)
|
||||||
/// @returns the newly-created global variable
|
/// @returns the newly-created global variable
|
||||||
const ast::Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
|
const Variable* CreateCombinedGlobal(const sem::Variable* texture_var,
|
||||||
const sem::Variable* sampler_var,
|
const sem::Variable* sampler_var,
|
||||||
std::string name) {
|
std::string name) {
|
||||||
SamplerTexturePair bp_pair;
|
SamplerTexturePair bp_pair;
|
||||||
|
@ -115,7 +114,7 @@ struct CombineSamplers::State {
|
||||||
if (it != binding_info->binding_map.end()) {
|
if (it != binding_info->binding_map.end()) {
|
||||||
name = it->second;
|
name = it->second;
|
||||||
}
|
}
|
||||||
ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
|
Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
|
||||||
Symbol symbol = ctx.dst->Symbols().New(name);
|
Symbol symbol = ctx.dst->Symbols().New(name);
|
||||||
return ctx.dst->GlobalVar(symbol, type, Attributes());
|
return ctx.dst->GlobalVar(symbol, type, Attributes());
|
||||||
}
|
}
|
||||||
|
@ -123,8 +122,8 @@ struct CombineSamplers::State {
|
||||||
/// Creates placeholder global sampler variables.
|
/// Creates placeholder global sampler variables.
|
||||||
/// @param kind the sampler kind to create for
|
/// @param kind the sampler kind to create for
|
||||||
/// @returns the newly-created global variable
|
/// @returns the newly-created global variable
|
||||||
const ast::Variable* CreatePlaceholder(type::SamplerKind kind) {
|
const Variable* CreatePlaceholder(type::SamplerKind kind) {
|
||||||
ast::Type type = ctx.dst->ty.sampler(kind);
|
Type type = ctx.dst->ty.sampler(kind);
|
||||||
const char* name = kind == type::SamplerKind::kComparisonSampler
|
const char* name = kind == type::SamplerKind::kComparisonSampler
|
||||||
? "placeholder_comparison_sampler"
|
? "placeholder_comparison_sampler"
|
||||||
: "placeholder_sampler";
|
: "placeholder_sampler";
|
||||||
|
@ -132,13 +131,13 @@ struct CombineSamplers::State {
|
||||||
return ctx.dst->GlobalVar(symbol, type, Attributes());
|
return ctx.dst->GlobalVar(symbol, type, Attributes());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates ast::Identifier for a given texture and sampler variable pair.
|
/// Creates Identifier for a given texture and sampler variable pair.
|
||||||
/// Depth textures with no samplers are turned into the corresponding
|
/// Depth textures with no samplers are turned into the corresponding
|
||||||
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
|
/// f32 texture (e.g., texture_depth_2d -> texture_2d<f32>).
|
||||||
/// @param texture the texture variable of interest
|
/// @param texture the texture variable of interest
|
||||||
/// @param sampler the texture variable of interest
|
/// @param sampler the texture variable of interest
|
||||||
/// @returns the newly-created type
|
/// @returns the newly-created type
|
||||||
ast::Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
|
Type CreateCombinedASTTypeFor(const sem::Variable* texture, const sem::Variable* sampler) {
|
||||||
const type::Type* texture_type = texture->Type()->UnwrapRef();
|
const type::Type* texture_type = texture->Type()->UnwrapRef();
|
||||||
const type::DepthTexture* depth = texture_type->As<type::DepthTexture>();
|
const type::DepthTexture* depth = texture_type->As<type::DepthTexture>();
|
||||||
if (depth && !sampler) {
|
if (depth && !sampler) {
|
||||||
|
@ -163,8 +162,7 @@ struct CombineSamplers::State {
|
||||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
|
ctx.Remove(ctx.src->AST().GlobalDeclarations(), global);
|
||||||
} else if (auto binding_point = global_sem->BindingPoint()) {
|
} else if (auto binding_point = global_sem->BindingPoint()) {
|
||||||
if (binding_point->group == 0 && binding_point->binding == 0) {
|
if (binding_point->group == 0 && binding_point->binding == 0) {
|
||||||
auto* attribute =
|
auto* attribute = ctx.dst->Disable(DisabledValidation::kBindingPointCollision);
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
|
|
||||||
ctx.InsertFront(global->attributes, attribute);
|
ctx.InsertFront(global->attributes, attribute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -172,13 +170,13 @@ struct CombineSamplers::State {
|
||||||
|
|
||||||
// Rewrite all function signatures to use combined samplers, and remove
|
// Rewrite all function signatures to use combined samplers, and remove
|
||||||
// separate textures & samplers. Create new combined globals where found.
|
// separate textures & samplers. Create new combined globals where found.
|
||||||
ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* {
|
ctx.ReplaceAll([&](const Function* ast_fn) -> const Function* {
|
||||||
if (auto* fn = sem.Get(ast_fn)) {
|
if (auto* fn = sem.Get(ast_fn)) {
|
||||||
auto pairs = fn->TextureSamplerPairs();
|
auto pairs = fn->TextureSamplerPairs();
|
||||||
if (pairs.IsEmpty()) {
|
if (pairs.IsEmpty()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
utils::Vector<const ast::Parameter*, 8> params;
|
utils::Vector<const Parameter*, 8> params;
|
||||||
for (auto pair : fn->TextureSamplerPairs()) {
|
for (auto pair : fn->TextureSamplerPairs()) {
|
||||||
const sem::Variable* texture_var = pair.first;
|
const sem::Variable* texture_var = pair.first;
|
||||||
const sem::Variable* sampler_var = pair.second;
|
const sem::Variable* sampler_var = pair.second;
|
||||||
|
@ -195,7 +193,7 @@ struct CombineSamplers::State {
|
||||||
} else {
|
} else {
|
||||||
// Either texture or sampler (or both) is a function parameter;
|
// Either texture or sampler (or both) is a function parameter;
|
||||||
// add a new function parameter to represent the combined sampler.
|
// add a new function parameter to represent the combined sampler.
|
||||||
ast::Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
|
Type type = CreateCombinedASTTypeFor(texture_var, sampler_var);
|
||||||
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
|
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
|
||||||
params.Push(var);
|
params.Push(var);
|
||||||
function_combined_texture_samplers_[fn][pair] = var;
|
function_combined_texture_samplers_[fn][pair] = var;
|
||||||
|
@ -215,7 +213,7 @@ struct CombineSamplers::State {
|
||||||
auto* body = ctx.Clone(ast_fn->body);
|
auto* body = ctx.Clone(ast_fn->body);
|
||||||
auto attributes = ctx.Clone(ast_fn->attributes);
|
auto attributes = ctx.Clone(ast_fn->attributes);
|
||||||
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
|
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
|
||||||
return ctx.dst->create<ast::Function>(name, params, return_type, body,
|
return ctx.dst->create<Function>(name, params, return_type, body,
|
||||||
std::move(attributes),
|
std::move(attributes),
|
||||||
std::move(return_type_attributes));
|
std::move(return_type_attributes));
|
||||||
}
|
}
|
||||||
|
@ -225,9 +223,9 @@ struct CombineSamplers::State {
|
||||||
// Replace all function call expressions containing texture or
|
// Replace all function call expressions containing texture or
|
||||||
// sampler parameters to use the current function's combined samplers or
|
// sampler parameters to use the current function's combined samplers or
|
||||||
// the combined global samplers, as appropriate.
|
// the combined global samplers, as appropriate.
|
||||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
|
ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
|
||||||
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
|
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
|
||||||
utils::Vector<const ast::Expression*, 8> args;
|
utils::Vector<const Expression*, 8> args;
|
||||||
// Replace all texture builtin calls.
|
// Replace all texture builtin calls.
|
||||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||||
const auto& signature = builtin->Signature();
|
const auto& signature = builtin->Signature();
|
||||||
|
@ -254,7 +252,7 @@ struct CombineSamplers::State {
|
||||||
for (auto* arg : expr->args) {
|
for (auto* arg : expr->args) {
|
||||||
auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
|
auto* type = ctx.src->TypeOf(arg)->UnwrapRef();
|
||||||
if (type->Is<type::Texture>()) {
|
if (type->Is<type::Texture>()) {
|
||||||
const ast::Variable* var =
|
const Variable* var =
|
||||||
IsGlobal(new_pair)
|
IsGlobal(new_pair)
|
||||||
? global_combined_texture_samplers_[new_pair]
|
? global_combined_texture_samplers_[new_pair]
|
||||||
: function_combined_texture_samplers_[call->Stmt()->Function()]
|
: function_combined_texture_samplers_[call->Stmt()->Function()]
|
||||||
|
@ -263,7 +261,7 @@ struct CombineSamplers::State {
|
||||||
} else if (auto* sampler_type = type->As<type::Sampler>()) {
|
} else if (auto* sampler_type = type->As<type::Sampler>()) {
|
||||||
type::SamplerKind kind = sampler_type->kind();
|
type::SamplerKind kind = sampler_type->kind();
|
||||||
int index = (kind == type::SamplerKind::kSampler) ? 0 : 1;
|
int index = (kind == type::SamplerKind::kSampler) ? 0 : 1;
|
||||||
const ast::Variable*& p = placeholder_samplers_[index];
|
const Variable*& p = placeholder_samplers_[index];
|
||||||
if (!p) {
|
if (!p) {
|
||||||
p = CreatePlaceholder(kind);
|
p = CreatePlaceholder(kind);
|
||||||
}
|
}
|
||||||
|
@ -272,10 +270,10 @@ struct CombineSamplers::State {
|
||||||
args.Push(ctx.Clone(arg));
|
args.Push(ctx.Clone(arg));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const ast::Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
|
const Expression* value = ctx.dst->Call(ctx.Clone(expr->target), args);
|
||||||
if (builtin->Type() == builtin::Function::kTextureLoad &&
|
if (builtin->Type() == builtin::Function::kTextureLoad &&
|
||||||
texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() &&
|
texture_var->Type()->UnwrapRef()->Is<type::DepthTexture>() &&
|
||||||
!call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
|
!call->Stmt()->Declaration()->Is<CallStatement>()) {
|
||||||
value = ctx.dst->MemberAccessor(value, "x");
|
value = ctx.dst->MemberAccessor(value, "x");
|
||||||
}
|
}
|
||||||
return value;
|
return value;
|
||||||
|
@ -307,7 +305,7 @@ struct CombineSamplers::State {
|
||||||
// If both texture and sampler are (now) global, pass that
|
// If both texture and sampler are (now) global, pass that
|
||||||
// global variable to the callee. Otherwise use the caller's
|
// global variable to the callee. Otherwise use the caller's
|
||||||
// function parameter for this pair.
|
// function parameter for this pair.
|
||||||
const ast::Variable* var =
|
const Variable* var =
|
||||||
IsGlobal(new_pair)
|
IsGlobal(new_pair)
|
||||||
? global_combined_texture_samplers_[new_pair]
|
? global_combined_texture_samplers_[new_pair]
|
||||||
: function_combined_texture_samplers_[call->Stmt()->Function()]
|
: function_combined_texture_samplers_[call->Stmt()->Function()]
|
||||||
|
|
|
@ -60,21 +60,21 @@ bool ShouldRun(const Program* program) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Offset is a simple ast::Expression builder interface, used to build byte
|
/// Offset is a simple Expression builder interface, used to build byte
|
||||||
/// offsets for storage and uniform buffer accesses.
|
/// offsets for storage and uniform buffer accesses.
|
||||||
struct Offset : utils::Castable<Offset> {
|
struct Offset : utils::Castable<Offset> {
|
||||||
/// @returns builds and returns the ast::Expression in `ctx.dst`
|
/// @returns builds and returns the Expression in `ctx.dst`
|
||||||
virtual const ast::Expression* Build(CloneContext& ctx) const = 0;
|
virtual const Expression* Build(CloneContext& ctx) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// OffsetExpr is an implementation of Offset that clones and casts the given
|
/// OffsetExpr is an implementation of Offset that clones and casts the given
|
||||||
/// expression to `u32`.
|
/// expression to `u32`.
|
||||||
struct OffsetExpr : Offset {
|
struct OffsetExpr : Offset {
|
||||||
const ast::Expression* const expr = nullptr;
|
const Expression* const expr = nullptr;
|
||||||
|
|
||||||
explicit OffsetExpr(const ast::Expression* e) : expr(e) {}
|
explicit OffsetExpr(const Expression* e) : expr(e) {}
|
||||||
|
|
||||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
const Expression* Build(CloneContext& ctx) const override {
|
||||||
auto* type = ctx.src->Sem().GetVal(expr)->Type()->UnwrapRef();
|
auto* type = ctx.src->Sem().GetVal(expr)->Type()->UnwrapRef();
|
||||||
auto* res = ctx.Clone(expr);
|
auto* res = ctx.Clone(expr);
|
||||||
if (!type->Is<type::U32>()) {
|
if (!type->Is<type::U32>()) {
|
||||||
|
@ -91,7 +91,7 @@ struct OffsetLiteral final : utils::Castable<OffsetLiteral, Offset> {
|
||||||
|
|
||||||
explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
|
explicit OffsetLiteral(uint32_t lit) : literal(lit) {}
|
||||||
|
|
||||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
const Expression* Build(CloneContext& ctx) const override {
|
||||||
return ctx.dst->Expr(u32(literal));
|
return ctx.dst->Expr(u32(literal));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -99,12 +99,12 @@ struct OffsetLiteral final : utils::Castable<OffsetLiteral, Offset> {
|
||||||
/// OffsetBinOp is an implementation of Offset that constructs a binary-op of
|
/// OffsetBinOp is an implementation of Offset that constructs a binary-op of
|
||||||
/// two Offsets.
|
/// two Offsets.
|
||||||
struct OffsetBinOp : Offset {
|
struct OffsetBinOp : Offset {
|
||||||
ast::BinaryOp op;
|
BinaryOp op;
|
||||||
Offset const* lhs = nullptr;
|
Offset const* lhs = nullptr;
|
||||||
Offset const* rhs = nullptr;
|
Offset const* rhs = nullptr;
|
||||||
|
|
||||||
const ast::Expression* Build(CloneContext& ctx) const override {
|
const Expression* Build(CloneContext& ctx) const override {
|
||||||
return ctx.dst->create<ast::BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx));
|
return ctx.dst->create<BinaryExpression>(op, lhs->Build(ctx), rhs->Build(ctx));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -313,7 +313,7 @@ struct BufferAccess {
|
||||||
|
|
||||||
/// Store describes a single storage or uniform buffer write
|
/// Store describes a single storage or uniform buffer write
|
||||||
struct Store {
|
struct Store {
|
||||||
const ast::AssignmentStatement* assignment; // The AST assignment statement
|
const AssignmentStatement* assignment; // The AST assignment statement
|
||||||
BufferAccess target; // The target for the write
|
BufferAccess target; // The target for the write
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -330,9 +330,9 @@ struct DecomposeMemoryAccess::State {
|
||||||
/// expressions chain the access.
|
/// expressions chain the access.
|
||||||
/// Subset of #expression_order, as expressions are not removed from
|
/// Subset of #expression_order, as expressions are not removed from
|
||||||
/// #expression_order.
|
/// #expression_order.
|
||||||
std::unordered_map<const ast::Expression*, BufferAccess> accesses;
|
std::unordered_map<const Expression*, BufferAccess> accesses;
|
||||||
/// The visited order of AST expressions (superset of #accesses)
|
/// The visited order of AST expressions (superset of #accesses)
|
||||||
std::vector<const ast::Expression*> expression_order;
|
std::vector<const Expression*> expression_order;
|
||||||
/// [buffer-type, element-type] -> load function name
|
/// [buffer-type, element-type] -> load function name
|
||||||
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
|
std::unordered_map<LoadStoreKey, Symbol, LoadStoreKey::Hasher> load_funcs;
|
||||||
/// [buffer-type, element-type] -> store function name
|
/// [buffer-type, element-type] -> store function name
|
||||||
|
@ -353,9 +353,9 @@ struct DecomposeMemoryAccess::State {
|
||||||
const Offset* ToOffset(uint32_t offset) { return offsets_.Create<OffsetLiteral>(offset); }
|
const Offset* ToOffset(uint32_t offset) { return offsets_.Create<OffsetLiteral>(offset); }
|
||||||
|
|
||||||
/// @param expr the expression to convert to an Offset
|
/// @param expr the expression to convert to an Offset
|
||||||
/// @returns an Offset for the given ast::Expression
|
/// @returns an Offset for the given Expression
|
||||||
const Offset* ToOffset(const ast::Expression* expr) {
|
const Offset* ToOffset(const Expression* expr) {
|
||||||
if (auto* lit = expr->As<ast::IntLiteralExpression>()) {
|
if (auto* lit = expr->As<IntLiteralExpression>()) {
|
||||||
if (lit->value >= 0) {
|
if (lit->value >= 0) {
|
||||||
return offsets_.Create<OffsetLiteral>(static_cast<uint32_t>(lit->value));
|
return offsets_.Create<OffsetLiteral>(static_cast<uint32_t>(lit->value));
|
||||||
}
|
}
|
||||||
|
@ -390,7 +390,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto* out = offsets_.Create<OffsetBinOp>();
|
auto* out = offsets_.Create<OffsetBinOp>();
|
||||||
out->op = ast::BinaryOp::kAdd;
|
out->op = BinaryOp::kAdd;
|
||||||
out->lhs = lhs;
|
out->lhs = lhs;
|
||||||
out->rhs = rhs;
|
out->rhs = rhs;
|
||||||
return out;
|
return out;
|
||||||
|
@ -422,7 +422,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
return offsets_.Create<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal);
|
return offsets_.Create<OffsetLiteral>(lhs_lit->literal * rhs_lit->literal);
|
||||||
}
|
}
|
||||||
auto* out = offsets_.Create<OffsetBinOp>();
|
auto* out = offsets_.Create<OffsetBinOp>();
|
||||||
out->op = ast::BinaryOp::kMultiply;
|
out->op = BinaryOp::kMultiply;
|
||||||
out->lhs = lhs;
|
out->lhs = lhs;
|
||||||
out->rhs = rhs;
|
out->rhs = rhs;
|
||||||
return out;
|
return out;
|
||||||
|
@ -432,7 +432,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
/// to #expression_order.
|
/// to #expression_order.
|
||||||
/// @param expr the expression that performs the access
|
/// @param expr the expression that performs the access
|
||||||
/// @param access the access
|
/// @param access the access
|
||||||
void AddAccess(const ast::Expression* expr, const BufferAccess& access) {
|
void AddAccess(const Expression* expr, const BufferAccess& access) {
|
||||||
TINT_ASSERT(Transform, access.type);
|
TINT_ASSERT(Transform, access.type);
|
||||||
accesses.emplace(expr, access);
|
accesses.emplace(expr, access);
|
||||||
expression_order.emplace_back(expr);
|
expression_order.emplace_back(expr);
|
||||||
|
@ -443,7 +443,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
/// `node`, an invalid BufferAccess is returned.
|
/// `node`, an invalid BufferAccess is returned.
|
||||||
/// @param node the expression that performed an access
|
/// @param node the expression that performed an access
|
||||||
/// @return the BufferAccess for the given expression
|
/// @return the BufferAccess for the given expression
|
||||||
BufferAccess TakeAccess(const ast::Expression* node) {
|
BufferAccess TakeAccess(const Expression* node) {
|
||||||
auto lhs_it = accesses.find(node);
|
auto lhs_it = accesses.find(node);
|
||||||
if (lhs_it == accesses.end()) {
|
if (lhs_it == accesses.end()) {
|
||||||
return {};
|
return {};
|
||||||
|
@ -475,7 +475,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
b.Func(name, params, el_ast_ty, nullptr,
|
b.Func(name, params, el_ast_ty, nullptr,
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
intrinsic,
|
intrinsic,
|
||||||
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
|
b.Disable(DisabledValidation::kFunctionHasNoBody),
|
||||||
});
|
});
|
||||||
} else if (auto* arr_ty = el_ty->As<type::Array>()) {
|
} else if (auto* arr_ty = el_ty->As<type::Array>()) {
|
||||||
// fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
|
// fn load_func(buffer : buf_ty, offset : u32) -> array<T, N> {
|
||||||
|
@ -498,8 +498,8 @@ struct DecomposeMemoryAccess::State {
|
||||||
TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
|
TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
|
||||||
arr_cnt = 1;
|
arr_cnt = 1;
|
||||||
}
|
}
|
||||||
auto* for_cond = b.create<ast::BinaryExpression>(
|
auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
|
||||||
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
|
b.Expr(u32(arr_cnt.value())));
|
||||||
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
|
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
|
||||||
auto* arr_el = b.IndexAccessor(arr, i);
|
auto* arr_el = b.IndexAccessor(arr, i);
|
||||||
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
|
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
|
||||||
|
@ -514,7 +514,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
b.Return(arr),
|
b.Return(arr),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
utils::Vector<const ast::Expression*, 8> values;
|
utils::Vector<const Expression*, 8> values;
|
||||||
if (auto* mat_ty = el_ty->As<type::Matrix>()) {
|
if (auto* mat_ty = el_ty->As<type::Matrix>()) {
|
||||||
auto* vec_ty = mat_ty->ColumnType();
|
auto* vec_ty = mat_ty->ColumnType();
|
||||||
Symbol load = LoadFunc(vec_ty, address_space, buffer);
|
Symbol load = LoadFunc(vec_ty, address_space, buffer);
|
||||||
|
@ -557,10 +557,10 @@ struct DecomposeMemoryAccess::State {
|
||||||
b.Func(name, params, b.ty.void_(), nullptr,
|
b.Func(name, params, b.ty.void_(), nullptr,
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
intrinsic,
|
intrinsic,
|
||||||
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
|
b.Disable(DisabledValidation::kFunctionHasNoBody),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto body = Switch<utils::Vector<const ast::Statement*, 8>>(
|
auto body = Switch<utils::Vector<const Statement*, 8>>(
|
||||||
el_ty, //
|
el_ty, //
|
||||||
[&](const type::Array* arr_ty) {
|
[&](const type::Array* arr_ty) {
|
||||||
// fn store_func(buffer : buf_ty, offset : u32, value : el_ty) {
|
// fn store_func(buffer : buf_ty, offset : u32, value : el_ty) {
|
||||||
|
@ -585,8 +585,8 @@ struct DecomposeMemoryAccess::State {
|
||||||
<< "unexpected non-constant array count";
|
<< "unexpected non-constant array count";
|
||||||
arr_cnt = 1;
|
arr_cnt = 1;
|
||||||
}
|
}
|
||||||
auto* for_cond = b.create<ast::BinaryExpression>(
|
auto* for_cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(i),
|
||||||
ast::BinaryOp::kLessThan, b.Expr(i), b.Expr(u32(arr_cnt.value())));
|
b.Expr(u32(arr_cnt.value())));
|
||||||
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
|
auto* for_cont = b.Assign(i, b.Add(i, 1_u));
|
||||||
auto* arr_el = b.IndexAccessor(array, i);
|
auto* arr_el = b.IndexAccessor(array, i);
|
||||||
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
|
auto* el_offset = b.Add(b.Expr("offset"), b.Mul(i, u32(arr_ty->Stride())));
|
||||||
|
@ -598,7 +598,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
[&](const type::Matrix* mat_ty) {
|
[&](const type::Matrix* mat_ty) {
|
||||||
auto* vec_ty = mat_ty->ColumnType();
|
auto* vec_ty = mat_ty->ColumnType();
|
||||||
Symbol store = StoreFunc(vec_ty, buffer);
|
Symbol store = StoreFunc(vec_ty, buffer);
|
||||||
utils::Vector<const ast::Statement*, 4> stmts;
|
utils::Vector<const Statement*, 4> stmts;
|
||||||
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
|
for (uint32_t i = 0; i < mat_ty->columns(); i++) {
|
||||||
auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
|
auto* offset = b.Add("offset", u32(i * mat_ty->ColumnStride()));
|
||||||
auto* element = b.IndexAccessor("value", u32(i));
|
auto* element = b.IndexAccessor("value", u32(i));
|
||||||
|
@ -608,7 +608,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
return stmts;
|
return stmts;
|
||||||
},
|
},
|
||||||
[&](const type::Struct* str) {
|
[&](const type::Struct* str) {
|
||||||
utils::Vector<const ast::Statement*, 8> stmts;
|
utils::Vector<const Statement*, 8> stmts;
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
auto* offset = b.Add("offset", u32(member->Offset()));
|
auto* offset = b.Add("offset", u32(member->Offset()));
|
||||||
auto* element = b.MemberAccessor("value", ctx.Clone(member->Name()));
|
auto* element = b.MemberAccessor("value", ctx.Clone(member->Name()));
|
||||||
|
@ -656,14 +656,14 @@ struct DecomposeMemoryAccess::State {
|
||||||
<< el_ty->TypeInfo().name;
|
<< el_ty->TypeInfo().name;
|
||||||
}
|
}
|
||||||
|
|
||||||
ast::Type ret_ty;
|
Type ret_ty;
|
||||||
|
|
||||||
// For intrinsics that return a struct, there is no AST node for it, so create one now.
|
// For intrinsics that return a struct, there is no AST node for it, so create one now.
|
||||||
if (intrinsic->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
|
if (intrinsic->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
|
||||||
auto* str = intrinsic->ReturnType()->As<type::Struct>();
|
auto* str = intrinsic->ReturnType()->As<type::Struct>();
|
||||||
TINT_ASSERT(Transform, str);
|
TINT_ASSERT(Transform, str);
|
||||||
|
|
||||||
utils::Vector<const ast::StructMember*, 8> ast_members;
|
utils::Vector<const StructMember*, 8> ast_members;
|
||||||
ast_members.Reserve(str->Members().Length());
|
ast_members.Reserve(str->Members().Length());
|
||||||
for (auto& m : str->Members()) {
|
for (auto& m : str->Members()) {
|
||||||
ast_members.Push(
|
ast_members.Push(
|
||||||
|
@ -681,7 +681,7 @@ struct DecomposeMemoryAccess::State {
|
||||||
b.Func(name, std::move(params), ret_ty, nullptr,
|
b.Func(name, std::move(params), ret_ty, nullptr,
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
atomic,
|
atomic,
|
||||||
b.Disable(ast::DisabledValidation::kFunctionHasNoBody),
|
b.Disable(DisabledValidation::kFunctionHasNoBody),
|
||||||
});
|
});
|
||||||
return name;
|
return name;
|
||||||
});
|
});
|
||||||
|
@ -689,11 +689,11 @@ struct DecomposeMemoryAccess::State {
|
||||||
};
|
};
|
||||||
|
|
||||||
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
|
DecomposeMemoryAccess::Intrinsic::Intrinsic(ProgramID pid,
|
||||||
ast::NodeID nid,
|
NodeID nid,
|
||||||
Op o,
|
Op o,
|
||||||
DataType ty,
|
DataType ty,
|
||||||
builtin::AddressSpace as,
|
builtin::AddressSpace as,
|
||||||
const ast::IdentifierExpression* buf)
|
const IdentifierExpression* buf)
|
||||||
: Base(pid, nid, utils::Vector{buf}), op(o), type(ty), address_space(as) {}
|
: Base(pid, nid, utils::Vector{buf}), op(o), type(ty), address_space(as) {}
|
||||||
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
|
DecomposeMemoryAccess::Intrinsic::~Intrinsic() = default;
|
||||||
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
|
std::string DecomposeMemoryAccess::Intrinsic::InternalName() const {
|
||||||
|
@ -804,7 +804,7 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const {
|
||||||
return op != Op::kLoad && op != Op::kStore;
|
return op != Op::kLoad && op != Op::kStore;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ast::IdentifierExpression* DecomposeMemoryAccess::Intrinsic::Buffer() const {
|
const IdentifierExpression* DecomposeMemoryAccess::Intrinsic::Buffer() const {
|
||||||
return dependencies[0];
|
return dependencies[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -832,7 +832,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||||
// nodes are fully immutable and require their children to be constructed
|
// nodes are fully immutable and require their children to be constructed
|
||||||
// first so their pointer can be passed to the parent's initializer.
|
// first so their pointer can be passed to the parent's initializer.
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
if (auto* ident = node->As<ast::IdentifierExpression>()) {
|
if (auto* ident = node->As<IdentifierExpression>()) {
|
||||||
// X
|
// X
|
||||||
if (auto* sem_ident = sem.GetVal(ident)) {
|
if (auto* sem_ident = sem.GetVal(ident)) {
|
||||||
if (auto* user = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) {
|
if (auto* user = sem_ident->UnwrapLoad()->As<sem::VariableUser>()) {
|
||||||
|
@ -852,7 +852,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* accessor = node->As<ast::MemberAccessorExpression>()) {
|
if (auto* accessor = node->As<MemberAccessorExpression>()) {
|
||||||
// X.Y
|
// X.Y
|
||||||
auto* accessor_sem = sem.Get(accessor)->UnwrapLoad();
|
auto* accessor_sem = sem.Get(accessor)->UnwrapLoad();
|
||||||
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
|
if (auto* swizzle = accessor_sem->As<sem::Swizzle>()) {
|
||||||
|
@ -882,7 +882,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* accessor = node->As<ast::IndexAccessorExpression>()) {
|
if (auto* accessor = node->As<IndexAccessorExpression>()) {
|
||||||
if (auto access = state.TakeAccess(accessor->object)) {
|
if (auto access = state.TakeAccess(accessor->object)) {
|
||||||
// X[Y]
|
// X[Y]
|
||||||
if (auto* arr = access.type->As<type::Array>()) {
|
if (auto* arr = access.type->As<type::Array>()) {
|
||||||
|
@ -915,8 +915,8 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* op = node->As<ast::UnaryOpExpression>()) {
|
if (auto* op = node->As<UnaryOpExpression>()) {
|
||||||
if (op->op == ast::UnaryOp::kAddressOf) {
|
if (op->op == UnaryOp::kAddressOf) {
|
||||||
// &X
|
// &X
|
||||||
if (auto access = state.TakeAccess(op->expr)) {
|
if (auto access = state.TakeAccess(op->expr)) {
|
||||||
// HLSL does not support pointers, so just take the access from the
|
// HLSL does not support pointers, so just take the access from the
|
||||||
|
@ -927,7 +927,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* assign = node->As<ast::AssignmentStatement>()) {
|
if (auto* assign = node->As<AssignmentStatement>()) {
|
||||||
// X = Y
|
// X = Y
|
||||||
// Move the LHS access to a store.
|
// Move the LHS access to a store.
|
||||||
if (auto lhs = state.TakeAccess(assign->lhs)) {
|
if (auto lhs = state.TakeAccess(assign->lhs)) {
|
||||||
|
@ -935,7 +935,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* call_expr = node->As<ast::CallExpression>()) {
|
if (auto* call_expr = node->As<CallExpression>()) {
|
||||||
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||||
if (builtin->Type() == builtin::Function::kArrayLength) {
|
if (builtin->Type() == builtin::Function::kArrayLength) {
|
||||||
|
@ -953,7 +953,7 @@ Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||||
auto buffer = ctx.Clone(access.var->Declaration()->name->symbol);
|
auto buffer = ctx.Clone(access.var->Declaration()->name->symbol);
|
||||||
Symbol func = state.AtomicFunc(el_ty, builtin, buffer);
|
Symbol func = state.AtomicFunc(el_ty, builtin, buffer);
|
||||||
|
|
||||||
utils::Vector<const ast::Expression*, 8> args{offset};
|
utils::Vector<const Expression*, 8> args{offset};
|
||||||
for (size_t i = 1; i < call_expr->args.Length(); i++) {
|
for (size_t i = 1; i < call_expr->args.Length(); i++) {
|
||||||
auto* arg = call_expr->args[i];
|
auto* arg = call_expr->args[i];
|
||||||
args.Push(ctx.Clone(arg));
|
args.Push(ctx.Clone(arg));
|
||||||
|
|
|
@ -35,7 +35,7 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
|
||||||
/// transforms this into calls to
|
/// transforms this into calls to
|
||||||
/// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
|
/// `[RW]ByteAddressBuffer.Load[N]()` or `[RW]ByteAddressBuffer.Store[N]()`,
|
||||||
/// with a possible cast.
|
/// with a possible cast.
|
||||||
class Intrinsic final : public utils::Castable<Intrinsic, ast::InternalAttribute> {
|
class Intrinsic final : public utils::Castable<Intrinsic, InternalAttribute> {
|
||||||
public:
|
public:
|
||||||
/// Intrinsic op
|
/// Intrinsic op
|
||||||
enum class Op {
|
enum class Op {
|
||||||
|
@ -82,11 +82,11 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
|
||||||
/// @param address_space the address space of the buffer
|
/// @param address_space the address space of the buffer
|
||||||
/// @param buffer the storage or uniform buffer identifier
|
/// @param buffer the storage or uniform buffer identifier
|
||||||
Intrinsic(ProgramID pid,
|
Intrinsic(ProgramID pid,
|
||||||
ast::NodeID nid,
|
NodeID nid,
|
||||||
Op o,
|
Op o,
|
||||||
DataType type,
|
DataType type,
|
||||||
builtin::AddressSpace address_space,
|
builtin::AddressSpace address_space,
|
||||||
const ast::IdentifierExpression* buffer);
|
const IdentifierExpression* buffer);
|
||||||
/// Destructor
|
/// Destructor
|
||||||
~Intrinsic() override;
|
~Intrinsic() override;
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ class DecomposeMemoryAccess final : public utils::Castable<DecomposeMemoryAccess
|
||||||
bool IsAtomic() const;
|
bool IsAtomic() const;
|
||||||
|
|
||||||
/// @return the buffer that this intrinsic operates on
|
/// @return the buffer that this intrinsic operates on
|
||||||
const ast::IdentifierExpression* Buffer() const;
|
const IdentifierExpression* Buffer() const;
|
||||||
|
|
||||||
/// The op of the intrinsic
|
/// The op of the intrinsic
|
||||||
const Op op;
|
const Op op;
|
||||||
|
|
|
@ -37,8 +37,8 @@ using DecomposedArrays = std::unordered_map<const type::Array*, Symbol>;
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* node : program->ASTNodes().Objects()) {
|
for (auto* node : program->ASTNodes().Objects()) {
|
||||||
if (auto* ident = node->As<ast::TemplatedIdentifier>()) {
|
if (auto* ident = node->As<TemplatedIdentifier>()) {
|
||||||
if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) {
|
if (GetAttribute<StrideAttribute>(ident->attributes)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -74,8 +74,8 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
|
||||||
// stride for the array element type, then replace the array element type with
|
// stride for the array element type, then replace the array element type with
|
||||||
// a structure, holding a single field with a @size attribute equal to the
|
// a structure, holding a single field with a @size attribute equal to the
|
||||||
// array stride.
|
// array stride.
|
||||||
ctx.ReplaceAll([&](const ast::IdentifierExpression* expr) -> const ast::IdentifierExpression* {
|
ctx.ReplaceAll([&](const IdentifierExpression* expr) -> const IdentifierExpression* {
|
||||||
auto* ident = expr->identifier->As<ast::TemplatedIdentifier>();
|
auto* ident = expr->identifier->As<TemplatedIdentifier>();
|
||||||
if (!ident) {
|
if (!ident) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -90,8 +90,8 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
|
||||||
if (!arr->IsStrideImplicit()) {
|
if (!arr->IsStrideImplicit()) {
|
||||||
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
|
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
|
||||||
auto name = b.Symbols().New("strided_arr");
|
auto name = b.Symbols().New("strided_arr");
|
||||||
auto* member_ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>());
|
auto* member_ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
|
||||||
auto* member = b.Member(kMemberName, ast::Type{member_ty},
|
auto* member = b.Member(kMemberName, Type{member_ty},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberSize(AInt(arr->Stride())),
|
b.MemberSize(AInt(arr->Stride())),
|
||||||
});
|
});
|
||||||
|
@ -105,14 +105,14 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
|
||||||
return b.Expr(b.ty.array(b.ty(el_ty)));
|
return b.Expr(b.ty.array(b.ty(el_ty)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (ast::GetAttribute<ast::StrideAttribute>(ident->attributes)) {
|
if (GetAttribute<StrideAttribute>(ident->attributes)) {
|
||||||
// Strip the @stride attribute
|
// Strip the @stride attribute
|
||||||
auto* ty = ctx.Clone(ident->arguments[0]->As<ast::IdentifierExpression>());
|
auto* ty = ctx.Clone(ident->arguments[0]->As<IdentifierExpression>());
|
||||||
if (ident->arguments.Length() > 1) {
|
if (ident->arguments.Length() > 1) {
|
||||||
auto* count = ctx.Clone(ident->arguments[1]);
|
auto* count = ctx.Clone(ident->arguments[1]);
|
||||||
return b.Expr(b.ty.array(ast::Type{ty}, count));
|
return b.Expr(b.ty.array(Type{ty}, count));
|
||||||
} else {
|
} else {
|
||||||
return b.Expr(b.ty.array(ast::Type{ty}));
|
return b.Expr(b.ty.array(Type{ty}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -122,7 +122,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
|
||||||
// element changed to a single field structure. These expressions are adjusted
|
// element changed to a single field structure. These expressions are adjusted
|
||||||
// to insert an additional member accessor for the single structure field.
|
// to insert an additional member accessor for the single structure field.
|
||||||
// Example: `arr[i]` -> `arr[i].el`
|
// Example: `arr[i]` -> `arr[i].el`
|
||||||
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
|
ctx.ReplaceAll([&](const IndexAccessorExpression* idx) -> const Expression* {
|
||||||
if (auto* ty = src->TypeOf(idx->object)) {
|
if (auto* ty = src->TypeOf(idx->object)) {
|
||||||
if (auto* arr = ty->UnwrapRef()->As<type::Array>()) {
|
if (auto* arr = ty->UnwrapRef()->As<type::Array>()) {
|
||||||
if (!arr->IsStrideImplicit()) {
|
if (!arr->IsStrideImplicit()) {
|
||||||
|
@ -140,7 +140,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
|
||||||
// `@stride(32) array<i32, 3>(1, 2, 3)`
|
// `@stride(32) array<i32, 3>(1, 2, 3)`
|
||||||
// ->
|
// ->
|
||||||
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
|
// `array<strided_arr, 3>(strided_arr(1), strided_arr(2), strided_arr(3))`
|
||||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::Expression* {
|
ctx.ReplaceAll([&](const CallExpression* expr) -> const Expression* {
|
||||||
if (!expr->args.IsEmpty()) {
|
if (!expr->args.IsEmpty()) {
|
||||||
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
|
if (auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>()) {
|
||||||
if (auto* ctor = call->Target()->As<sem::ValueConstructor>()) {
|
if (auto* ctor = call->Target()->As<sem::ValueConstructor>()) {
|
||||||
|
@ -153,7 +153,7 @@ Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
|
||||||
|
|
||||||
auto* target = ctx.Clone(expr->target);
|
auto* target = ctx.Clone(expr->target);
|
||||||
|
|
||||||
utils::Vector<const ast::Expression*, 8> args;
|
utils::Vector<const Expression*, 8> args;
|
||||||
if (auto it = decomposed.find(arr); it != decomposed.end()) {
|
if (auto it = decomposed.find(arr); it != decomposed.end()) {
|
||||||
args.Reserve(expr->args.Length());
|
args.Reserve(expr->args.Length());
|
||||||
for (auto* arg : expr->args) {
|
for (auto* arg : expr->args) {
|
||||||
|
|
|
@ -101,7 +101,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateDefaultStridedArray) {
|
||||||
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
|
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateStridedArray) {
|
||||||
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
|
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor("arr", 1_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -195,7 +195,7 @@ TEST_F(DecomposeStridedArrayTest, ReadUniformStridedArray) {
|
||||||
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
|
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ TEST_F(DecomposeStridedArrayTest, ReadUniformDefaultStridedArray) {
|
||||||
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 2_i))),
|
b.IndexAccessor(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 2_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ TEST_F(DecomposeStridedArrayTest, ReadStorageStridedArray) {
|
||||||
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
|
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ TEST_F(DecomposeStridedArrayTest, ReadStorageDefaultStridedArray) {
|
||||||
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
|
b.Decl(b.Let("b", b.ty.f32(), b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -410,7 +410,7 @@ TEST_F(DecomposeStridedArrayTest, WriteStorageStridedArray) {
|
||||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
|
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -472,7 +472,7 @@ TEST_F(DecomposeStridedArrayTest, WriteStorageDefaultStridedArray) {
|
||||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
|
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -531,7 +531,7 @@ TEST_F(DecomposeStridedArrayTest, ReadWriteViaPointerLets) {
|
||||||
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), 5_f),
|
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), 5_f),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -593,7 +593,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateAliasedStridedArray) {
|
||||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
|
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "a"), 1_i), 5_f),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -696,7 +696,7 @@ TEST_F(DecomposeStridedArrayTest, PrivateNestedStridedArray) {
|
||||||
5_f),
|
5_f),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ struct MatrixInfo {
|
||||||
const type::Matrix* matrix = nullptr;
|
const type::Matrix* matrix = nullptr;
|
||||||
|
|
||||||
/// @returns the identifier of an array that holds an vector column for each row of the matrix.
|
/// @returns the identifier of an array that holds an vector column for each row of the matrix.
|
||||||
ast::Type array(ProgramBuilder* b) const {
|
Type array(ProgramBuilder* b) const {
|
||||||
return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()),
|
return b->ty.array(b->ty.vec<f32>(matrix->rows()), u32(matrix->columns()),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b->Stride(stride),
|
b->Stride(stride),
|
||||||
|
@ -72,7 +72,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||||
// and populate the `decomposed` map with the members that have been replaced.
|
// and populate the `decomposed` map with the members that have been replaced.
|
||||||
utils::Hashmap<const type::StructMember*, MatrixInfo, 8> decomposed;
|
utils::Hashmap<const type::StructMember*, MatrixInfo, 8> decomposed;
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
if (auto* str = node->As<ast::Struct>()) {
|
if (auto* str = node->As<Struct>()) {
|
||||||
auto* str_ty = src->Sem().Get(str);
|
auto* str_ty = src->Sem().Get(str);
|
||||||
if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) &&
|
if (!str_ty->UsedAs(builtin::AddressSpace::kUniform) &&
|
||||||
!str_ty->UsedAs(builtin::AddressSpace::kStorage)) {
|
!str_ty->UsedAs(builtin::AddressSpace::kStorage)) {
|
||||||
|
@ -83,8 +83,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||||
if (!matrix) {
|
if (!matrix) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto* attr =
|
auto* attr = GetAttribute<StrideAttribute>(member->Declaration()->attributes);
|
||||||
ast::GetAttribute<ast::StrideAttribute>(member->Declaration()->attributes);
|
|
||||||
if (!attr) {
|
if (!attr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -111,8 +110,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||||
// preserve these without calling conversion functions.
|
// preserve these without calling conversion functions.
|
||||||
// Example:
|
// Example:
|
||||||
// ssbo.mat[2] -> ssbo.mat[2]
|
// ssbo.mat[2] -> ssbo.mat[2]
|
||||||
ctx.ReplaceAll(
|
ctx.ReplaceAll([&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
|
||||||
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
|
|
||||||
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
|
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
|
||||||
if (decomposed.Contains(access->Member())) {
|
if (decomposed.Contains(access->Member())) {
|
||||||
auto* obj = ctx.CloneWithoutTransform(expr->object);
|
auto* obj = ctx.CloneWithoutTransform(expr->object);
|
||||||
|
@ -129,7 +127,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||||
// Example:
|
// Example:
|
||||||
// ssbo.mat = mat_to_arr(m)
|
// ssbo.mat = mat_to_arr(m)
|
||||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
|
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
|
||||||
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const AssignmentStatement* stmt) -> const Statement* {
|
||||||
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
|
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
|
||||||
if (auto info = decomposed.Find(access->Member())) {
|
if (auto info = decomposed.Find(access->Member())) {
|
||||||
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
|
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
|
||||||
|
@ -142,7 +140,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||||
auto array = [&] { return info->array(ctx.dst); };
|
auto array = [&] { return info->array(ctx.dst); };
|
||||||
|
|
||||||
auto mat = b.Sym("m");
|
auto mat = b.Sym("m");
|
||||||
utils::Vector<const ast::Expression*, 4> columns;
|
utils::Vector<const Expression*, 4> columns;
|
||||||
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
|
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
|
||||||
columns.Push(b.IndexAccessor(mat, u32(i)));
|
columns.Push(b.IndexAccessor(mat, u32(i)));
|
||||||
}
|
}
|
||||||
|
@ -168,7 +166,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||||
// matrix type. Example:
|
// matrix type. Example:
|
||||||
// m = arr_to_mat(ssbo.mat)
|
// m = arr_to_mat(ssbo.mat)
|
||||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
|
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
|
||||||
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
|
ctx.ReplaceAll([&](const MemberAccessorExpression* expr) -> const Expression* {
|
||||||
if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) {
|
if (auto* access = src->Sem().Get(expr)->UnwrapLoad()->As<sem::StructMemberAccess>()) {
|
||||||
if (auto info = decomposed.Find(access->Member())) {
|
if (auto info = decomposed.Find(access->Member())) {
|
||||||
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
|
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
|
||||||
|
@ -181,7 +179,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||||
auto array = [&] { return info->array(ctx.dst); };
|
auto array = [&] { return info->array(ctx.dst); };
|
||||||
|
|
||||||
auto arr = b.Sym("arr");
|
auto arr = b.Sym("arr");
|
||||||
utils::Vector<const ast::Expression*, 4> columns;
|
utils::Vector<const Expression*, 4> columns;
|
||||||
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
|
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
|
||||||
columns.Push(b.IndexAccessor(arr, u32(i)));
|
columns.Push(b.IndexAccessor(arr, u32(i)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,13 +67,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
|
||||||
// let x : mat2x2<f32> = s.m;
|
// let x : mat2x2<f32> = s.m;
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(16_u),
|
b.MemberOffset(16_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
|
||||||
|
@ -82,7 +82,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix) {
|
||||||
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -124,13 +124,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
|
||||||
// let x : vec2<f32> = s.m[1];
|
// let x : vec2<f32> = s.m[1];
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(16_u),
|
b.MemberOffset(16_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
|
||||||
|
@ -140,7 +140,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformColumn) {
|
||||||
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
|
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -178,13 +178,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
|
||||||
// let x : mat2x2<f32> = s.m;
|
// let x : mat2x2<f32> = s.m;
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(16_u),
|
b.MemberOffset(16_u),
|
||||||
b.create<ast::StrideAttribute>(8u),
|
b.create<StrideAttribute>(8u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kUniform, b.Group(0_a), b.Binding(0_a));
|
||||||
|
@ -193,7 +193,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadUniformMatrix_DefaultStride) {
|
||||||
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -232,13 +232,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
|
||||||
// let x : mat2x2<f32> = s.m;
|
// let x : mat2x2<f32> = s.m;
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(8_u),
|
b.MemberOffset(8_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
||||||
|
@ -248,7 +248,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageMatrix) {
|
||||||
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -290,13 +290,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
|
||||||
// let x : vec2<f32> = s.m[1];
|
// let x : vec2<f32> = s.m[1];
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(16_u),
|
b.MemberOffset(16_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
||||||
|
@ -307,7 +307,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadStorageColumn) {
|
||||||
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
|
b.Decl(b.Let("x", b.ty.vec2<f32>(), b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -345,13 +345,13 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
|
||||||
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(8_u),
|
b.MemberOffset(8_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
||||||
|
@ -362,7 +362,7 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageMatrix) {
|
||||||
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
|
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -404,13 +404,13 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
|
||||||
// s.m[1] = vec2<f32>(1.0, 2.0);
|
// s.m[1] = vec2<f32>(1.0, 2.0);
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(8_u),
|
b.MemberOffset(8_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
||||||
|
@ -420,7 +420,7 @@ TEST_F(DecomposeStridedMatrixTest, WriteStorageColumn) {
|
||||||
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.vec2<f32>(1_f, 2_f)),
|
b.Assign(b.IndexAccessor(b.MemberAccessor("s", "m"), 1_i), b.vec2<f32>(1_f, 2_f)),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -464,13 +464,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
|
||||||
// (*b)[1] = vec2<f32>(5.0, 6.0);
|
// (*b)[1] = vec2<f32>(5.0, 6.0);
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(8_u),
|
b.MemberOffset(8_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kStorage, builtin::Access::kReadWrite,
|
||||||
|
@ -486,7 +486,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadWriteViaPointerLets) {
|
||||||
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), b.vec2<f32>(5_f, 6_f)),
|
b.Assign(b.IndexAccessor(b.Deref("b"), 1_i), b.vec2<f32>(5_f, 6_f)),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -536,13 +536,13 @@ TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
|
||||||
// let x : mat2x2<f32> = s.m;
|
// let x : mat2x2<f32> = s.m;
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(8_u),
|
b.MemberOffset(8_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
|
||||||
|
@ -551,7 +551,7 @@ TEST_F(DecomposeStridedMatrixTest, ReadPrivateMatrix) {
|
||||||
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
b.Decl(b.Let("x", b.ty.mat2x2<f32>(), b.MemberAccessor("s", "m"))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -590,13 +590,13 @@ TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
|
||||||
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
// s.m = mat2x2<f32>(vec2<f32>(1.0, 2.0), vec2<f32>(3.0, 4.0));
|
||||||
// }
|
// }
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
auto* S = b.Structure(
|
auto* S =
|
||||||
"S", utils::Vector{
|
b.Structure("S", utils::Vector{
|
||||||
b.Member("m", b.ty.mat2x2<f32>(),
|
b.Member("m", b.ty.mat2x2<f32>(),
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.MemberOffset(8_u),
|
b.MemberOffset(8_u),
|
||||||
b.create<ast::StrideAttribute>(32u),
|
b.create<StrideAttribute>(32u),
|
||||||
b.Disable(ast::DisabledValidation::kIgnoreStrideAttribute),
|
b.Disable(DisabledValidation::kIgnoreStrideAttribute),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
|
b.GlobalVar("s", b.ty.Of(S), builtin::AddressSpace::kPrivate);
|
||||||
|
@ -606,7 +606,7 @@ TEST_F(DecomposeStridedMatrixTest, WritePrivateMatrix) {
|
||||||
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
|
b.mat2x2<f32>(b.vec2<f32>(1_f, 2_f), b.vec2<f32>(3_f, 4_f))),
|
||||||
},
|
},
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
b.Stage(ast::PipelineStage::kCompute),
|
b.Stage(PipelineStage::kCompute),
|
||||||
b.WorkgroupSize(1_i),
|
b.WorkgroupSize(1_i),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -85,9 +85,8 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
|
||||||
b.GlobalVar(flag, builtin::AddressSpace::kPrivate, b.Expr(false));
|
b.GlobalVar(flag, builtin::AddressSpace::kPrivate, b.Expr(false));
|
||||||
|
|
||||||
// Replace all discard statements with a statement that marks the invocation as discarded.
|
// Replace all discard statements with a statement that marks the invocation as discarded.
|
||||||
ctx.ReplaceAll([&](const ast::DiscardStatement*) -> const ast::Statement* {
|
ctx.ReplaceAll(
|
||||||
return b.Assign(flag, b.Expr(true));
|
[&](const DiscardStatement*) -> const Statement* { return b.Assign(flag, b.Expr(true)); });
|
||||||
});
|
|
||||||
|
|
||||||
// Insert a conditional discard at the end of each entry point that does not end with a return.
|
// Insert a conditional discard at the end of each entry point that does not end with a return.
|
||||||
for (auto* func : functions_to_process) {
|
for (auto* func : functions_to_process) {
|
||||||
|
@ -111,7 +110,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
|
||||||
node,
|
node,
|
||||||
|
|
||||||
// Mask assignments to storage buffer variables.
|
// Mask assignments to storage buffer variables.
|
||||||
[&](const ast::AssignmentStatement* assign) {
|
[&](const AssignmentStatement* assign) {
|
||||||
// Skip writes in functions that are not called from shaders that discard.
|
// Skip writes in functions that are not called from shaders that discard.
|
||||||
auto* func = sem.Get(assign)->Function();
|
auto* func = sem.Get(assign)->Function();
|
||||||
if (functions_to_process.count(func) == 0) {
|
if (functions_to_process.count(func) == 0) {
|
||||||
|
@ -119,7 +118,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip phony assignments.
|
// Skip phony assignments.
|
||||||
if (assign->lhs->Is<ast::PhonyExpression>()) {
|
if (assign->lhs->Is<PhonyExpression>()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,7 +143,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Mask builtins that write to host-visible memory.
|
// Mask builtins that write to host-visible memory.
|
||||||
[&](const ast::CallExpression* call) {
|
[&](const CallExpression* call) {
|
||||||
auto* sem_call = sem.Get<sem::Call>(call);
|
auto* sem_call = sem.Get<sem::Call>(call);
|
||||||
auto* stmt = sem_call ? sem_call->Stmt() : nullptr;
|
auto* stmt = sem_call ? sem_call->Stmt() : nullptr;
|
||||||
auto* func = stmt ? stmt->Function() : nullptr;
|
auto* func = stmt ? stmt->Function() : nullptr;
|
||||||
|
@ -161,7 +160,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
|
||||||
} else if (builtin->IsAtomic() &&
|
} else if (builtin->IsAtomic() &&
|
||||||
builtin->Type() != builtin::Function::kAtomicLoad) {
|
builtin->Type() != builtin::Function::kAtomicLoad) {
|
||||||
// A call to an atomic builtin can be a statement or an expression.
|
// A call to an atomic builtin can be a statement or an expression.
|
||||||
if (auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>();
|
if (auto* call_stmt = stmt->Declaration()->As<CallStatement>();
|
||||||
call_stmt && call_stmt->expr == call) {
|
call_stmt && call_stmt->expr == call) {
|
||||||
// This call is a statement.
|
// This call is a statement.
|
||||||
// Wrap it inside a conditional block.
|
// Wrap it inside a conditional block.
|
||||||
|
@ -178,8 +177,8 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
|
||||||
// }
|
// }
|
||||||
// let y = x + tmp;
|
// let y = x + tmp;
|
||||||
auto result = b.Sym();
|
auto result = b.Sym();
|
||||||
ast::Type result_ty;
|
Type result_ty;
|
||||||
const ast::Statement* masked_call = nullptr;
|
const Statement* masked_call = nullptr;
|
||||||
if (builtin->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
|
if (builtin->Type() == builtin::Function::kAtomicCompareExchangeWeak) {
|
||||||
// Special case for atomicCompareExchangeWeak as we cannot name its
|
// Special case for atomicCompareExchangeWeak as we cannot name its
|
||||||
// result type. We have to declare an equivalent struct and copy the
|
// result type. We have to declare an equivalent struct and copy the
|
||||||
|
@ -232,7 +231,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Insert a conditional discard before all return statements in entry points.
|
// Insert a conditional discard before all return statements in entry points.
|
||||||
[&](const ast::ReturnStatement* ret) {
|
[&](const ReturnStatement* ret) {
|
||||||
auto* func = sem.Get(ret)->Function();
|
auto* func = sem.Get(ret)->Function();
|
||||||
if (func->Declaration()->IsEntryPoint() && functions_to_process.count(func)) {
|
if (func->Declaration()->IsEntryPoint() && functions_to_process.count(func)) {
|
||||||
auto* discard = b.If(flag, b.Block(b.Discard()));
|
auto* discard = b.If(flag, b.Block(b.Discard()));
|
||||||
|
|
|
@ -450,19 +450,19 @@ struct DirectVariableAccess::State {
|
||||||
|
|
||||||
Switch(
|
Switch(
|
||||||
variable->Declaration(),
|
variable->Declaration(),
|
||||||
[&](const ast::Var*) {
|
[&](const Var*) {
|
||||||
if (variable->AddressSpace() != builtin::AddressSpace::kHandle) {
|
if (variable->AddressSpace() != builtin::AddressSpace::kHandle) {
|
||||||
// Start a new access chain for the non-handle 'var' access
|
// Start a new access chain for the non-handle 'var' access
|
||||||
create_new_chain();
|
create_new_chain();
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::Parameter*) {
|
[&](const Parameter*) {
|
||||||
if (variable->Type()->Is<type::Pointer>()) {
|
if (variable->Type()->Is<type::Pointer>()) {
|
||||||
// Start a new access chain for the pointer parameter access
|
// Start a new access chain for the pointer parameter access
|
||||||
create_new_chain();
|
create_new_chain();
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::Let*) {
|
[&](const Let*) {
|
||||||
if (variable->Type()->Is<type::Pointer>()) {
|
if (variable->Type()->Is<type::Pointer>()) {
|
||||||
// variable is a pointer-let.
|
// variable is a pointer-let.
|
||||||
auto* init = sem.GetVal(variable->Declaration()->initializer);
|
auto* init = sem.GetVal(variable->Declaration()->initializer);
|
||||||
|
@ -494,11 +494,10 @@ struct DirectVariableAccess::State {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const sem::ValueExpression* e) {
|
[&](const sem::ValueExpression* e) {
|
||||||
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
|
if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
|
||||||
// Unary op.
|
// Unary op.
|
||||||
// If this is a '&' or '*', simply move the chain to the unary op expression.
|
// If this is a '&' or '*', simply move the chain to the unary op expression.
|
||||||
if (unary->op == ast::UnaryOp::kAddressOf ||
|
if (unary->op == UnaryOp::kAddressOf || unary->op == UnaryOp::kIndirection) {
|
||||||
unary->op == ast::UnaryOp::kIndirection) {
|
|
||||||
take_chain(sem.GetVal(unary->expr));
|
take_chain(sem.GetVal(unary->expr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -529,7 +528,7 @@ struct DirectVariableAccess::State {
|
||||||
|
|
||||||
if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) {
|
if (auto* idx_variable_user = idx->UnwrapMaterialize()->As<sem::VariableUser>()) {
|
||||||
auto* idx_variable = idx_variable_user->Variable();
|
auto* idx_variable = idx_variable_user->Variable();
|
||||||
if (idx_variable->Declaration()->IsAnyOf<ast::Let, ast::Parameter>()) {
|
if (idx_variable->Declaration()->IsAnyOf<Let, Parameter>()) {
|
||||||
// Dynamic index is an immutable variable
|
// Dynamic index is an immutable variable
|
||||||
continue; // Hoisting not required.
|
continue; // Hoisting not required.
|
||||||
}
|
}
|
||||||
|
@ -557,7 +556,7 @@ struct DirectVariableAccess::State {
|
||||||
/// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type
|
/// * Casts the resulting expression to a u32 if @p cast_to_u32 is true, and the expression type
|
||||||
/// isn't implicitly usable as a u32. This is to help feed the expression into a
|
/// isn't implicitly usable as a u32. This is to help feed the expression into a
|
||||||
/// `array<u32, N>` argument passed to a callee variant function.
|
/// `array<u32, N>` argument passed to a callee variant function.
|
||||||
const ast::Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) {
|
const Expression* BuildDynamicIndex(const sem::ValueExpression* idx, bool cast_to_u32) {
|
||||||
if (auto* val = idx->ConstantValue()) {
|
if (auto* val = idx->ConstantValue()) {
|
||||||
// Expression evaluated to a constant value. Just emit that constant.
|
// Expression evaluated to a constant value. Just emit that constant.
|
||||||
return b.Expr(val->ValueAs<AInt>());
|
return b.Expr(val->ValueAs<AInt>());
|
||||||
|
@ -808,7 +807,7 @@ struct DirectVariableAccess::State {
|
||||||
// many variant functions, keep a record of the last created variant, and explicitly add
|
// many variant functions, keep a record of the last created variant, and explicitly add
|
||||||
// this to the module if it isn't the last. We'll return the last created variant,
|
// this to the module if it isn't the last. We'll return the last created variant,
|
||||||
// taking the place of the original function.
|
// taking the place of the original function.
|
||||||
const ast::Function* pending_variant = nullptr;
|
const Function* pending_variant = nullptr;
|
||||||
|
|
||||||
// For each variant of fn...
|
// For each variant of fn...
|
||||||
for (auto variant_it : fn_info->SortedVariants()) {
|
for (auto variant_it : fn_info->SortedVariants()) {
|
||||||
|
@ -827,7 +826,7 @@ struct DirectVariableAccess::State {
|
||||||
// Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are
|
// Pointer parameters in the 'uniform', 'storage' or 'workgroup' address space are
|
||||||
// either replaced with an array of dynamic indices, or are dropped (if there are no
|
// either replaced with an array of dynamic indices, or are dropped (if there are no
|
||||||
// dynamic indices).
|
// dynamic indices).
|
||||||
utils::Vector<const ast::Parameter*, 8> params;
|
utils::Vector<const Parameter*, 8> params;
|
||||||
for (auto* param : fn->Parameters()) {
|
for (auto* param : fn->Parameters()) {
|
||||||
if (auto incoming_shape = variant_sig.Find(param)) {
|
if (auto incoming_shape = variant_sig.Find(param)) {
|
||||||
auto& symbols = *variant.ptr_param_symbols.Find(param);
|
auto& symbols = *variant.ptr_param_symbols.Find(param);
|
||||||
|
@ -856,7 +855,7 @@ struct DirectVariableAccess::State {
|
||||||
auto attrs = ctx.Clone(fn->Declaration()->attributes);
|
auto attrs = ctx.Clone(fn->Declaration()->attributes);
|
||||||
auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
|
auto ret_attrs = ctx.Clone(fn->Declaration()->return_type_attributes);
|
||||||
pending_variant =
|
pending_variant =
|
||||||
b.create<ast::Function>(b.Ident(variant.name), std::move(params), ret_ty, body,
|
b.create<Function>(b.Ident(variant.name), std::move(params), ret_ty, body,
|
||||||
std::move(attrs), std::move(ret_attrs));
|
std::move(attrs), std::move(ret_attrs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -877,7 +876,7 @@ struct DirectVariableAccess::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the new call expressions's arguments.
|
// Build the new call expressions's arguments.
|
||||||
utils::Vector<const ast::Expression*, 8> new_args;
|
utils::Vector<const Expression*, 8> new_args;
|
||||||
for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) {
|
for (size_t arg_idx = 0; arg_idx < call->Arguments().Length(); arg_idx++) {
|
||||||
auto* arg = call->Arguments()[arg_idx];
|
auto* arg = call->Arguments()[arg_idx];
|
||||||
auto* param = call->Target()->Parameters()[arg_idx];
|
auto* param = call->Target()->Parameters()[arg_idx];
|
||||||
|
@ -915,7 +914,7 @@ struct DirectVariableAccess::State {
|
||||||
// Get or create the dynamic indices array.
|
// Get or create the dynamic indices array.
|
||||||
if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
|
if (auto dyn_idx_arr_ty = DynamicIndexArrayType(full_indices)) {
|
||||||
// Build an array of dynamic indices to pass as the replacement for the pointer.
|
// Build an array of dynamic indices to pass as the replacement for the pointer.
|
||||||
utils::Vector<const ast::Expression*, 8> dyn_idx_args;
|
utils::Vector<const Expression*, 8> dyn_idx_args;
|
||||||
if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
|
if (auto* root_param = chain->root.variable->As<sem::Parameter>()) {
|
||||||
// Access chain originates from a pointer parameter.
|
// Access chain originates from a pointer parameter.
|
||||||
if (auto incoming_chain =
|
if (auto incoming_chain =
|
||||||
|
@ -985,7 +984,7 @@ struct DirectVariableAccess::State {
|
||||||
/// let.
|
/// let.
|
||||||
void TransformAccessChainExpressions() {
|
void TransformAccessChainExpressions() {
|
||||||
// Register a custom handler for all non-function call expressions
|
// Register a custom handler for all non-function call expressions
|
||||||
ctx.ReplaceAll([this](const ast::Expression* ast_expr) -> const ast::Expression* {
|
ctx.ReplaceAll([this](const Expression* ast_expr) -> const Expression* {
|
||||||
if (!clone_state->current_variant) {
|
if (!clone_state->current_variant) {
|
||||||
// Expression does not belong to a function variant.
|
// Expression does not belong to a function variant.
|
||||||
return nullptr; // Just clone the expression.
|
return nullptr; // Just clone the expression.
|
||||||
|
@ -1065,7 +1064,7 @@ struct DirectVariableAccess::State {
|
||||||
|
|
||||||
/// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
|
/// @returns the type alias used to hold the dynamic indices for @p shape, declaring a new alias
|
||||||
/// if this is the first call for the given shape.
|
/// if this is the first call for the given shape.
|
||||||
ast::Type DynamicIndexArrayType(const AccessShape& shape) {
|
Type DynamicIndexArrayType(const AccessShape& shape) {
|
||||||
auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
|
auto name = dynamic_index_array_aliases.GetOrCreate(shape, [&] {
|
||||||
// Count the number of dynamic indices
|
// Count the number of dynamic indices
|
||||||
uint32_t num_dyn_indices = shape.NumDynamicIndices();
|
uint32_t num_dyn_indices = shape.NumDynamicIndices();
|
||||||
|
@ -1076,7 +1075,7 @@ struct DirectVariableAccess::State {
|
||||||
b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
|
b.Alias(symbol, b.ty.array(b.ty.u32(), u32(num_dyn_indices)));
|
||||||
return symbol;
|
return symbol;
|
||||||
});
|
});
|
||||||
return name.IsValid() ? b.ty(name) : ast::Type{};
|
return name.IsValid() ? b.ty(name) : Type{};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @returns a name describing the given shape
|
/// @returns a name describing the given shape
|
||||||
|
@ -1113,7 +1112,7 @@ struct DirectVariableAccess::State {
|
||||||
/// Builds an expresion to the root of an access, returning the new expression.
|
/// Builds an expresion to the root of an access, returning the new expression.
|
||||||
/// @param root the AccessRoot
|
/// @param root the AccessRoot
|
||||||
/// @param deref if true, the returned expression will always be a reference type.
|
/// @param deref if true, the returned expression will always be a reference type.
|
||||||
const ast::Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
|
const Expression* BuildAccessRootExpr(const AccessRoot& root, bool deref) {
|
||||||
if (auto* param = root.variable->As<sem::Parameter>()) {
|
if (auto* param = root.variable->As<sem::Parameter>()) {
|
||||||
if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
|
if (auto symbols = clone_state->current_variant->ptr_param_symbols.Find(param)) {
|
||||||
if (deref) {
|
if (deref) {
|
||||||
|
@ -1123,7 +1122,7 @@ struct DirectVariableAccess::State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const ast::Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol));
|
const Expression* expr = b.Expr(ctx.Clone(root.variable->Declaration()->name->symbol));
|
||||||
if (deref) {
|
if (deref) {
|
||||||
if (root.variable->Type()->Is<type::Pointer>()) {
|
if (root.variable->Type()->Is<type::Pointer>()) {
|
||||||
expr = b.Deref(expr);
|
expr = b.Deref(expr);
|
||||||
|
@ -1137,10 +1136,9 @@ struct DirectVariableAccess::State {
|
||||||
/// @param expr the input expression
|
/// @param expr the input expression
|
||||||
/// @param access the access to perform on the current expression
|
/// @param access the access to perform on the current expression
|
||||||
/// @param dynamic_index a function that obtains the i'th dynamic index
|
/// @param dynamic_index a function that obtains the i'th dynamic index
|
||||||
const ast::Expression* BuildAccessExpr(
|
const Expression* BuildAccessExpr(const Expression* expr,
|
||||||
const ast::Expression* expr,
|
|
||||||
const AccessOp& access,
|
const AccessOp& access,
|
||||||
std::function<const ast::Expression*(size_t)> dynamic_index) {
|
std::function<const Expression*(size_t)> dynamic_index) {
|
||||||
if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
|
if (auto* dyn_idx = std::get_if<DynamicIndex>(&access)) {
|
||||||
/// The access uses a dynamic (runtime-expression) index.
|
/// The access uses a dynamic (runtime-expression) index.
|
||||||
auto* idx = dynamic_index(dyn_idx->slot);
|
auto* idx = dynamic_index(dyn_idx->slot);
|
||||||
|
|
|
@ -35,7 +35,7 @@ namespace {
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* node : program->ASTNodes().Objects()) {
|
for (auto* node : program->ASTNodes().Objects()) {
|
||||||
if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) {
|
if (node->IsAnyOf<CompoundAssignmentStatement, IncrementDecrementStatement>()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -58,17 +58,14 @@ struct ExpandCompoundAssignment::State {
|
||||||
/// @param lhs the lhs expression from the source statement
|
/// @param lhs the lhs expression from the source statement
|
||||||
/// @param rhs the rhs expression in the destination module
|
/// @param rhs the rhs expression in the destination module
|
||||||
/// @param op the binary operator
|
/// @param op the binary operator
|
||||||
void Expand(const ast::Statement* stmt,
|
void Expand(const Statement* stmt, const Expression* lhs, const Expression* rhs, BinaryOp op) {
|
||||||
const ast::Expression* lhs,
|
|
||||||
const ast::Expression* rhs,
|
|
||||||
ast::BinaryOp op) {
|
|
||||||
// Helper function to create the new LHS expression. This will be called
|
// Helper function to create the new LHS expression. This will be called
|
||||||
// twice when building the non-compound assignment statement, so must
|
// twice when building the non-compound assignment statement, so must
|
||||||
// not produce expressions that cause side effects.
|
// not produce expressions that cause side effects.
|
||||||
std::function<const ast::Expression*()> new_lhs;
|
std::function<const Expression*()> new_lhs;
|
||||||
|
|
||||||
// Helper function to create a variable that is a pointer to `expr`.
|
// Helper function to create a variable that is a pointer to `expr`.
|
||||||
auto hoist_pointer_to = [&](const ast::Expression* expr) {
|
auto hoist_pointer_to = [&](const Expression* expr) {
|
||||||
auto name = b.Sym();
|
auto name = b.Sym();
|
||||||
auto* ptr = b.AddressOf(ctx.Clone(expr));
|
auto* ptr = b.AddressOf(ctx.Clone(expr));
|
||||||
auto* decl = b.Decl(b.Let(name, ptr));
|
auto* decl = b.Decl(b.Let(name, ptr));
|
||||||
|
@ -77,7 +74,7 @@ struct ExpandCompoundAssignment::State {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper function to hoist `expr` to a let declaration.
|
// Helper function to hoist `expr` to a let declaration.
|
||||||
auto hoist_expr_to_let = [&](const ast::Expression* expr) {
|
auto hoist_expr_to_let = [&](const Expression* expr) {
|
||||||
auto name = b.Sym();
|
auto name = b.Sym();
|
||||||
auto* decl = b.Decl(b.Let(name, ctx.Clone(expr)));
|
auto* decl = b.Decl(b.Let(name, ctx.Clone(expr)));
|
||||||
hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
|
hoist_to_decl_before.InsertBefore(ctx.src->Sem().Get(stmt), decl);
|
||||||
|
@ -85,7 +82,7 @@ struct ExpandCompoundAssignment::State {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper function that returns `true` if the type of `expr` is a vector.
|
// Helper function that returns `true` if the type of `expr` is a vector.
|
||||||
auto is_vec = [&](const ast::Expression* expr) {
|
auto is_vec = [&](const Expression* expr) {
|
||||||
if (auto* val_expr = ctx.src->Sem().GetVal(expr)) {
|
if (auto* val_expr = ctx.src->Sem().GetVal(expr)) {
|
||||||
return val_expr->Type()->UnwrapRef()->Is<type::Vector>();
|
return val_expr->Type()->UnwrapRef()->Is<type::Vector>();
|
||||||
}
|
}
|
||||||
|
@ -96,10 +93,10 @@ struct ExpandCompoundAssignment::State {
|
||||||
// LHS that we can evaluate twice.
|
// LHS that we can evaluate twice.
|
||||||
// We need to special case compound assignments to vector components since
|
// We need to special case compound assignments to vector components since
|
||||||
// we cannot take the address of a vector component.
|
// we cannot take the address of a vector component.
|
||||||
auto* index_accessor = lhs->As<ast::IndexAccessorExpression>();
|
auto* index_accessor = lhs->As<IndexAccessorExpression>();
|
||||||
auto* member_accessor = lhs->As<ast::MemberAccessorExpression>();
|
auto* member_accessor = lhs->As<MemberAccessorExpression>();
|
||||||
if (lhs->Is<ast::IdentifierExpression>() ||
|
if (lhs->Is<IdentifierExpression>() ||
|
||||||
(member_accessor && member_accessor->object->Is<ast::IdentifierExpression>())) {
|
(member_accessor && member_accessor->object->Is<IdentifierExpression>())) {
|
||||||
// This is the simple case with no side effects, so we can just use the
|
// This is the simple case with no side effects, so we can just use the
|
||||||
// original LHS expression directly.
|
// original LHS expression directly.
|
||||||
// Before:
|
// Before:
|
||||||
|
@ -144,7 +141,7 @@ struct ExpandCompoundAssignment::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace the statement with a regular assignment statement.
|
// Replace the statement with a regular assignment statement.
|
||||||
auto* value = b.create<ast::BinaryExpression>(op, new_lhs(), rhs);
|
auto* value = b.create<BinaryExpression>(op, new_lhs(), rhs);
|
||||||
ctx.Replace(stmt, b.Assign(new_lhs(), value));
|
ctx.Replace(stmt, b.Assign(new_lhs(), value));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,11 +171,11 @@ Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src,
|
||||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||||
State state(ctx);
|
State state(ctx);
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
|
if (auto* assign = node->As<CompoundAssignmentStatement>()) {
|
||||||
state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
|
state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
|
||||||
} else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
|
} else if (auto* inc_dec = node->As<IncrementDecrementStatement>()) {
|
||||||
// For increment/decrement statements, `i++` becomes `i = i + 1`.
|
// For increment/decrement statements, `i++` becomes `i = i + 1`.
|
||||||
auto op = inc_dec->increment ? ast::BinaryOp::kAdd : ast::BinaryOp::kSubtract;
|
auto op = inc_dec->increment ? BinaryOp::kAdd : BinaryOp::kSubtract;
|
||||||
state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
|
state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ constexpr char kFirstInstanceName[] = "first_instance_index";
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* fn : program->AST().Functions()) {
|
for (auto* fn : program->AST().Functions()) {
|
||||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
if (fn->PipelineStage() == PipelineStage::kVertex) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -86,9 +86,9 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
|
||||||
// Traverse the AST scanning for builtin accesses via variables (includes
|
// Traverse the AST scanning for builtin accesses via variables (includes
|
||||||
// parameters) or structure member accesses.
|
// parameters) or structure member accesses.
|
||||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||||
if (auto* var = node->As<ast::Variable>()) {
|
if (auto* var = node->As<Variable>()) {
|
||||||
for (auto* attr : var->attributes) {
|
for (auto* attr : var->attributes) {
|
||||||
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
|
if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
|
||||||
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
|
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
|
||||||
if (builtin == builtin::BuiltinValue::kVertexIndex) {
|
if (builtin == builtin::BuiltinValue::kVertexIndex) {
|
||||||
auto* sem_var = ctx.src->Sem().Get(var);
|
auto* sem_var = ctx.src->Sem().Get(var);
|
||||||
|
@ -103,9 +103,9 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto* member = node->As<ast::StructMember>()) {
|
if (auto* member = node->As<StructMember>()) {
|
||||||
for (auto* attr : member->attributes) {
|
for (auto* attr : member->attributes) {
|
||||||
if (auto* builtin_attr = attr->As<ast::BuiltinAttribute>()) {
|
if (auto* builtin_attr = attr->As<BuiltinAttribute>()) {
|
||||||
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
|
builtin::BuiltinValue builtin = src->Sem().Get(builtin_attr)->Value();
|
||||||
if (builtin == builtin::BuiltinValue::kVertexIndex) {
|
if (builtin == builtin::BuiltinValue::kVertexIndex) {
|
||||||
auto* sem_mem = ctx.src->Sem().Get(member);
|
auto* sem_mem = ctx.src->Sem().Get(member);
|
||||||
|
@ -124,7 +124,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
|
||||||
|
|
||||||
if (has_vertex_or_instance_index) {
|
if (has_vertex_or_instance_index) {
|
||||||
// Add uniform buffer members and calculate byte offsets
|
// Add uniform buffer members and calculate byte offsets
|
||||||
utils::Vector<const ast::StructMember*, 8> members;
|
utils::Vector<const StructMember*, 8> members;
|
||||||
members.Push(b.Member(kFirstVertexName, b.ty.u32()));
|
members.Push(b.Member(kFirstVertexName, b.ty.u32()));
|
||||||
members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
|
members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
|
||||||
auto* struct_ = b.Structure(b.Sym(), std::move(members));
|
auto* struct_ = b.Structure(b.Sym(), std::move(members));
|
||||||
|
@ -138,7 +138,7 @@ Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Fix up all references to the builtins with the offsets
|
// Fix up all references to the builtins with the offsets
|
||||||
ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
|
ctx.ReplaceAll([=, &ctx](const Expression* expr) -> const Expression* {
|
||||||
if (auto* sem = ctx.src->Sem().GetVal(expr)) {
|
if (auto* sem = ctx.src->Sem().GetVal(expr)) {
|
||||||
if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) {
|
if (auto* user = sem->UnwrapLoad()->As<sem::VariableUser>()) {
|
||||||
auto it = builtin_vars.find(user->Variable());
|
auto it = builtin_vars.find(user->Variable());
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace {
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* node : program->ASTNodes().Objects()) {
|
for (auto* node : program->ASTNodes().Objects()) {
|
||||||
if (node->Is<ast::ForLoopStatement>()) {
|
if (node->Is<ForLoopStatement>()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,8 +47,8 @@ Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&,
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const ForLoopStatement* for_loop) -> const Statement* {
|
||||||
utils::Vector<const ast::Statement*, 8> stmts;
|
utils::Vector<const Statement*, 8> stmts;
|
||||||
if (auto* cond = for_loop->condition) {
|
if (auto* cond = for_loop->condition) {
|
||||||
// !condition
|
// !condition
|
||||||
auto* not_cond = b.Not(ctx.Clone(cond));
|
auto* not_cond = b.Not(ctx.Clone(cond));
|
||||||
|
@ -63,7 +63,7 @@ Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&,
|
||||||
stmts.Push(ctx.Clone(stmt));
|
stmts.Push(ctx.Clone(stmt));
|
||||||
}
|
}
|
||||||
|
|
||||||
const ast::BlockStatement* continuing = nullptr;
|
const BlockStatement* continuing = nullptr;
|
||||||
if (auto* cont = for_loop->continuing) {
|
if (auto* cont = for_loop->continuing) {
|
||||||
continuing = b.Block(ctx.Clone(cont));
|
continuing = b.Block(ctx.Clone(cont));
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,14 +43,14 @@ struct LocalizeStructArrayAssignment::State {
|
||||||
ApplyResult Run() {
|
ApplyResult Run() {
|
||||||
struct Shared {
|
struct Shared {
|
||||||
bool process_nested_nodes = false;
|
bool process_nested_nodes = false;
|
||||||
utils::Vector<const ast::Statement*, 4> insert_before_stmts;
|
utils::Vector<const Statement*, 4> insert_before_stmts;
|
||||||
utils::Vector<const ast::Statement*, 4> insert_after_stmts;
|
utils::Vector<const Statement*, 4> insert_after_stmts;
|
||||||
} s;
|
} s;
|
||||||
|
|
||||||
bool made_changes = false;
|
bool made_changes = false;
|
||||||
|
|
||||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||||
if (auto* assign_stmt = node->As<ast::AssignmentStatement>()) {
|
if (auto* assign_stmt = node->As<AssignmentStatement>()) {
|
||||||
// Process if it's an assignment statement to a dynamically indexed array
|
// Process if it's an assignment statement to a dynamically indexed array
|
||||||
// within a struct on a function or private storage variable. This
|
// within a struct on a function or private storage variable. This
|
||||||
// specific use-case is what FXC fails to compile with:
|
// specific use-case is what FXC fails to compile with:
|
||||||
|
@ -70,7 +70,7 @@ struct LocalizeStructArrayAssignment::State {
|
||||||
// Reset shared state for this assignment statement
|
// Reset shared state for this assignment statement
|
||||||
s = Shared{};
|
s = Shared{};
|
||||||
|
|
||||||
const ast::Expression* new_lhs = nullptr;
|
const Expression* new_lhs = nullptr;
|
||||||
{
|
{
|
||||||
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
|
TINT_SCOPED_ASSIGNMENT(s.process_nested_nodes, true);
|
||||||
new_lhs = ctx.Clone(assign_stmt->lhs);
|
new_lhs = ctx.Clone(assign_stmt->lhs);
|
||||||
|
@ -98,14 +98,13 @@ struct LocalizeStructArrayAssignment::State {
|
||||||
return SkipTransform;
|
return SkipTransform;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.ReplaceAll(
|
ctx.ReplaceAll([&](const IndexAccessorExpression* index_access) -> const Expression* {
|
||||||
[&](const ast::IndexAccessorExpression* index_access) -> const ast::Expression* {
|
|
||||||
if (!s.process_nested_nodes) {
|
if (!s.process_nested_nodes) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Indexing a member access expr?
|
// Indexing a member access expr?
|
||||||
auto* mem_access = index_access->object->As<ast::MemberAccessorExpression>();
|
auto* mem_access = index_access->object->As<MemberAccessorExpression>();
|
||||||
if (!mem_access) {
|
if (!mem_access) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -136,7 +135,7 @@ struct LocalizeStructArrayAssignment::State {
|
||||||
// e.g. *(tint_symbol) = tint_symbol_1;
|
// e.g. *(tint_symbol) = tint_symbol_1;
|
||||||
auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
|
auto* assign_rhs_to_temp = b.Assign(b.Deref(mem_access_ptr), tmp_var);
|
||||||
{
|
{
|
||||||
utils::Vector<const ast::Statement*, 8> stmts{assign_rhs_to_temp};
|
utils::Vector<const Statement*, 8> stmts{assign_rhs_to_temp};
|
||||||
for (auto* stmt : s.insert_after_stmts) {
|
for (auto* stmt : s.insert_after_stmts) {
|
||||||
stmts.Push(stmt);
|
stmts.Push(stmt);
|
||||||
}
|
}
|
||||||
|
@ -160,23 +159,22 @@ struct LocalizeStructArrayAssignment::State {
|
||||||
|
|
||||||
/// Returns true if `expr` contains an index accessor expression to a
|
/// Returns true if `expr` contains an index accessor expression to a
|
||||||
/// structure member of array type.
|
/// structure member of array type.
|
||||||
bool ContainsStructArrayIndex(const ast::Expression* expr) {
|
bool ContainsStructArrayIndex(const Expression* expr) {
|
||||||
bool result = false;
|
bool result = false;
|
||||||
ast::TraverseExpressions(
|
TraverseExpressions(expr, b.Diagnostics(), [&](const IndexAccessorExpression* ia) {
|
||||||
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
|
|
||||||
// Indexing using a runtime value?
|
// Indexing using a runtime value?
|
||||||
auto* idx_sem = src->Sem().GetVal(ia->index);
|
auto* idx_sem = src->Sem().GetVal(ia->index);
|
||||||
if (!idx_sem->ConstantValue()) {
|
if (!idx_sem->ConstantValue()) {
|
||||||
// Indexing a member access expr?
|
// Indexing a member access expr?
|
||||||
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
|
if (auto* ma = ia->object->As<MemberAccessorExpression>()) {
|
||||||
// That accesses an array?
|
// That accesses an array?
|
||||||
if (src->TypeOf(ma)->UnwrapRef()->Is<type::Array>()) {
|
if (src->TypeOf(ma)->UnwrapRef()->Is<type::Array>()) {
|
||||||
result = true;
|
result = true;
|
||||||
return ast::TraverseAction::Stop;
|
return TraverseAction::Stop;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ast::TraverseAction::Descend;
|
return TraverseAction::Descend;
|
||||||
});
|
});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -186,7 +184,7 @@ struct LocalizeStructArrayAssignment::State {
|
||||||
// of the assignment statement.
|
// of the assignment statement.
|
||||||
// See https://www.w3.org/TR/WGSL/#originating-variable-section
|
// See https://www.w3.org/TR/WGSL/#originating-variable-section
|
||||||
std::pair<const type::Type*, builtin::AddressSpace> GetOriginatingTypeAndAddressSpace(
|
std::pair<const type::Type*, builtin::AddressSpace> GetOriginatingTypeAndAddressSpace(
|
||||||
const ast::AssignmentStatement* assign_stmt) {
|
const AssignmentStatement* assign_stmt) {
|
||||||
auto* root_ident = src->Sem().GetVal(assign_stmt->lhs)->RootIdentifier();
|
auto* root_ident = src->Sem().GetVal(assign_stmt->lhs)->RootIdentifier();
|
||||||
if (TINT_UNLIKELY(!root_ident)) {
|
if (TINT_UNLIKELY(!root_ident)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
|
|
|
@ -30,12 +30,12 @@ namespace tint::ast::transform {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
/// Returns `true` if `stmt` has the behavior `behavior`.
|
/// Returns `true` if `stmt` has the behavior `behavior`.
|
||||||
bool HasBehavior(const Program* program, const ast::Statement* stmt, sem::Behavior behavior) {
|
bool HasBehavior(const Program* program, const Statement* stmt, sem::Behavior behavior) {
|
||||||
return program->Sem().Get(stmt)->Behaviors().Contains(behavior);
|
return program->Sem().Get(stmt)->Behaviors().Contains(behavior);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns `true` if `func` needs to be transformed.
|
/// Returns `true` if `func` needs to be transformed.
|
||||||
bool NeedsTransform(const Program* program, const ast::Function* func) {
|
bool NeedsTransform(const Program* program, const Function* func) {
|
||||||
// Entry points and intrinsic declarations never need transforming.
|
// Entry points and intrinsic declarations never need transforming.
|
||||||
if (func->IsEntryPoint() || func->body == nullptr) {
|
if (func->IsEntryPoint() || func->body == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -49,7 +49,7 @@ bool NeedsTransform(const Program* program, const ast::Function* func) {
|
||||||
if (HasBehavior(program, s, sem::Behavior::kReturn)) {
|
if (HasBehavior(program, s, sem::Behavior::kReturn)) {
|
||||||
// If this statement is itself a return, it will be the only exit point,
|
// If this statement is itself a return, it will be the only exit point,
|
||||||
// so no need to apply the transform to the function.
|
// so no need to apply the transform to the function.
|
||||||
if (s->Is<ast::ReturnStatement>()) {
|
if (s->Is<ReturnStatement>()) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
// Apply the transform in all other cases.
|
// Apply the transform in all other cases.
|
||||||
|
@ -78,7 +78,7 @@ class State {
|
||||||
ProgramBuilder& b;
|
ProgramBuilder& b;
|
||||||
|
|
||||||
/// The function.
|
/// The function.
|
||||||
const ast::Function* function;
|
const Function* function;
|
||||||
|
|
||||||
/// The symbol for the return flag variable.
|
/// The symbol for the return flag variable.
|
||||||
Symbol flag;
|
Symbol flag;
|
||||||
|
@ -92,32 +92,32 @@ class State {
|
||||||
public:
|
public:
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param context the clone context
|
/// @param context the clone context
|
||||||
State(CloneContext& context, const ast::Function* func)
|
State(CloneContext& context, const Function* func)
|
||||||
: ctx(context), b(*ctx.dst), function(func) {}
|
: ctx(context), b(*ctx.dst), function(func) {}
|
||||||
|
|
||||||
/// Process a statement (recursively).
|
/// Process a statement (recursively).
|
||||||
void ProcessStatement(const ast::Statement* stmt) {
|
void ProcessStatement(const Statement* stmt) {
|
||||||
if (stmt == nullptr || !HasBehavior(ctx.src, stmt, sem::Behavior::kReturn)) {
|
if (stmt == nullptr || !HasBehavior(ctx.src, stmt, sem::Behavior::kReturn)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Switch(
|
Switch(
|
||||||
stmt, [&](const ast::BlockStatement* block) { ProcessBlock(block); },
|
stmt, [&](const BlockStatement* block) { ProcessBlock(block); },
|
||||||
[&](const ast::CaseStatement* c) { ProcessStatement(c->body); },
|
[&](const CaseStatement* c) { ProcessStatement(c->body); },
|
||||||
[&](const ast::ForLoopStatement* f) {
|
[&](const ForLoopStatement* f) {
|
||||||
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
||||||
ProcessStatement(f->body);
|
ProcessStatement(f->body);
|
||||||
},
|
},
|
||||||
[&](const ast::IfStatement* i) {
|
[&](const IfStatement* i) {
|
||||||
ProcessStatement(i->body);
|
ProcessStatement(i->body);
|
||||||
ProcessStatement(i->else_statement);
|
ProcessStatement(i->else_statement);
|
||||||
},
|
},
|
||||||
[&](const ast::LoopStatement* l) {
|
[&](const LoopStatement* l) {
|
||||||
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
||||||
ProcessStatement(l->body);
|
ProcessStatement(l->body);
|
||||||
},
|
},
|
||||||
[&](const ast::ReturnStatement* r) {
|
[&](const ReturnStatement* r) {
|
||||||
utils::Vector<const ast::Statement*, 3> stmts;
|
utils::Vector<const Statement*, 3> stmts;
|
||||||
// Set the return flag to signal that we have hit a return.
|
// Set the return flag to signal that we have hit a return.
|
||||||
stmts.Push(b.Assign(b.Expr(flag), true));
|
stmts.Push(b.Assign(b.Expr(flag), true));
|
||||||
if (r->value) {
|
if (r->value) {
|
||||||
|
@ -130,25 +130,25 @@ class State {
|
||||||
}
|
}
|
||||||
ctx.Replace(r, b.Block(std::move(stmts)));
|
ctx.Replace(r, b.Block(std::move(stmts)));
|
||||||
},
|
},
|
||||||
[&](const ast::SwitchStatement* s) {
|
[&](const SwitchStatement* s) {
|
||||||
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
||||||
for (auto* c : s->body) {
|
for (auto* c : s->body) {
|
||||||
ProcessStatement(c);
|
ProcessStatement(c);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::WhileStatement* w) {
|
[&](const WhileStatement* w) {
|
||||||
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
TINT_SCOPED_ASSIGNMENT(is_in_loop_or_switch, true);
|
||||||
ProcessStatement(w->body);
|
ProcessStatement(w->body);
|
||||||
},
|
},
|
||||||
[&](Default) { TINT_ICE(Transform, b.Diagnostics()) << "unhandled statement type"; });
|
[&](Default) { TINT_ICE(Transform, b.Diagnostics()) << "unhandled statement type"; });
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessBlock(const ast::BlockStatement* block) {
|
void ProcessBlock(const BlockStatement* block) {
|
||||||
// We will rebuild the contents of the block statement.
|
// We will rebuild the contents of the block statement.
|
||||||
// We may introduce conditionals around statements that follow a statement with the
|
// We may introduce conditionals around statements that follow a statement with the
|
||||||
// `Return` behavior, so build a stack of statement lists that represent the new
|
// `Return` behavior, so build a stack of statement lists that represent the new
|
||||||
// (potentially nested) conditional blocks.
|
// (potentially nested) conditional blocks.
|
||||||
utils::Vector<utils::Vector<const ast::Statement*, 8>, 8> new_stmts({{}});
|
utils::Vector<utils::Vector<const Statement*, 8>, 8> new_stmts({{}});
|
||||||
|
|
||||||
// Insert variables for the return flag and return value at the top of the function.
|
// Insert variables for the return flag and return value at the top of the function.
|
||||||
if (block == function->body) {
|
if (block == function->body) {
|
||||||
|
@ -173,8 +173,7 @@ class State {
|
||||||
if (is_in_loop_or_switch) {
|
if (is_in_loop_or_switch) {
|
||||||
// We're in a loop/switch, and so we would have inserted a `break`.
|
// We're in a loop/switch, and so we would have inserted a `break`.
|
||||||
// If we've just come out of a loop/switch statement, we need to `break` again.
|
// If we've just come out of a loop/switch statement, we need to `break` again.
|
||||||
if (s->IsAnyOf<ast::LoopStatement, ast::ForLoopStatement,
|
if (s->IsAnyOf<LoopStatement, ForLoopStatement, SwitchStatement>()) {
|
||||||
ast::SwitchStatement>()) {
|
|
||||||
// If the loop only has the 'Return' behavior, we can just unconditionally
|
// If the loop only has the 'Return' behavior, we can just unconditionally
|
||||||
// break. Otherwise check the return flag.
|
// break. Otherwise check the return flag.
|
||||||
if (HasBehavior(ctx.src, s, sem::Behavior::kNext)) {
|
if (HasBehavior(ctx.src, s, sem::Behavior::kNext)) {
|
||||||
|
@ -194,7 +193,7 @@ class State {
|
||||||
|
|
||||||
// Descend the stack of new block statements, wrapping them in conditionals.
|
// Descend the stack of new block statements, wrapping them in conditionals.
|
||||||
while (new_stmts.Length() > 1) {
|
while (new_stmts.Length() > 1) {
|
||||||
const ast::IfStatement* i = nullptr;
|
const IfStatement* i = nullptr;
|
||||||
if (new_stmts.Back().Length() > 0) {
|
if (new_stmts.Back().Length() > 0) {
|
||||||
i = b.If(b.Not(b.Expr(flag)), b.Block(new_stmts.Back()));
|
i = b.If(b.Not(b.Expr(flag)), b.Block(new_stmts.Back()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,14 +33,14 @@ TINT_INSTANTIATE_TYPEINFO(tint::ast::transform::ModuleScopeVarToEntryPointParam)
|
||||||
namespace tint::ast::transform {
|
namespace tint::ast::transform {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using StructMemberList = utils::Vector<const ast::StructMember*, 8>;
|
using StructMemberList = utils::Vector<const StructMember*, 8>;
|
||||||
|
|
||||||
// The name of the struct member for arrays that are wrapped in structures.
|
// The name of the struct member for arrays that are wrapped in structures.
|
||||||
const char* kWrappedArrayMemberName = "arr";
|
const char* kWrappedArrayMemberName = "arr";
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||||
if (decl->Is<ast::Variable>()) {
|
if (decl->Is<Variable>()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
/// @param workgroup_parameter_members reference to a list of a workgroup struct members
|
/// @param workgroup_parameter_members reference to a list of a workgroup struct members
|
||||||
/// @param is_pointer output signalling whether the replacement is a pointer
|
/// @param is_pointer output signalling whether the replacement is a pointer
|
||||||
/// @param is_wrapped output signalling whether the replacement is wrapped in a struct
|
/// @param is_wrapped output signalling whether the replacement is wrapped in a struct
|
||||||
void ProcessVariableInEntryPoint(const ast::Function* func,
|
void ProcessVariableInEntryPoint(const Function* func,
|
||||||
const sem::Variable* var,
|
const sem::Variable* var,
|
||||||
Symbol new_var_symbol,
|
Symbol new_var_symbol,
|
||||||
std::function<Symbol()> workgroup_param,
|
std::function<Symbol()> workgroup_param,
|
||||||
|
@ -128,7 +128,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
// For a texture or sampler variable, redeclare it as an entry point parameter.
|
// For a texture or sampler variable, redeclare it as an entry point parameter.
|
||||||
// Disable entry point parameter validation.
|
// Disable entry point parameter validation.
|
||||||
auto* disable_validation =
|
auto* disable_validation =
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter);
|
ctx.dst->Disable(DisabledValidation::kEntryPointParameter);
|
||||||
auto attrs = ctx.Clone(var->Declaration()->attributes);
|
auto attrs = ctx.Clone(var->Declaration()->attributes);
|
||||||
attrs.Push(disable_validation);
|
attrs.Push(disable_validation);
|
||||||
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
|
auto* param = ctx.dst->Param(new_var_symbol, store_type(), attrs);
|
||||||
|
@ -141,8 +141,8 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
// Variables into the Storage and Uniform address spaces are redeclared as entry
|
// Variables into the Storage and Uniform address spaces are redeclared as entry
|
||||||
// point parameters with a pointer type.
|
// point parameters with a pointer type.
|
||||||
auto attributes = ctx.Clone(var->Declaration()->attributes);
|
auto attributes = ctx.Clone(var->Declaration()->attributes);
|
||||||
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter));
|
attributes.Push(ctx.dst->Disable(DisabledValidation::kEntryPointParameter));
|
||||||
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
|
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
|
||||||
|
|
||||||
auto param_type = store_type();
|
auto param_type = store_type();
|
||||||
if (auto* arr = ty->As<type::Array>();
|
if (auto* arr = ty->As<type::Array>();
|
||||||
|
@ -190,7 +190,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
is_pointer = true;
|
is_pointer = true;
|
||||||
} else {
|
} else {
|
||||||
auto* disable_validation =
|
auto* disable_validation =
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace);
|
ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace);
|
||||||
auto* initializer = ctx.Clone(var->Declaration()->initializer);
|
auto* initializer = ctx.Clone(var->Declaration()->initializer);
|
||||||
auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer,
|
auto* local_var = ctx.dst->Var(new_var_symbol, store_type(), sc, initializer,
|
||||||
utils::Vector{disable_validation});
|
utils::Vector{disable_validation});
|
||||||
|
@ -218,7 +218,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
/// @param var the variable
|
/// @param var the variable
|
||||||
/// @param new_var_symbol the symbol to use for the replacement
|
/// @param new_var_symbol the symbol to use for the replacement
|
||||||
/// @param is_pointer output signalling whether the replacement is a pointer or not
|
/// @param is_pointer output signalling whether the replacement is a pointer or not
|
||||||
void ProcessVariableInUserFunction(const ast::Function* func,
|
void ProcessVariableInUserFunction(const Function* func,
|
||||||
const sem::Variable* var,
|
const sem::Variable* var,
|
||||||
Symbol new_var_symbol,
|
Symbol new_var_symbol,
|
||||||
bool& is_pointer) {
|
bool& is_pointer) {
|
||||||
|
@ -247,7 +247,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a pointer for non-handle types.
|
// Use a pointer for non-handle types.
|
||||||
utils::Vector<const ast::Attribute*, 2> attributes;
|
utils::Vector<const Attribute*, 2> attributes;
|
||||||
if (!ty->is_handle()) {
|
if (!ty->is_handle()) {
|
||||||
param_type = sc == builtin::AddressSpace::kStorage
|
param_type = sc == builtin::AddressSpace::kStorage
|
||||||
? ctx.dst->ty.pointer(param_type, sc, var->Access())
|
? ctx.dst->ty.pointer(param_type, sc, var->Access())
|
||||||
|
@ -255,9 +255,8 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
is_pointer = true;
|
is_pointer = true;
|
||||||
|
|
||||||
// Disable validation of the parameter's address space and of arguments passed to it.
|
// Disable validation of the parameter's address space and of arguments passed to it.
|
||||||
attributes.Push(ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace));
|
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace));
|
||||||
attributes.Push(
|
attributes.Push(ctx.dst->Disable(DisabledValidation::kIgnoreInvalidPointerArgument));
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreInvalidPointerArgument));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redeclare the variable as a parameter.
|
// Redeclare the variable as a parameter.
|
||||||
|
@ -271,18 +270,18 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
/// @param new_var the symbol to use for replacement
|
/// @param new_var the symbol to use for replacement
|
||||||
/// @param is_pointer true if `new_var` is a pointer to the new variable
|
/// @param is_pointer true if `new_var` is a pointer to the new variable
|
||||||
/// @param member_name if valid, the name of the struct member that holds this variable
|
/// @param member_name if valid, the name of the struct member that holds this variable
|
||||||
void ReplaceUsesInFunction(const ast::Function* func,
|
void ReplaceUsesInFunction(const Function* func,
|
||||||
const sem::Variable* var,
|
const sem::Variable* var,
|
||||||
Symbol new_var,
|
Symbol new_var,
|
||||||
bool is_pointer,
|
bool is_pointer,
|
||||||
Symbol member_name) {
|
Symbol member_name) {
|
||||||
for (auto* user : var->Users()) {
|
for (auto* user : var->Users()) {
|
||||||
if (user->Stmt()->Function()->Declaration() == func) {
|
if (user->Stmt()->Function()->Declaration() == func) {
|
||||||
const ast::Expression* expr = ctx.dst->Expr(new_var);
|
const Expression* expr = ctx.dst->Expr(new_var);
|
||||||
if (is_pointer) {
|
if (is_pointer) {
|
||||||
// If this identifier is used by an address-of operator, just remove the
|
// If this identifier is used by an address-of operator, just remove the
|
||||||
// address-of instead of adding a deref, since we already have a pointer.
|
// address-of instead of adding a deref, since we already have a pointer.
|
||||||
auto* ident = user->Declaration()->As<ast::IdentifierExpression>();
|
auto* ident = user->Declaration()->As<IdentifierExpression>();
|
||||||
if (ident_to_address_of_.count(ident) && !member_name.IsValid()) {
|
if (ident_to_address_of_.count(ident) && !member_name.IsValid()) {
|
||||||
ctx.Replace(ident_to_address_of_[ident], expr);
|
ctx.Replace(ident_to_address_of_[ident], expr);
|
||||||
continue;
|
continue;
|
||||||
|
@ -302,19 +301,19 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
/// Process the module.
|
/// Process the module.
|
||||||
void Process() {
|
void Process() {
|
||||||
// Predetermine the list of function calls that need to be replaced.
|
// Predetermine the list of function calls that need to be replaced.
|
||||||
using CallList = utils::Vector<const ast::CallExpression*, 8>;
|
using CallList = utils::Vector<const CallExpression*, 8>;
|
||||||
std::unordered_map<const ast::Function*, CallList> calls_to_replace;
|
std::unordered_map<const Function*, CallList> calls_to_replace;
|
||||||
|
|
||||||
utils::Vector<const ast::Function*, 8> functions_to_process;
|
utils::Vector<const Function*, 8> functions_to_process;
|
||||||
|
|
||||||
// Collect private variables into a single structure.
|
// Collect private variables into a single structure.
|
||||||
StructMemberList private_struct_members;
|
StructMemberList private_struct_members;
|
||||||
utils::Vector<std::function<const ast::AssignmentStatement*()>, 4> private_initializers;
|
utils::Vector<std::function<const AssignmentStatement*()>, 4> private_initializers;
|
||||||
std::unordered_set<const ast::Function*> uses_privates;
|
std::unordered_set<const Function*> uses_privates;
|
||||||
|
|
||||||
// Build a list of functions that transitively reference any module-scope variables.
|
// Build a list of functions that transitively reference any module-scope variables.
|
||||||
for (auto* decl : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
|
for (auto* decl : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
|
||||||
if (auto* var = decl->As<ast::Var>()) {
|
if (auto* var = decl->As<Var>()) {
|
||||||
auto* sem_var = ctx.src->Sem().Get(var);
|
auto* sem_var = ctx.src->Sem().Get(var);
|
||||||
if (sem_var->AddressSpace() == builtin::AddressSpace::kPrivate) {
|
if (sem_var->AddressSpace() == builtin::AddressSpace::kPrivate) {
|
||||||
// Create a member in the private variable struct.
|
// Create a member in the private variable struct.
|
||||||
|
@ -335,7 +334,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* func_ast = decl->As<ast::Function>();
|
auto* func_ast = decl->As<Function>();
|
||||||
if (!func_ast) {
|
if (!func_ast) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -376,11 +375,11 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
// TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can
|
// TODO(jrprice): We should add support for bidirectional SEM tree traversal so that we can
|
||||||
// do this on the fly instead.
|
// do this on the fly instead.
|
||||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||||
auto* address_of = node->As<ast::UnaryOpExpression>();
|
auto* address_of = node->As<UnaryOpExpression>();
|
||||||
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
|
if (!address_of || address_of->op != UnaryOp::kAddressOf) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto* ident = address_of->expr->As<ast::IdentifierExpression>()) {
|
if (auto* ident = address_of->expr->As<IdentifierExpression>()) {
|
||||||
ident_to_address_of_[ident] = address_of;
|
ident_to_address_of_[ident] = address_of;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -414,11 +413,11 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
if (uses_privates.count(func_ast)) {
|
if (uses_privates.count(func_ast)) {
|
||||||
if (is_entry_point) {
|
if (is_entry_point) {
|
||||||
// Create a local declaration for the private variable struct.
|
// Create a local declaration for the private variable struct.
|
||||||
auto* var = ctx.dst->Var(
|
auto* var =
|
||||||
PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()),
|
ctx.dst->Var(PrivateStructVariableName(), ctx.dst->ty(PrivateStructName()),
|
||||||
builtin::AddressSpace::kPrivate,
|
builtin::AddressSpace::kPrivate,
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace),
|
ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
|
||||||
});
|
});
|
||||||
ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(var));
|
ctx.InsertFront(func_ast->body->statements, ctx.dst->Decl(var));
|
||||||
|
|
||||||
|
@ -482,7 +481,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
// Allow pointer aliasing if needed.
|
// Allow pointer aliasing if needed.
|
||||||
if (needs_pointer_aliasing) {
|
if (needs_pointer_aliasing) {
|
||||||
ctx.InsertBack(func_ast->attributes,
|
ctx.InsertBack(func_ast->attributes,
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kIgnorePointerAliasing));
|
ctx.dst->Disable(DisabledValidation::kIgnorePointerAliasing));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!workgroup_parameter_members.IsEmpty()) {
|
if (!workgroup_parameter_members.IsEmpty()) {
|
||||||
|
@ -492,11 +491,11 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
|
ctx.dst->Structure(ctx.dst->Sym(), std::move(workgroup_parameter_members));
|
||||||
auto param_type =
|
auto param_type =
|
||||||
ctx.dst->ty.pointer(ctx.dst->ty.Of(str), builtin::AddressSpace::kWorkgroup);
|
ctx.dst->ty.pointer(ctx.dst->ty.Of(str), builtin::AddressSpace::kWorkgroup);
|
||||||
auto* param = ctx.dst->Param(
|
auto* param =
|
||||||
workgroup_param(), param_type,
|
ctx.dst->Param(workgroup_param(), param_type,
|
||||||
utils::Vector{
|
utils::Vector{
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kEntryPointParameter),
|
ctx.dst->Disable(DisabledValidation::kEntryPointParameter),
|
||||||
ctx.dst->Disable(ast::DisabledValidation::kIgnoreAddressSpace),
|
ctx.dst->Disable(DisabledValidation::kIgnoreAddressSpace),
|
||||||
});
|
});
|
||||||
ctx.InsertFront(func_ast->params, param);
|
ctx.InsertFront(func_ast->params, param);
|
||||||
}
|
}
|
||||||
|
@ -508,7 +507,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
|
|
||||||
// Pass the private variable struct pointer if needed.
|
// Pass the private variable struct pointer if needed.
|
||||||
if (uses_privates.count(target_sem->Declaration())) {
|
if (uses_privates.count(target_sem->Declaration())) {
|
||||||
const ast::Expression* arg = ctx.dst->Expr(PrivateStructVariableName());
|
const Expression* arg = ctx.dst->Expr(PrivateStructVariableName());
|
||||||
if (is_entry_point) {
|
if (is_entry_point) {
|
||||||
arg = ctx.dst->AddressOf(arg);
|
arg = ctx.dst->AddressOf(arg);
|
||||||
}
|
}
|
||||||
|
@ -531,7 +530,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
|
|
||||||
auto new_var = it->second;
|
auto new_var = it->second;
|
||||||
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
bool is_handle = target_var->Type()->UnwrapRef()->is_handle();
|
||||||
const ast::Expression* arg = ctx.dst->Expr(new_var.symbol);
|
const Expression* arg = ctx.dst->Expr(new_var.symbol);
|
||||||
if (new_var.is_wrapped) {
|
if (new_var.is_wrapped) {
|
||||||
// The variable is wrapped in a struct, so we need to pass a pointer to the
|
// The variable is wrapped in a struct, so we need to pass a pointer to the
|
||||||
// struct member instead.
|
// struct member instead.
|
||||||
|
@ -577,8 +576,7 @@ struct ModuleScopeVarToEntryPointParam::State {
|
||||||
std::unordered_set<const sem::Struct*> cloned_structs_;
|
std::unordered_set<const sem::Struct*> cloned_structs_;
|
||||||
|
|
||||||
// Map from identifier expression to the address-of expression that uses it.
|
// Map from identifier expression to the address-of expression that uses it.
|
||||||
std::unordered_map<const ast::IdentifierExpression*, const ast::UnaryOpExpression*>
|
std::unordered_map<const IdentifierExpression*, const UnaryOpExpression*> ident_to_address_of_;
|
||||||
ident_to_address_of_;
|
|
||||||
|
|
||||||
// The name of the structure that contains all the module-scope private variables.
|
// The name of the structure that contains all the module-scope private variables.
|
||||||
Symbol private_struct_name;
|
Symbol private_struct_name;
|
||||||
|
|
|
@ -143,7 +143,7 @@ struct MultiplanarExternalTexture::State {
|
||||||
|
|
||||||
// Replace the original texture_external binding with a texture_2d<f32> binding.
|
// Replace the original texture_external binding with a texture_2d<f32> binding.
|
||||||
auto cloned_attributes = ctx.Clone(global->attributes);
|
auto cloned_attributes = ctx.Clone(global->attributes);
|
||||||
const ast::Expression* cloned_initializer = ctx.Clone(global->initializer);
|
const Expression* cloned_initializer = ctx.Clone(global->initializer);
|
||||||
|
|
||||||
auto* replacement =
|
auto* replacement =
|
||||||
b.Var(syms.plane_0, b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32()),
|
b.Var(syms.plane_0, b.ty.sampled_texture(type::TextureDimension::k2d, b.ty.f32()),
|
||||||
|
@ -153,7 +153,7 @@ struct MultiplanarExternalTexture::State {
|
||||||
|
|
||||||
// We must update all the texture_external parameters for user declared functions.
|
// We must update all the texture_external parameters for user declared functions.
|
||||||
for (auto* fn : ctx.src->AST().Functions()) {
|
for (auto* fn : ctx.src->AST().Functions()) {
|
||||||
for (const ast::Variable* param : fn->params) {
|
for (const Variable* param : fn->params) {
|
||||||
if (auto* sem_var = sem.Get(param)) {
|
if (auto* sem_var = sem.Get(param)) {
|
||||||
if (!sem_var->Type()->UnwrapRef()->Is<type::ExternalTexture>()) {
|
if (!sem_var->Type()->UnwrapRef()->Is<type::ExternalTexture>()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -184,7 +184,7 @@ struct MultiplanarExternalTexture::State {
|
||||||
|
|
||||||
// Transform the external texture builtin calls into calls to the external texture
|
// Transform the external texture builtin calls into calls to the external texture
|
||||||
// functions.
|
// functions.
|
||||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
|
||||||
auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = sem.Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
auto* builtin = call->Target()->As<sem::Builtin>();
|
auto* builtin = call->Target()->As<sem::Builtin>();
|
||||||
|
|
||||||
|
@ -305,10 +305,10 @@ struct MultiplanarExternalTexture::State {
|
||||||
/// @param call_type determines which function body to generate
|
/// @param call_type determines which function body to generate
|
||||||
/// @returns a statement list that makes of the body of the chosen function
|
/// @returns a statement list that makes of the body of the chosen function
|
||||||
auto buildTextureBuiltinBody(builtin::Function call_type) {
|
auto buildTextureBuiltinBody(builtin::Function call_type) {
|
||||||
utils::Vector<const ast::Statement*, 16> stmts;
|
utils::Vector<const Statement*, 16> stmts;
|
||||||
const ast::CallExpression* single_plane_call = nullptr;
|
const CallExpression* single_plane_call = nullptr;
|
||||||
const ast::CallExpression* plane_0_call = nullptr;
|
const CallExpression* plane_0_call = nullptr;
|
||||||
const ast::CallExpression* plane_1_call = nullptr;
|
const CallExpression* plane_1_call = nullptr;
|
||||||
switch (call_type) {
|
switch (call_type) {
|
||||||
case builtin::Function::kTextureSampleBaseClampToEdge:
|
case builtin::Function::kTextureSampleBaseClampToEdge:
|
||||||
stmts.Push(b.Decl(b.Let(
|
stmts.Push(b.Decl(b.Let(
|
||||||
|
@ -395,9 +395,9 @@ struct MultiplanarExternalTexture::State {
|
||||||
/// @param expr the call expression being transformed
|
/// @param expr the call expression being transformed
|
||||||
/// @param syms the expanded symbols to be used in the new call
|
/// @param syms the expanded symbols to be used in the new call
|
||||||
/// @returns a call expression to textureSampleExternal
|
/// @returns a call expression to textureSampleExternal
|
||||||
const ast::CallExpression* createTextureSampleBaseClampToEdge(const ast::CallExpression* expr,
|
const CallExpression* createTextureSampleBaseClampToEdge(const CallExpression* expr,
|
||||||
NewBindingSymbols syms) {
|
NewBindingSymbols syms) {
|
||||||
const ast::Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
|
const Expression* plane_0_binding_param = ctx.Clone(expr->args[0]);
|
||||||
|
|
||||||
if (TINT_UNLIKELY(expr->args.Length() != 3)) {
|
if (TINT_UNLIKELY(expr->args.Length() != 3)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
|
@ -443,7 +443,7 @@ struct MultiplanarExternalTexture::State {
|
||||||
/// @param call the call expression being transformed
|
/// @param call the call expression being transformed
|
||||||
/// @param syms the expanded symbols to be used in the new call
|
/// @param syms the expanded symbols to be used in the new call
|
||||||
/// @returns a call expression to textureLoadExternal
|
/// @returns a call expression to textureLoadExternal
|
||||||
const ast::CallExpression* createTextureLoad(const sem::Call* call, NewBindingSymbols syms) {
|
const CallExpression* createTextureLoad(const sem::Call* call, NewBindingSymbols syms) {
|
||||||
if (TINT_UNLIKELY(call->Arguments().Length() != 2)) {
|
if (TINT_UNLIKELY(call->Arguments().Length() != 2)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "expected textureLoad call with a texture_external to have 2 arguments, found "
|
<< "expected textureLoad call with a texture_external to have 2 arguments, found "
|
||||||
|
|
|
@ -33,7 +33,7 @@ namespace {
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* node : program->ASTNodes().Objects()) {
|
for (auto* node : program->ASTNodes().Objects()) {
|
||||||
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
|
if (auto* attr = node->As<BuiltinAttribute>()) {
|
||||||
if (program->Sem().Get(attr)->Value() == builtin::BuiltinValue::kNumWorkgroups) {
|
if (program->Sem().Get(attr)->Value() == builtin::BuiltinValue::kNumWorkgroups) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -86,7 +86,7 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
|
||||||
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
|
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
|
||||||
for (auto* func : src->AST().Functions()) {
|
for (auto* func : src->AST().Functions()) {
|
||||||
// num_workgroups is only valid for compute stages.
|
// num_workgroups is only valid for compute stages.
|
||||||
if (func->PipelineStage() != ast::PipelineStage::kCompute) {
|
if (func->PipelineStage() != PipelineStage::kCompute) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
|
||||||
|
|
||||||
// Get (or create, on first call) the uniform buffer that will receive the
|
// Get (or create, on first call) the uniform buffer that will receive the
|
||||||
// number of workgroups.
|
// number of workgroups.
|
||||||
const ast::Variable* num_workgroups_ubo = nullptr;
|
const Variable* num_workgroups_ubo = nullptr;
|
||||||
auto get_ubo = [&]() {
|
auto get_ubo = [&]() {
|
||||||
if (!num_workgroups_ubo) {
|
if (!num_workgroups_ubo) {
|
||||||
auto* num_workgroups_struct =
|
auto* num_workgroups_struct =
|
||||||
|
@ -166,11 +166,11 @@ Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
|
||||||
// Now replace all the places where the builtins are accessed with the value
|
// Now replace all the places where the builtins are accessed with the value
|
||||||
// loaded from the uniform buffer.
|
// loaded from the uniform buffer.
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
auto* accessor = node->As<ast::MemberAccessorExpression>();
|
auto* accessor = node->As<MemberAccessorExpression>();
|
||||||
if (!accessor) {
|
if (!accessor) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto* ident = accessor->object->As<ast::IdentifierExpression>();
|
auto* ident = accessor->object->As<IdentifierExpression>();
|
||||||
if (!ident) {
|
if (!ident) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,7 +94,7 @@ struct PackedVec3::State {
|
||||||
/// Create a `__packed_vec3` type with the same element type as `ty`.
|
/// Create a `__packed_vec3` type with the same element type as `ty`.
|
||||||
/// @param ty a three-element vector type
|
/// @param ty a three-element vector type
|
||||||
/// @returns the new AST type
|
/// @returns the new AST type
|
||||||
ast::Type MakePackedVec3(const type::Type* ty) {
|
Type MakePackedVec3(const type::Type* ty) {
|
||||||
auto* vec = ty->As<type::Vector>();
|
auto* vec = ty->As<type::Vector>();
|
||||||
TINT_ASSERT(Transform, vec != nullptr && vec->Width() == 3);
|
TINT_ASSERT(Transform, vec != nullptr && vec->Width() == 3);
|
||||||
return b.ty(builtin::Builtin::kPackedVec3, CreateASTTypeFor(ctx, vec->type()));
|
return b.ty(builtin::Builtin::kPackedVec3, CreateASTTypeFor(ctx, vec->type()));
|
||||||
|
@ -109,10 +109,10 @@ struct PackedVec3::State {
|
||||||
/// @param ty the type to rewrite
|
/// @param ty the type to rewrite
|
||||||
/// @param array_element `true` if this is being called for the element of an array
|
/// @param array_element `true` if this is being called for the element of an array
|
||||||
/// @returns the new AST type, or nullptr if rewriting was not necessary
|
/// @returns the new AST type, or nullptr if rewriting was not necessary
|
||||||
ast::Type RewriteType(const type::Type* ty, bool array_element = false) {
|
Type RewriteType(const type::Type* ty, bool array_element = false) {
|
||||||
return Switch(
|
return Switch(
|
||||||
ty,
|
ty,
|
||||||
[&](const type::Vector* vec) -> ast::Type {
|
[&](const type::Vector* vec) -> Type {
|
||||||
if (IsVec3(vec)) {
|
if (IsVec3(vec)) {
|
||||||
if (array_element) {
|
if (array_element) {
|
||||||
// Create a struct with a single `__packed_vec3` member.
|
// Create a struct with a single `__packed_vec3` member.
|
||||||
|
@ -134,7 +134,7 @@ struct PackedVec3::State {
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
},
|
},
|
||||||
[&](const type::Matrix* mat) -> ast::Type {
|
[&](const type::Matrix* mat) -> Type {
|
||||||
// Rewrite the matrix as an array of columns that use the aligned wrapper struct.
|
// Rewrite the matrix as an array of columns that use the aligned wrapper struct.
|
||||||
auto new_col_type = RewriteType(mat->ColumnType(), /* array_element */ true);
|
auto new_col_type = RewriteType(mat->ColumnType(), /* array_element */ true);
|
||||||
if (new_col_type) {
|
if (new_col_type) {
|
||||||
|
@ -142,11 +142,11 @@ struct PackedVec3::State {
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
},
|
},
|
||||||
[&](const type::Array* arr) -> ast::Type {
|
[&](const type::Array* arr) -> Type {
|
||||||
// Rewrite the array with the modified element type.
|
// Rewrite the array with the modified element type.
|
||||||
auto new_type = RewriteType(arr->ElemType(), /* array_element */ true);
|
auto new_type = RewriteType(arr->ElemType(), /* array_element */ true);
|
||||||
if (new_type) {
|
if (new_type) {
|
||||||
utils::Vector<const ast::Attribute*, 1> attrs;
|
utils::Vector<const Attribute*, 1> attrs;
|
||||||
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
|
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
|
||||||
return b.ty.array(new_type, std::move(attrs));
|
return b.ty.array(new_type, std::move(attrs));
|
||||||
} else if (auto count = arr->ConstantCount()) {
|
} else if (auto count = arr->ConstantCount()) {
|
||||||
|
@ -159,21 +159,21 @@ struct PackedVec3::State {
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
},
|
},
|
||||||
[&](const type::Struct* str) -> ast::Type {
|
[&](const type::Struct* str) -> Type {
|
||||||
if (ContainsVec3(str)) {
|
if (ContainsVec3(str)) {
|
||||||
auto name = rewritten_structs.GetOrCreate(str, [&]() {
|
auto name = rewritten_structs.GetOrCreate(str, [&]() {
|
||||||
utils::Vector<const ast::StructMember*, 4> members;
|
utils::Vector<const StructMember*, 4> members;
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
// If the member type contains a vec3, rewrite it.
|
// If the member type contains a vec3, rewrite it.
|
||||||
auto new_type = RewriteType(member->Type());
|
auto new_type = RewriteType(member->Type());
|
||||||
if (new_type) {
|
if (new_type) {
|
||||||
// Copy the member attributes.
|
// Copy the member attributes.
|
||||||
bool needs_align = true;
|
bool needs_align = true;
|
||||||
utils::Vector<const ast::Attribute*, 4> attributes;
|
utils::Vector<const Attribute*, 4> attributes;
|
||||||
if (auto* sem_mem = member->As<sem::StructMember>()) {
|
if (auto* sem_mem = member->As<sem::StructMember>()) {
|
||||||
for (auto* attr : sem_mem->Declaration()->attributes) {
|
for (auto* attr : sem_mem->Declaration()->attributes) {
|
||||||
if (attr->IsAnyOf<ast::StructMemberAlignAttribute,
|
if (attr->IsAnyOf<StructMemberAlignAttribute,
|
||||||
ast::StructMemberOffsetAttribute>()) {
|
StructMemberOffsetAttribute>()) {
|
||||||
needs_align = false;
|
needs_align = false;
|
||||||
}
|
}
|
||||||
attributes.Push(ctx.Clone(attr));
|
attributes.Push(ctx.Clone(attr));
|
||||||
|
@ -219,12 +219,12 @@ struct PackedVec3::State {
|
||||||
Symbol MakePackUnpackHelper(
|
Symbol MakePackUnpackHelper(
|
||||||
const char* name_prefix,
|
const char* name_prefix,
|
||||||
const type::Type* ty,
|
const type::Type* ty,
|
||||||
const std::function<const ast::Expression*(const ast::Expression*, const type::Type*)>&
|
const std::function<const Expression*(const Expression*, const type::Type*)>&
|
||||||
pack_or_unpack_element,
|
pack_or_unpack_element,
|
||||||
const std::function<ast::Type()>& in_type,
|
const std::function<Type()>& in_type,
|
||||||
const std::function<ast::Type()>& out_type) {
|
const std::function<Type()>& out_type) {
|
||||||
// Allocate a variable to hold the return value of the function.
|
// Allocate a variable to hold the return value of the function.
|
||||||
utils::Vector<const ast::Statement*, 4> statements;
|
utils::Vector<const Statement*, 4> statements;
|
||||||
statements.Push(b.Decl(b.Var("result", out_type())));
|
statements.Push(b.Decl(b.Var("result", out_type())));
|
||||||
|
|
||||||
// Helper that generates a loop to copy and pack/unpack elements of an array to the result:
|
// Helper that generates a loop to copy and pack/unpack elements of an array to the result:
|
||||||
|
@ -256,7 +256,7 @@ struct PackedVec3::State {
|
||||||
[&](const type::Struct* str) {
|
[&](const type::Struct* str) {
|
||||||
// Copy the struct members over one at a time, packing/unpacking as necessary.
|
// Copy the struct members over one at a time, packing/unpacking as necessary.
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
const ast::Expression* element =
|
const Expression* element =
|
||||||
b.MemberAccessor("in", b.Ident(ctx.Clone(member->Name())));
|
b.MemberAccessor("in", b.Ident(ctx.Clone(member->Name())));
|
||||||
if (ContainsVec3(member->Type())) {
|
if (ContainsVec3(member->Type())) {
|
||||||
element = pack_or_unpack_element(element, member->Type());
|
element = pack_or_unpack_element(element, member->Type());
|
||||||
|
@ -280,16 +280,16 @@ struct PackedVec3::State {
|
||||||
/// @param expr the composite value expression to unpack
|
/// @param expr the composite value expression to unpack
|
||||||
/// @param ty the unpacked type
|
/// @param ty the unpacked type
|
||||||
/// @returns an expression that holds the unpacked value
|
/// @returns an expression that holds the unpacked value
|
||||||
const ast::Expression* UnpackComposite(const ast::Expression* expr, const type::Type* ty) {
|
const Expression* UnpackComposite(const Expression* expr, const type::Type* ty) {
|
||||||
auto helper = unpack_helpers.GetOrCreate(ty, [&]() {
|
auto helper = unpack_helpers.GetOrCreate(ty, [&]() {
|
||||||
return MakePackUnpackHelper(
|
return MakePackUnpackHelper(
|
||||||
"tint_unpack_vec3_in_composite", ty,
|
"tint_unpack_vec3_in_composite", ty,
|
||||||
[&](const ast::Expression* element,
|
[&](const Expression* element,
|
||||||
const type::Type* element_type) -> const ast::Expression* {
|
const type::Type* element_type) -> const Expression* {
|
||||||
if (element_type->Is<type::Vector>()) {
|
if (element_type->Is<type::Vector>()) {
|
||||||
// Unpack a `__packed_vec3` by casting it to a regular vec3.
|
// Unpack a `__packed_vec3` by casting it to a regular vec3.
|
||||||
// If it is an array element, extract the vector from the wrapper struct.
|
// If it is an array element, extract the vector from the wrapper struct.
|
||||||
if (element->Is<ast::IndexAccessorExpression>()) {
|
if (element->Is<IndexAccessorExpression>()) {
|
||||||
element = b.MemberAccessor(element, kStructMemberName);
|
element = b.MemberAccessor(element, kStructMemberName);
|
||||||
}
|
}
|
||||||
return b.Call(CreateASTTypeFor(ctx, element_type), element);
|
return b.Call(CreateASTTypeFor(ctx, element_type), element);
|
||||||
|
@ -308,17 +308,17 @@ struct PackedVec3::State {
|
||||||
/// @param expr the composite value expression to pack
|
/// @param expr the composite value expression to pack
|
||||||
/// @param ty the unpacked type
|
/// @param ty the unpacked type
|
||||||
/// @returns an expression that holds the packed value
|
/// @returns an expression that holds the packed value
|
||||||
const ast::Expression* PackComposite(const ast::Expression* expr, const type::Type* ty) {
|
const Expression* PackComposite(const Expression* expr, const type::Type* ty) {
|
||||||
auto helper = pack_helpers.GetOrCreate(ty, [&]() {
|
auto helper = pack_helpers.GetOrCreate(ty, [&]() {
|
||||||
return MakePackUnpackHelper(
|
return MakePackUnpackHelper(
|
||||||
"tint_pack_vec3_in_composite", ty,
|
"tint_pack_vec3_in_composite", ty,
|
||||||
[&](const ast::Expression* element,
|
[&](const Expression* element,
|
||||||
const type::Type* element_type) -> const ast::Expression* {
|
const type::Type* element_type) -> const Expression* {
|
||||||
if (element_type->Is<type::Vector>()) {
|
if (element_type->Is<type::Vector>()) {
|
||||||
// Pack a vector element by casting it to a packed_vec3.
|
// Pack a vector element by casting it to a packed_vec3.
|
||||||
// If it is an array element, construct a wrapper struct.
|
// If it is an array element, construct a wrapper struct.
|
||||||
auto* packed = b.Call(MakePackedVec3(element_type), element);
|
auto* packed = b.Call(MakePackedVec3(element_type), element);
|
||||||
if (element->Is<ast::IndexAccessorExpression>()) {
|
if (element->Is<IndexAccessorExpression>()) {
|
||||||
packed = b.Call(RewriteType(element_type, true), packed);
|
packed = b.Call(RewriteType(element_type, true), packed);
|
||||||
}
|
}
|
||||||
return packed;
|
return packed;
|
||||||
|
@ -400,7 +400,7 @@ struct PackedVec3::State {
|
||||||
},
|
},
|
||||||
[&](const sem::Statement* stmt) {
|
[&](const sem::Statement* stmt) {
|
||||||
// Pack the RHS of assignment statements that are writing to packed types.
|
// Pack the RHS of assignment statements that are writing to packed types.
|
||||||
if (auto* assign = stmt->Declaration()->As<ast::AssignmentStatement>()) {
|
if (auto* assign = stmt->Declaration()->As<AssignmentStatement>()) {
|
||||||
auto* lhs = sem.GetVal(assign->lhs);
|
auto* lhs = sem.GetVal(assign->lhs);
|
||||||
auto* rhs = sem.GetVal(assign->rhs);
|
auto* rhs = sem.GetVal(assign->rhs);
|
||||||
if (!ContainsVec3(rhs->Type()) ||
|
if (!ContainsVec3(rhs->Type()) ||
|
||||||
|
@ -463,7 +463,7 @@ struct PackedVec3::State {
|
||||||
for (auto* expr : to_unpack_sorted) {
|
for (auto* expr : to_unpack_sorted) {
|
||||||
TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
|
TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
|
||||||
auto* packed = ctx.Clone(expr->Declaration());
|
auto* packed = ctx.Clone(expr->Declaration());
|
||||||
const ast::Expression* unpacked = nullptr;
|
const Expression* unpacked = nullptr;
|
||||||
if (IsVec3(expr->Type())) {
|
if (IsVec3(expr->Type())) {
|
||||||
if (expr->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
|
if (expr->UnwrapLoad()->Is<sem::IndexAccessorExpression>()) {
|
||||||
// If we are unpacking a vec3 from an array element, extract the vector from the
|
// If we are unpacking a vec3 from an array element, extract the vector from the
|
||||||
|
@ -484,7 +484,7 @@ struct PackedVec3::State {
|
||||||
for (auto* expr : to_pack_sorted) {
|
for (auto* expr : to_pack_sorted) {
|
||||||
TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
|
TINT_ASSERT(Transform, ContainsVec3(expr->Type()));
|
||||||
auto* unpacked = ctx.Clone(expr->Declaration());
|
auto* unpacked = ctx.Clone(expr->Declaration());
|
||||||
const ast::Expression* packed = nullptr;
|
const Expression* packed = nullptr;
|
||||||
if (IsVec3(expr->Type())) {
|
if (IsVec3(expr->Type())) {
|
||||||
// Cast the regular vec3 to a packed vector type.
|
// Cast the regular vec3 to a packed vector type.
|
||||||
packed = b.Call(MakePackedVec3(expr->Type()), unpacked);
|
packed = b.Call(MakePackedVec3(expr->Type()), unpacked);
|
||||||
|
|
|
@ -33,8 +33,8 @@ namespace tint::ast::transform {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void CreatePadding(utils::Vector<const ast::StructMember*, 8>* new_members,
|
void CreatePadding(utils::Vector<const StructMember*, 8>* new_members,
|
||||||
utils::Hashset<const ast::StructMember*, 8>* padding_members,
|
utils::Hashset<const StructMember*, 8>* padding_members,
|
||||||
ProgramBuilder* b,
|
ProgramBuilder* b,
|
||||||
uint32_t bytes) {
|
uint32_t bytes) {
|
||||||
const size_t count = bytes / 4u;
|
const size_t count = bytes / 4u;
|
||||||
|
@ -59,17 +59,17 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
|
||||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||||
auto& sem = src->Sem();
|
auto& sem = src->Sem();
|
||||||
|
|
||||||
std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs;
|
std::unordered_map<const Struct*, const Struct*> replaced_structs;
|
||||||
utils::Hashset<const ast::StructMember*, 8> padding_members;
|
utils::Hashset<const StructMember*, 8> padding_members;
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::Struct* ast_str) -> const ast::Struct* {
|
ctx.ReplaceAll([&](const Struct* ast_str) -> const Struct* {
|
||||||
auto* str = sem.Get(ast_str);
|
auto* str = sem.Get(ast_str);
|
||||||
if (!str || !str->IsHostShareable()) {
|
if (!str || !str->IsHostShareable()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
uint32_t offset = 0;
|
uint32_t offset = 0;
|
||||||
bool has_runtime_sized_array = false;
|
bool has_runtime_sized_array = false;
|
||||||
utils::Vector<const ast::StructMember*, 8> new_members;
|
utils::Vector<const StructMember*, 8> new_members;
|
||||||
for (auto* mem : str->Members()) {
|
for (auto* mem : str->Members()) {
|
||||||
auto name = mem->Name().Name();
|
auto name = mem->Name().Name();
|
||||||
|
|
||||||
|
@ -104,19 +104,18 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
|
||||||
CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
|
CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
utils::Vector<const ast::Attribute*, 1> struct_attribs;
|
utils::Vector<const Attribute*, 1> struct_attribs;
|
||||||
if (!padding_members.IsEmpty()) {
|
if (!padding_members.IsEmpty()) {
|
||||||
struct_attribs =
|
struct_attribs = utils::Vector{b.Disable(DisabledValidation::kIgnoreStructMemberLimit)};
|
||||||
utils::Vector{b.Disable(ast::DisabledValidation::kIgnoreStructMemberLimit)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* new_struct = b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members),
|
auto* new_struct = b.create<Struct>(ctx.Clone(ast_str->name), std::move(new_members),
|
||||||
std::move(struct_attribs));
|
std::move(struct_attribs));
|
||||||
replaced_structs[ast_str] = new_struct;
|
replaced_structs[ast_str] = new_struct;
|
||||||
return new_struct;
|
return new_struct;
|
||||||
});
|
});
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::CallExpression* ast_call) -> const ast::CallExpression* {
|
ctx.ReplaceAll([&](const CallExpression* ast_call) -> const CallExpression* {
|
||||||
if (ast_call->args.Length() == 0) {
|
if (ast_call->args.Length() == 0) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -139,7 +138,7 @@ Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, Dat
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
utils::Vector<const ast::Expression*, 8> new_args;
|
utils::Vector<const Expression*, 8> new_args;
|
||||||
|
|
||||||
auto* arg = ast_call->args.begin();
|
auto* arg = ast_call->args.begin();
|
||||||
for (auto* member : new_struct->members) {
|
for (auto* member : new_struct->members) {
|
||||||
|
|
|
@ -44,13 +44,13 @@ struct PreservePadding::State {
|
||||||
/// @returns the ApplyResult
|
/// @returns the ApplyResult
|
||||||
ApplyResult Run() {
|
ApplyResult Run() {
|
||||||
// Gather a list of assignments that need to be transformed.
|
// Gather a list of assignments that need to be transformed.
|
||||||
std::unordered_set<const ast::AssignmentStatement*> assignments_to_transform;
|
std::unordered_set<const AssignmentStatement*> assignments_to_transform;
|
||||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||||
Switch(
|
Switch(
|
||||||
node, //
|
node, //
|
||||||
[&](const ast::AssignmentStatement* assign) {
|
[&](const AssignmentStatement* assign) {
|
||||||
auto* ty = sem.GetVal(assign->lhs)->Type();
|
auto* ty = sem.GetVal(assign->lhs)->Type();
|
||||||
if (assign->lhs->Is<ast::PhonyExpression>()) {
|
if (assign->lhs->Is<PhonyExpression>()) {
|
||||||
// Ignore phony assignment.
|
// Ignore phony assignment.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -65,7 +65,7 @@ struct PreservePadding::State {
|
||||||
assignments_to_transform.insert(assign);
|
assignments_to_transform.insert(assign);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::Enable* enable) {
|
[&](const Enable* enable) {
|
||||||
// Check if the full pointer parameters extension is already enabled.
|
// Check if the full pointer parameters extension is already enabled.
|
||||||
if (enable->HasExtension(
|
if (enable->HasExtension(
|
||||||
builtin::Extension::kChromiumExperimentalFullPtrParameters)) {
|
builtin::Extension::kChromiumExperimentalFullPtrParameters)) {
|
||||||
|
@ -78,7 +78,7 @@ struct PreservePadding::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace all assignments that include padding with decomposed versions.
|
// Replace all assignments that include padding with decomposed versions.
|
||||||
ctx.ReplaceAll([&](const ast::AssignmentStatement* assign) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const AssignmentStatement* assign) -> const Statement* {
|
||||||
if (!assignments_to_transform.count(assign)) {
|
if (!assignments_to_transform.count(assign)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -96,9 +96,9 @@ struct PreservePadding::State {
|
||||||
/// @param lhs the lhs expression (in the destination program)
|
/// @param lhs the lhs expression (in the destination program)
|
||||||
/// @param rhs the rhs expression (in the destination program)
|
/// @param rhs the rhs expression (in the destination program)
|
||||||
/// @returns the statement that performs the assignment
|
/// @returns the statement that performs the assignment
|
||||||
const ast::Statement* MakeAssignment(const type::Type* ty,
|
const Statement* MakeAssignment(const type::Type* ty,
|
||||||
const ast::Expression* lhs,
|
const Expression* lhs,
|
||||||
const ast::Expression* rhs) {
|
const Expression* rhs) {
|
||||||
if (!HasPadding(ty)) {
|
if (!HasPadding(ty)) {
|
||||||
// No padding - use a regular assignment.
|
// No padding - use a regular assignment.
|
||||||
return b.Assign(lhs, rhs);
|
return b.Assign(lhs, rhs);
|
||||||
|
@ -120,7 +120,7 @@ struct PreservePadding::State {
|
||||||
EnableExtension();
|
EnableExtension();
|
||||||
auto helper = helpers.GetOrCreate(ty, [&]() {
|
auto helper = helpers.GetOrCreate(ty, [&]() {
|
||||||
auto helper_name = b.Symbols().New("assign_and_preserve_padding");
|
auto helper_name = b.Symbols().New("assign_and_preserve_padding");
|
||||||
utils::Vector<const ast::Parameter*, 2> params = {
|
utils::Vector<const Parameter*, 2> params = {
|
||||||
b.Param(kDestParamName,
|
b.Param(kDestParamName,
|
||||||
b.ty.pointer(CreateASTTypeFor(ctx, ty), builtin::AddressSpace::kStorage,
|
b.ty.pointer(CreateASTTypeFor(ctx, ty), builtin::AddressSpace::kStorage,
|
||||||
builtin::Access::kReadWrite)),
|
builtin::Access::kReadWrite)),
|
||||||
|
@ -137,7 +137,7 @@ struct PreservePadding::State {
|
||||||
[&](const type::Array* arr) {
|
[&](const type::Array* arr) {
|
||||||
// Call a helper function that uses a loop to assigns each element separately.
|
// Call a helper function that uses a loop to assigns each element separately.
|
||||||
return call_helper([&]() {
|
return call_helper([&]() {
|
||||||
utils::Vector<const ast::Statement*, 8> body;
|
utils::Vector<const Statement*, 8> body;
|
||||||
auto* idx = b.Var("i", b.Expr(0_u));
|
auto* idx = b.Var("i", b.Expr(0_u));
|
||||||
body.Push(
|
body.Push(
|
||||||
b.For(b.Decl(idx), b.LessThan(idx, u32(arr->ConstantCount().value())),
|
b.For(b.Decl(idx), b.LessThan(idx, u32(arr->ConstantCount().value())),
|
||||||
|
@ -151,7 +151,7 @@ struct PreservePadding::State {
|
||||||
[&](const type::Matrix* mat) {
|
[&](const type::Matrix* mat) {
|
||||||
// Call a helper function that assigns each column separately.
|
// Call a helper function that assigns each column separately.
|
||||||
return call_helper([&]() {
|
return call_helper([&]() {
|
||||||
utils::Vector<const ast::Statement*, 4> body;
|
utils::Vector<const Statement*, 4> body;
|
||||||
for (uint32_t i = 0; i < mat->columns(); i++) {
|
for (uint32_t i = 0; i < mat->columns(); i++) {
|
||||||
body.Push(MakeAssignment(mat->ColumnType(),
|
body.Push(MakeAssignment(mat->ColumnType(),
|
||||||
b.IndexAccessor(b.Deref(kDestParamName), u32(i)),
|
b.IndexAccessor(b.Deref(kDestParamName), u32(i)),
|
||||||
|
@ -163,7 +163,7 @@ struct PreservePadding::State {
|
||||||
[&](const type::Struct* str) {
|
[&](const type::Struct* str) {
|
||||||
// Call a helper function that assigns each member separately.
|
// Call a helper function that assigns each member separately.
|
||||||
return call_helper([&]() {
|
return call_helper([&]() {
|
||||||
utils::Vector<const ast::Statement*, 8> body;
|
utils::Vector<const Statement*, 8> body;
|
||||||
for (auto member : str->Members()) {
|
for (auto member : str->Members()) {
|
||||||
auto name = member->Name().Name();
|
auto name = member->Name().Name();
|
||||||
body.Push(MakeAssignment(member->Type(),
|
body.Push(MakeAssignment(member->Type(),
|
||||||
|
|
|
@ -69,7 +69,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* src_var_decl = expr->Stmt()->Declaration()->As<ast::VariableDeclStatement>()) {
|
if (auto* src_var_decl = expr->Stmt()->Declaration()->As<VariableDeclStatement>()) {
|
||||||
if (src_var_decl->variable->initializer == expr->Declaration()) {
|
if (src_var_decl->variable->initializer == expr->Declaration()) {
|
||||||
// This statement is just a variable declaration with the initializer as the
|
// This statement is just a variable declaration with the initializer as the
|
||||||
// initializer value. This is what we're attempting to transform to, and so
|
// initializer value. This is what we're attempting to transform to, and so
|
||||||
|
@ -84,7 +84,7 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
|
||||||
// A list of expressions that should be hoisted.
|
// A list of expressions that should be hoisted.
|
||||||
utils::Vector<const sem::ValueExpression*, 32> to_hoist;
|
utils::Vector<const sem::ValueExpression*, 32> to_hoist;
|
||||||
// A set of expressions that are constant, which _may_ need to be hoisted.
|
// A set of expressions that are constant, which _may_ need to be hoisted.
|
||||||
utils::Hashset<const ast::Expression*, 32> const_chains;
|
utils::Hashset<const Expression*, 32> const_chains;
|
||||||
|
|
||||||
// Walk the AST nodes. This order guarantees that leaf-expressions are visited first.
|
// Walk the AST nodes. This order guarantees that leaf-expressions are visited first.
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
|
@ -104,11 +104,9 @@ Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
|
||||||
// visit leaf-expressions first, this means the content of const_chains only
|
// visit leaf-expressions first, this means the content of const_chains only
|
||||||
// contains the outer-most constant expressions.
|
// contains the outer-most constant expressions.
|
||||||
auto* expr = sem->Declaration();
|
auto* expr = sem->Declaration();
|
||||||
bool ok = ast::TraverseExpressions(
|
bool ok = TraverseExpressions(expr, b.Diagnostics(), [&](const Expression* child) {
|
||||||
expr, b.Diagnostics(), [&](const ast::Expression* child) {
|
|
||||||
const_chains.Remove(child);
|
const_chains.Remove(child);
|
||||||
return child == expr ? ast::TraverseAction::Descend
|
return child == expr ? TraverseAction::Descend : TraverseAction::Skip;
|
||||||
: ast::TraverseAction::Skip;
|
|
||||||
});
|
});
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
return Program(std::move(b));
|
return Program(std::move(b));
|
||||||
|
|
|
@ -97,49 +97,48 @@ struct DecomposeSideEffects : tint::utils::Castable<PromoteSideEffectsToDecl, Tr
|
||||||
// need to be hoisted to ensure order of evaluation, both those that give
|
// need to be hoisted to ensure order of evaluation, both those that give
|
||||||
// side-effects, as well as those that receive, and returns a set of these
|
// side-effects, as well as those that receive, and returns a set of these
|
||||||
// expressions.
|
// expressions.
|
||||||
using ToHoistSet = std::unordered_set<const ast::Expression*>;
|
using ToHoistSet = std::unordered_set<const Expression*>;
|
||||||
class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
// Expressions to hoist because they either cause or receive side-effects.
|
// Expressions to hoist because they either cause or receive side-effects.
|
||||||
ToHoistSet to_hoist;
|
ToHoistSet to_hoist;
|
||||||
|
|
||||||
// Used to mark expressions as not or no longer having side-effects.
|
// Used to mark expressions as not or no longer having side-effects.
|
||||||
std::unordered_set<const ast::Expression*> no_side_effects;
|
std::unordered_set<const Expression*> no_side_effects;
|
||||||
|
|
||||||
// Returns true if `expr` has side-effects. Unlike invoking
|
// Returns true if `expr` has side-effects. Unlike invoking
|
||||||
// sem::ValueExpression::HasSideEffects(), this function takes into account whether
|
// sem::ValueExpression::HasSideEffects(), this function takes into account whether
|
||||||
// `expr` has been hoisted, returning false in that case. Furthermore, it
|
// `expr` has been hoisted, returning false in that case. Furthermore, it
|
||||||
// returns the correct result on parent expression nodes by traversing the
|
// returns the correct result on parent expression nodes by traversing the
|
||||||
// expression tree, memoizing the results to ensure O(1) amortized lookup.
|
// expression tree, memoizing the results to ensure O(1) amortized lookup.
|
||||||
bool HasSideEffects(const ast::Expression* expr) {
|
bool HasSideEffects(const Expression* expr) {
|
||||||
if (no_side_effects.count(expr)) {
|
if (no_side_effects.count(expr)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Switch(
|
return Switch(
|
||||||
expr,
|
expr, [&](const CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); },
|
||||||
[&](const ast::CallExpression* e) -> bool { return sem.Get(e)->HasSideEffects(); },
|
[&](const BinaryExpression* e) {
|
||||||
[&](const ast::BinaryExpression* e) {
|
|
||||||
if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) {
|
if (HasSideEffects(e->lhs) || HasSideEffects(e->rhs)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::IndexAccessorExpression* e) {
|
[&](const IndexAccessorExpression* e) {
|
||||||
if (HasSideEffects(e->object) || HasSideEffects(e->index)) {
|
if (HasSideEffects(e->object) || HasSideEffects(e->index)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::MemberAccessorExpression* e) {
|
[&](const MemberAccessorExpression* e) {
|
||||||
if (HasSideEffects(e->object)) {
|
if (HasSideEffects(e->object)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::BitcastExpression* e) { //
|
[&](const BitcastExpression* e) { //
|
||||||
if (HasSideEffects(e->expr)) {
|
if (HasSideEffects(e->expr)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -147,22 +146,22 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
|
|
||||||
[&](const ast::UnaryOpExpression* e) { //
|
[&](const UnaryOpExpression* e) { //
|
||||||
if (HasSideEffects(e->expr)) {
|
if (HasSideEffects(e->expr)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::IdentifierExpression* e) {
|
[&](const IdentifierExpression* e) {
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::LiteralExpression* e) {
|
[&](const LiteralExpression* e) {
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::PhonyExpression* e) {
|
[&](const PhonyExpression* e) {
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
|
@ -173,14 +172,14 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds `e` to `to_hoist` for hoisting to a let later on.
|
// Adds `e` to `to_hoist` for hoisting to a let later on.
|
||||||
void Hoist(const ast::Expression* e) {
|
void Hoist(const Expression* e) {
|
||||||
no_side_effects.insert(e);
|
no_side_effects.insert(e);
|
||||||
to_hoist.emplace(e);
|
to_hoist.emplace(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hoists any expressions in `maybe_hoist` and clears it
|
// Hoists any expressions in `maybe_hoist` and clears it
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
void Flush(tint::utils::Vector<const ast::Expression*, N>& maybe_hoist) {
|
void Flush(tint::utils::Vector<const Expression*, N>& maybe_hoist) {
|
||||||
for (auto* m : maybe_hoist) {
|
for (auto* m : maybe_hoist) {
|
||||||
Hoist(m);
|
Hoist(m);
|
||||||
}
|
}
|
||||||
|
@ -198,13 +197,13 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
// over-hoist the lhs expressions, as these may be be chained to refer to a
|
// over-hoist the lhs expressions, as these may be be chained to refer to a
|
||||||
// single memory location.
|
// single memory location.
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
bool ProcessExpression(const ast::Expression* expr,
|
bool ProcessExpression(const Expression* expr,
|
||||||
tint::utils::Vector<const ast::Expression*, N>& maybe_hoist) {
|
tint::utils::Vector<const Expression*, N>& maybe_hoist) {
|
||||||
auto process = [&](const ast::Expression* e) -> bool {
|
auto process = [&](const Expression* e) -> bool {
|
||||||
return ProcessExpression(e, maybe_hoist);
|
return ProcessExpression(e, maybe_hoist);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto default_process = [&](const ast::Expression* e) {
|
auto default_process = [&](const Expression* e) {
|
||||||
auto maybe = process(e);
|
auto maybe = process(e);
|
||||||
if (maybe) {
|
if (maybe) {
|
||||||
maybe_hoist.Push(e);
|
maybe_hoist.Push(e);
|
||||||
|
@ -215,7 +214,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto binary_process = [&](const ast::Expression* lhs, const ast::Expression* rhs) {
|
auto binary_process = [&](const Expression* lhs, const Expression* rhs) {
|
||||||
// If neither side causes side-effects, but at least one receives them,
|
// If neither side causes side-effects, but at least one receives them,
|
||||||
// let parent node hoist. This avoids over-hoisting side-effect receivers
|
// let parent node hoist. This avoids over-hoisting side-effect receivers
|
||||||
// of compound binary expressions (e.g. for "((a && b) && c) && f()", we
|
// of compound binary expressions (e.g. for "((a && b) && c) && f()", we
|
||||||
|
@ -235,8 +234,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto accessor_process = [&](const ast::Expression* lhs,
|
auto accessor_process = [&](const Expression* lhs, const Expression* rhs = nullptr) {
|
||||||
const ast::Expression* rhs = nullptr) {
|
|
||||||
auto maybe = process(lhs);
|
auto maybe = process(lhs);
|
||||||
// If lhs is a variable, let parent node hoist otherwise flush it right
|
// If lhs is a variable, let parent node hoist otherwise flush it right
|
||||||
// away. This is to avoid over-hoisting the lhs of accessor chains (e.g.
|
// away. This is to avoid over-hoisting the lhs of accessor chains (e.g.
|
||||||
|
@ -255,7 +253,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
|
|
||||||
return Switch(
|
return Switch(
|
||||||
expr,
|
expr,
|
||||||
[&](const ast::CallExpression* e) -> bool {
|
[&](const CallExpression* e) -> bool {
|
||||||
// We eagerly flush any variables in maybe_hoist for the current
|
// We eagerly flush any variables in maybe_hoist for the current
|
||||||
// call expression. Then we scope maybe_hoist to the processing of
|
// call expression. Then we scope maybe_hoist to the processing of
|
||||||
// the call args. This ensures that given: g(c, a(0), d) we hoist
|
// the call args. This ensures that given: g(c, a(0), d) we hoist
|
||||||
|
@ -276,7 +274,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
// no_side_effects() first.
|
// no_side_effects() first.
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[&](const ast::IdentifierExpression* e) {
|
[&](const IdentifierExpression* e) {
|
||||||
if (auto* sem_e = sem.GetVal(e)) {
|
if (auto* sem_e = sem.GetVal(e)) {
|
||||||
if (auto* var_user = sem_e->UnwrapLoad()->As<sem::VariableUser>()) {
|
if (auto* var_user = sem_e->UnwrapLoad()->As<sem::VariableUser>()) {
|
||||||
// Don't hoist constants.
|
// Don't hoist constants.
|
||||||
|
@ -297,7 +295,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::BinaryExpression* e) {
|
[&](const BinaryExpression* e) {
|
||||||
if (e->IsLogical() && HasSideEffects(e)) {
|
if (e->IsLogical() && HasSideEffects(e)) {
|
||||||
// Don't hoist children of logical binary expressions with
|
// Don't hoist children of logical binary expressions with
|
||||||
// side-effects. These will be handled by DecomposeState.
|
// side-effects. These will be handled by DecomposeState.
|
||||||
|
@ -307,27 +305,25 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
}
|
}
|
||||||
return binary_process(e->lhs, e->rhs);
|
return binary_process(e->lhs, e->rhs);
|
||||||
},
|
},
|
||||||
[&](const ast::BitcastExpression* e) { //
|
[&](const BitcastExpression* e) { //
|
||||||
return process(e->expr);
|
return process(e->expr);
|
||||||
},
|
},
|
||||||
[&](const ast::UnaryOpExpression* e) { //
|
[&](const UnaryOpExpression* e) { //
|
||||||
auto r = process(e->expr);
|
auto r = process(e->expr);
|
||||||
// Don't hoist address-of expressions.
|
// Don't hoist address-of expressions.
|
||||||
// E.g. for "g(&b, a(0))", we hoist "a(0)" only.
|
// E.g. for "g(&b, a(0))", we hoist "a(0)" only.
|
||||||
if (e->op == ast::UnaryOp::kAddressOf) {
|
if (e->op == UnaryOp::kAddressOf) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return r;
|
return r;
|
||||||
},
|
},
|
||||||
[&](const ast::IndexAccessorExpression* e) {
|
[&](const IndexAccessorExpression* e) { return accessor_process(e->object, e->index); },
|
||||||
return accessor_process(e->object, e->index);
|
[&](const MemberAccessorExpression* e) { return accessor_process(e->object); },
|
||||||
},
|
[&](const LiteralExpression*) {
|
||||||
[&](const ast::MemberAccessorExpression* e) { return accessor_process(e->object); },
|
|
||||||
[&](const ast::LiteralExpression*) {
|
|
||||||
// Leaf
|
// Leaf
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
[&](const ast::PhonyExpression*) {
|
[&](const PhonyExpression*) {
|
||||||
// Leaf
|
// Leaf
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
|
@ -338,12 +334,12 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Starts the recursive processing of a statement's expression(s) to hoist side-effects to lets.
|
// Starts the recursive processing of a statement's expression(s) to hoist side-effects to lets.
|
||||||
void ProcessExpression(const ast::Expression* expr) {
|
void ProcessExpression(const Expression* expr) {
|
||||||
if (!expr) {
|
if (!expr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
tint::utils::Vector<const ast::Expression*, 8> maybe_hoist;
|
tint::utils::Vector<const Expression*, 8> maybe_hoist;
|
||||||
ProcessExpression(expr, maybe_hoist);
|
ProcessExpression(expr, maybe_hoist);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,31 +350,31 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
|
||||||
// Traverse all statements, recursively processing their expression tree(s)
|
// Traverse all statements, recursively processing their expression tree(s)
|
||||||
// to hoist side-effects to lets.
|
// to hoist side-effects to lets.
|
||||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||||
auto* stmt = node->As<ast::Statement>();
|
auto* stmt = node->As<Statement>();
|
||||||
if (!stmt) {
|
if (!stmt) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Switch(
|
Switch(
|
||||||
stmt, //
|
stmt, //
|
||||||
[&](const ast::AssignmentStatement* s) {
|
[&](const AssignmentStatement* s) {
|
||||||
tint::utils::Vector<const ast::Expression*, 8> maybe_hoist;
|
tint::utils::Vector<const Expression*, 8> maybe_hoist;
|
||||||
ProcessExpression(s->lhs, maybe_hoist);
|
ProcessExpression(s->lhs, maybe_hoist);
|
||||||
ProcessExpression(s->rhs, maybe_hoist);
|
ProcessExpression(s->rhs, maybe_hoist);
|
||||||
},
|
},
|
||||||
[&](const ast::CallStatement* s) { //
|
[&](const CallStatement* s) { //
|
||||||
ProcessExpression(s->expr);
|
ProcessExpression(s->expr);
|
||||||
},
|
},
|
||||||
[&](const ast::ForLoopStatement* s) { ProcessExpression(s->condition); },
|
[&](const ForLoopStatement* s) { ProcessExpression(s->condition); },
|
||||||
[&](const ast::WhileStatement* s) { ProcessExpression(s->condition); },
|
[&](const WhileStatement* s) { ProcessExpression(s->condition); },
|
||||||
[&](const ast::IfStatement* s) { //
|
[&](const IfStatement* s) { //
|
||||||
ProcessExpression(s->condition);
|
ProcessExpression(s->condition);
|
||||||
},
|
},
|
||||||
[&](const ast::ReturnStatement* s) { //
|
[&](const ReturnStatement* s) { //
|
||||||
ProcessExpression(s->value);
|
ProcessExpression(s->value);
|
||||||
},
|
},
|
||||||
[&](const ast::SwitchStatement* s) { ProcessExpression(s->condition); },
|
[&](const SwitchStatement* s) { ProcessExpression(s->condition); },
|
||||||
[&](const ast::VariableDeclStatement* s) {
|
[&](const VariableDeclStatement* s) {
|
||||||
ProcessExpression(s->variable->initializer);
|
ProcessExpression(s->variable->initializer);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -394,20 +390,20 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
ToHoistSet to_hoist;
|
ToHoistSet to_hoist;
|
||||||
|
|
||||||
// Returns true if `binary_expr` should be decomposed for short-circuit eval.
|
// Returns true if `binary_expr` should be decomposed for short-circuit eval.
|
||||||
bool IsLogicalWithSideEffects(const ast::BinaryExpression* binary_expr) {
|
bool IsLogicalWithSideEffects(const BinaryExpression* binary_expr) {
|
||||||
return binary_expr->IsLogical() && (sem.GetVal(binary_expr->lhs)->HasSideEffects() ||
|
return binary_expr->IsLogical() && (sem.GetVal(binary_expr->lhs)->HasSideEffects() ||
|
||||||
sem.GetVal(binary_expr->rhs)->HasSideEffects());
|
sem.GetVal(binary_expr->rhs)->HasSideEffects());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recursive function used to decompose an expression for short-circuit eval.
|
// Recursive function used to decompose an expression for short-circuit eval.
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
const ast::Expression* Decompose(const ast::Expression* expr,
|
const Expression* Decompose(const Expression* expr,
|
||||||
tint::utils::Vector<const ast::Statement*, N>* curr_stmts) {
|
tint::utils::Vector<const Statement*, N>* curr_stmts) {
|
||||||
// Helper to avoid passing in same args.
|
// Helper to avoid passing in same args.
|
||||||
auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); };
|
auto decompose = [&](auto& e) { return Decompose(e, curr_stmts); };
|
||||||
|
|
||||||
// Clones `expr`, possibly hoisting it to a let.
|
// Clones `expr`, possibly hoisting it to a let.
|
||||||
auto clone_maybe_hoisted = [&](const ast::Expression* e) -> const ast::Expression* {
|
auto clone_maybe_hoisted = [&](const Expression* e) -> const Expression* {
|
||||||
if (to_hoist.count(e)) {
|
if (to_hoist.count(e)) {
|
||||||
auto name = b.Symbols().New();
|
auto name = b.Symbols().New();
|
||||||
auto* v = b.Let(name, ctx.Clone(e));
|
auto* v = b.Let(name, ctx.Clone(e));
|
||||||
|
@ -420,7 +416,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
|
|
||||||
return Switch(
|
return Switch(
|
||||||
expr,
|
expr,
|
||||||
[&](const ast::BinaryExpression* bin_expr) -> const ast::Expression* {
|
[&](const BinaryExpression* bin_expr) -> const Expression* {
|
||||||
if (!IsLogicalWithSideEffects(bin_expr)) {
|
if (!IsLogicalWithSideEffects(bin_expr)) {
|
||||||
// No short-circuit, emit usual binary expr
|
// No short-circuit, emit usual binary expr
|
||||||
ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs));
|
ctx.Replace(bin_expr->lhs, decompose(bin_expr->lhs));
|
||||||
|
@ -461,16 +457,16 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
auto name = b.Sym();
|
auto name = b.Sym();
|
||||||
curr_stmts->Push(b.Decl(b.Var(name, decompose(bin_expr->lhs))));
|
curr_stmts->Push(b.Decl(b.Var(name, decompose(bin_expr->lhs))));
|
||||||
|
|
||||||
const ast::Expression* if_cond = nullptr;
|
const Expression* if_cond = nullptr;
|
||||||
if (bin_expr->IsLogicalOr()) {
|
if (bin_expr->IsLogicalOr()) {
|
||||||
if_cond = b.Not(name);
|
if_cond = b.Not(name);
|
||||||
} else {
|
} else {
|
||||||
if_cond = b.Expr(name);
|
if_cond = b.Expr(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ast::BlockStatement* if_body = nullptr;
|
const BlockStatement* if_body = nullptr;
|
||||||
{
|
{
|
||||||
tint::utils::Vector<const ast::Statement*, N> stmts;
|
tint::utils::Vector<const Statement*, N> stmts;
|
||||||
TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts);
|
TINT_SCOPED_ASSIGNMENT(curr_stmts, &stmts);
|
||||||
auto* new_rhs = decompose(bin_expr->rhs);
|
auto* new_rhs = decompose(bin_expr->rhs);
|
||||||
curr_stmts->Push(b.Assign(name, new_rhs));
|
curr_stmts->Push(b.Assign(name, new_rhs));
|
||||||
|
@ -481,36 +477,36 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
|
|
||||||
return b.Expr(name);
|
return b.Expr(name);
|
||||||
},
|
},
|
||||||
[&](const ast::IndexAccessorExpression* idx) {
|
[&](const IndexAccessorExpression* idx) {
|
||||||
ctx.Replace(idx->object, decompose(idx->object));
|
ctx.Replace(idx->object, decompose(idx->object));
|
||||||
ctx.Replace(idx->index, decompose(idx->index));
|
ctx.Replace(idx->index, decompose(idx->index));
|
||||||
return clone_maybe_hoisted(idx);
|
return clone_maybe_hoisted(idx);
|
||||||
},
|
},
|
||||||
[&](const ast::BitcastExpression* bitcast) {
|
[&](const BitcastExpression* bitcast) {
|
||||||
ctx.Replace(bitcast->expr, decompose(bitcast->expr));
|
ctx.Replace(bitcast->expr, decompose(bitcast->expr));
|
||||||
return clone_maybe_hoisted(bitcast);
|
return clone_maybe_hoisted(bitcast);
|
||||||
},
|
},
|
||||||
[&](const ast::CallExpression* call) {
|
[&](const CallExpression* call) {
|
||||||
for (auto* a : call->args) {
|
for (auto* a : call->args) {
|
||||||
ctx.Replace(a, decompose(a));
|
ctx.Replace(a, decompose(a));
|
||||||
}
|
}
|
||||||
return clone_maybe_hoisted(call);
|
return clone_maybe_hoisted(call);
|
||||||
},
|
},
|
||||||
[&](const ast::MemberAccessorExpression* member) {
|
[&](const MemberAccessorExpression* member) {
|
||||||
ctx.Replace(member->object, decompose(member->object));
|
ctx.Replace(member->object, decompose(member->object));
|
||||||
return clone_maybe_hoisted(member);
|
return clone_maybe_hoisted(member);
|
||||||
},
|
},
|
||||||
[&](const ast::UnaryOpExpression* unary) {
|
[&](const UnaryOpExpression* unary) {
|
||||||
ctx.Replace(unary->expr, decompose(unary->expr));
|
ctx.Replace(unary->expr, decompose(unary->expr));
|
||||||
return clone_maybe_hoisted(unary);
|
return clone_maybe_hoisted(unary);
|
||||||
},
|
},
|
||||||
[&](const ast::LiteralExpression* lit) {
|
[&](const LiteralExpression* lit) {
|
||||||
return clone_maybe_hoisted(lit); // Leaf expression, just clone as is
|
return clone_maybe_hoisted(lit); // Leaf expression, just clone as is
|
||||||
},
|
},
|
||||||
[&](const ast::IdentifierExpression* id) {
|
[&](const IdentifierExpression* id) {
|
||||||
return clone_maybe_hoisted(id); // Leaf expression, just clone as is
|
return clone_maybe_hoisted(id); // Leaf expression, just clone as is
|
||||||
},
|
},
|
||||||
[&](const ast::PhonyExpression* phony) {
|
[&](const PhonyExpression* phony) {
|
||||||
return clone_maybe_hoisted(phony); // Leaf expression, just clone as is
|
return clone_maybe_hoisted(phony); // Leaf expression, just clone as is
|
||||||
},
|
},
|
||||||
[&](Default) {
|
[&](Default) {
|
||||||
|
@ -522,8 +518,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
|
|
||||||
// Inserts statements in `stmts` before `stmt`
|
// Inserts statements in `stmts` before `stmt`
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
void InsertBefore(tint::utils::Vector<const ast::Statement*, N>& stmts,
|
void InsertBefore(tint::utils::Vector<const Statement*, N>& stmts, const Statement* stmt) {
|
||||||
const ast::Statement* stmt) {
|
|
||||||
if (!stmts.IsEmpty()) {
|
if (!stmts.IsEmpty()) {
|
||||||
auto ip = utils::GetInsertionPoint(ctx, stmt);
|
auto ip = utils::GetInsertionPoint(ctx, stmt);
|
||||||
for (auto* s : stmts) {
|
for (auto* s : stmts) {
|
||||||
|
@ -534,86 +529,86 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
|
|
||||||
// Decomposes expressions of `stmt`, returning a replacement statement or
|
// Decomposes expressions of `stmt`, returning a replacement statement or
|
||||||
// nullptr if not replacing it.
|
// nullptr if not replacing it.
|
||||||
const ast::Statement* DecomposeStatement(const ast::Statement* stmt) {
|
const Statement* DecomposeStatement(const Statement* stmt) {
|
||||||
return Switch(
|
return Switch(
|
||||||
stmt,
|
stmt,
|
||||||
[&](const ast::AssignmentStatement* s) -> const ast::Statement* {
|
[&](const AssignmentStatement* s) -> const Statement* {
|
||||||
if (!sem.GetVal(s->lhs)->HasSideEffects() &&
|
if (!sem.GetVal(s->lhs)->HasSideEffects() &&
|
||||||
!sem.GetVal(s->rhs)->HasSideEffects()) {
|
!sem.GetVal(s->rhs)->HasSideEffects()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
// lhs before rhs
|
// lhs before rhs
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(s->lhs, Decompose(s->lhs, &stmts));
|
ctx.Replace(s->lhs, Decompose(s->lhs, &stmts));
|
||||||
ctx.Replace(s->rhs, Decompose(s->rhs, &stmts));
|
ctx.Replace(s->rhs, Decompose(s->rhs, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return ctx.CloneWithoutTransform(s);
|
return ctx.CloneWithoutTransform(s);
|
||||||
},
|
},
|
||||||
[&](const ast::CallStatement* s) -> const ast::Statement* {
|
[&](const CallStatement* s) -> const Statement* {
|
||||||
if (!sem.Get(s->expr)->HasSideEffects()) {
|
if (!sem.Get(s->expr)->HasSideEffects()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(s->expr, Decompose(s->expr, &stmts));
|
ctx.Replace(s->expr, Decompose(s->expr, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return ctx.CloneWithoutTransform(s);
|
return ctx.CloneWithoutTransform(s);
|
||||||
},
|
},
|
||||||
[&](const ast::ForLoopStatement* s) -> const ast::Statement* {
|
[&](const ForLoopStatement* s) -> const Statement* {
|
||||||
if (!s->condition || !sem.GetVal(s->condition)->HasSideEffects()) {
|
if (!s->condition || !sem.GetVal(s->condition)->HasSideEffects()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return ctx.CloneWithoutTransform(s);
|
return ctx.CloneWithoutTransform(s);
|
||||||
},
|
},
|
||||||
[&](const ast::WhileStatement* s) -> const ast::Statement* {
|
[&](const WhileStatement* s) -> const Statement* {
|
||||||
if (!sem.GetVal(s->condition)->HasSideEffects()) {
|
if (!sem.GetVal(s->condition)->HasSideEffects()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return ctx.CloneWithoutTransform(s);
|
return ctx.CloneWithoutTransform(s);
|
||||||
},
|
},
|
||||||
[&](const ast::IfStatement* s) -> const ast::Statement* {
|
[&](const IfStatement* s) -> const Statement* {
|
||||||
if (!sem.GetVal(s->condition)->HasSideEffects()) {
|
if (!sem.GetVal(s->condition)->HasSideEffects()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return ctx.CloneWithoutTransform(s);
|
return ctx.CloneWithoutTransform(s);
|
||||||
},
|
},
|
||||||
[&](const ast::ReturnStatement* s) -> const ast::Statement* {
|
[&](const ReturnStatement* s) -> const Statement* {
|
||||||
if (!s->value || !sem.GetVal(s->value)->HasSideEffects()) {
|
if (!s->value || !sem.GetVal(s->value)->HasSideEffects()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(s->value, Decompose(s->value, &stmts));
|
ctx.Replace(s->value, Decompose(s->value, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return ctx.CloneWithoutTransform(s);
|
return ctx.CloneWithoutTransform(s);
|
||||||
},
|
},
|
||||||
[&](const ast::SwitchStatement* s) -> const ast::Statement* {
|
[&](const SwitchStatement* s) -> const Statement* {
|
||||||
if (!sem.Get(s->condition)) {
|
if (!sem.Get(s->condition)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return ctx.CloneWithoutTransform(s);
|
return ctx.CloneWithoutTransform(s);
|
||||||
},
|
},
|
||||||
[&](const ast::VariableDeclStatement* s) -> const ast::Statement* {
|
[&](const VariableDeclStatement* s) -> const Statement* {
|
||||||
auto* var = s->variable;
|
auto* var = s->variable;
|
||||||
if (!var->initializer || !sem.GetVal(var->initializer)->HasSideEffects()) {
|
if (!var->initializer || !sem.GetVal(var->initializer)->HasSideEffects()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tint::utils::Vector<const ast::Statement*, 8> stmts;
|
tint::utils::Vector<const Statement*, 8> stmts;
|
||||||
ctx.Replace(var->initializer, Decompose(var->initializer, &stmts));
|
ctx.Replace(var->initializer, Decompose(var->initializer, &stmts));
|
||||||
InsertBefore(stmts, s);
|
InsertBefore(stmts, s);
|
||||||
return b.Decl(ctx.CloneWithoutTransform(var));
|
return b.Decl(ctx.CloneWithoutTransform(var));
|
||||||
},
|
},
|
||||||
[](Default) -> const ast::Statement* {
|
[](Default) -> const Statement* {
|
||||||
// Other statement types don't have expressions
|
// Other statement types don't have expressions
|
||||||
return nullptr;
|
return nullptr;
|
||||||
});
|
});
|
||||||
|
@ -626,7 +621,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
void Run() {
|
void Run() {
|
||||||
// We replace all BlockStatements as this allows us to iterate over the
|
// We replace all BlockStatements as this allows us to iterate over the
|
||||||
// block statements and ctx.InsertBefore hoisted declarations on them.
|
// block statements and ctx.InsertBefore hoisted declarations on them.
|
||||||
ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const BlockStatement* block) -> const Statement* {
|
||||||
for (auto* stmt : block->statements) {
|
for (auto* stmt : block->statements) {
|
||||||
if (auto* new_stmt = DecomposeStatement(stmt)) {
|
if (auto* new_stmt = DecomposeStatement(stmt)) {
|
||||||
ctx.Replace(stmt, new_stmt);
|
ctx.Replace(stmt, new_stmt);
|
||||||
|
@ -634,7 +629,7 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
||||||
|
|
||||||
// Handle for loops, as they are the only other AST node that
|
// Handle for loops, as they are the only other AST node that
|
||||||
// contains statements outside of BlockStatements.
|
// contains statements outside of BlockStatements.
|
||||||
if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
|
if (auto* fl = stmt->As<ForLoopStatement>()) {
|
||||||
if (auto* new_stmt = DecomposeStatement(fl->initializer)) {
|
if (auto* new_stmt = DecomposeStatement(fl->initializer)) {
|
||||||
ctx.Replace(fl->initializer, new_stmt);
|
ctx.Replace(fl->initializer, new_stmt);
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ struct RemoveContinueInSwitch::State {
|
||||||
bool made_changes = false;
|
bool made_changes = false;
|
||||||
|
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
auto* cont = node->As<ast::ContinueStatement>();
|
auto* cont = node->As<ContinueStatement>();
|
||||||
if (!cont) {
|
if (!cont) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -103,12 +103,12 @@ struct RemoveContinueInSwitch::State {
|
||||||
const sem::Info& sem = src->Sem();
|
const sem::Info& sem = src->Sem();
|
||||||
|
|
||||||
// Map of switch statement to 'tint_continue' variable.
|
// Map of switch statement to 'tint_continue' variable.
|
||||||
std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
|
std::unordered_map<const SwitchStatement*, Symbol> switch_to_cont_var_name;
|
||||||
|
|
||||||
// If `cont` is within a switch statement within a loop, returns a pointer to
|
// If `cont` is within a switch statement within a loop, returns a pointer to
|
||||||
// that switch statement.
|
// that switch statement.
|
||||||
static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
|
static const SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
|
||||||
const ast::ContinueStatement* cont) {
|
const ContinueStatement* cont) {
|
||||||
// Find whether first parent is a switch or a loop
|
// Find whether first parent is a switch or a loop
|
||||||
auto* sem_stmt = sem.Get(cont);
|
auto* sem_stmt = sem.Get(cont);
|
||||||
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
|
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
|
||||||
|
@ -116,7 +116,7 @@ struct RemoveContinueInSwitch::State {
|
||||||
if (!sem_parent) {
|
if (!sem_parent) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return sem_parent->Declaration()->As<ast::SwitchStatement>();
|
return sem_parent->Declaration()->As<SwitchStatement>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -53,14 +53,14 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
Switch(
|
Switch(
|
||||||
node,
|
node,
|
||||||
[&](const ast::AssignmentStatement* stmt) {
|
[&](const AssignmentStatement* stmt) {
|
||||||
if (stmt->lhs->Is<ast::PhonyExpression>()) {
|
if (stmt->lhs->Is<PhonyExpression>()) {
|
||||||
made_changes = true;
|
made_changes = true;
|
||||||
|
|
||||||
std::vector<const ast::Expression*> side_effects;
|
std::vector<const Expression*> side_effects;
|
||||||
if (!ast::TraverseExpressions(
|
if (!TraverseExpressions(
|
||||||
stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) {
|
stmt->rhs, b.Diagnostics(), [&](const CallExpression* expr) {
|
||||||
// ast::CallExpression may map to a function or builtin call
|
// CallExpression may map to a function or builtin call
|
||||||
// (both may have side-effects), or a value constructor or value
|
// (both may have side-effects), or a value constructor or value
|
||||||
// conversion (both do not have side effects).
|
// conversion (both do not have side effects).
|
||||||
auto* call = sem.Get<sem::Call>(expr);
|
auto* call = sem.Get<sem::Call>(expr);
|
||||||
|
@ -68,14 +68,14 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
|
||||||
// Semantic node must be a Materialize, in which case the
|
// Semantic node must be a Materialize, in which case the
|
||||||
// expression was creation-time (compile time), so could not
|
// expression was creation-time (compile time), so could not
|
||||||
// have side effects. Just skip.
|
// have side effects. Just skip.
|
||||||
return ast::TraverseAction::Skip;
|
return TraverseAction::Skip;
|
||||||
}
|
}
|
||||||
if (call->Target()->IsAnyOf<sem::Function, sem::Builtin>() &&
|
if (call->Target()->IsAnyOf<sem::Function, sem::Builtin>() &&
|
||||||
call->HasSideEffects()) {
|
call->HasSideEffects()) {
|
||||||
side_effects.push_back(expr);
|
side_effects.push_back(expr);
|
||||||
return ast::TraverseAction::Skip;
|
return TraverseAction::Skip;
|
||||||
}
|
}
|
||||||
return ast::TraverseAction::Descend;
|
return TraverseAction::Descend;
|
||||||
})) {
|
})) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -88,12 +88,12 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (side_effects.size() == 1) {
|
if (side_effects.size() == 1) {
|
||||||
if (auto* call_expr = side_effects[0]->As<ast::CallExpression>()) {
|
if (auto* call_expr = side_effects[0]->As<CallExpression>()) {
|
||||||
// Phony assignment with single call side effect.
|
// Phony assignment with single call side effect.
|
||||||
auto* call = sem.Get(call_expr)->Unwrap()->As<sem::Call>();
|
auto* call = sem.Get(call_expr)->Unwrap()->As<sem::Call>();
|
||||||
if (call->Target()->MustUse()) {
|
if (call->Target()->MustUse()) {
|
||||||
// Replace phony assignment assignment to uniquely named let.
|
// Replace phony assignment assignment to uniquely named let.
|
||||||
ctx.Replace<ast::Statement>(stmt, [&, call_expr] { //
|
ctx.Replace<Statement>(stmt, [&, call_expr] { //
|
||||||
auto name = b.Symbols().New("tint_phony");
|
auto name = b.Symbols().New("tint_phony");
|
||||||
auto* rhs = ctx.Clone(call_expr);
|
auto* rhs = ctx.Clone(call_expr);
|
||||||
return b.Decl(b.Let(name, rhs));
|
return b.Decl(b.Let(name, rhs));
|
||||||
|
@ -118,7 +118,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
|
||||||
}
|
}
|
||||||
auto sink = sinks.GetOrCreate(sig, [&] {
|
auto sink = sinks.GetOrCreate(sig, [&] {
|
||||||
auto name = b.Symbols().New("phony_sink");
|
auto name = b.Symbols().New("phony_sink");
|
||||||
utils::Vector<const ast::Parameter*, 8> params;
|
utils::Vector<const Parameter*, 8> params;
|
||||||
for (auto* ty : sig) {
|
for (auto* ty : sig) {
|
||||||
auto ast_ty = CreateASTTypeFor(ctx, ty);
|
auto ast_ty = CreateASTTypeFor(ctx, ty);
|
||||||
params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
|
params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
|
||||||
|
@ -126,7 +126,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
|
||||||
b.Func(name, params, b.ty.void_(), {});
|
b.Func(name, params, b.ty.void_(), {});
|
||||||
return name;
|
return name;
|
||||||
});
|
});
|
||||||
utils::Vector<const ast::Expression*, 8> args;
|
utils::Vector<const Expression*, 8> args;
|
||||||
for (auto* arg : side_effects) {
|
for (auto* arg : side_effects) {
|
||||||
args.Push(ctx.Clone(arg));
|
args.Push(ctx.Clone(arg));
|
||||||
}
|
}
|
||||||
|
@ -134,7 +134,7 @@ Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::CallStatement* stmt) {
|
[&](const CallStatement* stmt) {
|
||||||
// Remove call statements to const value-returning functions.
|
// Remove call statements to const value-returning functions.
|
||||||
// TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
|
// TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
|
||||||
auto* sem_expr = sem.Get(stmt->expr);
|
auto* sem_expr = sem.Get(stmt->expr);
|
||||||
|
|
|
@ -1265,10 +1265,10 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identifiers that need to keep their symbols preserved.
|
// Identifiers that need to keep their symbols preserved.
|
||||||
utils::Hashset<const ast::Identifier*, 16> preserved_identifiers;
|
utils::Hashset<const Identifier*, 16> preserved_identifiers;
|
||||||
|
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
auto preserve_if_builtin_type = [&](const ast::Identifier* ident) {
|
auto preserve_if_builtin_type = [&](const Identifier* ident) {
|
||||||
if (!global_decls.Contains(ident->symbol)) {
|
if (!global_decls.Contains(ident->symbol)) {
|
||||||
preserved_identifiers.Add(ident);
|
preserved_identifiers.Add(ident);
|
||||||
}
|
}
|
||||||
|
@ -1276,7 +1276,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
|
||||||
|
|
||||||
Switch(
|
Switch(
|
||||||
node,
|
node,
|
||||||
[&](const ast::MemberAccessorExpression* accessor) {
|
[&](const MemberAccessorExpression* accessor) {
|
||||||
auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
|
auto* sem = src->Sem().Get(accessor)->UnwrapLoad();
|
||||||
if (sem->Is<sem::Swizzle>()) {
|
if (sem->Is<sem::Swizzle>()) {
|
||||||
preserved_identifiers.Add(accessor->member);
|
preserved_identifiers.Add(accessor->member);
|
||||||
|
@ -1288,19 +1288,19 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::DiagnosticAttribute* diagnostic) {
|
[&](const DiagnosticAttribute* diagnostic) {
|
||||||
if (auto* category = diagnostic->control.rule_name->category) {
|
if (auto* category = diagnostic->control.rule_name->category) {
|
||||||
preserved_identifiers.Add(category);
|
preserved_identifiers.Add(category);
|
||||||
}
|
}
|
||||||
preserved_identifiers.Add(diagnostic->control.rule_name->name);
|
preserved_identifiers.Add(diagnostic->control.rule_name->name);
|
||||||
},
|
},
|
||||||
[&](const ast::DiagnosticDirective* diagnostic) {
|
[&](const DiagnosticDirective* diagnostic) {
|
||||||
if (auto* category = diagnostic->control.rule_name->category) {
|
if (auto* category = diagnostic->control.rule_name->category) {
|
||||||
preserved_identifiers.Add(category);
|
preserved_identifiers.Add(category);
|
||||||
}
|
}
|
||||||
preserved_identifiers.Add(diagnostic->control.rule_name->name);
|
preserved_identifiers.Add(diagnostic->control.rule_name->name);
|
||||||
},
|
},
|
||||||
[&](const ast::IdentifierExpression* expr) {
|
[&](const IdentifierExpression* expr) {
|
||||||
Switch(
|
Switch(
|
||||||
src->Sem().Get(expr), //
|
src->Sem().Get(expr), //
|
||||||
[&](const sem::BuiltinEnumExpressionBase*) {
|
[&](const sem::BuiltinEnumExpressionBase*) {
|
||||||
|
@ -1310,7 +1310,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
|
||||||
preserve_if_builtin_type(expr->identifier);
|
preserve_if_builtin_type(expr->identifier);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[&](const ast::CallExpression* call) {
|
[&](const CallExpression* call) {
|
||||||
Switch(
|
Switch(
|
||||||
src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(),
|
src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>()->Target(),
|
||||||
[&](const sem::Builtin*) {
|
[&](const sem::Builtin*) {
|
||||||
|
@ -1372,7 +1372,7 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
|
CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::Identifier* ident) -> const ast::Identifier* {
|
ctx.ReplaceAll([&](const Identifier* ident) -> const Identifier* {
|
||||||
const auto symbol = ident->symbol;
|
const auto symbol = ident->symbol;
|
||||||
if (preserved_identifiers.Contains(ident) || !should_rename(symbol)) {
|
if (preserved_identifiers.Contains(ident) || !should_rename(symbol)) {
|
||||||
return nullptr; // Preserve symbol
|
return nullptr; // Preserve symbol
|
||||||
|
@ -1382,12 +1382,12 @@ Transform::ApplyResult Renamer::Apply(const Program* src,
|
||||||
auto replacement = remappings.GetOrCreate(symbol, [&] { return b.Symbols().New(); });
|
auto replacement = remappings.GetOrCreate(symbol, [&] { return b.Symbols().New(); });
|
||||||
|
|
||||||
// Reconstruct the identifier
|
// Reconstruct the identifier
|
||||||
if (auto* tmpl_ident = ident->As<ast::TemplatedIdentifier>()) {
|
if (auto* tmpl_ident = ident->As<TemplatedIdentifier>()) {
|
||||||
auto args = ctx.Clone(tmpl_ident->arguments);
|
auto args = ctx.Clone(tmpl_ident->arguments);
|
||||||
return ctx.dst->create<ast::TemplatedIdentifier>(ctx.Clone(ident->source), replacement,
|
return ctx.dst->create<TemplatedIdentifier>(ctx.Clone(ident->source), replacement,
|
||||||
std::move(args), utils::Empty);
|
std::move(args), utils::Empty);
|
||||||
}
|
}
|
||||||
return ctx.dst->create<ast::Identifier>(ctx.Clone(ident->source), replacement);
|
return ctx.dst->create<Identifier>(ctx.Clone(ident->source), replacement);
|
||||||
});
|
});
|
||||||
|
|
||||||
ctx.Clone();
|
ctx.Clone();
|
||||||
|
|
|
@ -58,7 +58,7 @@ struct Robustness::State {
|
||||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||||
Switch(
|
Switch(
|
||||||
node, //
|
node, //
|
||||||
[&](const ast::IndexAccessorExpression* e) {
|
[&](const IndexAccessorExpression* e) {
|
||||||
// obj[idx]
|
// obj[idx]
|
||||||
// Array, matrix and vector indexing may require robustness transformation.
|
// Array, matrix and vector indexing may require robustness transformation.
|
||||||
auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>();
|
auto* expr = sem.Get(e)->Unwrap()->As<sem::IndexAccessorExpression>();
|
||||||
|
@ -73,7 +73,7 @@ struct Robustness::State {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::IdentifierExpression* e) {
|
[&](const IdentifierExpression* e) {
|
||||||
// Identifiers may resolve to pointer lets, which may be predicated.
|
// Identifiers may resolve to pointer lets, which may be predicated.
|
||||||
// Inspect.
|
// Inspect.
|
||||||
if (auto* user = sem.Get<sem::VariableUser>(e)) {
|
if (auto* user = sem.Get<sem::VariableUser>(e)) {
|
||||||
|
@ -86,42 +86,42 @@ struct Robustness::State {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::AccessorExpression* e) {
|
[&](const AccessorExpression* e) {
|
||||||
// obj.member
|
// obj.member
|
||||||
// Propagate the predication from the object to this expression.
|
// Propagate the predication from the object to this expression.
|
||||||
if (auto pred = predicates.Get(e->object)) {
|
if (auto pred = predicates.Get(e->object)) {
|
||||||
predicates.Add(e, *pred);
|
predicates.Add(e, *pred);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::UnaryOpExpression* e) {
|
[&](const UnaryOpExpression* e) {
|
||||||
// Includes address-of, or indirection
|
// Includes address-of, or indirection
|
||||||
// Propagate the predication from the inner expression to this expression.
|
// Propagate the predication from the inner expression to this expression.
|
||||||
if (auto pred = predicates.Get(e->expr)) {
|
if (auto pred = predicates.Get(e->expr)) {
|
||||||
predicates.Add(e, *pred);
|
predicates.Add(e, *pred);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::AssignmentStatement* s) {
|
[&](const AssignmentStatement* s) {
|
||||||
if (auto pred = predicates.Get(s->lhs)) {
|
if (auto pred = predicates.Get(s->lhs)) {
|
||||||
// Assignment target is predicated
|
// Assignment target is predicated
|
||||||
// Replace statement with condition on the predicate
|
// Replace statement with condition on the predicate
|
||||||
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
|
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::CompoundAssignmentStatement* s) {
|
[&](const CompoundAssignmentStatement* s) {
|
||||||
if (auto pred = predicates.Get(s->lhs)) {
|
if (auto pred = predicates.Get(s->lhs)) {
|
||||||
// Assignment expression is predicated
|
// Assignment expression is predicated
|
||||||
// Replace statement with condition on the predicate
|
// Replace statement with condition on the predicate
|
||||||
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
|
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::IncrementDecrementStatement* s) {
|
[&](const IncrementDecrementStatement* s) {
|
||||||
if (auto pred = predicates.Get(s->lhs)) {
|
if (auto pred = predicates.Get(s->lhs)) {
|
||||||
// Assignment expression is predicated
|
// Assignment expression is predicated
|
||||||
// Replace statement with condition on the predicate
|
// Replace statement with condition on the predicate
|
||||||
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
|
ctx.Replace(s, b.If(*pred, b.Block(ctx.Clone(s))));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::CallExpression* e) {
|
[&](const CallExpression* e) {
|
||||||
if (auto* call = sem.Get<sem::Call>(e)) {
|
if (auto* call = sem.Get<sem::Call>(e)) {
|
||||||
Switch(
|
Switch(
|
||||||
call->Target(), //
|
call->Target(), //
|
||||||
|
@ -163,7 +163,7 @@ struct Robustness::State {
|
||||||
// predicated_expr = expr;
|
// predicated_expr = expr;
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
if (auto* expr = node->As<ast::Expression>()) {
|
if (auto* expr = node->As<Expression>()) {
|
||||||
if (auto pred = predicates.Get(expr)) {
|
if (auto pred = predicates.Get(expr)) {
|
||||||
// Expression is predicated
|
// Expression is predicated
|
||||||
auto* sem_expr = sem.GetVal(expr);
|
auto* sem_expr = sem.GetVal(expr);
|
||||||
|
@ -202,15 +202,15 @@ struct Robustness::State {
|
||||||
/// Alias to the source program's semantic info
|
/// Alias to the source program's semantic info
|
||||||
const sem::Info& sem = ctx.src->Sem();
|
const sem::Info& sem = ctx.src->Sem();
|
||||||
/// Map of expression to predicate condition
|
/// Map of expression to predicate condition
|
||||||
utils::Hashmap<const ast::Expression*, Symbol, 32> predicates{};
|
utils::Hashmap<const Expression*, Symbol, 32> predicates{};
|
||||||
|
|
||||||
/// @return the `u32` typed expression that represents the maximum indexable value for the index
|
/// @return the `u32` typed expression that represents the maximum indexable value for the index
|
||||||
/// accessor @p expr, or nullptr if there is no robustness limit for this expression.
|
/// accessor @p expr, or nullptr if there is no robustness limit for this expression.
|
||||||
const ast::Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) {
|
const Expression* DynamicLimitFor(const sem::IndexAccessorExpression* expr) {
|
||||||
auto* obj_type = expr->Object()->Type();
|
auto* obj_type = expr->Object()->Type();
|
||||||
return Switch(
|
return Switch(
|
||||||
obj_type->UnwrapRef(), //
|
obj_type->UnwrapRef(), //
|
||||||
[&](const type::Vector* vec) -> const ast::Expression* {
|
[&](const type::Vector* vec) -> const Expression* {
|
||||||
if (expr->Index()->ConstantValue() || expr->Index()->Is<sem::Swizzle>()) {
|
if (expr->Index()->ConstantValue() || expr->Index()->Is<sem::Swizzle>()) {
|
||||||
// Index and size is constant.
|
// Index and size is constant.
|
||||||
// Validation will have rejected any OOB accesses.
|
// Validation will have rejected any OOB accesses.
|
||||||
|
@ -218,7 +218,7 @@ struct Robustness::State {
|
||||||
}
|
}
|
||||||
return b.Expr(u32(vec->Width() - 1u));
|
return b.Expr(u32(vec->Width() - 1u));
|
||||||
},
|
},
|
||||||
[&](const type::Matrix* mat) -> const ast::Expression* {
|
[&](const type::Matrix* mat) -> const Expression* {
|
||||||
if (expr->Index()->ConstantValue()) {
|
if (expr->Index()->ConstantValue()) {
|
||||||
// Index and size is constant.
|
// Index and size is constant.
|
||||||
// Validation will have rejected any OOB accesses.
|
// Validation will have rejected any OOB accesses.
|
||||||
|
@ -226,7 +226,7 @@ struct Robustness::State {
|
||||||
}
|
}
|
||||||
return b.Expr(u32(mat->columns() - 1u));
|
return b.Expr(u32(mat->columns() - 1u));
|
||||||
},
|
},
|
||||||
[&](const type::Array* arr) -> const ast::Expression* {
|
[&](const type::Array* arr) -> const Expression* {
|
||||||
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
|
if (arr->Count()->Is<type::RuntimeArrayCount>()) {
|
||||||
// Size is unknown until runtime.
|
// Size is unknown until runtime.
|
||||||
// Must clamp, even if the index is constant.
|
// Must clamp, even if the index is constant.
|
||||||
|
@ -248,7 +248,7 @@ struct Robustness::State {
|
||||||
type::Array::kErrExpectedConstantCount);
|
type::Array::kErrExpectedConstantCount);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
},
|
},
|
||||||
[&](Default) -> const ast::Expression* {
|
[&](Default) -> const Expression* {
|
||||||
TINT_ICE(Transform, b.Diagnostics())
|
TINT_ICE(Transform, b.Diagnostics())
|
||||||
<< "unhandled object type in robustness of array index: "
|
<< "unhandled object type in robustness of array index: "
|
||||||
<< obj_type->UnwrapRef()->FriendlyName();
|
<< obj_type->UnwrapRef()->FriendlyName();
|
||||||
|
@ -350,7 +350,7 @@ struct Robustness::State {
|
||||||
/// Applies predication to the non-texture builtin call, if required.
|
/// Applies predication to the non-texture builtin call, if required.
|
||||||
void MaybePredicateNonTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
|
void MaybePredicateNonTextureBuiltin(const sem::Call* call, const sem::Builtin* builtin) {
|
||||||
// Gather the predications for the builtin arguments
|
// Gather the predications for the builtin arguments
|
||||||
const ast::Expression* predicate = nullptr;
|
const Expression* predicate = nullptr;
|
||||||
for (auto* arg : call->Declaration()->args) {
|
for (auto* arg : call->Declaration()->args) {
|
||||||
if (auto pred = predicates.Get(arg)) {
|
if (auto pred = predicates.Get(arg)) {
|
||||||
predicate = And(predicate, b.Expr(*pred));
|
predicate = And(predicate, b.Expr(*pred));
|
||||||
|
@ -393,7 +393,7 @@ struct Robustness::State {
|
||||||
auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
|
auto* texture_arg = expr->args[static_cast<size_t>(texture_arg_idx)];
|
||||||
|
|
||||||
// Build the builtin predicate from the arguments
|
// Build the builtin predicate from the arguments
|
||||||
const ast::Expression* predicate = nullptr;
|
const Expression* predicate = nullptr;
|
||||||
|
|
||||||
Symbol level_idx, num_levels;
|
Symbol level_idx, num_levels;
|
||||||
if (level_arg_idx >= 0) {
|
if (level_arg_idx >= 0) {
|
||||||
|
@ -554,7 +554,7 @@ struct Robustness::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @returns a bitwise and of the two expressions, or the other expression if one is null.
|
/// @returns a bitwise and of the two expressions, or the other expression if one is null.
|
||||||
const ast::Expression* And(const ast::Expression* lhs, const ast::Expression* rhs) {
|
const Expression* And(const Expression* lhs, const Expression* rhs) {
|
||||||
if (lhs && rhs) {
|
if (lhs && rhs) {
|
||||||
return b.And(lhs, rhs);
|
return b.And(lhs, rhs);
|
||||||
}
|
}
|
||||||
|
@ -568,11 +568,11 @@ struct Robustness::State {
|
||||||
/// predicate.
|
/// predicate.
|
||||||
/// @param else_stmt - the statement to execute for the predication failure
|
/// @param else_stmt - the statement to execute for the predication failure
|
||||||
void PredicateCall(const sem::Call* call,
|
void PredicateCall(const sem::Call* call,
|
||||||
const ast::Expression* predicate,
|
const Expression* predicate,
|
||||||
const ast::BlockStatement* else_stmt = nullptr) {
|
const BlockStatement* else_stmt = nullptr) {
|
||||||
auto* expr = call->Declaration();
|
auto* expr = call->Declaration();
|
||||||
auto* stmt = call->Stmt();
|
auto* stmt = call->Stmt();
|
||||||
auto* call_stmt = stmt->Declaration()->As<ast::CallStatement>();
|
auto* call_stmt = stmt->Declaration()->As<CallStatement>();
|
||||||
if (call_stmt && call_stmt->expr == expr) {
|
if (call_stmt && call_stmt->expr == expr) {
|
||||||
// Wrap the statement in an if-statement with the predicate condition.
|
// Wrap the statement in an if-statement with the predicate condition.
|
||||||
hoist.Replace(stmt, b.If(predicate, b.Block(ctx.Clone(stmt->Declaration())),
|
hoist.Replace(stmt, b.If(predicate, b.Block(ctx.Clone(stmt->Declaration())),
|
||||||
|
@ -646,7 +646,7 @@ struct Robustness::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @returns a scalar or vector type with the element type @p scalar and width @p width
|
/// @returns a scalar or vector type with the element type @p scalar and width @p width
|
||||||
ast::Type ScalarOrVecTy(ast::Type scalar, uint32_t width) const {
|
Type ScalarOrVecTy(Type scalar, uint32_t width) const {
|
||||||
if (width > 1) {
|
if (width > 1) {
|
||||||
return b.ty.vec(scalar, width);
|
return b.ty.vec(scalar, width);
|
||||||
}
|
}
|
||||||
|
@ -655,7 +655,7 @@ struct Robustness::State {
|
||||||
|
|
||||||
/// @returns a vector constructed with the scalar expression @p scalar if @p width > 1,
|
/// @returns a vector constructed with the scalar expression @p scalar if @p width > 1,
|
||||||
/// otherwise returns @p scalar.
|
/// otherwise returns @p scalar.
|
||||||
const ast::Expression* ScalarOrVec(const ast::Expression* scalar, uint32_t width) {
|
const Expression* ScalarOrVec(const Expression* scalar, uint32_t width) {
|
||||||
if (width > 1) {
|
if (width > 1) {
|
||||||
return b.Call(b.ty.vec<Infer>(width), scalar);
|
return b.Call(b.ty.vec<Infer>(width), scalar);
|
||||||
}
|
}
|
||||||
|
@ -664,13 +664,13 @@ struct Robustness::State {
|
||||||
|
|
||||||
/// @returns @p val cast to a `vecN<i32>`, where `N` is @p width, or cast to i32 if @p width
|
/// @returns @p val cast to a `vecN<i32>`, where `N` is @p width, or cast to i32 if @p width
|
||||||
/// is 1.
|
/// is 1.
|
||||||
const ast::CallExpression* CastToSigned(const ast::Expression* val, uint32_t width) {
|
const CallExpression* CastToSigned(const Expression* val, uint32_t width) {
|
||||||
return b.Call(ScalarOrVecTy(b.ty.i32(), width), val);
|
return b.Call(ScalarOrVecTy(b.ty.i32(), width), val);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @returns @p val cast to a `vecN<u32>`, where `N` is @p width, or cast to u32 if @p width
|
/// @returns @p val cast to a `vecN<u32>`, where `N` is @p width, or cast to u32 if @p width
|
||||||
/// is 1.
|
/// is 1.
|
||||||
const ast::CallExpression* CastToUnsigned(const ast::Expression* val, uint32_t width) {
|
const CallExpression* CastToUnsigned(const Expression* val, uint32_t width) {
|
||||||
return b.Call(ScalarOrVecTy(b.ty.u32(), width), val);
|
return b.Call(ScalarOrVecTy(b.ty.u32(), width), val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -41,7 +41,7 @@ struct PointerOp {
|
||||||
/// Zero: no pointer op on `expr`
|
/// Zero: no pointer op on `expr`
|
||||||
int indirections = 0;
|
int indirections = 0;
|
||||||
/// The expression being operated on
|
/// The expression being operated on
|
||||||
const ast::Expression* expr = nullptr;
|
const Expression* expr = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -64,29 +64,29 @@ struct SimplifyPointers::State {
|
||||||
/// expression. The function-like argument `cb` is called for each found.
|
/// expression. The function-like argument `cb` is called for each found.
|
||||||
/// @param expr the expression to traverse
|
/// @param expr the expression to traverse
|
||||||
/// @param cb a function-like object with the signature
|
/// @param cb a function-like object with the signature
|
||||||
/// `void(const ast::Expression*)`, which is called for each array index
|
/// `void(const Expression*)`, which is called for each array index
|
||||||
/// expression
|
/// expression
|
||||||
template <typename F>
|
template <typename F>
|
||||||
static void CollectSavedArrayIndices(const ast::Expression* expr, F&& cb) {
|
static void CollectSavedArrayIndices(const Expression* expr, F&& cb) {
|
||||||
if (auto* a = expr->As<ast::IndexAccessorExpression>()) {
|
if (auto* a = expr->As<IndexAccessorExpression>()) {
|
||||||
CollectSavedArrayIndices(a->object, cb);
|
CollectSavedArrayIndices(a->object, cb);
|
||||||
if (!a->index->Is<ast::LiteralExpression>()) {
|
if (!a->index->Is<LiteralExpression>()) {
|
||||||
cb(a->index);
|
cb(a->index);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* m = expr->As<ast::MemberAccessorExpression>()) {
|
if (auto* m = expr->As<MemberAccessorExpression>()) {
|
||||||
CollectSavedArrayIndices(m->object, cb);
|
CollectSavedArrayIndices(m->object, cb);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto* u = expr->As<ast::UnaryOpExpression>()) {
|
if (auto* u = expr->As<UnaryOpExpression>()) {
|
||||||
CollectSavedArrayIndices(u->expr, cb);
|
CollectSavedArrayIndices(u->expr, cb);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Other ast::Expression types can be safely ignored as they cannot be
|
// Note: Other Expression types can be safely ignored as they cannot be
|
||||||
// used to generate a reference or pointer.
|
// used to generate a reference or pointer.
|
||||||
// See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
|
// See https://gpuweb.github.io/gpuweb/wgsl/#forming-references-and-pointers
|
||||||
}
|
}
|
||||||
|
@ -95,16 +95,16 @@ struct SimplifyPointers::State {
|
||||||
/// indirection ops into a PointerOp.
|
/// indirection ops into a PointerOp.
|
||||||
/// @param in the expression to walk
|
/// @param in the expression to walk
|
||||||
/// @returns the reduced PointerOp
|
/// @returns the reduced PointerOp
|
||||||
PointerOp Reduce(const ast::Expression* in) const {
|
PointerOp Reduce(const Expression* in) const {
|
||||||
PointerOp op{0, in};
|
PointerOp op{0, in};
|
||||||
while (true) {
|
while (true) {
|
||||||
if (auto* unary = op.expr->As<ast::UnaryOpExpression>()) {
|
if (auto* unary = op.expr->As<UnaryOpExpression>()) {
|
||||||
switch (unary->op) {
|
switch (unary->op) {
|
||||||
case ast::UnaryOp::kIndirection:
|
case UnaryOp::kIndirection:
|
||||||
op.indirections++;
|
op.indirections++;
|
||||||
op.expr = unary->expr;
|
op.expr = unary->expr;
|
||||||
continue;
|
continue;
|
||||||
case ast::UnaryOp::kAddressOf:
|
case UnaryOp::kAddressOf:
|
||||||
op.indirections--;
|
op.indirections--;
|
||||||
op.expr = unary->expr;
|
op.expr = unary->expr;
|
||||||
continue;
|
continue;
|
||||||
|
@ -115,7 +115,7 @@ struct SimplifyPointers::State {
|
||||||
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
|
if (auto* user = ctx.src->Sem().Get<sem::VariableUser>(op.expr)) {
|
||||||
auto* var = user->Variable();
|
auto* var = user->Variable();
|
||||||
if (var->Is<sem::LocalVariable>() && //
|
if (var->Is<sem::LocalVariable>() && //
|
||||||
var->Declaration()->Is<ast::Let>() && //
|
var->Declaration()->Is<Let>() && //
|
||||||
var->Type()->Is<type::Pointer>()) {
|
var->Type()->Is<type::Pointer>()) {
|
||||||
op.expr = var->Declaration()->initializer;
|
op.expr = var->Declaration()->initializer;
|
||||||
continue;
|
continue;
|
||||||
|
@ -129,7 +129,7 @@ struct SimplifyPointers::State {
|
||||||
/// @returns the new program or SkipTransform if the transform is not required
|
/// @returns the new program or SkipTransform if the transform is not required
|
||||||
ApplyResult Run() {
|
ApplyResult Run() {
|
||||||
// A map of saved expressions to their saved variable name
|
// A map of saved expressions to their saved variable name
|
||||||
utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars;
|
utils::Hashmap<const Expression*, Symbol, 8> saved_vars;
|
||||||
|
|
||||||
bool needs_transform = false;
|
bool needs_transform = false;
|
||||||
for (auto* ty : ctx.src->Types()) {
|
for (auto* ty : ctx.src->Types()) {
|
||||||
|
@ -146,8 +146,8 @@ struct SimplifyPointers::State {
|
||||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||||
Switch(
|
Switch(
|
||||||
node, //
|
node, //
|
||||||
[&](const ast::VariableDeclStatement* let) {
|
[&](const VariableDeclStatement* let) {
|
||||||
if (!let->variable->Is<ast::Let>()) {
|
if (!let->variable->Is<Let>()) {
|
||||||
return; // Not a `let` declaration. Ignore.
|
return; // Not a `let` declaration. Ignore.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,9 +160,9 @@ struct SimplifyPointers::State {
|
||||||
|
|
||||||
// Scan the initializer expression for array index expressions that need
|
// Scan the initializer expression for array index expressions that need
|
||||||
// to be hoist to temporary "saved" variables.
|
// to be hoist to temporary "saved" variables.
|
||||||
utils::Vector<const ast::VariableDeclStatement*, 8> saved;
|
utils::Vector<const VariableDeclStatement*, 8> saved;
|
||||||
CollectSavedArrayIndices(
|
CollectSavedArrayIndices(
|
||||||
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) {
|
var->Declaration()->initializer, [&](const Expression* idx_expr) {
|
||||||
// We have a sub-expression that needs to be saved.
|
// We have a sub-expression that needs to be saved.
|
||||||
// Create a new variable
|
// Create a new variable
|
||||||
auto saved_name = ctx.dst->Symbols().New(
|
auto saved_name = ctx.dst->Symbols().New(
|
||||||
|
@ -205,8 +205,8 @@ struct SimplifyPointers::State {
|
||||||
// need for the original declaration to exist. Remove it.
|
// need for the original declaration to exist. Remove it.
|
||||||
RemoveStatement(ctx, let);
|
RemoveStatement(ctx, let);
|
||||||
},
|
},
|
||||||
[&](const ast::UnaryOpExpression* op) {
|
[&](const UnaryOpExpression* op) {
|
||||||
if (op->op == ast::UnaryOp::kAddressOf) {
|
if (op->op == UnaryOp::kAddressOf) {
|
||||||
// Transform can be skipped if no address-of operator is used, as there
|
// Transform can be skipped if no address-of operator is used, as there
|
||||||
// will be no pointers that can be inlined.
|
// will be no pointers that can be inlined.
|
||||||
needs_transform = true;
|
needs_transform = true;
|
||||||
|
@ -218,7 +218,7 @@ struct SimplifyPointers::State {
|
||||||
return SkipTransform;
|
return SkipTransform;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the ast::Expression transform handler.
|
// Register the Expression transform handler.
|
||||||
// This performs two different transformations:
|
// This performs two different transformations:
|
||||||
// * Identifiers that resolve to the pointer-typed `let` declarations are
|
// * Identifiers that resolve to the pointer-typed `let` declarations are
|
||||||
// replaced with the recursively inlined initializer expression for the
|
// replaced with the recursively inlined initializer expression for the
|
||||||
|
@ -226,7 +226,7 @@ struct SimplifyPointers::State {
|
||||||
// * Sub-expressions inside the pointer-typed `let` initializer expression
|
// * Sub-expressions inside the pointer-typed `let` initializer expression
|
||||||
// that have been hoisted to a saved variable are replaced with the saved
|
// that have been hoisted to a saved variable are replaced with the saved
|
||||||
// variable identifier.
|
// variable identifier.
|
||||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
|
||||||
// Look to see if we need to swap this Expression with a saved variable.
|
// Look to see if we need to swap this Expression with a saved variable.
|
||||||
if (auto saved_var = saved_vars.Find(expr)) {
|
if (auto saved_var = saved_vars.Find(expr)) {
|
||||||
return ctx.dst->Expr(*saved_var);
|
return ctx.dst->Expr(*saved_var);
|
||||||
|
|
|
@ -45,7 +45,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the target entry point.
|
// Find the target entry point.
|
||||||
const ast::Function* entry_point = nullptr;
|
const Function* entry_point = nullptr;
|
||||||
for (auto* f : src->AST().Functions()) {
|
for (auto* f : src->AST().Functions()) {
|
||||||
if (!f->IsEntryPoint()) {
|
if (!f->IsEntryPoint()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -69,7 +69,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
|
||||||
for (auto* decl : src->AST().GlobalDeclarations()) {
|
for (auto* decl : src->AST().GlobalDeclarations()) {
|
||||||
Switch(
|
Switch(
|
||||||
decl, //
|
decl, //
|
||||||
[&](const ast::TypeDecl* ty) {
|
[&](const TypeDecl* ty) {
|
||||||
// Strip aliases that reference unused override declarations.
|
// Strip aliases that reference unused override declarations.
|
||||||
if (auto* arr = sem.Get(ty)->As<type::Array>()) {
|
if (auto* arr = sem.Get(ty)->As<type::Array>()) {
|
||||||
auto* refs = sem.TransitivelyReferencedOverrides(arr);
|
auto* refs = sem.TransitivelyReferencedOverrides(arr);
|
||||||
|
@ -85,9 +85,9 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
|
||||||
// TODO(jrprice): Strip other unused types.
|
// TODO(jrprice): Strip other unused types.
|
||||||
b.AST().AddTypeDecl(ctx.Clone(ty));
|
b.AST().AddTypeDecl(ctx.Clone(ty));
|
||||||
},
|
},
|
||||||
[&](const ast::Override* override) {
|
[&](const Override* override) {
|
||||||
if (referenced_vars.Contains(sem.Get(override))) {
|
if (referenced_vars.Contains(sem.Get(override))) {
|
||||||
if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
|
if (!HasAttribute<IdAttribute>(override->attributes)) {
|
||||||
// If the override doesn't already have an @id() attribute, add one
|
// If the override doesn't already have an @id() attribute, add one
|
||||||
// so that its allocated ID so that it won't be affected by other
|
// so that its allocated ID so that it won't be affected by other
|
||||||
// stripped away overrides
|
// stripped away overrides
|
||||||
|
@ -98,26 +98,24 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
|
||||||
b.AST().AddGlobalVariable(ctx.Clone(override));
|
b.AST().AddGlobalVariable(ctx.Clone(override));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::Var* var) {
|
[&](const Var* var) {
|
||||||
if (referenced_vars.Contains(sem.Get<sem::GlobalVariable>(var))) {
|
if (referenced_vars.Contains(sem.Get<sem::GlobalVariable>(var))) {
|
||||||
b.AST().AddGlobalVariable(ctx.Clone(var));
|
b.AST().AddGlobalVariable(ctx.Clone(var));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::Const* c) {
|
[&](const Const* c) {
|
||||||
// Always keep 'const' declarations, as these can be used by attributes and array
|
// Always keep 'const' declarations, as these can be used by attributes and array
|
||||||
// sizes, which are not tracked as transitively used by functions. They also don't
|
// sizes, which are not tracked as transitively used by functions. They also don't
|
||||||
// typically get emitted by the backend unless they're actually used.
|
// typically get emitted by the backend unless they're actually used.
|
||||||
b.AST().AddGlobalVariable(ctx.Clone(c));
|
b.AST().AddGlobalVariable(ctx.Clone(c));
|
||||||
},
|
},
|
||||||
[&](const ast::Function* func) {
|
[&](const Function* func) {
|
||||||
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->name->symbol)) {
|
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->name->symbol)) {
|
||||||
b.AST().AddFunction(ctx.Clone(func));
|
b.AST().AddFunction(ctx.Clone(func));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
|
[&](const Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
|
||||||
[&](const ast::DiagnosticDirective* d) {
|
[&](const DiagnosticDirective* d) { b.AST().AddDiagnosticDirective(ctx.Clone(d)); },
|
||||||
b.AST().AddDiagnosticDirective(ctx.Clone(d));
|
|
||||||
},
|
|
||||||
[&](Default) {
|
[&](Default) {
|
||||||
TINT_UNREACHABLE(Transform, b.Diagnostics())
|
TINT_UNREACHABLE(Transform, b.Diagnostics())
|
||||||
<< "unhandled global declaration: " << decl->TypeInfo().name;
|
<< "unhandled global declaration: " << decl->TypeInfo().name;
|
||||||
|
|
|
@ -70,7 +70,7 @@ struct SpirvAtomic::State {
|
||||||
// Look for stub functions generated by the SPIR-V reader, which are used as placeholders
|
// Look for stub functions generated by the SPIR-V reader, which are used as placeholders
|
||||||
// for atomic builtin calls.
|
// for atomic builtin calls.
|
||||||
for (auto* fn : ctx.src->AST().Functions()) {
|
for (auto* fn : ctx.src->AST().Functions()) {
|
||||||
if (auto* stub = ast::GetAttribute<Stub>(fn->attributes)) {
|
if (auto* stub = GetAttribute<Stub>(fn->attributes)) {
|
||||||
auto* sem = ctx.src->Sem().Get(fn);
|
auto* sem = ctx.src->Sem().Get(fn);
|
||||||
|
|
||||||
for (auto* call : sem->CallSites()) {
|
for (auto* call : sem->CallSites()) {
|
||||||
|
@ -121,14 +121,14 @@ struct SpirvAtomic::State {
|
||||||
|
|
||||||
// If we need to change structure members, then fork them.
|
// If we need to change structure members, then fork them.
|
||||||
if (!forked_structs.empty()) {
|
if (!forked_structs.empty()) {
|
||||||
ctx.ReplaceAll([&](const ast::Struct* str) {
|
ctx.ReplaceAll([&](const Struct* str) {
|
||||||
// Is `str` a structure we need to fork?
|
// Is `str` a structure we need to fork?
|
||||||
auto* str_ty = ctx.src->Sem().Get(str);
|
auto* str_ty = ctx.src->Sem().Get(str);
|
||||||
if (auto it = forked_structs.find(str_ty); it != forked_structs.end()) {
|
if (auto it = forked_structs.find(str_ty); it != forked_structs.end()) {
|
||||||
const auto& forked = it->second;
|
const auto& forked = it->second;
|
||||||
|
|
||||||
// Re-create the structure swapping in the atomic-flavoured members
|
// Re-create the structure swapping in the atomic-flavoured members
|
||||||
utils::Vector<const ast::StructMember*, 8> members;
|
utils::Vector<const StructMember*, 8> members;
|
||||||
members.Reserve(str->members.Length());
|
members.Reserve(str->members.Length());
|
||||||
for (size_t i = 0; i < str->members.Length(); i++) {
|
for (size_t i = 0; i < str->members.Length(); i++) {
|
||||||
auto* member = str->members[i];
|
auto* member = str->members[i];
|
||||||
|
@ -187,14 +187,14 @@ struct SpirvAtomic::State {
|
||||||
atomic_expressions.Add(index->Object());
|
atomic_expressions.Add(index->Object());
|
||||||
},
|
},
|
||||||
[&](const sem::ValueExpression* e) {
|
[&](const sem::ValueExpression* e) {
|
||||||
if (auto* unary = e->Declaration()->As<ast::UnaryOpExpression>()) {
|
if (auto* unary = e->Declaration()->As<UnaryOpExpression>()) {
|
||||||
atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
|
atomic_expressions.Add(ctx.src->Sem().GetVal(unary->expr));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ast::Type AtomicTypeFor(const type::Type* ty) {
|
Type AtomicTypeFor(const type::Type* ty) {
|
||||||
return Switch(
|
return Switch(
|
||||||
ty, //
|
ty, //
|
||||||
[&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
|
[&](const type::I32*) { return b.ty.atomic(CreateASTTypeFor(ctx, ty)); },
|
||||||
|
@ -221,7 +221,7 @@ struct SpirvAtomic::State {
|
||||||
[&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
|
[&](const type::Reference* ref) { return AtomicTypeFor(ref->StoreType()); },
|
||||||
[&](Default) {
|
[&](Default) {
|
||||||
TINT_ICE(Transform, b.Diagnostics()) << "unhandled type: " << ty->FriendlyName();
|
TINT_ICE(Transform, b.Diagnostics()) << "unhandled type: " << ty->FriendlyName();
|
||||||
return ast::Type{};
|
return Type{};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,7 +249,7 @@ struct SpirvAtomic::State {
|
||||||
for (auto* vu : atomic_var->Users()) {
|
for (auto* vu : atomic_var->Users()) {
|
||||||
Switch(
|
Switch(
|
||||||
vu->Stmt()->Declaration(),
|
vu->Stmt()->Declaration(),
|
||||||
[&](const ast::AssignmentStatement* assign) {
|
[&](const AssignmentStatement* assign) {
|
||||||
auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
|
auto* sem_lhs = ctx.src->Sem().GetVal(assign->lhs);
|
||||||
if (is_ref_to_atomic_var(sem_lhs)) {
|
if (is_ref_to_atomic_var(sem_lhs)) {
|
||||||
ctx.Replace(assign, [=] {
|
ctx.Replace(assign, [=] {
|
||||||
|
@ -272,7 +272,7 @@ struct SpirvAtomic::State {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const ast::VariableDeclStatement* decl) {
|
[&](const VariableDeclStatement* decl) {
|
||||||
auto* var = decl->variable;
|
auto* var = decl->variable;
|
||||||
if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
|
if (auto* sem_init = ctx.src->Sem().GetVal(var->initializer)) {
|
||||||
if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
|
if (is_ref_to_atomic_var(sem_init->UnwrapLoad())) {
|
||||||
|
@ -293,7 +293,7 @@ struct SpirvAtomic::State {
|
||||||
SpirvAtomic::SpirvAtomic() = default;
|
SpirvAtomic::SpirvAtomic() = default;
|
||||||
SpirvAtomic::~SpirvAtomic() = default;
|
SpirvAtomic::~SpirvAtomic() = default;
|
||||||
|
|
||||||
SpirvAtomic::Stub::Stub(ProgramID pid, ast::NodeID nid, builtin::Function b)
|
SpirvAtomic::Stub::Stub(ProgramID pid, NodeID nid, builtin::Function b)
|
||||||
: Base(pid, nid, utils::Empty), builtin(b) {}
|
: Base(pid, nid, utils::Empty), builtin(b) {}
|
||||||
SpirvAtomic::Stub::~Stub() = default;
|
SpirvAtomic::Stub::~Stub() = default;
|
||||||
std::string SpirvAtomic::Stub::InternalName() const {
|
std::string SpirvAtomic::Stub::InternalName() const {
|
||||||
|
|
|
@ -41,12 +41,12 @@ class SpirvAtomic final : public utils::Castable<SpirvAtomic, Transform> {
|
||||||
|
|
||||||
/// Stub is an attribute applied to stub SPIR-V reader generated functions that need to be
|
/// Stub is an attribute applied to stub SPIR-V reader generated functions that need to be
|
||||||
/// translated to an atomic builtin.
|
/// translated to an atomic builtin.
|
||||||
class Stub final : public utils::Castable<Stub, ast::InternalAttribute> {
|
class Stub final : public utils::Castable<Stub, InternalAttribute> {
|
||||||
public:
|
public:
|
||||||
/// @param pid the identifier of the program that owns this node
|
/// @param pid the identifier of the program that owns this node
|
||||||
/// @param nid the unique node identifier
|
/// @param nid the unique node identifier
|
||||||
/// @param builtin the atomic builtin this stub represents
|
/// @param builtin the atomic builtin this stub represents
|
||||||
Stub(ProgramID pid, ast::NodeID nid, builtin::Function builtin);
|
Stub(ProgramID pid, NodeID nid, builtin::Function builtin);
|
||||||
/// Destructor
|
/// Destructor
|
||||||
~Stub() override;
|
~Stub() override;
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,7 @@ struct Std140::State {
|
||||||
|
|
||||||
// Finally, replace all expression chains that used the authored types with those that
|
// Finally, replace all expression chains that used the authored types with those that
|
||||||
// correctly use the forked types.
|
// correctly use the forked types.
|
||||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
ctx.ReplaceAll([&](const Expression* expr) -> const Expression* {
|
||||||
if (auto access = AccessChainFor(expr)) {
|
if (auto access = AccessChainFor(expr)) {
|
||||||
if (!access->std140_mat_idx.has_value()) {
|
if (!access->std140_mat_idx.has_value()) {
|
||||||
// loading a std140 type, which is not a whole or partial decomposed matrix
|
// loading a std140 type, which is not a whole or partial decomposed matrix
|
||||||
|
@ -230,7 +230,7 @@ struct Std140::State {
|
||||||
|
|
||||||
// Map of structure member in src of a matrix type, to list of decomposed column
|
// Map of structure member in src of a matrix type, to list of decomposed column
|
||||||
// members in ctx.dst.
|
// members in ctx.dst.
|
||||||
utils::Hashmap<const type::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8>
|
utils::Hashmap<const type::StructMember*, utils::Vector<const StructMember*, 4>, 8>
|
||||||
std140_mat_members;
|
std140_mat_members;
|
||||||
|
|
||||||
/// Describes a matrix that has been forked to a std140-structure holding the decomposed column
|
/// Describes a matrix that has been forked to a std140-structure holding the decomposed column
|
||||||
|
@ -284,7 +284,7 @@ struct Std140::State {
|
||||||
if (str && str->UsedAs(builtin::AddressSpace::kUniform)) {
|
if (str && str->UsedAs(builtin::AddressSpace::kUniform)) {
|
||||||
// Should this uniform buffer be forked for std140 usage?
|
// Should this uniform buffer be forked for std140 usage?
|
||||||
bool fork_std140 = false;
|
bool fork_std140 = false;
|
||||||
utils::Vector<const ast::StructMember*, 8> members;
|
utils::Vector<const StructMember*, 8> members;
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
if (auto* mat = member->Type()->As<type::Matrix>()) {
|
if (auto* mat = member->Type()->As<type::Matrix>()) {
|
||||||
// Is this member a matrix that needs decomposition for std140-layout?
|
// Is this member a matrix that needs decomposition for std140-layout?
|
||||||
|
@ -335,7 +335,7 @@ struct Std140::State {
|
||||||
// Create a new forked structure, and insert it just under the original
|
// Create a new forked structure, and insert it just under the original
|
||||||
// structure.
|
// structure.
|
||||||
auto name = b.Symbols().New(str->Name().Name() + "_std140");
|
auto name = b.Symbols().New(str->Name().Name() + "_std140");
|
||||||
auto* std140 = b.create<ast::Struct>(b.Ident(name), std::move(members),
|
auto* std140 = b.create<Struct>(b.Ident(name), std::move(members),
|
||||||
ctx.Clone(str->Declaration()->attributes));
|
ctx.Clone(str->Declaration()->attributes));
|
||||||
ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
|
ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
|
||||||
std140_structs.Add(str, name);
|
std140_structs.Add(str, name);
|
||||||
|
@ -349,7 +349,7 @@ struct Std140::State {
|
||||||
/// Populates the #std140_uniforms set.
|
/// Populates the #std140_uniforms set.
|
||||||
void ReplaceUniformVarTypes() {
|
void ReplaceUniformVarTypes() {
|
||||||
for (auto* global : src->AST().GlobalVariables()) {
|
for (auto* global : src->AST().GlobalVariables()) {
|
||||||
if (auto* var = global->As<ast::Var>()) {
|
if (auto* var = global->As<Var>()) {
|
||||||
auto* v = sem.Get(var);
|
auto* v = sem.Get(var);
|
||||||
if (v->AddressSpace() == builtin::AddressSpace::kUniform) {
|
if (v->AddressSpace() == builtin::AddressSpace::kUniform) {
|
||||||
if (auto std140_ty = Std140Type(v->Type()->UnwrapRef())) {
|
if (auto std140_ty = Std140Type(v->Type()->UnwrapRef())) {
|
||||||
|
@ -367,9 +367,7 @@ struct Std140::State {
|
||||||
/// @param str the structure that will hold the uniquely named member.
|
/// @param str the structure that will hold the uniquely named member.
|
||||||
/// @param unsuffixed the common name prefix to use for the new members.
|
/// @param unsuffixed the common name prefix to use for the new members.
|
||||||
/// @param count the number of members that need to be created.
|
/// @param count the number of members that need to be created.
|
||||||
std::string PrefixForUniqueNames(const ast::Struct* str,
|
std::string PrefixForUniqueNames(const Struct* str, Symbol unsuffixed, uint32_t count) const {
|
||||||
Symbol unsuffixed,
|
|
||||||
uint32_t count) const {
|
|
||||||
auto prefix = unsuffixed.Name();
|
auto prefix = unsuffixed.Name();
|
||||||
// Keep on inserting '_' between the unsuffixed name and the suffix numbers until the name
|
// Keep on inserting '_' between the unsuffixed name and the suffix numbers until the name
|
||||||
// is unique.
|
// is unique.
|
||||||
|
@ -400,14 +398,14 @@ struct Std140::State {
|
||||||
/// If the semantic type is not split for std140-layout, then nullptr is returned.
|
/// If the semantic type is not split for std140-layout, then nullptr is returned.
|
||||||
/// @note will construct new std140 structures to hold decomposed matrices, populating
|
/// @note will construct new std140 structures to hold decomposed matrices, populating
|
||||||
/// #std140_mats.
|
/// #std140_mats.
|
||||||
ast::Type Std140Type(const type::Type* ty) {
|
Type Std140Type(const type::Type* ty) {
|
||||||
return Switch(
|
return Switch(
|
||||||
ty, //
|
ty, //
|
||||||
[&](const type::Struct* str) {
|
[&](const type::Struct* str) {
|
||||||
if (auto std140 = std140_structs.Find(str)) {
|
if (auto std140 = std140_structs.Find(str)) {
|
||||||
return b.ty(*std140);
|
return b.ty(*std140);
|
||||||
}
|
}
|
||||||
return ast::Type{};
|
return Type{};
|
||||||
},
|
},
|
||||||
[&](const type::Matrix* mat) {
|
[&](const type::Matrix* mat) {
|
||||||
if (MatrixNeedsDecomposing(mat)) {
|
if (MatrixNeedsDecomposing(mat)) {
|
||||||
|
@ -426,13 +424,13 @@ struct Std140::State {
|
||||||
});
|
});
|
||||||
return b.ty(std140_mat.name);
|
return b.ty(std140_mat.name);
|
||||||
}
|
}
|
||||||
return ast::Type{};
|
return Type{};
|
||||||
},
|
},
|
||||||
[&](const type::Array* arr) {
|
[&](const type::Array* arr) {
|
||||||
if (auto std140 = Std140Type(arr->ElemType())) {
|
if (auto std140 = Std140Type(arr->ElemType())) {
|
||||||
utils::Vector<const ast::Attribute*, 1> attrs;
|
utils::Vector<const Attribute*, 1> attrs;
|
||||||
if (!arr->IsStrideImplicit()) {
|
if (!arr->IsStrideImplicit()) {
|
||||||
attrs.Push(b.create<ast::StrideAttribute>(arr->Stride()));
|
attrs.Push(b.create<StrideAttribute>(arr->Stride()));
|
||||||
}
|
}
|
||||||
auto count = arr->ConstantCount();
|
auto count = arr->ConstantCount();
|
||||||
if (TINT_UNLIKELY(!count)) {
|
if (TINT_UNLIKELY(!count)) {
|
||||||
|
@ -446,7 +444,7 @@ struct Std140::State {
|
||||||
}
|
}
|
||||||
return b.ty.array(std140, b.Expr(u32(count.value())), std::move(attrs));
|
return b.ty.array(std140, b.Expr(u32(count.value())), std::move(attrs));
|
||||||
}
|
}
|
||||||
return ast::Type{};
|
return Type{};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -455,7 +453,7 @@ struct Std140::State {
|
||||||
/// @param align the alignment in bytes of the matrix.
|
/// @param align the alignment in bytes of the matrix.
|
||||||
/// @param size the size in bytes of the matrix.
|
/// @param size the size in bytes of the matrix.
|
||||||
/// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst).
|
/// @returns a vector of decomposed matrix column vectors as structure members (in ctx.dst).
|
||||||
utils::Vector<const ast::StructMember*, 4> DecomposedMatrixStructMembers(
|
utils::Vector<const StructMember*, 4> DecomposedMatrixStructMembers(
|
||||||
const type::Matrix* mat,
|
const type::Matrix* mat,
|
||||||
const std::string& name_prefix,
|
const std::string& name_prefix,
|
||||||
uint32_t align,
|
uint32_t align,
|
||||||
|
@ -463,9 +461,9 @@ struct Std140::State {
|
||||||
// Replace the member with column vectors.
|
// Replace the member with column vectors.
|
||||||
const auto num_columns = mat->columns();
|
const auto num_columns = mat->columns();
|
||||||
// Build a struct member for each column of the matrix
|
// Build a struct member for each column of the matrix
|
||||||
utils::Vector<const ast::StructMember*, 4> out;
|
utils::Vector<const StructMember*, 4> out;
|
||||||
for (uint32_t i = 0; i < num_columns; i++) {
|
for (uint32_t i = 0; i < num_columns; i++) {
|
||||||
utils::Vector<const ast::Attribute*, 1> attributes;
|
utils::Vector<const Attribute*, 1> attributes;
|
||||||
if ((i == 0) && mat->Align() != align) {
|
if ((i == 0) && mat->Align() != align) {
|
||||||
// The matrix was @align() annotated with a larger alignment
|
// The matrix was @align() annotated with a larger alignment
|
||||||
// than the natural alignment for the matrix. This extra padding
|
// than the natural alignment for the matrix. This extra padding
|
||||||
|
@ -493,7 +491,7 @@ struct Std140::State {
|
||||||
/// Walks the @p ast_expr, constructing and returning an AccessChain.
|
/// Walks the @p ast_expr, constructing and returning an AccessChain.
|
||||||
/// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer,
|
/// @returns an AccessChain if the expression is an access to a std140-forked uniform buffer,
|
||||||
/// otherwise returns a std::nullopt.
|
/// otherwise returns a std::nullopt.
|
||||||
std::optional<AccessChain> AccessChainFor(const ast::Expression* ast_expr) {
|
std::optional<AccessChain> AccessChainFor(const Expression* ast_expr) {
|
||||||
auto* expr = sem.GetVal(ast_expr);
|
auto* expr = sem.GetVal(ast_expr);
|
||||||
if (!expr) {
|
if (!expr) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
@ -576,10 +574,10 @@ struct Std140::State {
|
||||||
[&](const sem::ValueExpression* e) {
|
[&](const sem::ValueExpression* e) {
|
||||||
// Walk past indirection and address-of unary ops.
|
// Walk past indirection and address-of unary ops.
|
||||||
return Switch(e->Declaration(), //
|
return Switch(e->Declaration(), //
|
||||||
[&](const ast::UnaryOpExpression* u) {
|
[&](const UnaryOpExpression* u) {
|
||||||
switch (u->op) {
|
switch (u->op) {
|
||||||
case ast::UnaryOp::kAddressOf:
|
case UnaryOp::kAddressOf:
|
||||||
case ast::UnaryOp::kIndirection:
|
case UnaryOp::kIndirection:
|
||||||
expr = sem.GetVal(u->expr);
|
expr = sem.GetVal(u->expr);
|
||||||
return Action::kContinue;
|
return Action::kContinue;
|
||||||
default:
|
default:
|
||||||
|
@ -660,8 +658,8 @@ struct Std140::State {
|
||||||
/// Generates and returns an expression that loads the value from a std140 uniform buffer,
|
/// Generates and returns an expression that loads the value from a std140 uniform buffer,
|
||||||
/// converting the final result to a non-std140 type.
|
/// converting the final result to a non-std140 type.
|
||||||
/// @param chain the access chain from a uniform buffer to the value to load.
|
/// @param chain the access chain from a uniform buffer to the value to load.
|
||||||
const ast::Expression* LoadWithConvert(const AccessChain& chain) {
|
const Expression* LoadWithConvert(const AccessChain& chain) {
|
||||||
const ast::Expression* expr = nullptr;
|
const Expression* expr = nullptr;
|
||||||
const type::Type* ty = nullptr;
|
const type::Type* ty = nullptr;
|
||||||
auto dynamic_index = [&](size_t idx) {
|
auto dynamic_index = [&](size_t idx) {
|
||||||
return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
|
return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
|
||||||
|
@ -678,7 +676,7 @@ struct Std140::State {
|
||||||
/// std140-forked type to the type @p ty. If @p expr is not a std140-forked type, then Convert()
|
/// std140-forked type to the type @p ty. If @p expr is not a std140-forked type, then Convert()
|
||||||
/// will simply return @p expr.
|
/// will simply return @p expr.
|
||||||
/// @returns the converted value expression.
|
/// @returns the converted value expression.
|
||||||
const ast::Expression* Convert(const type::Type* ty, const ast::Expression* expr) {
|
const Expression* Convert(const type::Type* ty, const Expression* expr) {
|
||||||
// Get an existing, or create a new function for converting the std140 type to ty.
|
// Get an existing, or create a new function for converting the std140 type to ty.
|
||||||
auto fn = conv_fns.GetOrCreate(ty, [&] {
|
auto fn = conv_fns.GetOrCreate(ty, [&] {
|
||||||
auto std140_ty = Std140Type(ty);
|
auto std140_ty = Std140Type(ty);
|
||||||
|
@ -690,20 +688,20 @@ struct Std140::State {
|
||||||
// The converter function takes a single argument of the std140 type.
|
// The converter function takes a single argument of the std140 type.
|
||||||
auto* param = b.Param("val", std140_ty);
|
auto* param = b.Param("val", std140_ty);
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 3> stmts;
|
utils::Vector<const Statement*, 3> stmts;
|
||||||
|
|
||||||
Switch(
|
Switch(
|
||||||
ty, //
|
ty, //
|
||||||
[&](const type::Struct* str) {
|
[&](const type::Struct* str) {
|
||||||
// Convert each of the structure members using either a converter function
|
// Convert each of the structure members using either a converter function
|
||||||
// call, or by reassembling a std140 matrix from column vector members.
|
// call, or by reassembling a std140 matrix from column vector members.
|
||||||
utils::Vector<const ast::Expression*, 8> args;
|
utils::Vector<const Expression*, 8> args;
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
if (auto col_members = std140_mat_members.Find(member)) {
|
if (auto col_members = std140_mat_members.Find(member)) {
|
||||||
// std140 decomposed matrix. Reassemble.
|
// std140 decomposed matrix. Reassemble.
|
||||||
auto mat_ty = CreateASTTypeFor(ctx, member->Type());
|
auto mat_ty = CreateASTTypeFor(ctx, member->Type());
|
||||||
auto mat_args =
|
auto mat_args =
|
||||||
utils::Transform(*col_members, [&](const ast::StructMember* m) {
|
utils::Transform(*col_members, [&](const StructMember* m) {
|
||||||
return b.MemberAccessor(param, m->name->symbol);
|
return b.MemberAccessor(param, m->name->symbol);
|
||||||
});
|
});
|
||||||
args.Push(b.Call(mat_ty, std::move(mat_args)));
|
args.Push(b.Call(mat_ty, std::move(mat_args)));
|
||||||
|
@ -719,7 +717,7 @@ struct Std140::State {
|
||||||
// Reassemble a std140 matrix from the structure of column vector members.
|
// Reassemble a std140 matrix from the structure of column vector members.
|
||||||
auto std140_mat = std140_mats.Get(mat);
|
auto std140_mat = std140_mats.Get(mat);
|
||||||
if (TINT_LIKELY(std140_mat)) {
|
if (TINT_LIKELY(std140_mat)) {
|
||||||
utils::Vector<const ast::Expression*, 8> args;
|
utils::Vector<const Expression*, 8> args;
|
||||||
// std140 decomposed matrix. Reassemble.
|
// std140 decomposed matrix. Reassemble.
|
||||||
auto mat_ty = CreateASTTypeFor(ctx, mat);
|
auto mat_ty = CreateASTTypeFor(ctx, mat);
|
||||||
auto mat_args = utils::Transform(std140_mat->columns, [&](Symbol name) {
|
auto mat_args = utils::Transform(std140_mat->columns, [&](Symbol name) {
|
||||||
|
@ -782,7 +780,7 @@ struct Std140::State {
|
||||||
/// @param access the access chain from the uniform buffer to either the whole matrix or part of
|
/// @param access the access chain from the uniform buffer to either the whole matrix or part of
|
||||||
/// the matrix (column, column-swizzle, or element).
|
/// the matrix (column, column-swizzle, or element).
|
||||||
/// @returns the loaded value expression.
|
/// @returns the loaded value expression.
|
||||||
const ast::Expression* LoadMatrixWithFn(const AccessChain& access) {
|
const Expression* LoadMatrixWithFn(const AccessChain& access) {
|
||||||
// Get an existing, or create a new function for loading the uniform buffer value.
|
// Get an existing, or create a new function for loading the uniform buffer value.
|
||||||
// This function is keyed off the uniform buffer variable and the access chain.
|
// This function is keyed off the uniform buffer variable and the access chain.
|
||||||
auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] {
|
auto fn = load_fns.GetOrCreate(LoadFnKey{access.var, access.indices}, [&] {
|
||||||
|
@ -810,14 +808,14 @@ struct Std140::State {
|
||||||
/// column-swizzle, or element).
|
/// column-swizzle, or element).
|
||||||
/// @note The matrix column must be statically indexed to use this method.
|
/// @note The matrix column must be statically indexed to use this method.
|
||||||
/// @returns the loaded value expression.
|
/// @returns the loaded value expression.
|
||||||
const ast::Expression* LoadSubMatrixInline(const AccessChain& chain) {
|
const Expression* LoadSubMatrixInline(const AccessChain& chain) {
|
||||||
// Method for generating dynamic index expressions.
|
// Method for generating dynamic index expressions.
|
||||||
// As this is inline, we can just clone the expression.
|
// As this is inline, we can just clone the expression.
|
||||||
auto dynamic_index = [&](size_t idx) {
|
auto dynamic_index = [&](size_t idx) {
|
||||||
return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
|
return ctx.Clone(chain.dynamic_indices[idx]->Declaration());
|
||||||
};
|
};
|
||||||
|
|
||||||
const ast::Expression* expr = nullptr;
|
const Expression* expr = nullptr;
|
||||||
const type::Type* ty = nullptr;
|
const type::Type* ty = nullptr;
|
||||||
|
|
||||||
// Build the expression up to, but not including the matrix member
|
// Build the expression up to, but not including the matrix member
|
||||||
|
@ -891,7 +889,7 @@ struct Std140::State {
|
||||||
std::string name = "load";
|
std::string name = "load";
|
||||||
|
|
||||||
// The switch cases
|
// The switch cases
|
||||||
utils::Vector<const ast::CaseStatement*, 4> cases;
|
utils::Vector<const CaseStatement*, 4> cases;
|
||||||
|
|
||||||
// The function return type.
|
// The function return type.
|
||||||
const type::Type* ret_ty = nullptr;
|
const type::Type* ret_ty = nullptr;
|
||||||
|
@ -899,7 +897,7 @@ struct Std140::State {
|
||||||
// Build switch() cases for each column of the matrix
|
// Build switch() cases for each column of the matrix
|
||||||
auto num_columns = chain.std140_mat_ty->columns();
|
auto num_columns = chain.std140_mat_ty->columns();
|
||||||
for (uint32_t column_idx = 0; column_idx < num_columns; column_idx++) {
|
for (uint32_t column_idx = 0; column_idx < num_columns; column_idx++) {
|
||||||
const ast::Expression* expr = nullptr;
|
const Expression* expr = nullptr;
|
||||||
const type::Type* ty = nullptr;
|
const type::Type* ty = nullptr;
|
||||||
|
|
||||||
// Build the expression up to, but not including the matrix
|
// Build the expression up to, but not including the matrix
|
||||||
|
@ -991,7 +989,7 @@ struct Std140::State {
|
||||||
return b.Expr(dynamic_index_params[idx]->name->symbol);
|
return b.Expr(dynamic_index_params[idx]->name->symbol);
|
||||||
};
|
};
|
||||||
|
|
||||||
const ast::Expression* expr = nullptr;
|
const Expression* expr = nullptr;
|
||||||
const type::Type* ty = nullptr;
|
const type::Type* ty = nullptr;
|
||||||
std::string name = "load";
|
std::string name = "load";
|
||||||
|
|
||||||
|
@ -1005,13 +1003,13 @@ struct Std140::State {
|
||||||
name += "_" + access_name;
|
name += "_" + access_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 2> stmts;
|
utils::Vector<const Statement*, 2> stmts;
|
||||||
|
|
||||||
// Create a temporary pointer to the structure that holds the matrix columns
|
// Create a temporary pointer to the structure that holds the matrix columns
|
||||||
auto* let = b.Let("s", b.AddressOf(expr));
|
auto* let = b.Let("s", b.AddressOf(expr));
|
||||||
stmts.Push(b.Decl(let));
|
stmts.Push(b.Decl(let));
|
||||||
|
|
||||||
utils::Vector<const ast::MemberAccessorExpression*, 4> columns;
|
utils::Vector<const MemberAccessorExpression*, 4> columns;
|
||||||
if (auto* str = tint::As<type::Struct>(ty)) {
|
if (auto* str = tint::As<type::Struct>(ty)) {
|
||||||
// Structure member matrix. The columns are decomposed into the structure.
|
// Structure member matrix. The columns are decomposed into the structure.
|
||||||
auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]);
|
auto mat_member_idx = std::get<u32>(chain.indices[std140_mat_idx]);
|
||||||
|
@ -1053,7 +1051,7 @@ struct Std140::State {
|
||||||
/// Return type of BuildAccessExpr()
|
/// Return type of BuildAccessExpr()
|
||||||
struct ExprTypeName {
|
struct ExprTypeName {
|
||||||
/// The new, post-access expression
|
/// The new, post-access expression
|
||||||
const ast::Expression* expr;
|
const Expression* expr;
|
||||||
/// The type of #expr
|
/// The type of #expr
|
||||||
const type::Type* type;
|
const type::Type* type;
|
||||||
/// A name segment which can be used to build sensible names for helper functions
|
/// A name segment which can be used to build sensible names for helper functions
|
||||||
|
@ -1067,11 +1065,11 @@ struct Std140::State {
|
||||||
/// @param dynamic_index a function that obtains the i'th dynamic index
|
/// @param dynamic_index a function that obtains the i'th dynamic index
|
||||||
/// @returns a ExprTypeName which holds the new expression, new type and a name segment which
|
/// @returns a ExprTypeName which holds the new expression, new type and a name segment which
|
||||||
/// can be used for creating helper function names.
|
/// can be used for creating helper function names.
|
||||||
ExprTypeName BuildAccessExpr(const ast::Expression* lhs,
|
ExprTypeName BuildAccessExpr(const Expression* lhs,
|
||||||
const type::Type* ty,
|
const type::Type* ty,
|
||||||
const AccessChain& chain,
|
const AccessChain& chain,
|
||||||
size_t index,
|
size_t index,
|
||||||
std::function<const ast::Expression*(size_t)> dynamic_index) {
|
std::function<const Expression*(size_t)> dynamic_index) {
|
||||||
auto& access = chain.indices[index];
|
auto& access = chain.indices[index];
|
||||||
|
|
||||||
if (std::get_if<UniformVariable>(&access)) {
|
if (std::get_if<UniformVariable>(&access)) {
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace {
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* node : program->AST().GlobalVariables()) {
|
for (auto* node : program->AST().GlobalVariables()) {
|
||||||
if (node->Is<ast::Override>()) {
|
if (node->Is<Override>()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,12 +61,12 @@ Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
|
||||||
return SkipTransform;
|
return SkipTransform;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* {
|
ctx.ReplaceAll([&](const Override* w) -> const Const* {
|
||||||
auto* sem = ctx.src->Sem().Get(w);
|
auto* sem = ctx.src->Sem().Get(w);
|
||||||
|
|
||||||
auto source = ctx.Clone(w->source);
|
auto source = ctx.Clone(w->source);
|
||||||
auto sym = ctx.Clone(w->name->symbol);
|
auto sym = ctx.Clone(w->name->symbol);
|
||||||
ast::Type ty = w->type ? ctx.Clone(w->type) : ast::Type{};
|
Type ty = w->type ? ctx.Clone(w->type) : Type{};
|
||||||
|
|
||||||
// No replacement provided, just clone the override node as a const.
|
// No replacement provided, just clone the override node as a const.
|
||||||
auto iter = data->map.find(sem->OverrideId());
|
auto iter = data->map.find(sem->OverrideId());
|
||||||
|
@ -102,7 +102,7 @@ Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
|
||||||
// If the object is not materialized, and the 'override' variable is turned to a 'const', the
|
// If the object is not materialized, and the 'override' variable is turned to a 'const', the
|
||||||
// resulting type of the index may change. See: crbug.com/tint/1697.
|
// resulting type of the index may change. See: crbug.com/tint/1697.
|
||||||
ctx.ReplaceAll(
|
ctx.ReplaceAll(
|
||||||
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
|
[&](const IndexAccessorExpression* expr) -> const IndexAccessorExpression* {
|
||||||
if (auto* sem = src->Sem().Get(expr)) {
|
if (auto* sem = src->Sem().Get(expr)) {
|
||||||
if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
|
if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
|
||||||
if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&
|
if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&
|
||||||
|
|
|
@ -85,18 +85,18 @@ struct Texture1DTo2D::State {
|
||||||
return SkipTransform;
|
return SkipTransform;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto create_var = [&](const ast::Variable* v, ast::Type type) -> const ast::Variable* {
|
auto create_var = [&](const Variable* v, Type type) -> const Variable* {
|
||||||
if (v->As<ast::Parameter>()) {
|
if (v->As<Parameter>()) {
|
||||||
return ctx.dst->Param(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
|
return ctx.dst->Param(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
|
||||||
} else {
|
} else {
|
||||||
return ctx.dst->Var(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
|
return ctx.dst->Var(ctx.Clone(v->name->symbol), type, ctx.Clone(v->attributes));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::Variable* v) -> const ast::Variable* {
|
ctx.ReplaceAll([&](const Variable* v) -> const Variable* {
|
||||||
const ast::Variable* r = Switch(
|
const Variable* r = Switch(
|
||||||
sem.Get(v)->Type()->UnwrapRef(),
|
sem.Get(v)->Type()->UnwrapRef(),
|
||||||
[&](const type::SampledTexture* tex) -> const ast::Variable* {
|
[&](const type::SampledTexture* tex) -> const Variable* {
|
||||||
if (tex->dim() == type::TextureDimension::k1d) {
|
if (tex->dim() == type::TextureDimension::k1d) {
|
||||||
auto type = ctx.dst->ty.sampled_texture(type::TextureDimension::k2d,
|
auto type = ctx.dst->ty.sampled_texture(type::TextureDimension::k2d,
|
||||||
CreateASTTypeFor(ctx, tex->type()));
|
CreateASTTypeFor(ctx, tex->type()));
|
||||||
|
@ -105,7 +105,7 @@ struct Texture1DTo2D::State {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](const type::StorageTexture* storage_tex) -> const ast::Variable* {
|
[&](const type::StorageTexture* storage_tex) -> const Variable* {
|
||||||
if (storage_tex->dim() == type::TextureDimension::k1d) {
|
if (storage_tex->dim() == type::TextureDimension::k1d) {
|
||||||
auto type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d,
|
auto type = ctx.dst->ty.storage_texture(type::TextureDimension::k2d,
|
||||||
storage_tex->texel_format(),
|
storage_tex->texel_format(),
|
||||||
|
@ -119,7 +119,7 @@ struct Texture1DTo2D::State {
|
||||||
return r;
|
return r;
|
||||||
});
|
});
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::CallExpression* c) -> const ast::Expression* {
|
ctx.ReplaceAll([&](const CallExpression* c) -> const Expression* {
|
||||||
auto* call = sem.Get(c)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = sem.Get(c)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
if (!call) {
|
if (!call) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -141,7 +141,7 @@ struct Texture1DTo2D::State {
|
||||||
if (builtin->Type() == builtin::Function::kTextureDimensions) {
|
if (builtin->Type() == builtin::Function::kTextureDimensions) {
|
||||||
// If this textureDimensions() call is in a CallStatement, we can leave it
|
// If this textureDimensions() call is in a CallStatement, we can leave it
|
||||||
// unmodified since the return value will be dropped on the floor anyway.
|
// unmodified since the return value will be dropped on the floor anyway.
|
||||||
if (call->Stmt()->Declaration()->Is<ast::CallStatement>()) {
|
if (call->Stmt()->Declaration()->Is<CallStatement>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto* new_call = ctx.CloneWithoutTransform(c);
|
auto* new_call = ctx.CloneWithoutTransform(c);
|
||||||
|
@ -153,14 +153,14 @@ struct Texture1DTo2D::State {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
utils::Vector<const ast::Expression*, 8> args;
|
utils::Vector<const Expression*, 8> args;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (auto* arg : c->args) {
|
for (auto* arg : c->args) {
|
||||||
if (index == coords_index) {
|
if (index == coords_index) {
|
||||||
auto* ctype = call->Arguments()[static_cast<size_t>(coords_index)]->Type();
|
auto* ctype = call->Arguments()[static_cast<size_t>(coords_index)]->Type();
|
||||||
auto* coords = c->args[static_cast<size_t>(coords_index)];
|
auto* coords = c->args[static_cast<size_t>(coords_index)];
|
||||||
|
|
||||||
const ast::LiteralExpression* half = nullptr;
|
const LiteralExpression* half = nullptr;
|
||||||
if (ctype->is_integer_scalar()) {
|
if (ctype->is_integer_scalar()) {
|
||||||
half = ctx.dst->Expr(0_a);
|
half = ctx.dst->Expr(0_a);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -60,23 +60,23 @@ Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
|
void Transform::RemoveStatement(CloneContext& ctx, const Statement* stmt) {
|
||||||
auto* sem = ctx.src->Sem().Get(stmt);
|
auto* sem = ctx.src->Sem().Get(stmt);
|
||||||
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
|
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
|
||||||
ctx.Remove(block->Declaration()->statements, stmt);
|
ctx.Remove(block->Declaration()->statements, stmt);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (TINT_LIKELY(tint::Is<sem::ForLoopStatement>(sem->Parent()))) {
|
if (TINT_LIKELY(tint::Is<sem::ForLoopStatement>(sem->Parent()))) {
|
||||||
ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr));
|
ctx.Replace(stmt, static_cast<Expression*>(nullptr));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||||
<< "unable to remove statement from parent of type " << sem->TypeInfo().name;
|
<< "unable to remove statement from parent of type " << sem->TypeInfo().name;
|
||||||
}
|
}
|
||||||
|
|
||||||
ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
|
Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
|
||||||
if (ty->Is<type::Void>()) {
|
if (ty->Is<type::Void>()) {
|
||||||
return ast::Type{};
|
return Type{};
|
||||||
}
|
}
|
||||||
if (ty->Is<type::I32>()) {
|
if (ty->Is<type::I32>()) {
|
||||||
return ctx.dst->ty.i32();
|
return ctx.dst->ty.i32();
|
||||||
|
@ -108,9 +108,9 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
|
||||||
}
|
}
|
||||||
if (auto* a = ty->As<type::Array>()) {
|
if (auto* a = ty->As<type::Array>()) {
|
||||||
auto el = CreateASTTypeFor(ctx, a->ElemType());
|
auto el = CreateASTTypeFor(ctx, a->ElemType());
|
||||||
utils::Vector<const ast::Attribute*, 1> attrs;
|
utils::Vector<const Attribute*, 1> attrs;
|
||||||
if (!a->IsStrideImplicit()) {
|
if (!a->IsStrideImplicit()) {
|
||||||
attrs.Push(ctx.dst->create<ast::StrideAttribute>(a->Stride()));
|
attrs.Push(ctx.dst->create<StrideAttribute>(a->Stride()));
|
||||||
}
|
}
|
||||||
if (a->Count()->Is<type::RuntimeArrayCount>()) {
|
if (a->Count()->Is<type::RuntimeArrayCount>()) {
|
||||||
return ctx.dst->ty.array(el, std::move(attrs));
|
return ctx.dst->ty.array(el, std::move(attrs));
|
||||||
|
@ -125,7 +125,7 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
|
||||||
// See crbug.com/tint/1764.
|
// See crbug.com/tint/1764.
|
||||||
// Look for a type alias for this array.
|
// Look for a type alias for this array.
|
||||||
for (auto* type_decl : ctx.src->AST().TypeDecls()) {
|
for (auto* type_decl : ctx.src->AST().TypeDecls()) {
|
||||||
if (auto* alias = type_decl->As<ast::Alias>()) {
|
if (auto* alias = type_decl->As<Alias>()) {
|
||||||
if (ty == ctx.src->Sem().Get(alias)) {
|
if (ty == ctx.src->Sem().Get(alias)) {
|
||||||
// Alias found. Use the alias name to ensure types compare equal.
|
// Alias found. Use the alias name to ensure types compare equal.
|
||||||
return ctx.dst->ty(ctx.Clone(alias->name->symbol));
|
return ctx.dst->ty(ctx.Clone(alias->name->symbol));
|
||||||
|
@ -184,7 +184,7 @@ ast::Type Transform::CreateASTTypeFor(CloneContext& ctx, const type::Type* ty) {
|
||||||
}
|
}
|
||||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
||||||
<< "Unhandled type: " << ty->TypeInfo().name;
|
<< "Unhandled type: " << ty->TypeInfo().name;
|
||||||
return ast::Type{};
|
return Type{};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tint::ast::transform
|
} // namespace tint::ast::transform
|
||||||
|
|
|
@ -181,11 +181,11 @@ class Transform : public utils::Castable<Transform> {
|
||||||
const DataMap& inputs,
|
const DataMap& inputs,
|
||||||
DataMap& outputs) const = 0;
|
DataMap& outputs) const = 0;
|
||||||
|
|
||||||
/// CreateASTTypeFor constructs new ast::Type that reconstructs the semantic type `ty`.
|
/// CreateASTTypeFor constructs new Type that reconstructs the semantic type `ty`.
|
||||||
/// @param ctx the clone context
|
/// @param ctx the clone context
|
||||||
/// @param ty the semantic type to reconstruct
|
/// @param ty the semantic type to reconstruct
|
||||||
/// @returns an ast::Type that when resolved, will produce the semantic type `ty`.
|
/// @returns an Type that when resolved, will produce the semantic type `ty`.
|
||||||
static ast::Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
|
static Type CreateASTTypeFor(CloneContext& ctx, const type::Type* ty);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Removes the statement `stmt` from the transformed program.
|
/// Removes the statement `stmt` from the transformed program.
|
||||||
|
@ -193,7 +193,7 @@ class Transform : public utils::Castable<Transform> {
|
||||||
/// continuing of for-loops.
|
/// continuing of for-loops.
|
||||||
/// @param ctx the clone context
|
/// @param ctx the clone context
|
||||||
/// @param stmt the statement to remove when the program is cloned
|
/// @param stmt the statement to remove when the program is cloned
|
||||||
static void RemoveStatement(CloneContext& ctx, const ast::Statement* stmt);
|
static void RemoveStatement(CloneContext& ctx, const Statement* stmt);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tint::ast::transform
|
} // namespace tint::ast::transform
|
||||||
|
|
|
@ -32,7 +32,7 @@ struct CreateASTTypeForTest : public testing::Test, public Transform {
|
||||||
return SkipTransform;
|
return SkipTransform;
|
||||||
}
|
}
|
||||||
|
|
||||||
ast::Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
|
Type create(std::function<type::Type*(ProgramBuilder&)> create_sem_type) {
|
||||||
ProgramBuilder sem_type_builder;
|
ProgramBuilder sem_type_builder;
|
||||||
auto* sem_type = create_sem_type(sem_type_builder);
|
auto* sem_type = create_sem_type(sem_type_builder);
|
||||||
Program program(std::move(sem_type_builder));
|
Program program(std::move(sem_type_builder));
|
||||||
|
@ -44,9 +44,7 @@ struct CreateASTTypeForTest : public testing::Test, public Transform {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(CreateASTTypeForTest, Basic) {
|
TEST_F(CreateASTTypeForTest, Basic) {
|
||||||
auto check = [&](ast::Type ty, const char* expect) {
|
auto check = [&](Type ty, const char* expect) { CheckIdentifier(ty->identifier, expect); };
|
||||||
ast::CheckIdentifier(ty->identifier, expect);
|
|
||||||
};
|
|
||||||
|
|
||||||
check(create([](ProgramBuilder& b) { return b.create<type::I32>(); }), "i32");
|
check(create([](ProgramBuilder& b) { return b.create<type::I32>(); }), "i32");
|
||||||
check(create([](ProgramBuilder& b) { return b.create<type::U32>(); }), "u32");
|
check(create([](ProgramBuilder& b) { return b.create<type::U32>(); }), "u32");
|
||||||
|
@ -61,14 +59,14 @@ TEST_F(CreateASTTypeForTest, Matrix) {
|
||||||
return b.create<type::Matrix>(column_type, 3u);
|
return b.create<type::Matrix>(column_type, 3u);
|
||||||
});
|
});
|
||||||
|
|
||||||
ast::CheckIdentifier(mat, ast::Template("mat3x2", "f32"));
|
CheckIdentifier(mat, Template("mat3x2", "f32"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateASTTypeForTest, Vector) {
|
TEST_F(CreateASTTypeForTest, Vector) {
|
||||||
auto vec =
|
auto vec =
|
||||||
create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
|
create([](ProgramBuilder& b) { return b.create<type::Vector>(b.create<type::F32>(), 2u); });
|
||||||
|
|
||||||
ast::CheckIdentifier(vec, ast::Template("vec2", "f32"));
|
CheckIdentifier(vec, Template("vec2", "f32"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
|
TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
|
||||||
|
@ -77,8 +75,8 @@ TEST_F(CreateASTTypeForTest, ArrayImplicitStride) {
|
||||||
4u, 4u, 32u, 32u);
|
4u, 4u, 32u, 32u);
|
||||||
});
|
});
|
||||||
|
|
||||||
ast::CheckIdentifier(arr, ast::Template("array", "f32", 2_u));
|
CheckIdentifier(arr, Template("array", "f32", 2_u));
|
||||||
auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>();
|
auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
|
||||||
ASSERT_NE(tmpl_attr, nullptr);
|
ASSERT_NE(tmpl_attr, nullptr);
|
||||||
EXPECT_TRUE(tmpl_attr->attributes.IsEmpty());
|
EXPECT_TRUE(tmpl_attr->attributes.IsEmpty());
|
||||||
}
|
}
|
||||||
|
@ -88,12 +86,12 @@ TEST_F(CreateASTTypeForTest, ArrayNonImplicitStride) {
|
||||||
return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
|
return b.create<type::Array>(b.create<type::F32>(), b.create<type::ConstantArrayCount>(2u),
|
||||||
4u, 4u, 64u, 32u);
|
4u, 4u, 64u, 32u);
|
||||||
});
|
});
|
||||||
ast::CheckIdentifier(arr, ast::Template("array", "f32", 2_u));
|
CheckIdentifier(arr, Template("array", "f32", 2_u));
|
||||||
auto* tmpl_attr = arr->identifier->As<ast::TemplatedIdentifier>();
|
auto* tmpl_attr = arr->identifier->As<TemplatedIdentifier>();
|
||||||
ASSERT_NE(tmpl_attr, nullptr);
|
ASSERT_NE(tmpl_attr, nullptr);
|
||||||
ASSERT_EQ(tmpl_attr->attributes.Length(), 1u);
|
ASSERT_EQ(tmpl_attr->attributes.Length(), 1u);
|
||||||
ASSERT_TRUE(tmpl_attr->attributes[0]->Is<ast::StrideAttribute>());
|
ASSERT_TRUE(tmpl_attr->attributes[0]->Is<StrideAttribute>());
|
||||||
ASSERT_EQ(tmpl_attr->attributes[0]->As<ast::StrideAttribute>()->stride, 64u);
|
ASSERT_EQ(tmpl_attr->attributes[0]->As<StrideAttribute>()->stride, 64u);
|
||||||
}
|
}
|
||||||
|
|
||||||
// crbug.com/tint/1764
|
// crbug.com/tint/1764
|
||||||
|
@ -114,7 +112,7 @@ TEST_F(CreateASTTypeForTest, AliasedArrayWithComplexOverrideLength) {
|
||||||
|
|
||||||
CloneContext ctx(&ast_type_builder, &program, false);
|
CloneContext ctx(&ast_type_builder, &program, false);
|
||||||
auto ast_ty = CreateASTTypeFor(ctx, arr_ty);
|
auto ast_ty = CreateASTTypeFor(ctx, arr_ty);
|
||||||
ast::CheckIdentifier(ast_ty, "A");
|
CheckIdentifier(ast_ty, "A");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateASTTypeForTest, Struct) {
|
TEST_F(CreateASTTypeForTest, Struct) {
|
||||||
|
@ -124,7 +122,7 @@ TEST_F(CreateASTTypeForTest, Struct) {
|
||||||
4u /* size */, 4u /* size_no_padding */);
|
4u /* size */, 4u /* size_no_padding */);
|
||||||
});
|
});
|
||||||
|
|
||||||
ast::CheckIdentifier(str, "S");
|
CheckIdentifier(str, "S");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateASTTypeForTest, PrivatePointer) {
|
TEST_F(CreateASTTypeForTest, PrivatePointer) {
|
||||||
|
@ -133,7 +131,7 @@ TEST_F(CreateASTTypeForTest, PrivatePointer) {
|
||||||
builtin::Access::kReadWrite);
|
builtin::Access::kReadWrite);
|
||||||
});
|
});
|
||||||
|
|
||||||
ast::CheckIdentifier(ptr, ast::Template("ptr", "private", "i32"));
|
CheckIdentifier(ptr, Template("ptr", "private", "i32"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
|
TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
|
||||||
|
@ -142,7 +140,7 @@ TEST_F(CreateASTTypeForTest, StorageReadWritePointer) {
|
||||||
builtin::Access::kReadWrite);
|
builtin::Access::kReadWrite);
|
||||||
});
|
});
|
||||||
|
|
||||||
ast::CheckIdentifier(ptr, ast::Template("ptr", "storage", "i32", "read_write"));
|
CheckIdentifier(ptr, Template("ptr", "storage", "i32", "read_write"));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -74,7 +74,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (func_ast->PipelineStage() != ast::PipelineStage::kVertex) {
|
if (func_ast->PipelineStage() != PipelineStage::kVertex) {
|
||||||
// Currently only vertex stage could have interstage output variables that need
|
// Currently only vertex stage could have interstage output variables that need
|
||||||
// truncated.
|
// truncated.
|
||||||
continue;
|
continue;
|
||||||
|
@ -118,8 +118,8 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
|
||||||
old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] {
|
old_shader_io_structs_to_new_struct_and_truncate_functions.GetOrCreate(str, [&] {
|
||||||
auto new_struct_sym = b.Symbols().New();
|
auto new_struct_sym = b.Symbols().New();
|
||||||
|
|
||||||
utils::Vector<const ast::StructMember*, 20> truncated_members;
|
utils::Vector<const StructMember*, 20> truncated_members;
|
||||||
utils::Vector<const ast::Expression*, 20> initializer_exprs;
|
utils::Vector<const Expression*, 20> initializer_exprs;
|
||||||
|
|
||||||
for (auto* member : str->Members()) {
|
for (auto* member : str->Members()) {
|
||||||
if (omit_members.Contains(member)) {
|
if (omit_members.Contains(member)) {
|
||||||
|
@ -155,7 +155,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
|
||||||
|
|
||||||
// Replace return statements with new truncated shader IO struct
|
// Replace return statements with new truncated shader IO struct
|
||||||
ctx.ReplaceAll(
|
ctx.ReplaceAll(
|
||||||
[&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* {
|
[&](const ReturnStatement* return_statement) -> const ReturnStatement* {
|
||||||
auto* return_sem = sem.Get(return_statement);
|
auto* return_sem = sem.Get(return_statement);
|
||||||
if (auto mapping_fn_sym =
|
if (auto mapping_fn_sym =
|
||||||
entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
|
entry_point_functions_to_truncate_functions.Find(return_sem->Function())) {
|
||||||
|
@ -168,11 +168,11 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src,
|
||||||
// Remove IO attributes from old shader IO struct which is not used as entry point output
|
// Remove IO attributes from old shader IO struct which is not used as entry point output
|
||||||
// anymore.
|
// anymore.
|
||||||
for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) {
|
for (auto it : old_shader_io_structs_to_new_struct_and_truncate_functions) {
|
||||||
const ast::Struct* struct_ty = it.key->Declaration();
|
const Struct* struct_ty = it.key->Declaration();
|
||||||
for (auto* member : struct_ty->members) {
|
for (auto* member : struct_ty->members) {
|
||||||
for (auto* attr : member->attributes) {
|
for (auto* attr : member->attributes) {
|
||||||
if (attr->IsAnyOf<ast::BuiltinAttribute, ast::LocationAttribute,
|
if (attr->IsAnyOf<BuiltinAttribute, LocationAttribute, InterpolateAttribute,
|
||||||
ast::InterpolateAttribute, ast::InvariantAttribute>()) {
|
InvariantAttribute>()) {
|
||||||
ctx.Remove(member->attributes, attr);
|
ctx.Remove(member->attributes, attr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,29 +50,27 @@ struct Unshadow::State {
|
||||||
// Maps a variable to its new name.
|
// Maps a variable to its new name.
|
||||||
utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
|
utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
|
||||||
|
|
||||||
auto rename = [&](const sem::Variable* v) -> const ast::Variable* {
|
auto rename = [&](const sem::Variable* v) -> const Variable* {
|
||||||
auto* decl = v->Declaration();
|
auto* decl = v->Declaration();
|
||||||
auto name = decl->name->symbol.Name();
|
auto name = decl->name->symbol.Name();
|
||||||
auto symbol = b.Symbols().New(name);
|
auto symbol = b.Symbols().New(name);
|
||||||
renamed_to.Add(v, symbol);
|
renamed_to.Add(v, symbol);
|
||||||
|
|
||||||
auto source = ctx.Clone(decl->source);
|
auto source = ctx.Clone(decl->source);
|
||||||
auto type = decl->type ? ctx.Clone(decl->type) : ast::Type{};
|
auto type = decl->type ? ctx.Clone(decl->type) : Type{};
|
||||||
auto* initializer = ctx.Clone(decl->initializer);
|
auto* initializer = ctx.Clone(decl->initializer);
|
||||||
auto attributes = ctx.Clone(decl->attributes);
|
auto attributes = ctx.Clone(decl->attributes);
|
||||||
return Switch(
|
return Switch(
|
||||||
decl, //
|
decl, //
|
||||||
[&](const ast::Var* var) {
|
[&](const Var* var) {
|
||||||
return b.Var(source, symbol, type, var->declared_address_space,
|
return b.Var(source, symbol, type, var->declared_address_space,
|
||||||
var->declared_access, initializer, attributes);
|
var->declared_access, initializer, attributes);
|
||||||
},
|
},
|
||||||
[&](const ast::Let*) {
|
[&](const Let*) { return b.Let(source, symbol, type, initializer, attributes); },
|
||||||
return b.Let(source, symbol, type, initializer, attributes);
|
[&](const Const*) {
|
||||||
},
|
|
||||||
[&](const ast::Const*) {
|
|
||||||
return b.Const(source, symbol, type, initializer, attributes);
|
return b.Const(source, symbol, type, initializer, attributes);
|
||||||
},
|
},
|
||||||
[&](const ast::Parameter*) { //
|
[&](const Parameter*) { //
|
||||||
return b.Param(source, symbol, type, attributes);
|
return b.Param(source, symbol, type, attributes);
|
||||||
},
|
},
|
||||||
[&](Default) {
|
[&](Default) {
|
||||||
|
@ -105,8 +103,7 @@ struct Unshadow::State {
|
||||||
return SkipTransform;
|
return SkipTransform;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.ReplaceAll(
|
ctx.ReplaceAll([&](const IdentifierExpression* ident) -> const IdentifierExpression* {
|
||||||
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
|
|
||||||
if (auto* sem_ident = sem.GetVal(ident)) {
|
if (auto* sem_ident = sem.GetVal(ident)) {
|
||||||
if (auto* user = sem_ident->Unwrap()->As<sem::VariableUser>()) {
|
if (auto* user = sem_ident->Unwrap()->As<sem::VariableUser>()) {
|
||||||
if (auto renamed = renamed_to.Find(user->Variable())) {
|
if (auto renamed = renamed_to.Find(user->Variable())) {
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
|
|
||||||
namespace tint::ast::transform::utils {
|
namespace tint::ast::transform::utils {
|
||||||
|
|
||||||
InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt) {
|
InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt) {
|
||||||
auto& sem = ctx.src->Sem();
|
auto& sem = ctx.src->Sem();
|
||||||
auto& diag = ctx.dst->Diagnostics();
|
auto& diag = ctx.dst->Diagnostics();
|
||||||
using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>;
|
using RetType = std::pair<const sem::BlockStatement*, const Statement*>;
|
||||||
|
|
||||||
if (auto* sem_stmt = sem.Get(stmt)) {
|
if (auto* sem_stmt = sem.Get(stmt)) {
|
||||||
auto* parent = sem_stmt->Parent();
|
auto* parent = sem_stmt->Parent();
|
||||||
|
|
|
@ -24,7 +24,7 @@ namespace tint::ast::transform::utils {
|
||||||
|
|
||||||
/// InsertionPoint is a pair of the block (`first`) within which, and the
|
/// InsertionPoint is a pair of the block (`first`) within which, and the
|
||||||
/// statement (`second`) before or after which to insert.
|
/// statement (`second`) before or after which to insert.
|
||||||
using InsertionPoint = std::pair<const sem::BlockStatement*, const ast::Statement*>;
|
using InsertionPoint = std::pair<const sem::BlockStatement*, const Statement*>;
|
||||||
|
|
||||||
/// For the input statement, returns the block and statement within that
|
/// For the input statement, returns the block and statement within that
|
||||||
/// block to insert before/after. If `stmt` is a for-loop continue statement,
|
/// block to insert before/after. If `stmt` is a for-loop continue statement,
|
||||||
|
@ -32,7 +32,7 @@ using InsertionPoint = std::pair<const sem::BlockStatement*, const ast::Statemen
|
||||||
/// @param ctx the clone context
|
/// @param ctx the clone context
|
||||||
/// @param stmt the statement to insert before or after
|
/// @param stmt the statement to insert before or after
|
||||||
/// @return the insertion point
|
/// @return the insertion point
|
||||||
InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt);
|
InsertionPoint GetInsertionPoint(CloneContext& ctx, const Statement* stmt);
|
||||||
|
|
||||||
} // namespace tint::ast::transform::utils
|
} // namespace tint::ast::transform::utils
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ struct HoistToDeclBefore::State {
|
||||||
|
|
||||||
/// @copydoc HoistToDeclBefore::Add()
|
/// @copydoc HoistToDeclBefore::Add()
|
||||||
bool Add(const sem::ValueExpression* before_expr,
|
bool Add(const sem::ValueExpression* before_expr,
|
||||||
const ast::Expression* expr,
|
const Expression* expr,
|
||||||
VariableKind kind,
|
VariableKind kind,
|
||||||
const char* decl_name) {
|
const char* decl_name) {
|
||||||
auto name = b.Symbols().New(decl_name);
|
auto name = b.Symbols().New(decl_name);
|
||||||
|
@ -85,8 +85,8 @@ struct HoistToDeclBefore::State {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
|
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const Statement*)
|
||||||
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
|
bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
|
||||||
if (stmt) {
|
if (stmt) {
|
||||||
auto builder = [stmt] { return stmt; };
|
auto builder = [stmt] { return stmt; };
|
||||||
return InsertBeforeImpl(before_stmt, std::move(builder));
|
return InsertBeforeImpl(before_stmt, std::move(builder));
|
||||||
|
@ -99,8 +99,8 @@ struct HoistToDeclBefore::State {
|
||||||
return InsertBeforeImpl(before_stmt, std::move(builder));
|
return InsertBeforeImpl(before_stmt, std::move(builder));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with)
|
/// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with)
|
||||||
bool Replace(const sem::Statement* what, const ast::Statement* with) {
|
bool Replace(const sem::Statement* what, const Statement* with) {
|
||||||
auto builder = [with] { return with; };
|
auto builder = [with] { return with; };
|
||||||
return Replace(what, std::move(builder));
|
return Replace(what, std::move(builder));
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,7 @@ struct HoistToDeclBefore::State {
|
||||||
utils::Hashmap<const sem::WhileStatement*, LoopInfo, 4> while_loops;
|
utils::Hashmap<const sem::WhileStatement*, LoopInfo, 4> while_loops;
|
||||||
|
|
||||||
/// 'else if' statements that need to be decomposed to 'else {if}'
|
/// 'else if' statements that need to be decomposed to 'else {if}'
|
||||||
utils::Hashmap<const ast::IfStatement*, ElseIfInfo, 4> else_ifs;
|
utils::Hashmap<const IfStatement*, ElseIfInfo, 4> else_ifs;
|
||||||
|
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
static auto Build(const utils::Vector<StmtBuilder, N>& builders) {
|
static auto Build(const utils::Vector<StmtBuilder, N>& builders) {
|
||||||
|
@ -181,7 +181,7 @@ struct HoistToDeclBefore::State {
|
||||||
/// automatically called.
|
/// automatically called.
|
||||||
/// @warning the returned reference is invalid if this is called a second time, or the
|
/// @warning the returned reference is invalid if this is called a second time, or the
|
||||||
/// #else_ifs map is mutated.
|
/// #else_ifs map is mutated.
|
||||||
auto ElseIf(const ast::IfStatement* else_if) {
|
auto ElseIf(const IfStatement* else_if) {
|
||||||
if (else_ifs.IsEmpty()) {
|
if (else_ifs.IsEmpty()) {
|
||||||
RegisterElseIfTransform();
|
RegisterElseIfTransform();
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,7 @@ struct HoistToDeclBefore::State {
|
||||||
|
|
||||||
/// Registers the handler for transforming for-loops based on the content of the #for_loops map.
|
/// Registers the handler for transforming for-loops based on the content of the #for_loops map.
|
||||||
void RegisterForLoopTransform() const {
|
void RegisterForLoopTransform() const {
|
||||||
ctx.ReplaceAll([&](const ast::ForLoopStatement* stmt) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const ForLoopStatement* stmt) -> const Statement* {
|
||||||
auto& sem = ctx.src->Sem();
|
auto& sem = ctx.src->Sem();
|
||||||
|
|
||||||
if (auto* fl = sem.Get(stmt)) {
|
if (auto* fl = sem.Get(stmt)) {
|
||||||
|
@ -205,9 +205,9 @@ struct HoistToDeclBefore::State {
|
||||||
if (auto* cond = for_loop->condition) {
|
if (auto* cond = for_loop->condition) {
|
||||||
// !condition
|
// !condition
|
||||||
auto* not_cond =
|
auto* not_cond =
|
||||||
b.create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
|
b.create<UnaryOpExpression>(UnaryOp::kNot, ctx.Clone(cond));
|
||||||
// { break; }
|
// { break; }
|
||||||
auto* break_body = b.Block(b.create<ast::BreakStatement>());
|
auto* break_body = b.Block(b.create<BreakStatement>());
|
||||||
// if (!condition) { break; }
|
// if (!condition) { break; }
|
||||||
body_stmts.Push(b.If(not_cond, break_body));
|
body_stmts.Push(b.If(not_cond, break_body));
|
||||||
}
|
}
|
||||||
|
@ -215,7 +215,7 @@ struct HoistToDeclBefore::State {
|
||||||
body_stmts.Push(ctx.Clone(for_loop->body));
|
body_stmts.Push(ctx.Clone(for_loop->body));
|
||||||
|
|
||||||
// Create the continuing block if there was one.
|
// Create the continuing block if there was one.
|
||||||
const ast::BlockStatement* continuing = nullptr;
|
const BlockStatement* continuing = nullptr;
|
||||||
if (auto* cont = for_loop->continuing) {
|
if (auto* cont = for_loop->continuing) {
|
||||||
// Continuing block starts with any let declarations used by
|
// Continuing block starts with any let declarations used by
|
||||||
// the continuing.
|
// the continuing.
|
||||||
|
@ -249,7 +249,7 @@ struct HoistToDeclBefore::State {
|
||||||
/// map.
|
/// map.
|
||||||
void RegisterWhileLoopTransform() const {
|
void RegisterWhileLoopTransform() const {
|
||||||
// At least one while needs to be transformed into a loop.
|
// At least one while needs to be transformed into a loop.
|
||||||
ctx.ReplaceAll([&](const ast::WhileStatement* stmt) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const WhileStatement* stmt) -> const Statement* {
|
||||||
auto& sem = ctx.src->Sem();
|
auto& sem = ctx.src->Sem();
|
||||||
|
|
||||||
if (auto* w = sem.Get(stmt)) {
|
if (auto* w = sem.Get(stmt)) {
|
||||||
|
@ -274,7 +274,7 @@ struct HoistToDeclBefore::State {
|
||||||
// Next emit the body
|
// Next emit the body
|
||||||
body_stmts.Push(ctx.Clone(while_loop->body));
|
body_stmts.Push(ctx.Clone(while_loop->body));
|
||||||
|
|
||||||
const ast::BlockStatement* continuing = nullptr;
|
const BlockStatement* continuing = nullptr;
|
||||||
|
|
||||||
auto* body = b.Block(body_stmts);
|
auto* body = b.Block(body_stmts);
|
||||||
auto* loop = b.Loop(body, continuing);
|
auto* loop = b.Loop(body, continuing);
|
||||||
|
@ -289,7 +289,7 @@ struct HoistToDeclBefore::State {
|
||||||
/// map.
|
/// map.
|
||||||
void RegisterElseIfTransform() const {
|
void RegisterElseIfTransform() const {
|
||||||
// Decompose 'else-if' statements into 'else { if }' blocks.
|
// Decompose 'else-if' statements into 'else { if }' blocks.
|
||||||
ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const IfStatement* stmt) -> const Statement* {
|
||||||
if (auto info = else_ifs.Find(stmt)) {
|
if (auto info = else_ifs.Find(stmt)) {
|
||||||
// Build the else block's body statements, starting with let decls for the
|
// Build the else block's body statements, starting with let decls for the
|
||||||
// conditional expression.
|
// conditional expression.
|
||||||
|
@ -412,14 +412,13 @@ HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_uniqu
|
||||||
HoistToDeclBefore::~HoistToDeclBefore() {}
|
HoistToDeclBefore::~HoistToDeclBefore() {}
|
||||||
|
|
||||||
bool HoistToDeclBefore::Add(const sem::ValueExpression* before_expr,
|
bool HoistToDeclBefore::Add(const sem::ValueExpression* before_expr,
|
||||||
const ast::Expression* expr,
|
const Expression* expr,
|
||||||
VariableKind kind,
|
VariableKind kind,
|
||||||
const char* decl_name) {
|
const char* decl_name) {
|
||||||
return state_->Add(before_expr, expr, kind, decl_name);
|
return state_->Add(before_expr, expr, kind, decl_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
|
bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, const Statement* stmt) {
|
||||||
const ast::Statement* stmt) {
|
|
||||||
return state_->InsertBefore(before_stmt, stmt);
|
return state_->InsertBefore(before_stmt, stmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -428,7 +427,7 @@ bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt,
|
||||||
return state_->InsertBefore(before_stmt, builder);
|
return state_->InsertBefore(before_stmt, builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with) {
|
bool HoistToDeclBefore::Replace(const sem::Statement* what, const Statement* with) {
|
||||||
return state_->Replace(what, with);
|
return state_->Replace(what, with);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ class HoistToDeclBefore {
|
||||||
~HoistToDeclBefore();
|
~HoistToDeclBefore();
|
||||||
|
|
||||||
/// StmtBuilder is a builder of an AST statement
|
/// StmtBuilder is a builder of an AST statement
|
||||||
using StmtBuilder = std::function<const ast::Statement*()>;
|
using StmtBuilder = std::function<const Statement*()>;
|
||||||
|
|
||||||
/// VariableKind is either a var, let or const
|
/// VariableKind is either a var, let or const
|
||||||
enum class VariableKind {
|
enum class VariableKind {
|
||||||
|
@ -53,7 +53,7 @@ class HoistToDeclBefore {
|
||||||
/// @param decl_name optional name to use for the variable/constant name
|
/// @param decl_name optional name to use for the variable/constant name
|
||||||
/// @return true on success
|
/// @return true on success
|
||||||
bool Add(const sem::ValueExpression* before_expr,
|
bool Add(const sem::ValueExpression* before_expr,
|
||||||
const ast::Expression* expr,
|
const Expression* expr,
|
||||||
VariableKind kind,
|
VariableKind kind,
|
||||||
const char* decl_name = "");
|
const char* decl_name = "");
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class HoistToDeclBefore {
|
||||||
/// @param before_stmt statement to insert @p stmt before
|
/// @param before_stmt statement to insert @p stmt before
|
||||||
/// @param stmt statement to insert
|
/// @param stmt statement to insert
|
||||||
/// @return true on success
|
/// @return true on success
|
||||||
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt);
|
bool InsertBefore(const sem::Statement* before_stmt, const Statement* stmt);
|
||||||
|
|
||||||
/// Inserts the returned statement of @p builder before @p before_stmt, possibly converting
|
/// Inserts the returned statement of @p builder before @p before_stmt, possibly converting
|
||||||
/// 'for-loop's to 'loop's if necessary.
|
/// 'for-loop's to 'loop's if necessary.
|
||||||
|
@ -81,7 +81,7 @@ class HoistToDeclBefore {
|
||||||
/// @param what the statement to replace
|
/// @param what the statement to replace
|
||||||
/// @param with the replacement statement
|
/// @param with the replacement statement
|
||||||
/// @return true on success
|
/// @return true on success
|
||||||
bool Replace(const sem::Statement* what, const ast::Statement* with);
|
bool Replace(const sem::Statement* what, const Statement* with);
|
||||||
|
|
||||||
/// Replaces the statement @p what with the statement returned by @p stmt, possibly converting
|
/// Replaces the statement @p what with the statement returned by @p stmt, possibly converting
|
||||||
/// 'for-loop's to 'loop's if necessary.
|
/// 'for-loop's to 'loop's if necessary.
|
||||||
|
|
|
@ -628,7 +628,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
||||||
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
||||||
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
|
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
|
||||||
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
||||||
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
||||||
|
|
||||||
|
@ -637,7 +637,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont) {
|
||||||
CloneContext ctx(&cloned_b, &original);
|
CloneContext ctx(&cloned_b, &original);
|
||||||
|
|
||||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||||
auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
|
auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
|
||||||
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
|
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
|
||||||
hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
|
hoistToDeclBefore.InsertBefore(before_stmt, new_stmt);
|
||||||
|
|
||||||
|
@ -679,7 +679,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont_Function) {
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
||||||
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
||||||
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
|
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
|
||||||
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
||||||
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
||||||
|
|
||||||
|
@ -688,7 +688,7 @@ TEST_F(HoistToDeclBeforeTest, InsertBefore_ForLoopCont_Function) {
|
||||||
CloneContext ctx(&cloned_b, &original);
|
CloneContext ctx(&cloned_b, &original);
|
||||||
|
|
||||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||||
auto* before_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
|
auto* before_stmt = ctx.src->Sem().Get(cont->As<Statement>());
|
||||||
hoistToDeclBefore.InsertBefore(before_stmt,
|
hoistToDeclBefore.InsertBefore(before_stmt,
|
||||||
[&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
|
[&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
|
||||||
|
|
||||||
|
@ -1048,7 +1048,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) {
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
||||||
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
||||||
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
|
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
|
||||||
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
||||||
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
||||||
|
|
||||||
|
@ -1057,7 +1057,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) {
|
||||||
CloneContext ctx(&cloned_b, &original);
|
CloneContext ctx(&cloned_b, &original);
|
||||||
|
|
||||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||||
auto* target_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
|
auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
|
||||||
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
|
auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo"));
|
||||||
hoistToDeclBefore.Replace(target_stmt, new_stmt);
|
hoistToDeclBefore.Replace(target_stmt, new_stmt);
|
||||||
|
|
||||||
|
@ -1098,7 +1098,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) {
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty);
|
||||||
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
auto* var = b.Decl(b.Var("a", b.Expr(1_i)));
|
||||||
auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd);
|
auto* cont = b.CompoundAssign("a", b.Expr(1_i), BinaryOp::kAdd);
|
||||||
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
auto* s = b.For(nullptr, b.Expr(true), cont, b.Block());
|
||||||
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s});
|
||||||
|
|
||||||
|
@ -1107,7 +1107,7 @@ TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) {
|
||||||
CloneContext ctx(&cloned_b, &original);
|
CloneContext ctx(&cloned_b, &original);
|
||||||
|
|
||||||
HoistToDeclBefore hoistToDeclBefore(ctx);
|
HoistToDeclBefore hoistToDeclBefore(ctx);
|
||||||
auto* target_stmt = ctx.src->Sem().Get(cont->As<ast::Statement>());
|
auto* target_stmt = ctx.src->Sem().Get(cont->As<Statement>());
|
||||||
hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
|
hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); });
|
||||||
|
|
||||||
ctx.Clone();
|
ctx.Clone();
|
||||||
|
|
|
@ -37,7 +37,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
|
||||||
|
|
||||||
// Extracts array and matrix values that are dynamically indexed to a
|
// Extracts array and matrix values that are dynamically indexed to a
|
||||||
// temporary `var` local that is then indexed.
|
// temporary `var` local that is then indexed.
|
||||||
auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) {
|
auto dynamic_index_to_var = [&](const IndexAccessorExpression* access_expr) {
|
||||||
auto* index_expr = access_expr->index;
|
auto* index_expr = access_expr->index;
|
||||||
auto* object_expr = access_expr->object;
|
auto* object_expr = access_expr->object;
|
||||||
auto& sem = src->Sem();
|
auto& sem = src->Sem();
|
||||||
|
@ -62,7 +62,7 @@ Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
|
||||||
|
|
||||||
bool index_accessor_found = false;
|
bool index_accessor_found = false;
|
||||||
for (auto* node : src->ASTNodes().Objects()) {
|
for (auto* node : src->ASTNodes().Objects()) {
|
||||||
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
|
if (auto* access_expr = node->As<IndexAccessorExpression>()) {
|
||||||
if (!dynamic_index_to_var(access_expr)) {
|
if (!dynamic_index_to_var(access_expr)) {
|
||||||
return Program(std::move(b));
|
return Program(std::move(b));
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
|
||||||
|
|
||||||
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
|
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
|
||||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
auto* ty_conv = call->Target()->As<sem::ValueConversion>();
|
auto* ty_conv = call->Target()->As<sem::ValueConversion>();
|
||||||
if (!ty_conv) {
|
if (!ty_conv) {
|
||||||
|
@ -102,7 +102,7 @@ Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
|
||||||
}
|
}
|
||||||
|
|
||||||
auto build_vectorized_conversion_expression = [&](auto&& src_expression_builder) {
|
auto build_vectorized_conversion_expression = [&](auto&& src_expression_builder) {
|
||||||
utils::Vector<const ast::Expression*, 4> columns;
|
utils::Vector<const Expression*, 4> columns;
|
||||||
for (uint32_t c = 0; c < dst_type->columns(); c++) {
|
for (uint32_t c = 0; c < dst_type->columns(); c++) {
|
||||||
auto* src_matrix_expr = src_expression_builder();
|
auto* src_matrix_expr = src_expression_builder();
|
||||||
auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));
|
auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));
|
||||||
|
|
|
@ -61,7 +61,7 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
|
||||||
|
|
||||||
std::unordered_map<const type::Matrix*, Symbol> scalar_inits;
|
std::unordered_map<const type::Matrix*, Symbol> scalar_inits;
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
ctx.ReplaceAll([&](const CallExpression* expr) -> const CallExpression* {
|
||||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||||
auto* ty_init = call->Target()->As<sem::ValueConstructor>();
|
auto* ty_init = call->Target()->As<sem::ValueConstructor>();
|
||||||
if (!ty_init) {
|
if (!ty_init) {
|
||||||
|
@ -91,9 +91,9 @@ Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* s
|
||||||
// Constructs a matrix using vector columns, with the elements constructed using the
|
// Constructs a matrix using vector columns, with the elements constructed using the
|
||||||
// 'element(uint32_t c, uint32_t r)' callback.
|
// 'element(uint32_t c, uint32_t r)' callback.
|
||||||
auto build_mat = [&](auto&& element) {
|
auto build_mat = [&](auto&& element) {
|
||||||
utils::Vector<const ast::Expression*, 4> columns;
|
utils::Vector<const Expression*, 4> columns;
|
||||||
for (uint32_t c = 0; c < mat_type->columns(); c++) {
|
for (uint32_t c = 0; c < mat_type->columns(); c++) {
|
||||||
utils::Vector<const ast::Expression*, 4> row_values;
|
utils::Vector<const Expression*, 4> row_values;
|
||||||
for (uint32_t r = 0; r < mat_type->rows(); r++) {
|
for (uint32_t r = 0; r < mat_type->rows(); r++) {
|
||||||
row_values.Push(element(c, r));
|
row_values.Push(element(c, r));
|
||||||
}
|
}
|
||||||
|
|
|
@ -238,9 +238,9 @@ struct VertexPulling::State {
|
||||||
/// @returns the new program or SkipTransform if the transform is not required
|
/// @returns the new program or SkipTransform if the transform is not required
|
||||||
ApplyResult Run() {
|
ApplyResult Run() {
|
||||||
// Find entry point
|
// Find entry point
|
||||||
const ast::Function* func = nullptr;
|
const Function* func = nullptr;
|
||||||
for (auto* fn : src->AST().Functions()) {
|
for (auto* fn : src->AST().Functions()) {
|
||||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
if (fn->PipelineStage() == PipelineStage::kVertex) {
|
||||||
if (func != nullptr) {
|
if (func != nullptr) {
|
||||||
b.Diagnostics().add_error(
|
b.Diagnostics().add_error(
|
||||||
diag::System::Transform,
|
diag::System::Transform,
|
||||||
|
@ -264,18 +264,18 @@ struct VertexPulling::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// LocationReplacement describes an ast::Variable replacement for a location input.
|
/// LocationReplacement describes an Variable replacement for a location input.
|
||||||
struct LocationReplacement {
|
struct LocationReplacement {
|
||||||
/// The variable to replace in the source Program
|
/// The variable to replace in the source Program
|
||||||
ast::Variable* from;
|
Variable* from;
|
||||||
/// The replacement to use in the target ProgramBuilder
|
/// The replacement to use in the target ProgramBuilder
|
||||||
ast::Variable* to;
|
Variable* to;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// LocationInfo describes an input location
|
/// LocationInfo describes an input location
|
||||||
struct LocationInfo {
|
struct LocationInfo {
|
||||||
/// A builder that builds the expression that resolves to the (transformed) input location
|
/// A builder that builds the expression that resolves to the (transformed) input location
|
||||||
std::function<const ast::Expression*()> expr;
|
std::function<const Expression*()> expr;
|
||||||
/// The store type of the location variable
|
/// The store type of the location variable
|
||||||
const type::Type* type;
|
const type::Type* type;
|
||||||
};
|
};
|
||||||
|
@ -289,12 +289,12 @@ struct VertexPulling::State {
|
||||||
/// The clone context
|
/// The clone context
|
||||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||||
std::unordered_map<uint32_t, LocationInfo> location_info;
|
std::unordered_map<uint32_t, LocationInfo> location_info;
|
||||||
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
|
std::function<const Expression*()> vertex_index_expr = nullptr;
|
||||||
std::function<const ast::Expression*()> instance_index_expr = nullptr;
|
std::function<const Expression*()> instance_index_expr = nullptr;
|
||||||
Symbol pulling_position_name;
|
Symbol pulling_position_name;
|
||||||
Symbol struct_buffer_name;
|
Symbol struct_buffer_name;
|
||||||
std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
|
std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
|
||||||
utils::Vector<const ast::Parameter*, 8> new_function_parameters;
|
utils::Vector<const Parameter*, 8> new_function_parameters;
|
||||||
|
|
||||||
/// Generate the vertex buffer binding name
|
/// Generate the vertex buffer binding name
|
||||||
/// @param index index to append to buffer name
|
/// @param index index to append to buffer name
|
||||||
|
@ -331,11 +331,11 @@ struct VertexPulling::State {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates and returns the assignment to the variables from the buffers
|
/// Creates and returns the assignment to the variables from the buffers
|
||||||
const ast::BlockStatement* CreateVertexPullingPreamble() {
|
const BlockStatement* CreateVertexPullingPreamble() {
|
||||||
// Assign by looking at the vertex descriptor to find attributes with
|
// Assign by looking at the vertex descriptor to find attributes with
|
||||||
// matching location.
|
// matching location.
|
||||||
|
|
||||||
utils::Vector<const ast::Statement*, 8> stmts;
|
utils::Vector<const Statement*, 8> stmts;
|
||||||
|
|
||||||
for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) {
|
for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size(); ++buffer_idx) {
|
||||||
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
|
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
|
||||||
|
@ -399,7 +399,7 @@ struct VertexPulling::State {
|
||||||
// Convert the fetched scalar/vector if WGSL variable is of `f16` types
|
// Convert the fetched scalar/vector if WGSL variable is of `f16` types
|
||||||
if (var_dt.base_type == BaseWGSLType::kF16) {
|
if (var_dt.base_type == BaseWGSLType::kF16) {
|
||||||
// The type of the same element number of base type of target WGSL variable
|
// The type of the same element number of base type of target WGSL variable
|
||||||
ast::Type loaded_data_target_type;
|
Type loaded_data_target_type;
|
||||||
if (fmt_dt.width == 1) {
|
if (fmt_dt.width == 1) {
|
||||||
loaded_data_target_type = b.ty.f16();
|
loaded_data_target_type = b.ty.f16();
|
||||||
} else {
|
} else {
|
||||||
|
@ -433,7 +433,7 @@ struct VertexPulling::State {
|
||||||
|
|
||||||
// The components of result vector variable, initialized with type-converted
|
// The components of result vector variable, initialized with type-converted
|
||||||
// loaded data vector.
|
// loaded data vector.
|
||||||
utils::Vector<const ast::Expression*, 8> values{fetch};
|
utils::Vector<const Expression*, 8> values{fetch};
|
||||||
|
|
||||||
// Add padding elements. The result must be of vector types of signed/unsigned
|
// Add padding elements. The result must be of vector types of signed/unsigned
|
||||||
// integer or float, so use the abstract integer or abstract float value to do
|
// integer or float, so use the abstract integer or abstract float value to do
|
||||||
|
@ -470,7 +470,7 @@ struct VertexPulling::State {
|
||||||
/// @param offset the byte offset of the data from `buffer_base`
|
/// @param offset the byte offset of the data from `buffer_base`
|
||||||
/// @param buffer the index of the vertex buffer
|
/// @param buffer the index of the vertex buffer
|
||||||
/// @param format the vertex format to read
|
/// @param format the vertex format to read
|
||||||
const ast::Expression* Fetch(Symbol array_base,
|
const Expression* Fetch(Symbol array_base,
|
||||||
uint32_t offset,
|
uint32_t offset,
|
||||||
uint32_t buffer,
|
uint32_t buffer,
|
||||||
VertexFormat format) {
|
VertexFormat format) {
|
||||||
|
@ -679,11 +679,11 @@ struct VertexPulling::State {
|
||||||
/// @param buffer the index of the vertex buffer
|
/// @param buffer the index of the vertex buffer
|
||||||
/// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
|
/// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
|
||||||
/// VertexFormat::kFloat32
|
/// VertexFormat::kFloat32
|
||||||
const ast::Expression* LoadPrimitive(Symbol array_base,
|
const Expression* LoadPrimitive(Symbol array_base,
|
||||||
uint32_t offset,
|
uint32_t offset,
|
||||||
uint32_t buffer,
|
uint32_t buffer,
|
||||||
VertexFormat format) {
|
VertexFormat format) {
|
||||||
const ast::Expression* u = nullptr;
|
const Expression* u = nullptr;
|
||||||
if ((offset & 3) == 0) {
|
if ((offset & 3) == 0) {
|
||||||
// Aligned load.
|
// Aligned load.
|
||||||
|
|
||||||
|
@ -734,14 +734,14 @@ struct VertexPulling::State {
|
||||||
/// @param base_type underlying AST type
|
/// @param base_type underlying AST type
|
||||||
/// @param base_format underlying vertex format
|
/// @param base_format underlying vertex format
|
||||||
/// @param count how many elements the vector has
|
/// @param count how many elements the vector has
|
||||||
const ast::Expression* LoadVec(Symbol array_base,
|
const Expression* LoadVec(Symbol array_base,
|
||||||
uint32_t offset,
|
uint32_t offset,
|
||||||
uint32_t buffer,
|
uint32_t buffer,
|
||||||
uint32_t element_stride,
|
uint32_t element_stride,
|
||||||
ast::Type base_type,
|
Type base_type,
|
||||||
VertexFormat base_format,
|
VertexFormat base_format,
|
||||||
uint32_t count) {
|
uint32_t count) {
|
||||||
utils::Vector<const ast::Expression*, 8> expr_list;
|
utils::Vector<const Expression*, 8> expr_list;
|
||||||
for (uint32_t i = 0; i < count; ++i) {
|
for (uint32_t i = 0; i < count; ++i) {
|
||||||
// Offset read position by element_stride for each component
|
// Offset read position by element_stride for each component
|
||||||
uint32_t primitive_offset = offset + element_stride * i;
|
uint32_t primitive_offset = offset + element_stride * i;
|
||||||
|
@ -756,8 +756,8 @@ struct VertexPulling::State {
|
||||||
/// vertex_index and instance_index builtins if present.
|
/// vertex_index and instance_index builtins if present.
|
||||||
/// @param func the entry point function
|
/// @param func the entry point function
|
||||||
/// @param param the parameter to process
|
/// @param param the parameter to process
|
||||||
void ProcessNonStructParameter(const ast::Function* func, const ast::Parameter* param) {
|
void ProcessNonStructParameter(const Function* func, const Parameter* param) {
|
||||||
if (ast::HasAttribute<ast::LocationAttribute>(param->attributes)) {
|
if (HasAttribute<LocationAttribute>(param->attributes)) {
|
||||||
// Create a function-scope variable to replace the parameter.
|
// Create a function-scope variable to replace the parameter.
|
||||||
auto func_var_sym = ctx.Clone(param->name->symbol);
|
auto func_var_sym = ctx.Clone(param->name->symbol);
|
||||||
auto func_var_type = ctx.Clone(param->type);
|
auto func_var_type = ctx.Clone(param->type);
|
||||||
|
@ -776,7 +776,7 @@ struct VertexPulling::State {
|
||||||
}
|
}
|
||||||
location_info[sem->Location().value()] = info;
|
location_info[sem->Location().value()] = info;
|
||||||
} else {
|
} else {
|
||||||
auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes);
|
auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes);
|
||||||
if (TINT_UNLIKELY(!builtin_attr)) {
|
if (TINT_UNLIKELY(!builtin_attr)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
|
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
|
||||||
return;
|
return;
|
||||||
|
@ -804,21 +804,21 @@ struct VertexPulling::State {
|
||||||
/// @param func the entry point function
|
/// @param func the entry point function
|
||||||
/// @param param the parameter to process
|
/// @param param the parameter to process
|
||||||
/// @param struct_ty the structure type
|
/// @param struct_ty the structure type
|
||||||
void ProcessStructParameter(const ast::Function* func,
|
void ProcessStructParameter(const Function* func,
|
||||||
const ast::Parameter* param,
|
const Parameter* param,
|
||||||
const ast::Struct* struct_ty) {
|
const Struct* struct_ty) {
|
||||||
auto param_sym = ctx.Clone(param->name->symbol);
|
auto param_sym = ctx.Clone(param->name->symbol);
|
||||||
|
|
||||||
// Process the struct members.
|
// Process the struct members.
|
||||||
bool has_locations = false;
|
bool has_locations = false;
|
||||||
utils::Vector<const ast::StructMember*, 8> members_to_clone;
|
utils::Vector<const StructMember*, 8> members_to_clone;
|
||||||
for (auto* member : struct_ty->members) {
|
for (auto* member : struct_ty->members) {
|
||||||
auto member_sym = ctx.Clone(member->name->symbol);
|
auto member_sym = ctx.Clone(member->name->symbol);
|
||||||
std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() {
|
std::function<const Expression*()> member_expr = [this, param_sym, member_sym]() {
|
||||||
return b.MemberAccessor(param_sym, member_sym);
|
return b.MemberAccessor(param_sym, member_sym);
|
||||||
};
|
};
|
||||||
|
|
||||||
if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) {
|
if (HasAttribute<LocationAttribute>(member->attributes)) {
|
||||||
// Capture mapping from location to struct member.
|
// Capture mapping from location to struct member.
|
||||||
LocationInfo info;
|
LocationInfo info;
|
||||||
info.expr = member_expr;
|
info.expr = member_expr;
|
||||||
|
@ -830,7 +830,7 @@ struct VertexPulling::State {
|
||||||
location_info[sem->Attributes().location.value()] = info;
|
location_info[sem->Attributes().location.value()] = info;
|
||||||
has_locations = true;
|
has_locations = true;
|
||||||
} else {
|
} else {
|
||||||
auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(member->attributes);
|
auto* builtin_attr = GetAttribute<BuiltinAttribute>(member->attributes);
|
||||||
if (TINT_UNLIKELY(!builtin_attr)) {
|
if (TINT_UNLIKELY(!builtin_attr)) {
|
||||||
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
|
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
|
||||||
return;
|
return;
|
||||||
|
@ -858,7 +858,7 @@ struct VertexPulling::State {
|
||||||
|
|
||||||
if (!members_to_clone.IsEmpty()) {
|
if (!members_to_clone.IsEmpty()) {
|
||||||
// Create a new struct without the location attributes.
|
// Create a new struct without the location attributes.
|
||||||
utils::Vector<const ast::StructMember*, 8> new_members;
|
utils::Vector<const StructMember*, 8> new_members;
|
||||||
for (auto* member : members_to_clone) {
|
for (auto* member : members_to_clone) {
|
||||||
auto member_name = ctx.Clone(member->name);
|
auto member_name = ctx.Clone(member->name);
|
||||||
auto member_type = ctx.Clone(member->type);
|
auto member_type = ctx.Clone(member->type);
|
||||||
|
@ -883,7 +883,7 @@ struct VertexPulling::State {
|
||||||
|
|
||||||
/// Process an entry point function.
|
/// Process an entry point function.
|
||||||
/// @param func the entry point function
|
/// @param func the entry point function
|
||||||
void Process(const ast::Function* func) {
|
void Process(const Function* func) {
|
||||||
if (func->body->Empty()) {
|
if (func->body->Empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -936,8 +936,8 @@ struct VertexPulling::State {
|
||||||
auto attrs = ctx.Clone(func->attributes);
|
auto attrs = ctx.Clone(func->attributes);
|
||||||
auto ret_attrs = ctx.Clone(func->return_type_attributes);
|
auto ret_attrs = ctx.Clone(func->return_type_attributes);
|
||||||
auto* new_func =
|
auto* new_func =
|
||||||
b.create<ast::Function>(func->source, b.Ident(func_sym), new_function_parameters,
|
b.create<Function>(func->source, b.Ident(func_sym), new_function_parameters, ret_type,
|
||||||
ret_type, body, std::move(attrs), std::move(ret_attrs));
|
body, std::move(attrs), std::move(ret_attrs));
|
||||||
ctx.Replace(func, new_func);
|
ctx.Replace(func, new_func);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace {
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* node : program->ASTNodes().Objects()) {
|
for (auto* node : program->ASTNodes().Objects()) {
|
||||||
if (node->Is<ast::WhileStatement>()) {
|
if (node->Is<WhileStatement>()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -47,8 +47,8 @@ Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, Da
|
||||||
ProgramBuilder b;
|
ProgramBuilder b;
|
||||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||||
|
|
||||||
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
|
ctx.ReplaceAll([&](const WhileStatement* w) -> const Statement* {
|
||||||
utils::Vector<const ast::Statement*, 16> stmts;
|
utils::Vector<const Statement*, 16> stmts;
|
||||||
auto* cond = w->condition;
|
auto* cond = w->condition;
|
||||||
|
|
||||||
// !condition
|
// !condition
|
||||||
|
@ -64,7 +64,7 @@ Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, Da
|
||||||
stmts.Push(ctx.Clone(stmt));
|
stmts.Push(ctx.Clone(stmt));
|
||||||
}
|
}
|
||||||
|
|
||||||
const ast::BlockStatement* continuing = nullptr;
|
const BlockStatement* continuing = nullptr;
|
||||||
|
|
||||||
auto* body = b.Block(stmts);
|
auto* body = b.Block(stmts);
|
||||||
auto* loop = b.Loop(body, continuing);
|
auto* loop = b.Loop(body, continuing);
|
||||||
|
|
|
@ -36,7 +36,7 @@ namespace {
|
||||||
|
|
||||||
bool ShouldRun(const Program* program) {
|
bool ShouldRun(const Program* program) {
|
||||||
for (auto* global : program->AST().GlobalVariables()) {
|
for (auto* global : program->AST().GlobalVariables()) {
|
||||||
if (auto* var = global->As<ast::Var>()) {
|
if (auto* var = global->As<Var>()) {
|
||||||
auto* v = program->Sem().Get(var);
|
auto* v = program->Sem().Get(var);
|
||||||
if (v->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
|
if (v->AddressSpace() == builtin::AddressSpace::kWorkgroup) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -48,7 +48,7 @@ bool ShouldRun(const Program* program) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
using StatementList = utils::Vector<const ast::Statement*, 8>;
|
using StatementList = utils::Vector<const Statement*, 8>;
|
||||||
|
|
||||||
/// PIMPL state for the transform
|
/// PIMPL state for the transform
|
||||||
struct ZeroInitWorkgroupMemory::State {
|
struct ZeroInitWorkgroupMemory::State {
|
||||||
|
@ -132,10 +132,10 @@ struct ZeroInitWorkgroupMemory::State {
|
||||||
/// Run inserts the workgroup memory zero-initialization logic at the top of
|
/// Run inserts the workgroup memory zero-initialization logic at the top of
|
||||||
/// the given function
|
/// the given function
|
||||||
/// @param fn a compute shader entry point function
|
/// @param fn a compute shader entry point function
|
||||||
void Run(const ast::Function* fn) {
|
void Run(const Function* fn) {
|
||||||
auto& sem = ctx.src->Sem();
|
auto& sem = ctx.src->Sem();
|
||||||
|
|
||||||
CalculateWorkgroupSize(ast::GetAttribute<ast::WorkgroupAttribute>(fn->attributes));
|
CalculateWorkgroupSize(GetAttribute<WorkgroupAttribute>(fn->attributes));
|
||||||
|
|
||||||
// Generate a list of statements to zero initialize each of the
|
// Generate a list of statements to zero initialize each of the
|
||||||
// workgroup storage variables used by `fn`. This will populate #statements.
|
// workgroup storage variables used by `fn`. This will populate #statements.
|
||||||
|
@ -160,7 +160,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||||
// parameter
|
// parameter
|
||||||
std::function<const ast::Expression*()> local_index;
|
std::function<const ast::Expression*()> local_index;
|
||||||
for (auto* param : fn->params) {
|
for (auto* param : fn->params) {
|
||||||
if (auto* builtin_attr = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
|
if (auto* builtin_attr = GetAttribute<BuiltinAttribute>(param->attributes)) {
|
||||||
auto builtin = sem.Get(builtin_attr)->Value();
|
auto builtin = sem.Get(builtin_attr)->Value();
|
||||||
if (builtin == builtin::BuiltinValue::kLocalInvocationIndex) {
|
if (builtin == builtin::BuiltinValue::kLocalInvocationIndex) {
|
||||||
local_index = [=] { return b.Expr(ctx.Clone(param->name->symbol)); };
|
local_index = [=] { return b.Expr(ctx.Clone(param->name->symbol)); };
|
||||||
|
@ -231,7 +231,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||||
// }
|
// }
|
||||||
auto idx = b.Symbols().New("idx");
|
auto idx = b.Symbols().New("idx");
|
||||||
auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
|
auto* init = b.Decl(b.Var(idx, b.ty.u32(), local_index()));
|
||||||
auto* cond = b.create<ast::BinaryExpression>(ast::BinaryOp::kLessThan, b.Expr(idx),
|
auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr(idx),
|
||||||
b.Expr(u32(num_iterations)));
|
b.Expr(u32(num_iterations)));
|
||||||
auto* cont = b.Assign(
|
auto* cont = b.Assign(
|
||||||
idx, b.Add(idx, workgroup_size_const ? b.Expr(u32(workgroup_size_const))
|
idx, b.Add(idx, workgroup_size_const ? b.Expr(u32(workgroup_size_const))
|
||||||
|
@ -251,8 +251,8 @@ struct ZeroInitWorkgroupMemory::State {
|
||||||
// if (local_index < num_iterations) {
|
// if (local_index < num_iterations) {
|
||||||
// ...
|
// ...
|
||||||
// }
|
// }
|
||||||
auto* cond = b.create<ast::BinaryExpression>(
|
auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, local_index(),
|
||||||
ast::BinaryOp::kLessThan, local_index(), b.Expr(u32(num_iterations)));
|
b.Expr(u32(num_iterations)));
|
||||||
auto block = DeclareArrayIndices(num_iterations, array_indices,
|
auto block = DeclareArrayIndices(num_iterations, array_indices,
|
||||||
[&] { return b.Expr(local_index()); });
|
[&] { return b.Expr(local_index()); });
|
||||||
for (auto& s : stmts) {
|
for (auto& s : stmts) {
|
||||||
|
@ -382,7 +382,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||||
for (auto index : array_indices) {
|
for (auto index : array_indices) {
|
||||||
auto name = array_index_names.at(index);
|
auto name = array_index_names.at(index);
|
||||||
auto* mod = (num_iterations > index.modulo)
|
auto* mod = (num_iterations > index.modulo)
|
||||||
? b.create<ast::BinaryExpression>(ast::BinaryOp::kModulo, iteration(),
|
? b.create<BinaryExpression>(BinaryOp::kModulo, iteration(),
|
||||||
b.Expr(u32(index.modulo)))
|
b.Expr(u32(index.modulo)))
|
||||||
: iteration();
|
: iteration();
|
||||||
auto* div = (index.division != 1u) ? b.Div(mod, u32(index.division)) : mod;
|
auto* div = (index.division != 1u) ? b.Div(mod, u32(index.division)) : mod;
|
||||||
|
@ -395,7 +395,7 @@ struct ZeroInitWorkgroupMemory::State {
|
||||||
/// CalculateWorkgroupSize initializes the members #workgroup_size_const and
|
/// CalculateWorkgroupSize initializes the members #workgroup_size_const and
|
||||||
/// #workgroup_size_expr with the linear workgroup size.
|
/// #workgroup_size_expr with the linear workgroup size.
|
||||||
/// @param attr the workgroup attribute applied to the entry point function
|
/// @param attr the workgroup attribute applied to the entry point function
|
||||||
void CalculateWorkgroupSize(const ast::WorkgroupAttribute* attr) {
|
void CalculateWorkgroupSize(const WorkgroupAttribute* attr) {
|
||||||
bool is_signed = false;
|
bool is_signed = false;
|
||||||
workgroup_size_const = 1u;
|
workgroup_size_const = 1u;
|
||||||
workgroup_size_expr = nullptr;
|
workgroup_size_expr = nullptr;
|
||||||
|
@ -471,7 +471,7 @@ Transform::ApplyResult ZeroInitWorkgroupMemory::Apply(const Program* src,
|
||||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||||
|
|
||||||
for (auto* fn : src->AST().Functions()) {
|
for (auto* fn : src->AST().Functions()) {
|
||||||
if (fn->PipelineStage() == ast::PipelineStage::kCompute) {
|
if (fn->PipelineStage() == PipelineStage::kCompute) {
|
||||||
State{ctx}.Run(fn);
|
State{ctx}.Run(fn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue