tint/transform: Refactor transforms

Replace the ShouldRun() method with Apply() which will do the
transformation if it needs to be done, otherwise returns
'SkipTransform'.

This reduces a bunch of duplicated scanning between the old ShouldRun()
and Transform().

This change also adjusts code style to make the transforms more
consistent.

Change-Id: I9a6b10cb8b4ed62676b12ef30fb7764d363386c6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107681
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2022-11-03 08:41:19 +00:00 committed by Dawn LUCI CQ
parent de6db384aa
commit c6b381495d
111 changed files with 2132 additions and 2205 deletions

View File

@ -15,6 +15,7 @@
#include "src/tint/fuzzers/shuffle_transform.h"
#include <random>
#include <utility>
#include "src/tint/program_builder.h"
@ -22,15 +23,21 @@ namespace tint::fuzzers {
ShuffleTransform::ShuffleTransform(size_t seed) : seed_(seed) {}
void ShuffleTransform::Run(CloneContext& ctx,
const tint::transform::DataMap&,
tint::transform::DataMap&) const {
auto decls = ctx.src->AST().GlobalDeclarations();
transform::Transform::ApplyResult ShuffleTransform::Apply(const Program* src,
const transform::DataMap&,
transform::DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto decls = src->AST().GlobalDeclarations();
auto rng = std::mt19937_64{seed_};
std::shuffle(std::begin(decls), std::end(decls), rng);
for (auto* decl : decls) {
ctx.dst->AST().AddGlobalDeclaration(ctx.Clone(decl));
b.AST().AddGlobalDeclaration(ctx.Clone(decl));
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::fuzzers

View File

@ -20,16 +20,16 @@
namespace tint::fuzzers {
/// ShuffleTransform reorders the module scope declarations into a random order
class ShuffleTransform : public tint::transform::Transform {
class ShuffleTransform : public transform::Transform {
public:
/// Constructor
/// @param seed the random seed to use for the shuffling
explicit ShuffleTransform(size_t seed);
protected:
void Run(CloneContext& ctx,
const tint::transform::DataMap&,
tint::transform::DataMap&) const override;
/// @copydoc transform::Transform::Apply
ApplyResult Apply(const Program* program,
const transform::DataMap& inputs,
transform::DataMap& outputs) const override;
private:
size_t seed_;

View File

@ -23,6 +23,7 @@
#include "src/tint/ast/access.h"
#include "src/tint/ast/address_space.h"
#include "src/tint/ast/parameter.h"
#include "src/tint/sem/binding_point.h"
#include "src/tint/sem/expression.h"
#include "src/tint/sem/parameter_usage.h"
@ -212,6 +213,11 @@ class Parameter final : public Castable<Parameter, Variable> {
/// Destructor
~Parameter() override;
/// @returns the AST declaration node
const ast::Parameter* Declaration() const {
return static_cast<const ast::Parameter*>(Variable::Declaration());
}
/// @return the index of the parmeter in the function
uint32_t Index() const { return index_; }

View File

@ -31,21 +31,29 @@ AddBlockAttribute::AddBlockAttribute() = default;
AddBlockAttribute::~AddBlockAttribute() = default;
void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
Transform::ApplyResult AddBlockAttribute::Apply(const Program* src,
const DataMap&,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto& sem = src->Sem();
// A map from a type in the source program to a block-decorated wrapper that contains it in the
// destination program.
utils::Hashmap<const sem::Type*, const ast::Struct*, 8> wrapper_structs;
// Process global 'var' declarations that are buffers.
for (auto* global : ctx.src->AST().GlobalVariables()) {
bool made_changes = false;
for (auto* global : src->AST().GlobalVariables()) {
auto* var = sem.Get(global);
if (!ast::IsHostShareable(var->AddressSpace())) {
// Not declared in a host-sharable address space
continue;
}
made_changes = true;
auto* ty = var->Type()->UnwrapRef();
auto* str = ty->As<sem::Struct>();
@ -61,33 +69,36 @@ void AddBlockAttribute::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
const char* kMemberName = "inner";
auto* wrapper = wrapper_structs.GetOrCreate(ty, [&] {
auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(),
ctx.dst->AllocateNodeID());
auto wrapper_name = ctx.src->Symbols().NameFor(global->symbol) + "_block";
auto* ret = ctx.dst->create<ast::Struct>(
ctx.dst->Symbols().New(wrapper_name),
utils::Vector{ctx.dst->Member(kMemberName, CreateASTTypeFor(ctx, ty))},
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
auto wrapper_name = src->Symbols().NameFor(global->symbol) + "_block";
auto* ret = b.create<ast::Struct>(
b.Symbols().New(wrapper_name),
utils::Vector{b.Member(kMemberName, CreateASTTypeFor(ctx, ty))},
utils::Vector{block});
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), global, ret);
ctx.InsertBefore(src->AST().GlobalDeclarations(), global, ret);
return ret;
});
ctx.Replace(global->type, ctx.dst->ty.Of(wrapper));
ctx.Replace(global->type, b.ty.Of(wrapper));
// Insert a member accessor to get the original type from the wrapper at
// any usage of the original variable.
for (auto* user : var->Users()) {
ctx.Replace(user->Declaration(),
ctx.dst->MemberAccessor(ctx.Clone(global->symbol), kMemberName));
b.MemberAccessor(ctx.Clone(global->symbol), kMemberName));
}
} else {
// Add a block attribute to this struct directly.
auto* block = ctx.dst->ASTNodes().Create<BlockAttribute>(ctx.dst->ID(),
ctx.dst->AllocateNodeID());
auto* block = b.ASTNodes().Create<BlockAttribute>(b.ID(), b.AllocateNodeID());
ctx.InsertFront(str->Declaration()->attributes, block);
}
}
if (!made_changes) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
AddBlockAttribute::BlockAttribute::BlockAttribute(ProgramID pid, ast::NodeID nid)

View File

@ -53,14 +53,10 @@ class AddBlockAttribute final : public Castable<AddBlockAttribute, Transform> {
/// Destructor
~AddBlockAttribute() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -23,12 +23,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::AddEmptyEntryPoint);
using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
namespace {
AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* func : program->AST().Functions()) {
if (func->IsEntryPoint()) {
return false;
@ -37,13 +34,30 @@ bool AddEmptyEntryPoint::ShouldRun(const Program* program, const DataMap&) const
return true;
}
void AddEmptyEntryPoint::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.dst->Func(ctx.dst->Symbols().New("unused_entry_point"), {}, ctx.dst->ty.void_(), {},
utils::Vector{
ctx.dst->Stage(ast::PipelineStage::kCompute),
ctx.dst->WorkgroupSize(1_i),
});
} // namespace
AddEmptyEntryPoint::AddEmptyEntryPoint() = default;
AddEmptyEntryPoint::~AddEmptyEntryPoint() = default;
Transform::ApplyResult AddEmptyEntryPoint::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
b.Func(b.Symbols().New("unused_entry_point"), {}, b.ty.void_(), {},
utils::Vector{
b.Stage(ast::PipelineStage::kCompute),
b.WorkgroupSize(1_i),
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -27,19 +27,10 @@ class AddEmptyEntryPoint final : public Castable<AddEmptyEntryPoint, Transform>
/// Destructor
~AddEmptyEntryPoint() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -31,13 +31,153 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result);
namespace tint::transform {
namespace {
bool ShouldRun(const Program* program) {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
return true;
}
}
}
}
return false;
}
} // namespace
ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
/// The PIMPL state for this transform
/// PIMPL state for the transform
struct ArrayLengthFromUniform::State {
/// Constructor
/// @param program the source program
/// @param in the input transform data
/// @param out the output transform data
explicit State(const Program* program, const DataMap& in, DataMap& out)
: src(program), inputs(in), outputs(out) {}
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " +
std::string(TypeInfo::Of<ArrayLengthFromUniform>().name));
return Program(std::move(b));
}
if (!ShouldRun(ctx.src)) {
return SkipTransform;
}
const char* kBufferSizeMemberName = "buffer_size";
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
IterateArrayLengthOnStorageVar([&](const ast::CallExpression*, const sem::VariableUser*,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
if (idx_itr->second > max_buffer_size_index) {
max_buffer_size_index = idx_itr->second;
}
});
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = b.Structure(
b.Sym(), utils::Vector{
b.Member(kBufferSizeMemberName,
b.ty.array(b.ty.vec4(b.ty.u32()),
u32((max_buffer_size_index / 4) + 1))),
});
buffer_size_ubo =
b.GlobalVar(b.Sym(), b.ty.Of(buffer_size_struct), ast::AddressSpace::kUniform,
b.Group(AInt(cfg->ubo_binding.group)),
b.Binding(AInt(cfg->ubo_binding.binding)));
}
return buffer_size_ubo;
};
std::unordered_set<uint32_t> used_size_indices;
IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
const sem::VariableUser* storage_buffer_sem,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
uint32_t size_index = idx_itr->second;
used_size_indices.insert(size_index);
// Load the total storage buffer size from the UBO.
uint32_t array_index = size_index / 4;
auto* vec_expr = b.IndexAccessor(
b.MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index));
uint32_t vec_index = size_index % 4;
auto* total_storage_buffer_size = b.IndexAccessor(vec_expr, u32(vec_index));
// Calculate actual array length
// total_storage_buffer_size - array_offset
// array_length = ----------------------------------------
// array_stride
const ast::Expression* total_size = total_storage_buffer_size;
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
const sem::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
// The variable is a struct, so subtract the byte offset of the array
// member.
auto* array_member_sem = str->Members().back();
array_type = array_member_sem->Type()->As<sem::Array>();
total_size = b.Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
array_type = arr;
} else {
TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
return;
}
auto* array_length = b.Div(total_size, u32(array_type->Stride()));
ctx.Replace(call_expr, array_length);
});
outputs.Add<Result>(used_size_indices);
ctx.Clone();
return Program(std::move(b));
}
private:
/// The source program
const Program* const src;
/// The transform inputs
const DataMap& inputs;
/// The transform outputs
DataMap& outputs;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext& ctx;
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Iterate over all arrayLength() builtins that operate on
/// storage buffer variables.
@ -48,10 +188,10 @@ struct ArrayLengthFromUniform::State {
/// sem::GlobalVariable for the storage buffer.
template <typename F>
void IterateArrayLengthOnStorageVar(F&& functor) {
auto& sem = ctx.src->Sem();
auto& sem = src->Sem();
// Find all calls to the arrayLength() builtin.
for (auto* node : ctx.src->ASTNodes().Objects()) {
for (auto* node : src->ASTNodes().Objects()) {
auto* call_expr = node->As<ast::CallExpression>();
if (!call_expr) {
continue;
@ -79,7 +219,7 @@ struct ArrayLengthFromUniform::State {
// arrayLength(&array_var)
auto* param = call_expr->args[0]->As<ast::UnaryOpExpression>();
if (!param || param->op != ast::UnaryOp::kAddressOf) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
@ -90,7 +230,7 @@ struct ArrayLengthFromUniform::State {
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
@ -99,8 +239,7 @@ struct ArrayLengthFromUniform::State {
// Get the index to use for the buffer size array.
auto* var = tint::As<sem::GlobalVariable>(storage_buffer_sem->Variable());
if (!var) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "storage buffer is not a global variable";
TINT_ICE(Transform, b.Diagnostics()) << "storage buffer is not a global variable";
break;
}
functor(call_expr, storage_buffer_sem, var);
@ -108,117 +247,10 @@ struct ArrayLengthFromUniform::State {
}
};
bool ArrayLengthFromUniform::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
return true;
}
}
}
}
return false;
}
void ArrayLengthFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
return;
}
const char* kBufferSizeMemberName = "buffer_size";
// Determine the size of the buffer size array.
uint32_t max_buffer_size_index = 0;
State{ctx}.IterateArrayLengthOnStorageVar(
[&](const ast::CallExpression*, const sem::VariableUser*, const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
if (idx_itr->second > max_buffer_size_index) {
max_buffer_size_index = idx_itr->second;
}
});
// Get (or create, on first call) the uniform buffer that will receive the
// size of each storage buffer in the module.
const ast::Variable* buffer_size_ubo = nullptr;
auto get_ubo = [&]() {
if (!buffer_size_ubo) {
// Emit an array<vec4<u32>, N>, where N is 1/4 number of elements.
// We do this because UBOs require an element stride that is 16-byte
// aligned.
auto* buffer_size_struct = ctx.dst->Structure(
ctx.dst->Sym(),
utils::Vector{
ctx.dst->Member(kBufferSizeMemberName,
ctx.dst->ty.array(ctx.dst->ty.vec4(ctx.dst->ty.u32()),
u32((max_buffer_size_index / 4) + 1))),
});
buffer_size_ubo = ctx.dst->GlobalVar(ctx.dst->Sym(), ctx.dst->ty.Of(buffer_size_struct),
ast::AddressSpace::kUniform,
ctx.dst->Group(AInt(cfg->ubo_binding.group)),
ctx.dst->Binding(AInt(cfg->ubo_binding.binding)));
}
return buffer_size_ubo;
};
std::unordered_set<uint32_t> used_size_indices;
State{ctx}.IterateArrayLengthOnStorageVar([&](const ast::CallExpression* call_expr,
const sem::VariableUser* storage_buffer_sem,
const sem::GlobalVariable* var) {
auto binding = var->BindingPoint();
auto idx_itr = cfg->bindpoint_to_size_index.find(binding);
if (idx_itr == cfg->bindpoint_to_size_index.end()) {
return;
}
uint32_t size_index = idx_itr->second;
used_size_indices.insert(size_index);
// Load the total storage buffer size from the UBO.
uint32_t array_index = size_index / 4;
auto* vec_expr = ctx.dst->IndexAccessor(
ctx.dst->MemberAccessor(get_ubo()->symbol, kBufferSizeMemberName), u32(array_index));
uint32_t vec_index = size_index % 4;
auto* total_storage_buffer_size = ctx.dst->IndexAccessor(vec_expr, u32(vec_index));
// Calculate actual array length
// total_storage_buffer_size - array_offset
// array_length = ----------------------------------------
// array_stride
const ast::Expression* total_size = total_storage_buffer_size;
auto* storage_buffer_type = storage_buffer_sem->Type()->UnwrapRef();
const sem::Array* array_type = nullptr;
if (auto* str = storage_buffer_type->As<sem::Struct>()) {
// The variable is a struct, so subtract the byte offset of the array
// member.
auto* array_member_sem = str->Members().back();
array_type = array_member_sem->Type()->As<sem::Array>();
total_size = ctx.dst->Sub(total_storage_buffer_size, u32(array_member_sem->Offset()));
} else if (auto* arr = storage_buffer_type->As<sem::Array>()) {
array_type = arr;
} else {
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
return;
}
auto* array_length = ctx.dst->Div(total_size, u32(array_type->Stride()));
ctx.Replace(call_expr, array_length);
});
ctx.Clone();
outputs.Add<Result>(used_size_indices);
Transform::ApplyResult ArrayLengthFromUniform::Apply(const Program* src,
const DataMap& inputs,
DataMap& outputs) const {
return State{src, inputs, outputs}.Run();
}
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) : ubo_binding(ubo_bp) {}

View File

@ -100,22 +100,12 @@ class ArrayLengthFromUniform final : public Castable<ArrayLengthFromUniform, Tra
std::unordered_set<uint32_t> used_size_indices;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
/// The PIMPL state for this transform
struct State;
};

View File

@ -28,7 +28,13 @@ using ArrayLengthFromUniformTest = TransformTest;
TEST_F(ArrayLengthFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunNoArrayLength) {
@ -45,7 +51,13 @@ fn main() {
}
)";
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src));
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
EXPECT_FALSE(ShouldRun<ArrayLengthFromUniform>(src, data));
}
TEST_F(ArrayLengthFromUniformTest, ShouldRunWithArrayLength) {
@ -63,7 +75,13 @@ fn main() {
}
)";
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src));
ArrayLengthFromUniform::Config cfg({0, 30u});
cfg.bindpoint_to_size_index.emplace(sem::BindingPoint{0, 0}, 0);
DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
EXPECT_TRUE(ShouldRun<ArrayLengthFromUniform>(src, data));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {

View File

@ -40,19 +40,21 @@ BindingRemapper::Remappings::~Remappings() = default;
BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default;
bool BindingRemapper::ShouldRun(const Program*, const DataMap& inputs) const {
if (auto* remappings = inputs.Get<Remappings>()) {
return !remappings->binding_points.empty() || !remappings->access_controls.empty();
}
return false;
}
Transform::ApplyResult BindingRemapper::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
auto* remappings = inputs.Get<Remappings>();
if (!remappings) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
return;
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return Program(std::move(b));
}
if (remappings->binding_points.empty() && remappings->access_controls.empty()) {
return SkipTransform;
}
// A set of post-remapped binding points that need to be decorated with a
@ -62,11 +64,11 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
if (remappings->allow_collisions) {
// Scan for binding point collisions generated by this transform.
// Populate all collisions in the `add_collision_attr` set.
for (auto* func_ast : ctx.src->AST().Functions()) {
for (auto* func_ast : src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
auto* func = ctx.src->Sem().Get(func_ast);
auto* func = src->Sem().Get(func_ast);
std::unordered_map<sem::BindingPoint, int> binding_point_counts;
for (auto* global : func->TransitivelyReferencedGlobals()) {
if (global->Declaration()->HasBindingPoint()) {
@ -90,9 +92,9 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
}
}
for (auto* var : ctx.src->AST().Globals<ast::Var>()) {
for (auto* var : src->AST().Globals<ast::Var>()) {
if (var->HasBindingPoint()) {
auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(var);
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(var);
// The original binding point
BindingPoint from = global_sem->BindingPoint();
@ -106,8 +108,8 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) {
BindingPoint to = bp_it->second;
auto* new_group = ctx.dst->Group(AInt(to.group));
auto* new_binding = ctx.dst->Binding(AInt(to.binding));
auto* new_group = b.Group(AInt(to.group));
auto* new_binding = b.Binding(AInt(to.binding));
auto* old_group = ast::GetAttribute<ast::GroupAttribute>(var->attributes);
auto* old_binding = ast::GetAttribute<ast::BindingAttribute>(var->attributes);
@ -122,37 +124,37 @@ void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) co
if (ac_it != remappings->access_controls.end()) {
ast::Access ac = ac_it->second;
if (ac == ast::Access::kUndefined) {
ctx.dst->Diagnostics().add_error(
b.Diagnostics().add_error(
diag::System::Transform,
"invalid access mode (" + std::to_string(static_cast<uint32_t>(ac)) + ")");
return;
return Program(std::move(b));
}
auto* sem = ctx.src->Sem().Get(var);
auto* sem = src->Sem().Get(var);
if (sem->AddressSpace() != ast::AddressSpace::kStorage) {
ctx.dst->Diagnostics().add_error(
b.Diagnostics().add_error(
diag::System::Transform,
"cannot apply access control to variable with address space " +
std::string(utils::ToString(sem->AddressSpace())));
return;
return Program(std::move(b));
}
auto* ty = sem->Type()->UnwrapRef();
const ast::Type* inner_ty = CreateASTTypeFor(ctx, ty);
auto* new_var =
ctx.dst->Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
var->declared_address_space, ac, ctx.Clone(var->initializer),
ctx.Clone(var->attributes));
auto* new_var = b.Var(ctx.Clone(var->source), ctx.Clone(var->symbol), inner_ty,
var->declared_address_space, ac, ctx.Clone(var->initializer),
ctx.Clone(var->attributes));
ctx.Replace(var, new_var);
}
// Add `DisableValidationAttribute`s if required
if (add_collision_attr.count(bp)) {
auto* attribute = ctx.dst->Disable(ast::DisabledValidation::kBindingPointCollision);
auto* attribute = b.Disable(ast::DisabledValidation::kBindingPointCollision);
ctx.InsertBefore(var->attributes, *var->attributes.begin(), attribute);
}
}
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -67,19 +67,10 @@ class BindingRemapper final : public Castable<BindingRemapper, Transform> {
BindingRemapper();
~BindingRemapper() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -23,12 +23,6 @@ namespace {
using BindingRemapperTest = TransformTest;
TEST_F(BindingRemapperTest, ShouldRunNoRemappings) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<BindingRemapper>(src));
}
TEST_F(BindingRemapperTest, ShouldRunEmptyRemappings) {
auto* src = R"()";
@ -350,7 +344,7 @@ fn f() {
}
)";
auto* expect = src;
auto* expect = R"(error: missing transform data for tint::transform::BindingRemapper)";
auto got = Run<BindingRemapper>(src);

View File

@ -29,7 +29,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::BuiltinPolyfill::Config);
namespace tint::transform {
/// The PIMPL state for the BuiltinPolyfill transform
/// PIMPL state for the transform
struct BuiltinPolyfill::State {
/// Constructor
/// @param c the CloneContext
@ -604,193 +604,100 @@ BuiltinPolyfill::BuiltinPolyfill() = default;
BuiltinPolyfill::~BuiltinPolyfill() = default;
bool BuiltinPolyfill::ShouldRun(const Program* program, const DataMap& data) const {
if (auto* cfg = data.Get<Config>()) {
auto builtins = cfg->builtins;
auto& sem = program->Sem();
for (auto* node : program->ASTNodes().Objects()) {
if (auto* call = sem.Get<sem::Call>(node)) {
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (call->Stage() == sem::EvaluationStage::kConstant) {
continue; // Don't polyfill @const expressions
}
switch (builtin->Type()) {
case sem::BuiltinType::kAcosh:
if (builtins.acosh != Level::kNone) {
return true;
}
break;
case sem::BuiltinType::kAsinh:
if (builtins.asinh) {
return true;
}
break;
case sem::BuiltinType::kAtanh:
if (builtins.atanh != Level::kNone) {
return true;
}
break;
case sem::BuiltinType::kClamp:
if (builtins.clamp_int) {
auto& sig = builtin->Signature();
return sig.parameters[0]->Type()->is_integer_scalar_or_vector();
}
break;
case sem::BuiltinType::kCountLeadingZeros:
if (builtins.count_leading_zeros) {
return true;
}
break;
case sem::BuiltinType::kCountTrailingZeros:
if (builtins.count_trailing_zeros) {
return true;
}
break;
case sem::BuiltinType::kExtractBits:
if (builtins.extract_bits != Level::kNone) {
return true;
}
break;
case sem::BuiltinType::kFirstLeadingBit:
if (builtins.first_leading_bit) {
return true;
}
break;
case sem::BuiltinType::kFirstTrailingBit:
if (builtins.first_trailing_bit) {
return true;
}
break;
case sem::BuiltinType::kInsertBits:
if (builtins.insert_bits != Level::kNone) {
return true;
}
break;
case sem::BuiltinType::kSaturate:
if (builtins.saturate) {
return true;
}
break;
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
if (builtins.texture_sample_base_clamp_to_edge_2d_f32) {
auto& sig = builtin->Signature();
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
return stex->type()->Is<sem::F32>();
}
}
break;
case sem::BuiltinType::kQuantizeToF16:
if (builtins.quantize_to_vec_f16) {
if (builtin->ReturnType()->Is<sem::Vector>()) {
return true;
}
}
break;
default:
break;
}
}
}
}
}
return false;
}
void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) const {
Transform::ApplyResult BuiltinPolyfill::Apply(const Program* src,
const DataMap& data,
DataMap&) const {
auto* cfg = data.Get<Config>();
if (!cfg) {
ctx.Clone();
return;
return SkipTransform;
}
std::unordered_map<const sem::Builtin*, Symbol> polyfills;
auto& builtins = cfg->builtins;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto builtins = cfg->builtins;
State s{ctx, builtins};
if (auto* call = s.sem.Get<sem::Call>(expr)) {
utils::Hashmap<const sem::Builtin*, Symbol, 8> polyfills;
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State s{ctx, builtins};
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* call = src->Sem().Get<sem::Call>(node)) {
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
if (call->Stage() == sem::EvaluationStage::kConstant) {
return nullptr; // Don't polyfill @const expressions
continue; // Don't polyfill @const expressions
}
Symbol polyfill;
switch (builtin->Type()) {
case sem::BuiltinType::kAcosh:
if (builtins.acosh != Level::kNone) {
polyfill = utils::GetOrCreate(
polyfills, builtin, [&] { return s.acosh(builtin->ReturnType()); });
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.acosh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAsinh:
if (builtins.asinh) {
polyfill = utils::GetOrCreate(
polyfills, builtin, [&] { return s.asinh(builtin->ReturnType()); });
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.asinh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kAtanh:
if (builtins.atanh != Level::kNone) {
polyfill = utils::GetOrCreate(
polyfills, builtin, [&] { return s.atanh(builtin->ReturnType()); });
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.atanh(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kClamp:
if (builtins.clamp_int) {
auto& sig = builtin->Signature();
if (sig.parameters[0]->Type()->is_integer_scalar_or_vector()) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
return s.clampInteger(builtin->ReturnType());
});
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.clampInteger(builtin->ReturnType()); });
}
}
break;
case sem::BuiltinType::kCountLeadingZeros:
if (builtins.count_leading_zeros) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.countLeadingZeros(builtin->ReturnType());
});
}
break;
case sem::BuiltinType::kCountTrailingZeros:
if (builtins.count_trailing_zeros) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.countTrailingZeros(builtin->ReturnType());
});
}
break;
case sem::BuiltinType::kExtractBits:
if (builtins.extract_bits != Level::kNone) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
return s.extractBits(builtin->ReturnType());
});
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.extractBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstLeadingBit:
if (builtins.first_leading_bit) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
return s.firstLeadingBit(builtin->ReturnType());
});
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.firstLeadingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kFirstTrailingBit:
if (builtins.first_trailing_bit) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
return s.firstTrailingBit(builtin->ReturnType());
});
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.firstTrailingBit(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kInsertBits:
if (builtins.insert_bits != Level::kNone) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
return s.insertBits(builtin->ReturnType());
});
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.insertBits(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kSaturate:
if (builtins.saturate) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
return s.saturate(builtin->ReturnType());
});
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.saturate(builtin->ReturnType()); });
}
break;
case sem::BuiltinType::kTextureSampleBaseClampToEdge:
@ -799,7 +706,7 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons
auto* tex = sig.Parameter(sem::ParameterUsage::kTexture);
if (auto* stex = tex->Type()->As<sem::SampledTexture>()) {
if (stex->type()->Is<sem::F32>()) {
polyfill = utils::GetOrCreate(polyfills, builtin, [&] {
polyfill = polyfills.GetOrCreate(builtin, [&] {
return s.textureSampleBaseClampToEdge_2d_f32();
});
}
@ -809,8 +716,8 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons
case sem::BuiltinType::kQuantizeToF16:
if (builtins.quantize_to_vec_f16) {
if (auto* vec = builtin->ReturnType()->As<sem::Vector>()) {
polyfill = utils::GetOrCreate(polyfills, builtin,
[&] { return s.quantizeToF16(vec); });
polyfill = polyfills.GetOrCreate(
builtin, [&] { return s.quantizeToF16(vec); });
}
}
break;
@ -819,14 +726,20 @@ void BuiltinPolyfill::Run(CloneContext& ctx, const DataMap& data, DataMap&) cons
break;
}
if (polyfill.IsValid()) {
return s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
auto* replacement = s.b.Call(polyfill, ctx.Clone(call->Declaration()->args));
ctx.Replace(call->Declaration(), replacement);
made_changes = true;
}
}
}
return nullptr;
});
}
if (!made_changes) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
BuiltinPolyfill::Config::Config(const Builtins& b) : builtins(b) {}

View File

@ -87,21 +87,13 @@ class BuiltinPolyfill final : public Castable<BuiltinPolyfill, Transform> {
const Builtins builtins;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
protected:
private:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -1561,7 +1561,8 @@ fn f() {
TEST_F(BuiltinPolyfillTest, DISABLED_InsertBits_ConstantExpression) {
auto* src = R"(
fn f() {
let r : i32 = insertBits(1234, 5678, 5u, 6u);
let v = 1234i;
let r : i32 = insertBits(v, 5678, 5u, 6u);
}
)";
@ -1975,10 +1976,6 @@ fn f() {
)";
auto* expect = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : vec2<f32>) -> vec4<f32> {
let dims = vec2<f32>(textureDimensions(t, 0));
let half_texel = (vec2<f32>(0.5) / dims);
@ -1986,6 +1983,10 @@ fn tint_textureSampleBaseClampToEdge(t : texture_2d<f32>, s : sampler, coord : v
return textureSampleLevel(t, s, clamped, 0);
}
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
fn f() {
let r = tint_textureSampleBaseClampToEdge(t, s, vec2<f32>(0.5));
}

View File

@ -40,6 +40,19 @@ namespace tint::transform {
namespace {
bool ShouldRun(const Program* program) {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
return true;
}
}
}
}
return false;
}
/// ArrayUsage describes a runtime array usage.
/// It is used as a key by the array_length_by_usage map.
struct ArrayUsage {
@ -73,21 +86,16 @@ const CalculateArrayLength::BufferSizeIntrinsic* CalculateArrayLength::BufferSiz
CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default;
bool CalculateArrayLength::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (auto* sem_fn = program->Sem().Get(fn)) {
for (auto* builtin : sem_fn->DirectlyCalledBuiltins()) {
if (builtin->Type() == sem::BuiltinType::kArrayLength) {
return true;
}
}
}
Transform::ApplyResult CalculateArrayLength::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
return false;
}
void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto& sem = src->Sem();
// get_buffer_size_intrinsic() emits the function decorated with
// BufferSizeIntrinsic that is transformed by the HLSL writer into a call to
@ -95,24 +103,20 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
std::unordered_map<const sem::Reference*, Symbol> buffer_size_intrinsics;
auto get_buffer_size_intrinsic = [&](const sem::Reference* buffer_type) {
return utils::GetOrCreate(buffer_size_intrinsics, buffer_type, [&] {
auto name = ctx.dst->Sym();
auto name = b.Sym();
auto* type = CreateASTTypeFor(ctx, buffer_type);
auto* disable_validation =
ctx.dst->Disable(ast::DisabledValidation::kFunctionParameter);
ctx.dst->AST().AddFunction(ctx.dst->create<ast::Function>(
auto* disable_validation = b.Disable(ast::DisabledValidation::kFunctionParameter);
b.AST().AddFunction(b.create<ast::Function>(
name,
utils::Vector{
ctx.dst->Param("buffer",
ctx.dst->ty.pointer(type, buffer_type->AddressSpace(),
buffer_type->Access()),
utils::Vector{disable_validation}),
ctx.dst->Param("result", ctx.dst->ty.pointer(ctx.dst->ty.u32(),
ast::AddressSpace::kFunction)),
b.Param("buffer",
b.ty.pointer(type, buffer_type->AddressSpace(), buffer_type->Access()),
utils::Vector{disable_validation}),
b.Param("result", b.ty.pointer(b.ty.u32(), ast::AddressSpace::kFunction)),
},
ctx.dst->ty.void_(), nullptr,
b.ty.void_(), nullptr,
utils::Vector{
ctx.dst->ASTNodes().Create<BufferSizeIntrinsic>(ctx.dst->ID(),
ctx.dst->AllocateNodeID()),
b.ASTNodes().Create<BufferSizeIntrinsic>(b.ID(), b.AllocateNodeID()),
},
utils::Empty));
@ -123,7 +127,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
std::unordered_map<ArrayUsage, Symbol, ArrayUsage::Hasher> array_length_by_usage;
// Find all the arrayLength() calls...
for (auto* node : ctx.src->ASTNodes().Objects()) {
for (auto* node : src->ASTNodes().Objects()) {
if (auto* call_expr = node->As<ast::CallExpression>()) {
auto* call = sem.Get(call_expr)->UnwrapMaterialize()->As<sem::Call>();
if (auto* builtin = call->Target()->As<sem::Builtin>()) {
@ -149,7 +153,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
auto* arg = call_expr->args[0];
auto* address_of = arg->As<ast::UnaryOpExpression>();
if (!address_of || address_of->op != ast::UnaryOp::kAddressOf) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "arrayLength() expected address-of, got " << arg->TypeInfo().name;
}
auto* storage_buffer_expr = address_of->expr;
@ -158,7 +162,7 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
}
auto* storage_buffer_sem = sem.Get<sem::VariableUser>(storage_buffer_expr);
if (!storage_buffer_sem) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be &array_var or "
"&struct_var.array_member";
break;
@ -179,25 +183,24 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
// Construct the variable that'll hold the result of
// RWByteAddressBuffer.GetDimensions()
auto* buffer_size_result = ctx.dst->Decl(ctx.dst->Var(
ctx.dst->Sym(), ctx.dst->ty.u32(), ctx.dst->Expr(0_u)));
auto* buffer_size_result =
b.Decl(b.Var(b.Sym(), b.ty.u32(), b.Expr(0_u)));
// Call storage_buffer.GetDimensions(&buffer_size_result)
auto* call_get_dims = ctx.dst->CallStmt(ctx.dst->Call(
auto* call_get_dims = b.CallStmt(b.Call(
// BufferSizeIntrinsic(X, ARGS...) is
// translated to:
// X.GetDimensions(ARGS..) by the writer
buffer_size, ctx.dst->AddressOf(ctx.Clone(storage_buffer_expr)),
ctx.dst->AddressOf(
ctx.dst->Expr(buffer_size_result->variable->symbol))));
buffer_size, b.AddressOf(ctx.Clone(storage_buffer_expr)),
b.AddressOf(b.Expr(buffer_size_result->variable->symbol))));
// Calculate actual array length
// total_storage_buffer_size - array_offset
// array_length = ----------------------------------------
// array_stride
auto name = ctx.dst->Sym();
auto name = b.Sym();
const ast::Expression* total_size =
ctx.dst->Expr(buffer_size_result->variable);
b.Expr(buffer_size_result->variable);
const sem::Array* array_type = Switch(
storage_buffer_type->StoreType(),
@ -205,23 +208,21 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
// The variable is a struct, so subtract the byte offset of
// the array member.
auto* array_member_sem = str->Members().back();
total_size =
ctx.dst->Sub(total_size, u32(array_member_sem->Offset()));
total_size = b.Sub(total_size, u32(array_member_sem->Offset()));
return array_member_sem->Type()->As<sem::Array>();
},
[&](const sem::Array* arr) { return arr; });
if (!array_type) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "expected form of arrayLength argument to be "
"&array_var or &struct_var.array_member";
return name;
}
uint32_t array_stride = array_type->Size();
auto* array_length_var = ctx.dst->Decl(
ctx.dst->Let(name, ctx.dst->ty.u32(),
ctx.dst->Div(total_size, u32(array_stride))));
auto* array_length_var = b.Decl(
b.Let(name, b.ty.u32(), b.Div(total_size, u32(array_stride))));
// Insert the array length calculations at the top of the block
ctx.InsertBefore(block->statements, block->statements[0],
@ -234,13 +235,14 @@ void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
});
// Replace the call to arrayLength() with the array length variable
ctx.Replace(call_expr, ctx.dst->Expr(array_length));
ctx.Replace(call_expr, b.Expr(array_length));
}
}
}
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -59,19 +59,10 @@ class CalculateArrayLength final : public Castable<CalculateArrayLength, Transfo
/// Destructor
~CalculateArrayLength() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -123,7 +123,7 @@ bool HasSampleMask(utils::VectorRef<const ast::Attribute*> attrs) {
} // namespace
/// State holds the current transform state for a single entry point.
/// PIMPL state for the transform
struct CanonicalizeEntryPointIO::State {
/// OutputValue represents a shader result that the wrapper function produces.
struct OutputValue {
@ -770,17 +770,22 @@ struct CanonicalizeEntryPointIO::State {
}
};
void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Transform::ApplyResult CanonicalizeEntryPointIO::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
return;
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return Program(std::move(b));
}
// Remove entry point IO attributes from struct declarations.
// New structures will be created for each entry point, as necessary.
for (auto* ty : ctx.src->AST().TypeDecls()) {
for (auto* ty : src->AST().TypeDecls()) {
if (auto* struct_ty = ty->As<ast::Struct>()) {
for (auto* member : struct_ty->members) {
for (auto* attr : member->attributes) {
@ -792,7 +797,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat
}
}
for (auto* func_ast : ctx.src->AST().Functions()) {
for (auto* func_ast : src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) {
continue;
}
@ -802,6 +807,7 @@ void CanonicalizeEntryPointIO::Run(CloneContext& ctx, const DataMap& inputs, Dat
}
ctx.Clone();
return Program(std::move(b));
}
CanonicalizeEntryPointIO::Config::Config(ShaderStyle style,

View File

@ -127,15 +127,12 @@ class CanonicalizeEntryPointIO final : public Castable<CanonicalizeEntryPointIO,
CanonicalizeEntryPointIO();
~CanonicalizeEntryPointIO() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;
};

View File

@ -14,7 +14,7 @@
#include "src/tint/transform/clamp_frag_depth.h"
#include <utility>
#include <utility>
#include "src/tint/ast/attribute.h"
#include "src/tint/ast/builtin_attribute.h"
@ -64,12 +64,7 @@ bool ReturnsFragDepthInStruct(const sem::Info& sem, const ast::Function* fn) {
return false;
}
} // anonymous namespace
ClampFragDepth::ClampFragDepth() = default;
ClampFragDepth::~ClampFragDepth() = default;
bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
auto& sem = program->Sem();
for (auto* fn : program->AST().Functions()) {
@ -82,22 +77,33 @@ bool ClampFragDepth::ShouldRun(const Program* program, const DataMap&) const {
return false;
}
void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
} // anonymous namespace
ClampFragDepth::ClampFragDepth() = default;
ClampFragDepth::~ClampFragDepth() = default;
Transform::ApplyResult ClampFragDepth::Apply(const Program* src, const DataMap&, DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
// Abort on any use of push constants in the module.
for (auto* global : ctx.src->AST().GlobalVariables()) {
for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (var->declared_address_space == ast::AddressSpace::kPushConstant) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "ClampFragDepth doesn't know how to handle module that already use push "
"constants.";
return;
return Program(std::move(b));
}
}
}
auto& b = *ctx.dst;
auto& sem = ctx.src->Sem();
auto& sym = ctx.src->Symbols();
if (!ShouldRun(src)) {
return SkipTransform;
}
auto& sem = src->Sem();
auto& sym = src->Symbols();
// At least one entry-point needs clamping. Add the following to the module:
//
@ -197,6 +203,7 @@ void ClampFragDepth::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -61,19 +61,10 @@ class ClampFragDepth final : public Castable<ClampFragDepth, Transform> {
/// Destructor
~ClampFragDepth() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -47,10 +47,14 @@ CombineSamplers::BindingInfo::BindingInfo(const BindingMap& map,
CombineSamplers::BindingInfo::BindingInfo(const BindingInfo& other) = default;
CombineSamplers::BindingInfo::~BindingInfo() = default;
/// The PIMPL state for the CombineSamplers transform
/// PIMPL state for the transform
struct CombineSamplers::State {
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext& ctx;
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// The binding info
const BindingInfo* binding_info;
@ -88,9 +92,9 @@ struct CombineSamplers::State {
}
/// Constructor
/// @param context the clone context
/// @param program the source program
/// @param info the binding map information
State(CloneContext& context, const BindingInfo* info) : ctx(context), binding_info(info) {}
State(const Program* program, const BindingInfo* info) : src(program), binding_info(info) {}
/// Creates a combined sampler global variables.
/// (Note this is actually a Texture node at the AST level, but it will be
@ -145,8 +149,9 @@ struct CombineSamplers::State {
}
}
/// Performs the transformation
void Run() {
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
auto& sem = ctx.src->Sem();
// Remove all texture and sampler global variables. These will be replaced
@ -169,14 +174,14 @@ struct CombineSamplers::State {
// Rewrite all function signatures to use combined samplers, and remove
// separate textures & samplers. Create new combined globals where found.
ctx.ReplaceAll([&](const ast::Function* src) -> const ast::Function* {
if (auto* func = sem.Get(src)) {
auto pairs = func->TextureSamplerPairs();
ctx.ReplaceAll([&](const ast::Function* ast_fn) -> const ast::Function* {
if (auto* fn = sem.Get(ast_fn)) {
auto pairs = fn->TextureSamplerPairs();
if (pairs.IsEmpty()) {
return nullptr;
}
utils::Vector<const ast::Parameter*, 8> params;
for (auto pair : func->TextureSamplerPairs()) {
for (auto pair : fn->TextureSamplerPairs()) {
const sem::Variable* texture_var = pair.first;
const sem::Variable* sampler_var = pair.second;
std::string name =
@ -197,23 +202,23 @@ struct CombineSamplers::State {
auto* type = CreateCombinedASTTypeFor(texture_var, sampler_var);
auto* var = ctx.dst->Param(ctx.dst->Symbols().New(name), type);
params.Push(var);
function_combined_texture_samplers_[func][pair] = var;
function_combined_texture_samplers_[fn][pair] = var;
}
}
// Filter out separate textures and samplers from the original
// function signature.
for (auto* var : src->params) {
if (!sem.Get(var->type)->IsAnyOf<sem::Texture, sem::Sampler>()) {
params.Push(ctx.Clone(var));
for (auto* param : fn->Parameters()) {
if (!param->Type()->IsAnyOf<sem::Texture, sem::Sampler>()) {
params.Push(ctx.Clone(param->Declaration()));
}
}
// Create a new function signature that differs only in the parameter
// list.
auto symbol = ctx.Clone(src->symbol);
auto* return_type = ctx.Clone(src->return_type);
auto* body = ctx.Clone(src->body);
auto attributes = ctx.Clone(src->attributes);
auto return_type_attributes = ctx.Clone(src->return_type_attributes);
auto symbol = ctx.Clone(ast_fn->symbol);
auto* return_type = ctx.Clone(ast_fn->return_type);
auto* body = ctx.Clone(ast_fn->body);
auto attributes = ctx.Clone(ast_fn->attributes);
auto return_type_attributes = ctx.Clone(ast_fn->return_type_attributes);
return ctx.dst->create<ast::Function>(symbol, params, return_type, body,
std::move(attributes),
std::move(return_type_attributes));
@ -327,6 +332,7 @@ struct CombineSamplers::State {
});
ctx.Clone();
return Program(std::move(b));
}
};
@ -334,15 +340,18 @@ CombineSamplers::CombineSamplers() = default;
CombineSamplers::~CombineSamplers() = default;
void CombineSamplers::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Transform::ApplyResult CombineSamplers::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
auto* binding_info = inputs.Get<BindingInfo>();
if (!binding_info) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
return;
ProgramBuilder b;
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return Program(std::move(b));
}
State(ctx, binding_info).Run();
return State(src, binding_info).Run();
}
} // namespace tint::transform

View File

@ -88,17 +88,13 @@ class CombineSamplers final : public Castable<CombineSamplers, Transform> {
/// Destructor
~CombineSamplers() override;
protected:
/// The PIMPL state for this transform
struct State;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
private:
struct State;
};
} // namespace tint::transform

View File

@ -47,6 +47,18 @@ namespace tint::transform {
namespace {
bool ShouldRun(const Program* program) {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
if (var->AddressSpace() == ast::AddressSpace::kStorage ||
var->AddressSpace() == ast::AddressSpace::kUniform) {
return true;
}
}
}
return false;
}
/// Offset is a simple ast::Expression builder interface, used to build byte
/// offsets for storage and uniform buffer accesses.
struct Offset : Castable<Offset> {
@ -291,7 +303,7 @@ struct Store {
} // namespace
/// State holds the current transform state
/// PIMPL state for the transform
struct DecomposeMemoryAccess::State {
/// The clone context
CloneContext& ctx;
@ -477,7 +489,7 @@ struct DecomposeMemoryAccess::State {
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles storage and uniform.
// * Runtime-sized arrays are not loadable.
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
@ -578,7 +590,7 @@ struct DecomposeMemoryAccess::State {
// * Override-expression counts can only be applied to workgroup
// arrays, and this method only handles storage and uniform.
// * Runtime-sized arrays are not storable.
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
arr_cnt = 1;
}
@ -808,21 +820,16 @@ bool DecomposeMemoryAccess::Intrinsic::IsAtomic() const {
DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
bool DecomposeMemoryAccess::ShouldRun(const Program* program, const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* var = program->Sem().Get<sem::Variable>(decl)) {
if (var->AddressSpace() == ast::AddressSpace::kStorage ||
var->AddressSpace() == ast::AddressSpace::kUniform) {
return true;
}
}
Transform::ApplyResult DecomposeMemoryAccess::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
return false;
}
void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
auto& sem = src->Sem();
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state(ctx);
// Scan the AST nodes for storage and uniform buffer accesses. Complex
@ -833,7 +840,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
// Inner-most expression nodes are guaranteed to be visited first because AST
// nodes are fully immutable and require their children to be constructed
// first so their pointer can be passed to the parent's initializer.
for (auto* node : ctx.src->ASTNodes().Objects()) {
for (auto* node : src->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
// X
if (auto* var = sem.Get<sem::VariableUser>(ident)) {
@ -1001,6 +1008,7 @@ void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) con
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -108,20 +108,12 @@ class DecomposeMemoryAccess final : public Castable<DecomposeMemoryAccess, Trans
/// Destructor
~DecomposeMemoryAccess() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;
};

View File

@ -34,13 +34,7 @@ namespace {
using DecomposedArrays = std::unordered_map<const sem::Array*, Symbol>;
} // namespace
DecomposeStridedArray::DecomposeStridedArray() = default;
DecomposeStridedArray::~DecomposeStridedArray() = default;
bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ast = node->As<ast::Array>()) {
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
@ -51,8 +45,22 @@ bool DecomposeStridedArray::ShouldRun(const Program* program, const DataMap&) co
return false;
}
void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
const auto& sem = ctx.src->Sem();
} // namespace
DecomposeStridedArray::DecomposeStridedArray() = default;
DecomposeStridedArray::~DecomposeStridedArray() = default;
Transform::ApplyResult DecomposeStridedArray::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
const auto& sem = src->Sem();
static constexpr const char* kMemberName = "el";
@ -69,23 +77,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con
if (auto* arr = sem.Get(ast)) {
if (!arr->IsStrideImplicit()) {
auto el_ty = utils::GetOrCreate(decomposed, arr, [&] {
auto name = ctx.dst->Symbols().New("strided_arr");
auto name = b.Symbols().New("strided_arr");
auto* member_ty = ctx.Clone(ast->type);
auto* member = ctx.dst->Member(kMemberName, member_ty,
utils::Vector{
ctx.dst->MemberSize(AInt(arr->Stride())),
});
ctx.dst->Structure(name, utils::Vector{member});
auto* member = b.Member(kMemberName, member_ty,
utils::Vector{
b.MemberSize(AInt(arr->Stride())),
});
b.Structure(name, utils::Vector{member});
return name;
});
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ctx.dst->ty.type_name(el_ty), count);
return b.ty.array(b.ty.type_name(el_ty), count);
}
if (ast::GetAttribute<ast::StrideAttribute>(ast->attributes)) {
// Strip the @stride attribute
auto* ty = ctx.Clone(ast->type);
auto* count = ctx.Clone(ast->count);
return ctx.dst->ty.array(ty, count);
return b.ty.array(ty, count);
}
}
return nullptr;
@ -96,11 +104,11 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con
// to insert an additional member accessor for the single structure field.
// Example: `arr[i]` -> `arr[i].el`
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* idx) -> const ast::Expression* {
if (auto* ty = ctx.src->TypeOf(idx->object)) {
if (auto* ty = src->TypeOf(idx->object)) {
if (auto* arr = ty->UnwrapRef()->As<sem::Array>()) {
if (!arr->IsStrideImplicit()) {
auto* expr = ctx.CloneWithoutTransform(idx);
return ctx.dst->MemberAccessor(expr, kMemberName);
return b.MemberAccessor(expr, kMemberName);
}
}
}
@ -136,21 +144,23 @@ void DecomposeStridedArray::Run(CloneContext& ctx, const DataMap&, DataMap&) con
if (auto it = decomposed.find(arr); it != decomposed.end()) {
args.Reserve(expr->args.Length());
for (auto* arg : expr->args) {
args.Push(ctx.dst->Call(it->second, ctx.Clone(arg)));
args.Push(b.Call(it->second, ctx.Clone(arg)));
}
} else {
args = ctx.Clone(expr->args);
}
return target.type ? ctx.dst->Construct(target.type, std::move(args))
: ctx.dst->Call(target.name, std::move(args));
return target.type ? b.Construct(target.type, std::move(args))
: b.Call(target.name, std::move(args));
}
}
}
}
return nullptr;
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -35,19 +35,10 @@ class DecomposeStridedArray final : public Castable<DecomposeStridedArray, Trans
/// Destructor
~DecomposeStridedArray() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -53,24 +53,25 @@ struct MatrixInfo {
};
};
/// Return type of the callback function of GatherCustomStrideMatrixMembers
enum GatherResult { kContinue, kStop };
} // namespace
/// GatherCustomStrideMatrixMembers scans `program` for all matrix members of
/// storage and uniform structs, which are of a matrix type, and have a custom
/// matrix stride attribute. For each matrix member found, `callback` is called.
/// `callback` is a function with the signature:
/// GatherResult(const sem::StructMember* member,
/// sem::Matrix* matrix,
/// uint32_t stride)
/// If `callback` return GatherResult::kStop, then the scanning will immediately
/// terminate, and GatherCustomStrideMatrixMembers() will return, otherwise
/// scanning will continue.
template <typename F>
void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
for (auto* node : program->ASTNodes().Objects()) {
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src,
const DataMap&,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
// Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced.
utils::Hashmap<const ast::StructMember*, MatrixInfo, 8> decomposed;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* str = node->As<ast::Struct>()) {
auto* str_ty = program->Sem().Get(str);
auto* str_ty = src->Sem().Get(str);
if (!str_ty->UsedAs(ast::AddressSpace::kUniform) &&
!str_ty->UsedAs(ast::AddressSpace::kStorage)) {
continue;
@ -89,46 +90,20 @@ void GatherCustomStrideMatrixMembers(const Program* program, F&& callback) {
if (matrix->ColumnStride() == stride) {
continue;
}
if (callback(member, matrix, stride) == GatherResult::kStop) {
return;
}
// We've got ourselves a struct member of a matrix type with a custom
// stride. Replace this with an array of column vectors.
MatrixInfo info{stride, matrix};
auto* replacement =
b.Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
ctx.Replace(member->Declaration(), replacement);
decomposed.Add(member->Declaration(), info);
}
}
}
}
} // namespace
DecomposeStridedMatrix::DecomposeStridedMatrix() = default;
DecomposeStridedMatrix::~DecomposeStridedMatrix() = default;
bool DecomposeStridedMatrix::ShouldRun(const Program* program, const DataMap&) const {
bool should_run = false;
GatherCustomStrideMatrixMembers(program,
[&](const sem::StructMember*, const sem::Matrix*, uint32_t) {
should_run = true;
return GatherResult::kStop;
});
return should_run;
}
void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
// Scan the program for all storage and uniform structure matrix members with
// a custom stride attribute. Replace these matrices with an equivalent array,
// and populate the `decomposed` map with the members that have been replaced.
std::unordered_map<const ast::StructMember*, MatrixInfo> decomposed;
GatherCustomStrideMatrixMembers(
ctx.src, [&](const sem::StructMember* member, const sem::Matrix* matrix, uint32_t stride) {
// We've got ourselves a struct member of a matrix type with a custom
// stride. Replace this with an array of column vectors.
MatrixInfo info{stride, matrix};
auto* replacement =
ctx.dst->Member(member->Offset(), ctx.Clone(member->Name()), info.array(ctx.dst));
ctx.Replace(member->Declaration(), replacement);
decomposed.emplace(member->Declaration(), info);
return GatherResult::kContinue;
});
if (decomposed.IsEmpty()) {
return SkipTransform;
}
// For all expressions where a single matrix column vector was indexed, we can
// preserve these without calling conversion functions.
@ -136,12 +111,11 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co
// ssbo.mat[2] -> ssbo.mat[2]
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it != decomposed.end()) {
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr->object)) {
if (decomposed.Contains(access->Member()->Declaration())) {
auto* obj = ctx.CloneWithoutTransform(expr->object);
auto* idx = ctx.Clone(expr->index);
return ctx.dst->IndexAccessor(obj, idx);
return b.IndexAccessor(obj, idx);
}
}
return nullptr;
@ -154,39 +128,36 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co
// ssbo.mat = mat_to_arr(m)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> mat_to_arr;
ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* {
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(stmt->lhs)) {
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] {
auto name =
b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" +
std::to_string(info->matrix->rows()) + "_stride_" +
std::to_string(info->stride) + "_to_arr");
auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
auto array = [&] { return info->array(ctx.dst); };
auto mat = b.Sym("m");
utils::Vector<const ast::Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
columns.Push(b.IndexAccessor(mat, u32(i)));
}
b.Func(name,
utils::Vector{
b.Param(mat, matrix()),
},
array(),
utils::Vector{
b.Return(b.Construct(array(), columns)),
});
return name;
});
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
auto* rhs = b.Call(fn, ctx.Clone(stmt->rhs));
return b.Assign(lhs, rhs);
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(mat_to_arr, info, [&] {
auto name =
ctx.dst->Symbols().New("mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" +
std::to_string(info.stride) + "_to_arr");
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto mat = ctx.dst->Sym("m");
utils::Vector<const ast::Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
columns.Push(ctx.dst->IndexAccessor(mat, u32(i)));
}
ctx.dst->Func(name,
utils::Vector{
ctx.dst->Param(mat, matrix()),
},
array(),
utils::Vector{
ctx.dst->Return(ctx.dst->Construct(array(), columns)),
});
return name;
});
auto* lhs = ctx.CloneWithoutTransform(stmt->lhs);
auto* rhs = ctx.dst->Call(fn, ctx.Clone(stmt->rhs));
return ctx.dst->Assign(lhs, rhs);
}
return nullptr;
});
@ -196,41 +167,40 @@ void DecomposeStridedMatrix::Run(CloneContext& ctx, const DataMap&, DataMap&) co
// m = arr_to_mat(ssbo.mat)
std::unordered_map<MatrixInfo, Symbol, MatrixInfo::Hasher> arr_to_mat;
ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* {
if (auto* access = ctx.src->Sem().Get<sem::StructMemberAccess>(expr)) {
auto it = decomposed.find(access->Member()->Declaration());
if (it == decomposed.end()) {
return nullptr;
if (auto* access = src->Sem().Get<sem::StructMemberAccess>(expr)) {
if (auto* info = decomposed.Find(access->Member()->Declaration())) {
auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] {
auto name =
b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) +
"x" + std::to_string(info->matrix->rows()) + "_stride_" +
std::to_string(info->stride));
auto matrix = [&] { return CreateASTTypeFor(ctx, info->matrix); };
auto array = [&] { return info->array(ctx.dst); };
auto arr = b.Sym("arr");
utils::Vector<const ast::Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info->matrix->columns()); i++) {
columns.Push(b.IndexAccessor(arr, u32(i)));
}
b.Func(name,
utils::Vector{
b.Param(arr, array()),
},
matrix(),
utils::Vector{
b.Return(b.Construct(matrix(), columns)),
});
return name;
});
return b.Call(fn, ctx.CloneWithoutTransform(expr));
}
MatrixInfo info = it->second;
auto fn = utils::GetOrCreate(arr_to_mat, info, [&] {
auto name = ctx.dst->Symbols().New(
"arr_to_mat" + std::to_string(info.matrix->columns()) + "x" +
std::to_string(info.matrix->rows()) + "_stride_" + std::to_string(info.stride));
auto matrix = [&] { return CreateASTTypeFor(ctx, info.matrix); };
auto array = [&] { return info.array(ctx.dst); };
auto arr = ctx.dst->Sym("arr");
utils::Vector<const ast::Expression*, 4> columns;
for (uint32_t i = 0; i < static_cast<uint32_t>(info.matrix->columns()); i++) {
columns.Push(ctx.dst->IndexAccessor(arr, u32(i)));
}
ctx.dst->Func(name,
utils::Vector{
ctx.dst->Param(arr, array()),
},
matrix(),
utils::Vector{
ctx.dst->Return(ctx.dst->Construct(matrix(), columns)),
});
return name;
});
return ctx.dst->Call(fn, ctx.CloneWithoutTransform(expr));
}
return nullptr;
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -35,19 +35,10 @@ class DecomposeStridedMatrix final : public Castable<DecomposeStridedMatrix, Tra
/// Destructor
~DecomposeStridedMatrix() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -27,14 +27,20 @@ DisableUniformityAnalysis::DisableUniformityAnalysis() = default;
DisableUniformityAnalysis::~DisableUniformityAnalysis() = default;
bool DisableUniformityAnalysis::ShouldRun(const Program* program, const DataMap&) const {
return !program->Sem().Module()->Extensions().Contains(
ast::Extension::kChromiumDisableUniformityAnalysis);
}
Transform::ApplyResult DisableUniformityAnalysis::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (src->Sem().Module()->Extensions().Contains(
ast::Extension::kChromiumDisableUniformityAnalysis)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
b.Enable(ast::Extension::kChromiumDisableUniformityAnalysis);
void DisableUniformityAnalysis::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.dst->Enable(ast::Extension::kChromiumDisableUniformityAnalysis);
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -27,19 +27,10 @@ class DisableUniformityAnalysis final : public Castable<DisableUniformityAnalysi
/// Destructor
~DisableUniformityAnalysis() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -31,11 +31,9 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
namespace {
ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->IsAnyOf<ast::CompoundAssignmentStatement, ast::IncrementDecrementStatement>()) {
return true;
@ -44,21 +42,10 @@ bool ExpandCompoundAssignment::ShouldRun(const Program* program, const DataMap&)
return false;
}
namespace {
} // namespace
/// Internal class used to collect statement expansions during the transform.
class State {
private:
/// The clone context.
CloneContext& ctx;
/// The program builder.
ProgramBuilder& b;
/// The HoistToDeclBefore helper instance.
HoistToDeclBefore hoist_to_decl_before;
public:
/// PIMPL state for the transform
struct ExpandCompoundAssignment::State {
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context), b(*ctx.dst), hoist_to_decl_before(ctx) {}
@ -158,15 +145,32 @@ class State {
ctx.Replace(stmt, b.Assign(new_lhs(), value));
}
/// Finalize the transformation and clone the module.
void Finalize() { ctx.Clone(); }
private:
/// The clone context.
CloneContext& ctx;
/// The program builder.
ProgramBuilder& b;
/// The HoistToDeclBefore helper instance.
HoistToDeclBefore hoist_to_decl_before;
};
} // namespace
ExpandCompoundAssignment::ExpandCompoundAssignment() = default;
void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ExpandCompoundAssignment::~ExpandCompoundAssignment() = default;
Transform::ApplyResult ExpandCompoundAssignment::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state(ctx);
for (auto* node : ctx.src->ASTNodes().Objects()) {
for (auto* node : src->ASTNodes().Objects()) {
if (auto* assign = node->As<ast::CompoundAssignmentStatement>()) {
state.Expand(assign, assign->lhs, ctx.Clone(assign->rhs), assign->op);
} else if (auto* inc_dec = node->As<ast::IncrementDecrementStatement>()) {
@ -175,7 +179,9 @@ void ExpandCompoundAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&)
state.Expand(inc_dec, inc_dec->lhs, ctx.dst->Expr(1_a), op);
}
}
state.Finalize();
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -45,19 +45,13 @@ class ExpandCompoundAssignment final : public Castable<ExpandCompoundAssignment,
/// Destructor
~ExpandCompoundAssignment() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
private:
struct State;
};
} // namespace tint::transform

View File

@ -35,6 +35,15 @@ namespace {
constexpr char kFirstVertexName[] = "first_vertex_index";
constexpr char kFirstInstanceName[] = "first_instance_index";
bool ShouldRun(const Program* program) {
for (auto* fn : program->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
return true;
}
}
return false;
}
} // namespace
FirstIndexOffset::BindingPoint::BindingPoint() = default;
@ -49,16 +58,16 @@ FirstIndexOffset::Data::~Data() = default;
FirstIndexOffset::FirstIndexOffset() = default;
FirstIndexOffset::~FirstIndexOffset() = default;
bool FirstIndexOffset::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
return true;
}
Transform::ApplyResult FirstIndexOffset::Apply(const Program* src,
const DataMap& inputs,
DataMap& outputs) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
return false;
}
void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
// Get the uniform buffer binding point
uint32_t ub_binding = binding_;
uint32_t ub_group = group_;
@ -115,17 +124,17 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou
if (has_vertex_or_instance_index) {
// Add uniform buffer members and calculate byte offsets
utils::Vector<const ast::StructMember*, 8> members;
members.Push(ctx.dst->Member(kFirstVertexName, ctx.dst->ty.u32()));
members.Push(ctx.dst->Member(kFirstInstanceName, ctx.dst->ty.u32()));
auto* struct_ = ctx.dst->Structure(ctx.dst->Sym(), std::move(members));
members.Push(b.Member(kFirstVertexName, b.ty.u32()));
members.Push(b.Member(kFirstInstanceName, b.ty.u32()));
auto* struct_ = b.Structure(b.Sym(), std::move(members));
// Create a global to hold the uniform buffer
Symbol buffer_name = ctx.dst->Sym();
ctx.dst->GlobalVar(buffer_name, ctx.dst->ty.Of(struct_), ast::AddressSpace::kUniform,
utils::Vector{
ctx.dst->Binding(AInt(ub_binding)),
ctx.dst->Group(AInt(ub_group)),
});
Symbol buffer_name = b.Sym();
b.GlobalVar(buffer_name, b.ty.Of(struct_), ast::AddressSpace::kUniform,
utils::Vector{
b.Binding(AInt(ub_binding)),
b.Group(AInt(ub_group)),
});
// Fix up all references to the builtins with the offsets
ctx.ReplaceAll([=, &ctx](const ast::Expression* expr) -> const ast::Expression* {
@ -150,9 +159,10 @@ void FirstIndexOffset::Run(CloneContext& ctx, const DataMap& inputs, DataMap& ou
});
}
ctx.Clone();
outputs.Add<Data>(has_vertex_or_instance_index);
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -103,19 +103,10 @@ class FirstIndexOffset final : public Castable<FirstIndexOffset, Transform> {
/// Destructor
~FirstIndexOffset() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
uint32_t binding_ = 0;

View File

@ -14,17 +14,17 @@
#include "src/tint/transform/for_loop_to_loop.h"
#include <utility>
#include "src/tint/ast/break_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ForLoopToLoop);
namespace tint::transform {
ForLoopToLoop::ForLoopToLoop() = default;
namespace {
ForLoopToLoop::~ForLoopToLoop() = default;
bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::ForLoopStatement>()) {
return true;
@ -33,19 +33,31 @@ bool ForLoopToLoop::ShouldRun(const Program* program, const DataMap&) const {
return false;
}
void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
} // namespace
ForLoopToLoop::ForLoopToLoop() = default;
ForLoopToLoop::~ForLoopToLoop() = default;
Transform::ApplyResult ForLoopToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
ctx.ReplaceAll([&](const ast::ForLoopStatement* for_loop) -> const ast::Statement* {
utils::Vector<const ast::Statement*, 8> stmts;
if (auto* cond = for_loop->condition) {
// !condition
auto* not_cond =
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
auto* not_cond = b.Not(ctx.Clone(cond));
// { break; }
auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
auto* break_body = b.Block(b.Break());
// if (!condition) { break; }
stmts.Push(ctx.dst->If(not_cond, break_body));
stmts.Push(b.If(not_cond, break_body));
}
for (auto* stmt : for_loop->body->statements) {
stmts.Push(ctx.Clone(stmt));
@ -53,20 +65,21 @@ void ForLoopToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
const ast::BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
continuing = ctx.dst->Block(ctx.Clone(cont));
continuing = b.Block(ctx.Clone(cont));
}
auto* body = ctx.dst->Block(stmts);
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
auto* body = b.Block(stmts);
auto* loop = b.Loop(body, continuing);
if (auto* init = for_loop->initializer) {
return ctx.dst->Block(ctx.Clone(init), loop);
return b.Block(ctx.Clone(init), loop);
}
return loop;
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -29,19 +29,10 @@ class ForLoopToLoop final : public Castable<ForLoopToLoop, Transform> {
/// Destructor
~ForLoopToLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -32,70 +32,15 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::LocalizeStructArrayAssignment);
namespace tint::transform {
/// Private implementation of LocalizeStructArrayAssignment transform
class LocalizeStructArrayAssignment::State {
private:
CloneContext& ctx;
ProgramBuilder& b;
/// Returns true if `expr` contains an index accessor expression to a
/// structure member of array type.
bool ContainsStructArrayIndex(const ast::Expression* expr) {
bool result = false;
ast::TraverseExpressions(
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
// Indexing using a runtime value?
auto* idx_sem = ctx.src->Sem().Get(ia->index);
if (!idx_sem->ConstantValue()) {
// Indexing a member access expr?
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
// That accesses an array?
if (ctx.src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
result = true;
return ast::TraverseAction::Stop;
}
}
}
return ast::TraverseAction::Descend;
});
return result;
}
// Returns the type and address space of the originating variable of the lhs
// of the assignment statement.
// See https://www.w3.org/TR/WGSL/#originating-variable-section
std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace(
const ast::AssignmentStatement* assign_stmt) {
auto* source_var = ctx.src->Sem().Get(assign_stmt->lhs)->SourceVariable();
if (!source_var) {
TINT_ICE(Transform, b.Diagnostics())
<< "Unable to determine originating variable for lhs of assignment "
"statement";
return {};
}
auto* type = source_var->Type();
if (auto* ref = type->As<sem::Reference>()) {
return {ref->StoreType(), ref->AddressSpace()};
} else if (auto* ptr = type->As<sem::Pointer>()) {
return {ptr->StoreType(), ptr->AddressSpace()};
}
TINT_ICE(Transform, b.Diagnostics())
<< "Expecting to find variable of type pointer or reference on lhs "
"of assignment statement";
return {};
}
public:
/// PIMPL state for the transform
struct LocalizeStructArrayAssignment::State {
/// Constructor
/// @param ctx_in the CloneContext primed with the input program and
/// ProgramBuilder
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// Runs the transform
void Run() {
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
struct Shared {
bool process_nested_nodes = false;
utils::Vector<const ast::Statement*, 4> insert_before_stmts;
@ -189,6 +134,65 @@ class LocalizeStructArrayAssignment::State {
});
ctx.Clone();
return Program(std::move(b));
}
private:
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Returns true if `expr` contains an index accessor expression to a
/// structure member of array type.
bool ContainsStructArrayIndex(const ast::Expression* expr) {
bool result = false;
ast::TraverseExpressions(
expr, b.Diagnostics(), [&](const ast::IndexAccessorExpression* ia) {
// Indexing using a runtime value?
auto* idx_sem = src->Sem().Get(ia->index);
if (!idx_sem->ConstantValue()) {
// Indexing a member access expr?
if (auto* ma = ia->object->As<ast::MemberAccessorExpression>()) {
// That accesses an array?
if (src->TypeOf(ma)->UnwrapRef()->Is<sem::Array>()) {
result = true;
return ast::TraverseAction::Stop;
}
}
}
return ast::TraverseAction::Descend;
});
return result;
}
// Returns the type and address space of the originating variable of the lhs
// of the assignment statement.
// See https://www.w3.org/TR/WGSL/#originating-variable-section
std::pair<const sem::Type*, ast::AddressSpace> GetOriginatingTypeAndAddressSpace(
const ast::AssignmentStatement* assign_stmt) {
auto* source_var = src->Sem().Get(assign_stmt->lhs)->SourceVariable();
if (!source_var) {
TINT_ICE(Transform, b.Diagnostics())
<< "Unable to determine originating variable for lhs of assignment "
"statement";
return {};
}
auto* type = source_var->Type();
if (auto* ref = type->As<sem::Reference>()) {
return {ref->StoreType(), ref->AddressSpace()};
} else if (auto* ptr = type->As<sem::Pointer>()) {
return {ptr->StoreType(), ptr->AddressSpace()};
}
TINT_ICE(Transform, b.Diagnostics())
<< "Expecting to find variable of type pointer or reference on lhs "
"of assignment statement";
return {};
}
};
@ -196,9 +200,10 @@ LocalizeStructArrayAssignment::LocalizeStructArrayAssignment() = default;
LocalizeStructArrayAssignment::~LocalizeStructArrayAssignment() = default;
void LocalizeStructArrayAssignment::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State state(ctx);
state.Run();
Transform::ApplyResult LocalizeStructArrayAssignment::Apply(const Program* src,
const DataMap&,
DataMap&) const {
return State{src}.Run();
}
} // namespace tint::transform

View File

@ -36,17 +36,13 @@ class LocalizeStructArrayAssignment final
/// Destructor
~LocalizeStructArrayAssignment() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
class State;
struct State;
};
} // namespace tint::transform

View File

@ -31,9 +31,9 @@ namespace tint::transform {
Manager::Manager() = default;
Manager::~Manager() = default;
Output Manager::Run(const Program* program, const DataMap& data) const {
const Program* in = program;
Transform::ApplyResult Manager::Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const {
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
auto print_program = [&](const char* msg, const Transform* transform) {
auto wgsl = Program::printer(in);
@ -46,34 +46,30 @@ Output Manager::Run(const Program* program, const DataMap& data) const {
};
#endif
Output out;
std::optional<Program> output;
for (const auto& transform : transforms_) {
if (!transform->ShouldRun(in, data)) {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipping " << transform->TypeInfo().name
<< std::endl);
continue;
}
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
auto res = transform->Run(in, data);
out.program = std::move(res.program);
out.data.Add(std::move(res.data));
in = &out.program;
if (!in->IsValid()) {
TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
return out;
}
if (auto result = transform->Apply(program, inputs, outputs)) {
output.emplace(std::move(result.value()));
program = &output.value();
if (transform == transforms_.back()) {
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
if (!program->IsValid()) {
TINT_IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
break;
}
if (transform == transforms_.back()) {
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
}
} else {
TINT_IF_PRINT_PROGRAM(std::cout << "Skipped " << transform->TypeInfo().name
<< std::endl);
}
}
if (program == in) {
out.program = program->Clone();
}
return out;
return output;
}
} // namespace tint::transform

View File

@ -47,11 +47,10 @@ class Manager final : public Castable<Manager, Transform> {
transforms_.emplace_back(std::make_unique<T>(std::forward<ARGS>(args)...));
}
/// Runs the transforms on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformed program and diagnostics
Output Run(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
std::vector<std::unique_ptr<Transform>> transforms_;

View File

@ -65,15 +65,6 @@ MergeReturn::MergeReturn() = default;
MergeReturn::~MergeReturn() = default;
bool MergeReturn::ShouldRun(const Program* program, const DataMap&) const {
for (auto* func : program->AST().Functions()) {
if (NeedsTransform(program, func)) {
return true;
}
}
return false;
}
namespace {
/// Internal class used to during the transform.
@ -223,7 +214,12 @@ class State {
} // namespace
void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
Transform::ApplyResult MergeReturn::Apply(const Program* src, const DataMap&, DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
bool made_changes = false;
for (auto* func : ctx.src->AST().Functions()) {
if (!NeedsTransform(ctx.src, func)) {
continue;
@ -231,9 +227,15 @@ void MergeReturn::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State state(ctx, func);
state.ProcessStatement(func->body);
made_changes = true;
}
if (!made_changes) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -27,19 +27,10 @@ class MergeReturn final : public Castable<MergeReturn, Transform> {
/// Destructor
~MergeReturn() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -38,6 +38,15 @@ using WorkgroupParameterMemberList = utils::Vector<const ast::StructMember*, 8>;
// The name of the struct member for arrays that are wrapped in structures.
const char* kWrappedArrayMemberName = "arr";
bool ShouldRun(const Program* program) {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (decl->Is<ast::Variable>()) {
return true;
}
}
return false;
}
// Returns `true` if `type` is or contains a matrix type.
bool ContainsMatrix(const sem::Type* type) {
type = type->UnwrapRef();
@ -56,7 +65,7 @@ bool ContainsMatrix(const sem::Type* type) {
}
} // namespace
/// State holds the current transform state.
/// PIMPL state for the transform
struct ModuleScopeVarToEntryPointParam::State {
/// The clone context.
CloneContext& ctx;
@ -501,19 +510,20 @@ ModuleScopeVarToEntryPointParam::ModuleScopeVarToEntryPointParam() = default;
ModuleScopeVarToEntryPointParam::~ModuleScopeVarToEntryPointParam() = default;
bool ModuleScopeVarToEntryPointParam::ShouldRun(const Program* program, const DataMap&) const {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (decl->Is<ast::Variable>()) {
return true;
}
Transform::ApplyResult ModuleScopeVarToEntryPointParam::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
return false;
}
void ModuleScopeVarToEntryPointParam::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state{ctx};
state.Process();
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -69,20 +69,12 @@ class ModuleScopeVarToEntryPointParam final
/// Destructor
~ModuleScopeVarToEntryPointParam() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;
};

View File

@ -31,6 +31,17 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ty = node->As<ast::Type>()) {
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
return true;
}
}
}
return false;
}
/// This struct stores symbols for new bindings created as a result of transforming a
/// texture_external instance.
struct NewBindingSymbols {
@ -40,7 +51,7 @@ struct NewBindingSymbols {
};
} // namespace
/// State holds the current transform state
/// PIMPL state for the transform
struct MultiplanarExternalTexture::State {
/// The clone context.
CloneContext& ctx;
@ -537,30 +548,26 @@ MultiplanarExternalTexture::NewBindingPoints::~NewBindingPoints() = default;
MultiplanarExternalTexture::MultiplanarExternalTexture() = default;
MultiplanarExternalTexture::~MultiplanarExternalTexture() = default;
bool MultiplanarExternalTexture::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* ty = node->As<ast::Type>()) {
if (program->Sem().Get<sem::ExternalTexture>(ty)) {
return true;
}
}
}
return false;
}
// Within this transform, an instance of a texture_external binding is unpacked into two
// texture_2d<f32> bindings representing two possible planes of a single texture and a uniform
// buffer binding representing a struct of parameters. Calls to texture builtins that contain a
// texture_external parameter will be transformed into a newly generated version of the function,
// which can perform the desired operation on a single RGBA plane or on separate Y and UV planes.
void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Transform::ApplyResult MultiplanarExternalTexture::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
auto* new_binding_points = inputs.Get<NewBindingPoints>();
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
if (!new_binding_points) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"missing new binding point data for " + std::string(TypeInfo().name));
return;
b.Diagnostics().add_error(diag::System::Transform, "missing new binding point data for " +
std::string(TypeInfo().name));
return Program(std::move(b));
}
State state(ctx, new_binding_points);
@ -568,6 +575,7 @@ void MultiplanarExternalTexture::Run(CloneContext& ctx, const DataMap& inputs, D
state.Process();
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -80,21 +80,13 @@ class MultiplanarExternalTexture final : public Castable<MultiplanarExternalText
/// Destructor
~MultiplanarExternalTexture() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
protected:
private:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -23,7 +23,11 @@ using MultiplanarExternalTextureTest = TransformTest;
TEST_F(MultiplanarExternalTextureTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src));
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
EXPECT_FALSE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
@ -31,14 +35,22 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureAlias) {
type ET = texture_external;
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureGlobal) {
auto* src = R"(
@group(0) @binding(0) var ext_tex : texture_external;
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
@ -46,7 +58,11 @@ TEST_F(MultiplanarExternalTextureTest, ShouldRunHasExternalTextureParam) {
fn f(ext_tex : texture_external) {}
)";
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src));
DataMap data;
data.Add<MultiplanarExternalTexture::NewBindingPoints>(
MultiplanarExternalTexture::BindingsMap{{{0, 0}, {{0, 1}, {0, 2}}}});
EXPECT_TRUE(ShouldRun<MultiplanarExternalTexture>(src, data));
}
// Running the transform without passing in data for the new bindings should result in an error.

View File

@ -29,6 +29,18 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::NumWorkgroupsFromUniform::Config);
namespace tint::transform {
namespace {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) {
return true;
}
}
}
return false;
}
/// Accessor describes the identifiers used in a member accessor that is being
/// used to retrieve the num_workgroups builtin from a parameter.
struct Accessor {
@ -44,41 +56,40 @@ struct Accessor {
size_t operator()(const Accessor& a) const { return utils::Hash(a.param, a.member); }
};
};
} // namespace
NumWorkgroupsFromUniform::NumWorkgroupsFromUniform() = default;
NumWorkgroupsFromUniform::~NumWorkgroupsFromUniform() = default;
bool NumWorkgroupsFromUniform::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* attr = node->As<ast::BuiltinAttribute>()) {
if (attr->builtin == ast::BuiltinValue::kNumWorkgroups) {
return true;
}
}
}
return false;
}
Transform::ApplyResult NumWorkgroupsFromUniform::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
return;
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return Program(std::move(b));
}
if (!ShouldRun(src)) {
return SkipTransform;
}
const char* kNumWorkgroupsMemberName = "num_workgroups";
// Find all entry point parameters that declare the num_workgroups builtin.
std::unordered_set<Accessor, Accessor::Hasher> to_replace;
for (auto* func : ctx.src->AST().Functions()) {
for (auto* func : src->AST().Functions()) {
// num_workgroups is only valid for compute stages.
if (func->PipelineStage() != ast::PipelineStage::kCompute) {
continue;
}
for (auto* param : ctx.src->Sem().Get(func)->Parameters()) {
for (auto* param : src->Sem().Get(func)->Parameters()) {
// Because the CanonicalizeEntryPointIO transform has been run, builtins
// will only appear as struct members.
auto* str = param->Type()->As<sem::Struct>();
@ -108,7 +119,7 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
// If this is the only member, remove the struct and parameter too.
if (str->Members().size() == 1) {
ctx.Remove(func->params, param->Declaration());
ctx.Remove(ctx.src->AST().GlobalDeclarations(), str->Declaration());
ctx.Remove(src->AST().GlobalDeclarations(), str->Declaration());
}
}
}
@ -119,11 +130,10 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
const ast::Variable* num_workgroups_ubo = nullptr;
auto get_ubo = [&]() {
if (!num_workgroups_ubo) {
auto* num_workgroups_struct = ctx.dst->Structure(
ctx.dst->Sym(),
utils::Vector{
ctx.dst->Member(kNumWorkgroupsMemberName, ctx.dst->ty.vec3(ctx.dst->ty.u32())),
});
auto* num_workgroups_struct =
b.Structure(b.Sym(), utils::Vector{
b.Member(kNumWorkgroupsMemberName, b.ty.vec3(b.ty.u32())),
});
uint32_t group, binding;
if (cfg->ubo_binding.has_value()) {
@ -135,9 +145,9 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
// plus 1, or group 0 if no resource bound.
group = 0;
for (auto* global : ctx.src->AST().GlobalVariables()) {
for (auto* global : src->AST().GlobalVariables()) {
if (global->HasBindingPoint()) {
auto* global_sem = ctx.src->Sem().Get<sem::GlobalVariable>(global);
auto* global_sem = src->Sem().Get<sem::GlobalVariable>(global);
auto binding_point = global_sem->BindingPoint();
if (binding_point.group >= group) {
group = binding_point.group + 1;
@ -148,16 +158,16 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
binding = 0;
}
num_workgroups_ubo = ctx.dst->GlobalVar(
ctx.dst->Sym(), ctx.dst->ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform,
ctx.dst->Group(AInt(group)), ctx.dst->Binding(AInt(binding)));
num_workgroups_ubo =
b.GlobalVar(b.Sym(), b.ty.Of(num_workgroups_struct), ast::AddressSpace::kUniform,
b.Group(AInt(group)), b.Binding(AInt(binding)));
}
return num_workgroups_ubo;
};
// Now replace all the places where the builtins are accessed with the value
// loaded from the uniform buffer.
for (auto* node : ctx.src->ASTNodes().Objects()) {
for (auto* node : src->ASTNodes().Objects()) {
auto* accessor = node->As<ast::MemberAccessorExpression>();
if (!accessor) {
continue;
@ -168,12 +178,12 @@ void NumWorkgroupsFromUniform::Run(CloneContext& ctx, const DataMap& inputs, Dat
}
if (to_replace.count({ident->symbol, accessor->member->symbol})) {
ctx.Replace(accessor,
ctx.dst->MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName));
ctx.Replace(accessor, b.MemberAccessor(get_ubo()->symbol, kNumWorkgroupsMemberName));
}
}
ctx.Clone();
return Program(std::move(b));
}
NumWorkgroupsFromUniform::Config::Config(std::optional<sem::BindingPoint> ubo_bp)

View File

@ -72,19 +72,10 @@ class NumWorkgroupsFromUniform final : public Castable<NumWorkgroupsFromUniform,
std::optional<sem::BindingPoint> ubo_binding;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -28,7 +28,9 @@ using NumWorkgroupsFromUniformTest = TransformTest;
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src));
DataMap data;
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
EXPECT_FALSE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
}
TEST_F(NumWorkgroupsFromUniformTest, ShouldRunHasNumWorkgroups) {
@ -38,7 +40,9 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
}
)";
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src));
DataMap data;
data.Add<NumWorkgroupsFromUniform::Config>(sem::BindingPoint{0, 30u});
EXPECT_TRUE(ShouldRun<NumWorkgroupsFromUniform>(src, data));
}
TEST_F(NumWorkgroupsFromUniformTest, Error_MissingTransformData) {
@ -55,7 +59,6 @@ fn main(@builtin(num_workgroups) num_wgs : vec3<u32>) {
DataMap data;
data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
auto got = Run<Unshadow, CanonicalizeEntryPointIO, NumWorkgroupsFromUniform>(src, data);
EXPECT_EQ(expect, str(got));
}

View File

@ -33,14 +33,15 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
/// The PIMPL state for the PackedVec3 transform
/// PIMPL state for the transform
struct PackedVec3::State {
/// Constructor
/// @param c the CloneContext
explicit State(CloneContext& c) : ctx(c) {}
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// Runs the transform
void Run() {
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
// Packed vec3<T> struct members
utils::Hashset<const sem::StructMember*, 8> members;
@ -72,6 +73,10 @@ struct PackedVec3::State {
}
}
if (members.IsEmpty()) {
return SkipTransform;
}
// Walk the nodes, starting with the most deeply nested, finding all the AST expressions
// that load a whole packed vector (not a scalar / swizzle of the vector).
utils::Hashset<const sem::Expression*, 16> refs;
@ -137,36 +142,20 @@ struct PackedVec3::State {
}
ctx.Clone();
}
/// @returns true if this transform should be run for the given program
/// @param program the program to inspect
static bool ShouldRun(const Program* program) {
for (auto* decl : program->AST().GlobalDeclarations()) {
if (auto* str = program->Sem().Get<sem::Struct>(decl)) {
if (str->IsHostShareable()) {
for (auto* member : str->Members()) {
if (auto* vec = member->Type()->As<sem::Vector>()) {
if (vec->Width() == 3) {
return true;
}
}
}
}
}
}
return false;
return Program(std::move(b));
}
private:
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext& ctx;
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Alias to the semantic info in ctx.src
const sem::Info& sem = ctx.src->Sem();
/// Alias to the symbols in ctx.src
const SymbolTable& sym = ctx.src->Symbols();
/// Alias to the ctx.dst program builder
ProgramBuilder& b = *ctx.dst;
};
PackedVec3::Attribute::Attribute(ProgramID pid, ast::NodeID nid) : Base(pid, nid) {}
@ -183,12 +172,8 @@ std::string PackedVec3::Attribute::InternalName() const {
PackedVec3::PackedVec3() = default;
PackedVec3::~PackedVec3() = default;
bool PackedVec3::ShouldRun(const Program* program, const DataMap&) const {
return State::ShouldRun(program);
}
void PackedVec3::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State(ctx).Run();
Transform::ApplyResult PackedVec3::Apply(const Program* src, const DataMap&, DataMap&) const {
return State{src}.Run();
}
} // namespace tint::transform

View File

@ -56,21 +56,13 @@ class PackedVec3 final : public Castable<PackedVec3, Transform> {
/// Destructor
~PackedVec3() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -50,8 +50,10 @@ PadStructs::PadStructs() = default;
PadStructs::~PadStructs() = default;
void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
Transform::ApplyResult PadStructs::Apply(const Program* src, const DataMap&, DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto& sem = src->Sem();
std::unordered_map<const ast::Struct*, const ast::Struct*> replaced_structs;
utils::Hashset<const ast::StructMember*, 8> padding_members;
@ -65,7 +67,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
bool has_runtime_sized_array = false;
utils::Vector<const ast::StructMember*, 8> new_members;
for (auto* mem : str->Members()) {
auto name = ctx.src->Symbols().NameFor(mem->Name());
auto name = src->Symbols().NameFor(mem->Name());
if (offset < mem->Offset()) {
CreatePadding(&new_members, &padding_members, ctx.dst, mem->Offset() - offset);
@ -75,7 +77,7 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto* ty = mem->Type();
const ast::Type* type = CreateASTTypeFor(ctx, ty);
new_members.Push(ctx.dst->Member(name, type));
new_members.Push(b.Member(name, type));
uint32_t size = ty->Size();
if (ty->Is<sem::Struct>() && str->UsedAs(ast::AddressSpace::kUniform)) {
@ -97,8 +99,8 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
if (offset < struct_size && !has_runtime_sized_array) {
CreatePadding(&new_members, &padding_members, ctx.dst, struct_size - offset);
}
auto* new_struct = ctx.dst->create<ast::Struct>(ctx.Clone(ast_str->name),
std::move(new_members), utils::Empty);
auto* new_struct =
b.create<ast::Struct>(ctx.Clone(ast_str->name), std::move(new_members), utils::Empty);
replaced_structs[ast_str] = new_struct;
return new_struct;
});
@ -131,16 +133,17 @@ void PadStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto* arg = ast_call->args.begin();
for (auto* member : new_struct->members) {
if (padding_members.Contains(member)) {
new_args.Push(ctx.dst->Expr(0_u));
new_args.Push(b.Expr(0_u));
} else {
new_args.Push(ctx.Clone(*arg));
arg++;
}
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, str), new_args);
return b.Construct(CreateASTTypeFor(ctx, str), new_args);
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -30,14 +30,10 @@ class PadStructs final : public Castable<PadStructs, Transform> {
/// Destructor
~PadStructs() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -13,6 +13,9 @@
// limitations under the License.
#include "src/tint/transform/promote_initializers_to_let.h"
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/call.h"
#include "src/tint/sem/statement.h"
@ -27,9 +30,16 @@ PromoteInitializersToLet::PromoteInitializersToLet() = default;
PromoteInitializersToLet::~PromoteInitializersToLet() = default;
void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
Transform::ApplyResult PromoteInitializersToLet::Apply(const Program* src,
const DataMap&,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
HoistToDeclBefore hoist_to_decl_before(ctx);
bool any_promoted = false;
// Hoists array and structure initializers to a constant variable, declared
// just before the statement of usage.
auto promote = [&](const sem::Expression* expr) {
@ -59,14 +69,15 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&)
return true;
}
any_promoted = true;
return hoist_to_decl_before.Add(expr, expr->Declaration(), true);
};
for (auto* node : ctx.src->ASTNodes().Objects()) {
for (auto* node : src->ASTNodes().Objects()) {
bool ok = Switch(
node, //
[&](const ast::CallExpression* expr) {
if (auto* sem = ctx.src->Sem().Get(expr)) {
if (auto* sem = src->Sem().Get(expr)) {
auto* ctor = sem->UnwrapMaterialize()->As<sem::Call>();
if (ctor->Target()->Is<sem::TypeInitializer>()) {
return promote(sem);
@ -75,7 +86,7 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&)
return true;
},
[&](const ast::IdentifierExpression* expr) {
if (auto* sem = ctx.src->Sem().Get(expr)) {
if (auto* sem = src->Sem().Get(expr)) {
if (auto* user = sem->UnwrapMaterialize()->As<sem::VariableUser>()) {
// Identifier resolves to a variable
if (auto* stmt = user->Stmt()) {
@ -96,13 +107,17 @@ void PromoteInitializersToLet::Run(CloneContext& ctx, const DataMap&, DataMap&)
return true;
},
[&](Default) { return true; });
if (!ok) {
return;
return Program(std::move(b));
}
}
if (!any_promoted) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -33,14 +33,10 @@ class PromoteInitializersToLet final : public Castable<PromoteInitializersToLet,
/// Destructor
~PromoteInitializersToLet() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -53,34 +53,36 @@ class StateBase {
// to else {if}s so that the next transform, DecomposeSideEffects, can insert
// hoisted expressions above their current location.
struct SimplifySideEffectStatements : Castable<PromoteSideEffectsToDecl, Transform> {
class State;
void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
};
class SimplifySideEffectStatements::State : public StateBase {
HoistToDeclBefore hoist_to_decl_before;
Transform::ApplyResult SimplifySideEffectStatements::Apply(const Program* src,
const DataMap&,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
public:
explicit State(CloneContext& ctx_in) : StateBase(ctx_in), hoist_to_decl_before(ctx_in) {}
bool made_changes = false;
void Run() {
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* expr = node->As<ast::Expression>()) {
auto* sem_expr = sem.Get(expr);
if (!sem_expr || !sem_expr->HasSideEffects()) {
continue;
}
hoist_to_decl_before.Prepare(sem_expr);
HoistToDeclBefore hoist_to_decl_before(ctx);
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* expr = node->As<ast::Expression>()) {
auto* sem_expr = src->Sem().Get(expr);
if (!sem_expr || !sem_expr->HasSideEffects()) {
continue;
}
}
ctx.Clone();
}
};
void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State state(ctx);
state.Run();
hoist_to_decl_before.Prepare(sem_expr);
made_changes = true;
}
}
if (!made_changes) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
// Decomposes side-effecting expressions to ensure order of evaluation. This
@ -89,7 +91,7 @@ void SimplifySideEffectStatements::Run(CloneContext& ctx, const DataMap&, DataMa
struct DecomposeSideEffects : Castable<PromoteSideEffectsToDecl, Transform> {
class CollectHoistsState;
class DecomposeState;
void Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const override;
ApplyResult Apply(const Program* src, const DataMap& inputs, DataMap& outputs) const override;
};
// CollectHoistsState traverses the AST top-down, identifying which expressions
@ -667,12 +669,15 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
}
return nullptr;
});
ctx.Clone();
}
};
void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
Transform::ApplyResult DecomposeSideEffects::Apply(const Program* src,
const DataMap&,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
// First collect side-effecting expressions to hoist
CollectHoistsState collect_hoists_state{ctx};
auto to_hoist = collect_hoists_state.Run();
@ -680,6 +685,9 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
// Now decompose these expressions
DecomposeState decompose_state{ctx, std::move(to_hoist)};
decompose_state.Run();
ctx.Clone();
return Program(std::move(b));
}
} // namespace
@ -687,13 +695,13 @@ void DecomposeSideEffects::Run(CloneContext& ctx, const DataMap&, DataMap&) cons
PromoteSideEffectsToDecl::PromoteSideEffectsToDecl() = default;
PromoteSideEffectsToDecl::~PromoteSideEffectsToDecl() = default;
Output PromoteSideEffectsToDecl::Run(const Program* program, const DataMap& data) const {
Transform::ApplyResult PromoteSideEffectsToDecl::Apply(const Program* src,
const DataMap& inputs,
DataMap& outputs) const {
transform::Manager manager;
manager.Add<SimplifySideEffectStatements>();
manager.Add<DecomposeSideEffects>();
auto output = manager.Run(program, data);
return output;
return manager.Apply(src, inputs, outputs);
}
} // namespace tint::transform

View File

@ -31,12 +31,10 @@ class PromoteSideEffectsToDecl final : public Castable<PromoteSideEffectsToDecl,
/// Destructor
~PromoteSideEffectsToDecl() override;
protected:
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -32,53 +32,19 @@
TINT_INSTANTIATE_TYPEINFO(tint::transform::RemoveContinueInSwitch);
namespace tint::transform {
namespace {
class State {
private:
CloneContext& ctx;
ProgramBuilder& b;
const sem::Info& sem;
// Map of switch statement to 'tint_continue' variable.
std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
// If `cont` is within a switch statement within a loop, returns a pointer to
// that switch statement.
static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
const ast::ContinueStatement* cont) {
// Find whether first parent is a switch or a loop
auto* sem_stmt = sem.Get(cont);
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
sem::ForLoopStatement, sem::WhileStatement>();
if (!sem_parent) {
return nullptr;
}
return sem_parent->Declaration()->As<ast::SwitchStatement>();
}
public:
/// PIMPL state for the transform
struct RemoveContinueInSwitch::State {
/// Constructor
/// @param ctx_in the context
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
/// Returns true if this transform should be run for the given program
static bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
auto* stmt = node->As<ast::ContinueStatement>();
if (!stmt) {
continue;
}
if (GetParentSwitchInLoop(program->Sem(), stmt)) {
return true;
}
}
return false;
}
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// Runs the transform
void Run() {
for (auto* node : ctx.src->ASTNodes().Objects()) {
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
auto* cont = node->As<ast::ContinueStatement>();
if (!cont) {
continue;
@ -90,6 +56,8 @@ class State {
continue;
}
made_changes = true;
auto cont_var_name =
tint::utils::GetOrCreate(switch_to_cont_var_name, switch_stmt, [&]() {
// Create and insert 'var tint_continue : bool = false;' before the
@ -116,22 +84,50 @@ class State {
ctx.Replace(cont, new_stmt);
}
if (!made_changes) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
private:
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Alias to src->sem
const sem::Info& sem = src->Sem();
// Map of switch statement to 'tint_continue' variable.
std::unordered_map<const ast::SwitchStatement*, Symbol> switch_to_cont_var_name;
// If `cont` is within a switch statement within a loop, returns a pointer to
// that switch statement.
static const ast::SwitchStatement* GetParentSwitchInLoop(const sem::Info& sem,
const ast::ContinueStatement* cont) {
// Find whether first parent is a switch or a loop
auto* sem_stmt = sem.Get(cont);
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
sem::ForLoopStatement, sem::WhileStatement>();
if (!sem_parent) {
return nullptr;
}
return sem_parent->Declaration()->As<ast::SwitchStatement>();
}
};
} // namespace
RemoveContinueInSwitch::RemoveContinueInSwitch() = default;
RemoveContinueInSwitch::~RemoveContinueInSwitch() = default;
bool RemoveContinueInSwitch::ShouldRun(const Program* program, const DataMap& /*data*/) const {
return State::ShouldRun(program);
}
void RemoveContinueInSwitch::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State state(ctx);
state.Run();
Transform::ApplyResult RemoveContinueInSwitch::Apply(const Program* src,
const DataMap&,
DataMap&) const {
State state(src);
return state.Run();
}
} // namespace tint::transform

View File

@ -31,19 +31,13 @@ class RemoveContinueInSwitch final : public Castable<RemoveContinueInSwitch, Tra
/// Destructor
~RemoveContinueInSwitch() override;
protected:
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
private:
struct State;
};
} // namespace tint::transform

View File

@ -41,34 +41,25 @@ RemovePhonies::RemovePhonies() = default;
RemovePhonies::~RemovePhonies() = default;
bool RemovePhonies::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::PhonyExpression>()) {
return true;
}
if (auto* stmt = node->As<ast::CallStatement>()) {
if (program->Sem().Get(stmt->expr)->ConstantValue() != nullptr) {
return true;
}
}
}
return false;
}
Transform::ApplyResult RemovePhonies::Apply(const Program* src, const DataMap&, DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
auto& sem = ctx.src->Sem();
auto& sem = src->Sem();
std::unordered_map<SinkSignature, Symbol, utils::Hasher<SinkSignature>> sinks;
utils::Hashmap<SinkSignature, Symbol, 8, utils::Hasher<SinkSignature>> sinks;
for (auto* node : ctx.src->ASTNodes().Objects()) {
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
Switch(
node,
[&](const ast::AssignmentStatement* stmt) {
if (stmt->lhs->Is<ast::PhonyExpression>()) {
made_changes = true;
std::vector<const ast::Expression*> side_effects;
if (!ast::TraverseExpressions(
stmt->rhs, ctx.dst->Diagnostics(),
[&](const ast::CallExpression* expr) {
stmt->rhs, b.Diagnostics(), [&](const ast::CallExpression* expr) {
// ast::CallExpression may map to a function or builtin call
// (both may have side-effects), or a type initializer or
// type conversion (both do not have side effects).
@ -100,8 +91,7 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
if (auto* call = side_effects[0]->As<ast::CallExpression>()) {
// Phony assignment with single call side effect.
// Replace phony assignment with call.
ctx.Replace(stmt,
[&, call] { return ctx.dst->CallStmt(ctx.Clone(call)); });
ctx.Replace(stmt, [&, call] { return b.CallStmt(ctx.Clone(call)); });
return;
}
}
@ -114,22 +104,21 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
for (auto* arg : side_effects) {
sig.push_back(sem.Get(arg)->Type()->UnwrapRef());
}
auto sink = utils::GetOrCreate(sinks, sig, [&] {
auto name = ctx.dst->Symbols().New("phony_sink");
auto sink = sinks.GetOrCreate(sig, [&] {
auto name = b.Symbols().New("phony_sink");
utils::Vector<const ast::Parameter*, 8> params;
for (auto* ty : sig) {
auto* ast_ty = CreateASTTypeFor(ctx, ty);
params.Push(
ctx.dst->Param("p" + std::to_string(params.Length()), ast_ty));
params.Push(b.Param("p" + std::to_string(params.Length()), ast_ty));
}
ctx.dst->Func(name, params, ctx.dst->ty.void_(), {});
b.Func(name, params, b.ty.void_(), {});
return name;
});
utils::Vector<const ast::Expression*, 8> args;
for (auto* arg : side_effects) {
args.Push(ctx.Clone(arg));
}
return ctx.dst->CallStmt(ctx.dst->Call(sink, args));
return b.CallStmt(b.Call(sink, args));
});
}
},
@ -138,12 +127,18 @@ void RemovePhonies::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
// TODO(crbug.com/tint/1637): Remove if `stmt->expr` has no side-effects.
auto* sem_expr = sem.Get(stmt->expr);
if ((sem_expr->ConstantValue() != nullptr) && !sem_expr->HasSideEffects()) {
made_changes = true;
ctx.Remove(sem.Get(stmt)->Block()->Declaration()->statements, stmt);
}
});
}
if (!made_changes) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -33,19 +33,10 @@ class RemovePhonies final : public Castable<RemovePhonies, Transform> {
/// Destructor
~RemovePhonies() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -36,27 +36,28 @@ RemoveUnreachableStatements::RemoveUnreachableStatements() = default;
RemoveUnreachableStatements::~RemoveUnreachableStatements() = default;
bool RemoveUnreachableStatements::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* stmt = program->Sem().Get<sem::Statement>(node)) {
Transform::ApplyResult RemoveUnreachableStatements::Apply(const Program* src,
const DataMap&,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
bool made_changes = false;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* stmt = src->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) {
return true;
RemoveStatement(ctx, stmt->Declaration());
made_changes = true;
}
}
}
return false;
}
void RemoveUnreachableStatements::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* stmt = ctx.src->Sem().Get<sem::Statement>(node)) {
if (!stmt->IsReachable()) {
RemoveStatement(ctx, stmt->Declaration());
}
}
if (!made_changes) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -32,19 +32,10 @@ class RemoveUnreachableStatements final : public Castable<RemoveUnreachableState
/// Destructor
~RemoveUnreachableStatements() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -1252,39 +1252,31 @@ Renamer::Config::~Config() = default;
Renamer::Renamer() = default;
Renamer::~Renamer() = default;
Output Renamer::Run(const Program* in, const DataMap& inputs) const {
ProgramBuilder out;
// Disable auto-cloning of symbols, since we want to rename them.
CloneContext ctx(&out, in, false);
Transform::ApplyResult Renamer::Apply(const Program* src,
const DataMap& inputs,
DataMap& outputs) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ false};
// Swizzles, builtin calls and builtin structure members need to keep their
// symbols preserved.
std::unordered_set<const ast::IdentifierExpression*> preserve;
for (auto* node : in->ASTNodes().Objects()) {
utils::Hashset<const ast::IdentifierExpression*, 8> preserve;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* member = node->As<ast::MemberAccessorExpression>()) {
auto* sem = in->Sem().Get(member);
if (!sem) {
TINT_ICE(Transform, out.Diagnostics())
<< "MemberAccessorExpression has no semantic info";
continue;
}
auto* sem = src->Sem().Get(member);
if (sem->Is<sem::Swizzle>()) {
preserve.emplace(member->member);
} else if (auto* str_expr = in->Sem().Get(member->structure)) {
preserve.Add(member->member);
} else if (auto* str_expr = src->Sem().Get(member->structure)) {
if (auto* ty = str_expr->Type()->UnwrapRef()->As<sem::Struct>()) {
if (ty->Declaration() == nullptr) { // Builtin structure
preserve.emplace(member->member);
preserve.Add(member->member);
}
}
}
} else if (auto* call = node->As<ast::CallExpression>()) {
auto* sem = in->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>();
if (!sem) {
TINT_ICE(Transform, out.Diagnostics()) << "CallExpression has no semantic info";
continue;
}
auto* sem = src->Sem().Get(call)->UnwrapMaterialize()->As<sem::Call>();
if (sem->Target()->Is<sem::Builtin>()) {
preserve.emplace(call->target.name);
preserve.Add(call->target.name);
}
}
}
@ -1300,7 +1292,7 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const {
}
ctx.ReplaceAll([&](Symbol sym_in) {
auto name_in = ctx.src->Symbols().NameFor(sym_in);
auto name_in = src->Symbols().NameFor(sym_in);
if (preserve_unicode || text::utf8::IsASCII(name_in)) {
switch (target) {
case Target::kAll:
@ -1343,17 +1335,20 @@ Output Renamer::Run(const Program* in, const DataMap& inputs) const {
});
ctx.ReplaceAll([&](const ast::IdentifierExpression* ident) -> const ast::IdentifierExpression* {
if (preserve.count(ident)) {
if (preserve.Contains(ident)) {
auto sym_in = ident->symbol;
auto str = in->Symbols().NameFor(sym_in);
auto sym_out = out.Symbols().Register(str);
auto str = src->Symbols().NameFor(sym_in);
auto sym_out = b.Symbols().Register(str);
return ctx.dst->create<ast::IdentifierExpression>(ctx.Clone(ident->source), sym_out);
}
return nullptr; // Clone ident. Uses the symbol remapping above.
});
ctx.Clone();
return Output(Program(std::move(out)), std::make_unique<Data>(std::move(remappings)));
ctx.Clone(); // Must come before the std::move()
outputs.Add<Data>(std::move(remappings));
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -85,11 +85,10 @@ class Renamer final : public Castable<Renamer, Transform> {
/// Destructor
~Renamer() override;
/// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformation result
Output Run(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -33,36 +33,48 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::transform {
/// State holds the current transform state
/// PIMPL state for the transform
struct Robustness::State {
/// The clone context
CloneContext& ctx;
/// Constructor
/// @param program the source program
/// @param omitted the omitted address spaces
State(const Program* program, std::unordered_set<ast::AddressSpace>&& omitted)
: src(program), omitted_address_spaces(std::move(omitted)) {}
/// Set of address spacees to not apply the transform to
std::unordered_set<ast::AddressSpace> omitted_classes;
/// Applies the transformation state to `ctx`.
void Transform() {
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
ctx.ReplaceAll([&](const ast::IndexAccessorExpression* expr) { return Transform(expr); });
ctx.ReplaceAll([&](const ast::CallExpression* expr) { return Transform(expr); });
ctx.Clone();
return Program(std::move(b));
}
private:
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Set of address spaces to not apply the transform to
std::unordered_set<ast::AddressSpace> omitted_address_spaces;
/// Apply bounds clamping to array, vector and matrix indexing
/// @param expr the array, vector or matrix index expression
/// @return the clamped replacement expression, or nullptr if `expr` should be cloned without
/// changes.
const ast::IndexAccessorExpression* Transform(const ast::IndexAccessorExpression* expr) {
auto* sem =
ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>();
auto* sem = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::IndexAccessorExpression>();
auto* ret_type = sem->Type();
auto* ref = ret_type->As<sem::Reference>();
if (ref && omitted_classes.count(ref->AddressSpace()) != 0) {
if (ref && omitted_address_spaces.count(ref->AddressSpace()) != 0) {
return nullptr;
}
ProgramBuilder& b = *ctx.dst;
// idx return the cloned index expression, as a u32.
auto idx = [&]() -> const ast::Expression* {
auto* i = ctx.Clone(expr->index);
@ -109,8 +121,8 @@ struct Robustness::State {
} else {
// Note: Don't be tempted to use the array override variable as an expression
// here, the name might be shadowed!
ctx.dst->Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
b.Diagnostics().add_error(diag::System::Transform,
sem::Array::kErrExpectedConstantCount);
return nullptr;
}
@ -119,7 +131,7 @@ struct Robustness::State {
[&](Default) {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled object type in robustness of array index: "
<< ctx.src->FriendlyName(ret_type->UnwrapRef());
<< src->FriendlyName(ret_type->UnwrapRef());
return nullptr;
});
@ -127,9 +139,9 @@ struct Robustness::State {
return nullptr; // Clamping not needed
}
auto src = ctx.Clone(expr->source);
auto* obj = ctx.Clone(expr->object);
return b.IndexAccessor(src, obj, clamped_idx);
auto idx_src = ctx.Clone(expr->source);
auto* idx_obj = ctx.Clone(expr->object);
return b.IndexAccessor(idx_src, idx_obj, clamped_idx);
}
/// @param type builtin type
@ -145,15 +157,13 @@ struct Robustness::State {
/// @return the clamped replacement call expression, or nullptr if `expr`
/// should be cloned without changes.
const ast::CallExpression* Transform(const ast::CallExpression* expr) {
auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* call_target = call->Target();
auto* builtin = call_target->As<sem::Builtin>();
if (!builtin || !TextureBuiltinNeedsClamping(builtin->Type())) {
return nullptr; // No transform, just clone.
}
ProgramBuilder& b = *ctx.dst;
// Indices of the mandatory texture and coords parameters, and the optional
// array and level parameters.
auto& signature = builtin->Signature();
@ -261,7 +271,7 @@ struct Robustness::State {
// Clamp the level argument, if provided
if (level_idx >= 0) {
auto* arg = expr->args[static_cast<size_t>(level_idx)];
ctx.Replace(arg, level_arg ? level_arg() : ctx.dst->Expr(0_a));
ctx.Replace(arg, level_arg ? level_arg() : b.Expr(0_a));
}
return nullptr; // Clone, which will use the argument replacements above.
@ -276,28 +286,27 @@ Robustness::Config& Robustness::Config::operator=(const Config&) = default;
Robustness::Robustness() = default;
Robustness::~Robustness() = default;
void Robustness::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Transform::ApplyResult Robustness::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
Config cfg;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
std::unordered_set<ast::AddressSpace> omitted_classes;
for (auto sc : cfg.omitted_classes) {
std::unordered_set<ast::AddressSpace> omitted_address_spaces;
for (auto sc : cfg.omitted_address_spaces) {
switch (sc) {
case AddressSpace::kUniform:
omitted_classes.insert(ast::AddressSpace::kUniform);
omitted_address_spaces.insert(ast::AddressSpace::kUniform);
break;
case AddressSpace::kStorage:
omitted_classes.insert(ast::AddressSpace::kStorage);
omitted_address_spaces.insert(ast::AddressSpace::kStorage);
break;
}
}
State state{ctx, std::move(omitted_classes)};
state.Transform();
ctx.Clone();
return State{src, std::move(omitted_address_spaces)}.Run();
}
} // namespace tint::transform

View File

@ -54,9 +54,9 @@ class Robustness final : public Castable<Robustness, Transform> {
/// @returns this Config
Config& operator=(const Config&);
/// Address spacees to omit from apply the transform to.
/// Address spaces to omit from apply the transform to.
/// This allows for optimizing on hardware that provide safe accesses.
std::unordered_set<AddressSpace> omitted_classes;
std::unordered_set<AddressSpace> omitted_address_spaces;
};
/// Constructor
@ -64,14 +64,10 @@ class Robustness final : public Castable<Robustness, Transform> {
/// Destructor
~Robustness() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;

View File

@ -1274,7 +1274,7 @@ fn f() {
)";
Robustness::Config cfg;
cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage);
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage);
DataMap data;
data.Add<Robustness::Config>(cfg);
@ -1325,7 +1325,7 @@ fn f() {
)";
Robustness::Config cfg;
cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform);
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform);
DataMap data;
data.Add<Robustness::Config>(cfg);
@ -1376,8 +1376,8 @@ fn f() {
)";
Robustness::Config cfg;
cfg.omitted_classes.insert(Robustness::AddressSpace::kStorage);
cfg.omitted_classes.insert(Robustness::AddressSpace::kUniform);
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kStorage);
cfg.omitted_address_spaces.insert(Robustness::AddressSpace::kUniform);
DataMap data;
data.Add<Robustness::Config>(cfg);

View File

@ -45,14 +45,18 @@ struct PointerOp {
} // namespace
/// The PIMPL state for the SimplifyPointers transform
/// PIMPL state for the transform
struct SimplifyPointers::State {
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext& ctx;
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context) {}
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// Traverses the expression `expr` looking for non-literal array indexing
/// expressions that would affect the computed address of a pointer
@ -120,10 +124,11 @@ struct SimplifyPointers::State {
}
}
/// Performs the transformation
void Run() {
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
// A map of saved expressions to their saved variable name
std::unordered_map<const ast::Expression*, Symbol> saved_vars;
utils::Hashmap<const ast::Expression*, Symbol, 8> saved_vars;
// Register the ast::Expression transform handler.
// This performs two different transformations:
@ -135,9 +140,8 @@ struct SimplifyPointers::State {
// variable identifier.
ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* {
// Look to see if we need to swap this Expression with a saved variable.
auto it = saved_vars.find(expr);
if (it != saved_vars.end()) {
return ctx.dst->Expr(it->second);
if (auto* saved_var = saved_vars.Find(expr)) {
return ctx.dst->Expr(*saved_var);
}
// Reduce the expression, folding away chains of address-of / indirections
@ -174,7 +178,7 @@ struct SimplifyPointers::State {
// Scan the initializer expression for array index expressions that need
// to be hoist to temporary "saved" variables.
std::vector<const ast::VariableDeclStatement*> saved;
utils::Vector<const ast::VariableDeclStatement*, 8> saved;
CollectSavedArrayIndices(
var->Declaration()->initializer, [&](const ast::Expression* idx_expr) {
// We have a sub-expression that needs to be saved.
@ -182,18 +186,18 @@ struct SimplifyPointers::State {
auto saved_name = ctx.dst->Symbols().New(
ctx.src->Symbols().NameFor(var->Declaration()->symbol) + "_save");
auto* decl = ctx.dst->Decl(ctx.dst->Let(saved_name, ctx.Clone(idx_expr)));
saved.emplace_back(decl);
saved.Push(decl);
// Record the substitution of `idx_expr` to the saved variable
// with the symbol `saved_name`. This will be used by the
// ReplaceAll() handler above.
saved_vars.emplace(idx_expr, saved_name);
saved_vars.Add(idx_expr, saved_name);
});
// Find the place to insert the saved declarations.
// Special care needs to be made for lets declared as the initializer
// part of for-loops. In this case the block will hold the for-loop
// statement, not the let.
if (!saved.empty()) {
if (!saved.IsEmpty()) {
auto* stmt = ctx.src->Sem().Get(let);
auto* block = stmt->Block();
// Find the statement owned by the block (either the let decl or a
@ -219,7 +223,9 @@ struct SimplifyPointers::State {
RemoveStatement(ctx, let);
}
}
ctx.Clone();
return Program(std::move(b));
}
};
@ -227,8 +233,8 @@ SimplifyPointers::SimplifyPointers() = default;
SimplifyPointers::~SimplifyPointers() = default;
void SimplifyPointers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State(ctx).Run();
Transform::ApplyResult SimplifyPointers::Apply(const Program* src, const DataMap&, DataMap&) const {
return State(src).Run();
}
} // namespace tint::transform

View File

@ -39,16 +39,13 @@ class SimplifyPointers final : public Castable<SimplifyPointers, Transform> {
/// Destructor
~SimplifyPointers() override;
protected:
struct State;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
private:
struct State;
};
} // namespace tint::transform

View File

@ -30,33 +30,37 @@ SingleEntryPoint::SingleEntryPoint() = default;
SingleEntryPoint::~SingleEntryPoint() = default;
void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Transform::ApplyResult SingleEntryPoint::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, "missing transform data for " + std::string(TypeInfo().name));
return;
b.Diagnostics().add_error(diag::System::Transform,
"missing transform data for " + std::string(TypeInfo().name));
return Program(std::move(b));
}
// Find the target entry point.
const ast::Function* entry_point = nullptr;
for (auto* f : ctx.src->AST().Functions()) {
for (auto* f : src->AST().Functions()) {
if (!f->IsEntryPoint()) {
continue;
}
if (ctx.src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
if (src->Symbols().NameFor(f->symbol) == cfg->entry_point_name) {
entry_point = f;
break;
}
}
if (entry_point == nullptr) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
"entry point '" + cfg->entry_point_name + "' not found");
return;
b.Diagnostics().add_error(diag::System::Transform,
"entry point '" + cfg->entry_point_name + "' not found");
return Program(std::move(b));
}
auto& sem = ctx.src->Sem();
auto& sem = src->Sem();
// Build set of referenced module-scope variables for faster lookups later.
std::unordered_set<const ast::Variable*> referenced_vars;
@ -66,12 +70,12 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
// Clone any module-scope variables, types, and functions that are statically referenced by the
// target entry point.
for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
for (auto* decl : src->AST().GlobalDeclarations()) {
Switch(
decl, //
[&](const ast::TypeDecl* ty) {
// TODO(jrprice): Strip unused types.
ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
b.AST().AddTypeDecl(ctx.Clone(ty));
},
[&](const ast::Override* override) {
if (referenced_vars.count(override)) {
@ -80,37 +84,39 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
// so that its allocated ID so that it won't be affected by other
// stripped away overrides
auto* global = sem.Get(override);
const auto* id = ctx.dst->Id(global->OverrideId());
const auto* id = b.Id(global->OverrideId());
ctx.InsertFront(override->attributes, id);
}
ctx.dst->AST().AddGlobalVariable(ctx.Clone(override));
b.AST().AddGlobalVariable(ctx.Clone(override));
}
},
[&](const ast::Var* var) {
if (referenced_vars.count(var)) {
ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
b.AST().AddGlobalVariable(ctx.Clone(var));
}
},
[&](const ast::Const* c) {
// Always keep 'const' declarations, as these can be used by attributes and array
// sizes, which are not tracked as transitively used by functions. They also don't
// typically get emitted by the backend unless they're actually used.
ctx.dst->AST().AddGlobalVariable(ctx.Clone(c));
b.AST().AddGlobalVariable(ctx.Clone(c));
},
[&](const ast::Function* func) {
if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol)) {
ctx.dst->AST().AddFunction(ctx.Clone(func));
b.AST().AddFunction(ctx.Clone(func));
}
},
[&](const ast::Enable* ext) { ctx.dst->AST().AddEnable(ctx.Clone(ext)); },
[&](const ast::Enable* ext) { b.AST().AddEnable(ctx.Clone(ext)); },
[&](Default) {
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name;
});
}
// Clone the entry point.
ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
b.AST().AddFunction(ctx.Clone(entry_point));
return Program(std::move(b));
}
SingleEntryPoint::Config::Config(std::string entry_point) : entry_point_name(entry_point) {}

View File

@ -53,14 +53,10 @@ class SingleEntryPoint final : public Castable<SingleEntryPoint, Transform> {
/// Destructor
~SingleEntryPoint() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -37,7 +37,7 @@ namespace tint::transform {
using namespace tint::number_suffixes; // NOLINT
/// Private implementation of transform
/// PIMPL state for the transform
struct SpirvAtomic::State {
private:
/// A struct that has been forked because a subset of members were made atomic.
@ -46,19 +46,24 @@ struct SpirvAtomic::State {
std::unordered_set<size_t> atomic_members;
};
CloneContext& ctx;
ProgramBuilder& b = *ctx.dst;
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
std::unordered_map<const ast::Struct*, ForkedStruct> forked_structs;
std::unordered_set<const sem::Variable*> atomic_variables;
utils::UniqueVector<const sem::Expression*, 8> atomic_expressions;
public:
/// Constructor
/// @param c the clone context
explicit State(CloneContext& c) : ctx(c) {}
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// Runs the transform
void Run() {
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
// Look for stub functions generated by the SPIR-V reader, which are used as placeholders
// for atomic builtin calls.
for (auto* fn : ctx.src->AST().Functions()) {
@ -102,6 +107,10 @@ struct SpirvAtomic::State {
}
}
if (atomic_expressions.IsEmpty()) {
return SkipTransform;
}
// Transform all variables and structure members that were used in atomic operations as
// atomic types. This propagates up originating expression chains.
ProcessAtomicExpressions();
@ -143,6 +152,7 @@ struct SpirvAtomic::State {
ReplaceLoadsAndStores();
ctx.Clone();
return Program(std::move(b));
}
private:
@ -297,17 +307,8 @@ const SpirvAtomic::Stub* SpirvAtomic::Stub::Clone(CloneContext* ctx) const {
ctx->dst->AllocateNodeID(), builtin);
}
bool SpirvAtomic::ShouldRun(const Program* program, const DataMap&) const {
for (auto* fn : program->AST().Functions()) {
if (ast::HasAttribute<Stub>(fn->attributes)) {
return true;
}
}
return false;
}
void SpirvAtomic::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State{ctx}.Run();
Transform::ApplyResult SpirvAtomic::Apply(const Program* src, const DataMap&, DataMap&) const {
return State{src}.Run();
}
} // namespace tint::transform

View File

@ -63,21 +63,13 @@ class SpirvAtomic final : public Castable<SpirvAtomic, Transform> {
const sem::BuiltinType builtin;
};
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
protected:
private:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -77,14 +77,20 @@ struct Hasher<DynamicIndex> {
namespace tint::transform {
/// The PIMPL state for the Std140 transform
/// PIMPL state for the transform
struct Std140::State {
/// Constructor
/// @param c the CloneContext
explicit State(CloneContext& c) : ctx(c) {}
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// Runs the transform
void Run() {
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
if (!ShouldRun()) {
// Transform is not required
return SkipTransform;
}
// Begin by creating forked types for any type that is used as a uniform buffer, that
// either directly or transitively contains a matrix that needs splitting for std140 layout.
ForkTypes();
@ -116,11 +122,11 @@ struct Std140::State {
});
ctx.Clone();
return Program(std::move(b));
}
/// @returns true if this transform should be run for the given program
/// @param program the program to inspect
static bool ShouldRun(const Program* program) {
bool ShouldRun() const {
// Returns true if the type needs to be forked for std140 usage.
auto needs_fork = [&](const sem::Type* ty) {
while (auto* arr = ty->As<sem::Array>()) {
@ -135,7 +141,7 @@ struct Std140::State {
};
// Scan structures for members that need forking
for (auto* ty : program->Types()) {
for (auto* ty : src->Types()) {
if (auto* str = ty->As<sem::Struct>()) {
if (str->UsedAs(ast::AddressSpace::kUniform)) {
for (auto* member : str->Members()) {
@ -148,8 +154,8 @@ struct Std140::State {
}
// Scan uniform variables that have types that need forking
for (auto* decl : program->AST().GlobalVariables()) {
auto* global = program->Sem().Get(decl);
for (auto* decl : src->AST().GlobalVariables()) {
auto* global = src->Sem().Get(decl);
if (global->AddressSpace() == ast::AddressSpace::kUniform) {
if (needs_fork(global->Type()->UnwrapRef())) {
return true;
@ -197,14 +203,16 @@ struct Std140::State {
}
};
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext& ctx;
/// Alias to the semantic info in ctx.src
const sem::Info& sem = ctx.src->Sem();
/// Alias to the symbols in ctx.src
const SymbolTable& sym = ctx.src->Symbols();
/// Alias to the ctx.dst program builder
ProgramBuilder& b = *ctx.dst;
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Alias to the semantic info in src
const sem::Info& sem = src->Sem();
/// Alias to the symbols in src
const SymbolTable& sym = src->Symbols();
/// Map of load function signature, to the generated function
utils::Hashmap<LoadFnKey, Symbol, 8, LoadFnKey::Hasher> load_fns;
@ -218,7 +226,7 @@ struct Std140::State {
// Map of original structure to 'std140' forked structure
utils::Hashmap<const sem::Struct*, Symbol, 8> std140_structs;
// Map of structure member in ctx.src of a matrix type, to list of decomposed column
// Map of structure member in src of a matrix type, to list of decomposed column
// members in ctx.dst.
utils::Hashmap<const sem::StructMember*, utils::Vector<const ast::StructMember*, 4>, 8>
std140_mat_members;
@ -232,7 +240,7 @@ struct Std140::State {
utils::Vector<Symbol, 4> columns;
};
// Map of matrix type in ctx.src, to decomposed column structure in ctx.dst.
// Map of matrix type in src, to decomposed column structure in ctx.dst.
utils::Hashmap<const sem::Matrix*, Std140Matrix, 8> std140_mats;
/// AccessChain describes a chain of access expressions to uniform buffer variable.
@ -266,7 +274,7 @@ struct Std140::State {
/// map (via Std140Type()).
void ForkTypes() {
// For each module scope declaration...
for (auto* global : ctx.src->Sem().Module()->DependencyOrderedDeclarations()) {
for (auto* global : src->Sem().Module()->DependencyOrderedDeclarations()) {
// Check to see if this is a structure used by a uniform buffer...
auto* str = sem.Get<sem::Struct>(global);
if (str && str->UsedAs(ast::AddressSpace::kUniform)) {
@ -317,7 +325,7 @@ struct Std140::State {
if (fork_std140) {
// Clone any members that have not already been cloned.
for (auto& member : members) {
if (member->program_id == ctx.src->ID()) {
if (member->program_id == src->ID()) {
member = ctx.Clone(member);
}
}
@ -326,7 +334,7 @@ struct Std140::State {
auto name = b.Symbols().New(sym.NameFor(str->Name()) + "_std140");
auto* std140 = b.create<ast::Struct>(name, std::move(members),
ctx.Clone(str->Declaration()->attributes));
ctx.InsertAfter(ctx.src->AST().GlobalDeclarations(), global, std140);
ctx.InsertAfter(src->AST().GlobalDeclarations(), global, std140);
std140_structs.Add(str, name);
}
}
@ -337,14 +345,13 @@ struct Std140::State {
/// type that has been forked for std140-layout.
/// Populates the #std140_uniforms set.
void ReplaceUniformVarTypes() {
for (auto* global : ctx.src->AST().GlobalVariables()) {
for (auto* global : src->AST().GlobalVariables()) {
if (auto* var = global->As<ast::Var>()) {
if (var->declared_address_space == ast::AddressSpace::kUniform) {
auto* v = sem.Get(var);
if (auto* std140_ty = Std140Type(v->Type()->UnwrapRef())) {
ctx.Replace(global->type, std140_ty);
std140_uniforms.Add(v);
continue;
}
}
}
@ -404,7 +411,7 @@ struct Std140::State {
auto std140_mat = std140_mats.GetOrCreate(mat, [&] {
auto name = b.Symbols().New("mat" + std::to_string(mat->columns()) + "x" +
std::to_string(mat->rows()) + "_" +
ctx.src->FriendlyName(mat->type()));
src->FriendlyName(mat->type()));
auto members =
DecomposedMatrixStructMembers(mat, "col", mat->Align(), mat->Size());
b.Structure(name, members);
@ -421,7 +428,7 @@ struct Std140::State {
if (auto* std140 = Std140Type(arr->ElemType())) {
utils::Vector<const ast::Attribute*, 1> attrs;
if (!arr->IsStrideImplicit()) {
attrs.Push(ctx.dst->create<ast::StrideAttribute>(arr->Stride()));
attrs.Push(b.create<ast::StrideAttribute>(arr->Stride()));
}
auto count = arr->ConstantCount();
if (!count) {
@ -429,7 +436,7 @@ struct Std140::State {
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
@ -440,7 +447,7 @@ struct Std140::State {
});
}
/// @param mat the matrix to decompose (in ctx.src)
/// @param mat the matrix to decompose (in src)
/// @param name_prefix the name prefix to apply to each of the returned column vector members.
/// @param align the alignment in bytes of the matrix.
/// @param size the size in bytes of the matrix.
@ -473,7 +480,7 @@ struct Std140::State {
// Build the member
const auto col_name = name_prefix + std::to_string(i);
const auto* col_ty = CreateASTTypeFor(ctx, mat->ColumnType());
const auto* col_member = ctx.dst->Member(col_name, col_ty, std::move(attributes));
const auto* col_member = b.Member(col_name, col_ty, std::move(attributes));
// Record the member for std140_mat_members
out.Push(col_member);
}
@ -618,7 +625,7 @@ struct Std140::State {
/// @returns a name suffix for a std140 -> non-std140 conversion function based on the type
/// being converted.
const std::string ConvertSuffix(const sem::Type* ty) const {
const std::string ConvertSuffix(const sem::Type* ty) {
return Switch(
ty, //
[&](const sem::Struct* str) { return sym.NameFor(str->Name()); },
@ -629,8 +636,7 @@ struct Std140::State {
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
<< "unexpected non-constant array count";
TINT_ICE(Transform, b.Diagnostics()) << "unexpected non-constant array count";
count = 1;
}
return "arr" + std::to_string(count.value()) + "_" + ConvertSuffix(arr->ElemType());
@ -642,7 +648,7 @@ struct Std140::State {
[&](const sem::F32*) { return "f32"; },
[&](Default) {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled type for conversion name: " << ctx.src->FriendlyName(ty);
<< "unhandled type for conversion name: " << src->FriendlyName(ty);
return "";
});
}
@ -718,8 +724,7 @@ struct Std140::State {
stmts.Push(b.Return(b.Construct(mat_ty, std::move(mat_args))));
} else {
TINT_ICE(Transform, b.Diagnostics())
<< "failed to find std140 matrix info for: "
<< ctx.src->FriendlyName(ty);
<< "failed to find std140 matrix info for: " << src->FriendlyName(ty);
}
}, //
[&](const sem::Array* arr) {
@ -736,7 +741,7 @@ struct Std140::State {
// * Override-expression counts can only be applied to workgroup arrays, and
// this method only handles types transitively used as uniform buffers.
// * Runtime-sized arrays cannot be used in uniform buffers.
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "unexpected non-constant array count";
count = 1;
}
@ -749,7 +754,7 @@ struct Std140::State {
},
[&](Default) {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled type for conversion: " << ctx.src->FriendlyName(ty);
<< "unhandled type for conversion: " << src->FriendlyName(ty);
});
// Generate the function
@ -1063,7 +1068,7 @@ struct Std140::State {
if (std::get_if<UniformVariable>(&access)) {
const auto* expr = b.Expr(ctx.Clone(chain.var->Declaration()->symbol));
const auto name = ctx.src->Symbols().NameFor(chain.var->Declaration()->symbol);
const auto name = src->Symbols().NameFor(chain.var->Declaration()->symbol);
ty = chain.var->Type()->UnwrapRef();
return {expr, ty, name};
}
@ -1090,7 +1095,7 @@ struct Std140::State {
}, //
[&](Default) -> ExprTypeName {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
<< "unhandled type for access chain: " << src->FriendlyName(ty);
return {};
});
}
@ -1104,14 +1109,14 @@ struct Std140::State {
for (auto el : *swizzle) {
rhs += xyzw[el];
}
auto swizzle_ty = ctx.src->Types().Find<sem::Vector>(
auto swizzle_ty = src->Types().Find<sem::Vector>(
vec->type(), static_cast<uint32_t>(swizzle->Length()));
auto* expr = b.MemberAccessor(lhs, rhs);
return {expr, swizzle_ty, rhs};
}, //
[&](Default) -> ExprTypeName {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
<< "unhandled type for access chain: " << src->FriendlyName(ty);
return {};
});
}
@ -1140,7 +1145,7 @@ struct Std140::State {
}, //
[&](Default) -> ExprTypeName {
TINT_ICE(Transform, b.Diagnostics())
<< "unhandled type for access chain: " << ctx.src->FriendlyName(ty);
<< "unhandled type for access chain: " << src->FriendlyName(ty);
return {};
});
}
@ -1150,12 +1155,8 @@ Std140::Std140() = default;
Std140::~Std140() = default;
bool Std140::ShouldRun(const Program* program, const DataMap&) const {
return State::ShouldRun(program);
}
void Std140::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State(ctx).Run();
Transform::ApplyResult Std140::Apply(const Program* src, const DataMap&, DataMap&) const {
return State(src).Run();
}
} // namespace tint::transform

View File

@ -34,21 +34,13 @@ class Std140 final : public Castable<Std140, Transform> {
/// Destructor
~Std140() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -15,6 +15,7 @@
#include "src/tint/transform/substitute_override.h"
#include <functional>
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/builtin.h"
@ -25,12 +26,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride);
TINT_INSTANTIATE_TYPEINFO(tint::transform::SubstituteOverride::Config);
namespace tint::transform {
namespace {
SubstituteOverride::SubstituteOverride() = default;
SubstituteOverride::~SubstituteOverride() = default;
bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* node : program->AST().GlobalVariables()) {
if (node->Is<ast::Override>()) {
return true;
@ -39,18 +37,32 @@ bool SubstituteOverride::ShouldRun(const Program* program, const DataMap&) const
return false;
}
void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&) const {
} // namespace
SubstituteOverride::SubstituteOverride() = default;
SubstituteOverride::~SubstituteOverride() = default;
Transform::ApplyResult SubstituteOverride::Apply(const Program* src,
const DataMap& config,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
const auto* data = config.Get<Config>();
if (!data) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
"Missing override substitution data");
return;
b.Diagnostics().add_error(diag::System::Transform, "Missing override substitution data");
return Program(std::move(b));
}
if (!ShouldRun(ctx.src)) {
return SkipTransform;
}
ctx.ReplaceAll([&](const ast::Override* w) -> const ast::Const* {
auto* sem = ctx.src->Sem().Get(w);
auto src = ctx.Clone(w->source);
auto source = ctx.Clone(w->source);
auto sym = ctx.Clone(w->symbol);
auto* ty = ctx.Clone(w->type);
@ -58,30 +70,30 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&)
auto iter = data->map.find(sem->OverrideId());
if (iter == data->map.end()) {
if (!w->initializer) {
ctx.dst->Diagnostics().add_error(
b.Diagnostics().add_error(
diag::System::Transform,
"Initializer not provided for override, and override not overridden.");
return nullptr;
}
return ctx.dst->Const(src, sym, ty, ctx.Clone(w->initializer));
return b.Const(source, sym, ty, ctx.Clone(w->initializer));
}
auto value = iter->second;
auto* ctor = Switch(
sem->Type(),
[&](const sem::Bool*) { return ctx.dst->Expr(!std::equal_to<double>()(value, 0.0)); },
[&](const sem::I32*) { return ctx.dst->Expr(i32(value)); },
[&](const sem::U32*) { return ctx.dst->Expr(u32(value)); },
[&](const sem::F32*) { return ctx.dst->Expr(f32(value)); },
[&](const sem::F16*) { return ctx.dst->Expr(f16(value)); });
[&](const sem::Bool*) { return b.Expr(!std::equal_to<double>()(value, 0.0)); },
[&](const sem::I32*) { return b.Expr(i32(value)); },
[&](const sem::U32*) { return b.Expr(u32(value)); },
[&](const sem::F32*) { return b.Expr(f32(value)); },
[&](const sem::F16*) { return b.Expr(f16(value)); });
if (!ctor) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
"Failed to create override-expression");
b.Diagnostics().add_error(diag::System::Transform,
"Failed to create override-expression");
return nullptr;
}
return ctx.dst->Const(src, sym, ty, ctor);
return b.Const(source, sym, ty, ctor);
});
// Ensure that objects that are indexed with an override-expression are materialized.
@ -89,11 +101,10 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&)
// resulting type of the index may change. See: crbug.com/tint/1697.
ctx.ReplaceAll(
[&](const ast::IndexAccessorExpression* expr) -> const ast::IndexAccessorExpression* {
if (auto* sem = ctx.src->Sem().Get(expr)) {
if (auto* sem = src->Sem().Get(expr)) {
if (auto* access = sem->UnwrapMaterialize()->As<sem::IndexAccessorExpression>()) {
if (access->Object()->UnwrapMaterialize()->Type()->HoldsAbstract() &&
access->Index()->Stage() == sem::EvaluationStage::kOverride) {
auto& b = *ctx.dst;
auto* obj = b.Call(sem::str(sem::BuiltinType::kTintMaterialize),
ctx.Clone(expr->object));
return b.IndexAccessor(obj, ctx.Clone(expr->index));
@ -104,6 +115,7 @@ void SubstituteOverride::Run(CloneContext& ctx, const DataMap& config, DataMap&)
});
ctx.Clone();
return Program(std::move(b));
}
SubstituteOverride::Config::Config() = default;

View File

@ -75,19 +75,10 @@ class SubstituteOverride final : public Castable<SubstituteOverride, Transform>
/// Destructor
~SubstituteOverride() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -122,7 +122,18 @@ class TransformTestBase : public BASE {
}
const Transform& t = TRANSFORM();
return t.ShouldRun(&program, data);
DataMap outputs;
auto result = t.Apply(&program, data, outputs);
if (!result) {
return false;
}
if (!result->IsValid()) {
ADD_FAILURE() << "Apply() called by ShouldRun() returned errors: "
<< result->Diagnostics().str();
return true;
}
return result.has_value();
}
/// @param in the input WGSL source

View File

@ -46,24 +46,19 @@ Output::Output(Program&& p) : program(std::move(p)) {}
Transform::Transform() = default;
Transform::~Transform() = default;
Output Transform::Run(const Program* program, const DataMap& data /* = {} */) const {
ProgramBuilder builder;
CloneContext ctx(&builder, program);
Output Transform::Run(const Program* src, const DataMap& data /* = {} */) const {
Output output;
Run(ctx, data, output.data);
output.program = Program(std::move(builder));
if (auto program = Apply(src, data, output.data)) {
output.program = std::move(program.value());
} else {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
ctx.Clone();
output.program = Program(std::move(b));
}
return output;
}
void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
<< "Transform::Run() unimplemented for " << TypeInfo().name;
}
bool Transform::ShouldRun(const Program*, const DataMap&) const {
return true;
}
void Transform::RemoveStatement(CloneContext& ctx, const ast::Statement* stmt) {
auto* sem = ctx.src->Sem().Get(stmt);
if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {

View File

@ -158,26 +158,30 @@ class Transform : public Castable<Transform> {
/// Destructor
~Transform() override;
/// Runs the transform on `program`, returning the transformation result.
/// Runs the transform on @p program, returning the transformation result or a clone of
/// @p program.
/// @param program the source program to transform
/// @param data optional extra transform-specific input data
/// @returns the transformation result
virtual Output Run(const Program* program, const DataMap& data = {}) const;
Output Run(const Program* program, const DataMap& data = {}) const;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
virtual bool ShouldRun(const Program* program, const DataMap& data = {}) const;
/// The return value of Apply().
/// If SkipTransform (std::nullopt), then the transform is not needed to be run.
using ApplyResult = std::optional<Program>;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// Value returned from Apply() to indicate that the transform does not need to be run
static inline constexpr std::nullopt_t SkipTransform = std::nullopt;
/// Runs the transform on `program`, return.
/// @param program the input program
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const;
/// @returns a transformed program, or std::nullopt if the transform didn't need to run.
virtual ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const = 0;
protected:
/// Removes the statement `stmt` from the transformed program.
/// RemoveStatement handles edge cases, like statements in the initializer and
/// continuing of for-loops.

View File

@ -23,7 +23,9 @@ namespace {
// Inherit from Transform so we have access to protected methods
struct CreateASTTypeForTest : public testing::Test, public Transform {
Output Run(const Program*, const DataMap&) const override { return {}; }
ApplyResult Apply(const Program*, const DataMap&, DataMap&) const override {
return SkipTransform;
}
const ast::Type* create(std::function<sem::Type*(ProgramBuilder&)> create_sem_type) {
ProgramBuilder sem_type_builder;

View File

@ -28,27 +28,32 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::Unshadow);
namespace tint::transform {
/// The PIMPL state for the Unshadow transform
/// PIMPL state for the transform
struct Unshadow::State {
/// The source program
const Program* const src;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext& ctx;
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
/// Constructor
/// @param context the clone context
explicit State(CloneContext& context) : ctx(context) {}
/// @param program the source program
explicit State(const Program* program) : src(program) {}
/// Performs the transformation
void Run() {
auto& sem = ctx.src->Sem();
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
Transform::ApplyResult Run() {
auto& sem = src->Sem();
// Maps a variable to its new name.
std::unordered_map<const sem::Variable*, Symbol> renamed_to;
utils::Hashmap<const sem::Variable*, Symbol, 8> renamed_to;
auto rename = [&](const sem::Variable* v) -> const ast::Variable* {
auto* decl = v->Declaration();
auto name = ctx.src->Symbols().NameFor(decl->symbol);
auto symbol = ctx.dst->Symbols().New(name);
renamed_to.emplace(v, symbol);
auto name = src->Symbols().NameFor(decl->symbol);
auto symbol = b.Symbols().New(name);
renamed_to.Add(v, symbol);
auto source = ctx.Clone(decl->source);
auto* type = ctx.Clone(decl->type);
@ -57,20 +62,20 @@ struct Unshadow::State {
return Switch(
decl, //
[&](const ast::Var* var) {
return ctx.dst->Var(source, symbol, type, var->declared_address_space,
var->declared_access, initializer, attributes);
return b.Var(source, symbol, type, var->declared_address_space,
var->declared_access, initializer, attributes);
},
[&](const ast::Let*) {
return ctx.dst->Let(source, symbol, type, initializer, attributes);
return b.Let(source, symbol, type, initializer, attributes);
},
[&](const ast::Const*) {
return ctx.dst->Const(source, symbol, type, initializer, attributes);
return b.Const(source, symbol, type, initializer, attributes);
},
[&](const ast::Parameter*) {
return ctx.dst->Param(source, symbol, type, attributes);
[&](const ast::Parameter*) { //
return b.Param(source, symbol, type, attributes);
},
[&](Default) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "unexpected variable type: " << decl->TypeInfo().name;
return nullptr;
});
@ -92,14 +97,15 @@ struct Unshadow::State {
ctx.ReplaceAll(
[&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* {
if (auto* user = sem.Get<sem::VariableUser>(ident)) {
auto it = renamed_to.find(user->Variable());
if (it != renamed_to.end()) {
return ctx.dst->Expr(it->second);
if (auto* renamed = renamed_to.Find(user->Variable())) {
return b.Expr(*renamed);
}
}
return nullptr;
});
ctx.Clone();
return Program(std::move(b));
}
};
@ -107,8 +113,8 @@ Unshadow::Unshadow() = default;
Unshadow::~Unshadow() = default;
void Unshadow::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
State(ctx).Run();
Transform::ApplyResult Unshadow::Apply(const Program* src, const DataMap&, DataMap&) const {
return State(src).Run();
}
} // namespace tint::transform

View File

@ -29,16 +29,13 @@ class Unshadow final : public Castable<Unshadow, Transform> {
/// Destructor
~Unshadow() override;
protected:
struct State;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
private:
struct State;
};
} // namespace tint::transform

View File

@ -35,7 +35,51 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions);
namespace tint::transform {
namespace {
class State {
bool ShouldRun(const Program* program) {
auto& sem = program->Sem();
for (auto* f : program->AST().Functions()) {
if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
return true;
}
}
return false;
}
} // namespace
/// PIMPL state for the transform
struct UnwindDiscardFunctions::State {
/// Constructor
/// @param ctx_in the context
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
/// Runs the transform
void Run() {
ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
// Iterate block statements and replace them as needed.
for (auto* stmt : block->statements) {
if (auto* new_stmt = Statement(stmt)) {
ctx.Replace(stmt, new_stmt);
}
// Handle for loops, as they are the only other AST node that
// contains statements outside of BlockStatements.
if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
if (auto* new_stmt = Statement(fl->initializer)) {
ctx.Replace(fl->initializer, new_stmt);
}
if (auto* new_stmt = Statement(fl->continuing)) {
// NOTE: Should never reach here as we cannot discard in a
// continuing block.
ctx.Replace(fl->continuing, new_stmt);
}
}
}
return nullptr;
});
}
private:
CloneContext& ctx;
ProgramBuilder& b;
@ -163,7 +207,7 @@ class State {
// Returns true if `stmt` is a for-loop initializer statement.
bool IsForLoopInitStatement(const ast::Statement* stmt) {
if (auto* sem_stmt = sem.Get(stmt)) {
if (auto* sem_fl = As<sem::ForLoopStatement>(sem_stmt->Parent())) {
if (auto* sem_fl = tint::As<sem::ForLoopStatement>(sem_stmt->Parent())) {
return sem_fl->Declaration()->initializer == stmt;
}
}
@ -305,60 +349,26 @@ class State {
return TryInsertAfter(s, sem_expr);
});
}
public:
/// Constructor
/// @param ctx_in the context
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst), sem(ctx_in.src->Sem()) {}
/// Runs the transform
void Run() {
ctx.ReplaceAll([&](const ast::BlockStatement* block) -> const ast::Statement* {
// Iterate block statements and replace them as needed.
for (auto* stmt : block->statements) {
if (auto* new_stmt = Statement(stmt)) {
ctx.Replace(stmt, new_stmt);
}
// Handle for loops, as they are the only other AST node that
// contains statements outside of BlockStatements.
if (auto* fl = stmt->As<ast::ForLoopStatement>()) {
if (auto* new_stmt = Statement(fl->initializer)) {
ctx.Replace(fl->initializer, new_stmt);
}
if (auto* new_stmt = Statement(fl->continuing)) {
// NOTE: Should never reach here as we cannot discard in a
// continuing block.
ctx.Replace(fl->continuing, new_stmt);
}
}
}
return nullptr;
});
ctx.Clone();
}
};
} // namespace
UnwindDiscardFunctions::UnwindDiscardFunctions() = default;
UnwindDiscardFunctions::~UnwindDiscardFunctions() = default;
void UnwindDiscardFunctions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
Transform::ApplyResult UnwindDiscardFunctions::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
State state(ctx);
state.Run();
}
bool UnwindDiscardFunctions::ShouldRun(const Program* program, const DataMap& /*data*/) const {
auto& sem = program->Sem();
for (auto* f : program->AST().Functions()) {
if (sem.Get(f)->Behaviors().Contains(sem::Behavior::kDiscard)) {
return true;
}
}
return false;
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -44,19 +44,13 @@ class UnwindDiscardFunctions final : public Castable<UnwindDiscardFunctions, Tra
/// Destructor
~UnwindDiscardFunctions() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
private:
struct State;
};
} // namespace tint::transform

View File

@ -30,7 +30,59 @@
namespace tint::transform {
/// Private implementation of HoistToDeclBefore transform
class HoistToDeclBefore::State {
struct HoistToDeclBefore::State {
/// Constructor
/// @param ctx_in the clone context
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
/// @copydoc HoistToDeclBefore::Add()
bool Add(const sem::Expression* before_expr,
const ast::Expression* expr,
bool as_let,
const char* decl_name) {
auto name = b.Symbols().New(decl_name);
if (as_let) {
auto builder = [this, expr, name] {
return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr)));
};
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
return false;
}
} else {
auto builder = [this, expr, name] {
return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr)));
};
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
return false;
}
}
// Replace the initializer expression with a reference to the let
ctx.Replace(expr, b.Expr(name));
return true;
}
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
if (stmt) {
auto builder = [stmt] { return stmt; };
return InsertBeforeImpl(before_stmt, std::move(builder));
}
return InsertBeforeImpl(before_stmt, Decompose{});
}
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) {
return InsertBeforeImpl(before_stmt, std::move(builder));
}
/// @copydoc HoistToDeclBefore::Prepare()
bool Prepare(const sem::Expression* before_expr) {
return InsertBefore(before_expr->Stmt(), nullptr);
}
private:
CloneContext& ctx;
ProgramBuilder& b;
@ -215,6 +267,8 @@ class HoistToDeclBefore::State {
template <typename BUILDER>
bool InsertBeforeImpl(const sem::Statement* before_stmt, BUILDER&& builder) {
(void)builder; // Avoid 'unused parameter' warning due to 'if constexpr'
auto* ip = before_stmt->Declaration();
auto* else_if = before_stmt->As<sem::IfStatement>();
@ -299,58 +353,6 @@ class HoistToDeclBefore::State {
<< "unhandled expression parent statement type: " << parent->TypeInfo().name;
return false;
}
public:
/// Constructor
/// @param ctx_in the clone context
explicit State(CloneContext& ctx_in) : ctx(ctx_in), b(*ctx_in.dst) {}
/// @copydoc HoistToDeclBefore::Add()
bool Add(const sem::Expression* before_expr,
const ast::Expression* expr,
bool as_let,
const char* decl_name) {
auto name = b.Symbols().New(decl_name);
if (as_let) {
auto builder = [this, expr, name] {
return b.Decl(b.Let(name, ctx.CloneWithoutTransform(expr)));
};
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
return false;
}
} else {
auto builder = [this, expr, name] {
return b.Decl(b.Var(name, ctx.CloneWithoutTransform(expr)));
};
if (!InsertBeforeImpl(before_expr->Stmt(), std::move(builder))) {
return false;
}
}
// Replace the initializer expression with a reference to the let
ctx.Replace(expr, b.Expr(name));
return true;
}
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const ast::Statement*)
bool InsertBefore(const sem::Statement* before_stmt, const ast::Statement* stmt) {
if (stmt) {
auto builder = [stmt] { return stmt; };
return InsertBeforeImpl(before_stmt, std::move(builder));
}
return InsertBeforeImpl(before_stmt, Decompose{});
}
/// @copydoc HoistToDeclBefore::InsertBefore(const sem::Statement*, const StmtBuilder&)
bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder) {
return InsertBeforeImpl(before_stmt, std::move(builder));
}
/// @copydoc HoistToDeclBefore::Prepare()
bool Prepare(const sem::Expression* before_expr) {
return InsertBefore(before_expr->Stmt(), nullptr);
}
};
HoistToDeclBefore::HoistToDeclBefore(CloneContext& ctx) : state_(std::make_unique<State>(ctx)) {}

View File

@ -77,7 +77,7 @@ class HoistToDeclBefore {
bool Prepare(const sem::Expression* before_expr);
private:
class State;
struct State;
std::unique_ptr<State> state_;
};

View File

@ -13,6 +13,9 @@
// limitations under the License.
#include "src/tint/transform/var_for_dynamic_index.h"
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
@ -22,7 +25,12 @@ VarForDynamicIndex::VarForDynamicIndex() = default;
VarForDynamicIndex::~VarForDynamicIndex() = default;
void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
Transform::ApplyResult VarForDynamicIndex::Apply(const Program* src,
const DataMap&,
DataMap&) const {
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
HoistToDeclBefore hoist_to_decl_before(ctx);
// Extracts array and matrix values that are dynamically indexed to a
@ -30,7 +38,7 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const
auto dynamic_index_to_var = [&](const ast::IndexAccessorExpression* access_expr) {
auto* index_expr = access_expr->index;
auto* object_expr = access_expr->object;
auto& sem = ctx.src->Sem();
auto& sem = src->Sem();
if (sem.Get(index_expr)->ConstantValue()) {
// Index expression resolves to a compile time value.
@ -49,15 +57,21 @@ void VarForDynamicIndex::Run(CloneContext& ctx, const DataMap&, DataMap&) const
return hoist_to_decl_before.Add(indexed, object_expr, false, "var_for_index");
};
for (auto* node : ctx.src->ASTNodes().Objects()) {
bool index_accessor_found = false;
for (auto* node : src->ASTNodes().Objects()) {
if (auto* access_expr = node->As<ast::IndexAccessorExpression>()) {
if (!dynamic_index_to_var(access_expr)) {
return;
return Program(std::move(b));
}
index_accessor_found = true;
}
}
if (!index_accessor_found) {
return SkipTransform;
}
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -31,14 +31,10 @@ class VarForDynamicIndex : public Transform {
/// Destructor
~VarForDynamicIndex() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -30,11 +30,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeMatrixConversions);
namespace tint::transform {
VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
namespace {
VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* sem = program->Sem().Get<sem::Expression>(node)) {
if (auto* call = sem->UnwrapMaterialize()->As<sem::Call>()) {
@ -50,14 +48,29 @@ bool VectorizeMatrixConversions::ShouldRun(const Program* program, const DataMap
return false;
}
void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
} // namespace
VectorizeMatrixConversions::VectorizeMatrixConversions() = default;
VectorizeMatrixConversions::~VectorizeMatrixConversions() = default;
Transform::ApplyResult VectorizeMatrixConversions::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
using HelperFunctionKey =
utils::UnorderedKeyWrapper<std::tuple<const sem::Matrix*, const sem::Matrix*>>;
std::unordered_map<HelperFunctionKey, Symbol> matrix_convs;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_conv = call->Target()->As<sem::TypeConversion>();
if (!ty_conv) {
return nullptr;
@ -72,16 +85,16 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&
return nullptr;
}
auto& src = args[0];
auto& matrix = args[0];
auto* src_type = args[0]->Type()->UnwrapRef()->As<sem::Matrix>();
auto* src_type = matrix->Type()->UnwrapRef()->As<sem::Matrix>();
if (!src_type) {
return nullptr;
}
// The source and destination type of a matrix conversion must have a same shape.
if (!(src_type->rows() == dst_type->rows() && src_type->columns() == dst_type->columns())) {
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "source and destination matrix has different shape in matrix conversion";
return nullptr;
}
@ -90,47 +103,45 @@ void VectorizeMatrixConversions::Run(CloneContext& ctx, const DataMap&, DataMap&
utils::Vector<const ast::Expression*, 4> columns;
for (uint32_t c = 0; c < dst_type->columns(); c++) {
auto* src_matrix_expr = src_expression_builder();
auto* src_column_expr =
ctx.dst->IndexAccessor(src_matrix_expr, ctx.dst->Expr(tint::AInt(c)));
columns.Push(ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()),
src_column_expr));
auto* src_column_expr = b.IndexAccessor(src_matrix_expr, b.Expr(tint::AInt(c)));
columns.Push(
b.Construct(CreateASTTypeFor(ctx, dst_type->ColumnType()), src_column_expr));
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, dst_type), columns);
return b.Construct(CreateASTTypeFor(ctx, dst_type), columns);
};
// Replace the matrix conversion to column vector conversions and a matrix construction.
if (!src->HasSideEffects()) {
if (!matrix->HasSideEffects()) {
// Simply use the argument's declaration if it has no side effects.
return build_vectorized_conversion_expression([&]() { //
return ctx.Clone(src->Declaration());
return ctx.Clone(matrix->Declaration());
});
} else {
// If has side effects, use a helper function.
auto fn =
utils::GetOrCreate(matrix_convs, HelperFunctionKey{{src_type, dst_type}}, [&] {
auto name =
ctx.dst->Symbols().New("convert_mat" + std::to_string(src_type->columns()) +
"x" + std::to_string(src_type->rows()) + "_" +
ctx.dst->FriendlyName(src_type->type()) + "_" +
ctx.dst->FriendlyName(dst_type->type()));
ctx.dst->Func(
name,
utils::Vector{
ctx.dst->Param("value", CreateASTTypeFor(ctx, src_type)),
},
CreateASTTypeFor(ctx, dst_type),
utils::Vector{
ctx.dst->Return(build_vectorized_conversion_expression([&]() { //
return ctx.dst->Expr("value");
})),
});
auto name = b.Symbols().New(
"convert_mat" + std::to_string(src_type->columns()) + "x" +
std::to_string(src_type->rows()) + "_" + b.FriendlyName(src_type->type()) +
"_" + b.FriendlyName(dst_type->type()));
b.Func(name,
utils::Vector{
b.Param("value", CreateASTTypeFor(ctx, src_type)),
},
CreateASTTypeFor(ctx, dst_type),
utils::Vector{
b.Return(build_vectorized_conversion_expression([&]() { //
return b.Expr("value");
})),
});
return name;
});
return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
return b.Call(fn, ctx.Clone(args[0]->Declaration()));
}
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -28,19 +28,10 @@ class VectorizeMatrixConversions final : public Castable<VectorizeMatrixConversi
/// Destructor
~VectorizeMatrixConversions() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -27,12 +27,9 @@
TINT_INSTANTIATE_TYPEINFO(tint::transform::VectorizeScalarMatrixInitializers);
namespace tint::transform {
namespace {
VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default;
VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default;
bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (auto* call = program->Sem().Get<sem::Call>(node)) {
if (call->Target()->Is<sem::TypeInitializer>() && call->Type()->Is<sem::Matrix>()) {
@ -46,11 +43,26 @@ bool VectorizeScalarMatrixInitializers::ShouldRun(const Program* program, const
return false;
}
void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
} // namespace
VectorizeScalarMatrixInitializers::VectorizeScalarMatrixInitializers() = default;
VectorizeScalarMatrixInitializers::~VectorizeScalarMatrixInitializers() = default;
Transform::ApplyResult VectorizeScalarMatrixInitializers::Apply(const Program* src,
const DataMap&,
DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
std::unordered_map<const sem::Matrix*, Symbol> scalar_inits;
ctx.ReplaceAll([&](const ast::CallExpression* expr) -> const ast::CallExpression* {
auto* call = ctx.src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* call = src->Sem().Get(expr)->UnwrapMaterialize()->As<sem::Call>();
auto* ty_init = call->Target()->As<sem::TypeInitializer>();
if (!ty_init) {
return nullptr;
@ -87,10 +99,10 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D
}
// Construct the column vector.
columns.Push(ctx.dst->vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
std::move(row_values)));
columns.Push(b.vec(CreateASTTypeFor(ctx, mat_type->type()), mat_type->rows(),
std::move(row_values)));
}
return ctx.dst->Construct(CreateASTTypeFor(ctx, mat_type), columns);
return b.Construct(CreateASTTypeFor(ctx, mat_type), columns);
};
if (args.Length() == 1) {
@ -98,23 +110,22 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D
// This is done to ensure that the single argument value is only evaluated once, and
// with the correct expression evaluation order.
auto fn = utils::GetOrCreate(scalar_inits, mat_type, [&] {
auto name =
ctx.dst->Symbols().New("build_mat" + std::to_string(mat_type->columns()) + "x" +
std::to_string(mat_type->rows()));
ctx.dst->Func(name,
utils::Vector{
// Single scalar parameter
ctx.dst->Param("value", CreateASTTypeFor(ctx, mat_type->type())),
},
CreateASTTypeFor(ctx, mat_type),
utils::Vector{
ctx.dst->Return(build_mat([&](uint32_t, uint32_t) { //
return ctx.dst->Expr("value");
})),
});
auto name = b.Symbols().New("build_mat" + std::to_string(mat_type->columns()) +
"x" + std::to_string(mat_type->rows()));
b.Func(name,
utils::Vector{
// Single scalar parameter
b.Param("value", CreateASTTypeFor(ctx, mat_type->type())),
},
CreateASTTypeFor(ctx, mat_type),
utils::Vector{
b.Return(build_mat([&](uint32_t, uint32_t) { //
return b.Expr("value");
})),
});
return name;
});
return ctx.dst->Call(fn, ctx.Clone(args[0]->Declaration()));
return b.Call(fn, ctx.Clone(args[0]->Declaration()));
}
if (args.Length() == mat_type->columns() * mat_type->rows()) {
@ -123,12 +134,13 @@ void VectorizeScalarMatrixInitializers::Run(CloneContext& ctx, const DataMap&, D
});
}
TINT_ICE(Transform, ctx.dst->Diagnostics())
TINT_ICE(Transform, b.Diagnostics())
<< "matrix initializer has unexpected number of arguments";
return nullptr;
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

View File

@ -29,19 +29,10 @@ class VectorizeScalarMatrixInitializers final
/// Destructor
~VectorizeScalarMatrixInitializers() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
};
} // namespace tint::transform

View File

@ -201,13 +201,46 @@ DataType DataTypeOf(VertexFormat format) {
return {BaseType::kInvalid, 0};
}
struct State {
State(CloneContext& context, const VertexPulling::Config& c) : ctx(context), cfg(c) {}
State(const State&) = default;
~State() = default;
} // namespace
/// LocationReplacement describes an ast::Variable replacement for a
/// location input.
/// PIMPL state for the transform
struct VertexPulling::State {
/// Constructor
/// @param program the source program
/// @param c the VertexPulling config
State(const Program* program, const VertexPulling::Config& c) : src(program), cfg(c) {}
/// Runs the transform
/// @returns the new program or SkipTransform if the transform is not required
ApplyResult Run() {
// Find entry point
const ast::Function* func = nullptr;
for (auto* fn : src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
if (func != nullptr) {
b.Diagnostics().add_error(
diag::System::Transform,
"VertexPulling found more than one vertex entry point");
return Program(std::move(b));
}
func = fn;
}
}
if (func == nullptr) {
b.Diagnostics().add_error(diag::System::Transform,
"Vertex stage entry point not found");
return Program(std::move(b));
}
AddVertexStorageBuffers();
Process(func);
ctx.Clone();
return Program(std::move(b));
}
private:
/// LocationReplacement describes an ast::Variable replacement for a location input.
struct LocationReplacement {
/// The variable to replace in the source Program
ast::Variable* from;
@ -215,13 +248,22 @@ struct State {
ast::Variable* to;
};
/// LocationInfo describes an input location
struct LocationInfo {
/// A builder that builds the expression that resolves to the (transformed) input location
std::function<const ast::Expression*()> expr;
/// The store type of the location variable
const sem::Type* type;
};
CloneContext& ctx;
/// The source program
const Program* const src;
/// The transform config
VertexPulling::Config const cfg;
/// The target program builder
ProgramBuilder b;
/// The clone context
CloneContext ctx = {&b, src, /* auto_clone_symbols */ true};
std::unordered_map<uint32_t, LocationInfo> location_info;
std::function<const ast::Expression*()> vertex_index_expr = nullptr;
std::function<const ast::Expression*()> instance_index_expr = nullptr;
@ -235,7 +277,7 @@ struct State {
Symbol GetVertexBufferName(uint32_t index) {
return utils::GetOrCreate(vertex_buffer_names, index, [&] {
static const char kVertexBufferNamePrefix[] = "tint_pulling_vertex_buffer_";
return ctx.dst->Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
return b.Symbols().New(kVertexBufferNamePrefix + std::to_string(index));
});
}
@ -243,7 +285,7 @@ struct State {
Symbol GetStructBufferName() {
if (!struct_buffer_name.IsValid()) {
static const char kStructBufferName[] = "tint_vertex_data";
struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
struct_buffer_name = b.Symbols().New(kStructBufferName);
}
return struct_buffer_name;
}
@ -252,21 +294,19 @@ struct State {
void AddVertexStorageBuffers() {
// Creating the struct type
static const char kStructName[] = "TintVertexData";
auto* struct_type =
ctx.dst->Structure(ctx.dst->Symbols().New(kStructName),
utils::Vector{
ctx.dst->Member(GetStructBufferName(), ctx.dst->ty.array<u32>()),
});
auto* struct_type = b.Structure(b.Symbols().New(kStructName),
utils::Vector{
b.Member(GetStructBufferName(), b.ty.array<u32>()),
});
for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
// The decorated variable with struct type
ctx.dst->GlobalVar(GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
ast::AddressSpace::kStorage, ast::Access::kRead,
ctx.dst->Binding(AInt(i)), ctx.dst->Group(AInt(cfg.pulling_group)));
b.GlobalVar(GetVertexBufferName(i), b.ty.Of(struct_type), ast::AddressSpace::kStorage,
ast::Access::kRead, b.Binding(AInt(i)), b.Group(AInt(cfg.pulling_group)));
}
}
/// Creates and returns the assignment to the variables from the buffers
ast::BlockStatement* CreateVertexPullingPreamble() {
const ast::BlockStatement* CreateVertexPullingPreamble() {
// Assign by looking at the vertex descriptor to find attributes with
// matching location.
@ -276,7 +316,7 @@ struct State {
const VertexBufferLayoutDescriptor& buffer_layout = cfg.vertex_state[buffer_idx];
if ((buffer_layout.array_stride & 3) != 0) {
ctx.dst->Diagnostics().add_error(
b.Diagnostics().add_error(
diag::System::Transform,
"WebGPU requires that vertex stride must be a multiple of 4 bytes, "
"but VertexPulling array stride for buffer " +
@ -292,15 +332,15 @@ struct State {
// buffer_array_base is the base array offset for all the vertex
// attributes. These are units of uint (4 bytes).
auto buffer_array_base =
ctx.dst->Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
b.Symbols().New("buffer_array_base_" + std::to_string(buffer_idx));
auto* attribute_offset = index_expr;
if (buffer_layout.array_stride != 4) {
attribute_offset = ctx.dst->Mul(index_expr, u32(buffer_layout.array_stride / 4u));
attribute_offset = b.Mul(index_expr, u32(buffer_layout.array_stride / 4u));
}
// let pulling_offset_n = <attribute_offset>
stmts.Push(ctx.dst->Decl(ctx.dst->Let(buffer_array_base, attribute_offset)));
stmts.Push(b.Decl(b.Let(buffer_array_base, attribute_offset)));
for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) {
auto it = location_info.find(attribute_desc.shader_location);
@ -320,8 +360,8 @@ struct State {
err << "VertexAttributeDescriptor for location "
<< std::to_string(attribute_desc.shader_location) << " has format "
<< attribute_desc.format << " but shader expects "
<< var.type->FriendlyName(ctx.src->Symbols());
ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
<< var.type->FriendlyName(src->Symbols());
b.Diagnostics().add_error(diag::System::Transform, err.str());
return nullptr;
}
@ -337,16 +377,16 @@ struct State {
// WGSL variable vector width is smaller than the loaded vector width
switch (var_dt.width) {
case 1:
value = ctx.dst->MemberAccessor(fetch, "x");
value = b.MemberAccessor(fetch, "x");
break;
case 2:
value = ctx.dst->MemberAccessor(fetch, "xy");
value = b.MemberAccessor(fetch, "xy");
break;
case 3:
value = ctx.dst->MemberAccessor(fetch, "xyz");
value = b.MemberAccessor(fetch, "xyz");
break;
default:
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.width;
TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.width;
return nullptr;
}
} else if (var_dt.width > fmt_dt.width) {
@ -355,32 +395,32 @@ struct State {
utils::Vector<const ast::Expression*, 8> values{fetch};
switch (var_dt.base_type) {
case BaseType::kI32:
ty = ctx.dst->ty.i32();
ty = b.ty.i32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
values.Push(ctx.dst->Expr((i == 3) ? 1_i : 0_i));
values.Push(b.Expr((i == 3) ? 1_i : 0_i));
}
break;
case BaseType::kU32:
ty = ctx.dst->ty.u32();
ty = b.ty.u32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
values.Push(ctx.dst->Expr((i == 3) ? 1_u : 0_u));
values.Push(b.Expr((i == 3) ? 1_u : 0_u));
}
break;
case BaseType::kF32:
ty = ctx.dst->ty.f32();
ty = b.ty.f32();
for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
values.Push(ctx.dst->Expr((i == 3) ? 1_f : 0_f));
values.Push(b.Expr((i == 3) ? 1_f : 0_f));
}
break;
default:
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics()) << var_dt.base_type;
TINT_UNREACHABLE(Transform, b.Diagnostics()) << var_dt.base_type;
return nullptr;
}
value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
value = b.Construct(b.ty.vec(ty, var_dt.width), values);
}
// Assign the value to the WGSL variable
stmts.Push(ctx.dst->Assign(var.expr(), value));
stmts.Push(b.Assign(var.expr(), value));
}
}
@ -388,7 +428,7 @@ struct State {
return nullptr;
}
return ctx.dst->create<ast::BlockStatement>(std::move(stmts));
return b.Block(std::move(stmts));
}
/// Generates an expression reading from a buffer a specific format.
@ -407,7 +447,7 @@ struct State {
};
// Returns a i32 loaded from buffer_base + offset.
auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
auto load_i32 = [&] { return b.Bitcast<i32>(load_u32()); };
// Returns a u32 loaded from buffer_base + offset + 4.
auto load_next_u32 = [&] {
@ -415,7 +455,7 @@ struct State {
};
// Returns a i32 loaded from buffer_base + offset + 4.
auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
auto load_next_i32 = [&] { return b.Bitcast<i32>(load_next_u32()); };
// Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
// The low 16 bits are 0.
@ -427,17 +467,17 @@ struct State {
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
switch (offset & 3) {
case 0:
return ctx.dst->Shl(low_u32, 16_u);
return b.Shl(low_u32, 16_u);
case 1:
return ctx.dst->And(ctx.dst->Shl(low_u32, 8_u), 0xffff0000_u);
return b.And(b.Shl(low_u32, 8_u), 0xffff0000_u);
case 2:
return ctx.dst->And(low_u32, 0xffff0000_u);
return b.And(low_u32, 0xffff0000_u);
default: { // 3:
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
VertexFormat::kUint32);
auto* shr = ctx.dst->Shr(low_u32, 8_u);
auto* shl = ctx.dst->Shl(high_u32, 24_u);
return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000_u);
auto* shr = b.Shr(low_u32, 8_u);
auto* shl = b.Shl(high_u32, 24_u);
return b.And(b.Or(shl, shr), 0xffff0000_u);
}
}
};
@ -450,24 +490,24 @@ struct State {
LoadPrimitive(array_base, low_u32_offset, buffer, VertexFormat::kUint32);
switch (offset & 3) {
case 0:
return ctx.dst->And(low_u32, 0xffff_u);
return b.And(low_u32, 0xffff_u);
case 1:
return ctx.dst->And(ctx.dst->Shr(low_u32, 8_u), 0xffff_u);
return b.And(b.Shr(low_u32, 8_u), 0xffff_u);
case 2:
return ctx.dst->Shr(low_u32, 16_u);
return b.Shr(low_u32, 16_u);
default: { // 3:
auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
VertexFormat::kUint32);
auto* shr = ctx.dst->Shr(low_u32, 24_u);
auto* shl = ctx.dst->Shl(high_u32, 8_u);
return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff_u);
auto* shr = b.Shr(low_u32, 24_u);
auto* shl = b.Shl(high_u32, 8_u);
return b.And(b.Or(shl, shr), 0xffff_u);
}
}
};
// Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
// The low 16 bits are 0.
auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
auto load_i16_h = [&] { return b.Bitcast<i32>(load_u16_h()); };
// Assumptions are made that alignment must be at least as large as the size
// of a single component.
@ -480,128 +520,121 @@ struct State {
// Vectors of basic primitives
case VertexFormat::kUint32x2:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
VertexFormat::kUint32, 2);
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 2);
case VertexFormat::kUint32x3:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
VertexFormat::kUint32, 3);
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 3);
case VertexFormat::kUint32x4:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
VertexFormat::kUint32, 4);
return LoadVec(array_base, offset, buffer, 4, b.ty.u32(), VertexFormat::kUint32, 4);
case VertexFormat::kSint32x2:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
VertexFormat::kSint32, 2);
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 2);
case VertexFormat::kSint32x3:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
VertexFormat::kSint32, 3);
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 3);
case VertexFormat::kSint32x4:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
VertexFormat::kSint32, 4);
return LoadVec(array_base, offset, buffer, 4, b.ty.i32(), VertexFormat::kSint32, 4);
case VertexFormat::kFloat32x2:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
VertexFormat::kFloat32, 2);
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
2);
case VertexFormat::kFloat32x3:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
VertexFormat::kFloat32, 3);
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
3);
case VertexFormat::kFloat32x4:
return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
VertexFormat::kFloat32, 4);
return LoadVec(array_base, offset, buffer, 4, b.ty.f32(), VertexFormat::kFloat32,
4);
case VertexFormat::kUint8x2: {
// yyxx0000, yyxx0000
auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
auto* u16s = b.vec2<u32>(load_u16_h());
// xx000000, yyxx0000
auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8_u, 0_u));
auto* shl = b.Shl(u16s, b.vec2<u32>(8_u, 0_u));
// 000000xx, 000000yy
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u));
return b.Shr(shl, b.vec2<u32>(24_u));
}
case VertexFormat::kUint8x4: {
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
auto* u32s = ctx.dst->vec4<u32>(load_u32());
auto* u32s = b.vec4<u32>(load_u32());
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u));
auto* shl = b.Shl(u32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
// 000000xx, 000000yy, 000000zz, 000000ww
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u));
return b.Shr(shl, b.vec4<u32>(24_u));
}
case VertexFormat::kUint16x2: {
// yyyyxxxx, yyyyxxxx
auto* u32s = ctx.dst->vec2<u32>(load_u32());
auto* u32s = b.vec2<u32>(load_u32());
// xxxx0000, yyyyxxxx
auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16_u, 0_u));
auto* shl = b.Shl(u32s, b.vec2<u32>(16_u, 0_u));
// 0000xxxx, 0000yyyy
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u));
return b.Shr(shl, b.vec2<u32>(16_u));
}
case VertexFormat::kUint16x4: {
// yyyyxxxx, wwwwzzzz
auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
auto* u32s = b.vec2<u32>(load_u32(), load_next_u32());
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
auto* xxyy = b.MemberAccessor(u32s, "xxyy");
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u));
auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
// 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u));
return b.Shr(shl, b.vec4<u32>(16_u));
}
case VertexFormat::kSint8x2: {
// yyxx0000, yyxx0000
auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
auto* i16s = b.vec2<i32>(load_i16_h());
// xx000000, yyxx0000
auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8_u, 0_u));
auto* shl = b.Shl(i16s, b.vec2<u32>(8_u, 0_u));
// ssssssxx, ssssssyy
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24_u));
return b.Shr(shl, b.vec2<u32>(24_u));
}
case VertexFormat::kSint8x4: {
// wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
auto* i32s = ctx.dst->vec4<i32>(load_i32());
auto* i32s = b.vec4<i32>(load_i32());
// xx000000, yyxx0000, zzyyxx00, wwzzyyxx
auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24_u, 16_u, 8_u, 0_u));
auto* shl = b.Shl(i32s, b.vec4<u32>(24_u, 16_u, 8_u, 0_u));
// ssssssxx, ssssssyy, sssssszz, ssssssww
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24_u));
return b.Shr(shl, b.vec4<u32>(24_u));
}
case VertexFormat::kSint16x2: {
// yyyyxxxx, yyyyxxxx
auto* i32s = ctx.dst->vec2<i32>(load_i32());
auto* i32s = b.vec2<i32>(load_i32());
// xxxx0000, yyyyxxxx
auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16_u, 0_u));
auto* shl = b.Shl(i32s, b.vec2<u32>(16_u, 0_u));
// ssssxxxx, ssssyyyy
return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16_u));
return b.Shr(shl, b.vec2<u32>(16_u));
}
case VertexFormat::kSint16x4: {
// yyyyxxxx, wwwwzzzz
auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
auto* i32s = b.vec2<i32>(load_i32(), load_next_i32());
// yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
auto* xxyy = b.MemberAccessor(i32s, "xxyy");
// xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16_u, 0_u, 16_u, 0_u));
auto* shl = b.Shl(xxyy, b.vec4<u32>(16_u, 0_u, 16_u, 0_u));
// ssssxxxx, ssssyyyy, sssszzzz, sssswwww
return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16_u));
return b.Shr(shl, b.vec4<u32>(16_u));
}
case VertexFormat::kUnorm8x2:
return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
return b.MemberAccessor(b.Call("unpack4x8unorm", load_u16_l()), "xy");
case VertexFormat::kSnorm8x2:
return ctx.dst->MemberAccessor(ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
return b.MemberAccessor(b.Call("unpack4x8snorm", load_u16_l()), "xy");
case VertexFormat::kUnorm8x4:
return ctx.dst->Call("unpack4x8unorm", load_u32());
return b.Call("unpack4x8unorm", load_u32());
case VertexFormat::kSnorm8x4:
return ctx.dst->Call("unpack4x8snorm", load_u32());
return b.Call("unpack4x8snorm", load_u32());
case VertexFormat::kUnorm16x2:
return ctx.dst->Call("unpack2x16unorm", load_u32());
return b.Call("unpack2x16unorm", load_u32());
case VertexFormat::kSnorm16x2:
return ctx.dst->Call("unpack2x16snorm", load_u32());
return b.Call("unpack2x16snorm", load_u32());
case VertexFormat::kFloat16x2:
return ctx.dst->Call("unpack2x16float", load_u32());
return b.Call("unpack2x16float", load_u32());
case VertexFormat::kUnorm16x4:
return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16unorm", load_u32()),
ctx.dst->Call("unpack2x16unorm", load_next_u32()));
return b.vec4<f32>(b.Call("unpack2x16unorm", load_u32()),
b.Call("unpack2x16unorm", load_next_u32()));
case VertexFormat::kSnorm16x4:
return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16snorm", load_u32()),
ctx.dst->Call("unpack2x16snorm", load_next_u32()));
return b.vec4<f32>(b.Call("unpack2x16snorm", load_u32()),
b.Call("unpack2x16snorm", load_next_u32()));
case VertexFormat::kFloat16x4:
return ctx.dst->vec4<f32>(ctx.dst->Call("unpack2x16float", load_u32()),
ctx.dst->Call("unpack2x16float", load_next_u32()));
return b.vec4<f32>(b.Call("unpack2x16float", load_u32()),
b.Call("unpack2x16float", load_next_u32()));
}
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "format " << static_cast<int>(format);
TINT_UNREACHABLE(Transform, b.Diagnostics()) << "format " << static_cast<int>(format);
return nullptr;
}
@ -623,12 +656,12 @@ struct State {
const ast ::Expression* index = nullptr;
if (offset > 0) {
index = ctx.dst->Add(array_base, u32(offset / 4));
index = b.Add(array_base, u32(offset / 4));
} else {
index = ctx.dst->Expr(array_base);
index = b.Expr(array_base);
}
u = ctx.dst->IndexAccessor(
ctx.dst->MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
u = b.IndexAccessor(
b.MemberAccessor(GetVertexBufferName(buffer), GetStructBufferName()), index);
} else {
// Unaligned load
@ -639,22 +672,22 @@ struct State {
uint32_t shift = 8u * (offset & 3u);
auto* low_shr = ctx.dst->Shr(low, u32(shift));
auto* high_shl = ctx.dst->Shl(high, u32(32u - shift));
u = ctx.dst->Or(low_shr, high_shl);
auto* low_shr = b.Shr(low, u32(shift));
auto* high_shl = b.Shl(high, u32(32u - shift));
u = b.Or(low_shr, high_shl);
}
switch (format) {
case VertexFormat::kUint32:
return u;
case VertexFormat::kSint32:
return ctx.dst->Bitcast(ctx.dst->ty.i32(), u);
return b.Bitcast(b.ty.i32(), u);
case VertexFormat::kFloat32:
return ctx.dst->Bitcast(ctx.dst->ty.f32(), u);
return b.Bitcast(b.ty.f32(), u);
default:
break;
}
TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
TINT_UNREACHABLE(Transform, b.Diagnostics())
<< "invalid format for LoadPrimitive" << static_cast<int>(format);
return nullptr;
}
@ -682,8 +715,7 @@ struct State {
expr_list.Push(LoadPrimitive(array_base, primitive_offset, buffer, base_format));
}
return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
std::move(expr_list));
return b.Construct(b.create<ast::Vector>(base_type, count), std::move(expr_list));
}
/// Process a non-struct entry point parameter.
@ -696,34 +728,30 @@ struct State {
// Create a function-scope variable to replace the parameter.
auto func_var_sym = ctx.Clone(param->symbol);
auto* func_var_type = ctx.Clone(param->type);
auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
auto* func_var = b.Var(func_var_sym, func_var_type);
ctx.InsertFront(func->body->statements, b.Decl(func_var));
// Capture mapping from location to the new variable.
LocationInfo info;
info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
info.expr = [this, func_var]() { return b.Expr(func_var); };
auto* sem = ctx.src->Sem().Get<sem::Parameter>(param);
auto* sem = src->Sem().Get<sem::Parameter>(param);
info.type = sem->Type();
if (!sem->Location().has_value()) {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Location missing value";
TINT_ICE(Transform, b.Diagnostics()) << "Location missing value";
return;
}
location_info[sem->Location().value()] = info;
} else if (auto* builtin = ast::GetAttribute<ast::BuiltinAttribute>(param->attributes)) {
// Check for existing vertex_index and instance_index builtins.
if (builtin->builtin == ast::BuiltinValue::kVertexIndex) {
vertex_index_expr = [this, param]() {
return ctx.dst->Expr(ctx.Clone(param->symbol));
};
vertex_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
} else if (builtin->builtin == ast::BuiltinValue::kInstanceIndex) {
instance_index_expr = [this, param]() {
return ctx.dst->Expr(ctx.Clone(param->symbol));
};
instance_index_expr = [this, param]() { return b.Expr(ctx.Clone(param->symbol)); };
}
new_function_parameters.Push(ctx.Clone(param));
} else {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
}
}
@ -746,7 +774,7 @@ struct State {
for (auto* member : struct_ty->members) {
auto member_sym = ctx.Clone(member->symbol);
std::function<const ast::Expression*()> member_expr = [this, param_sym, member_sym]() {
return ctx.dst->MemberAccessor(param_sym, member_sym);
return b.MemberAccessor(param_sym, member_sym);
};
if (ast::HasAttribute<ast::LocationAttribute>(member->attributes)) {
@ -754,7 +782,7 @@ struct State {
LocationInfo info;
info.expr = member_expr;
auto* sem = ctx.src->Sem().Get(member);
auto* sem = src->Sem().Get(member);
info.type = sem->Type();
TINT_ASSERT(Transform, sem->Location().has_value());
@ -770,7 +798,7 @@ struct State {
}
members_to_clone.Push(member);
} else {
TINT_ICE(Transform, ctx.dst->Diagnostics()) << "Invalid entry point parameter";
TINT_ICE(Transform, b.Diagnostics()) << "Invalid entry point parameter";
}
}
@ -781,8 +809,8 @@ struct State {
}
// Create a function-scope variable to replace the parameter.
auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
auto* func_var = b.Var(param_sym, ctx.Clone(param->type));
ctx.InsertFront(func->body->statements, b.Decl(func_var));
if (!members_to_clone.IsEmpty()) {
// Create a new struct without the location attributes.
@ -791,20 +819,20 @@ struct State {
auto member_sym = ctx.Clone(member->symbol);
auto* member_type = ctx.Clone(member->type);
auto member_attrs = ctx.Clone(member->attributes);
new_members.Push(ctx.dst->Member(member_sym, member_type, std::move(member_attrs)));
new_members.Push(b.Member(member_sym, member_type, std::move(member_attrs)));
}
auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
auto* new_struct = b.Structure(b.Sym(), new_members);
// Create a new function parameter with this struct.
auto* new_param = ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
auto* new_param = b.Param(b.Sym(), b.ty.Of(new_struct));
new_function_parameters.Push(new_param);
// Copy values from the new parameter to the function-scope variable.
for (auto* member : members_to_clone) {
auto member_name = ctx.Clone(member->symbol);
ctx.InsertFront(func->body->statements,
ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
ctx.dst->MemberAccessor(new_param, member_name)));
b.Assign(b.MemberAccessor(func_var, member_name),
b.MemberAccessor(new_param, member_name)));
}
}
}
@ -818,7 +846,7 @@ struct State {
// Process entry point parameters.
for (auto* param : func->params) {
auto* sem = ctx.src->Sem().Get(param);
auto* sem = src->Sem().Get(param);
if (auto* str = sem->Type()->As<sem::Struct>()) {
ProcessStructParameter(func, param, str->Declaration());
} else {
@ -830,11 +858,11 @@ struct State {
if (!vertex_index_expr) {
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
if (layout.step_mode == VertexStepMode::kVertex) {
auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
new_function_parameters.Push(ctx.dst->Param(
name, ctx.dst->ty.u32(),
utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kVertexIndex)}));
vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
auto name = b.Symbols().New("tint_pulling_vertex_index");
new_function_parameters.Push(
b.Param(name, b.ty.u32(),
utils::Vector{b.Builtin(ast::BuiltinValue::kVertexIndex)}));
vertex_index_expr = [this, name]() { return b.Expr(name); };
break;
}
}
@ -842,11 +870,11 @@ struct State {
if (!instance_index_expr) {
for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
if (layout.step_mode == VertexStepMode::kInstance) {
auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
new_function_parameters.Push(ctx.dst->Param(
name, ctx.dst->ty.u32(),
utils::Vector{ctx.dst->Builtin(ast::BuiltinValue::kInstanceIndex)}));
instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
auto name = b.Symbols().New("tint_pulling_instance_index");
new_function_parameters.Push(
b.Param(name, b.ty.u32(),
utils::Vector{b.Builtin(ast::BuiltinValue::kInstanceIndex)}));
instance_index_expr = [this, name]() { return b.Expr(name); };
break;
}
}
@ -864,53 +892,24 @@ struct State {
auto attrs = ctx.Clone(func->attributes);
auto ret_attrs = ctx.Clone(func->return_type_attributes);
auto* new_func =
ctx.dst->create<ast::Function>(func->source, func_sym, new_function_parameters,
ret_type, body, std::move(attrs), std::move(ret_attrs));
b.create<ast::Function>(func->source, func_sym, new_function_parameters, ret_type, body,
std::move(attrs), std::move(ret_attrs));
ctx.Replace(func, new_func);
}
};
} // namespace
VertexPulling::VertexPulling() = default;
VertexPulling::~VertexPulling() = default;
void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) const {
Transform::ApplyResult VertexPulling::Apply(const Program* src,
const DataMap& inputs,
DataMap&) const {
auto cfg = cfg_;
if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data;
}
// Find entry point
const ast::Function* func = nullptr;
for (auto* fn : ctx.src->AST().Functions()) {
if (fn->PipelineStage() == ast::PipelineStage::kVertex) {
if (func != nullptr) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform,
"VertexPulling found more than one vertex entry point");
return;
}
func = fn;
}
}
if (func == nullptr) {
ctx.dst->Diagnostics().add_error(diag::System::Transform,
"Vertex stage entry point not found");
return;
}
// TODO(idanr): Need to check shader locations in descriptor cover all
// attributes
// TODO(idanr): Make sure we covered all error cases, to guarantee the
// following stages will pass
State state{ctx, cfg};
state.AddVertexStorageBuffers();
state.Process(func);
ctx.Clone();
return State{src, cfg}.Run();
}
VertexPulling::Config::Config() = default;

View File

@ -171,16 +171,14 @@ class VertexPulling final : public Castable<VertexPulling, Transform> {
/// Destructor
~VertexPulling() override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
/// @copydoc Transform::Apply
ApplyResult Apply(const Program* program,
const DataMap& inputs,
DataMap& outputs) const override;
private:
struct State;
Config cfg_;
};

View File

@ -14,18 +14,17 @@
#include "src/tint/transform/while_to_loop.h"
#include <utility>
#include "src/tint/ast/break_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop);
namespace tint::transform {
namespace {
WhileToLoop::WhileToLoop() = default;
WhileToLoop::~WhileToLoop() = default;
bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
bool ShouldRun(const Program* program) {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::WhileStatement>()) {
return true;
@ -34,20 +33,32 @@ bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
return false;
}
void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
} // namespace
WhileToLoop::WhileToLoop() = default;
WhileToLoop::~WhileToLoop() = default;
Transform::ApplyResult WhileToLoop::Apply(const Program* src, const DataMap&, DataMap&) const {
if (!ShouldRun(src)) {
return SkipTransform;
}
ProgramBuilder b;
CloneContext ctx{&b, src, /* auto_clone_symbols */ true};
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
utils::Vector<const ast::Statement*, 16> stmts;
auto* cond = w->condition;
// !condition
auto* not_cond =
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
auto* not_cond = b.Not(ctx.Clone(cond));
// { break; }
auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
auto* break_body = b.Block(b.Break());
// if (!condition) { break; }
stmts.Push(ctx.dst->If(not_cond, break_body));
stmts.Push(b.If(not_cond, break_body));
for (auto* stmt : w->body->statements) {
stmts.Push(ctx.Clone(stmt));
@ -55,13 +66,14 @@ void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
const ast::BlockStatement* continuing = nullptr;
auto* body = ctx.dst->Block(stmts);
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
auto* body = b.Block(stmts);
auto* loop = b.Loop(body, continuing);
return loop;
});
ctx.Clone();
return Program(std::move(b));
}
} // namespace tint::transform

Some files were not shown because too many files have changed in this diff Show More