diff --git a/fuzzers/tint_common_fuzzer.cc b/fuzzers/tint_common_fuzzer.cc index 9e68902aef..a1ed3894c9 100644 --- a/fuzzers/tint_common_fuzzer.cc +++ b/fuzzers/tint_common_fuzzer.cc @@ -144,7 +144,7 @@ int CommonFuzzer::Run(const uint8_t* data, size_t size) { if (transform_manager_) { auto out = transform_manager_->Run(&program); - if (out.diagnostics.contains_errors()) { + if (!out.program.IsValid()) { return 0; } diff --git a/samples/main.cc b/samples/main.cc index b847ffe51e..0d9d0b3ac0 100644 --- a/samples/main.cc +++ b/samples/main.cc @@ -28,8 +28,8 @@ namespace { -[[noreturn]] -void TintInternalCompilerErrorReporter(const tint::diag::List& diagnostics) { +[[noreturn]] void TintInternalCompilerErrorReporter( + const tint::diag::List& diagnostics) { auto printer = tint::diag::Printer::create(stderr, true); tint::diag::Formatter{}.format(diagnostics, printer.get()); exit(1); @@ -569,8 +569,8 @@ int main(int argc, const char** argv) { } auto out = transform_manager.Run(program.get()); - if (out.diagnostics.contains_errors()) { - diag_formatter.format(out.diagnostics, diag_printer.get()); + if (!out.program.IsValid()) { + diag_formatter.format(out.program.Diagnostics(), diag_printer.get()); return 1; } diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc index 72799be1f9..4cb38227d0 100644 --- a/src/transform/bound_array_accessors.cc +++ b/src/transform/bound_array_accessors.cc @@ -58,19 +58,19 @@ BoundArrayAccessors::~BoundArrayAccessors() = default; Transform::Output BoundArrayAccessors::Run(const Program* in) { ProgramBuilder out; - diag::List diagnostics; CloneContext(&out, in) .ReplaceAll([&](CloneContext* ctx, ast::ArrayAccessorExpression* expr) { - return Transform(expr, ctx, &diagnostics); + return Transform(expr, ctx); }) .Clone(); - return Output(Program(std::move(out)), std::move(diagnostics)); + return Output(Program(std::move(out))); } ast::ArrayAccessorExpression* BoundArrayAccessors::Transform( ast::ArrayAccessorExpression* expr, - CloneContext* ctx, - diag::List* diags) { + CloneContext* ctx) { + auto& diags = ctx->dst->Diagnostics(); + auto* ret_type = ctx->src->Sem().Get(expr->array())->Type()->UnwrapAll(); if (!ret_type->Is() && !ret_type->Is() && !ret_type->Is()) { @@ -103,7 +103,7 @@ ast::ArrayAccessorExpression* BoundArrayAccessors::Transform( auto* limit = b.Sub(arr_len, b.Expr(1u)); new_idx = b.Call("min", b.Construct(ctx->Clone(old_idx)), limit); } else { - diags->add_error("invalid 0 size", expr->source()); + diags.add_error("invalid 0 size", expr->source()); return nullptr; } } else if (auto* c = old_idx->As()) { @@ -115,8 +115,8 @@ ast::ArrayAccessorExpression* BoundArrayAccessors::Transform( } else if (auto* uint = lit->As()) { new_idx = b.Expr(std::min(uint->value(), size - 1)); } else { - diags->add_error("unknown scalar constructor type for accessor", - expr->source()); + diags.add_error("unknown scalar constructor type for accessor", + expr->source()); return nullptr; } } else { diff --git a/src/transform/bound_array_accessors.h b/src/transform/bound_array_accessors.h index 899e7d3ef5..22e0ffd2b0 100644 --- a/src/transform/bound_array_accessors.h +++ b/src/transform/bound_array_accessors.h @@ -45,8 +45,7 @@ class BoundArrayAccessors : public Transform { private: ast::ArrayAccessorExpression* Transform(ast::ArrayAccessorExpression* expr, - CloneContext* ctx, - diag::List* diags); + CloneContext* ctx); }; } // namespace transform diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc index 783bb98645..58def8f6e8 100644 --- a/src/transform/bound_array_accessors_test.cc +++ b/src/transform/bound_array_accessors_test.cc @@ -45,7 +45,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Array_Idx_Nested_Scalar) { @@ -75,7 +75,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Array_Idx_Scalar) { @@ -97,7 +97,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Array_Idx_Expr) { @@ -123,7 +123,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Array_Idx_Negative) { @@ -145,7 +145,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Array_Idx_OutOfBounds) { @@ -167,7 +167,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Vector_Idx_Scalar) { @@ -189,7 +189,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Vector_Idx_Expr) { @@ -215,7 +215,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Vector_Idx_Negative) { @@ -237,7 +237,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Vector_Idx_OutOfBounds) { @@ -259,7 +259,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Scalar) { @@ -281,7 +281,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Column) { @@ -307,7 +307,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Row) { @@ -333,7 +333,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Negative_Column) { @@ -355,7 +355,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Negative_Row) { @@ -377,7 +377,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Matrix_Idx_OutOfBounds_Column) { @@ -399,7 +399,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(BoundArrayAccessorsTest, Matrix_Idx_OutOfBounds_Row) { @@ -421,7 +421,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } // TODO(dsinclair): Implement when constant_id exists @@ -488,7 +488,7 @@ fn f() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } // TODO(dsinclair): Clamp atomics when available. diff --git a/src/transform/emit_vertex_point_size_test.cc b/src/transform/emit_vertex_point_size_test.cc index 9a4a35958c..5fc7dbae99 100644 --- a/src/transform/emit_vertex_point_size_test.cc +++ b/src/transform/emit_vertex_point_size_test.cc @@ -54,7 +54,7 @@ fn non_entry_b() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { @@ -87,7 +87,7 @@ fn non_entry_b() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(EmitVertexPointSizeTest, NonVertexStage) { @@ -113,7 +113,7 @@ fn compute_entry() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } } // namespace diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 6afd9fdc19..29fe53f8c3 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -15,6 +15,7 @@ #include "src/transform/first_index_offset.h" #include +#include #include #include "src/ast/array_accessor_expression.h" @@ -54,6 +55,8 @@ #include "src/type/u32_type.h" #include "src/type_determiner.h" +TINT_INSTANTIATE_CLASS_ID(tint::transform::FirstIndexOffset::Data); + namespace tint { namespace transform { namespace { @@ -80,20 +83,34 @@ ast::Variable* clone_variable_with_new_name(CloneContext* ctx, } // namespace +FirstIndexOffset::Data::Data(bool has_vtx_index, + bool has_inst_index, + uint32_t first_vtx_offset, + uint32_t first_idx_offset) + : has_vertex_index(has_vtx_index), + has_instance_index(has_inst_index), + first_vertex_offset(first_vtx_offset), + first_index_offset(first_idx_offset) {} + +FirstIndexOffset::Data::Data(const Data&) = default; + +FirstIndexOffset::Data::~Data() = default; + FirstIndexOffset::FirstIndexOffset(uint32_t binding, uint32_t group) : binding_(binding), group_(group) {} FirstIndexOffset::~FirstIndexOffset() = default; Transform::Output FirstIndexOffset::Run(const Program* in) { + ProgramBuilder out; + // First do a quick check to see if the transform has already been applied. for (ast::Variable* var : in->AST().GlobalVariables()) { if (auto* dec_var = var->As()) { if (dec_var->symbol() == in->Symbols().Get(kBufferName)) { - Output out; - out.diagnostics.add_error( + out.Diagnostics().add_error( "First index offset transform has already been applied."); - return out; + return Output(Program(std::move(out))); } } } @@ -101,7 +118,7 @@ Transform::Output FirstIndexOffset::Run(const Program* in) { Symbol vertex_index_sym; Symbol instance_index_sym; - // Lazilly construct the UniformBuffer on first call to + // Lazily construct the UniformBuffer on first call to // maybe_create_buffer_var() ast::Variable* buffer_var = nullptr; auto maybe_create_buffer_var = [&](ProgramBuilder* dst) { @@ -114,7 +131,6 @@ Transform::Output FirstIndexOffset::Run(const Program* in) { // add a CreateFirstIndexOffset() statement to each function that uses one of // these builtins. - ProgramBuilder out; CloneContext(&out, in) .ReplaceAll([&](CloneContext* ctx, ast::Variable* var) -> ast::Variable* { for (ast::VariableDecoration* dec : var->decorations()) { @@ -163,7 +179,10 @@ Transform::Output FirstIndexOffset::Run(const Program* in) { }) .Clone(); - return Output(Program(std::move(out))); + return Output( + Program(std::move(out)), + std::make_unique(has_vertex_index_, has_instance_index_, + vertex_index_offset_, instance_index_offset_)); } bool FirstIndexOffset::HasVertexIndex() { diff --git a/src/transform/first_index_offset.h b/src/transform/first_index_offset.h index b8983ae81a..dfb0353f58 100644 --- a/src/transform/first_index_offset.h +++ b/src/transform/first_index_offset.h @@ -62,6 +62,34 @@ namespace transform { /// class FirstIndexOffset : public Transform { public: + /// Data holds information about shader usage and constant buffer offsets. + struct Data : public Castable { + /// Constructor + /// @param has_vtx_index True if the shader uses vertex_index + /// @param has_inst_index True if the shader uses instance_index + /// @param first_vtx_offset Offset of first vertex into constant buffer + /// @param first_idx_offset Offset of first instance into constant buffer + Data(bool has_vtx_index, + bool has_inst_index, + uint32_t first_vtx_offset, + uint32_t first_idx_offset); + + /// Copy constructor + Data(const Data&); + + /// Destructor + ~Data() override; + + /// True if the shader uses vertex_index + bool const has_vertex_index; + /// True if the shader uses instance_index + bool const has_instance_index; + /// Offset of first vertex into constant buffer + uint32_t const first_vertex_offset; + /// Offset of first instance into constant buffer + uint32_t const first_index_offset; + }; + /// Constructor /// @param binding the binding() for firstVertex/Instance uniform /// @param group the group() for firstVertex/Instance uniform @@ -73,15 +101,19 @@ class FirstIndexOffset : public Transform { /// @returns the transformation result Output Run(const Program* program) override; + /// [DEPRECATED] - Use Data /// @returns whether shader uses vertex_index bool HasVertexIndex(); + /// [DEPRECATED] - Use Data /// @returns whether shader uses instance_index bool HasInstanceIndex(); + /// [DEPRECATED] - Use Data /// @returns offset of firstVertex into constant buffer uint32_t GetFirstVertexOffset(); + /// [DEPRECATED] - Use Data /// @returns offset of firstInstance into constant buffer uint32_t GetFirstInstanceOffset(); diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc index 977a7166a7..c102c9d556 100644 --- a/src/transform/first_index_offset_test.cc +++ b/src/transform/first_index_offset_test.cc @@ -40,8 +40,8 @@ fn entry() -> void { } )"; - auto* expect = R"(manager().Run() errored: -error: First index offset transform has already been applied.)"; + auto* expect = + "error: First index offset transform has already been applied."; std::vector> transforms; transforms.emplace_back(std::make_unique(0, 0)); @@ -49,7 +49,15 @@ error: First index offset transform has already been applied.)"; auto got = Transform(src, std::move(transforms)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); + + auto* data = got.data.Get(); + + ASSERT_NE(data, nullptr); + EXPECT_EQ(data->has_vertex_index, true); + EXPECT_EQ(data->has_instance_index, false); + EXPECT_EQ(data->first_vertex_offset, 0u); + EXPECT_EQ(data->first_index_offset, 0u); } TEST_F(FirstIndexOffsetTest, EmptyModule) { @@ -58,7 +66,15 @@ TEST_F(FirstIndexOffsetTest, EmptyModule) { auto got = Transform(src, 0, 0); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); + + auto* data = got.data.Get(); + + ASSERT_NE(data, nullptr); + EXPECT_EQ(data->has_vertex_index, false); + EXPECT_EQ(data->has_instance_index, false); + EXPECT_EQ(data->first_vertex_offset, 0u); + EXPECT_EQ(data->first_index_offset, 0u); } TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) { @@ -99,7 +115,15 @@ fn entry() -> void { auto got = Transform(src, 1, 2); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); + + auto* data = got.data.Get(); + + ASSERT_NE(data, nullptr); + EXPECT_EQ(data->has_vertex_index, true); + EXPECT_EQ(data->has_instance_index, false); + EXPECT_EQ(data->first_vertex_offset, 0u); + EXPECT_EQ(data->first_index_offset, 0u); } TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) { @@ -140,7 +164,15 @@ fn entry() -> void { auto got = Transform(src, 1, 7); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); + + auto* data = got.data.Get(); + + ASSERT_NE(data, nullptr); + EXPECT_EQ(data->has_vertex_index, false); + EXPECT_EQ(data->has_instance_index, true); + EXPECT_EQ(data->first_vertex_offset, 0u); + EXPECT_EQ(data->first_index_offset, 0u); } TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) { @@ -187,7 +219,15 @@ fn entry() -> void { auto got = Transform(src, 1, 2); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); + + auto* data = got.data.Get(); + + ASSERT_NE(data, nullptr); + EXPECT_EQ(data->has_vertex_index, true); + EXPECT_EQ(data->has_instance_index, true); + EXPECT_EQ(data->first_vertex_offset, 0u); + EXPECT_EQ(data->first_index_offset, 4u); } TEST_F(FirstIndexOffsetTest, NestedCalls) { @@ -236,7 +276,15 @@ fn entry() -> void { auto got = Transform(src, 1, 2); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); + + auto* data = got.data.Get(); + + ASSERT_NE(data, nullptr); + EXPECT_EQ(data->has_vertex_index, true); + EXPECT_EQ(data->has_instance_index, false); + EXPECT_EQ(data->first_vertex_offset, 0u); + EXPECT_EQ(data->first_index_offset, 0u); } } // namespace diff --git a/src/transform/hlsl_test.cc b/src/transform/hlsl_test.cc index 38a4e69dae..3eca722253 100644 --- a/src/transform/hlsl_test.cc +++ b/src/transform/hlsl_test.cc @@ -52,7 +52,7 @@ fn main() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(HlslTest, PromoteArrayInitializerToConstVar_ArrayInArray) { @@ -75,7 +75,7 @@ fn main() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(HlslTest, PromoteArrayInitializerToConstVar_NoChangeOnArrayVarDecl) { @@ -92,7 +92,7 @@ const module_arr : array = array(0.0, 1.0, 2.0, 3.0); auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(HlslTest, PromoteArrayInitializerToConstVar_Bug406) { @@ -146,7 +146,7 @@ fn main() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } } // namespace diff --git a/src/transform/manager.cc b/src/transform/manager.cc index 31fe9ac290..e91a54f186 100644 --- a/src/transform/manager.cc +++ b/src/transform/manager.cc @@ -29,8 +29,8 @@ Transform::Output Manager::Run(const Program* program) { for (auto& transform : transforms_) { auto res = transform->Run(program); out.program = std::move(res.program); - out.diagnostics.add(std::move(res.diagnostics)); - if (out.diagnostics.contains_errors()) { + out.data.Add(std::move(res.data)); + if (!out.program.IsValid()) { return out; } program = &out.program; diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc index 53bdc52435..4740a419d7 100644 --- a/src/transform/spirv_test.cc +++ b/src/transform/spirv_test.cc @@ -55,7 +55,7 @@ fn main() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(SpirvTest, HandleSampleMaskBuiltins_FunctionArg) { @@ -99,7 +99,7 @@ fn main() -> void { auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } } // namespace diff --git a/src/transform/test_helper.h b/src/transform/test_helper.h index f52c5e741b..ae78d51ec8 100644 --- a/src/transform/test_helper.h +++ b/src/transform/test_helper.h @@ -38,35 +38,59 @@ class TransformTest : public testing::Test { /// `transforms`. /// @param in the input WGSL source /// @param transforms the list of transforms to apply - /// @return the transformed WGSL output - std::string Transform( + /// @return the transformed output + Transform::Output Transform( std::string in, std::vector> transforms) { - diag::Formatter::Style style; - style.print_newline_at_end = false; - Source::File file("test", in); auto program = reader::wgsl::Parse(&file); if (!program.IsValid()) { - return diag::Formatter(style).format(program.Diagnostics()); + return Transform::Output(std::move(program)); } - { - Manager manager; - for (auto& transform : transforms) { - manager.append(std::move(transform)); - } - auto result = manager.Run(&program); + Manager manager; + for (auto& transform : transforms) { + manager.append(std::move(transform)); + } + return manager.Run(&program); + } - if (result.diagnostics.contains_errors()) { - return "manager().Run() errored:\n" + - diag::Formatter(style).format(result.diagnostics); - } - program = std::move(result.program); + /// Transforms and returns the WGSL source `in`, transformed using + /// `transform`. + /// @param transform the transform to apply + /// @param in the input WGSL source + /// @return the transformed output + Transform::Output Transform(std::string in, + std::unique_ptr transform) { + std::vector> transforms; + transforms.emplace_back(std::move(transform)); + return Transform(std::move(in), std::move(transforms)); + } + + /// Transforms and returns the WGSL source `in`, transformed using + /// a transform of type `TRANSFORM`. + /// @param in the input WGSL source + /// @param args the TRANSFORM constructor arguments + /// @return the transformed output + template + Transform::Output Transform(std::string in, ARGS&&... args) { + return Transform(std::move(in), + std::make_unique(std::forward(args)...)); + } + + /// @param output the output of the transform + /// @returns the output program as a WGSL string, or an error string if the + /// program is not valid. + std::string str(const Transform::Output& output) { + diag::Formatter::Style style; + style.print_newline_at_end = false; + + if (!output.program.IsValid()) { + return diag::Formatter(style).format(output.program.Diagnostics()); } - writer::wgsl::Generator generator(&program); + writer::wgsl::Generator generator(&output.program); if (!generator.Generate()) { return "WGSL writer failed:\n" + generator.error(); } @@ -84,29 +108,6 @@ class TransformTest : public testing::Test { } return "\n" + res + "\n"; } - - /// Transforms and returns the WGSL source `in`, transformed using - /// `transform`. - /// @param transform the transform to apply - /// @param in the input WGSL source - /// @return the transformed WGSL output - std::string Transform(std::string in, - std::unique_ptr transform) { - std::vector> transforms; - transforms.emplace_back(std::move(transform)); - return Transform(std::move(in), std::move(transforms)); - } - - /// Transforms and returns the WGSL source `in`, transformed using - /// a transform of type `TRANSFORM`. - /// @param in the input WGSL source - /// @param args the TRANSFORM constructor arguments - /// @return the transformed WGSL output - template - std::string Transform(std::string in, ARGS&&... args) { - return Transform(std::move(in), - std::make_unique(std::forward(args)...)); - } }; } // namespace transform diff --git a/src/transform/transform.cc b/src/transform/transform.cc index c037cdba54..ed368ae058 100644 --- a/src/transform/transform.cc +++ b/src/transform/transform.cc @@ -19,16 +19,27 @@ #include "src/clone_context.h" #include "src/program_builder.h" +TINT_INSTANTIATE_CLASS_ID(tint::transform::Data); + namespace tint { namespace transform { +Data::Data() = default; + +Data::Data(const Data&) = default; + +Data::~Data() = default; + +DataMap::DataMap() = default; + +DataMap::DataMap(DataMap&&) = default; + +DataMap::~DataMap() = default; + Transform::Output::Output() = default; Transform::Output::Output(Program&& p) : program(std::move(p)) {} -Transform::Output::Output(Program&& p, diag::List&& d) - : program(std::move(p)), diagnostics(std::move(d)) {} - Transform::Transform() = default; Transform::~Transform() = default; diff --git a/src/transform/transform.h b/src/transform/transform.h index 3516259183..c598b83651 100644 --- a/src/transform/transform.h +++ b/src/transform/transform.h @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include "src/diagnostic/diagnostic.h" @@ -25,6 +27,84 @@ namespace tint { namespace transform { +/// Data is the base class for transforms that emit extra output information +/// along with a Program. +class Data : public Castable { + public: + /// Constructor + Data(); + + /// Copy constructor + Data(const Data&); + + /// Destructor + ~Data() override; +}; + +/// DataMap is a map of Data unique pointers keyed by the Data's ClassID. +class DataMap { + public: + /// Constructor + DataMap(); + + /// Move constructor + DataMap(DataMap&&); + + /// Constructor + /// @param data_unique_ptrs a variadic list of additional data unique_ptrs + /// produced by the transform + template + explicit DataMap(DATA... data_unique_ptrs) { + PutAll(std::forward(data_unique_ptrs)...); + } + + /// Destructor + ~DataMap(); + + /// Adds the data into DataMap keyed by the ClassID of type T. + /// @param data the data to add to the DataMap + template + void Put(std::unique_ptr&& data) { + static_assert(std::is_base_of::value, + "T does not derive from Data"); + map_[ClassID::Of()] = std::move(data); + } + + /// @returns a pointer to the Data placed into the DataMap with a call to + /// Put() + template + T const* Get() const { + auto it = map_.find(ClassID::Of()); + if (it == map_.end()) { + return nullptr; + } + return static_cast(it->second.get()); + } + + /// Add moves all the data from other into this DataMap + /// @param other the DataMap to move into this DataMap + void Add(DataMap&& other) { + for (auto& it : other.map_) { + map_.emplace(it.first, std::move(it.second)); + } + other.map_.clear(); + } + + private: + template + void PutAll(T0&& first) { + Put(std::forward(first)); + } + + template + void PutAll(T0&& first, Tn&&... remainder) { + Put(std::forward(first)); + PutAll(std::forward(remainder)...); + } + + std::unordered_map> map_; +}; + /// Interface for Program transforms class Transform { public: @@ -34,7 +114,8 @@ class Transform { virtual ~Transform(); /// The return type of Run() - struct Output { + class Output { + public: /// Constructor Output(); @@ -43,13 +124,21 @@ class Transform { explicit Output(Program&& program); /// Constructor - /// @param program the program to move into this Output - /// @param diags the list of diagnostics to move into this Output - Output(Program&& program, diag::List&& diags); + /// @param program_ the program to move into this Output + /// @param data_ a variadic list of additional data unique_ptrs produced by + /// the transform + template + Output(Program&& program_, DATA... data_) + : program(std::move(program_)), data(std::forward(data_)...) {} /// The transformed program. May be empty on error. Program program; + + /// Extra output generated by the transforms. + DataMap data; + /// Diagnostics raised while running the Transform. + /// [DEPRECATED] Use `program.Diagnostics()` diag::List diagnostics; }; diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index f052bd44aa..273806f7ce 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -57,11 +57,15 @@ static const char kDefaultInstanceIndexName[] = "_tint_pulling_instance_index"; } // namespace VertexPulling::VertexPulling() = default; + +VertexPulling::VertexPulling(const Config& config) + : cfg(config), vertex_state_set(true) {} + VertexPulling::~VertexPulling() = default; void VertexPulling::SetVertexState(const VertexStateDescriptor& vertex_state) { cfg.vertex_state = vertex_state; - cfg.vertex_state_set = true; + vertex_state_set = true; } void VertexPulling::SetEntryPoint(std::string entry_point) { @@ -77,20 +81,20 @@ void VertexPulling::SetPullingBufferBindingSet(uint32_t number) { } Transform::Output VertexPulling::Run(const Program* in) { + ProgramBuilder out; + // Check SetVertexState was called - if (!cfg.vertex_state_set) { - Output out; - out.diagnostics.add_error("SetVertexState not called"); - return out; + if (!vertex_state_set) { + out.Diagnostics().add_error("SetVertexState not called"); + return Output(Program(std::move(out))); } // Find entry point auto* func = in->AST().Functions().Find( in->Symbols().Get(cfg.entry_point_name), ast::PipelineStage::kVertex); if (func == nullptr) { - Output out; - out.diagnostics.add_error("Vertex stage entry point not found"); - return out; + out.Diagnostics().add_error("Vertex stage entry point not found"); + return Output(Program(std::move(out))); } // TODO(idanr): Need to check shader locations in descriptor cover all @@ -98,7 +102,6 @@ Transform::Output VertexPulling::Run(const Program* in) { // TODO(idanr): Make sure we covered all error cases, to guarantee the // following stages will pass - ProgramBuilder out; CloneContext ctx(&out, in); State state{ctx, cfg}; @@ -123,7 +126,9 @@ Transform::Output VertexPulling::Run(const Program* in) { } VertexPulling::Config::Config() = default; + VertexPulling::Config::Config(const Config&) = default; + VertexPulling::Config::~Config() = default; VertexPulling::State::State(CloneContext& context, const Config& c) diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h index 215052e458..a9bdb52f78 100644 --- a/src/transform/vertex_pulling.h +++ b/src/transform/vertex_pulling.h @@ -136,26 +136,54 @@ using VertexStateDescriptor = std::vector; /// shader to use. class VertexPulling : public Transform { public: + /// Configuration options for the transform + struct Config { + /// Constructor + Config(); + + /// Copy constructor + Config(const Config&); + + /// Destructor + ~Config(); + + /// The entry point to add assignments into + std::string entry_point_name; + + /// The vertex state descriptor, containing info about attributes + VertexStateDescriptor vertex_state; + + /// The "group" we will put all our vertex buffers into (as storage buffers) + /// Default to 4 as it is past the limits of user-accessible groups + uint32_t pulling_group = 4u; + }; + /// Constructor VertexPulling(); + + /// Constructor + /// @param config the configuration options for the transform + explicit VertexPulling(const Config& config); + /// Destructor ~VertexPulling() override; /// Sets the vertex state descriptor, containing info about attributes + /// [DEPRECATED] Use the VertexPulling(const Config&) /// @param vertex_state the vertex state descriptor void SetVertexState(const VertexStateDescriptor& vertex_state); - /// Sets the entry point to add assignments into + /// [DEPRECATED] Use the VertexPulling(const Config&) /// @param entry_point the vertex stage entry point void SetEntryPoint(std::string entry_point); - /// Sets the "set" we will put all our vertex buffers into (as storage /// buffers) - /// DEPRECATED + /// [DEPRECATED] Use the VertexPulling(const Config&) /// @param number the set number we will use void SetPullingBufferBindingSet(uint32_t number); /// Sets the "group" we will put all our vertex buffers into (as storage /// buffers) + /// [DEPRECATED] Use the VertexPulling(const Config&) /// @param number the group number we will use void SetPullingBufferBindingGroup(uint32_t number); @@ -165,19 +193,8 @@ class VertexPulling : public Transform { Output Run(const Program* program) override; private: - struct Config { - Config(); - Config(const Config&); - ~Config(); - - std::string entry_point_name; - VertexStateDescriptor vertex_state; - bool vertex_state_set = false; - // Default to 4 as it is past the limits of user-accessible groups - uint32_t pulling_group = 4u; - }; - Config cfg; + bool vertex_state_set = false; struct State { State(CloneContext& ctx, const Config& c); diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index c81a00e931..cab1d4dca8 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -30,26 +30,25 @@ TEST_F(VertexPullingTest, Error_NoVertexState) { fn main() -> void {} )"; - auto* expect = R"(manager().Run() errored: -error: SetVertexState not called)"; + auto* expect = "error: SetVertexState not called"; auto got = Transform(src); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, Error_NoEntryPoint) { auto* src = ""; - auto* expect = R"(manager().Run() errored: -error: Vertex stage entry point not found)"; + auto* expect = "error: Vertex stage entry point not found"; - auto transform = std::make_unique(); - transform->SetVertexState({}); + VertexPulling::Config cfg; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, Error_InvalidEntryPoint) { @@ -58,16 +57,16 @@ TEST_F(VertexPullingTest, Error_InvalidEntryPoint) { fn main() -> void {} )"; - auto* expect = R"(manager().Run() errored: -error: Vertex stage entry point not found)"; + auto* expect = "error: Vertex stage entry point not found"; - auto transform = std::make_unique(); - transform->SetVertexState({}); - transform->SetEntryPoint("_"); + VertexPulling::Config cfg; + cfg.entry_point_name = "_"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, Error_EntryPointWrongStage) { @@ -76,16 +75,16 @@ TEST_F(VertexPullingTest, Error_EntryPointWrongStage) { fn main() -> void {} )"; - auto* expect = R"(manager().Run() errored: -error: Vertex stage entry point not found)"; + auto* expect = "error: Vertex stage entry point not found"; - auto transform = std::make_unique(); - transform->SetVertexState({}); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, BasicModule) { @@ -109,13 +108,14 @@ fn main() -> void { } )"; - auto transform = std::make_unique(); - transform->SetVertexState({}); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, OneAttribute) { @@ -149,14 +149,16 @@ fn main() -> void { } )"; - auto transform = std::make_unique(); - transform->SetVertexState( - {{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.vertex_state = { + {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, OneInstancedAttribute) { @@ -190,14 +192,16 @@ fn main() -> void { } )"; - auto transform = std::make_unique(); - transform->SetVertexState( - {{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}}); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.vertex_state = { + {{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}}; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { @@ -231,15 +235,17 @@ fn main() -> void { } )"; - auto transform = std::make_unique(); - transform->SetVertexState( - {{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); - transform->SetPullingBufferBindingSet(5); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.vertex_state = { + {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}; + cfg.pulling_group = 5; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } // We expect the transform to use an existing builtin variables if it finds them @@ -285,15 +291,26 @@ fn main() -> void { } )"; - auto transform = std::make_unique(); - transform->SetVertexState( - {{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}, - {4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}}); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.vertex_state = {{ + { + 4, + InputStepMode::kVertex, + {{VertexFormat::kF32, 0, 0}}, + }, + { + 4, + InputStepMode::kInstance, + {{VertexFormat::kF32, 0, 1}}, + }, + }}; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { @@ -332,16 +349,18 @@ fn main() -> void { } )"; - auto transform = std::make_unique(); - transform->SetVertexState( - {{{16, - InputStepMode::kVertex, - {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.vertex_state = { + {{16, + InputStepMode::kVertex, + {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } TEST_F(VertexPullingTest, FloatVectorAttributes) { @@ -389,16 +408,19 @@ fn main() -> void { } )"; - auto transform = std::make_unique(); - transform->SetVertexState( - {{{8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}}, - {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}}, - {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}}); - transform->SetEntryPoint("main"); + VertexPulling::Config cfg; + cfg.vertex_state = {{ + {8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}}, + {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}}, + {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}, + }}; + cfg.entry_point_name = "main"; + + auto transform = std::make_unique(cfg); auto got = Transform(src, std::move(transform)); - EXPECT_EQ(expect, got); + EXPECT_EQ(expect, str(got)); } } // namespace diff --git a/src/writer/hlsl/test_helper.h b/src/writer/hlsl/test_helper.h index 2b169f7ad0..3c6cff7d30 100644 --- a/src/writer/hlsl/test_helper.h +++ b/src/writer/hlsl/test_helper.h @@ -104,8 +104,8 @@ class TestHelperBase : public BODY, public ProgramBuilder { }(); auto result = transform::Hlsl().Run(program.get()); [&]() { - ASSERT_FALSE(result.diagnostics.contains_errors()) - << formatter.format(result.diagnostics); + ASSERT_TRUE(result.program.IsValid()) + << formatter.format(result.program.Diagnostics()); }(); *program = std::move(result.program); gen_ = std::make_unique(program.get()); diff --git a/src/writer/spirv/test_helper.h b/src/writer/spirv/test_helper.h index ba5e51a239..bd053ab02f 100644 --- a/src/writer/spirv/test_helper.h +++ b/src/writer/spirv/test_helper.h @@ -81,8 +81,8 @@ class TestHelperBase : public ProgramBuilder, public BASE { }(); auto result = transform::Spirv().Run(program.get()); [&]() { - ASSERT_FALSE(result.diagnostics.contains_errors()) - << diag::Formatter().format(result.diagnostics); + ASSERT_TRUE(result.program.IsValid()) + << diag::Formatter().format(result.program.Diagnostics()); }(); *program = std::move(result.program); spirv_builder = std::make_unique(program.get());