transform: API cleanup

Transforms are meant to be reusable. Those that hold state cannot be used concurrently.
State leakage between runs is dangerous.
To fix this:
* Add transform::Data - A new base class for extra information emitted by transforms.
* Add transform::DataMap - A container of Data, keyed by type.
* Add a transform::DataMap field to Transform::Output.
* Have FirstIndexOffset emit a FirstIndexOffset::Data.
* Deprecate the getters on the transform.

Mutability of the transform config is also dangerous as setters can be called while a transform is actively running on another thread.
To fix:
* Expose a VertexPulling::Config structure and add a constructor that accepts this.
* Deprecate the setters on VertexPulling.

Also deprecate Transform::Output::diagnostics.
Put all the transform diagnostics into the returned Program. Reduces error handling of the client.

Change-Id: Ibd228dc2fbf004ede4720e2d6019c024bc5934d1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42264
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-02-24 15:55:24 +00:00 committed by Commit Bot service account
parent 94b36c3e86
commit eb496d0a4d
20 changed files with 437 additions and 194 deletions

View File

@ -144,7 +144,7 @@ int CommonFuzzer::Run(const uint8_t* data, size_t size) {
if (transform_manager_) {
auto out = transform_manager_->Run(&program);
if (out.diagnostics.contains_errors()) {
if (!out.program.IsValid()) {
return 0;
}

View File

@ -28,8 +28,8 @@
namespace {
[[noreturn]]
void TintInternalCompilerErrorReporter(const tint::diag::List& diagnostics) {
[[noreturn]] void TintInternalCompilerErrorReporter(
const tint::diag::List& diagnostics) {
auto printer = tint::diag::Printer::create(stderr, true);
tint::diag::Formatter{}.format(diagnostics, printer.get());
exit(1);
@ -569,8 +569,8 @@ int main(int argc, const char** argv) {
}
auto out = transform_manager.Run(program.get());
if (out.diagnostics.contains_errors()) {
diag_formatter.format(out.diagnostics, diag_printer.get());
if (!out.program.IsValid()) {
diag_formatter.format(out.program.Diagnostics(), diag_printer.get());
return 1;
}

View File

@ -58,19 +58,19 @@ BoundArrayAccessors::~BoundArrayAccessors() = default;
Transform::Output BoundArrayAccessors::Run(const Program* in) {
ProgramBuilder out;
diag::List diagnostics;
CloneContext(&out, in)
.ReplaceAll([&](CloneContext* ctx, ast::ArrayAccessorExpression* expr) {
return Transform(expr, ctx, &diagnostics);
return Transform(expr, ctx);
})
.Clone();
return Output(Program(std::move(out)), std::move(diagnostics));
return Output(Program(std::move(out)));
}
ast::ArrayAccessorExpression* BoundArrayAccessors::Transform(
ast::ArrayAccessorExpression* expr,
CloneContext* ctx,
diag::List* diags) {
CloneContext* ctx) {
auto& diags = ctx->dst->Diagnostics();
auto* ret_type = ctx->src->Sem().Get(expr->array())->Type()->UnwrapAll();
if (!ret_type->Is<type::Array>() && !ret_type->Is<type::Matrix>() &&
!ret_type->Is<type::Vector>()) {
@ -103,7 +103,7 @@ ast::ArrayAccessorExpression* BoundArrayAccessors::Transform(
auto* limit = b.Sub(arr_len, b.Expr(1u));
new_idx = b.Call("min", b.Construct<u32>(ctx->Clone(old_idx)), limit);
} else {
diags->add_error("invalid 0 size", expr->source());
diags.add_error("invalid 0 size", expr->source());
return nullptr;
}
} else if (auto* c = old_idx->As<ast::ScalarConstructorExpression>()) {
@ -115,8 +115,8 @@ ast::ArrayAccessorExpression* BoundArrayAccessors::Transform(
} else if (auto* uint = lit->As<ast::UintLiteral>()) {
new_idx = b.Expr(std::min(uint->value(), size - 1));
} else {
diags->add_error("unknown scalar constructor type for accessor",
expr->source());
diags.add_error("unknown scalar constructor type for accessor",
expr->source());
return nullptr;
}
} else {

View File

@ -45,8 +45,7 @@ class BoundArrayAccessors : public Transform {
private:
ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr,
CloneContext* ctx,
diag::List* diags);
CloneContext* ctx);
};
} // namespace transform

View File

@ -45,7 +45,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Array_Idx_Nested_Scalar) {
@ -75,7 +75,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Array_Idx_Scalar) {
@ -97,7 +97,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Array_Idx_Expr) {
@ -123,7 +123,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Array_Idx_Negative) {
@ -145,7 +145,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Array_Idx_OutOfBounds) {
@ -167,7 +167,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Vector_Idx_Scalar) {
@ -189,7 +189,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Vector_Idx_Expr) {
@ -215,7 +215,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Vector_Idx_Negative) {
@ -237,7 +237,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Vector_Idx_OutOfBounds) {
@ -259,7 +259,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Scalar) {
@ -281,7 +281,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Column) {
@ -307,7 +307,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Row) {
@ -333,7 +333,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Negative_Column) {
@ -355,7 +355,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Negative_Row) {
@ -377,7 +377,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_OutOfBounds_Column) {
@ -399,7 +399,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(BoundArrayAccessorsTest, Matrix_Idx_OutOfBounds_Row) {
@ -421,7 +421,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
// TODO(dsinclair): Implement when constant_id exists
@ -488,7 +488,7 @@ fn f() -> void {
auto got = Transform<BoundArrayAccessors>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
// TODO(dsinclair): Clamp atomics when available.

View File

@ -54,7 +54,7 @@ fn non_entry_b() -> void {
auto got = Transform<EmitVertexPointSize>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
@ -87,7 +87,7 @@ fn non_entry_b() -> void {
auto got = Transform<EmitVertexPointSize>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
@ -113,7 +113,7 @@ fn compute_entry() -> void {
auto got = Transform<EmitVertexPointSize>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
} // namespace

View File

@ -15,6 +15,7 @@
#include "src/transform/first_index_offset.h"
#include <cassert>
#include <memory>
#include <utility>
#include "src/ast/array_accessor_expression.h"
@ -54,6 +55,8 @@
#include "src/type/u32_type.h"
#include "src/type_determiner.h"
TINT_INSTANTIATE_CLASS_ID(tint::transform::FirstIndexOffset::Data);
namespace tint {
namespace transform {
namespace {
@ -80,20 +83,34 @@ ast::Variable* clone_variable_with_new_name(CloneContext* ctx,
} // namespace
FirstIndexOffset::Data::Data(bool has_vtx_index,
bool has_inst_index,
uint32_t first_vtx_offset,
uint32_t first_idx_offset)
: has_vertex_index(has_vtx_index),
has_instance_index(has_inst_index),
first_vertex_offset(first_vtx_offset),
first_index_offset(first_idx_offset) {}
FirstIndexOffset::Data::Data(const Data&) = default;
FirstIndexOffset::Data::~Data() = default;
FirstIndexOffset::FirstIndexOffset(uint32_t binding, uint32_t group)
: binding_(binding), group_(group) {}
FirstIndexOffset::~FirstIndexOffset() = default;
Transform::Output FirstIndexOffset::Run(const Program* in) {
ProgramBuilder out;
// First do a quick check to see if the transform has already been applied.
for (ast::Variable* var : in->AST().GlobalVariables()) {
if (auto* dec_var = var->As<ast::Variable>()) {
if (dec_var->symbol() == in->Symbols().Get(kBufferName)) {
Output out;
out.diagnostics.add_error(
out.Diagnostics().add_error(
"First index offset transform has already been applied.");
return out;
return Output(Program(std::move(out)));
}
}
}
@ -101,7 +118,7 @@ Transform::Output FirstIndexOffset::Run(const Program* in) {
Symbol vertex_index_sym;
Symbol instance_index_sym;
// Lazilly construct the UniformBuffer on first call to
// Lazily construct the UniformBuffer on first call to
// maybe_create_buffer_var()
ast::Variable* buffer_var = nullptr;
auto maybe_create_buffer_var = [&](ProgramBuilder* dst) {
@ -114,7 +131,6 @@ Transform::Output FirstIndexOffset::Run(const Program* in) {
// add a CreateFirstIndexOffset() statement to each function that uses one of
// these builtins.
ProgramBuilder out;
CloneContext(&out, in)
.ReplaceAll([&](CloneContext* ctx, ast::Variable* var) -> ast::Variable* {
for (ast::VariableDecoration* dec : var->decorations()) {
@ -163,7 +179,10 @@ Transform::Output FirstIndexOffset::Run(const Program* in) {
})
.Clone();
return Output(Program(std::move(out)));
return Output(
Program(std::move(out)),
std::make_unique<Data>(has_vertex_index_, has_instance_index_,
vertex_index_offset_, instance_index_offset_));
}
bool FirstIndexOffset::HasVertexIndex() {

View File

@ -62,6 +62,34 @@ namespace transform {
///
class FirstIndexOffset : public Transform {
public:
/// Data holds information about shader usage and constant buffer offsets.
struct Data : public Castable<Data, transform::Data> {
/// Constructor
/// @param has_vtx_index True if the shader uses vertex_index
/// @param has_inst_index True if the shader uses instance_index
/// @param first_vtx_offset Offset of first vertex into constant buffer
/// @param first_idx_offset Offset of first instance into constant buffer
Data(bool has_vtx_index,
bool has_inst_index,
uint32_t first_vtx_offset,
uint32_t first_idx_offset);
/// Copy constructor
Data(const Data&);
/// Destructor
~Data() override;
/// True if the shader uses vertex_index
bool const has_vertex_index;
/// True if the shader uses instance_index
bool const has_instance_index;
/// Offset of first vertex into constant buffer
uint32_t const first_vertex_offset;
/// Offset of first instance into constant buffer
uint32_t const first_index_offset;
};
/// Constructor
/// @param binding the binding() for firstVertex/Instance uniform
/// @param group the group() for firstVertex/Instance uniform
@ -73,15 +101,19 @@ class FirstIndexOffset : public Transform {
/// @returns the transformation result
Output Run(const Program* program) override;
/// [DEPRECATED] - Use Data
/// @returns whether shader uses vertex_index
bool HasVertexIndex();
/// [DEPRECATED] - Use Data
/// @returns whether shader uses instance_index
bool HasInstanceIndex();
/// [DEPRECATED] - Use Data
/// @returns offset of firstVertex into constant buffer
uint32_t GetFirstVertexOffset();
/// [DEPRECATED] - Use Data
/// @returns offset of firstInstance into constant buffer
uint32_t GetFirstInstanceOffset();

View File

@ -40,8 +40,8 @@ fn entry() -> void {
}
)";
auto* expect = R"(manager().Run() errored:
error: First index offset transform has already been applied.)";
auto* expect =
"error: First index offset transform has already been applied.";
std::vector<std::unique_ptr<transform::Transform>> transforms;
transforms.emplace_back(std::make_unique<FirstIndexOffset>(0, 0));
@ -49,7 +49,15 @@ error: First index offset transform has already been applied.)";
auto got = Transform(src, std::move(transforms));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_index_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, EmptyModule) {
@ -58,7 +66,15 @@ TEST_F(FirstIndexOffsetTest, EmptyModule) {
auto got = Transform<FirstIndexOffset>(src, 0, 0);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, false);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_index_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
@ -99,7 +115,15 @@ fn entry() -> void {
auto got = Transform<FirstIndexOffset>(src, 1, 2);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_index_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
@ -140,7 +164,15 @@ fn entry() -> void {
auto got = Transform<FirstIndexOffset>(src, 1, 7);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, false);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_index_offset, 0u);
}
TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
@ -187,7 +219,15 @@ fn entry() -> void {
auto got = Transform<FirstIndexOffset>(src, 1, 2);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, true);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_index_offset, 4u);
}
TEST_F(FirstIndexOffsetTest, NestedCalls) {
@ -236,7 +276,15 @@ fn entry() -> void {
auto got = Transform<FirstIndexOffset>(src, 1, 2);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
auto* data = got.data.Get<FirstIndexOffset::Data>();
ASSERT_NE(data, nullptr);
EXPECT_EQ(data->has_vertex_index, true);
EXPECT_EQ(data->has_instance_index, false);
EXPECT_EQ(data->first_vertex_offset, 0u);
EXPECT_EQ(data->first_index_offset, 0u);
}
} // namespace

View File

@ -52,7 +52,7 @@ fn main() -> void {
auto got = Transform<Hlsl>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(HlslTest, PromoteArrayInitializerToConstVar_ArrayInArray) {
@ -75,7 +75,7 @@ fn main() -> void {
auto got = Transform<Hlsl>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(HlslTest, PromoteArrayInitializerToConstVar_NoChangeOnArrayVarDecl) {
@ -92,7 +92,7 @@ const module_arr : array<f32, 4> = array<f32, 4>(0.0, 1.0, 2.0, 3.0);
auto got = Transform<Hlsl>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(HlslTest, PromoteArrayInitializerToConstVar_Bug406) {
@ -146,7 +146,7 @@ fn main() -> void {
auto got = Transform<Hlsl>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
} // namespace

View File

@ -29,8 +29,8 @@ Transform::Output Manager::Run(const Program* program) {
for (auto& transform : transforms_) {
auto res = transform->Run(program);
out.program = std::move(res.program);
out.diagnostics.add(std::move(res.diagnostics));
if (out.diagnostics.contains_errors()) {
out.data.Add(std::move(res.data));
if (!out.program.IsValid()) {
return out;
}
program = &out.program;

View File

@ -55,7 +55,7 @@ fn main() -> void {
auto got = Transform<Spirv>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) {
@ -99,7 +99,7 @@ fn main() -> void {
auto got = Transform<Spirv>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
} // namespace

View File

@ -38,35 +38,59 @@ class TransformTest : public testing::Test {
/// `transforms`.
/// @param in the input WGSL source
/// @param transforms the list of transforms to apply
/// @return the transformed WGSL output
std::string Transform(
/// @return the transformed output
Transform::Output Transform(
std::string in,
std::vector<std::unique_ptr<transform::Transform>> transforms) {
diag::Formatter::Style style;
style.print_newline_at_end = false;
Source::File file("test", in);
auto program = reader::wgsl::Parse(&file);
if (!program.IsValid()) {
return diag::Formatter(style).format(program.Diagnostics());
return Transform::Output(std::move(program));
}
{
Manager manager;
for (auto& transform : transforms) {
manager.append(std::move(transform));
}
auto result = manager.Run(&program);
Manager manager;
for (auto& transform : transforms) {
manager.append(std::move(transform));
}
return manager.Run(&program);
}
if (result.diagnostics.contains_errors()) {
return "manager().Run() errored:\n" +
diag::Formatter(style).format(result.diagnostics);
}
program = std::move(result.program);
/// Transforms and returns the WGSL source `in`, transformed using
/// `transform`.
/// @param transform the transform to apply
/// @param in the input WGSL source
/// @return the transformed output
Transform::Output Transform(std::string in,
std::unique_ptr<transform::Transform> transform) {
std::vector<std::unique_ptr<transform::Transform>> transforms;
transforms.emplace_back(std::move(transform));
return Transform(std::move(in), std::move(transforms));
}
/// Transforms and returns the WGSL source `in`, transformed using
/// a transform of type `TRANSFORM`.
/// @param in the input WGSL source
/// @param args the TRANSFORM constructor arguments
/// @return the transformed output
template <typename TRANSFORM, typename... ARGS>
Transform::Output Transform(std::string in, ARGS&&... args) {
return Transform(std::move(in),
std::make_unique<TRANSFORM>(std::forward<ARGS>(args)...));
}
/// @param output the output of the transform
/// @returns the output program as a WGSL string, or an error string if the
/// program is not valid.
std::string str(const Transform::Output& output) {
diag::Formatter::Style style;
style.print_newline_at_end = false;
if (!output.program.IsValid()) {
return diag::Formatter(style).format(output.program.Diagnostics());
}
writer::wgsl::Generator generator(&program);
writer::wgsl::Generator generator(&output.program);
if (!generator.Generate()) {
return "WGSL writer failed:\n" + generator.error();
}
@ -84,29 +108,6 @@ class TransformTest : public testing::Test {
}
return "\n" + res + "\n";
}
/// Transforms and returns the WGSL source `in`, transformed using
/// `transform`.
/// @param transform the transform to apply
/// @param in the input WGSL source
/// @return the transformed WGSL output
std::string Transform(std::string in,
std::unique_ptr<transform::Transform> transform) {
std::vector<std::unique_ptr<transform::Transform>> transforms;
transforms.emplace_back(std::move(transform));
return Transform(std::move(in), std::move(transforms));
}
/// Transforms and returns the WGSL source `in`, transformed using
/// a transform of type `TRANSFORM`.
/// @param in the input WGSL source
/// @param args the TRANSFORM constructor arguments
/// @return the transformed WGSL output
template <typename TRANSFORM, typename... ARGS>
std::string Transform(std::string in, ARGS&&... args) {
return Transform(std::move(in),
std::make_unique<TRANSFORM>(std::forward<ARGS>(args)...));
}
};
} // namespace transform

View File

@ -19,16 +19,27 @@
#include "src/clone_context.h"
#include "src/program_builder.h"
TINT_INSTANTIATE_CLASS_ID(tint::transform::Data);
namespace tint {
namespace transform {
Data::Data() = default;
Data::Data(const Data&) = default;
Data::~Data() = default;
DataMap::DataMap() = default;
DataMap::DataMap(DataMap&&) = default;
DataMap::~DataMap() = default;
Transform::Output::Output() = default;
Transform::Output::Output(Program&& p) : program(std::move(p)) {}
Transform::Output::Output(Program&& p, diag::List&& d)
: program(std::move(p)), diagnostics(std::move(d)) {}
Transform::Transform() = default;
Transform::~Transform() = default;

View File

@ -17,6 +17,8 @@
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include "src/diagnostic/diagnostic.h"
@ -25,6 +27,84 @@
namespace tint {
namespace transform {
/// Data is the base class for transforms that emit extra output information
/// along with a Program.
class Data : public Castable<Data> {
public:
/// Constructor
Data();
/// Copy constructor
Data(const Data&);
/// Destructor
~Data() override;
};
/// DataMap is a map of Data unique pointers keyed by the Data's ClassID.
class DataMap {
public:
/// Constructor
DataMap();
/// Move constructor
DataMap(DataMap&&);
/// Constructor
/// @param data_unique_ptrs a variadic list of additional data unique_ptrs
/// produced by the transform
template <typename... DATA>
explicit DataMap(DATA... data_unique_ptrs) {
PutAll(std::forward<DATA>(data_unique_ptrs)...);
}
/// Destructor
~DataMap();
/// Adds the data into DataMap keyed by the ClassID of type T.
/// @param data the data to add to the DataMap
template <typename T>
void Put(std::unique_ptr<T>&& data) {
static_assert(std::is_base_of<Data, T>::value,
"T does not derive from Data");
map_[ClassID::Of<T>()] = std::move(data);
}
/// @returns a pointer to the Data placed into the DataMap with a call to
/// Put()
template <typename T>
T const* Get() const {
auto it = map_.find(ClassID::Of<T>());
if (it == map_.end()) {
return nullptr;
}
return static_cast<T*>(it->second.get());
}
/// Add moves all the data from other into this DataMap
/// @param other the DataMap to move into this DataMap
void Add(DataMap&& other) {
for (auto& it : other.map_) {
map_.emplace(it.first, std::move(it.second));
}
other.map_.clear();
}
private:
template <typename T0>
void PutAll(T0&& first) {
Put(std::forward<T0>(first));
}
template <typename T0, typename... Tn>
void PutAll(T0&& first, Tn&&... remainder) {
Put(std::forward<T0>(first));
PutAll(std::forward<Tn>(remainder)...);
}
std::unordered_map<ClassID, std::unique_ptr<Data>> map_;
};
/// Interface for Program transforms
class Transform {
public:
@ -34,7 +114,8 @@ class Transform {
virtual ~Transform();
/// The return type of Run()
struct Output {
class Output {
public:
/// Constructor
Output();
@ -43,13 +124,21 @@ class Transform {
explicit Output(Program&& program);
/// Constructor
/// @param program the program to move into this Output
/// @param diags the list of diagnostics to move into this Output
Output(Program&& program, diag::List&& diags);
/// @param program_ the program to move into this Output
/// @param data_ a variadic list of additional data unique_ptrs produced by
/// the transform
template <typename... DATA>
Output(Program&& program_, DATA... data_)
: program(std::move(program_)), data(std::forward<DATA>(data_)...) {}
/// The transformed program. May be empty on error.
Program program;
/// Extra output generated by the transforms.
DataMap data;
/// Diagnostics raised while running the Transform.
/// [DEPRECATED] Use `program.Diagnostics()`
diag::List diagnostics;
};

View File

@ -57,11 +57,15 @@ static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index";
} // namespace
VertexPulling::VertexPulling() = default;
VertexPulling::VertexPulling(const Config& config)
: cfg(config), vertex_state_set(true) {}
VertexPulling::~VertexPulling() = default;
void VertexPulling::SetVertexState(const VertexStateDescriptor& vertex_state) {
cfg.vertex_state = vertex_state;
cfg.vertex_state_set = true;
vertex_state_set = true;
}
void VertexPulling::SetEntryPoint(std::string entry_point) {
@ -77,20 +81,20 @@ void VertexPulling::SetPullingBufferBindingSet(uint32_t number) {
}
Transform::Output VertexPulling::Run(const Program* in) {
ProgramBuilder out;
// Check SetVertexState was called
if (!cfg.vertex_state_set) {
Output out;
out.diagnostics.add_error("SetVertexState not called");
return out;
if (!vertex_state_set) {
out.Diagnostics().add_error("SetVertexState not called");
return Output(Program(std::move(out)));
}
// Find entry point
auto* func = in->AST().Functions().Find(
in->Symbols().Get(cfg.entry_point_name), ast::PipelineStage::kVertex);
if (func == nullptr) {
Output out;
out.diagnostics.add_error("Vertex stage entry point not found");
return out;
out.Diagnostics().add_error("Vertex stage entry point not found");
return Output(Program(std::move(out)));
}
// TODO(idanr): Need to check shader locations in descriptor cover all
@ -98,7 +102,6 @@ Transform::Output VertexPulling::Run(const Program* in) {
// TODO(idanr): Make sure we covered all error cases, to guarantee the
// following stages will pass
ProgramBuilder out;
CloneContext ctx(&out, in);
State state{ctx, cfg};
@ -123,7 +126,9 @@ Transform::Output VertexPulling::Run(const Program* in) {
}
VertexPulling::Config::Config() = default;
VertexPulling::Config::Config(const Config&) = default;
VertexPulling::Config::~Config() = default;
VertexPulling::State::State(CloneContext& context, const Config& c)

View File

@ -136,26 +136,54 @@ using VertexStateDescriptor = std::vector<VertexBufferLayoutDescriptor>;
/// shader to use.
class VertexPulling : public Transform {
public:
/// Configuration options for the transform
struct Config {
/// Constructor
Config();
/// Copy constructor
Config(const Config&);
/// Destructor
~Config();
/// The entry point to add assignments into
std::string entry_point_name;
/// The vertex state descriptor, containing info about attributes
VertexStateDescriptor vertex_state;
/// The "group" we will put all our vertex buffers into (as storage buffers)
/// Default to 4 as it is past the limits of user-accessible groups
uint32_t pulling_group = 4u;
};
/// Constructor
VertexPulling();
/// Constructor
/// @param config the configuration options for the transform
explicit VertexPulling(const Config& config);
/// Destructor
~VertexPulling() override;
/// Sets the vertex state descriptor, containing info about attributes
/// [DEPRECATED] Use the VertexPulling(const Config&)
/// @param vertex_state the vertex state descriptor
void SetVertexState(const VertexStateDescriptor& vertex_state);
/// Sets the entry point to add assignments into
/// [DEPRECATED] Use the VertexPulling(const Config&)
/// @param entry_point the vertex stage entry point
void SetEntryPoint(std::string entry_point);
/// Sets the "set" we will put all our vertex buffers into (as storage
/// buffers)
/// DEPRECATED
/// [DEPRECATED] Use the VertexPulling(const Config&)
/// @param number the set number we will use
void SetPullingBufferBindingSet(uint32_t number);
/// Sets the "group" we will put all our vertex buffers into (as storage
/// buffers)
/// [DEPRECATED] Use the VertexPulling(const Config&)
/// @param number the group number we will use
void SetPullingBufferBindingGroup(uint32_t number);
@ -165,19 +193,8 @@ class VertexPulling : public Transform {
Output Run(const Program* program) override;
private:
struct Config {
Config();
Config(const Config&);
~Config();
std::string entry_point_name;
VertexStateDescriptor vertex_state;
bool vertex_state_set = false;
// Default to 4 as it is past the limits of user-accessible groups
uint32_t pulling_group = 4u;
};
Config cfg;
bool vertex_state_set = false;
struct State {
State(CloneContext& ctx, const Config& c);

View File

@ -30,26 +30,25 @@ TEST_F(VertexPullingTest, Error_NoVertexState) {
fn main() -> void {}
)";
auto* expect = R"(manager().Run() errored:
error: SetVertexState not called)";
auto* expect = "error: SetVertexState not called";
auto got = Transform<VertexPulling>(src);
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, Error_NoEntryPoint) {
auto* src = "";
auto* expect = R"(manager().Run() errored:
error: Vertex stage entry point not found)";
auto* expect = "error: Vertex stage entry point not found";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState({});
VertexPulling::Config cfg;
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
@ -58,16 +57,16 @@ TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
fn main() -> void {}
)";
auto* expect = R"(manager().Run() errored:
error: Vertex stage entry point not found)";
auto* expect = "error: Vertex stage entry point not found";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState({});
transform->SetEntryPoint("_");
VertexPulling::Config cfg;
cfg.entry_point_name = "_";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
@ -76,16 +75,16 @@ TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
fn main() -> void {}
)";
auto* expect = R"(manager().Run() errored:
error: Vertex stage entry point not found)";
auto* expect = "error: Vertex stage entry point not found";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState({});
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, BasicModule) {
@ -109,13 +108,14 @@ fn main() -> void {
}
)";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState({});
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, OneAttribute) {
@ -149,14 +149,16 @@ fn main() -> void {
}
)";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState(
{{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.vertex_state = {
{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}};
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, OneInstancedAttribute) {
@ -190,14 +192,16 @@ fn main() -> void {
}
)";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState(
{{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}});
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.vertex_state = {
{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}};
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
@ -231,15 +235,17 @@ fn main() -> void {
}
)";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState(
{{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
transform->SetPullingBufferBindingSet(5);
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.vertex_state = {
{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}};
cfg.pulling_group = 5;
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
// We expect the transform to use an existing builtin variables if it finds them
@ -285,15 +291,26 @@ fn main() -> void {
}
)";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState(
{{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}},
{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}});
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.vertex_state = {{
{
4,
InputStepMode::kVertex,
{{VertexFormat::kF32, 0, 0}},
},
{
4,
InputStepMode::kInstance,
{{VertexFormat::kF32, 0, 1}},
},
}};
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
@ -332,16 +349,18 @@ fn main() -> void {
}
)";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState(
{{{16,
InputStepMode::kVertex,
{{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}});
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.vertex_state = {
{{16,
InputStepMode::kVertex,
{{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}};
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
TEST_F(VertexPullingTest, FloatVectorAttributes) {
@ -389,16 +408,19 @@ fn main() -> void {
}
)";
auto transform = std::make_unique<VertexPulling>();
transform->SetVertexState(
{{{8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}},
{12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
{16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}});
transform->SetEntryPoint("main");
VertexPulling::Config cfg;
cfg.vertex_state = {{
{8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}},
{12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}},
{16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}},
}};
cfg.entry_point_name = "main";
auto transform = std::make_unique<VertexPulling>(cfg);
auto got = Transform(src, std::move(transform));
EXPECT_EQ(expect, got);
EXPECT_EQ(expect, str(got));
}
} // namespace

View File

@ -104,8 +104,8 @@ class TestHelperBase : public BODY, public ProgramBuilder {
}();
auto result = transform::Hlsl().Run(program.get());
[&]() {
ASSERT_FALSE(result.diagnostics.contains_errors())
<< formatter.format(result.diagnostics);
ASSERT_TRUE(result.program.IsValid())
<< formatter.format(result.program.Diagnostics());
}();
*program = std::move(result.program);
gen_ = std::make_unique<GeneratorImpl>(program.get());

View File

@ -81,8 +81,8 @@ class TestHelperBase : public ProgramBuilder, public BASE {
}();
auto result = transform::Spirv().Run(program.get());
[&]() {
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
ASSERT_TRUE(result.program.IsValid())
<< diag::Formatter().format(result.program.Diagnostics());
}();
*program = std::move(result.program);
spirv_builder = std::make_unique<spirv::Builder>(program.get());