tint/transform: Refactor transforms
Replace the ShouldRun() method with Apply() which will do the transformation if it needs to be done, otherwise returns 'SkipTransform'. This reduces a bunch of duplicated scanning between the old ShouldRun() and Transform(). This change also adjusts code style to make the transforms more consistent. Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681 Reviewed-by: James Price <jrprice@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
parent
de6db384aa
commit
c6b381495d
|
@ -15,6 +15,7 @@
|
|||
#include "src/tint/fuzzers/shuffle_transform.h"
|
||||
|
||||
#include <random>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
|
||||
|
@ -22,15 +23,21 @@ namespace tint::fuzzers {
|
|||
|
||||
ShuffleTransform::ShuffleTransform(size_t seed) : seed_(seed) {}
|
||||
|
||||
void ShuffleTransform::Run(CloneContext& ctx,
|
||||
const tint::transform::DataMap&,
|
||||
tint::transform::DataMap&) const {
|
||||
auto decls = ctx.src->AST().GlobalDeclarations();
|
||||
transform::Transform::ApplyResult ShuffleTransform::Apply(const Program* src,
|
||||
const transform::DataMap&,
|
||||
transform::DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
auto decls = src->AST().GlobalDeclarations();
|
||||
auto rng = std::mt19937_64{seed_};
|
||||
std::shuffle(std::begin(decls), std::end(decls), rng);
|
||||
for (auto* decl : decls) {
|
||||
ctx.dst->AST().AddGlobalDeclaration(ctx.Clone(decl));
|
||||
b.AST().AddGlobalDeclaration(ctx.Clone(decl));
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::fuzzers
|
||||
|
|
|
@ -20,16 +20,16 @@
|
|||
namespace tint::fuzzers {
|
||||
|
||||
/// ShuffleTransform reorders the module scope declarations into a random order
|
||||
class ShuffleTransform : public tint::transform::Transform {
|
||||
class ShuffleTransform : public transform::Transform {
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param seed the random seed to use for the shuffling
|
||||
explicit ShuffleTransform(size_t seed);
|
||||
|
||||
protected:
|
||||
void Run(CloneContext& ctx,
|
||||
const tint::transform::DataMap&,
|
||||
tint::transform::DataMap&) const override;
|
||||
/// @copydoc transform::Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const transform::DataMap& inputs,
|
||||
transform::DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
size_t seed_;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "src/tint/ast/access.h"
|
||||
#include "src/tint/ast/address_space.h"
|
||||
#include "src/tint/ast/parameter.h"
|
||||
#include "src/tint/sem/binding_point.h"
|
||||
#include "src/tint/sem/expression.h"
|
||||
#include "src/tint/sem/parameter_usage.h"
|
||||
|
@ -212,6 +213,11 @@ class Parameter final : public Castable<Parameter, Variable> {
|
|||
/// Destructor
|
||||
~Parameter() override;
|
||||
|
||||
/// @returns the AST declaration node
|
||||
const ast::Parameter* Declaration() const {
|
||||
return static_cast<const ast::Parameter*>(Variable::Declaration());
|
||||
}
|
||||
|
||||
/// @return the index of the parmeter in the function
|
||||
uint32_t Index() const { return index_; }
|
||||
|
||||
|
|
|
@ -31,21 +31,29 @@ AddBlockAttribute::AddBlockAttribute() = default;
|
|||
|
||||
AddBlockAttribute::~AddBlockAttribute() = default;
|
||||
|
||||
void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
auto& sem = src->Sem();
|
||||
|
||||
// A map from a type in the source program to a block-decorated wrapper that contains it in the
|
||||
// destination program.
|
||||
utils::Hashmap<const sem::Type*, const ast::Struct*, 8> wrapper_structs;
|
||||
|
||||
// Process global 'var' declarations that are buffers.
|
||||
for (auto* global : ctx.src->AST().GlobalVariables()) {
|
||||
bool made_changes = false;
|
||||
for (auto* global : src->AST().GlobalVariables()) {
|
||||
auto* var = sem.Get(global);
|
||||
if (!ast::IsHostShareable(var->AddressSpace())) {
|
||||
// Not declared in a host-sharable address space
|
||||
continue;
|
||||
}
|
||||
|
||||
made_changes = true;
|
||||
|
||||
auto* ty = var->Type()->UnwrapRef();
|
||||
auto* str = ty->As<sem::Struct>();
|
||||
|
||||
|
@ -61,33 +69,36 @@ void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
const char* kMemberName = "inner";
|
||||
|
||||
auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
|
||||
auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(),
|
||||
ctx.dst->AllocateNodeID());
|
||||
auto wrapper_name = ctx.src->Symbols().NameFor(global->symbol) + "_block";
|
||||
auto* ret = ctx.dst->create<ast::Struct>(
|
||||
ctx.dst->Symbols().New(wrapper_name),
|
||||
utils::Vector{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
|
||||
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
|
||||
auto wrapper_name = src->Symbols().NameFor(global->symbol) + "_block";
|
||||
auto* ret = b.create<ast::Struct>(
|
||||
b.Symbols().New(wrapper_name),
|
||||
utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
|
||||
utils::Vector{block});
|
||||
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), global, ret);
|
||||
ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret);
|
||||
return ret;
|
||||
});
|
||||
ctx.Replace(global->type, ctx.dst->ty.Of(wrapper));
|
||||
ctx.Replace(global->type, b.ty.Of(wrapper));
|
||||
|
||||
// Insert a member accessor to get the original type from the wrapper at
|
||||
// any usage of the original variable.
|
||||
for (auto* user : var->Users()) {
|
||||
ctx.Replace(user->Declaration(),
|
||||
ctx.dst->MemberAccessor(ctx.Clone(global->symbol), kMemberName));
|
||||
b.MemberAccessor(ctx.Clone(global->symbol), kMemberName));
|
||||
}
|
||||
} else {
|
||||
// Add a block attribute to this struct directly.
|
||||
auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(),
|
||||
ctx.dst->AllocateNodeID());
|
||||
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
|
||||
ctx.InsertFront(str->Declaration()->attributes, block);
|
||||
}
|
||||
}
|
||||
|
||||
if (!made_changes) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid)
|
||||
|
|
|
@ -53,14 +53,10 @@ class AddBlockAttribute final : public Castable<AddBlockAttribute, Transform> {
|
|||
/// Destructor
|
||||
~AddBlockAttribute() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -23,12 +23,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::AddEmptyEntryPoint);
|
|||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
|
||||
|
||||
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
|
||||
|
||||
bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* func : program->AST().Functions()) {
|
||||
if (func->IsEntryPoint()) {
|
||||
return false;
|
||||
|
@ -37,13 +34,30 @@ bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const
|
|||
return true;
|
||||
}
|
||||
|
||||
void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {},
|
||||
utils::Vector{
|
||||
ctx.dst->Stage(ast::PipelineStage::kCompute),
|
||||
ctx.dst->WorkgroupSize(1_i),
|
||||
});
|
||||
} // namespace
|
||||
|
||||
AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
|
||||
|
||||
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
|
||||
|
||||
Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
|
||||
utils::Vector{
|
||||
b.Stage(ast::PipelineStage::kCompute),
|
||||
b.WorkgroupSize(1_i),
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -27,19 +27,10 @@ class AddEmptyEntryPoint final : public Castable<AddEmptyEntryPoint, Transform>
|
|||
/// Destructor
|
||||
~AddEmptyEntryPoint() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -31,13 +31,153 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result);
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
namespace {
|
||||
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
|
||||
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
|
||||
|
||||
/// The PIMPL state for this transform
|
||||
/// PIMPL state for the transform
|
||||
struct ArrayLengthFromUniform::State {
|
||||
/// Constructor
|
||||
/// @param program the source program
|
||||
/// @param in the input transform data
|
||||
/// @param out the output transform data
|
||||
explicit State(const Program* program, const DataMap& in, DataMap& out)
|
||||
: src(program), inputs(in), outputs(out) {}
|
||||
|
||||
/// Runs the transform
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"missing transform data for " +
|
||||
std::string(TypeInfo::Of<ArrayLengthFromUniform>().name));
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
if (!ShouldRun(ctx.src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
const char* kBufferSizeMemberName = "buffer_size";
|
||||
|
||||
// Determine the size of the buffer size array.
|
||||
uint32_t max_buffer_size_index = 0;
|
||||
|
||||
IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*,
|
||||
const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
if (idx_itr->second > max_buffer_size_index) {
|
||||
max_buffer_size_index = idx_itr->second;
|
||||
}
|
||||
});
|
||||
|
||||
// Get (or create, on first call) the uniform buffer that will receive the
|
||||
// size of each storage buffer in the module.
|
||||
const ast::Variable* buffer_size_ubo = nullptr;
|
||||
auto get_ubo = [&]() {
|
||||
if (!buffer_size_ubo) {
|
||||
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
|
||||
// We do this because UBOs require an element stride that is 16-byte
|
||||
// aligned.
|
||||
auto* buffer_size_struct = b.Structure(
|
||||
b.Sym(), utils::Vector{
|
||||
b.Member(kBufferSizeMemberName,
|
||||
b.ty.array(b.ty.vec4(b.ty.u32()),
|
||||
u32((max_buffer_size_index / 4) + 1))),
|
||||
});
|
||||
buffer_size_ubo =
|
||||
b.GlobalVar(b.Sym(), b.ty.Of(buffer_size_struct), ast::AddressSpace::kUniform,
|
||||
b.Group(AInt(cfg->ubo_binding.group)),
|
||||
b.Binding(AInt(cfg->ubo_binding.binding)));
|
||||
}
|
||||
return buffer_size_ubo;
|
||||
};
|
||||
|
||||
std::unordered_set<uint32_t> used_size_indices;
|
||||
|
||||
IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
|
||||
const sem::VariableUser* storage_buffer_sem,
|
||||
const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t size_index = idx_itr->second;
|
||||
used_size_indices.insert(size_index);
|
||||
|
||||
// Load the total storage buffer size from the UBO.
|
||||
uint32_t array_index = size_index / 4;
|
||||
auto* vec_expr = b.IndexAccessor(
|
||||
b.MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index));
|
||||
uint32_t vec_index = size_index % 4;
|
||||
auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index));
|
||||
|
||||
// Calculate actual array length
|
||||
// total_storage_buffer_size - array_offset
|
||||
// array_length = ----------------------------------------
|
||||
// array_stride
|
||||
const ast::Expression* total_size = total_storage_buffer_size;
|
||||
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
|
||||
const sem::Array* array_type = nullptr;
|
||||
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
|
||||
// The variable is a struct, so subtract the byte offset of the array
|
||||
// member.
|
||||
auto* array_member_sem = str->Members().back();
|
||||
array_type = array_member_sem->Type()->As<sem::Array>();
|
||||
total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
|
||||
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
|
||||
array_type = arr;
|
||||
} else {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
return;
|
||||
}
|
||||
auto* array_length = b.Div(total_size, u32(array_type->Stride()));
|
||||
|
||||
ctx.Replace(call_expr, array_length);
|
||||
});
|
||||
|
||||
outputs.Add<Result>(used_size_indices);
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
private:
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The transform inputs
|
||||
const DataMap& inputs;
|
||||
/// The transform outputs
|
||||
DataMap& outputs;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
/// Iterate over all arrayLength() builtins that operate on
|
||||
/// storage buffer variables.
|
||||
|
@ -48,10 +188,10 @@ struct ArrayLengthFromUniform::State {
|
|||
/// sem::GlobalVariable for the storage buffer.
|
||||
template <typename F>
|
||||
void IterateArrayLengthOnStorageVar(F&& functor) {
|
||||
auto& sem = ctx.src->Sem();
|
||||
auto& sem = src->Sem();
|
||||
|
||||
// Find all calls to the arrayLength() builtin.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
auto* call_expr = node->As<ast::CallExpression>();
|
||||
if (!call_expr) {
|
||||
continue;
|
||||
|
@ -79,7 +219,7 @@ struct ArrayLengthFromUniform::State {
|
|||
// arrayLength(&array_var)
|
||||
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
|
||||
if (!param || param->op != ast::UnaryOp::kAddressOf) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
|
@ -90,7 +230,7 @@ struct ArrayLengthFromUniform::State {
|
|||
}
|
||||
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||
if (!storage_buffer_sem) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
|
@ -99,8 +239,7 @@ struct ArrayLengthFromUniform::State {
|
|||
// Get the index to use for the buffer size array.
|
||||
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
|
||||
if (!var) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "storage buffer is not a global variable";
|
||||
TINT_ICE(Transform, b.Diagnostics()) << "storage buffer is not a global variable";
|
||||
break;
|
||||
}
|
||||
functor(call_expr, storage_buffer_sem, var);
|
||||
|
@ -108,117 +247,10 @@ struct ArrayLengthFromUniform::State {
|
|||
}
|
||||
};
|
||||
|
||||
bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
}
|
||||
|
||||
const char* kBufferSizeMemberName = "buffer_size";
|
||||
|
||||
// Determine the size of the buffer size array.
|
||||
uint32_t max_buffer_size_index = 0;
|
||||
|
||||
State{ctx}.IterateArrayLengthOnStorageVar(
|
||||
[&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
if (idx_itr->second > max_buffer_size_index) {
|
||||
max_buffer_size_index = idx_itr->second;
|
||||
}
|
||||
});
|
||||
|
||||
// Get (or create, on first call) the uniform buffer that will receive the
|
||||
// size of each storage buffer in the module.
|
||||
const ast::Variable* buffer_size_ubo = nullptr;
|
||||
auto get_ubo = [&]() {
|
||||
if (!buffer_size_ubo) {
|
||||
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
|
||||
// We do this because UBOs require an element stride that is 16-byte
|
||||
// aligned.
|
||||
auto* buffer_size_struct = ctx.dst->Structure(
|
||||
ctx.dst->Sym(),
|
||||
utils::Vector{
|
||||
ctx.dst->Member(kBufferSizeMemberName,
|
||||
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
|
||||
u32((max_buffer_size_index / 4) + 1))),
|
||||
});
|
||||
buffer_size_ubo = ctx.dst->GlobalVar(ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
|
||||
ast::AddressSpace::kUniform,
|
||||
ctx.dst->Group(AInt(cfg->ubo_binding.group)),
|
||||
ctx.dst->Binding(AInt(cfg->ubo_binding.binding)));
|
||||
}
|
||||
return buffer_size_ubo;
|
||||
};
|
||||
|
||||
std::unordered_set<uint32_t> used_size_indices;
|
||||
|
||||
State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
|
||||
const sem::VariableUser* storage_buffer_sem,
|
||||
const sem::GlobalVariable* var) {
|
||||
auto binding = var->BindingPoint();
|
||||
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
|
||||
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t size_index = idx_itr->second;
|
||||
used_size_indices.insert(size_index);
|
||||
|
||||
// Load the total storage buffer size from the UBO.
|
||||
uint32_t array_index = size_index / 4;
|
||||
auto* vec_expr = ctx.dst->IndexAccessor(
|
||||
ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index));
|
||||
uint32_t vec_index = size_index % 4;
|
||||
auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, u32(vec_index));
|
||||
|
||||
// Calculate actual array length
|
||||
// total_storage_buffer_size - array_offset
|
||||
// array_length = ----------------------------------------
|
||||
// array_stride
|
||||
const ast::Expression* total_size = total_storage_buffer_size;
|
||||
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
|
||||
const sem::Array* array_type = nullptr;
|
||||
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
|
||||
// The variable is a struct, so subtract the byte offset of the array
|
||||
// member.
|
||||
auto* array_member_sem = str->Members().back();
|
||||
array_type = array_member_sem->Type()->As<sem::Array>();
|
||||
total_size = ctx.dst->Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
|
||||
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
|
||||
array_type = arr;
|
||||
} else {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
return;
|
||||
}
|
||||
auto* array_length = ctx.dst->Div(total_size, u32(array_type->Stride()));
|
||||
|
||||
ctx.Replace(call_expr, array_length);
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
outputs.Add<Result>(used_size_indices);
|
||||
Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
return State{src, inputs, outputs}.Run();
|
||||
}
|
||||
|
||||
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {}
|
||||
|
|
|
@ -100,22 +100,12 @@ class ArrayLengthFromUniform final : public Castable<ArrayLengthFromUniform, Tra
|
|||
std::unordered_set<uint32_t> used_size_indices;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
/// The PIMPL state for this transform
|
||||
struct State;
|
||||
};
|
||||
|
||||
|
|
|
@ -28,7 +28,13 @@ using ArrayLengthFromUniformTest = TransformTest;
|
|||
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
|
||||
|
@ -45,7 +51,13 @@ fn main() {
|
|||
}
|
||||
)";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
|
||||
|
@ -63,7 +75,13 @@ fn main() {
|
|||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
|
||||
ArrayLengthFromUniform::Config cfg({0, 30u});
|
||||
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
|
||||
|
||||
DataMap data;
|
||||
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
|
||||
|
||||
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
|
||||
|
|
|
@ -40,19 +40,21 @@ BindingRemapper::Remappings::~Remappings() = default;
|
|||
BindingRemapper::BindingRemapper() = default;
|
||||
BindingRemapper::~BindingRemapper() = default;
|
||||
|
||||
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
|
||||
if (auto* remappings = inputs.Get<Remappings>()) {
|
||||
return !remappings->binding_points.empty() || !remappings->access_controls.empty();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
Transform::ApplyResult BindingRemapper::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
auto* remappings = inputs.Get<Remappings>();
|
||||
if (!remappings) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
if (remappings->binding_points.empty() && remappings->access_controls.empty()) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
// A set of post-remapped binding points that need to be decorated with a
|
||||
|
@ -62,11 +64,11 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
|
|||
if (remappings->allow_collisions) {
|
||||
// Scan for binding point collisions generated by this transform.
|
||||
// Populate all collisions in the `add_collision_attr` set.
|
||||
for (auto* func_ast : ctx.src->AST().Functions()) {
|
||||
for (auto* func_ast : src->AST().Functions()) {
|
||||
if (!func_ast->IsEntryPoint()) {
|
||||
continue;
|
||||
}
|
||||
auto* func = ctx.src->Sem().Get(func_ast);
|
||||
auto* func = src->Sem().Get(func_ast);
|
||||
std::unordered_map<sem::BindingPoint, int> binding_point_counts;
|
||||
for (auto* global : func->TransitivelyReferencedGlobals()) {
|
||||
if (global->Declaration()->HasBindingPoint()) {
|
||||
|
@ -90,9 +92,9 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
|
|||
}
|
||||
}
|
||||
|
||||
for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
|
||||
for (auto* var : src->AST().Globals<ast::Var>()) {
|
||||
if (var->HasBindingPoint()) {
|
||||
auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(var);
|
||||
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
|
||||
|
||||
// The original binding point
|
||||
BindingPoint from = global_sem->BindingPoint();
|
||||
|
@ -106,8 +108,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
|
|||
auto bp_it = remappings->binding_points.find(from);
|
||||
if (bp_it != remappings->binding_points.end()) {
|
||||
BindingPoint to = bp_it->second;
|
||||
auto* new_group = ctx.dst->Group(AInt(to.group));
|
||||
auto* new_binding = ctx.dst->Binding(AInt(to.binding));
|
||||
auto* new_group = b.Group(AInt(to.group));
|
||||
auto* new_binding = b.Binding(AInt(to.binding));
|
||||
|
||||
auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
|
||||
auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
|
||||
|
@ -122,37 +124,37 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
|
|||
if (ac_it != remappings->access_controls.end()) {
|
||||
ast::Access ac = ac_it->second;
|
||||
if (ac == ast::Access::kUndefined) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
b.Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"invalid access mode (" + std::to_string(static_cast<uint32_t>(ac)) + ")");
|
||||
return;
|
||||
return Program(std::move(b));
|
||||
}
|
||||
auto* sem = ctx.src->Sem().Get(var);
|
||||
auto* sem = src->Sem().Get(var);
|
||||
if (sem->AddressSpace() != ast::AddressSpace::kStorage) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
b.Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"cannot apply access control to variable with address space " +
|
||||
std::string(utils::ToString(sem->AddressSpace())));
|
||||
return;
|
||||
return Program(std::move(b));
|
||||
}
|
||||
auto* ty = sem->Type()->UnwrapRef();
|
||||
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
|
||||
auto* new_var =
|
||||
ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
|
||||
var->declared_address_space, ac, ctx.Clone(var->initializer),
|
||||
ctx.Clone(var->attributes));
|
||||
auto* new_var = b.Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
|
||||
var->declared_address_space, ac, ctx.Clone(var->initializer),
|
||||
ctx.Clone(var->attributes));
|
||||
ctx.Replace(var, new_var);
|
||||
}
|
||||
|
||||
// Add `DisableValidationAttribute`s if required
|
||||
if (add_collision_attr.count(bp)) {
|
||||
auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
|
||||
auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision);
|
||||
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -67,19 +67,10 @@ class BindingRemapper final : public Castable<BindingRemapper, Transform> {
|
|||
BindingRemapper();
|
||||
~BindingRemapper() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -23,12 +23,6 @@ namespace {
|
|||
|
||||
using BindingRemapperTest = TransformTest;
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
|
||||
}
|
||||
|
||||
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
|
||||
auto* src = R"()";
|
||||
|
||||
|
@ -350,7 +344,7 @@ fn f() {
|
|||
}
|
||||
)";
|
||||
|
||||
auto* expect = src;
|
||||
auto* expect = R"(error: missing transform data for tint::transform::BindingRemapper)";
|
||||
|
||||
auto got = Run<BindingRemapper>(src);
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::BuiltinPolyfill::Config);
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
/// The PIMPL state for the BuiltinPolyfill transform
|
||||
/// PIMPL state for the transform
|
||||
struct BuiltinPolyfill::State {
|
||||
/// Constructor
|
||||
/// @param c the CloneContext
|
||||
|
@ -604,193 +604,100 @@ BuiltinPolyfill::BuiltinPolyfill() = default;
|
|||
|
||||
BuiltinPolyfill::~BuiltinPolyfill() = default;
|
||||
|
||||
bool BuiltinPolyfill::ShouldRun(const Program* program, const DataMap& data) const {
|
||||
if (auto* cfg = data.Get<Config>()) {
|
||||
auto builtins = cfg->builtins;
|
||||
auto& sem = program->Sem();
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* call = sem.Get<sem::Call>(node)) {
|
||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||
if (call->Stage() == sem::EvaluationStage::kConstant) {
|
||||
continue; // Don't polyfill @const expressions
|
||||
}
|
||||
switch (builtin->Type()) {
|
||||
case sem::BuiltinType::kAcosh:
|
||||
if (builtins.acosh != Level::kNone) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kAsinh:
|
||||
if (builtins.asinh) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kAtanh:
|
||||
if (builtins.atanh != Level::kNone) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kClamp:
|
||||
if (builtins.clamp_int) {
|
||||
auto& sig = builtin->Signature();
|
||||
return sig.parameters[0]->Type()->is_integer_scalar_or_vector();
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kCountLeadingZeros:
|
||||
if (builtins.count_leading_zeros) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kCountTrailingZeros:
|
||||
if (builtins.count_trailing_zeros) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kExtractBits:
|
||||
if (builtins.extract_bits != Level::kNone) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kFirstLeadingBit:
|
||||
if (builtins.first_leading_bit) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kFirstTrailingBit:
|
||||
if (builtins.first_trailing_bit) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kInsertBits:
|
||||
if (builtins.insert_bits != Level::kNone) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kSaturate:
|
||||
if (builtins.saturate) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
|
||||
if (builtins.texture_sample_base_clamp_to_edge_2d_f32) {
|
||||
auto& sig = builtin->Signature();
|
||||
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
|
||||
if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
|
||||
return stex->type()->Is<sem::F32>();
|
||||
}
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kQuantizeToF16:
|
||||
if (builtins.quantize_to_vec_f16) {
|
||||
if (builtin->ReturnType()->Is<sem::Vector>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) const {
|
||||
Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
|
||||
const DataMap& data,
|
||||
DataMap&) const {
|
||||
auto* cfg = data.Get<Config>();
|
||||
if (!cfg) {
|
||||
ctx.Clone();
|
||||
return;
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
std::unordered_map<const sem::Builtin*, Symbol> polyfills;
|
||||
auto& builtins = cfg->builtins;
|
||||
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
||||
auto builtins = cfg->builtins;
|
||||
State s{ctx, builtins};
|
||||
if (auto* call = s.sem.Get<sem::Call>(expr)) {
|
||||
utils::Hashmap<const sem::Builtin*, Symbol, 8> polyfills;
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
State s{ctx, builtins};
|
||||
|
||||
bool made_changes = false;
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* call = src->Sem().Get<sem::Call>(node)) {
|
||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||
if (call->Stage() == sem::EvaluationStage::kConstant) {
|
||||
return nullptr; // Don't polyfill @const expressions
|
||||
continue; // Don't polyfill @const expressions
|
||||
}
|
||||
Symbol polyfill;
|
||||
switch (builtin->Type()) {
|
||||
case sem::BuiltinType::kAcosh:
|
||||
if (builtins.acosh != Level::kNone) {
|
||||
polyfill = utils::GetOrCreate(
|
||||
polyfills, builtin, [&] { return s.acosh(builtin->ReturnType()); });
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.acosh(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kAsinh:
|
||||
if (builtins.asinh) {
|
||||
polyfill = utils::GetOrCreate(
|
||||
polyfills, builtin, [&] { return s.asinh(builtin->ReturnType()); });
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.asinh(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kAtanh:
|
||||
if (builtins.atanh != Level::kNone) {
|
||||
polyfill = utils::GetOrCreate(
|
||||
polyfills, builtin, [&] { return s.atanh(builtin->ReturnType()); });
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.atanh(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kClamp:
|
||||
if (builtins.clamp_int) {
|
||||
auto& sig = builtin->Signature();
|
||||
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
return s.clampInteger(builtin->ReturnType());
|
||||
});
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.clampInteger(builtin->ReturnType()); });
|
||||
}
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kCountLeadingZeros:
|
||||
if (builtins.count_leading_zeros) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
polyfill = polyfills.GetOrCreate(builtin, [&] {
|
||||
return s.countLeadingZeros(builtin->ReturnType());
|
||||
});
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kCountTrailingZeros:
|
||||
if (builtins.count_trailing_zeros) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
polyfill = polyfills.GetOrCreate(builtin, [&] {
|
||||
return s.countTrailingZeros(builtin->ReturnType());
|
||||
});
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kExtractBits:
|
||||
if (builtins.extract_bits != Level::kNone) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
return s.extractBits(builtin->ReturnType());
|
||||
});
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.extractBits(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kFirstLeadingBit:
|
||||
if (builtins.first_leading_bit) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
return s.firstLeadingBit(builtin->ReturnType());
|
||||
});
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kFirstTrailingBit:
|
||||
if (builtins.first_trailing_bit) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
return s.firstTrailingBit(builtin->ReturnType());
|
||||
});
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kInsertBits:
|
||||
if (builtins.insert_bits != Level::kNone) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
return s.insertBits(builtin->ReturnType());
|
||||
});
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.insertBits(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kSaturate:
|
||||
if (builtins.saturate) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
return s.saturate(builtin->ReturnType());
|
||||
});
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.saturate(builtin->ReturnType()); });
|
||||
}
|
||||
break;
|
||||
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
|
||||
|
@ -799,7 +706,7 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons
|
|||
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
|
||||
if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
|
||||
if (stex->type()->Is<sem::F32>()) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
|
||||
polyfill = polyfills.GetOrCreate(builtin, [&] {
|
||||
return s.textureSampleBaseClampToEdge_2d_f32();
|
||||
});
|
||||
}
|
||||
|
@ -809,8 +716,8 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons
|
|||
case sem::BuiltinType::kQuantizeToF16:
|
||||
if (builtins.quantize_to_vec_f16) {
|
||||
if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) {
|
||||
polyfill = utils::GetOrCreate(polyfills, builtin,
|
||||
[&] { return s.quantizeToF16(vec); });
|
||||
polyfill = polyfills.GetOrCreate(
|
||||
builtin, [&] { return s.quantizeToF16(vec); });
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
@ -819,14 +726,20 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons
|
|||
break;
|
||||
}
|
||||
if (polyfill.IsValid()) {
|
||||
return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
|
||||
auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
|
||||
ctx.Replace(call->Declaration(), replacement);
|
||||
made_changes = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
if (!made_changes) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}
|
||||
|
|
|
@ -87,21 +87,13 @@ class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> {
|
|||
const Builtins builtins;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
protected:
|
||||
private:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -1561,7 +1561,8 @@ fn f() {
|
|||
TEST_F(BuiltinPolyfillTest, DISABLED_InsertBits_ConstantExpression) {
|
||||
auto* src = R"(
|
||||
fn f() {
|
||||
let r : i32 = insertBits(1234, 5678, 5u, 6u);
|
||||
let v = 1234i;
|
||||
let r : i32 = insertBits(v, 5678, 5u, 6u);
|
||||
}
|
||||
)";
|
||||
|
||||
|
@ -1975,10 +1976,6 @@ fn f() {
|
|||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
|
||||
@group(0) @binding(1) var s : sampler;
|
||||
|
||||
fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : vec2<f32>) -> vec4<f32> {
|
||||
let dims = vec2<f32>(textureDimensions(t, 0));
|
||||
let half_texel = (vec2<f32>(0.5) / dims);
|
||||
|
@ -1986,6 +1983,10 @@ fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : v
|
|||
return textureSampleLevel(t, s, clamped, 0);
|
||||
}
|
||||
|
||||
@group(0) @binding(0) var t : texture_2d<f32>;
|
||||
|
||||
@group(0) @binding(1) var s : sampler;
|
||||
|
||||
fn f() {
|
||||
let r = tint_textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5));
|
||||
}
|
||||
|
|
|
@ -40,6 +40,19 @@ namespace tint::transform {
|
|||
|
||||
namespace {
|
||||
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// ArrayUsage describes a runtime array usage.
|
||||
/// It is used as a key by the array_length_by_usage map.
|
||||
struct ArrayUsage {
|
||||
|
@ -73,21 +86,16 @@ const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSiz
|
|||
CalculateArrayLength::CalculateArrayLength() = default;
|
||||
CalculateArrayLength::~CalculateArrayLength() = default;
|
||||
|
||||
bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (auto* sem_fn = program->Sem().Get(fn)) {
|
||||
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
|
||||
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
auto& sem = src->Sem();
|
||||
|
||||
// get_buffer_size_intrinsic() emits the function decorated with
|
||||
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
|
||||
|
@ -95,24 +103,20 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics;
|
||||
auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) {
|
||||
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
|
||||
auto name = ctx.dst->Sym();
|
||||
auto name = b.Sym();
|
||||
auto* type = CreateASTTypeFor(ctx, buffer_type);
|
||||
auto* disable_validation =
|
||||
ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter);
|
||||
ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
|
||||
auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
|
||||
b.AST().AddFunction(b.create<ast::Function>(
|
||||
name,
|
||||
utils::Vector{
|
||||
ctx.dst->Param("buffer",
|
||||
ctx.dst->ty.pointer(type, buffer_type->AddressSpace(),
|
||||
buffer_type->Access()),
|
||||
utils::Vector{disable_validation}),
|
||||
ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(),
|
||||
ast::AddressSpace::kFunction)),
|
||||
b.Param("buffer",
|
||||
b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()),
|
||||
utils::Vector{disable_validation}),
|
||||
b.Param("result", b.ty.pointer(b.ty.u32(), ast::AddressSpace::kFunction)),
|
||||
},
|
||||
ctx.dst->ty.void_(), nullptr,
|
||||
b.ty.void_(), nullptr,
|
||||
utils::Vector{
|
||||
ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID(),
|
||||
ctx.dst->AllocateNodeID()),
|
||||
b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
|
||||
},
|
||||
utils::Empty));
|
||||
|
||||
|
@ -123,7 +127,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage;
|
||||
|
||||
// Find all the arrayLength() calls...
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* call_expr = node->As<ast::CallExpression>()) {
|
||||
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
|
||||
|
@ -149,7 +153,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
auto* arg = call_expr->args[0];
|
||||
auto* address_of = arg->As<ast::UnaryOpExpression>();
|
||||
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "arrayLength() expected address-of, got " << arg->TypeInfo().name;
|
||||
}
|
||||
auto* storage_buffer_expr = address_of->expr;
|
||||
|
@ -158,7 +162,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
}
|
||||
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
|
||||
if (!storage_buffer_sem) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "expected form of arrayLength argument to be &array_var or "
|
||||
"&struct_var.array_member";
|
||||
break;
|
||||
|
@ -179,25 +183,24 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
|
||||
// Construct the variable that'll hold the result of
|
||||
// RWByteAddressBuffer.GetDimensions()
|
||||
auto* buffer_size_result = ctx.dst->Decl(ctx.dst->Var(
|
||||
ctx.dst->Sym(), ctx.dst->ty.u32(), ctx.dst->Expr(0_u)));
|
||||
auto* buffer_size_result =
|
||||
b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u)));
|
||||
|
||||
// Call storage_buffer.GetDimensions(&buffer_size_result)
|
||||
auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
|
||||
auto* call_get_dims = b.CallStmt(b.Call(
|
||||
// BufferSizeIntrinsic(X, ARGS...) is
|
||||
// translated to:
|
||||
// X.GetDimensions(ARGS..) by the writer
|
||||
buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)),
|
||||
ctx.dst->AddressOf(
|
||||
ctx.dst->Expr(buffer_size_result->variable->symbol))));
|
||||
buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)),
|
||||
b.AddressOf(b.Expr(buffer_size_result->variable->symbol))));
|
||||
|
||||
// Calculate actual array length
|
||||
// total_storage_buffer_size - array_offset
|
||||
// array_length = ----------------------------------------
|
||||
// array_stride
|
||||
auto name = ctx.dst->Sym();
|
||||
auto name = b.Sym();
|
||||
const ast::Expression* total_size =
|
||||
ctx.dst->Expr(buffer_size_result->variable);
|
||||
b.Expr(buffer_size_result->variable);
|
||||
|
||||
const sem::Array* array_type = Switch(
|
||||
storage_buffer_type->StoreType(),
|
||||
|
@ -205,23 +208,21 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
// The variable is a struct, so subtract the byte offset of
|
||||
// the array member.
|
||||
auto* array_member_sem = str->Members().back();
|
||||
total_size =
|
||||
ctx.dst->Sub(total_size, u32(array_member_sem->Offset()));
|
||||
total_size = b.Sub(total_size, u32(array_member_sem->Offset()));
|
||||
return array_member_sem->Type()->As<sem::Array>();
|
||||
},
|
||||
[&](const sem::Array* arr) { return arr; });
|
||||
|
||||
if (!array_type) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "expected form of arrayLength argument to be "
|
||||
"&array_var or &struct_var.array_member";
|
||||
return name;
|
||||
}
|
||||
|
||||
uint32_t array_stride = array_type->Size();
|
||||
auto* array_length_var = ctx.dst->Decl(
|
||||
ctx.dst->Let(name, ctx.dst->ty.u32(),
|
||||
ctx.dst->Div(total_size, u32(array_stride))));
|
||||
auto* array_length_var = b.Decl(
|
||||
b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride))));
|
||||
|
||||
// Insert the array length calculations at the top of the block
|
||||
ctx.InsertBefore(block->statements, block->statements[0],
|
||||
|
@ -234,13 +235,14 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
});
|
||||
|
||||
// Replace the call to arrayLength() with the array length variable
|
||||
ctx.Replace(call_expr, ctx.dst->Expr(array_length));
|
||||
ctx.Replace(call_expr, b.Expr(array_length));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -59,19 +59,10 @@ class CalculateArrayLength final : public Castable<CalculateArrayLength, Transfo
|
|||
/// Destructor
|
||||
~CalculateArrayLength() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -123,7 +123,7 @@ bool HasSampleMask(utils::VectorRef<const ast::Attribute*> attrs) {
|
|||
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state for a single entry point.
|
||||
/// PIMPL state for the transform
|
||||
struct CanonicalizeEntryPointIO::State {
|
||||
/// OutputValue represents a shader result that the wrapper function produces.
|
||||
struct OutputValue {
|
||||
|
@ -770,17 +770,22 @@ struct CanonicalizeEntryPointIO::State {
|
|||
}
|
||||
};
|
||||
|
||||
void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
// Remove entry point IO attributes from struct declarations.
|
||||
// New structures will be created for each entry point, as necessary.
|
||||
for (auto* ty : ctx.src->AST().TypeDecls()) {
|
||||
for (auto* ty : src->AST().TypeDecls()) {
|
||||
if (auto* struct_ty = ty->As<ast::Struct>()) {
|
||||
for (auto* member : struct_ty->members) {
|
||||
for (auto* attr : member->attributes) {
|
||||
|
@ -792,7 +797,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat
|
|||
}
|
||||
}
|
||||
|
||||
for (auto* func_ast : ctx.src->AST().Functions()) {
|
||||
for (auto* func_ast : src->AST().Functions()) {
|
||||
if (!func_ast->IsEntryPoint()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -802,6 +807,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat
|
|||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,
|
||||
|
|
|
@ -127,15 +127,12 @@ class CanonicalizeEntryPointIO final : public Castable<CanonicalizeEntryPointIO,
|
|||
CanonicalizeEntryPointIO();
|
||||
~CanonicalizeEntryPointIO() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
#include "src/tint/transform/clamp_frag_depth.h"
|
||||
|
||||
#include <utility>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/ast/attribute.h"
|
||||
#include "src/tint/ast/builtin_attribute.h"
|
||||
|
@ -64,12 +64,7 @@ bool ReturnsFragDepthInStruct(const sem::Info& sem, const ast::Function* fn) {
|
|||
return false;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
ClampFragDepth::ClampFragDepth() = default;
|
||||
ClampFragDepth::~ClampFragDepth() = default;
|
||||
|
||||
bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
auto& sem = program->Sem();
|
||||
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
|
@ -82,22 +77,33 @@ bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
} // anonymous namespace
|
||||
|
||||
ClampFragDepth::ClampFragDepth() = default;
|
||||
ClampFragDepth::~ClampFragDepth() = default;
|
||||
|
||||
Transform::ApplyResult ClampFragDepth::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
// Abort on any use of push constants in the module.
|
||||
for (auto* global : ctx.src->AST().GlobalVariables()) {
|
||||
for (auto* global : src->AST().GlobalVariables()) {
|
||||
if (auto* var = global->As<ast::Var>()) {
|
||||
if (var->declared_address_space == ast::AddressSpace::kPushConstant) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "ClampFragDepth doesn't know how to handle module that already use push "
|
||||
"constants.";
|
||||
return;
|
||||
return Program(std::move(b));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto& b = *ctx.dst;
|
||||
auto& sem = ctx.src->Sem();
|
||||
auto& sym = ctx.src->Symbols();
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
auto& sem = src->Sem();
|
||||
auto& sym = src->Symbols();
|
||||
|
||||
// At least one entry-point needs clamping. Add the following to the module:
|
||||
//
|
||||
|
@ -197,6 +203,7 @@ void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -61,19 +61,10 @@ class ClampFragDepth final : public Castable<ClampFragDepth, Transform> {
|
|||
/// Destructor
|
||||
~ClampFragDepth() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -47,10 +47,14 @@ CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map,
|
|||
CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default;
|
||||
CombineSamplers::BindingInfo::~BindingInfo() = default;
|
||||
|
||||
/// The PIMPL state for the CombineSamplers transform
|
||||
/// PIMPL state for the transform
|
||||
struct CombineSamplers::State {
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
/// The binding info
|
||||
const BindingInfo* binding_info;
|
||||
|
@ -88,9 +92,9 @@ struct CombineSamplers::State {
|
|||
}
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
/// @param program the source program
|
||||
/// @param info the binding map information
|
||||
State(CloneContext& context, const BindingInfo* info) : ctx(context), binding_info(info) {}
|
||||
State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {}
|
||||
|
||||
/// Creates a combined sampler global variables.
|
||||
/// (Note this is actually a Texture node at the AST level, but it will be
|
||||
|
@ -145,8 +149,9 @@ struct CombineSamplers::State {
|
|||
}
|
||||
}
|
||||
|
||||
/// Performs the transformation
|
||||
void Run() {
|
||||
/// Runs the transform
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
// Remove all texture and sampler global variables. These will be replaced
|
||||
|
@ -169,14 +174,14 @@ struct CombineSamplers::State {
|
|||
|
||||
// Rewrite all function signatures to use combined samplers, and remove
|
||||
// separate textures & samplers. Create new combined globals where found.
|
||||
ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
|
||||
if (auto* func = sem.Get(src)) {
|
||||
auto pairs = func->TextureSamplerPairs();
|
||||
ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* {
|
||||
if (auto* fn = sem.Get(ast_fn)) {
|
||||
auto pairs = fn->TextureSamplerPairs();
|
||||
if (pairs.IsEmpty()) {
|
||||
return nullptr;
|
||||
}
|
||||
utils::Vector<const ast::Parameter*, 8> params;
|
||||
for (auto pair : func->TextureSamplerPairs()) {
|
||||
for (auto pair : fn->TextureSamplerPairs()) {
|
||||
const sem::Variable* texture_var = pair.first;
|
||||
const sem::Variable* sampler_var = pair.second;
|
||||
std::string name =
|
||||
|
@ -197,23 +202,23 @@ struct CombineSamplers::State {
|
|||
auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
|
||||
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
|
||||
params.Push(var);
|
||||
function_combined_texture_samplers_[func][pair] = var;
|
||||
function_combined_texture_samplers_[fn][pair] = var;
|
||||
}
|
||||
}
|
||||
// Filter out separate textures and samplers from the original
|
||||
// function signature.
|
||||
for (auto* var : src->params) {
|
||||
if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) {
|
||||
params.Push(ctx.Clone(var));
|
||||
for (auto* param : fn->Parameters()) {
|
||||
if (!param->Type()->IsAnyOf<sem::Texture, sem::Sampler>()) {
|
||||
params.Push(ctx.Clone(param->Declaration()));
|
||||
}
|
||||
}
|
||||
// Create a new function signature that differs only in the parameter
|
||||
// list.
|
||||
auto symbol = ctx.Clone(src->symbol);
|
||||
auto* return_type = ctx.Clone(src->return_type);
|
||||
auto* body = ctx.Clone(src->body);
|
||||
auto attributes = ctx.Clone(src->attributes);
|
||||
auto return_type_attributes = ctx.Clone(src->return_type_attributes);
|
||||
auto symbol = ctx.Clone(ast_fn->symbol);
|
||||
auto* return_type = ctx.Clone(ast_fn->return_type);
|
||||
auto* body = ctx.Clone(ast_fn->body);
|
||||
auto attributes = ctx.Clone(ast_fn->attributes);
|
||||
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
|
||||
return ctx.dst->create<ast::Function>(symbol, params, return_type, body,
|
||||
std::move(attributes),
|
||||
std::move(return_type_attributes));
|
||||
|
@ -327,6 +332,7 @@ struct CombineSamplers::State {
|
|||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -334,15 +340,18 @@ CombineSamplers::CombineSamplers() = default;
|
|||
|
||||
CombineSamplers::~CombineSamplers() = default;
|
||||
|
||||
void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
Transform::ApplyResult CombineSamplers::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* binding_info = inputs.Get<BindingInfo>();
|
||||
if (!binding_info) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
ProgramBuilder b;
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
State(ctx, binding_info).Run();
|
||||
return State(src, binding_info).Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -88,17 +88,13 @@ class CombineSamplers final : public Castable<CombineSamplers, Transform> {
|
|||
/// Destructor
|
||||
~CombineSamplers() override;
|
||||
|
||||
protected:
|
||||
/// The PIMPL state for this transform
|
||||
struct State;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -47,6 +47,18 @@ namespace tint::transform {
|
|||
|
||||
namespace {
|
||||
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
|
||||
if (var->AddressSpace() == ast::AddressSpace::kStorage ||
|
||||
var->AddressSpace() == ast::AddressSpace::kUniform) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Offset is a simple ast::Expression builder interface, used to build byte
|
||||
/// offsets for storage and uniform buffer accesses.
|
||||
struct Offset : Castable<Offset> {
|
||||
|
@ -291,7 +303,7 @@ struct Store {
|
|||
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state
|
||||
/// PIMPL state for the transform
|
||||
struct DecomposeMemoryAccess::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
|
@ -477,7 +489,7 @@ struct DecomposeMemoryAccess::State {
|
|||
// * Override-expression counts can only be applied to workgroup arrays, and
|
||||
// this method only handles storage and uniform.
|
||||
// * Runtime-sized arrays are not loadable.
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unexpected non-constant array count";
|
||||
arr_cnt = 1;
|
||||
}
|
||||
|
@ -578,7 +590,7 @@ struct DecomposeMemoryAccess::State {
|
|||
// * Override-expression counts can only be applied to workgroup
|
||||
// arrays, and this method only handles storage and uniform.
|
||||
// * Runtime-sized arrays are not storable.
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unexpected non-constant array count";
|
||||
arr_cnt = 1;
|
||||
}
|
||||
|
@ -808,21 +820,16 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const {
|
|||
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
|
||||
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
|
||||
|
||||
bool DecomposeMemoryAccess::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
|
||||
if (var->AddressSpace() == ast::AddressSpace::kStorage ||
|
||||
var->AddressSpace() == ast::AddressSpace::kUniform) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
|
||||
auto& sem = src->Sem();
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
State state(ctx);
|
||||
|
||||
// Scan the AST nodes for storage and uniform buffer accesses. Complex
|
||||
|
@ -833,7 +840,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
|
|||
// Inner-most expression nodes are guaranteed to be visited first because AST
|
||||
// nodes are fully immutable and require their children to be constructed
|
||||
// first so their pointer can be passed to the parent's initializer.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* ident = node->As<ast::IdentifierExpression>()) {
|
||||
// X
|
||||
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
|
||||
|
@ -1001,6 +1008,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
|
|||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -108,20 +108,12 @@ class DecomposeMemoryAccess final : public Castable<DecomposeMemoryAccess, Trans
|
|||
/// Destructor
|
||||
~DecomposeMemoryAccess() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
|
|
|
@ -34,13 +34,7 @@ namespace {
|
|||
|
||||
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
|
||||
|
||||
} // namespace
|
||||
|
||||
DecomposeStridedArray::DecomposeStridedArray() = default;
|
||||
|
||||
DecomposeStridedArray::~DecomposeStridedArray() = default;
|
||||
|
||||
bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* ast = node->As<ast::Array>()) {
|
||||
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
|
||||
|
@ -51,8 +45,22 @@ bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) co
|
|||
return false;
|
||||
}
|
||||
|
||||
void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
const auto& sem = ctx.src->Sem();
|
||||
} // namespace
|
||||
|
||||
DecomposeStridedArray::DecomposeStridedArray() = default;
|
||||
|
||||
DecomposeStridedArray::~DecomposeStridedArray() = default;
|
||||
|
||||
Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
const auto& sem = src->Sem();
|
||||
|
||||
static constexpr const char* kMemberName = "el";
|
||||
|
||||
|
@ -69,23 +77,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con
|
|||
if (auto* arr = sem.Get(ast)) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
|
||||
auto name = ctx.dst->Symbols().New("strided_arr");
|
||||
auto name = b.Symbols().New("strided_arr");
|
||||
auto* member_ty = ctx.Clone(ast->type);
|
||||
auto* member = ctx.dst->Member(kMemberName, member_ty,
|
||||
utils::Vector{
|
||||
ctx.dst->MemberSize(AInt(arr->Stride())),
|
||||
});
|
||||
ctx.dst->Structure(name, utils::Vector{member});
|
||||
auto* member = b.Member(kMemberName, member_ty,
|
||||
utils::Vector{
|
||||
b.MemberSize(AInt(arr->Stride())),
|
||||
});
|
||||
b.Structure(name, utils::Vector{member});
|
||||
return name;
|
||||
});
|
||||
auto* count = ctx.Clone(ast->count);
|
||||
return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
|
||||
return b.ty.array(b.ty.type_name(el_ty), count);
|
||||
}
|
||||
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
|
||||
// Strip the @stride attribute
|
||||
auto* ty = ctx.Clone(ast->type);
|
||||
auto* count = ctx.Clone(ast->count);
|
||||
return ctx.dst->ty.array(ty, count);
|
||||
return b.ty.array(ty, count);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -96,11 +104,11 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con
|
|||
// to insert an additional member accessor for the single structure field.
|
||||
// Example: `arr[i]` -> `arr[i].el`
|
||||
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
|
||||
if (auto* ty = ctx.src->TypeOf(idx->object)) {
|
||||
if (auto* ty = src->TypeOf(idx->object)) {
|
||||
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
auto* expr = ctx.CloneWithoutTransform(idx);
|
||||
return ctx.dst->MemberAccessor(expr, kMemberName);
|
||||
return b.MemberAccessor(expr, kMemberName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -136,21 +144,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con
|
|||
if (auto it = decomposed.find(arr); it != decomposed.end()) {
|
||||
args.Reserve(expr->args.Length());
|
||||
for (auto* arg : expr->args) {
|
||||
args.Push(ctx.dst->Call(it->second, ctx.Clone(arg)));
|
||||
args.Push(b.Call(it->second, ctx.Clone(arg)));
|
||||
}
|
||||
} else {
|
||||
args = ctx.Clone(expr->args);
|
||||
}
|
||||
|
||||
return target.type ? ctx.dst->Construct(target.type, std::move(args))
|
||||
: ctx.dst->Call(target.name, std::move(args));
|
||||
return target.type ? b.Construct(target.type, std::move(args))
|
||||
: b.Call(target.name, std::move(args));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -35,19 +35,10 @@ class DecomposeStridedArray final : public Castable<DecomposeStridedArray, Trans
|
|||
/// Destructor
|
||||
~DecomposeStridedArray() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -53,24 +53,25 @@ struct MatrixInfo {
|
|||
};
|
||||
};
|
||||
|
||||
/// Return type of the callback function of GatherCustomStrideMatrixMembers
|
||||
enum GatherResult { kContinue, kStop };
|
||||
} // namespace
|
||||
|
||||
/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
|
||||
/// storage and uniform structs, which are of a matrix type, and have a custom
|
||||
/// matrix stride attribute. For each matrix member found, `callback` is called.
|
||||
/// `callback` is a function with the signature:
|
||||
/// GatherResult(const sem::StructMember* member,
|
||||
/// sem::Matrix* matrix,
|
||||
/// uint32_t stride)
|
||||
/// If `callback` return GatherResult::kStop, then the scanning will immediately
|
||||
/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
|
||||
/// scanning will continue.
|
||||
template <typename F>
|
||||
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
|
||||
|
||||
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
|
||||
|
||||
Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
// Scan the program for all storage and uniform structure matrix members with
|
||||
// a custom stride attribute. Replace these matrices with an equivalent array,
|
||||
// and populate the `decomposed` map with the members that have been replaced.
|
||||
utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed;
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* str = node->As<ast::Struct>()) {
|
||||
auto* str_ty = program->Sem().Get(str);
|
||||
auto* str_ty = src->Sem().Get(str);
|
||||
if (!str_ty->UsedAs(ast::AddressSpace::kUniform) &&
|
||||
!str_ty->UsedAs(ast::AddressSpace::kStorage)) {
|
||||
continue;
|
||||
|
@ -89,46 +90,20 @@ void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
|
|||
if (matrix->ColumnStride() == stride) {
|
||||
continue;
|
||||
}
|
||||
if (callback(member, matrix, stride) == GatherResult::kStop) {
|
||||
return;
|
||||
}
|
||||
// We've got ourselves a struct member of a matrix type with a custom
|
||||
// stride. Replace this with an array of column vectors.
|
||||
MatrixInfo info{stride, matrix};
|
||||
auto* replacement =
|
||||
b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
|
||||
ctx.Replace(member->Declaration(), replacement);
|
||||
decomposed.Add(member->Declaration(), info);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
|
||||
|
||||
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
|
||||
|
||||
bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool should_run = false;
|
||||
GatherCustomStrideMatrixMembers(program,
|
||||
[&](const sem::StructMember*, const sem::Matrix*, uint32_t) {
|
||||
should_run = true;
|
||||
return GatherResult::kStop;
|
||||
});
|
||||
return should_run;
|
||||
}
|
||||
|
||||
void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
// Scan the program for all storage and uniform structure matrix members with
|
||||
// a custom stride attribute. Replace these matrices with an equivalent array,
|
||||
// and populate the `decomposed` map with the members that have been replaced.
|
||||
std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
|
||||
GatherCustomStrideMatrixMembers(
|
||||
ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) {
|
||||
// We've got ourselves a struct member of a matrix type with a custom
|
||||
// stride. Replace this with an array of column vectors.
|
||||
MatrixInfo info{stride, matrix};
|
||||
auto* replacement =
|
||||
ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
|
||||
ctx.Replace(member->Declaration(), replacement);
|
||||
decomposed.emplace(member->Declaration(), info);
|
||||
return GatherResult::kContinue;
|
||||
});
|
||||
if (decomposed.IsEmpty()) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
// For all expressions where a single matrix column vector was indexed, we can
|
||||
// preserve these without calling conversion functions.
|
||||
|
@ -136,12 +111,11 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co
|
|||
// ssbo.mat[2] -> ssbo.mat[2]
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it != decomposed.end()) {
|
||||
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
|
||||
if (decomposed.Contains(access->Member()->Declaration())) {
|
||||
auto* obj = ctx.CloneWithoutTransform(expr->object);
|
||||
auto* idx = ctx.Clone(expr->index);
|
||||
return ctx.dst->IndexAccessor(obj, idx);
|
||||
return b.IndexAccessor(obj, idx);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -154,39 +128,36 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co
|
|||
// ssbo.mat = mat_to_arr(m)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
|
||||
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
|
||||
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
|
||||
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
|
||||
auto name =
|
||||
b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
|
||||
std::to_string(info->matrix->rows()) + "_stride_" +
|
||||
std::to_string(info->stride) + "_to_arr");
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
|
||||
auto array = [&] { return info->array(ctx.dst); };
|
||||
|
||||
auto mat = b.Sym("m");
|
||||
utils::Vector<const ast::Expression*, 4> columns;
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
|
||||
columns.Push(b.IndexAccessor(mat, u32(i)));
|
||||
}
|
||||
b.Func(name,
|
||||
utils::Vector{
|
||||
b.Param(mat, matrix()),
|
||||
},
|
||||
array(),
|
||||
utils::Vector{
|
||||
b.Return(b.Construct(array(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
|
||||
auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs));
|
||||
return b.Assign(lhs, rhs);
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
|
||||
auto name =
|
||||
ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" +
|
||||
std::to_string(info.stride) + "_to_arr");
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
|
||||
auto mat = ctx.dst->Sym("m");
|
||||
utils::Vector<const ast::Expression*, 4> columns;
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
|
||||
columns.Push(ctx.dst->IndexAccessor(mat, u32(i)));
|
||||
}
|
||||
ctx.dst->Func(name,
|
||||
utils::Vector{
|
||||
ctx.dst->Param(mat, matrix()),
|
||||
},
|
||||
array(),
|
||||
utils::Vector{
|
||||
ctx.dst->Return(ctx.dst->Construct(array(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
|
||||
auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
|
||||
return ctx.dst->Assign(lhs, rhs);
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
@ -196,41 +167,40 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co
|
|||
// m = arr_to_mat(ssbo.mat)
|
||||
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
|
||||
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
|
||||
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
|
||||
auto it = decomposed.find(access->Member()->Declaration());
|
||||
if (it == decomposed.end()) {
|
||||
return nullptr;
|
||||
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
|
||||
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
|
||||
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
|
||||
auto name =
|
||||
b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
|
||||
"x" + std::to_string(info->matrix->rows()) + "_stride_" +
|
||||
std::to_string(info->stride));
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
|
||||
auto array = [&] { return info->array(ctx.dst); };
|
||||
|
||||
auto arr = b.Sym("arr");
|
||||
utils::Vector<const ast::Expression*, 4> columns;
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
|
||||
columns.Push(b.IndexAccessor(arr, u32(i)));
|
||||
}
|
||||
b.Func(name,
|
||||
utils::Vector{
|
||||
b.Param(arr, array()),
|
||||
},
|
||||
matrix(),
|
||||
utils::Vector{
|
||||
b.Return(b.Construct(matrix(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return b.Call(fn, ctx.CloneWithoutTransform(expr));
|
||||
}
|
||||
MatrixInfo info = it->second;
|
||||
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
|
||||
auto name = ctx.dst->Symbols().New(
|
||||
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
|
||||
std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride));
|
||||
|
||||
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
|
||||
auto array = [&] { return info.array(ctx.dst); };
|
||||
|
||||
auto arr = ctx.dst->Sym("arr");
|
||||
utils::Vector<const ast::Expression*, 4> columns;
|
||||
for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
|
||||
columns.Push(ctx.dst->IndexAccessor(arr, u32(i)));
|
||||
}
|
||||
ctx.dst->Func(name,
|
||||
utils::Vector{
|
||||
ctx.dst->Param(arr, array()),
|
||||
},
|
||||
matrix(),
|
||||
utils::Vector{
|
||||
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -35,19 +35,10 @@ class DecomposeStridedMatrix final : public Castable<DecomposeStridedMatrix, Tra
|
|||
/// Destructor
|
||||
~DecomposeStridedMatrix() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -27,14 +27,20 @@ DisableUniformityAnalysis::DisableUniformityAnalysis() = default;
|
|||
|
||||
DisableUniformityAnalysis::~DisableUniformityAnalysis() = default;
|
||||
|
||||
bool DisableUniformityAnalysis::ShouldRun(const Program* program, const DataMap&) const {
|
||||
return !program->Sem().Module()->Extensions().Contains(
|
||||
ast::Extension::kChromiumDisableUniformityAnalysis);
|
||||
}
|
||||
Transform::ApplyResult DisableUniformityAnalysis::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (src->Sem().Module()->Extensions().Contains(
|
||||
ast::Extension::kChromiumDisableUniformityAnalysis)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
b.Enable(ast::Extension::kChromiumDisableUniformityAnalysis);
|
||||
|
||||
void DisableUniformityAnalysis::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ctx.dst->Enable(ast::Extension::kChromiumDisableUniformityAnalysis);
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -27,19 +27,10 @@ class DisableUniformityAnalysis final : public Castable<DisableUniformityAnalysi
|
|||
/// Destructor
|
||||
~DisableUniformityAnalysis() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -31,11 +31,9 @@ using namespace tint::number_suffixes; // NOLINT
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
|
||||
namespace {
|
||||
|
||||
ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
|
||||
|
||||
bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) {
|
||||
return true;
|
||||
|
@ -44,21 +42,10 @@ bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&)
|
|||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
} // namespace
|
||||
|
||||
/// Internal class used to collect statement expansions during the transform.
|
||||
class State {
|
||||
private:
|
||||
/// The clone context.
|
||||
CloneContext& ctx;
|
||||
|
||||
/// The program builder.
|
||||
ProgramBuilder& b;
|
||||
|
||||
/// The HoistToDeclBefore helper instance.
|
||||
HoistToDeclBefore hoist_to_decl_before;
|
||||
|
||||
public:
|
||||
/// PIMPL state for the transform
|
||||
struct ExpandCompoundAssignment::State {
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
|
||||
|
@ -158,15 +145,32 @@ class State {
|
|||
ctx.Replace(stmt, b.Assign(new_lhs(), value));
|
||||
}
|
||||
|
||||
/// Finalize the transformation and clone the module.
|
||||
void Finalize() { ctx.Clone(); }
|
||||
private:
|
||||
/// The clone context.
|
||||
CloneContext& ctx;
|
||||
|
||||
/// The program builder.
|
||||
ProgramBuilder& b;
|
||||
|
||||
/// The HoistToDeclBefore helper instance.
|
||||
HoistToDeclBefore hoist_to_decl_before;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
|
||||
|
||||
void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
|
||||
|
||||
Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
State state(ctx);
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
|
||||
state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
|
||||
} else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
|
||||
|
@ -175,7 +179,9 @@ void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&)
|
|||
state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
|
||||
}
|
||||
}
|
||||
state.Finalize();
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -45,19 +45,13 @@ class ExpandCompoundAssignment final : public Castable<ExpandCompoundAssignment,
|
|||
/// Destructor
|
||||
~ExpandCompoundAssignment() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -35,6 +35,15 @@ namespace {
|
|||
constexpr char kFirstVertexName[] = "first_vertex_index";
|
||||
constexpr char kFirstInstanceName[] = "first_instance_index";
|
||||
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FirstIndexOffset::BindingPoint::BindingPoint() = default;
|
||||
|
@ -49,16 +58,16 @@ FirstIndexOffset::Data::~Data() = default;
|
|||
FirstIndexOffset::FirstIndexOffset() = default;
|
||||
FirstIndexOffset::~FirstIndexOffset() = default;
|
||||
|
||||
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
return true;
|
||||
}
|
||||
Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
// Get the uniform buffer binding point
|
||||
uint32_t ub_binding = binding_;
|
||||
uint32_t ub_group = group_;
|
||||
|
@ -115,17 +124,17 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou
|
|||
if (has_vertex_or_instance_index) {
|
||||
// Add uniform buffer members and calculate byte offsets
|
||||
utils::Vector<const ast::StructMember*, 8> members;
|
||||
members.Push(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32()));
|
||||
members.Push(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32()));
|
||||
auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members));
|
||||
members.Push(b.Member(kFirstVertexName, b.ty.u32()));
|
||||
members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
|
||||
auto* struct_ = b.Structure(b.Sym(), std::move(members));
|
||||
|
||||
// Create a global to hold the uniform buffer
|
||||
Symbol buffer_name = ctx.dst->Sym();
|
||||
ctx.dst->GlobalVar(buffer_name, ctx.dst->ty.Of(struct_), ast::AddressSpace::kUniform,
|
||||
utils::Vector{
|
||||
ctx.dst->Binding(AInt(ub_binding)),
|
||||
ctx.dst->Group(AInt(ub_group)),
|
||||
});
|
||||
Symbol buffer_name = b.Sym();
|
||||
b.GlobalVar(buffer_name, b.ty.Of(struct_), ast::AddressSpace::kUniform,
|
||||
utils::Vector{
|
||||
b.Binding(AInt(ub_binding)),
|
||||
b.Group(AInt(ub_group)),
|
||||
});
|
||||
|
||||
// Fix up all references to the builtins with the offsets
|
||||
ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
|
||||
|
@ -150,9 +159,10 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou
|
|||
});
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
|
||||
outputs.Add<Data>(has_vertex_or_instance_index);
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -103,19 +103,10 @@ class FirstIndexOffset final : public Castable<FirstIndexOffset, Transform> {
|
|||
/// Destructor
|
||||
~FirstIndexOffset() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
uint32_t binding_ = 0;
|
||||
|
|
|
@ -14,17 +14,17 @@
|
|||
|
||||
#include "src/tint/transform/for_loop_to_loop.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/ast/break_statement.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop);
|
||||
|
||||
namespace tint::transform {
|
||||
ForLoopToLoop::ForLoopToLoop() = default;
|
||||
namespace {
|
||||
|
||||
ForLoopToLoop::~ForLoopToLoop() = default;
|
||||
|
||||
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::ForLoopStatement>()) {
|
||||
return true;
|
||||
|
@ -33,19 +33,31 @@ bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
} // namespace
|
||||
|
||||
ForLoopToLoop::ForLoopToLoop() = default;
|
||||
|
||||
ForLoopToLoop::~ForLoopToLoop() = default;
|
||||
|
||||
Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
|
||||
utils::Vector<const ast::Statement*, 8> stmts;
|
||||
if (auto* cond = for_loop->condition) {
|
||||
// !condition
|
||||
auto* not_cond =
|
||||
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
|
||||
auto* not_cond = b.Not(ctx.Clone(cond));
|
||||
|
||||
// { break; }
|
||||
auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
|
||||
auto* break_body = b.Block(b.Break());
|
||||
|
||||
// if (!condition) { break; }
|
||||
stmts.Push(ctx.dst->If(not_cond, break_body));
|
||||
stmts.Push(b.If(not_cond, break_body));
|
||||
}
|
||||
for (auto* stmt : for_loop->body->statements) {
|
||||
stmts.Push(ctx.Clone(stmt));
|
||||
|
@ -53,20 +65,21 @@ void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
|
||||
const ast::BlockStatement* continuing = nullptr;
|
||||
if (auto* cont = for_loop->continuing) {
|
||||
continuing = ctx.dst->Block(ctx.Clone(cont));
|
||||
continuing = b.Block(ctx.Clone(cont));
|
||||
}
|
||||
|
||||
auto* body = ctx.dst->Block(stmts);
|
||||
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
|
||||
auto* body = b.Block(stmts);
|
||||
auto* loop = b.Loop(body, continuing);
|
||||
|
||||
if (auto* init = for_loop->initializer) {
|
||||
return ctx.dst->Block(ctx.Clone(init), loop);
|
||||
return b.Block(ctx.Clone(init), loop);
|
||||
}
|
||||
|
||||
return loop;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -29,19 +29,10 @@ class ForLoopToLoop final : public Castable<ForLoopToLoop, Transform> {
|
|||
/// Destructor
|
||||
~ForLoopToLoop() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -32,70 +32,15 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment);
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
/// Private implementation of LocalizeStructArrayAssignment transform
|
||||
class LocalizeStructArrayAssignment::State {
|
||||
private:
|
||||
CloneContext& ctx;
|
||||
ProgramBuilder& b;
|
||||
|
||||
/// Returns true if `expr` contains an index accessor expression to a
|
||||
/// structure member of array type.
|
||||
bool ContainsStructArrayIndex(const ast::Expression* expr) {
|
||||
bool result = false;
|
||||
ast::TraverseExpressions(
|
||||
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
|
||||
// Indexing using a runtime value?
|
||||
auto* idx_sem = ctx.src->Sem().Get(ia->index);
|
||||
if (!idx_sem->ConstantValue()) {
|
||||
// Indexing a member access expr?
|
||||
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
|
||||
// That accesses an array?
|
||||
if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
|
||||
result = true;
|
||||
return ast::TraverseAction::Stop;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ast::TraverseAction::Descend;
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the type and address space of the originating variable of the lhs
|
||||
// of the assignment statement.
|
||||
// See https://www.w3.org/TR/WGSL/#originating-variable-section
|
||||
std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace(
|
||||
const ast::AssignmentStatement* assign_stmt) {
|
||||
auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable();
|
||||
if (!source_var) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "Unable to determine originating variable for lhs of assignment "
|
||||
"statement";
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* type = source_var->Type();
|
||||
if (auto* ref = type->As<sem::Reference>()) {
|
||||
return {ref->StoreType(), ref->AddressSpace()};
|
||||
} else if (auto* ptr = type->As<sem::Pointer>()) {
|
||||
return {ptr->StoreType(), ptr->AddressSpace()};
|
||||
}
|
||||
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "Expecting to find variable of type pointer or reference on lhs "
|
||||
"of assignment statement";
|
||||
return {};
|
||||
}
|
||||
|
||||
public:
|
||||
/// PIMPL state for the transform
|
||||
struct LocalizeStructArrayAssignment::State {
|
||||
/// Constructor
|
||||
/// @param ctx_in the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
|
||||
/// @param program the source program
|
||||
explicit State(const Program* program) : src(program) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
struct Shared {
|
||||
bool process_nested_nodes = false;
|
||||
utils::Vector<const ast::Statement*, 4> insert_before_stmts;
|
||||
|
@ -189,6 +134,65 @@ class LocalizeStructArrayAssignment::State {
|
|||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
private:
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
/// Returns true if `expr` contains an index accessor expression to a
|
||||
/// structure member of array type.
|
||||
bool ContainsStructArrayIndex(const ast::Expression* expr) {
|
||||
bool result = false;
|
||||
ast::TraverseExpressions(
|
||||
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
|
||||
// Indexing using a runtime value?
|
||||
auto* idx_sem = src->Sem().Get(ia->index);
|
||||
if (!idx_sem->ConstantValue()) {
|
||||
// Indexing a member access expr?
|
||||
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
|
||||
// That accesses an array?
|
||||
if (src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
|
||||
result = true;
|
||||
return ast::TraverseAction::Stop;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ast::TraverseAction::Descend;
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the type and address space of the originating variable of the lhs
|
||||
// of the assignment statement.
|
||||
// See https://www.w3.org/TR/WGSL/#originating-variable-section
|
||||
std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace(
|
||||
const ast::AssignmentStatement* assign_stmt) {
|
||||
auto* source_var = src->Sem().Get(assign_stmt->lhs)->SourceVariable();
|
||||
if (!source_var) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "Unable to determine originating variable for lhs of assignment "
|
||||
"statement";
|
||||
return {};
|
||||
}
|
||||
|
||||
auto* type = source_var->Type();
|
||||
if (auto* ref = type->As<sem::Reference>()) {
|
||||
return {ref->StoreType(), ref->AddressSpace()};
|
||||
} else if (auto* ptr = type->As<sem::Pointer>()) {
|
||||
return {ptr->StoreType(), ptr->AddressSpace()};
|
||||
}
|
||||
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "Expecting to find variable of type pointer or reference on lhs "
|
||||
"of assignment statement";
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -196,9 +200,10 @@ LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default;
|
|||
|
||||
LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
|
||||
|
||||
void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State state(ctx);
|
||||
state.Run();
|
||||
Transform::ApplyResult LocalizeStructArrayAssignment::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
return State{src}.Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -36,17 +36,13 @@ class LocalizeStructArrayAssignment final
|
|||
/// Destructor
|
||||
~LocalizeStructArrayAssignment() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
class State;
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -31,9 +31,9 @@ namespace tint::transform {
|
|||
Manager::Manager() = default;
|
||||
Manager::~Manager() = default;
|
||||
|
||||
Output Manager::Run(const Program* program, const DataMap& data) const {
|
||||
const Program* in = program;
|
||||
|
||||
Transform::ApplyResult Manager::Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
|
||||
auto print_program = [&](const char* msg, const Transform* transform) {
|
||||
auto wgsl = Program::printer(in);
|
||||
|
@ -46,34 +46,30 @@ Output Manager::Run(const Program* program, const DataMap& data) const {
|
|||
};
|
||||
#endif
|
||||
|
||||
Output out;
|
||||
std::optional<Program> output;
|
||||
|
||||
for (const auto& transform : transforms_) {
|
||||
if (!transform->ShouldRun(in, data)) {
|
||||
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name
|
||||
<< std::endl);
|
||||
continue;
|
||||
}
|
||||
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
|
||||
|
||||
auto res = transform->Run(in, data);
|
||||
out.program = std::move(res.program);
|
||||
out.data.Add(std::move(res.data));
|
||||
in = &out.program;
|
||||
if (!in->IsValid()) {
|
||||
TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
|
||||
return out;
|
||||
}
|
||||
if (auto result = transform->Apply(program, inputs, outputs)) {
|
||||
output.emplace(std::move(result.value()));
|
||||
program = &output.value();
|
||||
|
||||
if (transform == transforms_.back()) {
|
||||
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
|
||||
if (!program->IsValid()) {
|
||||
TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
|
||||
break;
|
||||
}
|
||||
|
||||
if (transform == transforms_.back()) {
|
||||
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
|
||||
}
|
||||
} else {
|
||||
TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name
|
||||
<< std::endl);
|
||||
}
|
||||
}
|
||||
|
||||
if (program == in) {
|
||||
out.program = program->Clone();
|
||||
}
|
||||
|
||||
return out;
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -47,11 +47,10 @@ class Manager final : public Castable<Manager, Transform> {
|
|||
transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
|
||||
}
|
||||
|
||||
/// Runs the transforms on `program`, returning the transformation result.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns the transformed program and diagnostics
|
||||
Output Run(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Transform>> transforms_;
|
||||
|
|
|
@ -65,15 +65,6 @@ MergeReturn::MergeReturn() = default;
|
|||
|
||||
MergeReturn::~MergeReturn() = default;
|
||||
|
||||
bool MergeReturn::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* func : program->AST().Functions()) {
|
||||
if (NeedsTransform(program, func)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Internal class used to during the transform.
|
||||
|
@ -223,7 +214,12 @@ class State {
|
|||
|
||||
} // namespace
|
||||
|
||||
void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
Transform::ApplyResult MergeReturn::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
bool made_changes = false;
|
||||
|
||||
for (auto* func : ctx.src->AST().Functions()) {
|
||||
if (!NeedsTransform(ctx.src, func)) {
|
||||
continue;
|
||||
|
@ -231,9 +227,15 @@ void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
|
||||
State state(ctx, func);
|
||||
state.ProcessStatement(func->body);
|
||||
made_changes = true;
|
||||
}
|
||||
|
||||
if (!made_changes) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -27,19 +27,10 @@ class MergeReturn final : public Castable<MergeReturn, Transform> {
|
|||
/// Destructor
|
||||
~MergeReturn() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -38,6 +38,15 @@ using WorkgroupParameterMemberList = utils::Vector<const ast::StructMember*, 8>;
|
|||
// The name of the struct member for arrays that are wrapped in structures.
|
||||
const char* kWrappedArrayMemberName = "arr";
|
||||
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (decl->Is<ast::Variable>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns `true` if `type` is or contains a matrix type.
|
||||
bool ContainsMatrix(const sem::Type* type) {
|
||||
type = type->UnwrapRef();
|
||||
|
@ -56,7 +65,7 @@ bool ContainsMatrix(const sem::Type* type) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state.
|
||||
/// PIMPL state for the transform
|
||||
struct ModuleScopeVarToEntryPointParam::State {
|
||||
/// The clone context.
|
||||
CloneContext& ctx;
|
||||
|
@ -501,19 +510,20 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
|
|||
|
||||
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
|
||||
|
||||
bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (decl->Is<ast::Variable>()) {
|
||||
return true;
|
||||
}
|
||||
Transform::ApplyResult ModuleScopeVarToEntryPointParam::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
State state{ctx};
|
||||
state.Process();
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -69,20 +69,12 @@ class ModuleScopeVarToEntryPointParam final
|
|||
/// Destructor
|
||||
~ModuleScopeVarToEntryPointParam() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
|
|
|
@ -31,6 +31,17 @@ using namespace tint::number_suffixes; // NOLINT
|
|||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* ty = node->As<ast::Type>()) {
|
||||
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// This struct stores symbols for new bindings created as a result of transforming a
|
||||
/// texture_external instance.
|
||||
struct NewBindingSymbols {
|
||||
|
@ -40,7 +51,7 @@ struct NewBindingSymbols {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
/// State holds the current transform state
|
||||
/// PIMPL state for the transform
|
||||
struct MultiplanarExternalTexture::State {
|
||||
/// The clone context.
|
||||
CloneContext& ctx;
|
||||
|
@ -537,30 +548,26 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
|
|||
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
|
||||
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
|
||||
|
||||
bool MultiplanarExternalTexture::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* ty = node->As<ast::Type>()) {
|
||||
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Within this transform, an instance of a texture_external binding is unpacked into two
|
||||
// texture_2d<f32> bindings representing two possible planes of a single texture and a uniform
|
||||
// buffer binding representing a struct of parameters. Calls to texture builtins that contain a
|
||||
// texture_external parameter will be transformed into a newly generated version of the function,
|
||||
// which can perform the desired operation on a single RGBA plane or on separate Y and UV planes.
|
||||
void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
Transform::ApplyResult MultiplanarExternalTexture::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto* new_binding_points = inputs.Get<NewBindingPoints>();
|
||||
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
if (!new_binding_points) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"missing new binding point data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
b.Diagnostics().add_error(diag::System::Transform, "missing new binding point data for " +
|
||||
std::string(TypeInfo().name));
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
State state(ctx, new_binding_points);
|
||||
|
@ -568,6 +575,7 @@ void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, D
|
|||
state.Process();
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -80,21 +80,13 @@ class MultiplanarExternalTexture final : public Castable<MultiplanarExternalText
|
|||
/// Destructor
|
||||
~MultiplanarExternalTexture() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
protected:
|
||||
private:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -23,7 +23,11 @@ using MultiplanarExternalTextureTest = TransformTest;
|
|||
TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
DataMap data;
|
||||
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
|
||||
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
|
||||
|
||||
EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
|
||||
|
@ -31,14 +35,22 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
|
|||
type ET = texture_external;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
DataMap data;
|
||||
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
|
||||
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
|
||||
}
|
||||
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
|
||||
auto* src = R"(
|
||||
@group(0) @binding(0) var ext_tex : texture_external;
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
DataMap data;
|
||||
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
|
||||
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
|
||||
|
@ -46,7 +58,11 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
|
|||
fn f(ext_tex : texture_external) {}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
|
||||
DataMap data;
|
||||
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
|
||||
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
|
||||
|
||||
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
|
||||
}
|
||||
|
||||
// Running the transform without passing in data for the new bindings should result in an error.
|
||||
|
|
|
@ -29,6 +29,18 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config);
|
|||
|
||||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
|
||||
if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Accessor describes the identifiers used in a member accessor that is being
|
||||
/// used to retrieve the num_workgroups builtin from a parameter.
|
||||
struct Accessor {
|
||||
|
@ -44,41 +56,40 @@ struct Accessor {
|
|||
size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); }
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
|
||||
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
|
||||
|
||||
bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
|
||||
if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
|
||||
return;
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
const char* kNumWorkgroupsMemberName = "num_workgroups";
|
||||
|
||||
// Find all entry point parameters that declare the num_workgroups builtin.
|
||||
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
|
||||
for (auto* func : ctx.src->AST().Functions()) {
|
||||
for (auto* func : src->AST().Functions()) {
|
||||
// num_workgroups is only valid for compute stages.
|
||||
if (func->PipelineStage() != ast::PipelineStage::kCompute) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
|
||||
for (auto* param : src->Sem().Get(func)->Parameters()) {
|
||||
// Because the CanonicalizeEntryPointIO transform has been run, builtins
|
||||
// will only appear as struct members.
|
||||
auto* str = param->Type()->As<sem::Struct>();
|
||||
|
@ -108,7 +119,7 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
|
|||
// If this is the only member, remove the struct and parameter too.
|
||||
if (str->Members().size() == 1) {
|
||||
ctx.Remove(func->params, param->Declaration());
|
||||
ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
|
||||
ctx.Remove(src->AST().GlobalDeclarations(), str->Declaration());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -119,11 +130,10 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
|
|||
const ast::Variable* num_workgroups_ubo = nullptr;
|
||||
auto get_ubo = [&]() {
|
||||
if (!num_workgroups_ubo) {
|
||||
auto* num_workgroups_struct = ctx.dst->Structure(
|
||||
ctx.dst->Sym(),
|
||||
utils::Vector{
|
||||
ctx.dst->Member(kNumWorkgroupsMemberName, ctx.dst->ty.vec3(ctx.dst->ty.u32())),
|
||||
});
|
||||
auto* num_workgroups_struct =
|
||||
b.Structure(b.Sym(), utils::Vector{
|
||||
b.Member(kNumWorkgroupsMemberName, b.ty.vec3(b.ty.u32())),
|
||||
});
|
||||
|
||||
uint32_t group, binding;
|
||||
if (cfg->ubo_binding.has_value()) {
|
||||
|
@ -135,9 +145,9 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
|
|||
// plus 1, or group 0 if no resource bound.
|
||||
group = 0;
|
||||
|
||||
for (auto* global : ctx.src->AST().GlobalVariables()) {
|
||||
for (auto* global : src->AST().GlobalVariables()) {
|
||||
if (global->HasBindingPoint()) {
|
||||
auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(global);
|
||||
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(global);
|
||||
auto binding_point = global_sem->BindingPoint();
|
||||
if (binding_point.group >= group) {
|
||||
group = binding_point.group + 1;
|
||||
|
@ -148,16 +158,16 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
|
|||
binding = 0;
|
||||
}
|
||||
|
||||
num_workgroups_ubo = ctx.dst->GlobalVar(
|
||||
ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform,
|
||||
ctx.dst->Group(AInt(group)), ctx.dst->Binding(AInt(binding)));
|
||||
num_workgroups_ubo =
|
||||
b.GlobalVar(b.Sym(), b.ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform,
|
||||
b.Group(AInt(group)), b.Binding(AInt(binding)));
|
||||
}
|
||||
return num_workgroups_ubo;
|
||||
};
|
||||
|
||||
// Now replace all the places where the builtins are accessed with the value
|
||||
// loaded from the uniform buffer.
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
auto* accessor = node->As<ast::MemberAccessorExpression>();
|
||||
if (!accessor) {
|
||||
continue;
|
||||
|
@ -168,12 +178,12 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
|
|||
}
|
||||
|
||||
if (to_replace.count({ident->symbol, accessor->member->symbol})) {
|
||||
ctx.Replace(accessor,
|
||||
ctx.dst->MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName));
|
||||
ctx.Replace(accessor, b.MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName));
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
NumWorkgroupsFromUniform::Config::Config(std::optional<sem::BindingPoint> ubo_bp)
|
||||
|
|
|
@ -72,19 +72,10 @@ class NumWorkgroupsFromUniform final : public Castable<NumWorkgroupsFromUniform,
|
|||
std::optional<sem::BindingPoint> ubo_binding;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -28,7 +28,9 @@ using NumWorkgroupsFromUniformTest = TransformTest;
|
|||
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
|
||||
auto* src = R"()";
|
||||
|
||||
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
|
||||
DataMap data;
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
|
||||
|
@ -38,7 +40,9 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
|
|||
}
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
|
||||
DataMap data;
|
||||
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
|
||||
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
|
||||
}
|
||||
|
||||
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
|
||||
|
@ -55,7 +59,6 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
|
|||
DataMap data;
|
||||
data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
|
||||
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
|
|
|
@ -33,14 +33,15 @@ using namespace tint::number_suffixes; // NOLINT
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
/// The PIMPL state for the PackedVec3 transform
|
||||
/// PIMPL state for the transform
|
||||
struct PackedVec3::State {
|
||||
/// Constructor
|
||||
/// @param c the CloneContext
|
||||
explicit State(CloneContext& c) : ctx(c) {}
|
||||
/// @param program the source program
|
||||
explicit State(const Program* program) : src(program) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
// Packed vec3<T> struct members
|
||||
utils::Hashset<const sem::StructMember*, 8> members;
|
||||
|
||||
|
@ -72,6 +73,10 @@ struct PackedVec3::State {
|
|||
}
|
||||
}
|
||||
|
||||
if (members.IsEmpty()) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
// Walk the nodes, starting with the most deeply nested, finding all the AST expressions
|
||||
// that load a whole packed vector (not a scalar / swizzle of the vector).
|
||||
utils::Hashset<const sem::Expression*, 16> refs;
|
||||
|
@ -137,36 +142,20 @@ struct PackedVec3::State {
|
|||
}
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
|
||||
/// @returns true if this transform should be run for the given program
|
||||
/// @param program the program to inspect
|
||||
static bool ShouldRun(const Program* program) {
|
||||
for (auto* decl : program->AST().GlobalDeclarations()) {
|
||||
if (auto* str = program->Sem().Get<sem::Struct>(decl)) {
|
||||
if (str->IsHostShareable()) {
|
||||
for (auto* member : str->Members()) {
|
||||
if (auto* vec = member->Type()->As<sem::Vector>()) {
|
||||
if (vec->Width() == 3) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
private:
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
/// Alias to the semantic info in ctx.src
|
||||
const sem::Info& sem = ctx.src->Sem();
|
||||
/// Alias to the symbols in ctx.src
|
||||
const SymbolTable& sym = ctx.src->Symbols();
|
||||
/// Alias to the ctx.dst program builder
|
||||
ProgramBuilder& b = *ctx.dst;
|
||||
};
|
||||
|
||||
PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {}
|
||||
|
@ -183,12 +172,8 @@ std::string PackedVec3::Attribute::InternalName() const {
|
|||
PackedVec3::PackedVec3() = default;
|
||||
PackedVec3::~PackedVec3() = default;
|
||||
|
||||
bool PackedVec3::ShouldRun(const Program* program, const DataMap&) const {
|
||||
return State::ShouldRun(program);
|
||||
}
|
||||
|
||||
void PackedVec3::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State(ctx).Run();
|
||||
Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
return State{src}.Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -56,21 +56,13 @@ class PackedVec3 final : public Castable<PackedVec3, Transform> {
|
|||
/// Destructor
|
||||
~PackedVec3() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -50,8 +50,10 @@ PadStructs::PadStructs() = default;
|
|||
|
||||
PadStructs::~PadStructs() = default;
|
||||
|
||||
void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
auto& sem = src->Sem();
|
||||
|
||||
std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs;
|
||||
utils::Hashset<const ast::StructMember*, 8> padding_members;
|
||||
|
@ -65,7 +67,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
bool has_runtime_sized_array = false;
|
||||
utils::Vector<const ast::StructMember*, 8> new_members;
|
||||
for (auto* mem : str->Members()) {
|
||||
auto name = ctx.src->Symbols().NameFor(mem->Name());
|
||||
auto name = src->Symbols().NameFor(mem->Name());
|
||||
|
||||
if (offset < mem->Offset()) {
|
||||
CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset);
|
||||
|
@ -75,7 +77,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
auto* ty = mem->Type();
|
||||
const ast::Type* type = CreateASTTypeFor(ctx, ty);
|
||||
|
||||
new_members.Push(ctx.dst->Member(name, type));
|
||||
new_members.Push(b.Member(name, type));
|
||||
|
||||
uint32_t size = ty->Size();
|
||||
if (ty->Is<sem::Struct>() && str->UsedAs(ast::AddressSpace::kUniform)) {
|
||||
|
@ -97,8 +99,8 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
if (offset < struct_size && !has_runtime_sized_array) {
|
||||
CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
|
||||
}
|
||||
auto* new_struct = ctx.dst->create<ast::Struct>(ctx.Clone(ast_str->name),
|
||||
std::move(new_members), utils::Empty);
|
||||
auto* new_struct =
|
||||
b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members), utils::Empty);
|
||||
replaced_structs[ast_str] = new_struct;
|
||||
return new_struct;
|
||||
});
|
||||
|
@ -131,16 +133,17 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
auto* arg = ast_call->args.begin();
|
||||
for (auto* member : new_struct->members) {
|
||||
if (padding_members.Contains(member)) {
|
||||
new_args.Push(ctx.dst->Expr(0_u));
|
||||
new_args.Push(b.Expr(0_u));
|
||||
} else {
|
||||
new_args.Push(ctx.Clone(*arg));
|
||||
arg++;
|
||||
}
|
||||
}
|
||||
return ctx.dst->Construct(CreateASTTypeFor(ctx, str), new_args);
|
||||
return b.Construct(CreateASTTypeFor(ctx, str), new_args);
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -30,14 +30,10 @@ class PadStructs final : public Castable<PadStructs, Transform> {
|
|||
/// Destructor
|
||||
~PadStructs() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -13,6 +13,9 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/promote_initializers_to_let.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/call.h"
|
||||
#include "src/tint/sem/statement.h"
|
||||
|
@ -27,9 +30,16 @@ PromoteInitializersToLet::PromoteInitializersToLet() = default;
|
|||
|
||||
PromoteInitializersToLet::~PromoteInitializersToLet() = default;
|
||||
|
||||
void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
HoistToDeclBefore hoist_to_decl_before(ctx);
|
||||
|
||||
bool any_promoted = false;
|
||||
|
||||
// Hoists array and structure initializers to a constant variable, declared
|
||||
// just before the statement of usage.
|
||||
auto promote = [&](const sem::Expression* expr) {
|
||||
|
@ -59,14 +69,15 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&)
|
|||
return true;
|
||||
}
|
||||
|
||||
any_promoted = true;
|
||||
return hoist_to_decl_before.Add(expr, expr->Declaration(), true);
|
||||
};
|
||||
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
bool ok = Switch(
|
||||
node, //
|
||||
[&](const ast::CallExpression* expr) {
|
||||
if (auto* sem = ctx.src->Sem().Get(expr)) {
|
||||
if (auto* sem = src->Sem().Get(expr)) {
|
||||
auto* ctor = sem->UnwrapMaterialize()->As<sem::Call>();
|
||||
if (ctor->Target()->Is<sem::TypeInitializer>()) {
|
||||
return promote(sem);
|
||||
|
@ -75,7 +86,7 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&)
|
|||
return true;
|
||||
},
|
||||
[&](const ast::IdentifierExpression* expr) {
|
||||
if (auto* sem = ctx.src->Sem().Get(expr)) {
|
||||
if (auto* sem = src->Sem().Get(expr)) {
|
||||
if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
|
||||
// Identifier resolves to a variable
|
||||
if (auto* stmt = user->Stmt()) {
|
||||
|
@ -96,13 +107,17 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&)
|
|||
return true;
|
||||
},
|
||||
[&](Default) { return true; });
|
||||
|
||||
if (!ok) {
|
||||
return;
|
||||
return Program(std::move(b));
|
||||
}
|
||||
}
|
||||
|
||||
if (!any_promoted) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -33,14 +33,10 @@ class PromoteInitializersToLet final : public Castable<PromoteInitializersToLet,
|
|||
/// Destructor
|
||||
~PromoteInitializersToLet() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -53,34 +53,36 @@ class StateBase {
|
|||
// to else {if}s so that the next transform, DecomposeSideEffects, can insert
|
||||
// hoisted expressions above their current location.
|
||||
struct SimplifySideEffectStatements : Castable<PromoteSideEffectsToDecl, Transform> {
|
||||
class State;
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
|
||||
ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
class SimplifySideEffectStatements::State : public StateBase {
|
||||
HoistToDeclBefore hoist_to_decl_before;
|
||||
Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
public:
|
||||
explicit State(CloneContext& ctx_in) : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {}
|
||||
bool made_changes = false;
|
||||
|
||||
void Run() {
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* expr = node->As<ast::Expression>()) {
|
||||
auto* sem_expr = sem.Get(expr);
|
||||
if (!sem_expr || !sem_expr->HasSideEffects()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
hoist_to_decl_before.Prepare(sem_expr);
|
||||
HoistToDeclBefore hoist_to_decl_before(ctx);
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* expr = node->As<ast::Expression>()) {
|
||||
auto* sem_expr = src->Sem().Get(expr);
|
||||
if (!sem_expr || !sem_expr->HasSideEffects()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
ctx.Clone();
|
||||
}
|
||||
};
|
||||
|
||||
void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State state(ctx);
|
||||
state.Run();
|
||||
hoist_to_decl_before.Prepare(sem_expr);
|
||||
made_changes = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!made_changes) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
// Decomposes side-effecting expressions to ensure order of evaluation. This
|
||||
|
@ -89,7 +91,7 @@ void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMa
|
|||
struct DecomposeSideEffects : Castable<PromoteSideEffectsToDecl, Transform> {
|
||||
class CollectHoistsState;
|
||||
class DecomposeState;
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
|
||||
ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
// CollectHoistsState traverses the AST top-down, identifying which expressions
|
||||
|
@ -667,12 +669,15 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
|
|||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
};
|
||||
|
||||
void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
Transform::ApplyResult DecomposeSideEffects::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
// First collect side-effecting expressions to hoist
|
||||
CollectHoistsState collect_hoists_state{ctx};
|
||||
auto to_hoist = collect_hoists_state.Run();
|
||||
|
@ -680,6 +685,9 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
// Now decompose these expressions
|
||||
DecomposeState decompose_state{ctx, std::move(to_hoist)};
|
||||
decompose_state.Run();
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -687,13 +695,13 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
|
|||
PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default;
|
||||
PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default;
|
||||
|
||||
Output PromoteSideEffectsToDecl::Run(const Program* program, const DataMap& data) const {
|
||||
Transform::ApplyResult PromoteSideEffectsToDecl::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
transform::Manager manager;
|
||||
manager.Add<SimplifySideEffectStatements>();
|
||||
manager.Add<DecomposeSideEffects>();
|
||||
|
||||
auto output = manager.Run(program, data);
|
||||
return output;
|
||||
return manager.Apply(src, inputs, outputs);
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -31,12 +31,10 @@ class PromoteSideEffectsToDecl final : public Castable<PromoteSideEffectsToDecl,
|
|||
/// Destructor
|
||||
~PromoteSideEffectsToDecl() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform on `program`, returning the transformation result.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific data
|
||||
/// @returns the transformation result
|
||||
Output Run(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -32,53 +32,19 @@
|
|||
TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveContinueInSwitch);
|
||||
|
||||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
class State {
|
||||
private:
|
||||
CloneContext& ctx;
|
||||
ProgramBuilder& b;
|
||||
const sem::Info& sem;
|
||||
|
||||
// Map of switch statement to 'tint_continue' variable.
|
||||
std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
|
||||
|
||||
// If `cont` is within a switch statement within a loop, returns a pointer to
|
||||
// that switch statement.
|
||||
static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
|
||||
const ast::ContinueStatement* cont) {
|
||||
// Find whether first parent is a switch or a loop
|
||||
auto* sem_stmt = sem.Get(cont);
|
||||
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
|
||||
sem::ForLoopStatement, sem::WhileStatement>();
|
||||
if (!sem_parent) {
|
||||
return nullptr;
|
||||
}
|
||||
return sem_parent->Declaration()->As<ast::SwitchStatement>();
|
||||
}
|
||||
|
||||
public:
|
||||
/// PIMPL state for the transform
|
||||
struct RemoveContinueInSwitch::State {
|
||||
/// Constructor
|
||||
/// @param ctx_in the context
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
|
||||
|
||||
/// Returns true if this transform should be run for the given program
|
||||
static bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
auto* stmt = node->As<ast::ContinueStatement>();
|
||||
if (!stmt) {
|
||||
continue;
|
||||
}
|
||||
if (GetParentSwitchInLoop(program->Sem(), stmt)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
/// @param program the source program
|
||||
explicit State(const Program* program) : src(program) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
bool made_changes = false;
|
||||
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
auto* cont = node->As<ast::ContinueStatement>();
|
||||
if (!cont) {
|
||||
continue;
|
||||
|
@ -90,6 +56,8 @@ class State {
|
|||
continue;
|
||||
}
|
||||
|
||||
made_changes = true;
|
||||
|
||||
auto cont_var_name =
|
||||
tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() {
|
||||
// Create and insert 'var tint_continue : bool = false;' before the
|
||||
|
@ -116,22 +84,50 @@ class State {
|
|||
ctx.Replace(cont, new_stmt);
|
||||
}
|
||||
|
||||
if (!made_changes) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
private:
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
/// Alias to src->sem
|
||||
const sem::Info& sem = src->Sem();
|
||||
|
||||
// Map of switch statement to 'tint_continue' variable.
|
||||
std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
|
||||
|
||||
// If `cont` is within a switch statement within a loop, returns a pointer to
|
||||
// that switch statement.
|
||||
static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
|
||||
const ast::ContinueStatement* cont) {
|
||||
// Find whether first parent is a switch or a loop
|
||||
auto* sem_stmt = sem.Get(cont);
|
||||
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
|
||||
sem::ForLoopStatement, sem::WhileStatement>();
|
||||
if (!sem_parent) {
|
||||
return nullptr;
|
||||
}
|
||||
return sem_parent->Declaration()->As<ast::SwitchStatement>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RemoveContinueInSwitch::RemoveContinueInSwitch() = default;
|
||||
RemoveContinueInSwitch::~RemoveContinueInSwitch() = default;
|
||||
|
||||
bool RemoveContinueInSwitch::ShouldRun(const Program* program, const DataMap& /*data*/) const {
|
||||
return State::ShouldRun(program);
|
||||
}
|
||||
|
||||
void RemoveContinueInSwitch::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State state(ctx);
|
||||
state.Run();
|
||||
Transform::ApplyResult RemoveContinueInSwitch::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
State state(src);
|
||||
return state.Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -31,19 +31,13 @@ class RemoveContinueInSwitch final : public Castable<RemoveContinueInSwitch, Tra
|
|||
/// Destructor
|
||||
~RemoveContinueInSwitch() override;
|
||||
|
||||
protected:
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -41,34 +41,25 @@ RemovePhonies::RemovePhonies() = default;
|
|||
|
||||
RemovePhonies::~RemovePhonies() = default;
|
||||
|
||||
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::PhonyExpression>()) {
|
||||
return true;
|
||||
}
|
||||
if (auto* stmt = node->As<ast::CallStatement>()) {
|
||||
if (program->Sem().Get(stmt->expr)->ConstantValue() != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
auto& sem = ctx.src->Sem();
|
||||
auto& sem = src->Sem();
|
||||
|
||||
std::unordered_map<SinkSignature, Symbol, utils::Hasher<SinkSignature>> sinks;
|
||||
utils::Hashmap<SinkSignature, Symbol, 8, utils::Hasher<SinkSignature>> sinks;
|
||||
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
bool made_changes = false;
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
Switch(
|
||||
node,
|
||||
[&](const ast::AssignmentStatement* stmt) {
|
||||
if (stmt->lhs->Is<ast::PhonyExpression>()) {
|
||||
made_changes = true;
|
||||
|
||||
std::vector<const ast::Expression*> side_effects;
|
||||
if (!ast::TraverseExpressions(
|
||||
stmt->rhs, ctx.dst->Diagnostics(),
|
||||
[&](const ast::CallExpression* expr) {
|
||||
stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) {
|
||||
// ast::CallExpression may map to a function or builtin call
|
||||
// (both may have side-effects), or a type initializer or
|
||||
// type conversion (both do not have side effects).
|
||||
|
@ -100,8 +91,7 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
if (auto* call = side_effects[0]->As<ast::CallExpression>()) {
|
||||
// Phony assignment with single call side effect.
|
||||
// Replace phony assignment with call.
|
||||
ctx.Replace(stmt,
|
||||
[&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); });
|
||||
ctx.Replace(stmt, [&, call] { return b.CallStmt(ctx.Clone(call)); });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -114,22 +104,21 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
for (auto* arg : side_effects) {
|
||||
sig.push_back(sem.Get(arg)->Type()->UnwrapRef());
|
||||
}
|
||||
auto sink = utils::GetOrCreate(sinks, sig, [&] {
|
||||
auto name = ctx.dst->Symbols().New("phony_sink");
|
||||
auto sink = sinks.GetOrCreate(sig, [&] {
|
||||
auto name = b.Symbols().New("phony_sink");
|
||||
utils::Vector<const ast::Parameter*, 8> params;
|
||||
for (auto* ty : sig) {
|
||||
auto* ast_ty = CreateASTTypeFor(ctx, ty);
|
||||
params.Push(
|
||||
ctx.dst->Param("p" + std::to_string(params.Length()), ast_ty));
|
||||
params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
|
||||
}
|
||||
ctx.dst->Func(name, params, ctx.dst->ty.void_(), {});
|
||||
b.Func(name, params, b.ty.void_(), {});
|
||||
return name;
|
||||
});
|
||||
utils::Vector<const ast::Expression*, 8> args;
|
||||
for (auto* arg : side_effects) {
|
||||
args.Push(ctx.Clone(arg));
|
||||
}
|
||||
return ctx.dst->CallStmt(ctx.dst->Call(sink, args));
|
||||
return b.CallStmt(b.Call(sink, args));
|
||||
});
|
||||
}
|
||||
},
|
||||
|
@ -138,12 +127,18 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
// TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
|
||||
auto* sem_expr = sem.Get(stmt->expr);
|
||||
if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) {
|
||||
made_changes = true;
|
||||
ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (!made_changes) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -33,19 +33,10 @@ class RemovePhonies final : public Castable<RemovePhonies, Transform> {
|
|||
/// Destructor
|
||||
~RemovePhonies() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -36,27 +36,28 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
|
|||
|
||||
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
|
||||
|
||||
bool RemoveUnreachableStatements::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
|
||||
Transform::ApplyResult RemoveUnreachableStatements::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
bool made_changes = false;
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* stmt = src->Sem().Get<sem::Statement>(node)) {
|
||||
if (!stmt->IsReachable()) {
|
||||
return true;
|
||||
RemoveStatement(ctx, stmt->Declaration());
|
||||
made_changes = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
|
||||
if (!stmt->IsReachable()) {
|
||||
RemoveStatement(ctx, stmt->Declaration());
|
||||
}
|
||||
}
|
||||
if (!made_changes) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -32,19 +32,10 @@ class RemoveUnreachableStatements final : public Castable<RemoveUnreachableState
|
|||
/// Destructor
|
||||
~RemoveUnreachableStatements() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -1252,39 +1252,31 @@ Renamer::Config::~Config() = default;
|
|||
Renamer::Renamer() = default;
|
||||
Renamer::~Renamer() = default;
|
||||
|
||||
Output Renamer::Run(const Program* in, const DataMap& inputs) const {
|
||||
ProgramBuilder out;
|
||||
// Disable auto-cloning of symbols, since we want to rename them.
|
||||
CloneContext ctx(&out, in, false);
|
||||
Transform::ApplyResult Renamer::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
|
||||
|
||||
// Swizzles, builtin calls and builtin structure members need to keep their
|
||||
// symbols preserved.
|
||||
std::unordered_set<const ast::IdentifierExpression*> preserve;
|
||||
for (auto* node : in->ASTNodes().Objects()) {
|
||||
utils::Hashset<const ast::IdentifierExpression*, 8> preserve;
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* member = node->As<ast::MemberAccessorExpression>()) {
|
||||
auto* sem = in->Sem().Get(member);
|
||||
if (!sem) {
|
||||
TINT_ICE(Transform, out.Diagnostics())
|
||||
<< "MemberAccessorExpression has no semantic info";
|
||||
continue;
|
||||
}
|
||||
auto* sem = src->Sem().Get(member);
|
||||
if (sem->Is<sem::Swizzle>()) {
|
||||
preserve.emplace(member->member);
|
||||
} else if (auto* str_expr = in->Sem().Get(member->structure)) {
|
||||
preserve.Add(member->member);
|
||||
} else if (auto* str_expr = src->Sem().Get(member->structure)) {
|
||||
if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
|
||||
if (ty->Declaration() == nullptr) { // Builtin structure
|
||||
preserve.emplace(member->member);
|
||||
preserve.Add(member->member);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (auto* call = node->As<ast::CallExpression>()) {
|
||||
auto* sem = in->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>();
|
||||
if (!sem) {
|
||||
TINT_ICE(Transform, out.Diagnostics()) << "CallExpression has no semantic info";
|
||||
continue;
|
||||
}
|
||||
auto* sem = src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>();
|
||||
if (sem->Target()->Is<sem::Builtin>()) {
|
||||
preserve.emplace(call->target.name);
|
||||
preserve.Add(call->target.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1300,7 +1292,7 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const {
|
|||
}
|
||||
|
||||
ctx.ReplaceAll([&](Symbol sym_in) {
|
||||
auto name_in = ctx.src->Symbols().NameFor(sym_in);
|
||||
auto name_in = src->Symbols().NameFor(sym_in);
|
||||
if (preserve_unicode || text::utf8::IsASCII(name_in)) {
|
||||
switch (target) {
|
||||
case Target::kAll:
|
||||
|
@ -1343,17 +1335,20 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const {
|
|||
});
|
||||
|
||||
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
|
||||
if (preserve.count(ident)) {
|
||||
if (preserve.Contains(ident)) {
|
||||
auto sym_in = ident->symbol;
|
||||
auto str = in->Symbols().NameFor(sym_in);
|
||||
auto sym_out = out.Symbols().Register(str);
|
||||
auto str = src->Symbols().NameFor(sym_in);
|
||||
auto sym_out = b.Symbols().Register(str);
|
||||
return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out);
|
||||
}
|
||||
return nullptr; // Clone ident. Uses the symbol remapping above.
|
||||
});
|
||||
ctx.Clone();
|
||||
|
||||
return Output(Program(std::move(out)), std::make_unique<Data>(std::move(remappings)));
|
||||
ctx.Clone(); // Must come before the std::move()
|
||||
|
||||
outputs.Add<Data>(std::move(remappings));
|
||||
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -85,11 +85,10 @@ class Renamer final : public Castable<Renamer, Transform> {
|
|||
/// Destructor
|
||||
~Renamer() override;
|
||||
|
||||
/// Runs the transform on `program`, returning the transformation result.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns the transformation result
|
||||
Output Run(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -33,36 +33,48 @@ using namespace tint::number_suffixes; // NOLINT
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
/// State holds the current transform state
|
||||
/// PIMPL state for the transform
|
||||
struct Robustness::State {
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
/// Constructor
|
||||
/// @param program the source program
|
||||
/// @param omitted the omitted address spaces
|
||||
State(const Program* program, std::unordered_set<ast::AddressSpace>&& omitted)
|
||||
: src(program), omitted_address_spaces(std::move(omitted)) {}
|
||||
|
||||
/// Set of address spacees to not apply the transform to
|
||||
std::unordered_set<ast::AddressSpace> omitted_classes;
|
||||
|
||||
/// Applies the transformation state to `ctx`.
|
||||
void Transform() {
|
||||
/// Runs the transform
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); });
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); });
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
private:
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
/// Set of address spaces to not apply the transform to
|
||||
std::unordered_set<ast::AddressSpace> omitted_address_spaces;
|
||||
|
||||
/// Apply bounds clamping to array, vector and matrix indexing
|
||||
/// @param expr the array, vector or matrix index expression
|
||||
/// @return the clamped replacement expression, or nullptr if `expr` should be cloned without
|
||||
/// changes.
|
||||
const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) {
|
||||
auto* sem =
|
||||
ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>();
|
||||
auto* sem = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>();
|
||||
auto* ret_type = sem->Type();
|
||||
|
||||
auto* ref = ret_type->As<sem::Reference>();
|
||||
if (ref && omitted_classes.count(ref->AddressSpace()) != 0) {
|
||||
if (ref && omitted_address_spaces.count(ref->AddressSpace()) != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ProgramBuilder& b = *ctx.dst;
|
||||
|
||||
// idx return the cloned index expression, as a u32.
|
||||
auto idx = [&]() -> const ast::Expression* {
|
||||
auto* i = ctx.Clone(expr->index);
|
||||
|
@ -109,8 +121,8 @@ struct Robustness::State {
|
|||
} else {
|
||||
// Note: Don't be tempted to use the array override variable as an expression
|
||||
// here, the name might be shadowed!
|
||||
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
||||
sem::Array::kErrExpectedConstantCount);
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
sem::Array::kErrExpectedConstantCount);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -119,7 +131,7 @@ struct Robustness::State {
|
|||
[&](Default) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled object type in robustness of array index: "
|
||||
<< ctx.src->FriendlyName(ret_type->UnwrapRef());
|
||||
<< src->FriendlyName(ret_type->UnwrapRef());
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
|
@ -127,9 +139,9 @@ struct Robustness::State {
|
|||
return nullptr; // Clamping not needed
|
||||
}
|
||||
|
||||
auto src = ctx.Clone(expr->source);
|
||||
auto* obj = ctx.Clone(expr->object);
|
||||
return b.IndexAccessor(src, obj, clamped_idx);
|
||||
auto idx_src = ctx.Clone(expr->source);
|
||||
auto* idx_obj = ctx.Clone(expr->object);
|
||||
return b.IndexAccessor(idx_src, idx_obj, clamped_idx);
|
||||
}
|
||||
|
||||
/// @param type builtin type
|
||||
|
@ -145,15 +157,13 @@ struct Robustness::State {
|
|||
/// @return the clamped replacement call expression, or nullptr if `expr`
|
||||
/// should be cloned without changes.
|
||||
const ast::CallExpression* Transform(const ast::CallExpression* expr) {
|
||||
auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* call_target = call->Target();
|
||||
auto* builtin = call_target->As<sem::Builtin>();
|
||||
if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) {
|
||||
return nullptr; // No transform, just clone.
|
||||
}
|
||||
|
||||
ProgramBuilder& b = *ctx.dst;
|
||||
|
||||
// Indices of the mandatory texture and coords parameters, and the optional
|
||||
// array and level parameters.
|
||||
auto& signature = builtin->Signature();
|
||||
|
@ -261,7 +271,7 @@ struct Robustness::State {
|
|||
// Clamp the level argument, if provided
|
||||
if (level_idx >= 0) {
|
||||
auto* arg = expr->args[static_cast<size_t>(level_idx)];
|
||||
ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0_a));
|
||||
ctx.Replace(arg, level_arg ? level_arg() : b.Expr(0_a));
|
||||
}
|
||||
|
||||
return nullptr; // Clone, which will use the argument replacements above.
|
||||
|
@ -276,28 +286,27 @@ Robustness::Config& Robustness::Config::operator=(const Config&) = default;
|
|||
Robustness::Robustness() = default;
|
||||
Robustness::~Robustness() = default;
|
||||
|
||||
void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
Transform::ApplyResult Robustness::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
Config cfg;
|
||||
if (auto* cfg_data = inputs.Get<Config>()) {
|
||||
cfg = *cfg_data;
|
||||
}
|
||||
|
||||
std::unordered_set<ast::AddressSpace> omitted_classes;
|
||||
for (auto sc : cfg.omitted_classes) {
|
||||
std::unordered_set<ast::AddressSpace> omitted_address_spaces;
|
||||
for (auto sc : cfg.omitted_address_spaces) {
|
||||
switch (sc) {
|
||||
case AddressSpace::kUniform:
|
||||
omitted_classes.insert(ast::AddressSpace::kUniform);
|
||||
omitted_address_spaces.insert(ast::AddressSpace::kUniform);
|
||||
break;
|
||||
case AddressSpace::kStorage:
|
||||
omitted_classes.insert(ast::AddressSpace::kStorage);
|
||||
omitted_address_spaces.insert(ast::AddressSpace::kStorage);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
State state{ctx, std::move(omitted_classes)};
|
||||
|
||||
state.Transform();
|
||||
ctx.Clone();
|
||||
return State{src, std::move(omitted_address_spaces)}.Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -54,9 +54,9 @@ class Robustness final : public Castable<Robustness, Transform> {
|
|||
/// @returns this Config
|
||||
Config& operator=(const Config&);
|
||||
|
||||
/// Address spacees to omit from apply the transform to.
|
||||
/// Address spaces to omit from apply the transform to.
|
||||
/// This allows for optimizing on hardware that provide safe accesses.
|
||||
std::unordered_set<AddressSpace> omitted_classes;
|
||||
std::unordered_set<AddressSpace> omitted_address_spaces;
|
||||
};
|
||||
|
||||
/// Constructor
|
||||
|
@ -64,14 +64,10 @@ class Robustness final : public Castable<Robustness, Transform> {
|
|||
/// Destructor
|
||||
~Robustness() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
|
|
|
@ -1274,7 +1274,7 @@ fn f() {
|
|||
)";
|
||||
|
||||
Robustness::Config cfg;
|
||||
cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage);
|
||||
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage);
|
||||
|
||||
DataMap data;
|
||||
data.Add<Robustness::Config>(cfg);
|
||||
|
@ -1325,7 +1325,7 @@ fn f() {
|
|||
)";
|
||||
|
||||
Robustness::Config cfg;
|
||||
cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform);
|
||||
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform);
|
||||
|
||||
DataMap data;
|
||||
data.Add<Robustness::Config>(cfg);
|
||||
|
@ -1376,8 +1376,8 @@ fn f() {
|
|||
)";
|
||||
|
||||
Robustness::Config cfg;
|
||||
cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage);
|
||||
cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform);
|
||||
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage);
|
||||
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform);
|
||||
|
||||
DataMap data;
|
||||
data.Add<Robustness::Config>(cfg);
|
||||
|
|
|
@ -45,14 +45,18 @@ struct PointerOp {
|
|||
|
||||
} // namespace
|
||||
|
||||
/// The PIMPL state for the SimplifyPointers transform
|
||||
/// PIMPL state for the transform
|
||||
struct SimplifyPointers::State {
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
explicit State(CloneContext& context) : ctx(context) {}
|
||||
/// @param program the source program
|
||||
explicit State(const Program* program) : src(program) {}
|
||||
|
||||
/// Traverses the expression `expr` looking for non-literal array indexing
|
||||
/// expressions that would affect the computed address of a pointer
|
||||
|
@ -120,10 +124,11 @@ struct SimplifyPointers::State {
|
|||
}
|
||||
}
|
||||
|
||||
/// Performs the transformation
|
||||
void Run() {
|
||||
/// Runs the transform
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
// A map of saved expressions to their saved variable name
|
||||
std::unordered_map<const ast::Expression*, Symbol> saved_vars;
|
||||
utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars;
|
||||
|
||||
// Register the ast::Expression transform handler.
|
||||
// This performs two different transformations:
|
||||
|
@ -135,9 +140,8 @@ struct SimplifyPointers::State {
|
|||
// variable identifier.
|
||||
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
|
||||
// Look to see if we need to swap this Expression with a saved variable.
|
||||
auto it = saved_vars.find(expr);
|
||||
if (it != saved_vars.end()) {
|
||||
return ctx.dst->Expr(it->second);
|
||||
if (auto* saved_var = saved_vars.Find(expr)) {
|
||||
return ctx.dst->Expr(*saved_var);
|
||||
}
|
||||
|
||||
// Reduce the expression, folding away chains of address-of / indirections
|
||||
|
@ -174,7 +178,7 @@ struct SimplifyPointers::State {
|
|||
|
||||
// Scan the initializer expression for array index expressions that need
|
||||
// to be hoist to temporary "saved" variables.
|
||||
std::vector<const ast::VariableDeclStatement*> saved;
|
||||
utils::Vector<const ast::VariableDeclStatement*, 8> saved;
|
||||
CollectSavedArrayIndices(
|
||||
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) {
|
||||
// We have a sub-expression that needs to be saved.
|
||||
|
@ -182,18 +186,18 @@ struct SimplifyPointers::State {
|
|||
auto saved_name = ctx.dst->Symbols().New(
|
||||
ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save");
|
||||
auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
|
||||
saved.emplace_back(decl);
|
||||
saved.Push(decl);
|
||||
// Record the substitution of `idx_expr` to the saved variable
|
||||
// with the symbol `saved_name`. This will be used by the
|
||||
// ReplaceAll() handler above.
|
||||
saved_vars.emplace(idx_expr, saved_name);
|
||||
saved_vars.Add(idx_expr, saved_name);
|
||||
});
|
||||
|
||||
// Find the place to insert the saved declarations.
|
||||
// Special care needs to be made for lets declared as the initializer
|
||||
// part of for-loops. In this case the block will hold the for-loop
|
||||
// statement, not the let.
|
||||
if (!saved.empty()) {
|
||||
if (!saved.IsEmpty()) {
|
||||
auto* stmt = ctx.src->Sem().Get(let);
|
||||
auto* block = stmt->Block();
|
||||
// Find the statement owned by the block (either the let decl or a
|
||||
|
@ -219,7 +223,9 @@ struct SimplifyPointers::State {
|
|||
RemoveStatement(ctx, let);
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -227,8 +233,8 @@ SimplifyPointers::SimplifyPointers() = default;
|
|||
|
||||
SimplifyPointers::~SimplifyPointers() = default;
|
||||
|
||||
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State(ctx).Run();
|
||||
Transform::ApplyResult SimplifyPointers::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
return State(src).Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -39,16 +39,13 @@ class SimplifyPointers final : public Castable<SimplifyPointers, Transform> {
|
|||
/// Destructor
|
||||
~SimplifyPointers() override;
|
||||
|
||||
protected:
|
||||
struct State;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -30,33 +30,37 @@ SingleEntryPoint::SingleEntryPoint() = default;
|
|||
|
||||
SingleEntryPoint::~SingleEntryPoint() = default;
|
||||
|
||||
void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
auto* cfg = inputs.Get<Config>();
|
||||
if (cfg == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
|
||||
|
||||
return;
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"missing transform data for " + std::string(TypeInfo().name));
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
// Find the target entry point.
|
||||
const ast::Function* entry_point = nullptr;
|
||||
for (auto* f : ctx.src->AST().Functions()) {
|
||||
for (auto* f : src->AST().Functions()) {
|
||||
if (!f->IsEntryPoint()) {
|
||||
continue;
|
||||
}
|
||||
if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
|
||||
if (src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
|
||||
entry_point = f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (entry_point == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
||||
"entry point '" + cfg->entry_point_name + "' not found");
|
||||
return;
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"entry point '" + cfg->entry_point_name + "' not found");
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
auto& sem = ctx.src->Sem();
|
||||
auto& sem = src->Sem();
|
||||
|
||||
// Build set of referenced module-scope variables for faster lookups later.
|
||||
std::unordered_set<const ast::Variable*> referenced_vars;
|
||||
|
@ -66,12 +70,12 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
|
|||
|
||||
// Clone any module-scope variables, types, and functions that are statically referenced by the
|
||||
// target entry point.
|
||||
for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
|
||||
for (auto* decl : src->AST().GlobalDeclarations()) {
|
||||
Switch(
|
||||
decl, //
|
||||
[&](const ast::TypeDecl* ty) {
|
||||
// TODO(jrprice): Strip unused types.
|
||||
ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
|
||||
b.AST().AddTypeDecl(ctx.Clone(ty));
|
||||
},
|
||||
[&](const ast::Override* override) {
|
||||
if (referenced_vars.count(override)) {
|
||||
|
@ -80,37 +84,39 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
|
|||
// so that its allocated ID so that it won't be affected by other
|
||||
// stripped away overrides
|
||||
auto* global = sem.Get(override);
|
||||
const auto* id = ctx.dst->Id(global->OverrideId());
|
||||
const auto* id = b.Id(global->OverrideId());
|
||||
ctx.InsertFront(override->attributes, id);
|
||||
}
|
||||
ctx.dst->AST().AddGlobalVariable(ctx.Clone(override));
|
||||
b.AST().AddGlobalVariable(ctx.Clone(override));
|
||||
}
|
||||
},
|
||||
[&](const ast::Var* var) {
|
||||
if (referenced_vars.count(var)) {
|
||||
ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
|
||||
b.AST().AddGlobalVariable(ctx.Clone(var));
|
||||
}
|
||||
},
|
||||
[&](const ast::Const* c) {
|
||||
// Always keep 'const' declarations, as these can be used by attributes and array
|
||||
// sizes, which are not tracked as transitively used by functions. They also don't
|
||||
// typically get emitted by the backend unless they're actually used.
|
||||
ctx.dst->AST().AddGlobalVariable(ctx.Clone(c));
|
||||
b.AST().AddGlobalVariable(ctx.Clone(c));
|
||||
},
|
||||
[&](const ast::Function* func) {
|
||||
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
|
||||
ctx.dst->AST().AddFunction(ctx.Clone(func));
|
||||
b.AST().AddFunction(ctx.Clone(func));
|
||||
}
|
||||
},
|
||||
[&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); },
|
||||
[&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
|
||||
[&](Default) {
|
||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_UNREACHABLE(Transform, b.Diagnostics())
|
||||
<< "unhandled global declaration: " << decl->TypeInfo().name;
|
||||
});
|
||||
}
|
||||
|
||||
// Clone the entry point.
|
||||
ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
|
||||
b.AST().AddFunction(ctx.Clone(entry_point));
|
||||
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {}
|
||||
|
|
|
@ -53,14 +53,10 @@ class SingleEntryPoint final : public Castable<SingleEntryPoint, Transform> {
|
|||
/// Destructor
|
||||
~SingleEntryPoint() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace tint::transform {
|
|||
|
||||
using namespace tint::number_suffixes; // NOLINT
|
||||
|
||||
/// Private implementation of transform
|
||||
/// PIMPL state for the transform
|
||||
struct SpirvAtomic::State {
|
||||
private:
|
||||
/// A struct that has been forked because a subset of members were made atomic.
|
||||
|
@ -46,19 +46,24 @@ struct SpirvAtomic::State {
|
|||
std::unordered_set<size_t> atomic_members;
|
||||
};
|
||||
|
||||
CloneContext& ctx;
|
||||
ProgramBuilder& b = *ctx.dst;
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs;
|
||||
std::unordered_set<const sem::Variable*> atomic_variables;
|
||||
utils::UniqueVector<const sem::Expression*, 8> atomic_expressions;
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param c the clone context
|
||||
explicit State(CloneContext& c) : ctx(c) {}
|
||||
/// @param program the source program
|
||||
explicit State(const Program* program) : src(program) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
// Look for stub functions generated by the SPIR-V reader, which are used as placeholders
|
||||
// for atomic builtin calls.
|
||||
for (auto* fn : ctx.src->AST().Functions()) {
|
||||
|
@ -102,6 +107,10 @@ struct SpirvAtomic::State {
|
|||
}
|
||||
}
|
||||
|
||||
if (atomic_expressions.IsEmpty()) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
// Transform all variables and structure members that were used in atomic operations as
|
||||
// atomic types. This propagates up originating expression chains.
|
||||
ProcessAtomicExpressions();
|
||||
|
@ -143,6 +152,7 @@ struct SpirvAtomic::State {
|
|||
ReplaceLoadsAndStores();
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -297,17 +307,8 @@ const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(CloneContext* ctx) const {
|
|||
ctx->dst->AllocateNodeID(), builtin);
|
||||
}
|
||||
|
||||
bool SpirvAtomic::ShouldRun(const Program* program, const DataMap&) const {
|
||||
for (auto* fn : program->AST().Functions()) {
|
||||
if (ast::HasAttribute<Stub>(fn->attributes)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void SpirvAtomic::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State{ctx}.Run();
|
||||
Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
return State{src}.Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -63,21 +63,13 @@ class SpirvAtomic final : public Castable<SpirvAtomic, Transform> {
|
|||
const sem::BuiltinType builtin;
|
||||
};
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
protected:
|
||||
private:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -77,14 +77,20 @@ struct Hasher<DynamicIndex> {
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
/// The PIMPL state for the Std140 transform
|
||||
/// PIMPL state for the transform
|
||||
struct Std140::State {
|
||||
/// Constructor
|
||||
/// @param c the CloneContext
|
||||
explicit State(CloneContext& c) : ctx(c) {}
|
||||
/// @param program the source program
|
||||
explicit State(const Program* program) : src(program) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
if (!ShouldRun()) {
|
||||
// Transform is not required
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
// Begin by creating forked types for any type that is used as a uniform buffer, that
|
||||
// either directly or transitively contains a matrix that needs splitting for std140 layout.
|
||||
ForkTypes();
|
||||
|
@ -116,11 +122,11 @@ struct Std140::State {
|
|||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
/// @returns true if this transform should be run for the given program
|
||||
/// @param program the program to inspect
|
||||
static bool ShouldRun(const Program* program) {
|
||||
bool ShouldRun() const {
|
||||
// Returns true if the type needs to be forked for std140 usage.
|
||||
auto needs_fork = [&](const sem::Type* ty) {
|
||||
while (auto* arr = ty->As<sem::Array>()) {
|
||||
|
@ -135,7 +141,7 @@ struct Std140::State {
|
|||
};
|
||||
|
||||
// Scan structures for members that need forking
|
||||
for (auto* ty : program->Types()) {
|
||||
for (auto* ty : src->Types()) {
|
||||
if (auto* str = ty->As<sem::Struct>()) {
|
||||
if (str->UsedAs(ast::AddressSpace::kUniform)) {
|
||||
for (auto* member : str->Members()) {
|
||||
|
@ -148,8 +154,8 @@ struct Std140::State {
|
|||
}
|
||||
|
||||
// Scan uniform variables that have types that need forking
|
||||
for (auto* decl : program->AST().GlobalVariables()) {
|
||||
auto* global = program->Sem().Get(decl);
|
||||
for (auto* decl : src->AST().GlobalVariables()) {
|
||||
auto* global = src->Sem().Get(decl);
|
||||
if (global->AddressSpace() == ast::AddressSpace::kUniform) {
|
||||
if (needs_fork(global->Type()->UnwrapRef())) {
|
||||
return true;
|
||||
|
@ -197,14 +203,16 @@ struct Std140::State {
|
|||
}
|
||||
};
|
||||
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
/// Alias to the semantic info in ctx.src
|
||||
const sem::Info& sem = ctx.src->Sem();
|
||||
/// Alias to the symbols in ctx.src
|
||||
const SymbolTable& sym = ctx.src->Symbols();
|
||||
/// Alias to the ctx.dst program builder
|
||||
ProgramBuilder& b = *ctx.dst;
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
/// Alias to the semantic info in src
|
||||
const sem::Info& sem = src->Sem();
|
||||
/// Alias to the symbols in src
|
||||
const SymbolTable& sym = src->Symbols();
|
||||
|
||||
/// Map of load function signature, to the generated function
|
||||
utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns;
|
||||
|
@ -218,7 +226,7 @@ struct Std140::State {
|
|||
// Map of original structure to 'std140' forked structure
|
||||
utils::Hashmap<const sem::Struct*, Symbol, 8> std140_structs;
|
||||
|
||||
// Map of structure member in ctx.src of a matrix type, to list of decomposed column
|
||||
// Map of structure member in src of a matrix type, to list of decomposed column
|
||||
// members in ctx.dst.
|
||||
utils::Hashmap<const sem::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8>
|
||||
std140_mat_members;
|
||||
|
@ -232,7 +240,7 @@ struct Std140::State {
|
|||
utils::Vector<Symbol, 4> columns;
|
||||
};
|
||||
|
||||
// Map of matrix type in ctx.src, to decomposed column structure in ctx.dst.
|
||||
// Map of matrix type in src, to decomposed column structure in ctx.dst.
|
||||
utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats;
|
||||
|
||||
/// AccessChain describes a chain of access expressions to uniform buffer variable.
|
||||
|
@ -266,7 +274,7 @@ struct Std140::State {
|
|||
/// map (via Std140Type()).
|
||||
void ForkTypes() {
|
||||
// For each module scope declaration...
|
||||
for (auto* global : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
|
||||
for (auto* global : src->Sem().Module()->DependencyOrderedDeclarations()) {
|
||||
// Check to see if this is a structure used by a uniform buffer...
|
||||
auto* str = sem.Get<sem::Struct>(global);
|
||||
if (str && str->UsedAs(ast::AddressSpace::kUniform)) {
|
||||
|
@ -317,7 +325,7 @@ struct Std140::State {
|
|||
if (fork_std140) {
|
||||
// Clone any members that have not already been cloned.
|
||||
for (auto& member : members) {
|
||||
if (member->program_id == ctx.src->ID()) {
|
||||
if (member->program_id == src->ID()) {
|
||||
member = ctx.Clone(member);
|
||||
}
|
||||
}
|
||||
|
@ -326,7 +334,7 @@ struct Std140::State {
|
|||
auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140");
|
||||
auto* std140 = b.create<ast::Struct>(name, std::move(members),
|
||||
ctx.Clone(str->Declaration()->attributes));
|
||||
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), global, std140);
|
||||
ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
|
||||
std140_structs.Add(str, name);
|
||||
}
|
||||
}
|
||||
|
@ -337,14 +345,13 @@ struct Std140::State {
|
|||
/// type that has been forked for std140-layout.
|
||||
/// Populates the #std140_uniforms set.
|
||||
void ReplaceUniformVarTypes() {
|
||||
for (auto* global : ctx.src->AST().GlobalVariables()) {
|
||||
for (auto* global : src->AST().GlobalVariables()) {
|
||||
if (auto* var = global->As<ast::Var>()) {
|
||||
if (var->declared_address_space == ast::AddressSpace::kUniform) {
|
||||
auto* v = sem.Get(var);
|
||||
if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) {
|
||||
ctx.Replace(global->type, std140_ty);
|
||||
std140_uniforms.Add(v);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -404,7 +411,7 @@ struct Std140::State {
|
|||
auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
|
||||
auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
|
||||
std::to_string(mat->rows()) + "_" +
|
||||
ctx.src->FriendlyName(mat->type()));
|
||||
src->FriendlyName(mat->type()));
|
||||
auto members =
|
||||
DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size());
|
||||
b.Structure(name, members);
|
||||
|
@ -421,7 +428,7 @@ struct Std140::State {
|
|||
if (auto* std140 = Std140Type(arr->ElemType())) {
|
||||
utils::Vector<const ast::Attribute*, 1> attrs;
|
||||
if (!arr->IsStrideImplicit()) {
|
||||
attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride()));
|
||||
attrs.Push(b.create<ast::StrideAttribute>(arr->Stride()));
|
||||
}
|
||||
auto count = arr->ConstantCount();
|
||||
if (!count) {
|
||||
|
@ -429,7 +436,7 @@ struct Std140::State {
|
|||
// * Override-expression counts can only be applied to workgroup arrays, and
|
||||
// this method only handles types transitively used as uniform buffers.
|
||||
// * Runtime-sized arrays cannot be used in uniform buffers.
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unexpected non-constant array count";
|
||||
count = 1;
|
||||
}
|
||||
|
@ -440,7 +447,7 @@ struct Std140::State {
|
|||
});
|
||||
}
|
||||
|
||||
/// @param mat the matrix to decompose (in ctx.src)
|
||||
/// @param mat the matrix to decompose (in src)
|
||||
/// @param name_prefix the name prefix to apply to each of the returned column vector members.
|
||||
/// @param align the alignment in bytes of the matrix.
|
||||
/// @param size the size in bytes of the matrix.
|
||||
|
@ -473,7 +480,7 @@ struct Std140::State {
|
|||
// Build the member
|
||||
const auto col_name = name_prefix + std::to_string(i);
|
||||
const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType());
|
||||
const auto* col_member = ctx.dst->Member(col_name, col_ty, std::move(attributes));
|
||||
const auto* col_member = b.Member(col_name, col_ty, std::move(attributes));
|
||||
// Record the member for std140_mat_members
|
||||
out.Push(col_member);
|
||||
}
|
||||
|
@ -618,7 +625,7 @@ struct Std140::State {
|
|||
|
||||
/// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
|
||||
/// being converted.
|
||||
const std::string ConvertSuffix(const sem::Type* ty) const {
|
||||
const std::string ConvertSuffix(const sem::Type* ty) {
|
||||
return Switch(
|
||||
ty, //
|
||||
[&](const sem::Struct* str) { return sym.NameFor(str->Name()); },
|
||||
|
@ -629,8 +636,7 @@ struct Std140::State {
|
|||
// * Override-expression counts can only be applied to workgroup arrays, and
|
||||
// this method only handles types transitively used as uniform buffers.
|
||||
// * Runtime-sized arrays cannot be used in uniform buffers.
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
<< "unexpected non-constant array count";
|
||||
TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
|
||||
count = 1;
|
||||
}
|
||||
return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
|
||||
|
@ -642,7 +648,7 @@ struct Std140::State {
|
|||
[&](const sem::F32*) { return "f32"; },
|
||||
[&](Default) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled type for conversion name: " << ctx.src->FriendlyName(ty);
|
||||
<< "unhandled type for conversion name: " << src->FriendlyName(ty);
|
||||
return "";
|
||||
});
|
||||
}
|
||||
|
@ -718,8 +724,7 @@ struct Std140::State {
|
|||
stmts.Push(b.Return(b.Construct(mat_ty, std::move(mat_args))));
|
||||
} else {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "failed to find std140 matrix info for: "
|
||||
<< ctx.src->FriendlyName(ty);
|
||||
<< "failed to find std140 matrix info for: " << src->FriendlyName(ty);
|
||||
}
|
||||
}, //
|
||||
[&](const sem::Array* arr) {
|
||||
|
@ -736,7 +741,7 @@ struct Std140::State {
|
|||
// * Override-expression counts can only be applied to workgroup arrays, and
|
||||
// this method only handles types transitively used as uniform buffers.
|
||||
// * Runtime-sized arrays cannot be used in uniform buffers.
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unexpected non-constant array count";
|
||||
count = 1;
|
||||
}
|
||||
|
@ -749,7 +754,7 @@ struct Std140::State {
|
|||
},
|
||||
[&](Default) {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled type for conversion: " << ctx.src->FriendlyName(ty);
|
||||
<< "unhandled type for conversion: " << src->FriendlyName(ty);
|
||||
});
|
||||
|
||||
// Generate the function
|
||||
|
@ -1063,7 +1068,7 @@ struct Std140::State {
|
|||
|
||||
if (std::get_if<UniformVariable>(&access)) {
|
||||
const auto* expr = b.Expr(ctx.Clone(chain.var->Declaration()->symbol));
|
||||
const auto name = ctx.src->Symbols().NameFor(chain.var->Declaration()->symbol);
|
||||
const auto name = src->Symbols().NameFor(chain.var->Declaration()->symbol);
|
||||
ty = chain.var->Type()->UnwrapRef();
|
||||
return {expr, ty, name};
|
||||
}
|
||||
|
@ -1090,7 +1095,7 @@ struct Std140::State {
|
|||
}, //
|
||||
[&](Default) -> ExprTypeName {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
|
||||
<< "unhandled type for access chain: " << src->FriendlyName(ty);
|
||||
return {};
|
||||
});
|
||||
}
|
||||
|
@ -1104,14 +1109,14 @@ struct Std140::State {
|
|||
for (auto el : *swizzle) {
|
||||
rhs += xyzw[el];
|
||||
}
|
||||
auto swizzle_ty = ctx.src->Types().Find<sem::Vector>(
|
||||
auto swizzle_ty = src->Types().Find<sem::Vector>(
|
||||
vec->type(), static_cast<uint32_t>(swizzle->Length()));
|
||||
auto* expr = b.MemberAccessor(lhs, rhs);
|
||||
return {expr, swizzle_ty, rhs};
|
||||
}, //
|
||||
[&](Default) -> ExprTypeName {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
|
||||
<< "unhandled type for access chain: " << src->FriendlyName(ty);
|
||||
return {};
|
||||
});
|
||||
}
|
||||
|
@ -1140,7 +1145,7 @@ struct Std140::State {
|
|||
}, //
|
||||
[&](Default) -> ExprTypeName {
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
|
||||
<< "unhandled type for access chain: " << src->FriendlyName(ty);
|
||||
return {};
|
||||
});
|
||||
}
|
||||
|
@ -1150,12 +1155,8 @@ Std140::Std140() = default;
|
|||
|
||||
Std140::~Std140() = default;
|
||||
|
||||
bool Std140::ShouldRun(const Program* program, const DataMap&) const {
|
||||
return State::ShouldRun(program);
|
||||
}
|
||||
|
||||
void Std140::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State(ctx).Run();
|
||||
Transform::ApplyResult Std140::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
return State(src).Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -34,21 +34,13 @@ class Std140 final : public Castable<Std140, Transform> {
|
|||
/// Destructor
|
||||
~Std140() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "src/tint/transform/substitute_override.h"
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/sem/builtin.h"
|
||||
|
@ -25,12 +26,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride);
|
|||
TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config);
|
||||
|
||||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
SubstituteOverride::SubstituteOverride() = default;
|
||||
|
||||
SubstituteOverride::~SubstituteOverride() = default;
|
||||
|
||||
bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->AST().GlobalVariables()) {
|
||||
if (node->Is<ast::Override>()) {
|
||||
return true;
|
||||
|
@ -39,18 +37,32 @@ bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const
|
|||
return false;
|
||||
}
|
||||
|
||||
void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) const {
|
||||
} // namespace
|
||||
|
||||
SubstituteOverride::SubstituteOverride() = default;
|
||||
|
||||
SubstituteOverride::~SubstituteOverride() = default;
|
||||
|
||||
Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
|
||||
const DataMap& config,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
const auto* data = config.Get<Config>();
|
||||
if (!data) {
|
||||
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
||||
"Missing override substitution data");
|
||||
return;
|
||||
b.Diagnostics().add_error(diag::System::Transform, "Missing override substitution data");
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
if (!ShouldRun(ctx.src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* {
|
||||
auto* sem = ctx.src->Sem().Get(w);
|
||||
|
||||
auto src = ctx.Clone(w->source);
|
||||
auto source = ctx.Clone(w->source);
|
||||
auto sym = ctx.Clone(w->symbol);
|
||||
auto* ty = ctx.Clone(w->type);
|
||||
|
||||
|
@ -58,30 +70,30 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&)
|
|||
auto iter = data->map.find(sem->OverrideId());
|
||||
if (iter == data->map.end()) {
|
||||
if (!w->initializer) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
b.Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"Initializer not provided for override, and override not overridden.");
|
||||
return nullptr;
|
||||
}
|
||||
return ctx.dst->Const(src, sym, ty, ctx.Clone(w->initializer));
|
||||
return b.Const(source, sym, ty, ctx.Clone(w->initializer));
|
||||
}
|
||||
|
||||
auto value = iter->second;
|
||||
auto* ctor = Switch(
|
||||
sem->Type(),
|
||||
[&](const sem::Bool*) { return ctx.dst->Expr(!std::equal_to<double>()(value, 0.0)); },
|
||||
[&](const sem::I32*) { return ctx.dst->Expr(i32(value)); },
|
||||
[&](const sem::U32*) { return ctx.dst->Expr(u32(value)); },
|
||||
[&](const sem::F32*) { return ctx.dst->Expr(f32(value)); },
|
||||
[&](const sem::F16*) { return ctx.dst->Expr(f16(value)); });
|
||||
[&](const sem::Bool*) { return b.Expr(!std::equal_to<double>()(value, 0.0)); },
|
||||
[&](const sem::I32*) { return b.Expr(i32(value)); },
|
||||
[&](const sem::U32*) { return b.Expr(u32(value)); },
|
||||
[&](const sem::F32*) { return b.Expr(f32(value)); },
|
||||
[&](const sem::F16*) { return b.Expr(f16(value)); });
|
||||
|
||||
if (!ctor) {
|
||||
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
||||
"Failed to create override-expression");
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"Failed to create override-expression");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx.dst->Const(src, sym, ty, ctor);
|
||||
return b.Const(source, sym, ty, ctor);
|
||||
});
|
||||
|
||||
// Ensure that objects that are indexed with an override-expression are materialized.
|
||||
|
@ -89,11 +101,10 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&)
|
|||
// resulting type of the index may change. See: crbug.com/tint/1697.
|
||||
ctx.ReplaceAll(
|
||||
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
|
||||
if (auto* sem = ctx.src->Sem().Get(expr)) {
|
||||
if (auto* sem = src->Sem().Get(expr)) {
|
||||
if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
|
||||
if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&
|
||||
access->Index()->Stage() == sem::EvaluationStage::kOverride) {
|
||||
auto& b = *ctx.dst;
|
||||
auto* obj = b.Call(sem::str(sem::BuiltinType::kTintMaterialize),
|
||||
ctx.Clone(expr->object));
|
||||
return b.IndexAccessor(obj, ctx.Clone(expr->index));
|
||||
|
@ -104,6 +115,7 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&)
|
|||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
SubstituteOverride::Config::Config() = default;
|
||||
|
|
|
@ -75,19 +75,10 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform>
|
|||
/// Destructor
|
||||
~SubstituteOverride() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -122,7 +122,18 @@ class TransformTestBase : public BASE {
|
|||
}
|
||||
|
||||
const Transform& t = TRANSFORM();
|
||||
return t.ShouldRun(&program, data);
|
||||
|
||||
DataMap outputs;
|
||||
auto result = t.Apply(&program, data, outputs);
|
||||
if (!result) {
|
||||
return false;
|
||||
}
|
||||
if (!result->IsValid()) {
|
||||
ADD_FAILURE() << "Apply() called by ShouldRun() returned errors: "
|
||||
<< result->Diagnostics().str();
|
||||
return true;
|
||||
}
|
||||
return result.has_value();
|
||||
}
|
||||
|
||||
/// @param in the input WGSL source
|
||||
|
|
|
@ -46,24 +46,19 @@ Output::Output(Program&& p) : program(std::move(p)) {}
|
|||
Transform::Transform() = default;
|
||||
Transform::~Transform() = default;
|
||||
|
||||
Output Transform::Run(const Program* program, const DataMap& data /* = {} */) const {
|
||||
ProgramBuilder builder;
|
||||
CloneContext ctx(&builder, program);
|
||||
Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const {
|
||||
Output output;
|
||||
Run(ctx, data, output.data);
|
||||
output.program = Program(std::move(builder));
|
||||
if (auto program = Apply(src, data, output.data)) {
|
||||
output.program = std::move(program.value());
|
||||
} else {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
ctx.Clone();
|
||||
output.program = Program(std::move(b));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
|
||||
<< "Transform::Run() unimplemented for " << TypeInfo().name;
|
||||
}
|
||||
|
||||
bool Transform::ShouldRun(const Program*, const DataMap&) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
|
||||
auto* sem = ctx.src->Sem().Get(stmt);
|
||||
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
|
||||
|
|
|
@ -158,26 +158,30 @@ class Transform : public Castable<Transform> {
|
|||
/// Destructor
|
||||
~Transform() override;
|
||||
|
||||
/// Runs the transform on `program`, returning the transformation result.
|
||||
/// Runs the transform on @p program, returning the transformation result or a clone of
|
||||
/// @p program.
|
||||
/// @param program the source program to transform
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns the transformation result
|
||||
virtual Output Run(const Program* program, const DataMap& data = {}) const;
|
||||
Output Run(const Program* program, const DataMap& data = {}) const;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
virtual bool ShouldRun(const Program* program, const DataMap& data = {}) const;
|
||||
/// The return value of Apply().
|
||||
/// If SkipTransform (std::nullopt), then the transform is not needed to be run.
|
||||
using ApplyResult = std::optional<Program>;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// Value returned from Apply() to indicate that the transform does not need to be run
|
||||
static inline constexpr std::nullopt_t SkipTransform = std::nullopt;
|
||||
|
||||
/// Runs the transform on `program`, return.
|
||||
/// @param program the input program
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const;
|
||||
/// @returns a transformed program, or std::nullopt if the transform didn't need to run.
|
||||
virtual ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const = 0;
|
||||
|
||||
protected:
|
||||
/// Removes the statement `stmt` from the transformed program.
|
||||
/// RemoveStatement handles edge cases, like statements in the initializer and
|
||||
/// continuing of for-loops.
|
||||
|
|
|
@ -23,7 +23,9 @@ namespace {
|
|||
|
||||
// Inherit from Transform so we have access to protected methods
|
||||
struct CreateASTTypeForTest : public testing::Test, public Transform {
|
||||
Output Run(const Program*, const DataMap&) const override { return {}; }
|
||||
ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
const ast::Type* create(std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
|
||||
ProgramBuilder sem_type_builder;
|
||||
|
|
|
@ -28,27 +28,32 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow);
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
/// The PIMPL state for the Unshadow transform
|
||||
/// PIMPL state for the transform
|
||||
struct Unshadow::State {
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext& ctx;
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
/// Constructor
|
||||
/// @param context the clone context
|
||||
explicit State(CloneContext& context) : ctx(context) {}
|
||||
/// @param program the source program
|
||||
explicit State(const Program* program) : src(program) {}
|
||||
|
||||
/// Performs the transformation
|
||||
void Run() {
|
||||
auto& sem = ctx.src->Sem();
|
||||
/// Runs the transform
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
Transform::ApplyResult Run() {
|
||||
auto& sem = src->Sem();
|
||||
|
||||
// Maps a variable to its new name.
|
||||
std::unordered_map<const sem::Variable*, Symbol> renamed_to;
|
||||
utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
|
||||
|
||||
auto rename = [&](const sem::Variable* v) -> const ast::Variable* {
|
||||
auto* decl = v->Declaration();
|
||||
auto name = ctx.src->Symbols().NameFor(decl->symbol);
|
||||
auto symbol = ctx.dst->Symbols().New(name);
|
||||
renamed_to.emplace(v, symbol);
|
||||
auto name = src->Symbols().NameFor(decl->symbol);
|
||||
auto symbol = b.Symbols().New(name);
|
||||
renamed_to.Add(v, symbol);
|
||||
|
||||
auto source = ctx.Clone(decl->source);
|
||||
auto* type = ctx.Clone(decl->type);
|
||||
|
@ -57,20 +62,20 @@ struct Unshadow::State {
|
|||
return Switch(
|
||||
decl, //
|
||||
[&](const ast::Var* var) {
|
||||
return ctx.dst->Var(source, symbol, type, var->declared_address_space,
|
||||
var->declared_access, initializer, attributes);
|
||||
return b.Var(source, symbol, type, var->declared_address_space,
|
||||
var->declared_access, initializer, attributes);
|
||||
},
|
||||
[&](const ast::Let*) {
|
||||
return ctx.dst->Let(source, symbol, type, initializer, attributes);
|
||||
return b.Let(source, symbol, type, initializer, attributes);
|
||||
},
|
||||
[&](const ast::Const*) {
|
||||
return ctx.dst->Const(source, symbol, type, initializer, attributes);
|
||||
return b.Const(source, symbol, type, initializer, attributes);
|
||||
},
|
||||
[&](const ast::Parameter*) {
|
||||
return ctx.dst->Param(source, symbol, type, attributes);
|
||||
[&](const ast::Parameter*) { //
|
||||
return b.Param(source, symbol, type, attributes);
|
||||
},
|
||||
[&](Default) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "unexpected variable type: " << decl->TypeInfo().name;
|
||||
return nullptr;
|
||||
});
|
||||
|
@ -92,14 +97,15 @@ struct Unshadow::State {
|
|||
ctx.ReplaceAll(
|
||||
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
|
||||
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
|
||||
auto it = renamed_to.find(user->Variable());
|
||||
if (it != renamed_to.end()) {
|
||||
return ctx.dst->Expr(it->second);
|
||||
if (auto* renamed = renamed_to.Find(user->Variable())) {
|
||||
return b.Expr(*renamed);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -107,8 +113,8 @@ Unshadow::Unshadow() = default;
|
|||
|
||||
Unshadow::~Unshadow() = default;
|
||||
|
||||
void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
State(ctx).Run();
|
||||
Transform::ApplyResult Unshadow::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
return State(src).Run();
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -29,16 +29,13 @@ class Unshadow final : public Castable<Unshadow, Transform> {
|
|||
/// Destructor
|
||||
~Unshadow() override;
|
||||
|
||||
protected:
|
||||
struct State;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -35,7 +35,51 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions);
|
|||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
class State {
|
||||
bool ShouldRun(const Program* program) {
|
||||
auto& sem = program->Sem();
|
||||
for (auto* f : program->AST().Functions()) {
|
||||
if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/// PIMPL state for the transform
|
||||
struct UnwindDiscardFunctions::State {
|
||||
/// Constructor
|
||||
/// @param ctx_in the context
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
|
||||
// Iterate block statements and replace them as needed.
|
||||
for (auto* stmt : block->statements) {
|
||||
if (auto* new_stmt = Statement(stmt)) {
|
||||
ctx.Replace(stmt, new_stmt);
|
||||
}
|
||||
|
||||
// Handle for loops, as they are the only other AST node that
|
||||
// contains statements outside of BlockStatements.
|
||||
if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
|
||||
if (auto* new_stmt = Statement(fl->initializer)) {
|
||||
ctx.Replace(fl->initializer, new_stmt);
|
||||
}
|
||||
if (auto* new_stmt = Statement(fl->continuing)) {
|
||||
// NOTE: Should never reach here as we cannot discard in a
|
||||
// continuing block.
|
||||
ctx.Replace(fl->continuing, new_stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
CloneContext& ctx;
|
||||
ProgramBuilder& b;
|
||||
|
@ -163,7 +207,7 @@ class State {
|
|||
// Returns true if `stmt` is a for-loop initializer statement.
|
||||
bool IsForLoopInitStatement(const ast::Statement* stmt) {
|
||||
if (auto* sem_stmt = sem.Get(stmt)) {
|
||||
if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) {
|
||||
if (auto* sem_fl = tint::As<sem::ForLoopStatement>(sem_stmt->Parent())) {
|
||||
return sem_fl->Declaration()->initializer == stmt;
|
||||
}
|
||||
}
|
||||
|
@ -305,60 +349,26 @@ class State {
|
|||
return TryInsertAfter(s, sem_expr);
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param ctx_in the context
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
|
||||
|
||||
/// Runs the transform
|
||||
void Run() {
|
||||
ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
|
||||
// Iterate block statements and replace them as needed.
|
||||
for (auto* stmt : block->statements) {
|
||||
if (auto* new_stmt = Statement(stmt)) {
|
||||
ctx.Replace(stmt, new_stmt);
|
||||
}
|
||||
|
||||
// Handle for loops, as they are the only other AST node that
|
||||
// contains statements outside of BlockStatements.
|
||||
if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
|
||||
if (auto* new_stmt = Statement(fl->initializer)) {
|
||||
ctx.Replace(fl->initializer, new_stmt);
|
||||
}
|
||||
if (auto* new_stmt = Statement(fl->continuing)) {
|
||||
// NOTE: Should never reach here as we cannot discard in a
|
||||
// continuing block.
|
||||
ctx.Replace(fl->continuing, new_stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
UnwindDiscardFunctions::UnwindDiscardFunctions() = default;
|
||||
UnwindDiscardFunctions::~UnwindDiscardFunctions() = default;
|
||||
|
||||
void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
Transform::ApplyResult UnwindDiscardFunctions::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
State state(ctx);
|
||||
state.Run();
|
||||
}
|
||||
|
||||
bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const {
|
||||
auto& sem = program->Sem();
|
||||
for (auto* f : program->AST().Functions()) {
|
||||
if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -44,19 +44,13 @@ class UnwindDiscardFunctions final : public Castable<UnwindDiscardFunctions, Tra
|
|||
/// Destructor
|
||||
~UnwindDiscardFunctions() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
private:
|
||||
struct State;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -30,7 +30,59 @@
|
|||
namespace tint::transform {
|
||||
|
||||
/// Private implementation of HoistToDeclBefore transform
|
||||
class HoistToDeclBefore::State {
|
||||
struct HoistToDeclBefore::State {
|
||||
/// Constructor
|
||||
/// @param ctx_in the clone context
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::Add()
|
||||
bool Add(const sem::Expression* before_expr,
|
||||
const ast::Expression* expr,
|
||||
bool as_let,
|
||||
const char* decl_name) {
|
||||
auto name = b.Symbols().New(decl_name);
|
||||
|
||||
if (as_let) {
|
||||
auto builder = [this, expr, name] {
|
||||
return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr)));
|
||||
};
|
||||
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto builder = [this, expr, name] {
|
||||
return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr)));
|
||||
};
|
||||
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the initializer expression with a reference to the let
|
||||
ctx.Replace(expr, b.Expr(name));
|
||||
return true;
|
||||
}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
|
||||
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
|
||||
if (stmt) {
|
||||
auto builder = [stmt] { return stmt; };
|
||||
return InsertBeforeImpl(before_stmt, std::move(builder));
|
||||
}
|
||||
return InsertBeforeImpl(before_stmt, Decompose{});
|
||||
}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
|
||||
bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) {
|
||||
return InsertBeforeImpl(before_stmt, std::move(builder));
|
||||
}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::Prepare()
|
||||
bool Prepare(const sem::Expression* before_expr) {
|
||||
return InsertBefore(before_expr->Stmt(), nullptr);
|
||||
}
|
||||
|
||||
private:
|
||||
CloneContext& ctx;
|
||||
ProgramBuilder& b;
|
||||
|
||||
|
@ -215,6 +267,8 @@ class HoistToDeclBefore::State {
|
|||
|
||||
template <typename BUILDER>
|
||||
bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) {
|
||||
(void)builder; // Avoid 'unused parameter' warning due to 'if constexpr'
|
||||
|
||||
auto* ip = before_stmt->Declaration();
|
||||
|
||||
auto* else_if = before_stmt->As<sem::IfStatement>();
|
||||
|
@ -299,58 +353,6 @@ class HoistToDeclBefore::State {
|
|||
<< "unhandled expression parent statement type: " << parent->TypeInfo().name;
|
||||
return false;
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructor
|
||||
/// @param ctx_in the clone context
|
||||
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::Add()
|
||||
bool Add(const sem::Expression* before_expr,
|
||||
const ast::Expression* expr,
|
||||
bool as_let,
|
||||
const char* decl_name) {
|
||||
auto name = b.Symbols().New(decl_name);
|
||||
|
||||
if (as_let) {
|
||||
auto builder = [this, expr, name] {
|
||||
return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr)));
|
||||
};
|
||||
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto builder = [this, expr, name] {
|
||||
return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr)));
|
||||
};
|
||||
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the initializer expression with a reference to the let
|
||||
ctx.Replace(expr, b.Expr(name));
|
||||
return true;
|
||||
}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
|
||||
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
|
||||
if (stmt) {
|
||||
auto builder = [stmt] { return stmt; };
|
||||
return InsertBeforeImpl(before_stmt, std::move(builder));
|
||||
}
|
||||
return InsertBeforeImpl(before_stmt, Decompose{});
|
||||
}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
|
||||
bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) {
|
||||
return InsertBeforeImpl(before_stmt, std::move(builder));
|
||||
}
|
||||
|
||||
/// @copydoc HoistToDeclBefore::Prepare()
|
||||
bool Prepare(const sem::Expression* before_expr) {
|
||||
return InsertBefore(before_expr->Stmt(), nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {}
|
||||
|
|
|
@ -77,7 +77,7 @@ class HoistToDeclBefore {
|
|||
bool Prepare(const sem::Expression* before_expr);
|
||||
|
||||
private:
|
||||
class State;
|
||||
struct State;
|
||||
std::unique_ptr<State> state_;
|
||||
};
|
||||
|
||||
|
|
|
@ -13,6 +13,9 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "src/tint/transform/var_for_dynamic_index.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/program_builder.h"
|
||||
#include "src/tint/transform/utils/hoist_to_decl_before.h"
|
||||
|
||||
|
@ -22,7 +25,12 @@ VarForDynamicIndex::VarForDynamicIndex() = default;
|
|||
|
||||
VarForDynamicIndex::~VarForDynamicIndex() = default;
|
||||
|
||||
void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
HoistToDeclBefore hoist_to_decl_before(ctx);
|
||||
|
||||
// Extracts array and matrix values that are dynamically indexed to a
|
||||
|
@ -30,7 +38,7 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const
|
|||
auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) {
|
||||
auto* index_expr = access_expr->index;
|
||||
auto* object_expr = access_expr->object;
|
||||
auto& sem = ctx.src->Sem();
|
||||
auto& sem = src->Sem();
|
||||
|
||||
if (sem.Get(index_expr)->ConstantValue()) {
|
||||
// Index expression resolves to a compile time value.
|
||||
|
@ -49,15 +57,21 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const
|
|||
return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index");
|
||||
};
|
||||
|
||||
for (auto* node : ctx.src->ASTNodes().Objects()) {
|
||||
bool index_accessor_found = false;
|
||||
for (auto* node : src->ASTNodes().Objects()) {
|
||||
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
|
||||
if (!dynamic_index_to_var(access_expr)) {
|
||||
return;
|
||||
return Program(std::move(b));
|
||||
}
|
||||
index_accessor_found = true;
|
||||
}
|
||||
}
|
||||
if (!index_accessor_found) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -31,14 +31,10 @@ class VarForDynamicIndex : public Transform {
|
|||
/// Destructor
|
||||
~VarForDynamicIndex() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -30,11 +30,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeMatrixConversions);
|
|||
|
||||
namespace tint::transform {
|
||||
|
||||
VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
|
||||
namespace {
|
||||
|
||||
VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
|
||||
|
||||
bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* sem = program->Sem().Get<sem::Expression>(node)) {
|
||||
if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
|
||||
|
@ -50,14 +48,29 @@ bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap
|
|||
return false;
|
||||
}
|
||||
|
||||
void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
} // namespace
|
||||
|
||||
VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
|
||||
|
||||
VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
|
||||
|
||||
Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
using HelperFunctionKey =
|
||||
utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>;
|
||||
|
||||
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
|
||||
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
||||
auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* ty_conv = call->Target()->As<sem::TypeConversion>();
|
||||
if (!ty_conv) {
|
||||
return nullptr;
|
||||
|
@ -72,16 +85,16 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto& src = args[0];
|
||||
auto& matrix = args[0];
|
||||
|
||||
auto* src_type = args[0]->Type()->UnwrapRef()->As<sem::Matrix>();
|
||||
auto* src_type = matrix->Type()->UnwrapRef()->As<sem::Matrix>();
|
||||
if (!src_type) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// The source and destination type of a matrix conversion must have a same shape.
|
||||
if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "source and destination matrix has different shape in matrix conversion";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -90,47 +103,45 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&
|
|||
utils::Vector<const ast::Expression*, 4> columns;
|
||||
for (uint32_t c = 0; c < dst_type->columns(); c++) {
|
||||
auto* src_matrix_expr = src_expression_builder();
|
||||
auto* src_column_expr =
|
||||
ctx.dst->IndexAccessor(src_matrix_expr, ctx.dst->Expr(tint::AInt(c)));
|
||||
columns.Push(ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()),
|
||||
src_column_expr));
|
||||
auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));
|
||||
columns.Push(
|
||||
b.Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), src_column_expr));
|
||||
}
|
||||
return ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type), columns);
|
||||
return b.Construct(CreateASTTypeFor(ctx, dst_type), columns);
|
||||
};
|
||||
|
||||
// Replace the matrix conversion to column vector conversions and a matrix construction.
|
||||
if (!src->HasSideEffects()) {
|
||||
if (!matrix->HasSideEffects()) {
|
||||
// Simply use the argument's declaration if it has no side effects.
|
||||
return build_vectorized_conversion_expression([&]() { //
|
||||
return ctx.Clone(src->Declaration());
|
||||
return ctx.Clone(matrix->Declaration());
|
||||
});
|
||||
} else {
|
||||
// If has side effects, use a helper function.
|
||||
auto fn =
|
||||
utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] {
|
||||
auto name =
|
||||
ctx.dst->Symbols().New("convert_mat" + std::to_string(src_type->columns()) +
|
||||
"x" + std::to_string(src_type->rows()) + "_" +
|
||||
ctx.dst->FriendlyName(src_type->type()) + "_" +
|
||||
ctx.dst->FriendlyName(dst_type->type()));
|
||||
ctx.dst->Func(
|
||||
name,
|
||||
utils::Vector{
|
||||
ctx.dst->Param("value", CreateASTTypeFor(ctx, src_type)),
|
||||
},
|
||||
CreateASTTypeFor(ctx, dst_type),
|
||||
utils::Vector{
|
||||
ctx.dst->Return(build_vectorized_conversion_expression([&]() { //
|
||||
return ctx.dst->Expr("value");
|
||||
})),
|
||||
});
|
||||
auto name = b.Symbols().New(
|
||||
"convert_mat" + std::to_string(src_type->columns()) + "x" +
|
||||
std::to_string(src_type->rows()) + "_" + b.FriendlyName(src_type->type()) +
|
||||
"_" + b.FriendlyName(dst_type->type()));
|
||||
b.Func(name,
|
||||
utils::Vector{
|
||||
b.Param("value", CreateASTTypeFor(ctx, src_type)),
|
||||
},
|
||||
CreateASTTypeFor(ctx, dst_type),
|
||||
utils::Vector{
|
||||
b.Return(build_vectorized_conversion_expression([&]() { //
|
||||
return b.Expr("value");
|
||||
})),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
|
||||
return b.Call(fn, ctx.Clone(args[0]->Declaration()));
|
||||
}
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -28,19 +28,10 @@ class VectorizeMatrixConversions final : public Castable<VectorizeMatrixConversi
|
|||
/// Destructor
|
||||
~VectorizeMatrixConversions() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -27,12 +27,9 @@
|
|||
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixInitializers);
|
||||
|
||||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default;
|
||||
|
||||
VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default;
|
||||
|
||||
bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (auto* call = program->Sem().Get<sem::Call>(node)) {
|
||||
if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) {
|
||||
|
@ -46,11 +43,26 @@ bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const
|
|||
return false;
|
||||
}
|
||||
|
||||
void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
} // namespace
|
||||
|
||||
VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default;
|
||||
|
||||
VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default;
|
||||
|
||||
Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* src,
|
||||
const DataMap&,
|
||||
DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
std::unordered_map<const sem::Matrix*, Symbol> scalar_inits;
|
||||
|
||||
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
|
||||
auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
|
||||
auto* ty_init = call->Target()->As<sem::TypeInitializer>();
|
||||
if (!ty_init) {
|
||||
return nullptr;
|
||||
|
@ -87,10 +99,10 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D
|
|||
}
|
||||
|
||||
// Construct the column vector.
|
||||
columns.Push(ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
|
||||
std::move(row_values)));
|
||||
columns.Push(b.vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
|
||||
std::move(row_values)));
|
||||
}
|
||||
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
|
||||
return b.Construct(CreateASTTypeFor(ctx, mat_type), columns);
|
||||
};
|
||||
|
||||
if (args.Length() == 1) {
|
||||
|
@ -98,23 +110,22 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D
|
|||
// This is done to ensure that the single argument value is only evaluated once, and
|
||||
// with the correct expression evaluation order.
|
||||
auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] {
|
||||
auto name =
|
||||
ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" +
|
||||
std::to_string(mat_type->rows()));
|
||||
ctx.dst->Func(name,
|
||||
utils::Vector{
|
||||
// Single scalar parameter
|
||||
ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())),
|
||||
},
|
||||
CreateASTTypeFor(ctx, mat_type),
|
||||
utils::Vector{
|
||||
ctx.dst->Return(build_mat([&](uint32_t, uint32_t) { //
|
||||
return ctx.dst->Expr("value");
|
||||
})),
|
||||
});
|
||||
auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) +
|
||||
"x" + std::to_string(mat_type->rows()));
|
||||
b.Func(name,
|
||||
utils::Vector{
|
||||
// Single scalar parameter
|
||||
b.Param("value", CreateASTTypeFor(ctx, mat_type->type())),
|
||||
},
|
||||
CreateASTTypeFor(ctx, mat_type),
|
||||
utils::Vector{
|
||||
b.Return(build_mat([&](uint32_t, uint32_t) { //
|
||||
return b.Expr("value");
|
||||
})),
|
||||
});
|
||||
return name;
|
||||
});
|
||||
return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
|
||||
return b.Call(fn, ctx.Clone(args[0]->Declaration()));
|
||||
}
|
||||
|
||||
if (args.Length() == mat_type->columns() * mat_type->rows()) {
|
||||
|
@ -123,12 +134,13 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D
|
|||
});
|
||||
}
|
||||
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_ICE(Transform, b.Diagnostics())
|
||||
<< "matrix initializer has unexpected number of arguments";
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -29,19 +29,10 @@ class VectorizeScalarMatrixInitializers final
|
|||
/// Destructor
|
||||
~VectorizeScalarMatrixInitializers() override;
|
||||
|
||||
/// @param program the program to inspect
|
||||
/// @param data optional extra transform-specific input data
|
||||
/// @returns true if this transform should be run for the given program
|
||||
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
};
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
|
@ -201,13 +201,46 @@ DataType DataTypeOf(VertexFormat format) {
|
|||
return {BaseType::kInvalid, 0};
|
||||
}
|
||||
|
||||
struct State {
|
||||
State(CloneContext& context, const VertexPulling::Config& c) : ctx(context), cfg(c) {}
|
||||
State(const State&) = default;
|
||||
~State() = default;
|
||||
} // namespace
|
||||
|
||||
/// LocationReplacement describes an ast::Variable replacement for a
|
||||
/// location input.
|
||||
/// PIMPL state for the transform
|
||||
struct VertexPulling::State {
|
||||
/// Constructor
|
||||
/// @param program the source program
|
||||
/// @param c the VertexPulling config
|
||||
State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {}
|
||||
|
||||
/// Runs the transform
|
||||
/// @returns the new program or SkipTransform if the transform is not required
|
||||
ApplyResult Run() {
|
||||
// Find entry point
|
||||
const ast::Function* func = nullptr;
|
||||
for (auto* fn : src->AST().Functions()) {
|
||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
if (func != nullptr) {
|
||||
b.Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"VertexPulling found more than one vertex entry point");
|
||||
return Program(std::move(b));
|
||||
}
|
||||
func = fn;
|
||||
}
|
||||
}
|
||||
if (func == nullptr) {
|
||||
b.Diagnostics().add_error(diag::System::Transform,
|
||||
"Vertex stage entry point not found");
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
AddVertexStorageBuffers();
|
||||
Process(func);
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
private:
|
||||
/// LocationReplacement describes an ast::Variable replacement for a location input.
|
||||
struct LocationReplacement {
|
||||
/// The variable to replace in the source Program
|
||||
ast::Variable* from;
|
||||
|
@ -215,13 +248,22 @@ struct State {
|
|||
ast::Variable* to;
|
||||
};
|
||||
|
||||
/// LocationInfo describes an input location
|
||||
struct LocationInfo {
|
||||
/// A builder that builds the expression that resolves to the (transformed) input location
|
||||
std::function<const ast::Expression*()> expr;
|
||||
/// The store type of the location variable
|
||||
const sem::Type* type;
|
||||
};
|
||||
|
||||
CloneContext& ctx;
|
||||
/// The source program
|
||||
const Program* const src;
|
||||
/// The transform config
|
||||
VertexPulling::Config const cfg;
|
||||
/// The target program builder
|
||||
ProgramBuilder b;
|
||||
/// The clone context
|
||||
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
|
||||
std::unordered_map<uint32_t, LocationInfo> location_info;
|
||||
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
|
||||
std::function<const ast::Expression*()> instance_index_expr = nullptr;
|
||||
|
@ -235,7 +277,7 @@ struct State {
|
|||
Symbol GetVertexBufferName(uint32_t index) {
|
||||
return utils::GetOrCreate(vertex_buffer_names, index, [&] {
|
||||
static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
|
||||
return ctx.dst->Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
|
||||
return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -243,7 +285,7 @@ struct State {
|
|||
Symbol GetStructBufferName() {
|
||||
if (!struct_buffer_name.IsValid()) {
|
||||
static const char kStructBufferName[] = "tint_vertex_data";
|
||||
struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
|
||||
struct_buffer_name = b.Symbols().New(kStructBufferName);
|
||||
}
|
||||
return struct_buffer_name;
|
||||
}
|
||||
|
@ -252,21 +294,19 @@ struct State {
|
|||
void AddVertexStorageBuffers() {
|
||||
// Creating the struct type
|
||||
static const char kStructName[] = "TintVertexData";
|
||||
auto* struct_type =
|
||||
ctx.dst->Structure(ctx.dst->Symbols().New(kStructName),
|
||||
utils::Vector{
|
||||
ctx.dst->Member(GetStructBufferName(), ctx.dst->ty.array<u32>()),
|
||||
});
|
||||
auto* struct_type = b.Structure(b.Symbols().New(kStructName),
|
||||
utils::Vector{
|
||||
b.Member(GetStructBufferName(), b.ty.array<u32>()),
|
||||
});
|
||||
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
|
||||
// The decorated variable with struct type
|
||||
ctx.dst->GlobalVar(GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
|
||||
ast::AddressSpace::kStorage, ast::Access::kRead,
|
||||
ctx.dst->Binding(AInt(i)), ctx.dst->Group(AInt(cfg.pulling_group)));
|
||||
b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage,
|
||||
ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group)));
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates and returns the assignment to the variables from the buffers
|
||||
ast::BlockStatement* CreateVertexPullingPreamble() {
|
||||
const ast::BlockStatement* CreateVertexPullingPreamble() {
|
||||
// Assign by looking at the vertex descriptor to find attributes with
|
||||
// matching location.
|
||||
|
||||
|
@ -276,7 +316,7 @@ struct State {
|
|||
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
|
||||
|
||||
if ((buffer_layout.array_stride & 3) != 0) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
b.Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"WebGPU requires that vertex stride must be a multiple of 4 bytes, "
|
||||
"but VertexPulling array stride for buffer " +
|
||||
|
@ -292,15 +332,15 @@ struct State {
|
|||
// buffer_array_base is the base array offset for all the vertex
|
||||
// attributes. These are units of uint (4 bytes).
|
||||
auto buffer_array_base =
|
||||
ctx.dst->Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
|
||||
b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
|
||||
|
||||
auto* attribute_offset = index_expr;
|
||||
if (buffer_layout.array_stride != 4) {
|
||||
attribute_offset = ctx.dst->Mul(index_expr, u32(buffer_layout.array_stride / 4u));
|
||||
attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u));
|
||||
}
|
||||
|
||||
// let pulling_offset_n = <attribute_offset>
|
||||
stmts.Push(ctx.dst->Decl(ctx.dst->Let(buffer_array_base, attribute_offset)));
|
||||
stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset)));
|
||||
|
||||
for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) {
|
||||
auto it = location_info.find(attribute_desc.shader_location);
|
||||
|
@ -320,8 +360,8 @@ struct State {
|
|||
err << "VertexAttributeDescriptor for location "
|
||||
<< std::to_string(attribute_desc.shader_location) << " has format "
|
||||
<< attribute_desc.format << " but shader expects "
|
||||
<< var.type->FriendlyName(ctx.src->Symbols());
|
||||
ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
|
||||
<< var.type->FriendlyName(src->Symbols());
|
||||
b.Diagnostics().add_error(diag::System::Transform, err.str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -337,16 +377,16 @@ struct State {
|
|||
// WGSL variable vector width is smaller than the loaded vector width
|
||||
switch (var_dt.width) {
|
||||
case 1:
|
||||
value = ctx.dst->MemberAccessor(fetch, "x");
|
||||
value = b.MemberAccessor(fetch, "x");
|
||||
break;
|
||||
case 2:
|
||||
value = ctx.dst->MemberAccessor(fetch, "xy");
|
||||
value = b.MemberAccessor(fetch, "xy");
|
||||
break;
|
||||
case 3:
|
||||
value = ctx.dst->MemberAccessor(fetch, "xyz");
|
||||
value = b.MemberAccessor(fetch, "xyz");
|
||||
break;
|
||||
default:
|
||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.width;
|
||||
TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width;
|
||||
return nullptr;
|
||||
}
|
||||
} else if (var_dt.width > fmt_dt.width) {
|
||||
|
@ -355,32 +395,32 @@ struct State {
|
|||
utils::Vector<const ast::Expression*, 8> values{fetch};
|
||||
switch (var_dt.base_type) {
|
||||
case BaseType::kI32:
|
||||
ty = ctx.dst->ty.i32();
|
||||
ty = b.ty.i32();
|
||||
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
|
||||
values.Push(ctx.dst->Expr((i == 3) ? 1_i : 0_i));
|
||||
values.Push(b.Expr((i == 3) ? 1_i : 0_i));
|
||||
}
|
||||
break;
|
||||
case BaseType::kU32:
|
||||
ty = ctx.dst->ty.u32();
|
||||
ty = b.ty.u32();
|
||||
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
|
||||
values.Push(ctx.dst->Expr((i == 3) ? 1_u : 0_u));
|
||||
values.Push(b.Expr((i == 3) ? 1_u : 0_u));
|
||||
}
|
||||
break;
|
||||
case BaseType::kF32:
|
||||
ty = ctx.dst->ty.f32();
|
||||
ty = b.ty.f32();
|
||||
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
|
||||
values.Push(ctx.dst->Expr((i == 3) ? 1_f : 0_f));
|
||||
values.Push(b.Expr((i == 3) ? 1_f : 0_f));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.base_type;
|
||||
TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type;
|
||||
return nullptr;
|
||||
}
|
||||
value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
|
||||
value = b.Construct(b.ty.vec(ty, var_dt.width), values);
|
||||
}
|
||||
|
||||
// Assign the value to the WGSL variable
|
||||
stmts.Push(ctx.dst->Assign(var.expr(), value));
|
||||
stmts.Push(b.Assign(var.expr(), value));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -388,7 +428,7 @@ struct State {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx.dst->create<ast::BlockStatement>(std::move(stmts));
|
||||
return b.Block(std::move(stmts));
|
||||
}
|
||||
|
||||
/// Generates an expression reading from a buffer a specific format.
|
||||
|
@ -407,7 +447,7 @@ struct State {
|
|||
};
|
||||
|
||||
// Returns a i32 loaded from buffer_base + offset.
|
||||
auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
|
||||
auto load_i32 = [&] { return b.Bitcast<i32>(load_u32()); };
|
||||
|
||||
// Returns a u32 loaded from buffer_base + offset + 4.
|
||||
auto load_next_u32 = [&] {
|
||||
|
@ -415,7 +455,7 @@ struct State {
|
|||
};
|
||||
|
||||
// Returns a i32 loaded from buffer_base + offset + 4.
|
||||
auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
|
||||
auto load_next_i32 = [&] { return b.Bitcast<i32>(load_next_u32()); };
|
||||
|
||||
// Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
|
||||
// The low 16 bits are 0.
|
||||
|
@ -427,17 +467,17 @@ struct State {
|
|||
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
|
||||
switch (offset & 3) {
|
||||
case 0:
|
||||
return ctx.dst->Shl(low_u32, 16_u);
|
||||
return b.Shl(low_u32, 16_u);
|
||||
case 1:
|
||||
return ctx.dst->And(ctx.dst->Shl(low_u32, 8_u), 0xffff0000_u);
|
||||
return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u);
|
||||
case 2:
|
||||
return ctx.dst->And(low_u32, 0xffff0000_u);
|
||||
return b.And(low_u32, 0xffff0000_u);
|
||||
default: { // 3:
|
||||
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
|
||||
VertexFormat::kUint32);
|
||||
auto* shr = ctx.dst->Shr(low_u32, 8_u);
|
||||
auto* shl = ctx.dst->Shl(high_u32, 24_u);
|
||||
return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000_u);
|
||||
auto* shr = b.Shr(low_u32, 8_u);
|
||||
auto* shl = b.Shl(high_u32, 24_u);
|
||||
return b.And(b.Or(shl, shr), 0xffff0000_u);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -450,24 +490,24 @@ struct State {
|
|||
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
|
||||
switch (offset & 3) {
|
||||
case 0:
|
||||
return ctx.dst->And(low_u32, 0xffff_u);
|
||||
return b.And(low_u32, 0xffff_u);
|
||||
case 1:
|
||||
return ctx.dst->And(ctx.dst->Shr(low_u32, 8_u), 0xffff_u);
|
||||
return b.And(b.Shr(low_u32, 8_u), 0xffff_u);
|
||||
case 2:
|
||||
return ctx.dst->Shr(low_u32, 16_u);
|
||||
return b.Shr(low_u32, 16_u);
|
||||
default: { // 3:
|
||||
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
|
||||
VertexFormat::kUint32);
|
||||
auto* shr = ctx.dst->Shr(low_u32, 24_u);
|
||||
auto* shl = ctx.dst->Shl(high_u32, 8_u);
|
||||
return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff_u);
|
||||
auto* shr = b.Shr(low_u32, 24_u);
|
||||
auto* shl = b.Shl(high_u32, 8_u);
|
||||
return b.And(b.Or(shl, shr), 0xffff_u);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
|
||||
// The low 16 bits are 0.
|
||||
auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
|
||||
auto load_i16_h = [&] { return b.Bitcast<i32>(load_u16_h()); };
|
||||
|
||||
// Assumptions are made that alignment must be at least as large as the size
|
||||
// of a single component.
|
||||
|
@ -480,128 +520,121 @@ struct State {
|
|||
|
||||
// Vectors of basic primitives
|
||||
case VertexFormat::kUint32x2:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
|
||||
VertexFormat::kUint32, 2);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2);
|
||||
case VertexFormat::kUint32x3:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
|
||||
VertexFormat::kUint32, 3);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3);
|
||||
case VertexFormat::kUint32x4:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
|
||||
VertexFormat::kUint32, 4);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4);
|
||||
case VertexFormat::kSint32x2:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
|
||||
VertexFormat::kSint32, 2);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2);
|
||||
case VertexFormat::kSint32x3:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
|
||||
VertexFormat::kSint32, 3);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3);
|
||||
case VertexFormat::kSint32x4:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
|
||||
VertexFormat::kSint32, 4);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4);
|
||||
case VertexFormat::kFloat32x2:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
|
||||
VertexFormat::kFloat32, 2);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
|
||||
2);
|
||||
case VertexFormat::kFloat32x3:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
|
||||
VertexFormat::kFloat32, 3);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
|
||||
3);
|
||||
case VertexFormat::kFloat32x4:
|
||||
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
|
||||
VertexFormat::kFloat32, 4);
|
||||
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
|
||||
4);
|
||||
|
||||
case VertexFormat::kUint8x2: {
|
||||
// yyxx0000, yyxx0000
|
||||
auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
|
||||
auto* u16s = b.vec2<u32>(load_u16_h());
|
||||
// xx000000, yyxx0000
|
||||
auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8_u, 0_u));
|
||||
auto* shl = b.Shl(u16s, b.vec2<u32>(8_u, 0_u));
|
||||
// 000000xx, 000000yy
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u));
|
||||
return b.Shr(shl, b.vec2<u32>(24_u));
|
||||
}
|
||||
case VertexFormat::kUint8x4: {
|
||||
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
|
||||
auto* u32s = ctx.dst->vec4<u32>(load_u32());
|
||||
auto* u32s = b.vec4<u32>(load_u32());
|
||||
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
|
||||
auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u));
|
||||
auto* shl = b.Shl(u32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
|
||||
// 000000xx, 000000yy, 000000zz, 000000ww
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u));
|
||||
return b.Shr(shl, b.vec4<u32>(24_u));
|
||||
}
|
||||
case VertexFormat::kUint16x2: {
|
||||
// yyyyxxxx, yyyyxxxx
|
||||
auto* u32s = ctx.dst->vec2<u32>(load_u32());
|
||||
auto* u32s = b.vec2<u32>(load_u32());
|
||||
// xxxx0000, yyyyxxxx
|
||||
auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16_u, 0_u));
|
||||
auto* shl = b.Shl(u32s, b.vec2<u32>(16_u, 0_u));
|
||||
// 0000xxxx, 0000yyyy
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u));
|
||||
return b.Shr(shl, b.vec2<u32>(16_u));
|
||||
}
|
||||
case VertexFormat::kUint16x4: {
|
||||
// yyyyxxxx, wwwwzzzz
|
||||
auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
|
||||
auto* u32s = b.vec2<u32>(load_u32(), load_next_u32());
|
||||
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
|
||||
auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
|
||||
auto* xxyy = b.MemberAccessor(u32s, "xxyy");
|
||||
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
|
||||
auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u));
|
||||
auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
|
||||
// 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u));
|
||||
return b.Shr(shl, b.vec4<u32>(16_u));
|
||||
}
|
||||
case VertexFormat::kSint8x2: {
|
||||
// yyxx0000, yyxx0000
|
||||
auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
|
||||
auto* i16s = b.vec2<i32>(load_i16_h());
|
||||
// xx000000, yyxx0000
|
||||
auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8_u, 0_u));
|
||||
auto* shl = b.Shl(i16s, b.vec2<u32>(8_u, 0_u));
|
||||
// ssssssxx, ssssssyy
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u));
|
||||
return b.Shr(shl, b.vec2<u32>(24_u));
|
||||
}
|
||||
case VertexFormat::kSint8x4: {
|
||||
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
|
||||
auto* i32s = ctx.dst->vec4<i32>(load_i32());
|
||||
auto* i32s = b.vec4<i32>(load_i32());
|
||||
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
|
||||
auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u));
|
||||
auto* shl = b.Shl(i32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
|
||||
// ssssssxx, ssssssyy, sssssszz, ssssssww
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u));
|
||||
return b.Shr(shl, b.vec4<u32>(24_u));
|
||||
}
|
||||
case VertexFormat::kSint16x2: {
|
||||
// yyyyxxxx, yyyyxxxx
|
||||
auto* i32s = ctx.dst->vec2<i32>(load_i32());
|
||||
auto* i32s = b.vec2<i32>(load_i32());
|
||||
// xxxx0000, yyyyxxxx
|
||||
auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16_u, 0_u));
|
||||
auto* shl = b.Shl(i32s, b.vec2<u32>(16_u, 0_u));
|
||||
// ssssxxxx, ssssyyyy
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u));
|
||||
return b.Shr(shl, b.vec2<u32>(16_u));
|
||||
}
|
||||
case VertexFormat::kSint16x4: {
|
||||
// yyyyxxxx, wwwwzzzz
|
||||
auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
|
||||
auto* i32s = b.vec2<i32>(load_i32(), load_next_i32());
|
||||
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
|
||||
auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
|
||||
auto* xxyy = b.MemberAccessor(i32s, "xxyy");
|
||||
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
|
||||
auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u));
|
||||
auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
|
||||
// ssssxxxx, ssssyyyy, sssszzzz, sssswwww
|
||||
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u));
|
||||
return b.Shr(shl, b.vec4<u32>(16_u));
|
||||
}
|
||||
case VertexFormat::kUnorm8x2:
|
||||
return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
|
||||
return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy");
|
||||
case VertexFormat::kSnorm8x2:
|
||||
return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
|
||||
return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy");
|
||||
case VertexFormat::kUnorm8x4:
|
||||
return ctx.dst->Call("unpack4x8unorm", load_u32());
|
||||
return b.Call("unpack4x8unorm", load_u32());
|
||||
case VertexFormat::kSnorm8x4:
|
||||
return ctx.dst->Call("unpack4x8snorm", load_u32());
|
||||
return b.Call("unpack4x8snorm", load_u32());
|
||||
case VertexFormat::kUnorm16x2:
|
||||
return ctx.dst->Call("unpack2x16unorm", load_u32());
|
||||
return b.Call("unpack2x16unorm", load_u32());
|
||||
case VertexFormat::kSnorm16x2:
|
||||
return ctx.dst->Call("unpack2x16snorm", load_u32());
|
||||
return b.Call("unpack2x16snorm", load_u32());
|
||||
case VertexFormat::kFloat16x2:
|
||||
return ctx.dst->Call("unpack2x16float", load_u32());
|
||||
return b.Call("unpack2x16float", load_u32());
|
||||
case VertexFormat::kUnorm16x4:
|
||||
return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16unorm", load_u32()),
|
||||
ctx.dst->Call("unpack2x16unorm", load_next_u32()));
|
||||
return b.vec4<f32>(b.Call("unpack2x16unorm", load_u32()),
|
||||
b.Call("unpack2x16unorm", load_next_u32()));
|
||||
case VertexFormat::kSnorm16x4:
|
||||
return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16snorm", load_u32()),
|
||||
ctx.dst->Call("unpack2x16snorm", load_next_u32()));
|
||||
return b.vec4<f32>(b.Call("unpack2x16snorm", load_u32()),
|
||||
b.Call("unpack2x16snorm", load_next_u32()));
|
||||
case VertexFormat::kFloat16x4:
|
||||
return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16float", load_u32()),
|
||||
ctx.dst->Call("unpack2x16float", load_next_u32()));
|
||||
return b.vec4<f32>(b.Call("unpack2x16float", load_u32()),
|
||||
b.Call("unpack2x16float", load_next_u32()));
|
||||
}
|
||||
|
||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
||||
<< "format " << static_cast<int>(format);
|
||||
TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast<int>(format);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -623,12 +656,12 @@ struct State {
|
|||
|
||||
const ast ::Expression* index = nullptr;
|
||||
if (offset > 0) {
|
||||
index = ctx.dst->Add(array_base, u32(offset / 4));
|
||||
index = b.Add(array_base, u32(offset / 4));
|
||||
} else {
|
||||
index = ctx.dst->Expr(array_base);
|
||||
index = b.Expr(array_base);
|
||||
}
|
||||
u = ctx.dst->IndexAccessor(
|
||||
ctx.dst->MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
|
||||
u = b.IndexAccessor(
|
||||
b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
|
||||
|
||||
} else {
|
||||
// Unaligned load
|
||||
|
@ -639,22 +672,22 @@ struct State {
|
|||
|
||||
uint32_t shift = 8u * (offset & 3u);
|
||||
|
||||
auto* low_shr = ctx.dst->Shr(low, u32(shift));
|
||||
auto* high_shl = ctx.dst->Shl(high, u32(32u - shift));
|
||||
u = ctx.dst->Or(low_shr, high_shl);
|
||||
auto* low_shr = b.Shr(low, u32(shift));
|
||||
auto* high_shl = b.Shl(high, u32(32u - shift));
|
||||
u = b.Or(low_shr, high_shl);
|
||||
}
|
||||
|
||||
switch (format) {
|
||||
case VertexFormat::kUint32:
|
||||
return u;
|
||||
case VertexFormat::kSint32:
|
||||
return ctx.dst->Bitcast(ctx.dst->ty.i32(), u);
|
||||
return b.Bitcast(b.ty.i32(), u);
|
||||
case VertexFormat::kFloat32:
|
||||
return ctx.dst->Bitcast(ctx.dst->ty.f32(), u);
|
||||
return b.Bitcast(b.ty.f32(), u);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
|
||||
TINT_UNREACHABLE(Transform, b.Diagnostics())
|
||||
<< "invalid format for LoadPrimitive" << static_cast<int>(format);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -682,8 +715,7 @@ struct State {
|
|||
expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format));
|
||||
}
|
||||
|
||||
return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
|
||||
std::move(expr_list));
|
||||
return b.Construct(b.create<ast::Vector>(base_type, count), std::move(expr_list));
|
||||
}
|
||||
|
||||
/// Process a non-struct entry point parameter.
|
||||
|
@ -696,34 +728,30 @@ struct State {
|
|||
// Create a function-scope variable to replace the parameter.
|
||||
auto func_var_sym = ctx.Clone(param->symbol);
|
||||
auto* func_var_type = ctx.Clone(param->type);
|
||||
auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
|
||||
ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
|
||||
auto* func_var = b.Var(func_var_sym, func_var_type);
|
||||
ctx.InsertFront(func->body->statements, b.Decl(func_var));
|
||||
// Capture mapping from location to the new variable.
|
||||
LocationInfo info;
|
||||
info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
|
||||
info.expr = [this, func_var]() { return b.Expr(func_var); };
|
||||
|
||||
auto* sem = ctx.src->Sem().Get<sem::Parameter>(param);
|
||||
auto* sem = src->Sem().Get<sem::Parameter>(param);
|
||||
info.type = sem->Type();
|
||||
|
||||
if (!sem->Location().has_value()) {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value";
|
||||
TINT_ICE(Transform, b.Diagnostics()) << "Location missing value";
|
||||
return;
|
||||
}
|
||||
location_info[sem->Location().value()] = info;
|
||||
} else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
|
||||
// Check for existing vertex_index and instance_index builtins.
|
||||
if (builtin->builtin == ast::BuiltinValue::kVertexIndex) {
|
||||
vertex_index_expr = [this, param]() {
|
||||
return ctx.dst->Expr(ctx.Clone(param->symbol));
|
||||
};
|
||||
vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
|
||||
} else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) {
|
||||
instance_index_expr = [this, param]() {
|
||||
return ctx.dst->Expr(ctx.Clone(param->symbol));
|
||||
};
|
||||
instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
|
||||
}
|
||||
new_function_parameters.Push(ctx.Clone(param));
|
||||
} else {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
|
||||
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -746,7 +774,7 @@ struct State {
|
|||
for (auto* member : struct_ty->members) {
|
||||
auto member_sym = ctx.Clone(member->symbol);
|
||||
std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() {
|
||||
return ctx.dst->MemberAccessor(param_sym, member_sym);
|
||||
return b.MemberAccessor(param_sym, member_sym);
|
||||
};
|
||||
|
||||
if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) {
|
||||
|
@ -754,7 +782,7 @@ struct State {
|
|||
LocationInfo info;
|
||||
info.expr = member_expr;
|
||||
|
||||
auto* sem = ctx.src->Sem().Get(member);
|
||||
auto* sem = src->Sem().Get(member);
|
||||
info.type = sem->Type();
|
||||
|
||||
TINT_ASSERT(Transform, sem->Location().has_value());
|
||||
|
@ -770,7 +798,7 @@ struct State {
|
|||
}
|
||||
members_to_clone.Push(member);
|
||||
} else {
|
||||
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
|
||||
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -781,8 +809,8 @@ struct State {
|
|||
}
|
||||
|
||||
// Create a function-scope variable to replace the parameter.
|
||||
auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
|
||||
ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
|
||||
auto* func_var = b.Var(param_sym, ctx.Clone(param->type));
|
||||
ctx.InsertFront(func->body->statements, b.Decl(func_var));
|
||||
|
||||
if (!members_to_clone.IsEmpty()) {
|
||||
// Create a new struct without the location attributes.
|
||||
|
@ -791,20 +819,20 @@ struct State {
|
|||
auto member_sym = ctx.Clone(member->symbol);
|
||||
auto* member_type = ctx.Clone(member->type);
|
||||
auto member_attrs = ctx.Clone(member->attributes);
|
||||
new_members.Push(ctx.dst->Member(member_sym, member_type, std::move(member_attrs)));
|
||||
new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs)));
|
||||
}
|
||||
auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
|
||||
auto* new_struct = b.Structure(b.Sym(), new_members);
|
||||
|
||||
// Create a new function parameter with this struct.
|
||||
auto* new_param = ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
|
||||
auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct));
|
||||
new_function_parameters.Push(new_param);
|
||||
|
||||
// Copy values from the new parameter to the function-scope variable.
|
||||
for (auto* member : members_to_clone) {
|
||||
auto member_name = ctx.Clone(member->symbol);
|
||||
ctx.InsertFront(func->body->statements,
|
||||
ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
|
||||
ctx.dst->MemberAccessor(new_param, member_name)));
|
||||
b.Assign(b.MemberAccessor(func_var, member_name),
|
||||
b.MemberAccessor(new_param, member_name)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -818,7 +846,7 @@ struct State {
|
|||
|
||||
// Process entry point parameters.
|
||||
for (auto* param : func->params) {
|
||||
auto* sem = ctx.src->Sem().Get(param);
|
||||
auto* sem = src->Sem().Get(param);
|
||||
if (auto* str = sem->Type()->As<sem::Struct>()) {
|
||||
ProcessStructParameter(func, param, str->Declaration());
|
||||
} else {
|
||||
|
@ -830,11 +858,11 @@ struct State {
|
|||
if (!vertex_index_expr) {
|
||||
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
|
||||
if (layout.step_mode == VertexStepMode::kVertex) {
|
||||
auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
|
||||
new_function_parameters.Push(ctx.dst->Param(
|
||||
name, ctx.dst->ty.u32(),
|
||||
utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kVertexIndex)}));
|
||||
vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
|
||||
auto name = b.Symbols().New("tint_pulling_vertex_index");
|
||||
new_function_parameters.Push(
|
||||
b.Param(name, b.ty.u32(),
|
||||
utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)}));
|
||||
vertex_index_expr = [this, name]() { return b.Expr(name); };
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -842,11 +870,11 @@ struct State {
|
|||
if (!instance_index_expr) {
|
||||
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
|
||||
if (layout.step_mode == VertexStepMode::kInstance) {
|
||||
auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
|
||||
new_function_parameters.Push(ctx.dst->Param(
|
||||
name, ctx.dst->ty.u32(),
|
||||
utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kInstanceIndex)}));
|
||||
instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
|
||||
auto name = b.Symbols().New("tint_pulling_instance_index");
|
||||
new_function_parameters.Push(
|
||||
b.Param(name, b.ty.u32(),
|
||||
utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)}));
|
||||
instance_index_expr = [this, name]() { return b.Expr(name); };
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -864,53 +892,24 @@ struct State {
|
|||
auto attrs = ctx.Clone(func->attributes);
|
||||
auto ret_attrs = ctx.Clone(func->return_type_attributes);
|
||||
auto* new_func =
|
||||
ctx.dst->create<ast::Function>(func->source, func_sym, new_function_parameters,
|
||||
ret_type, body, std::move(attrs), std::move(ret_attrs));
|
||||
b.create<ast::Function>(func->source, func_sym, new_function_parameters, ret_type, body,
|
||||
std::move(attrs), std::move(ret_attrs));
|
||||
ctx.Replace(func, new_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
VertexPulling::VertexPulling() = default;
|
||||
VertexPulling::~VertexPulling() = default;
|
||||
|
||||
void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
|
||||
Transform::ApplyResult VertexPulling::Apply(const Program* src,
|
||||
const DataMap& inputs,
|
||||
DataMap&) const {
|
||||
auto cfg = cfg_;
|
||||
if (auto* cfg_data = inputs.Get<Config>()) {
|
||||
cfg = *cfg_data;
|
||||
}
|
||||
|
||||
// Find entry point
|
||||
const ast::Function* func = nullptr;
|
||||
for (auto* fn : ctx.src->AST().Functions()) {
|
||||
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
|
||||
if (func != nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(
|
||||
diag::System::Transform,
|
||||
"VertexPulling found more than one vertex entry point");
|
||||
return;
|
||||
}
|
||||
func = fn;
|
||||
}
|
||||
}
|
||||
if (func == nullptr) {
|
||||
ctx.dst->Diagnostics().add_error(diag::System::Transform,
|
||||
"Vertex stage entry point not found");
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(idanr): Need to check shader locations in descriptor cover all
|
||||
// attributes
|
||||
|
||||
// TODO(idanr): Make sure we covered all error cases, to guarantee the
|
||||
// following stages will pass
|
||||
|
||||
State state{ctx, cfg};
|
||||
state.AddVertexStorageBuffers();
|
||||
state.Process(func);
|
||||
|
||||
ctx.Clone();
|
||||
return State{src, cfg}.Run();
|
||||
}
|
||||
|
||||
VertexPulling::Config::Config() = default;
|
||||
|
|
|
@ -171,16 +171,14 @@ class VertexPulling final : public Castable<VertexPulling, Transform> {
|
|||
/// Destructor
|
||||
~VertexPulling() override;
|
||||
|
||||
protected:
|
||||
/// Runs the transform using the CloneContext built for transforming a
|
||||
/// program. Run() is responsible for calling Clone() on the CloneContext.
|
||||
/// @param ctx the CloneContext primed with the input program and
|
||||
/// ProgramBuilder
|
||||
/// @param inputs optional extra transform-specific input data
|
||||
/// @param outputs optional extra transform-specific output data
|
||||
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
|
||||
/// @copydoc Transform::Apply
|
||||
ApplyResult Apply(const Program* program,
|
||||
const DataMap& inputs,
|
||||
DataMap& outputs) const override;
|
||||
|
||||
private:
|
||||
struct State;
|
||||
|
||||
Config cfg_;
|
||||
};
|
||||
|
||||
|
|
|
@ -14,18 +14,17 @@
|
|||
|
||||
#include "src/tint/transform/while_to_loop.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/tint/ast/break_statement.h"
|
||||
#include "src/tint/program_builder.h"
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop);
|
||||
|
||||
namespace tint::transform {
|
||||
namespace {
|
||||
|
||||
WhileToLoop::WhileToLoop() = default;
|
||||
|
||||
WhileToLoop::~WhileToLoop() = default;
|
||||
|
||||
bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
|
||||
bool ShouldRun(const Program* program) {
|
||||
for (auto* node : program->ASTNodes().Objects()) {
|
||||
if (node->Is<ast::WhileStatement>()) {
|
||||
return true;
|
||||
|
@ -34,20 +33,32 @@ bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
||||
} // namespace
|
||||
|
||||
WhileToLoop::WhileToLoop() = default;
|
||||
|
||||
WhileToLoop::~WhileToLoop() = default;
|
||||
|
||||
Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
|
||||
if (!ShouldRun(src)) {
|
||||
return SkipTransform;
|
||||
}
|
||||
|
||||
ProgramBuilder b;
|
||||
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
|
||||
|
||||
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
|
||||
utils::Vector<const ast::Statement*, 16> stmts;
|
||||
auto* cond = w->condition;
|
||||
|
||||
// !condition
|
||||
auto* not_cond =
|
||||
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
|
||||
auto* not_cond = b.Not(ctx.Clone(cond));
|
||||
|
||||
// { break; }
|
||||
auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
|
||||
auto* break_body = b.Block(b.Break());
|
||||
|
||||
// if (!condition) { break; }
|
||||
stmts.Push(ctx.dst->If(not_cond, break_body));
|
||||
stmts.Push(b.If(not_cond, break_body));
|
||||
|
||||
for (auto* stmt : w->body->statements) {
|
||||
stmts.Push(ctx.Clone(stmt));
|
||||
|
@ -55,13 +66,14 @@ void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
|
|||
|
||||
const ast::BlockStatement* continuing = nullptr;
|
||||
|
||||
auto* body = ctx.dst->Block(stmts);
|
||||
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
|
||||
auto* body = b.Block(stmts);
|
||||
auto* loop = b.Loop(body, continuing);
|
||||
|
||||
return loop;
|
||||
});
|
||||
|
||||
ctx.Clone();
|
||||
return Program(std::move(b));
|
||||
}
|
||||
|
||||
} // namespace tint::transform
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue