Program: Track what transforms have been applied

Allows transforms to assert their dependencies have been run before they
are.

Also allows the backends to validate that their sanitizers have been run
before they're used.

Change-Id: I1e97afe06f9e7371283bade54bbb2e2c41f87a00
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55442
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-06-25 10:26:26 +00:00 committed by Tint LUCI CQ
parent 5b9203cc05
commit b5cd10c7bd
64 changed files with 575 additions and 313 deletions

View File

@ -860,7 +860,7 @@ if(${TINT_BUILD_TESTS})
if(${TINT_BUILD_WGSL_READER} AND ${TINT_BUILD_WGSL_WRITER}) if(${TINT_BUILD_WGSL_READER} AND ${TINT_BUILD_WGSL_WRITER})
list(APPEND TINT_TEST_SRCS list(APPEND TINT_TEST_SRCS
transform/array_length_from_uniform.cc transform/array_length_from_uniform_test.cc
transform/binding_remapper_test.cc transform/binding_remapper_test.cc
transform/bound_array_accessors_test.cc transform/bound_array_accessors_test.cc
transform/calculate_array_length_test.cc transform/calculate_array_length_test.cc

View File

@ -57,6 +57,7 @@ Symbol CloneContext::Clone(Symbol s) {
void CloneContext::Clone() { void CloneContext::Clone() {
dst->AST().Copy(this, &src->AST()); dst->AST().Copy(this, &src->AST());
dst->SetTransformApplied(src->TransformsApplied());
} }
ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) { ast::FunctionList CloneContext::Clone(const ast::FunctionList& v) {

View File

@ -33,6 +33,7 @@ Program::Program(Program&& program)
sem_(std::move(program.sem_)), sem_(std::move(program.sem_)),
symbols_(std::move(program.symbols_)), symbols_(std::move(program.symbols_)),
diagnostics_(std::move(program.diagnostics_)), diagnostics_(std::move(program.diagnostics_)),
transforms_applied_(std::move(program.transforms_applied_)),
is_valid_(program.is_valid_) { is_valid_(program.is_valid_) {
program.AssertNotMoved(); program.AssertNotMoved();
program.moved_ = true; program.moved_ = true;
@ -57,6 +58,7 @@ Program::Program(ProgramBuilder&& builder) {
sem_ = std::move(builder.Sem()); sem_ = std::move(builder.Sem());
symbols_ = std::move(builder.Symbols()); symbols_ = std::move(builder.Symbols());
diagnostics_.add(std::move(builder.Diagnostics())); diagnostics_.add(std::move(builder.Diagnostics()));
transforms_applied_ = builder.TransformsApplied();
builder.MarkAsMoved(); builder.MarkAsMoved();
if (!is_valid_ && !diagnostics_.contains_errors()) { if (!is_valid_ && !diagnostics_.contains_errors()) {
@ -80,6 +82,7 @@ Program& Program::operator=(Program&& program) {
sem_ = std::move(program.sem_); sem_ = std::move(program.sem_);
symbols_ = std::move(program.symbols_); symbols_ = std::move(program.symbols_);
diagnostics_ = std::move(program.diagnostics_); diagnostics_ = std::move(program.diagnostics_);
transforms_applied_ = std::move(program.transforms_applied_);
is_valid_ = program.is_valid_; is_valid_ = program.is_valid_;
return *this; return *this;
} }

View File

@ -16,6 +16,7 @@
#define SRC_PROGRAM_H_ #define SRC_PROGRAM_H_
#include <string> #include <string>
#include <unordered_set>
#include "src/ast/function.h" #include "src/ast/function.h"
#include "src/program_id.h" #include "src/program_id.h"
@ -125,6 +126,25 @@ class Program {
/// information /// information
bool IsValid() const; bool IsValid() const;
/// @return the TypeInfo pointers of all transforms that have been applied to
/// this program.
std::unordered_set<const TypeInfo*> TransformsApplied() const {
return transforms_applied_;
}
/// @param transform the TypeInfo of the transform
/// @returns true if the transform with the given TypeInfo was applied to the
/// Program
bool HasTransformApplied(const TypeInfo* transform) const {
return transforms_applied_.count(transform);
}
/// @returns true if the transform of type `T` was applied.
template <typename T>
bool HasTransformApplied() const {
return HasTransformApplied(&TypeInfo::Of<T>());
}
/// Helper for returning the resolved semantic type of the expression `expr`. /// Helper for returning the resolved semantic type of the expression `expr`.
/// @param expr the AST expression /// @param expr the AST expression
/// @return the resolved semantic type for the expression, or nullptr if the /// @return the resolved semantic type for the expression, or nullptr if the
@ -180,6 +200,7 @@ class Program {
sem::Info sem_; sem::Info sem_;
SymbolTable symbols_{id_}; SymbolTable symbols_{id_};
diag::List diagnostics_; diag::List diagnostics_;
std::unordered_set<const TypeInfo*> transforms_applied_;
bool is_valid_ = false; // Not valid until it is built bool is_valid_ = false; // Not valid until it is built
bool moved_ = false; bool moved_ = false;
}; };

View File

@ -37,7 +37,9 @@ ProgramBuilder::ProgramBuilder(ProgramBuilder&& rhs)
sem_nodes_(std::move(rhs.sem_nodes_)), sem_nodes_(std::move(rhs.sem_nodes_)),
ast_(rhs.ast_), ast_(rhs.ast_),
sem_(std::move(rhs.sem_)), sem_(std::move(rhs.sem_)),
symbols_(std::move(rhs.symbols_)) { symbols_(std::move(rhs.symbols_)),
diagnostics_(std::move(rhs.diagnostics_)),
transforms_applied_(std::move(rhs.transforms_applied_)) {
rhs.MarkAsMoved(); rhs.MarkAsMoved();
} }
@ -53,6 +55,8 @@ ProgramBuilder& ProgramBuilder::operator=(ProgramBuilder&& rhs) {
ast_ = rhs.ast_; ast_ = rhs.ast_;
sem_ = std::move(rhs.sem_); sem_ = std::move(rhs.sem_);
symbols_ = std::move(rhs.symbols_); symbols_ = std::move(rhs.symbols_);
diagnostics_ = std::move(rhs.diagnostics_);
transforms_applied_ = std::move(rhs.transforms_applied_);
return *this; return *this;
} }
@ -65,6 +69,7 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) {
builder.sem_ = sem::Info::Wrap(program->Sem()); builder.sem_ = sem::Info::Wrap(program->Sem());
builder.symbols_ = program->Symbols(); builder.symbols_ = program->Symbols();
builder.diagnostics_ = program->Diagnostics(); builder.diagnostics_ = program->Diagnostics();
builder.transforms_applied_ = program->TransformsApplied();
return builder; return builder;
} }

View File

@ -16,6 +16,7 @@
#define SRC_PROGRAM_BUILDER_H_ #define SRC_PROGRAM_BUILDER_H_
#include <string> #include <string>
#include <unordered_set>
#include <utility> #include <utility>
#include "src/ast/alias.h" #include "src/ast/alias.h"
@ -84,13 +85,14 @@
#error "internal tint header being #included from tint.h" #error "internal tint header being #included from tint.h"
#endif #endif
namespace tint {
// Forward declarations // Forward declarations
namespace tint {
namespace ast { namespace ast {
class VariableDeclStatement; class VariableDeclStatement;
} // namespace ast } // namespace ast
} // namespace tint
namespace tint {
class CloneContext; class CloneContext;
/// ProgramBuilder is a mutable builder for a Program. /// ProgramBuilder is a mutable builder for a Program.
@ -2039,6 +2041,40 @@ class ProgramBuilder {
source_ = Source(loc); source_ = Source(loc);
} }
/// Marks that the given transform has been applied to this program.
/// @param transform the transform that has been applied
void SetTransformApplied(const CastableBase* transform) {
transforms_applied_.emplace(&transform->TypeInfo());
}
/// Marks that the given transform `T` has been applied to this program.
template <typename T>
void SetTransformApplied() {
transforms_applied_.emplace(&TypeInfo::Of<T>());
}
/// Marks that the transforms with the given TypeInfos have been applied to
/// this program.
/// @param transforms the set of transform TypeInfos that has been applied
void SetTransformApplied(
const std::unordered_set<const TypeInfo*>& transforms) {
for (auto* transform : transforms) {
transforms_applied_.emplace(transform);
}
}
/// @returns true if the transform of type `T` was applied.
template <typename T>
bool HasTransformApplied() {
return transforms_applied_.count(&TypeInfo::Of<T>());
}
/// @return the TypeInfo pointers of all transforms that have been applied to
/// this program.
std::unordered_set<const TypeInfo*> TransformsApplied() const {
return transforms_applied_;
}
/// Helper for returning the resolved semantic type of the expression `expr`. /// Helper for returning the resolved semantic type of the expression `expr`.
/// @note As the Resolver is run when the Program is built, this will only be /// @note As the Resolver is run when the Program is built, this will only be
/// useful for the Resolver itself and tests that use their own Resolver. /// useful for the Resolver itself and tests that use their own Resolver.
@ -2125,6 +2161,7 @@ class ProgramBuilder {
sem::Info sem_; sem::Info sem_;
SymbolTable symbols_{id_}; SymbolTable symbols_{id_};
diag::List diagnostics_; diag::List diagnostics_;
std::unordered_set<const TypeInfo*> transforms_applied_;
/// The source to use when creating AST nodes without providing a Source as /// The source to use when creating AST nodes without providing a Source as
/// the first argument. /// the first argument.

View File

@ -21,7 +21,10 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/simplify.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Config);
TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result); TINT_INSTANTIATE_TYPEINFO(tint::transform::ArrayLengthFromUniform::Result);
@ -31,16 +34,19 @@ namespace transform {
ArrayLengthFromUniform::ArrayLengthFromUniform() = default; ArrayLengthFromUniform::ArrayLengthFromUniform() = default;
ArrayLengthFromUniform::~ArrayLengthFromUniform() = default; ArrayLengthFromUniform::~ArrayLengthFromUniform() = default;
Output ArrayLengthFromUniform::Run(const Program* in, const DataMap& data) { void ArrayLengthFromUniform::Run(CloneContext& ctx,
ProgramBuilder out; const DataMap& inputs,
CloneContext ctx(&out, in); DataMap& outputs) {
if (!Requires<InlinePointerLets, Simplify>(ctx)) {
return;
}
auto* cfg = data.Get<Config>(); auto* cfg = inputs.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
out.Diagnostics().add_error( ctx.dst->Diagnostics().add_error(
diag::System::Transform, diag::System::Transform,
"missing transform data for ArrayLengthFromUniform"); "missing transform data for ArrayLengthFromUniform");
return Output(Program(std::move(out))); return;
} }
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
@ -149,8 +155,7 @@ Output ArrayLengthFromUniform::Run(const Program* in, const DataMap& data) {
ctx.Clone(); ctx.Clone();
return Output{Program(std::move(out)), outputs.Add<Result>(buffer_size_ubo ? true : false);
std::make_unique<Result>(buffer_size_ubo ? true : false)};
} }
ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp) ArrayLengthFromUniform::Config::Config(sem::BindingPoint ubo_bp)

View File

@ -49,7 +49,8 @@ namespace transform {
/// This transform assumes that the `InlinePointerLets` and `Simplify` /// This transform assumes that the `InlinePointerLets` and `Simplify`
/// transforms have been run before it so that arguments to the arrayLength /// transforms have been run before it so that arguments to the arrayLength
/// builtin always have the form `&resource.array`. /// builtin always have the form `&resource.array`.
class ArrayLengthFromUniform : public Transform { class ArrayLengthFromUniform
: public Castable<ArrayLengthFromUniform, Transform> {
public: public:
/// Constructor /// Constructor
ArrayLengthFromUniform(); ArrayLengthFromUniform();
@ -91,11 +92,14 @@ class ArrayLengthFromUniform : public Transform {
bool const needs_buffer_sizes; bool const needs_buffer_sizes;
}; };
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -16,6 +16,8 @@
#include <utility> #include <utility>
#include "src/transform/inline_pointer_lets.h"
#include "src/transform/simplify.h"
#include "src/transform/test_helper.h" #include "src/transform/test_helper.h"
namespace tint { namespace tint {
@ -29,7 +31,31 @@ TEST_F(ArrayLengthFromUniformTest, Error_MissingTransformData) {
auto* expect = "error: missing transform data for ArrayLengthFromUniform"; auto* expect = "error: missing transform data for ArrayLengthFromUniform";
auto got = Run<ArrayLengthFromUniform>(src); auto got = Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingInlinePointerLets) {
auto* src = "";
auto* expect =
"error: tint::transform::ArrayLengthFromUniform depends on "
"tint::transform::InlinePointerLets but the dependency was not run";
auto got = Run<Simplify, ArrayLengthFromUniform>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(ArrayLengthFromUniformTest, Error_MissingSimplify) {
auto* src = "";
auto* expect =
"error: tint::transform::ArrayLengthFromUniform depends on "
"tint::transform::Simplify but the dependency was not run";
auto got = Run<InlinePointerLets, ArrayLengthFromUniform>(src);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
@ -78,7 +104,8 @@ fn main() {
DataMap data; DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<ArrayLengthFromUniform>(src, data); auto got =
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
EXPECT_TRUE( EXPECT_TRUE(
@ -131,7 +158,8 @@ fn main() {
DataMap data; DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<ArrayLengthFromUniform>(src, data); auto got =
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
EXPECT_TRUE( EXPECT_TRUE(
@ -203,7 +231,8 @@ fn main() {
DataMap data; DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<ArrayLengthFromUniform>(src, data); auto got =
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
EXPECT_TRUE( EXPECT_TRUE(
@ -232,7 +261,8 @@ fn main() {
DataMap data; DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<ArrayLengthFromUniform>(src, data); auto got =
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(src, str(got)); EXPECT_EQ(src, str(got));
EXPECT_FALSE( EXPECT_FALSE(
@ -273,7 +303,8 @@ fn main() {
DataMap data; DataMap data;
data.Add<ArrayLengthFromUniform::Config>(std::move(cfg)); data.Add<ArrayLengthFromUniform::Config>(std::move(cfg));
auto got = Run<ArrayLengthFromUniform>(src, data); auto got =
Run<InlinePointerLets, Simplify, ArrayLengthFromUniform>(src, data);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }

View File

@ -22,6 +22,7 @@
#include "src/sem/function.h" #include "src/sem/function.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper);
TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper::Remappings); TINT_INSTANTIATE_TYPEINFO(tint::transform::BindingRemapper::Remappings);
namespace tint { namespace tint {
@ -40,14 +41,13 @@ BindingRemapper::Remappings::~Remappings() = default;
BindingRemapper::BindingRemapper() = default; BindingRemapper::BindingRemapper() = default;
BindingRemapper::~BindingRemapper() = default; BindingRemapper::~BindingRemapper() = default;
Output BindingRemapper::Run(const Program* in, const DataMap& datamap) { void BindingRemapper::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
ProgramBuilder out; auto* remappings = inputs.Get<Remappings>();
auto* remappings = datamap.Get<Remappings>();
if (!remappings) { if (!remappings) {
out.Diagnostics().add_error( ctx.dst->Diagnostics().add_error(
diag::System::Transform, diag::System::Transform,
"BindingRemapper did not find the remapping data"); "BindingRemapper did not find the remapping data");
return Output(Program(std::move(out))); return;
} }
// A set of post-remapped binding points that need to be decorated with a // A set of post-remapped binding points that need to be decorated with a
@ -57,11 +57,11 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) {
if (remappings->allow_collisions) { if (remappings->allow_collisions) {
// Scan for binding point collisions generated by this transform. // Scan for binding point collisions generated by this transform.
// Populate all collisions in the `add_collision_deco` set. // Populate all collisions in the `add_collision_deco` set.
for (auto* func_ast : in->AST().Functions()) { for (auto* func_ast : ctx.src->AST().Functions()) {
if (!func_ast->IsEntryPoint()) { if (!func_ast->IsEntryPoint()) {
continue; continue;
} }
auto* func = in->Sem().Get(func_ast); auto* func = ctx.src->Sem().Get(func_ast);
std::unordered_map<sem::BindingPoint, int> binding_point_counts; std::unordered_map<sem::BindingPoint, int> binding_point_counts;
for (auto* var : func->ReferencedModuleVariables()) { for (auto* var : func->ReferencedModuleVariables()) {
if (auto binding_point = var->Declaration()->binding_point()) { if (auto binding_point = var->Declaration()->binding_point()) {
@ -85,9 +85,7 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) {
} }
} }
CloneContext ctx(&out, in); for (auto* var : ctx.src->AST().GlobalVariables()) {
for (auto* var : in->AST().GlobalVariables()) {
if (auto binding_point = var->binding_point()) { if (auto binding_point = var->binding_point()) {
// The original binding point // The original binding point
BindingPoint from{binding_point.group->value(), BindingPoint from{binding_point.group->value(),
@ -102,8 +100,8 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) {
auto bp_it = remappings->binding_points.find(from); auto bp_it = remappings->binding_points.find(from);
if (bp_it != remappings->binding_points.end()) { if (bp_it != remappings->binding_points.end()) {
BindingPoint to = bp_it->second; BindingPoint to = bp_it->second;
auto* new_group = out.create<ast::GroupDecoration>(to.group); auto* new_group = ctx.dst->create<ast::GroupDecoration>(to.group);
auto* new_binding = out.create<ast::BindingDecoration>(to.binding); auto* new_binding = ctx.dst->create<ast::BindingDecoration>(to.binding);
ctx.Replace(binding_point.group, new_group); ctx.Replace(binding_point.group, new_group);
ctx.Replace(binding_point.binding, new_binding); ctx.Replace(binding_point.binding, new_binding);
@ -114,7 +112,7 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) {
auto ac_it = remappings->access_controls.find(from); auto ac_it = remappings->access_controls.find(from);
if (ac_it != remappings->access_controls.end()) { if (ac_it != remappings->access_controls.end()) {
ast::Access ac = ac_it->second; ast::Access ac = ac_it->second;
auto* ty = in->Sem().Get(var)->Type()->UnwrapRef(); auto* ty = ctx.src->Sem().Get(var)->Type()->UnwrapRef();
ast::Type* inner_ty = CreateASTTypeFor(&ctx, ty); ast::Type* inner_ty = CreateASTTypeFor(&ctx, ty);
auto* new_var = ctx.dst->create<ast::Variable>( auto* new_var = ctx.dst->create<ast::Variable>(
ctx.Clone(var->source()), ctx.Clone(var->symbol()), ctx.Clone(var->source()), ctx.Clone(var->symbol()),
@ -133,8 +131,8 @@ Output BindingRemapper::Run(const Program* in, const DataMap& datamap) {
} }
} }
} }
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
} // namespace transform } // namespace transform

View File

@ -29,7 +29,7 @@ using BindingPoint = sem::BindingPoint;
/// BindingRemapper is a transform used to remap resource binding points and /// BindingRemapper is a transform used to remap resource binding points and
/// access controls. /// access controls.
class BindingRemapper : public Transform { class BindingRemapper : public Castable<BindingRemapper, Transform> {
public: public:
/// BindingPoints is a map of old binding point to new binding point /// BindingPoints is a map of old binding point to new binding point
using BindingPoints = std::unordered_map<BindingPoint, BindingPoint>; using BindingPoints = std::unordered_map<BindingPoint, BindingPoint>;
@ -68,11 +68,14 @@ class BindingRemapper : public Transform {
BindingRemapper(); BindingRemapper();
~BindingRemapper() override; ~BindingRemapper() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -20,20 +20,20 @@
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::BoundArrayAccessors);
namespace tint { namespace tint {
namespace transform { namespace transform {
BoundArrayAccessors::BoundArrayAccessors() = default; BoundArrayAccessors::BoundArrayAccessors() = default;
BoundArrayAccessors::~BoundArrayAccessors() = default; BoundArrayAccessors::~BoundArrayAccessors() = default;
Output BoundArrayAccessors::Run(const Program* in, const DataMap&) { void BoundArrayAccessors::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) { ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) {
return Transform(expr, &ctx); return Transform(expr, &ctx);
}); });
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
ast::ArrayAccessorExpression* BoundArrayAccessors::Transform( ast::ArrayAccessorExpression* BoundArrayAccessors::Transform(

View File

@ -25,18 +25,21 @@ namespace transform {
/// the bounds of the array. Any access before the start of the array will clamp /// the bounds of the array. Any access before the start of the array will clamp
/// to zero and any access past the end of the array will clamp to /// to zero and any access past the end of the array will clamp to
/// (array length - 1). /// (array length - 1).
class BoundArrayAccessors : public Transform { class BoundArrayAccessors : public Castable<BoundArrayAccessors, Transform> {
public: public:
/// Constructor /// Constructor
BoundArrayAccessors(); BoundArrayAccessors();
/// Destructor /// Destructor
~BoundArrayAccessors() override; ~BoundArrayAccessors() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
private: private:
ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr, ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr,

View File

@ -27,6 +27,7 @@
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
#include "src/utils/hash.h" #include "src/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::CalculateArrayLength);
TINT_INSTANTIATE_TYPEINFO( TINT_INSTANTIATE_TYPEINFO(
tint::transform::CalculateArrayLength::BufferSizeIntrinsic); tint::transform::CalculateArrayLength::BufferSizeIntrinsic);
@ -69,10 +70,7 @@ CalculateArrayLength::BufferSizeIntrinsic::Clone(CloneContext* ctx) const {
CalculateArrayLength::CalculateArrayLength() = default; CalculateArrayLength::CalculateArrayLength() = default;
CalculateArrayLength::~CalculateArrayLength() = default; CalculateArrayLength::~CalculateArrayLength() = default;
Output CalculateArrayLength::Run(const Program* in, const DataMap&) { void CalculateArrayLength::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
// get_buffer_size_intrinsic() emits the function decorated with // get_buffer_size_intrinsic() emits the function decorated with
@ -232,8 +230,6 @@ Output CalculateArrayLength::Run(const Program* in, const DataMap&) {
} }
ctx.Clone(); ctx.Clone();
return Output{Program(std::move(out))};
} }
} // namespace transform } // namespace transform

View File

@ -29,7 +29,7 @@ namespace transform {
/// CalculateArrayLength is a transform used to replace calls to arrayLength() /// CalculateArrayLength is a transform used to replace calls to arrayLength()
/// with a value calculated from the size of the storage buffer. /// with a value calculated from the size of the storage buffer.
class CalculateArrayLength : public Transform { class CalculateArrayLength : public Castable<CalculateArrayLength, Transform> {
public: public:
/// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic /// BufferSizeIntrinsic is an InternalDecoration that's applied to intrinsic
/// functions used to obtain the runtime size of a storage buffer. /// functions used to obtain the runtime size of a storage buffer.
@ -56,11 +56,14 @@ class CalculateArrayLength : public Transform {
/// Destructor /// Destructor
~CalculateArrayLength() override; ~CalculateArrayLength() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -24,6 +24,7 @@
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO);
TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::CanonicalizeEntryPointIO::Config);
namespace tint { namespace tint {
@ -62,16 +63,15 @@ bool StructMemberComparator(const ast::StructMember* a,
} // namespace } // namespace
Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap& data) { void CanonicalizeEntryPointIO::Run(CloneContext& ctx,
ProgramBuilder out; const DataMap& inputs,
CloneContext ctx(&out, in); DataMap&) {
auto* cfg = inputs.Get<Config>();
auto* cfg = data.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
out.Diagnostics().add_error( ctx.dst->Diagnostics().add_error(
diag::System::Transform, diag::System::Transform,
"missing transform data for CanonicalizeEntryPointIO"); "missing transform data for CanonicalizeEntryPointIO");
return Output(Program(std::move(out))); return;
} }
// Strip entry point IO decorations from struct declarations. // Strip entry point IO decorations from struct declarations.
@ -375,7 +375,6 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap& data) {
} }
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
CanonicalizeEntryPointIO::Config::Config(BuiltinStyle builtins, CanonicalizeEntryPointIO::Config::Config(BuiltinStyle builtins,

View File

@ -68,7 +68,8 @@ namespace transform {
/// return retval; /// return retval;
/// } /// }
/// ``` /// ```
class CanonicalizeEntryPointIO : public Transform { class CanonicalizeEntryPointIO
: public Castable<CanonicalizeEntryPointIO, Transform> {
public: public:
/// BuiltinStyle is an enumerator of different ways to emit builtins. /// BuiltinStyle is an enumerator of different ways to emit builtins.
enum class BuiltinStyle { enum class BuiltinStyle {
@ -102,11 +103,14 @@ class CanonicalizeEntryPointIO : public Transform {
CanonicalizeEntryPointIO(); CanonicalizeEntryPointIO();
~CanonicalizeEntryPointIO() override; ~CanonicalizeEntryPointIO() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -38,6 +38,7 @@
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
#include "src/utils/hash.h" #include "src/utils/hash.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess);
TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic); TINT_INSTANTIATE_TYPEINFO(tint::transform::DecomposeMemoryAccess::Intrinsic);
namespace tint { namespace tint {
@ -790,10 +791,7 @@ DecomposeMemoryAccess::Intrinsic* DecomposeMemoryAccess::Intrinsic::Clone(
DecomposeMemoryAccess::DecomposeMemoryAccess() = default; DecomposeMemoryAccess::DecomposeMemoryAccess() = default;
DecomposeMemoryAccess::~DecomposeMemoryAccess() = default; DecomposeMemoryAccess::~DecomposeMemoryAccess() = default;
Output DecomposeMemoryAccess::Run(const Program* in, const DataMap&) { void DecomposeMemoryAccess::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
State state; State state;
@ -987,7 +985,6 @@ Output DecomposeMemoryAccess::Run(const Program* in, const DataMap&) {
} }
ctx.Clone(); ctx.Clone();
return Output{Program(std::move(out))};
} }
} // namespace transform } // namespace transform

View File

@ -30,7 +30,8 @@ namespace transform {
/// DecomposeMemoryAccess is a transform used to replace storage and uniform /// DecomposeMemoryAccess is a transform used to replace storage and uniform
/// buffer accesses with a combination of load, store or atomic functions on /// buffer accesses with a combination of load, store or atomic functions on
/// primitive types. /// primitive types.
class DecomposeMemoryAccess : public Transform { class DecomposeMemoryAccess
: public Castable<DecomposeMemoryAccess, Transform> {
public: public:
/// Intrinsic is an InternalDecoration that's used to decorate a stub function /// Intrinsic is an InternalDecoration that's used to decorate a stub function
/// so that the HLSL transforms this into calls to /// so that the HLSL transforms this into calls to
@ -103,11 +104,14 @@ class DecomposeMemoryAccess : public Transform {
/// Destructor /// Destructor
~DecomposeMemoryAccess() override; ~DecomposeMemoryAccess() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
struct State; struct State;
}; };

View File

@ -18,15 +18,17 @@
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ExternalTextureTransform);
namespace tint { namespace tint {
namespace transform { namespace transform {
ExternalTextureTransform::ExternalTextureTransform() = default; ExternalTextureTransform::ExternalTextureTransform() = default;
ExternalTextureTransform::~ExternalTextureTransform() = default; ExternalTextureTransform::~ExternalTextureTransform() = default;
Output ExternalTextureTransform::Run(const Program* in, const DataMap&) { void ExternalTextureTransform::Run(CloneContext& ctx,
ProgramBuilder out; const DataMap&,
CloneContext ctx(&out, in); DataMap&) {
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
// Within this transform, usages of texture_external are replaced with a // Within this transform, usages of texture_external are replaced with a
@ -105,7 +107,7 @@ Output ExternalTextureTransform::Run(const Program* in, const DataMap&) {
// Scan the AST nodes for external texture declarations. // Scan the AST nodes for external texture declarations.
for (auto* node : ctx.src->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Variable>()) { if (auto* var = node->As<ast::Variable>()) {
if (Is<ast::ExternalTexture>(var->type())) { if (::tint::Is<ast::ExternalTexture>(var->type())) {
// Replace a single-plane external texture with a 2D, f32 sampled // Replace a single-plane external texture with a 2D, f32 sampled
// texture. // texture.
auto* newType = ctx.dst->ty.sampled_texture(ast::TextureDimension::k2d, auto* newType = ctx.dst->ty.sampled_texture(ast::TextureDimension::k2d,
@ -125,7 +127,6 @@ Output ExternalTextureTransform::Run(const Program* in, const DataMap&) {
} }
ctx.Clone(); ctx.Clone();
return Output{Program(std::move(out))};
} }
} // namespace transform } // namespace transform

View File

@ -27,18 +27,22 @@ namespace transform {
/// This allows us to share SPIR-V/HLSL writer paths for sampled textures /// This allows us to share SPIR-V/HLSL writer paths for sampled textures
/// instead of adding dedicated writer paths for external textures. /// instead of adding dedicated writer paths for external textures.
/// ExternalTextureTransform performs this transformation. /// ExternalTextureTransform performs this transformation.
class ExternalTextureTransform : public Transform { class ExternalTextureTransform
: public Castable<ExternalTextureTransform, Transform> {
public: public:
/// Constructor /// Constructor
ExternalTextureTransform(); ExternalTextureTransform();
/// Destructor /// Destructor
~ExternalTextureTransform() override; ~ExternalTextureTransform() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -25,6 +25,7 @@
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset);
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::BindingPoint); TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::BindingPoint);
TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::Data); TINT_INSTANTIATE_TYPEINFO(tint::transform::FirstIndexOffset::Data);
@ -57,18 +58,17 @@ FirstIndexOffset::Data::~Data() = default;
FirstIndexOffset::FirstIndexOffset() = default; FirstIndexOffset::FirstIndexOffset() = default;
FirstIndexOffset::~FirstIndexOffset() = default; FirstIndexOffset::~FirstIndexOffset() = default;
Output FirstIndexOffset::Run(const Program* in, const DataMap& data) { void FirstIndexOffset::Run(CloneContext& ctx,
const DataMap& inputs,
DataMap& outputs) {
// Get the uniform buffer binding point // Get the uniform buffer binding point
uint32_t ub_binding = binding_; uint32_t ub_binding = binding_;
uint32_t ub_group = group_; uint32_t ub_group = group_;
if (auto* binding_point = data.Get<BindingPoint>()) { if (auto* binding_point = inputs.Get<BindingPoint>()) {
ub_binding = binding_point->binding; ub_binding = binding_point->binding;
ub_group = binding_point->group; ub_group = binding_point->group;
} }
ProgramBuilder out;
CloneContext ctx(&out, in);
// Map of builtin usages // Map of builtin usages
std::unordered_map<const sem::Variable*, const char*> builtin_vars; std::unordered_map<const sem::Variable*, const char*> builtin_vars;
std::unordered_map<const sem::StructMember*, const char*> builtin_members; std::unordered_map<const sem::StructMember*, const char*> builtin_members;
@ -78,7 +78,7 @@ Output FirstIndexOffset::Run(const Program* in, const DataMap& data) {
// Traverse the AST scanning for builtin accesses via variables (includes // Traverse the AST scanning for builtin accesses via variables (includes
// parameters) or structure member accesses. // parameters) or structure member accesses.
for (auto* node : in->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* var = node->As<ast::Variable>()) { if (auto* var = node->As<ast::Variable>()) {
for (ast::Decoration* dec : var->decorations()) { for (ast::Decoration* dec : var->decorations()) {
if (auto* builtin_dec = dec->As<ast::BuiltinDecoration>()) { if (auto* builtin_dec = dec->As<ast::BuiltinDecoration>()) {
@ -173,10 +173,8 @@ Output FirstIndexOffset::Run(const Program* in, const DataMap& data) {
ctx.Clone(); ctx.Clone();
return Output( outputs.Add<Data>(has_vertex_index, has_instance_index, vertex_index_offset,
Program(std::move(out)), instance_index_offset);
std::make_unique<Data>(has_vertex_index, has_instance_index,
vertex_index_offset, instance_index_offset));
} }
} // namespace transform } // namespace transform

View File

@ -55,7 +55,7 @@ namespace transform {
/// return vert_idx; /// return vert_idx;
/// } /// }
/// ///
class FirstIndexOffset : public Transform { class FirstIndexOffset : public Castable<FirstIndexOffset, Transform> {
public: public:
/// BindingPoint is consumed by the FirstIndexOffset transform. /// BindingPoint is consumed by the FirstIndexOffset transform.
/// BindingPoint specifies the binding point of the first index uniform /// BindingPoint specifies the binding point of the first index uniform
@ -112,11 +112,14 @@ class FirstIndexOffset : public Transform {
/// Destructor /// Destructor
~FirstIndexOffset() override; ~FirstIndexOffset() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
private: private:
uint32_t binding_ = 0; uint32_t binding_ = 0;

View File

@ -20,6 +20,8 @@
#include "src/program_builder.h" #include "src/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::FoldConstants);
namespace tint { namespace tint {
namespace { namespace {
@ -318,10 +320,7 @@ FoldConstants::FoldConstants() = default;
FoldConstants::~FoldConstants() = default; FoldConstants::~FoldConstants() = default;
Output FoldConstants::Run(const Program* in, const DataMap&) { void FoldConstants::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
ExprToValue expr_to_value; ExprToValue expr_to_value;
// Visit inner expressions before outer expressions // Visit inner expressions before outer expressions
@ -345,8 +344,6 @@ Output FoldConstants::Run(const Program* in, const DataMap&) {
} }
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
} // namespace transform } // namespace transform

View File

@ -21,7 +21,7 @@ namespace tint {
namespace transform { namespace transform {
/// FoldConstants transforms the AST by folding constant expressions /// FoldConstants transforms the AST by folding constant expressions
class FoldConstants : public Transform { class FoldConstants : public Castable<FoldConstants, Transform> {
public: public:
/// Constructor /// Constructor
FoldConstants(); FoldConstants();
@ -29,11 +29,14 @@ class FoldConstants : public Transform {
/// Destructor /// Destructor
~FoldConstants() override; ~FoldConstants() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -29,6 +29,8 @@
#include "src/transform/wrap_arrays_in_structs.h" #include "src/transform/wrap_arrays_in_structs.h"
#include "src/transform/zero_init_workgroup_memory.h" #include "src/transform/zero_init_workgroup_memory.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Hlsl);
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -69,6 +71,7 @@ Output Hlsl::Run(const Program* in, const DataMap&) {
CloneContext ctx(&builder, &out.program); CloneContext ctx(&builder, &out.program);
AddEmptyEntryPoint(ctx); AddEmptyEntryPoint(ctx);
ctx.Clone(); ctx.Clone();
builder.SetTransformApplied(this);
return Output{Program(std::move(builder))}; return Output{Program(std::move(builder))};
} }

View File

@ -27,7 +27,7 @@ namespace transform {
/// Hlsl is a transform used to sanitize a Program for use with the Hlsl writer. /// Hlsl is a transform used to sanitize a Program for use with the Hlsl writer.
/// Passing a non-sanitized Program to the Hlsl writer will result in undefined /// Passing a non-sanitized Program to the Hlsl writer will result in undefined
/// behavior. /// behavior.
class Hlsl : public Transform { class Hlsl : public Castable<Hlsl, Transform> {
public: public:
/// Constructor /// Constructor
Hlsl(); Hlsl();

View File

@ -25,6 +25,8 @@
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::InlinePointerLets);
namespace tint { namespace tint {
namespace transform { namespace transform {
namespace { namespace {
@ -79,10 +81,7 @@ InlinePointerLets::InlinePointerLets() = default;
InlinePointerLets::~InlinePointerLets() = default; InlinePointerLets::~InlinePointerLets() = default;
Output InlinePointerLets::Run(const Program* in, const DataMap&) { void InlinePointerLets::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
// If not null, current_ptr_let is the current PtrLet being operated on. // If not null, current_ptr_let is the current PtrLet being operated on.
PtrLet* current_ptr_let = nullptr; PtrLet* current_ptr_let = nullptr;
// A map of the AST `let` variable to the PtrLet // A map of the AST `let` variable to the PtrLet
@ -107,7 +106,7 @@ Output InlinePointerLets::Run(const Program* in, const DataMap&) {
} }
} }
if (auto* ident = expr->As<ast::IdentifierExpression>()) { if (auto* ident = expr->As<ast::IdentifierExpression>()) {
if (auto* vu = in->Sem().Get<sem::VariableUser>(ident)) { if (auto* vu = ctx.src->Sem().Get<sem::VariableUser>(ident)) {
auto* var = vu->Variable()->Declaration(); auto* var = vu->Variable()->Declaration();
auto it = ptr_lets.find(var); auto it = ptr_lets.find(var);
if (it != ptr_lets.end()) { if (it != ptr_lets.end()) {
@ -130,13 +129,13 @@ Output InlinePointerLets::Run(const Program* in, const DataMap&) {
// Find all the pointer-typed `let` declarations. // Find all the pointer-typed `let` declarations.
// Note that these must be function-scoped, as module-scoped `let`s are not // Note that these must be function-scoped, as module-scoped `let`s are not
// permitted. // permitted.
for (auto* node : in->ASTNodes().Objects()) { for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* let = node->As<ast::VariableDeclStatement>()) { if (auto* let = node->As<ast::VariableDeclStatement>()) {
if (!let->variable()->is_const()) { if (!let->variable()->is_const()) {
continue; // Not a `let` declaration. Ignore. continue; // Not a `let` declaration. Ignore.
} }
auto* var = in->Sem().Get(let->variable()); auto* var = ctx.src->Sem().Get(let->variable());
if (!var->Type()->Is<sem::Pointer>()) { if (!var->Type()->Is<sem::Pointer>()) {
continue; // Not a pointer type. Ignore. continue; // Not a pointer type. Ignore.
} }
@ -183,8 +182,6 @@ Output InlinePointerLets::Run(const Program* in, const DataMap&) {
} }
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
} // namespace transform } // namespace transform

View File

@ -31,7 +31,7 @@ namespace transform {
/// Note: InlinePointerLets does not operate on module-scope `let`s, as these /// Note: InlinePointerLets does not operate on module-scope `let`s, as these
/// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants /// cannot be pointers: https://gpuweb.github.io/gpuweb/wgsl/#module-constants
/// `A module-scope let-declared constant must be of atomic-free plain type.` /// `A module-scope let-declared constant must be of atomic-free plain type.`
class InlinePointerLets : public Transform { class InlinePointerLets : public Castable<InlinePointerLets, Transform> {
public: public:
/// Constructor /// Constructor
InlinePointerLets(); InlinePointerLets();
@ -39,11 +39,14 @@ class InlinePointerLets : public Transform {
/// Destructor /// Destructor
~InlinePointerLets() override; ~InlinePointerLets() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -14,6 +14,8 @@
#include "src/transform/manager.h" #include "src/transform/manager.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Manager);
namespace tint { namespace tint {
namespace transform { namespace transform {

View File

@ -28,7 +28,7 @@ namespace transform {
/// The inner transforms will execute in the appended order. /// The inner transforms will execute in the appended order.
/// If any inner transform fails the manager will return immediately and /// If any inner transform fails the manager will return immediately and
/// the error can be retrieved with the Output's diagnostics. /// the error can be retrieved with the Output's diagnostics.
class Manager : public Transform { class Manager : public Castable<Manager, Transform> {
public: public:
/// Constructor /// Constructor
Manager(); Manager();

View File

@ -36,6 +36,7 @@
#include "src/transform/wrap_arrays_in_structs.h" #include "src/transform/wrap_arrays_in_structs.h"
#include "src/transform/zero_init_workgroup_memory.h" #include "src/transform/zero_init_workgroup_memory.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Config);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Result); TINT_INSTANTIATE_TYPEINFO(tint::transform::Msl::Result);
@ -103,6 +104,8 @@ Output Msl::Run(const Program* in, const DataMap& inputs) {
auto result = std::make_unique<Result>( auto result = std::make_unique<Result>(
out.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes); out.data.Get<ArrayLengthFromUniform::Result>()->needs_buffer_sizes);
builder.SetTransformApplied(this);
return Output{Program(std::move(builder)), std::move(result)}; return Output{Program(std::move(builder)), std::move(result)};
} }

View File

@ -23,7 +23,7 @@ namespace transform {
/// Msl is a transform used to sanitize a Program for use with the Msl writer. /// Msl is a transform used to sanitize a Program for use with the Msl writer.
/// Passing a non-sanitized Program to the Msl writer will result in undefined /// Passing a non-sanitized Program to the Msl writer will result in undefined
/// behavior. /// behavior.
class Msl : public Transform { class Msl : public Castable<Msl, Transform> {
public: public:
/// The default buffer slot to use for the storage buffer size buffer. /// The default buffer slot to use for the storage buffer size buffer.
const uint32_t kDefaultBufferSizeUniformIndex = 30; const uint32_t kDefaultBufferSizeUniformIndex = 30;

View File

@ -22,6 +22,8 @@
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements);
namespace tint { namespace tint {
namespace transform { namespace transform {
namespace { namespace {
@ -89,10 +91,7 @@ PadArrayElements::PadArrayElements() = default;
PadArrayElements::~PadArrayElements() = default; PadArrayElements::~PadArrayElements() = default;
Output PadArrayElements::Run(const Program* in, const DataMap&) { void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays; std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays;
@ -149,8 +148,6 @@ Output PadArrayElements::Run(const Program* in, const DataMap&) {
}); });
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
} // namespace transform } // namespace transform

View File

@ -30,7 +30,7 @@ namespace transform {
/// structure element type. /// structure element type.
/// This transform helps with backends that cannot directly return arrays or use /// This transform helps with backends that cannot directly return arrays or use
/// them as parameters. /// them as parameters.
class PadArrayElements : public Transform { class PadArrayElements : public Castable<PadArrayElements, Transform> {
public: public:
/// Constructor /// Constructor
PadArrayElements(); PadArrayElements();
@ -38,11 +38,14 @@ class PadArrayElements : public Transform {
/// Destructor /// Destructor
~PadArrayElements() override; ~PadArrayElements() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -21,6 +21,8 @@
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/statement.h" #include "src/sem/statement.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteInitializersToConstVar);
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -28,10 +30,9 @@ PromoteInitializersToConstVar::PromoteInitializersToConstVar() = default;
PromoteInitializersToConstVar::~PromoteInitializersToConstVar() = default; PromoteInitializersToConstVar::~PromoteInitializersToConstVar() = default;
Output PromoteInitializersToConstVar::Run(const Program* in, const DataMap&) { void PromoteInitializersToConstVar::Run(CloneContext& ctx,
ProgramBuilder out; const DataMap&,
CloneContext ctx(&out, in); DataMap&) {
// Scan the AST nodes for array and structure initializers which // Scan the AST nodes for array and structure initializers which
// need to be promoted to their own constant declaration. // need to be promoted to their own constant declaration.
@ -100,8 +101,6 @@ Output PromoteInitializersToConstVar::Run(const Program* in, const DataMap&) {
} }
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
} // namespace transform } // namespace transform

View File

@ -23,7 +23,8 @@ namespace transform {
/// A transform that hoists the array and structure initializers to a constant /// A transform that hoists the array and structure initializers to a constant
/// variable, declared just before the statement of usage. See /// variable, declared just before the statement of usage. See
/// crbug.com/tint/406 for more details. /// crbug.com/tint/406 for more details.
class PromoteInitializersToConstVar : public Transform { class PromoteInitializersToConstVar
: public Castable<PromoteInitializersToConstVar, Transform> {
public: public:
/// Constructor /// Constructor
PromoteInitializersToConstVar(); PromoteInitializersToConstVar();
@ -31,11 +32,14 @@ class PromoteInitializersToConstVar : public Transform {
/// Destructor /// Destructor
~PromoteInitializersToConstVar() override; ~PromoteInitializersToConstVar() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -22,7 +22,9 @@
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/sem/member_accessor_expression.h" #include "src/sem/member_accessor_expression.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer::Data); TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer::Data);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Renamer::Config);
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -835,18 +837,18 @@ const char* kReservedKeywordsMSL[] = {"access",
} // namespace } // namespace
Renamer::Data::Data(Remappings&& r) : remappings(std::move(r)) {} Renamer::Data::Data(Remappings&& r) : remappings(std::move(r)) {}
Renamer::Data::Data(const Data&) = default; Renamer::Data::Data(const Data&) = default;
Renamer::Data::~Data() = default; Renamer::Data::~Data() = default;
Renamer::Renamer() : cfg_{} {} Renamer::Config::Config(Target t) : target(t) {}
Renamer::Config::Config(const Config&) = default;
Renamer::Renamer(const Config& config) : cfg_(config) {} Renamer::Config::~Config() = default;
Renamer::Renamer() : deprecated_cfg_(Target::kAll) {}
Renamer::Renamer(const Config& config) : deprecated_cfg_(config) {}
Renamer::~Renamer() = default; Renamer::~Renamer() = default;
Output Renamer::Run(const Program* in, const DataMap&) { Output Renamer::Run(const Program* in, const DataMap& inputs) {
ProgramBuilder out; ProgramBuilder out;
// Disable auto-cloning of symbols, since we want to rename them. // Disable auto-cloning of symbols, since we want to rename them.
CloneContext ctx(&out, in, false); CloneContext ctx(&out, in, false);
@ -879,9 +881,14 @@ Output Renamer::Run(const Program* in, const DataMap&) {
Data::Remappings remappings; Data::Remappings remappings;
auto* cfg = inputs.Get<Config>();
if (!cfg) {
cfg = &deprecated_cfg_;
}
ctx.ReplaceAll([&](Symbol sym_in) { ctx.ReplaceAll([&](Symbol sym_in) {
auto name_in = ctx.src->Symbols().NameFor(sym_in); auto name_in = ctx.src->Symbols().NameFor(sym_in);
switch (cfg_.target) { switch (cfg->target) {
case Target::kAll: case Target::kAll:
// Always rename. // Always rename.
break; break;

View File

@ -24,7 +24,7 @@ namespace tint {
namespace transform { namespace transform {
/// Renamer is a Transform that renames all the symbols in a program. /// Renamer is a Transform that renames all the symbols in a program.
class Renamer : public Transform { class Renamer : public Castable<Renamer, Transform> {
public: public:
/// Data is outputted by the Renamer transform. /// Data is outputted by the Renamer transform.
/// Data holds information about shader usage and constant buffer offsets. /// Data holds information about shader usage and constant buffer offsets.
@ -57,16 +57,27 @@ class Renamer : public Transform {
}; };
/// Configuration options for the transform /// Configuration options for the transform
struct Config { struct Config : public Castable<Config, transform::Data> {
/// Constructor
/// @param tgt the targets to rename
explicit Config(Target tgt);
/// Copy constructor
Config(const Config&);
/// Destructor
~Config() override;
/// The targets to rename /// The targets to rename
Target target = Target::kAll; Target const target = Target::kAll;
}; };
/// Constructor using a default configuration /// Constructor using a the configuration provided in the input Data
Renamer(); Renamer();
/// Constructor /// Constructor
/// @param config the configuration for the transform /// @param config the configuration for the transform
/// [DEPRECATED] Pass Config as input Data
explicit Renamer(const Config& config); explicit Renamer(const Config& config);
/// Destructor /// Destructor
@ -79,7 +90,7 @@ class Renamer : public Transform {
Output Run(const Program* program, const DataMap& data = {}) override; Output Run(const Program* program, const DataMap& data = {}) override;
private: private:
Config const cfg_; Config const deprecated_cfg_;
}; };
} // namespace transform } // namespace transform

View File

@ -207,8 +207,9 @@ fn frag_main() {
} }
)"; )";
Renamer::Config config{Renamer::Target::kHlslKeywords}; DataMap inputs;
auto got = Run(src, std::make_unique<Renamer>(config)); inputs.Add<Renamer::Config>(Renamer::Target::kHlslKeywords);
auto got = Run<Renamer>(src, inputs);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
@ -231,8 +232,9 @@ fn frag_main() {
} }
)"; )";
Renamer::Config config{Renamer::Target::kMslKeywords}; DataMap inputs;
auto got = Run(src, std::make_unique<Renamer>(config)); inputs.Add<Renamer::Config>(Renamer::Target::kMslKeywords);
auto got = Run<Renamer>(src, inputs);
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }

View File

@ -25,6 +25,8 @@
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Simplify);
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -32,10 +34,7 @@ Simplify::Simplify() = default;
Simplify::~Simplify() = default; Simplify::~Simplify() = default;
Output Simplify::Run(const Program* in, const DataMap&) { void Simplify::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* { ctx.ReplaceAll([&](ast::Expression* expr) -> ast::Expression* {
if (auto* outer = expr->As<ast::UnaryOpExpression>()) { if (auto* outer = expr->As<ast::UnaryOpExpression>()) {
if (auto* inner = outer->expr()->As<ast::UnaryOpExpression>()) { if (auto* inner = outer->expr()->As<ast::UnaryOpExpression>()) {
@ -55,8 +54,6 @@ Output Simplify::Run(const Program* in, const DataMap&) {
}); });
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
} // namespace transform } // namespace transform

View File

@ -28,7 +28,7 @@ namespace transform {
/// Simplify currently optimizes the following: /// Simplify currently optimizes the following:
/// `&(*(expr))` => `expr` /// `&(*(expr))` => `expr`
/// `*(&(expr))` => `expr` /// `*(&(expr))` => `expr`
class Simplify : public Transform { class Simplify : public Castable<Simplify, Transform> {
public: public:
/// Constructor /// Constructor
Simplify(); Simplify();
@ -36,11 +36,14 @@ class Simplify : public Transform {
/// Destructor /// Destructor
~Simplify() override; ~Simplify() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -21,6 +21,7 @@
#include "src/sem/function.h" #include "src/sem/function.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint);
TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::SingleEntryPoint::Config);
namespace tint { namespace tint {
@ -30,69 +31,63 @@ SingleEntryPoint::SingleEntryPoint() = default;
SingleEntryPoint::~SingleEntryPoint() = default; SingleEntryPoint::~SingleEntryPoint() = default;
Output SingleEntryPoint::Run(const Program* in, const DataMap& data) { void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
ProgramBuilder out; auto* cfg = inputs.Get<Config>();
auto* cfg = data.Get<Config>();
if (cfg == nullptr) { if (cfg == nullptr) {
out.Diagnostics().add_error(diag::System::Transform, ctx.dst->Diagnostics().add_error(
"missing transform data for SingleEntryPoint"); diag::System::Transform, "missing transform data for SingleEntryPoint");
return Output(Program(std::move(out))); return;
} }
// Find the target entry point. // Find the target entry point.
ast::Function* entry_point = nullptr; ast::Function* entry_point = nullptr;
for (auto* f : in->AST().Functions()) { for (auto* f : ctx.src->AST().Functions()) {
if (!f->IsEntryPoint()) { if (!f->IsEntryPoint()) {
continue; continue;
} }
if (in->Symbols().NameFor(f->symbol()) == cfg->entry_point_name) { if (ctx.src->Symbols().NameFor(f->symbol()) == cfg->entry_point_name) {
entry_point = f; entry_point = f;
break; break;
} }
} }
if (entry_point == nullptr) { if (entry_point == nullptr) {
out.Diagnostics().add_error( ctx.dst->Diagnostics().add_error(
diag::System::Transform, diag::System::Transform,
"entry point '" + cfg->entry_point_name + "' not found"); "entry point '" + cfg->entry_point_name + "' not found");
return Output(Program(std::move(out))); return;
} }
CloneContext ctx(&out, in); auto& sem = ctx.src->Sem();
auto* sem = in->Sem().Get(entry_point);
// Build set of referenced module-scope variables for faster lookups later. // Build set of referenced module-scope variables for faster lookups later.
std::unordered_set<const ast::Variable*> referenced_vars; std::unordered_set<const ast::Variable*> referenced_vars;
for (auto* var : sem->ReferencedModuleVariables()) { for (auto* var : sem.Get(entry_point)->ReferencedModuleVariables()) {
referenced_vars.emplace(var->Declaration()); referenced_vars.emplace(var->Declaration());
} }
// Clone any module-scope variables, types, and functions that are statically // Clone any module-scope variables, types, and functions that are statically
// referenced by the target entry point. // referenced by the target entry point.
for (auto* decl : in->AST().GlobalDeclarations()) { for (auto* decl : ctx.src->AST().GlobalDeclarations()) {
if (auto* ty = decl->As<ast::TypeDecl>()) { if (auto* ty = decl->As<ast::TypeDecl>()) {
// TODO(jrprice): Strip unused types. // TODO(jrprice): Strip unused types.
out.AST().AddTypeDecl(ctx.Clone(ty)); ctx.dst->AST().AddTypeDecl(ctx.Clone(ty));
} else if (auto* var = decl->As<ast::Variable>()) { } else if (auto* var = decl->As<ast::Variable>()) {
if (var->is_const() || referenced_vars.count(var)) { if (var->is_const() || referenced_vars.count(var)) {
out.AST().AddGlobalVariable(ctx.Clone(var)); ctx.dst->AST().AddGlobalVariable(ctx.Clone(var));
} }
} else if (auto* func = decl->As<ast::Function>()) { } else if (auto* func = decl->As<ast::Function>()) {
if (in->Sem().Get(func)->HasAncestorEntryPoint(entry_point->symbol())) { if (sem.Get(func)->HasAncestorEntryPoint(entry_point->symbol())) {
out.AST().AddFunction(ctx.Clone(func)); ctx.dst->AST().AddFunction(ctx.Clone(func));
} }
} else { } else {
TINT_UNREACHABLE(Transform, out.Diagnostics()) TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
<< "unhandled global declaration: " << decl->TypeInfo().name; << "unhandled global declaration: " << decl->TypeInfo().name;
return Output(Program(std::move(out))); return;
} }
} }
// Clone the entry point. // Clone the entry point.
out.AST().AddFunction(ctx.Clone(entry_point)); ctx.dst->AST().AddFunction(ctx.Clone(entry_point));
return Output(Program(std::move(out)));
} }
SingleEntryPoint::Config::Config(std::string entry_point) SingleEntryPoint::Config::Config(std::string entry_point)

View File

@ -26,7 +26,7 @@ namespace transform {
/// ///
/// All module-scope variables, types, and functions that are not used by the /// All module-scope variables, types, and functions that are not used by the
/// target entry point will also be removed. /// target entry point will also be removed.
class SingleEntryPoint : public Transform { class SingleEntryPoint : public Castable<SingleEntryPoint, Transform> {
public: public:
/// Configuration options for the transform /// Configuration options for the transform
struct Config : public Castable<Config, Data> { struct Config : public Castable<Config, Data> {
@ -54,11 +54,14 @@ class SingleEntryPoint : public Transform {
/// Destructor /// Destructor
~SingleEntryPoint() override; ~SingleEntryPoint() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
}; };
} // namespace transform } // namespace transform

View File

@ -34,6 +34,7 @@
#include "src/transform/simplify.h" #include "src/transform/simplify.h"
#include "src/transform/zero_init_workgroup_memory.h" #include "src/transform/zero_init_workgroup_memory.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Spirv);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Spirv::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::Spirv::Config);
namespace tint { namespace tint {
@ -70,6 +71,7 @@ Output Spirv::Run(const Program* in, const DataMap& data) {
} }
ctx2.Clone(); ctx2.Clone();
out2.SetTransformApplied(this);
return Output{Program(std::move(out2))}; return Output{Program(std::move(out2))};
} }

View File

@ -29,7 +29,7 @@ namespace transform {
/// Spirv is a transform used to sanitize a Program for use with the Spirv /// Spirv is a transform used to sanitize a Program for use with the Spirv
/// writer. Passing a non-sanitized Program to the Spirv writer will result in /// writer. Passing a non-sanitized Program to the Spirv writer will result in
/// undefined behavior. /// undefined behavior.
class Spirv : public Transform { class Spirv : public Castable<Spirv, Transform> {
public: public:
/// Configuration options for the transform. /// Configuration options for the transform.
struct Config : public Castable<Config, Data> { struct Config : public Castable<Config, Data> {

View File

@ -32,32 +32,6 @@ namespace transform {
template <typename BASE> template <typename BASE>
class TransformTestBase : public BASE { class TransformTestBase : public BASE {
public: public:
/// Transforms and returns the WGSL source `in`, transformed using
/// `transforms`.
/// @param in the input WGSL source
/// @param transforms the list of transforms to apply
/// @param data the optional DataMap to pass to Transform::Run()
/// @return the transformed output
Output Run(std::string in,
std::vector<std::unique_ptr<transform::Transform>> transforms,
const DataMap& data = {}) {
auto file = std::make_unique<Source::File>("test", in);
auto program = reader::wgsl::Parse(file.get());
// Keep this pointer alive after Transform() returns
files_.emplace_back(std::move(file));
if (!program.IsValid()) {
return Output(std::move(program));
}
Manager manager;
for (auto& transform : transforms) {
manager.append(std::move(transform));
}
return manager.Run(&program, data);
}
/// Transforms and returns the WGSL source `in`, transformed using /// Transforms and returns the WGSL source `in`, transformed using
/// `transform`. /// `transform`.
/// @param transform the transform to apply /// @param transform the transform to apply
@ -77,9 +51,24 @@ class TransformTestBase : public BASE {
/// @param in the input WGSL source /// @param in the input WGSL source
/// @param data the optional DataMap to pass to Transform::Run() /// @param data the optional DataMap to pass to Transform::Run()
/// @return the transformed output /// @return the transformed output
template <typename TRANSFORM> template <typename... TRANSFORMS>
Output Run(std::string in, const DataMap& data = {}) { Output Run(std::string in, const DataMap& data = {}) {
return Run(std::move(in), std::make_unique<TRANSFORM>(), data); auto file = std::make_unique<Source::File>("test", in);
auto program = reader::wgsl::Parse(file.get());
// Keep this pointer alive after Transform() returns
files_.emplace_back(std::move(file));
if (!program.IsValid()) {
return Output(std::move(program));
}
Manager manager;
for (auto* transform_ptr :
std::initializer_list<Transform*>{new TRANSFORMS()...}) {
manager.append(std::unique_ptr<Transform>(transform_ptr));
}
return manager.Run(&program, data);
} }
/// @param output the output of the transform /// @param output the output of the transform

View File

@ -15,11 +15,13 @@
#include "src/transform/transform.h" #include "src/transform/transform.h"
#include <algorithm> #include <algorithm>
#include <string>
#include "src/program_builder.h" #include "src/program_builder.h"
#include "src/sem/atomic_type.h" #include "src/sem/atomic_type.h"
#include "src/sem/reference_type.h" #include "src/sem/reference_type.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
TINT_INSTANTIATE_TYPEINFO(tint::transform::Data); TINT_INSTANTIATE_TYPEINFO(tint::transform::Data);
namespace tint { namespace tint {
@ -40,6 +42,35 @@ Output::Output(Program&& p) : program(std::move(p)) {}
Transform::Transform() = default; Transform::Transform() = default;
Transform::~Transform() = default; Transform::~Transform() = default;
Output Transform::Run(const Program* program, const DataMap& data /* = {} */) {
ProgramBuilder builder;
CloneContext ctx(&builder, program);
Output output;
Run(ctx, data, output.data);
builder.SetTransformApplied(this);
output.program = Program(std::move(builder));
return output;
}
void Transform::Run(CloneContext& ctx, const DataMap&, DataMap&) {
TINT_UNIMPLEMENTED(Transform, ctx.dst->Diagnostics())
<< "Transform::Run() unimplemented for " << TypeInfo().name;
}
bool Transform::Requires(CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps) {
for (auto* dep : deps) {
if (!ctx.src->HasTransformApplied(dep)) {
ctx.dst->Diagnostics().add_error(
diag::System::Transform, std::string(TypeInfo().name) +
" depends on " + std::string(dep->name) +
" but the dependency was not run");
return false;
}
}
return true;
}
ast::Function* Transform::CloneWithStatementsAtStart( ast::Function* Transform::CloneWithStatementsAtStart(
CloneContext* ctx, CloneContext* ctx,
ast::Function* in, ast::Function* in,

View File

@ -19,6 +19,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "src/castable.h"
#include "src/program.h" #include "src/program.h"
namespace tint { namespace tint {
@ -145,20 +146,45 @@ class Output {
}; };
/// Interface for Program transforms /// Interface for Program transforms
class Transform { class Transform : public Castable<Transform> {
public: public:
/// Constructor /// Constructor
Transform(); Transform();
/// Destructor /// Destructor
virtual ~Transform(); ~Transform() override;
/// Runs the transform on `program`, returning the transformation result. /// Runs the transform on `program`, returning the transformation result.
/// @param program the source program to transform /// @param program the source program to transform
/// @param data optional extra transform-specific input data /// @param data optional extra transform-specific input data
/// @returns the transformation result /// @returns the transformation result
virtual Output Run(const Program* program, const DataMap& data = {}) = 0; virtual Output Run(const Program* program, const DataMap& data = {});
protected: protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
virtual void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs);
/// Requires appends an error diagnostic to `ctx.dst` if the template type
/// transforms were not already run on `ctx.src`.
/// @param ctx the CloneContext
/// @returns true if all dependency transforms have been run
template <typename... TRANSFORMS>
bool Requires(CloneContext& ctx) {
return Requires(ctx, {&::tint::TypeInfo::Of<TRANSFORMS>()...});
}
/// Requires appends an error diagnostic to `ctx.dst` if the list of
/// Transforms were not already run on `ctx.src`.
/// @param ctx the CloneContext
/// @param deps the list of Transform TypeInfos
/// @returns true if all dependency transforms have been run
bool Requires(CloneContext& ctx,
std::initializer_list<const ::tint::TypeInfo*> deps);
/// Clones the function `in` adding `statements` to the beginning of the /// Clones the function `in` adding `statements` to the beginning of the
/// cloned function body. /// cloned function body.
/// @param ctx the clone context /// @param ctx the clone context

View File

@ -24,6 +24,7 @@
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling);
TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);
namespace tint { namespace tint {
@ -456,21 +457,20 @@ struct State {
VertexPulling::VertexPulling() = default; VertexPulling::VertexPulling() = default;
VertexPulling::~VertexPulling() = default; VertexPulling::~VertexPulling() = default;
Output VertexPulling::Run(const Program* in, const DataMap& data) { void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
ProgramBuilder out;
auto cfg = cfg_; auto cfg = cfg_;
if (auto* cfg_data = data.Get<Config>()) { if (auto* cfg_data = inputs.Get<Config>()) {
cfg = *cfg_data; cfg = *cfg_data;
} }
// Find entry point // Find entry point
auto* func = in->AST().Functions().Find( auto* func = ctx.src->AST().Functions().Find(
in->Symbols().Get(cfg.entry_point_name), ast::PipelineStage::kVertex); ctx.src->Symbols().Get(cfg.entry_point_name),
ast::PipelineStage::kVertex);
if (func == nullptr) { if (func == nullptr) {
out.Diagnostics().add_error(diag::System::Transform, ctx.dst->Diagnostics().add_error(diag::System::Transform,
"Vertex stage entry point not found"); "Vertex stage entry point not found");
return Output(Program(std::move(out))); return;
} }
// TODO(idanr): Need to check shader locations in descriptor cover all // TODO(idanr): Need to check shader locations in descriptor cover all
@ -479,15 +479,11 @@ Output VertexPulling::Run(const Program* in, const DataMap& data) {
// TODO(idanr): Make sure we covered all error cases, to guarantee the // TODO(idanr): Make sure we covered all error cases, to guarantee the
// following stages will pass // following stages will pass
CloneContext ctx(&out, in);
State state{ctx, cfg}; State state{ctx, cfg};
state.AddVertexStorageBuffers(); state.AddVertexStorageBuffers();
state.Process(func); state.Process(func);
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
VertexPulling::Config::Config() = default; VertexPulling::Config::Config() = default;

View File

@ -130,7 +130,7 @@ using VertexStateDescriptor = std::vector<VertexBufferLayoutDescriptor>;
/// code, but these are types that the data may arrive as. We need to convert /// code, but these are types that the data may arrive as. We need to convert
/// these smaller types into the base types such as `f32` and `u32` for the /// these smaller types into the base types such as `f32` and `u32` for the
/// shader to use. /// shader to use.
class VertexPulling : public Transform { class VertexPulling : public Castable<VertexPulling, Transform> {
public: public:
/// Configuration options for the transform /// Configuration options for the transform
struct Config : public Castable<Config, Data> { struct Config : public Castable<Config, Data> {
@ -164,11 +164,14 @@ class VertexPulling : public Transform {
/// Destructor /// Destructor
~VertexPulling() override; ~VertexPulling() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
private: private:
Config cfg_; Config cfg_;

View File

@ -21,6 +21,8 @@
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WrapArraysInStructs);
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -33,10 +35,7 @@ WrapArraysInStructs::WrapArraysInStructs() = default;
WrapArraysInStructs::~WrapArraysInStructs() = default; WrapArraysInStructs::~WrapArraysInStructs() = default;
Output WrapArraysInStructs::Run(const Program* in, const DataMap&) { void WrapArraysInStructs::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
std::unordered_map<const sem::Array*, WrappedArrayInfo> wrapped_arrays; std::unordered_map<const sem::Array*, WrappedArrayInfo> wrapped_arrays;
@ -60,8 +59,8 @@ Output WrapArraysInStructs::Run(const Program* in, const DataMap&) {
// Fix up array accessors so `a[1]` becomes `a.arr[1]` // Fix up array accessors so `a[1]` becomes `a.arr[1]`
ctx.ReplaceAll([&](ast::ArrayAccessorExpression* accessor) ctx.ReplaceAll([&](ast::ArrayAccessorExpression* accessor)
-> ast::ArrayAccessorExpression* { -> ast::ArrayAccessorExpression* {
if (auto* array = if (auto* array = ::tint::As<sem::Array>(
As<sem::Array>(sem.Get(accessor->array())->Type()->UnwrapRef())) { sem.Get(accessor->array())->Type()->UnwrapRef())) {
if (wrapper(array)) { if (wrapper(array)) {
// Array is wrapped in a structure. Emit a member accessor to get // Array is wrapped in a structure. Emit a member accessor to get
// to the actual array. // to the actual array.
@ -76,7 +75,8 @@ Output WrapArraysInStructs::Run(const Program* in, const DataMap&) {
// Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))` // Fix up array constructors so `A(1,2)` becomes `tint_array_wrapper(A(1,2))`
ctx.ReplaceAll([&](ast::TypeConstructorExpression* ctor) -> ast::Expression* { ctx.ReplaceAll([&](ast::TypeConstructorExpression* ctor) -> ast::Expression* {
if (auto* array = As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) { if (auto* array =
::tint::As<sem::Array>(sem.Get(ctor)->Type()->UnwrapRef())) {
if (auto w = wrapper(array)) { if (auto w = wrapper(array)) {
// Wrap the array type constructor with another constructor for // Wrap the array type constructor with another constructor for
// the wrapper // the wrapper
@ -91,8 +91,6 @@ Output WrapArraysInStructs::Run(const Program* in, const DataMap&) {
}); });
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
WrapArraysInStructs::WrappedArrayInfo WrapArraysInStructs::WrapArray( WrapArraysInStructs::WrappedArrayInfo WrapArraysInStructs::WrapArray(

View File

@ -36,7 +36,7 @@ namespace transform {
/// wrapping. /// wrapping.
/// This transform helps with backends that cannot directly return arrays or use /// This transform helps with backends that cannot directly return arrays or use
/// them as parameters. /// them as parameters.
class WrapArraysInStructs : public Transform { class WrapArraysInStructs : public Castable<WrapArraysInStructs, Transform> {
public: public:
/// Constructor /// Constructor
WrapArraysInStructs(); WrapArraysInStructs();
@ -44,11 +44,14 @@ class WrapArraysInStructs : public Transform {
/// Destructor /// Destructor
~WrapArraysInStructs() override; ~WrapArraysInStructs() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
private: private:
struct WrappedArrayInfo { struct WrappedArrayInfo {

View File

@ -23,6 +23,8 @@
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::ZeroInitWorkgroupMemory);
namespace tint { namespace tint {
namespace transform { namespace transform {
@ -111,13 +113,10 @@ ZeroInitWorkgroupMemory::ZeroInitWorkgroupMemory() = default;
ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default; ZeroInitWorkgroupMemory::~ZeroInitWorkgroupMemory() = default;
Output ZeroInitWorkgroupMemory::Run(const Program* in, const DataMap&) { void ZeroInitWorkgroupMemory::Run(CloneContext& ctx, const DataMap&, DataMap&) {
ProgramBuilder out;
CloneContext ctx(&out, in);
auto& sem = ctx.src->Sem(); auto& sem = ctx.src->Sem();
for (auto* ast_func : in->AST().Functions()) { for (auto* ast_func : ctx.src->AST().Functions()) {
if (!ast_func->IsEntryPoint()) { if (!ast_func->IsEntryPoint()) {
continue; continue;
} }
@ -192,8 +191,6 @@ Output ZeroInitWorkgroupMemory::Run(const Program* in, const DataMap&) {
} }
ctx.Clone(); ctx.Clone();
return Output(Program(std::move(out)));
} }
} // namespace transform } // namespace transform

View File

@ -23,7 +23,8 @@ namespace transform {
/// ZeroInitWorkgroupMemory is a transform that injects code at the top of entry /// ZeroInitWorkgroupMemory is a transform that injects code at the top of entry
/// points to zero-initialize workgroup memory used by that entry point (and all /// points to zero-initialize workgroup memory used by that entry point (and all
/// transitive functions called by that entry point) /// transitive functions called by that entry point)
class ZeroInitWorkgroupMemory : public Transform { class ZeroInitWorkgroupMemory
: public Castable<ZeroInitWorkgroupMemory, Transform> {
public: public:
/// Constructor /// Constructor
ZeroInitWorkgroupMemory(); ZeroInitWorkgroupMemory();
@ -31,11 +32,14 @@ class ZeroInitWorkgroupMemory : public Transform {
/// Destructor /// Destructor
~ZeroInitWorkgroupMemory() override; ~ZeroInitWorkgroupMemory() override;
/// Runs the transform on `program`, returning the transformation result. protected:
/// @param program the source program to transform /// Runs the transform using the CloneContext built for transforming a
/// @param data optional extra transform-specific input data /// program. Run() is responsible for calling Clone() on the CloneContext.
/// @returns the transformation result /// @param ctx the CloneContext primed with the input program and
Output Run(const Program* program, const DataMap& data = {}) override; /// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) override;
private: private:
struct State; struct State;

View File

@ -36,6 +36,7 @@
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/transform/calculate_array_length.h" #include "src/transform/calculate_array_length.h"
#include "src/transform/hlsl.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
#include "src/writer/append_vector.h" #include "src/writer/append_vector.h"
#include "src/writer/float_to_string.h" #include "src/writer/float_to_string.h"
@ -114,6 +115,14 @@ GeneratorImpl::GeneratorImpl(const Program* program)
GeneratorImpl::~GeneratorImpl() = default; GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate(std::ostream& out) { bool GeneratorImpl::Generate(std::ostream& out) {
if (!builder_.HasTransformApplied<transform::Hlsl>()) {
diagnostics_.add_error(
diag::System::Writer,
"HLSL writer requires the transform::Hlsl sanitizer to have been "
"applied to the input program");
return false;
}
std::stringstream pending; std::stringstream pending;
const TypeInfo* last_kind = nullptr; const TypeInfo* last_kind = nullptr;

View File

@ -21,6 +21,16 @@ namespace {
using HlslGeneratorImplTest = TestHelper; using HlslGeneratorImplTest = TestHelper;
TEST_F(HlslGeneratorImplTest, ErrorIfSanitizerNotRun) {
auto program = std::make_unique<Program>(std::move(*this));
GeneratorImpl gen(program.get());
EXPECT_FALSE(gen.Generate(out));
EXPECT_EQ(
gen.error(),
"error: HLSL writer requires the transform::Hlsl sanitizer to have been "
"applied to the input program");
}
TEST_F(HlslGeneratorImplTest, Generate) { TEST_F(HlslGeneratorImplTest, Generate) {
Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::DecorationList{}); ast::DecorationList{});

View File

@ -44,15 +44,17 @@ class TestHelperBase : public BODY, public ProgramBuilder {
if (gen_) { if (gen_) {
return *gen_; return *gen_;
} }
diag::Formatter formatter; // Fake that the HLSL sanitizer has been applied, so that we can unit test
// the writer without it erroring.
SetTransformApplied<transform::Hlsl>();
[&]() { [&]() {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< formatter.format(Diagnostics()); << diag::Formatter().format(Diagnostics());
}(); }();
program = std::make_unique<Program>(std::move(*this)); program = std::make_unique<Program>(std::move(*this));
[&]() { [&]() {
ASSERT_TRUE(program->IsValid()) ASSERT_TRUE(program->IsValid())
<< formatter.format(program->Diagnostics()); << diag::Formatter().format(program->Diagnostics());
}(); }();
gen_ = std::make_unique<GeneratorImpl>(program.get()); gen_ = std::make_unique<GeneratorImpl>(program.get());
return *gen_; return *gen_;

View File

@ -50,6 +50,7 @@
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/sem/vector_type.h" #include "src/sem/vector_type.h"
#include "src/sem/void_type.h" #include "src/sem/void_type.h"
#include "src/transform/msl.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
#include "src/writer/float_to_string.h" #include "src/writer/float_to_string.h"
@ -75,6 +76,14 @@ GeneratorImpl::GeneratorImpl(const Program* program)
GeneratorImpl::~GeneratorImpl() = default; GeneratorImpl::~GeneratorImpl() = default;
bool GeneratorImpl::Generate() { bool GeneratorImpl::Generate() {
if (!program_->HasTransformApplied<transform::Msl>()) {
diagnostics_.add_error(
diag::System::Writer,
"MSL writer requires the transform::Msl sanitizer to have been "
"applied to the input program");
return false;
}
out_ << "#include <metal_stdlib>" << std::endl << std::endl; out_ << "#include <metal_stdlib>" << std::endl << std::endl;
out_ << "using namespace metal;" << std::endl; out_ << "using namespace metal;" << std::endl;

View File

@ -22,6 +22,16 @@ namespace {
using MslGeneratorImplTest = TestHelper; using MslGeneratorImplTest = TestHelper;
TEST_F(MslGeneratorImplTest, ErrorIfSanitizerNotRun) {
auto program = std::make_unique<Program>(std::move(*this));
GeneratorImpl gen(program.get());
EXPECT_FALSE(gen.Generate());
EXPECT_EQ(
gen.error(),
"error: MSL writer requires the transform::Msl sanitizer to have been "
"applied to the input program");
}
TEST_F(MslGeneratorImplTest, Generate) { TEST_F(MslGeneratorImplTest, Generate) {
Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{}, Func("my_func", ast::VariableList{}, ty.void_(), ast::StatementList{},
ast::DecorationList{ ast::DecorationList{

View File

@ -43,6 +43,9 @@ class TestHelperBase : public BASE, public ProgramBuilder {
if (gen_) { if (gen_) {
return *gen_; return *gen_;
} }
// Fake that the MSL sanitizer has been applied, so that we can unit test
// the writer without it erroring.
SetTransformApplied<transform::Msl>();
[&]() { [&]() {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics()); << diag::Formatter().format(Diagnostics());

View File

@ -35,6 +35,7 @@
#include "src/sem/struct.h" #include "src/sem/struct.h"
#include "src/sem/variable.h" #include "src/sem/variable.h"
#include "src/sem/vector_type.h" #include "src/sem/vector_type.h"
#include "src/transform/spirv.h"
#include "src/utils/get_or_create.h" #include "src/utils/get_or_create.h"
#include "src/writer/append_vector.h" #include "src/writer/append_vector.h"
@ -258,6 +259,13 @@ Builder::Builder(const Program* program)
Builder::~Builder() = default; Builder::~Builder() = default;
bool Builder::Build() { bool Builder::Build() {
if (!builder_.HasTransformApplied<transform::Spirv>()) {
error_ =
"SPIR-V writer requires the transform::Spirv sanitizer to have been "
"applied to the input program";
return false;
}
push_capability(SpvCapabilityShader); push_capability(SpvCapabilityShader);
push_memory_model(spv::Op::OpMemoryModel, push_memory_model(spv::Op::OpMemoryModel,

View File

@ -22,6 +22,16 @@ namespace {
using BuilderTest = TestHelper; using BuilderTest = TestHelper;
TEST_F(BuilderTest, ErrorIfSanitizerNotRun) {
auto program = std::make_unique<Program>(std::move(*this));
spirv::Builder b(program.get());
EXPECT_FALSE(b.Build());
EXPECT_EQ(
b.error(),
"SPIR-V writer requires the transform::Spirv sanitizer to have been "
"applied to the input program");
}
TEST_F(BuilderTest, InsertsPreamble) { TEST_F(BuilderTest, InsertsPreamble) {
spirv::Builder& b = Build(); spirv::Builder& b = Build();

View File

@ -43,6 +43,9 @@ class TestHelperBase : public ProgramBuilder, public BASE {
if (spirv_builder) { if (spirv_builder) {
return *spirv_builder; return *spirv_builder;
} }
// Fake that the SPIR-V sanitizer has been applied, so that we can unit test
// the writer without it erroring.
SetTransformApplied<transform::Spirv>();
[&]() { [&]() {
ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" ASSERT_TRUE(IsValid()) << "Builder program is not valid\n"
<< diag::Formatter().format(Diagnostics()); << diag::Formatter().format(Diagnostics());