From d408f2465ae7b39f4510af2392c22914043fa71c Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Mon, 14 Dec 2020 20:31:17 +0000 Subject: [PATCH] Remove BlockStatement::insert() Bug: tint:396 Bug: tint:390 Change-Id: I719b84804164fa801ded505ed56717948f06c7a7 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35502 Commit-Queue: Ben Clayton Reviewed-by: dan sinclair --- src/ast/block_statement.h | 16 +- src/ast/block_statement_test.cc | 18 -- src/ast/module.cc | 21 ++- src/ast/module.h | 21 ++- src/reader/wgsl/parser_impl.cc | 27 ++- src/reader/wgsl/parser_impl.h | 2 +- src/reader/wgsl/parser_impl_for_stmt_test.cc | 7 +- .../wgsl/parser_impl_statements_test.cc | 4 +- src/transform/bound_array_accessors.cc | 8 +- src/transform/emit_vertex_point_size.cc | 72 ++++--- src/transform/first_index_offset.cc | 95 +++++----- src/transform/transform.cc | 19 ++ src/transform/transform.h | 12 ++ src/transform/vertex_pulling.cc | 178 ++++++++++-------- src/transform/vertex_pulling.h | 32 ++-- src/transform/vertex_pulling_test.cc | 116 ++++++------ 16 files changed, 349 insertions(+), 299 deletions(-) diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h index 9b5417af16..d485bdf7a6 100644 --- a/src/ast/block_statement.h +++ b/src/ast/block_statement.h @@ -35,14 +35,6 @@ class BlockStatement : public Castable { BlockStatement(BlockStatement&&); ~BlockStatement() override; - /// Insert a statement to the block - /// @param index the index to insert at - /// @param stmt the statement to insert - void insert(size_t index, Statement* stmt) { - auto offset = static_cast(index); - statements_.insert(statements_.begin() + offset, stmt); - } - /// @returns true if the block is empty bool empty() const { return statements_.empty(); } /// @returns the number of statements directly in the block @@ -60,16 +52,12 @@ class BlockStatement : public Castable { /// Retrieves the statement at `idx` /// @param idx the index. The index is not bounds checked. /// @returns the statement at `idx` - const Statement* get(size_t idx) const { return statements_[idx]; } + Statement* get(size_t idx) const { return statements_[idx]; } /// Retrieves the statement at `idx` /// @param idx the index. The index is not bounds checked. /// @returns the statement at `idx` - Statement* operator[](size_t idx) { return statements_[idx]; } - /// Retrieves the statement at `idx` - /// @param idx the index. The index is not bounds checked. - /// @returns the statement at `idx` - const Statement* operator[](size_t idx) const { return statements_[idx]; } + Statement* operator[](size_t idx) const { return statements_[idx]; } /// @returns the beginning iterator StatementList::const_iterator begin() const { return statements_.begin(); } diff --git a/src/ast/block_statement_test.cc b/src/ast/block_statement_test.cc index 71b71d20d8..78134c6ca3 100644 --- a/src/ast/block_statement_test.cc +++ b/src/ast/block_statement_test.cc @@ -37,24 +37,6 @@ TEST_F(BlockStatementTest, Creation) { EXPECT_EQ(b[0], ptr); } -TEST_F(BlockStatementTest, Creation_WithInsert) { - auto* s1 = create(Source{}); - auto* s2 = create(Source{}); - auto* s3 = create(Source{}); - - BlockStatement b(Source{}, StatementList{}); - b.insert(0, s1); - b.insert(0, s2); - b.insert(1, s3); - - // |b| should contain s2, s3, s1 - - ASSERT_EQ(b.size(), 3u); - EXPECT_EQ(b[0], s2); - EXPECT_EQ(b[1], s3); - EXPECT_EQ(b[2], s1); -} - TEST_F(BlockStatementTest, Creation_WithSource) { BlockStatement b(Source{Source::Location{20, 2}}, ast::StatementList{}); auto src = b.source(); diff --git a/src/ast/module.cc b/src/ast/module.cc index 6d5cf58618..a357554e18 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -33,15 +33,30 @@ Module::~Module() = default; Module Module::Clone() { Module out; CloneContext ctx(&out); - Clone(&ctx); + + // Symbol table must be cloned first so that the resulting module has the + // symbols before we start the tree mutations. + ctx.mod->symbol_table_ = symbol_table_; + + CloneUsing(&ctx); return out; } -void Module::Clone(CloneContext* ctx) { +Module Module::Clone(const std::function& init) { + Module out; + CloneContext ctx(&out); + // Symbol table must be cloned first so that the resulting module has the // symbols before we start the tree mutations. - ctx->mod->symbol_table_ = symbol_table_; + ctx.mod->symbol_table_ = symbol_table_; + init(&ctx); + + CloneUsing(&ctx); + return out; +} + +void Module::CloneUsing(CloneContext* ctx) { for (auto* ty : constructed_types_) { ctx->mod->constructed_types_.emplace_back(ctx->Clone(ty)); } diff --git a/src/ast/module.h b/src/ast/module.h index 1facd79d8b..e5d153db7e 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -15,6 +15,7 @@ #ifndef SRC_AST_MODULE_H_ #define SRC_AST_MODULE_H_ +#include #include #include #include @@ -55,13 +56,11 @@ class Module { /// @return a deep copy of this module Module Clone(); - /// Clone this module into `ctx->mod` using the provided CloneContext - /// The module will be cloned in this order: - /// * Constructed types - /// * Global variables - /// * Functions - /// @param ctx the clone context - void Clone(CloneContext* ctx); + /// @param init a callback function to configure the CloneContex before + /// cloning any of the module's state + /// @return a deep copy of this module, calling `init` to first initialize the + /// context. + Module Clone(const std::function& init); /// Add a global variable to the module /// @param var the variable to add @@ -181,6 +180,14 @@ class Module { private: Module(const Module&) = delete; + /// Clone this module into `ctx->mod` using the provided CloneContext + /// The module will be cloned in this order: + /// * Constructed types + /// * Global variables + /// * Functions + /// @param ctx the clone context + void CloneUsing(CloneContext* ctx); + SymbolTable symbol_table_; VariableList global_variables_; // The constructed types are owned by the type manager diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc index 6219b32f12..663bdc6d45 100644 --- a/src/reader/wgsl/parser_impl.cc +++ b/src/reader/wgsl/parser_impl.cc @@ -1418,7 +1418,12 @@ Expect ParserImpl::expect_builtin() { // body_stmt // : BRACKET_LEFT statements BRACKET_RIGHT Expect ParserImpl::expect_body_stmt() { - return expect_brace_block("", [&] { return expect_statements(); }); + return expect_brace_block("", [&]() -> Expect { + auto stmts = expect_statements(); + if (stmts.errored) + return Failure::kErrored; + return create(Source{}, stmts.value); + }); } // paren_rhs_stmt @@ -1437,7 +1442,7 @@ Expect ParserImpl::expect_paren_rhs_stmt() { // statements // : statement* -Expect ParserImpl::expect_statements() { +Expect ParserImpl::expect_statements() { bool errored = false; ast::StatementList stmts; @@ -1455,7 +1460,7 @@ Expect ParserImpl::expect_statements() { if (errored) return Failure::kErrored; - return create(Source{}, stmts); + return stmts; } // statement @@ -1859,15 +1864,16 @@ Maybe ParserImpl::loop_stmt() { return Failure::kNoMatch; return expect_brace_block("loop", [&]() -> Maybe { - auto body = expect_statements(); - if (body.errored) + auto stmts = expect_statements(); + if (stmts.errored) return Failure::kErrored; auto continuing = continuing_stmt(); if (continuing.errored) return Failure::kErrored; - return create(source, body.value, continuing.value); + auto* body = create(source, stmts.value); + return create(source, body, continuing.value); }); } @@ -1958,9 +1964,9 @@ Maybe ParserImpl::for_stmt() { if (header.errored) return Failure::kErrored; - auto body = + auto stmts = expect_brace_block("for loop", [&] { return expect_statements(); }); - if (body.errored) + if (stmts.errored) return Failure::kErrored; // The for statement is a syntactic sugar on top of the loop statement. @@ -1980,7 +1986,7 @@ Maybe ParserImpl::for_stmt() { auto* break_if_not_condition = create(not_condition->source(), not_condition, break_body, ast::ElseStatementList{}); - body->insert(0, break_if_not_condition); + stmts.value.insert(stmts.value.begin(), break_if_not_condition); } ast::BlockStatement* continuing_body = nullptr; @@ -1991,7 +1997,8 @@ Maybe ParserImpl::for_stmt() { }); } - auto* loop = create(source, body.value, continuing_body); + auto* body = create(source, stmts.value); + auto* loop = create(source, body, continuing_body); if (header->initializer != nullptr) { return create(source, ast::StatementList{ diff --git a/src/reader/wgsl/parser_impl.h b/src/reader/wgsl/parser_impl.h index 1bdd73ba1d..c523d06d51 100644 --- a/src/reader/wgsl/parser_impl.h +++ b/src/reader/wgsl/parser_impl.h @@ -468,7 +468,7 @@ class ParserImpl { Expect expect_paren_rhs_stmt(); /// Parses a `statements` grammar element /// @returns the statements parsed - Expect expect_statements(); + Expect expect_statements(); /// Parses a `statement` grammar element /// @returns the parsed statement or nullptr Maybe statement(); diff --git a/src/reader/wgsl/parser_impl_for_stmt_test.cc b/src/reader/wgsl/parser_impl_for_stmt_test.cc index 2172e21495..b112779509 100644 --- a/src/reader/wgsl/parser_impl_for_stmt_test.cc +++ b/src/reader/wgsl/parser_impl_for_stmt_test.cc @@ -15,6 +15,7 @@ #include #include "gtest/gtest.h" +#include "src/ast/block_statement.h" #include "src/reader/wgsl/parser_impl.h" #include "src/reader/wgsl/parser_impl_test_helper.h" @@ -30,15 +31,15 @@ class ForStmtTest : public ParserImplTest { auto e_loop = p_loop->expect_statements(); EXPECT_FALSE(e_loop.errored); EXPECT_FALSE(p_loop->has_error()) << p_loop->error(); - ASSERT_NE(e_loop.value, nullptr); auto p_for = parser(for_str); auto e_for = p_for->expect_statements(); EXPECT_FALSE(e_for.errored); EXPECT_FALSE(p_for->has_error()) << p_for->error(); - ASSERT_NE(e_for.value, nullptr); - EXPECT_EQ(e_loop->str(), e_for->str()); + std::string loop = ast::BlockStatement({}, e_loop.value).str(); + std::string for_ = ast::BlockStatement({}, e_for.value).str(); + EXPECT_EQ(loop, for_); } }; diff --git a/src/reader/wgsl/parser_impl_statements_test.cc b/src/reader/wgsl/parser_impl_statements_test.cc index 88b6012fcb..d33176c5ee 100644 --- a/src/reader/wgsl/parser_impl_statements_test.cc +++ b/src/reader/wgsl/parser_impl_statements_test.cc @@ -29,8 +29,8 @@ TEST_F(ParserImplTest, Statements) { EXPECT_FALSE(e.errored); EXPECT_FALSE(p->has_error()) << p->error(); ASSERT_EQ(e->size(), 2u); - EXPECT_TRUE(e->get(0)->Is()); - EXPECT_TRUE(e->get(1)->Is()); + EXPECT_TRUE(e.value[0]->Is()); + EXPECT_TRUE(e.value[1]->Is()); } TEST_F(ParserImplTest, Statements_Empty) { diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc index d68da7c1f5..3adf6b7d26 100644 --- a/src/transform/bound_array_accessors.cc +++ b/src/transform/bound_array_accessors.cc @@ -55,11 +55,11 @@ BoundArrayAccessors::~BoundArrayAccessors() = default; Transform::Output BoundArrayAccessors::Run(ast::Module* mod) { Output out; - ast::CloneContext ctx(&out.module); - ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) { - return Transform(expr, &ctx, &out.diagnostics); + out.module = mod->Clone([&](ast::CloneContext* ctx) { + ctx->ReplaceAll([&, ctx](ast::ArrayAccessorExpression* expr) { + return Transform(expr, ctx, &out.diagnostics); + }); }); - mod->Clone(&ctx); return out; } diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc index 1d22a26ea0..d8640d815c 100644 --- a/src/transform/emit_vertex_point_size.cc +++ b/src/transform/emit_vertex_point_size.cc @@ -19,6 +19,7 @@ #include "src/ast/assignment_statement.h" #include "src/ast/block_statement.h" +#include "src/ast/clone_context.h" #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" #include "src/ast/scalar_constructor_expression.h" @@ -39,45 +40,56 @@ EmitVertexPointSize::~EmitVertexPointSize() = default; Transform::Output EmitVertexPointSize::Run(ast::Module* in) { Output out; - out.module = in->Clone(); - auto* mod = &out.module; - if (!mod->HasStage(ast::PipelineStage::kVertex)) { + if (!in->HasStage(ast::PipelineStage::kVertex)) { // If the module doesn't have any vertex stages, then there's nothing to do. + out.module = in->Clone(); return out; } - auto* f32 = mod->create(); + tint::ast::AssignmentStatement* pointsize_assign = nullptr; + auto get_pointsize_assign = [&pointsize_assign](ast::Module* mod) { + if (pointsize_assign != nullptr) { + return pointsize_assign; + } - // Declare the pointsize builtin output variable. - auto* pointsize_var = - mod->create(Source{}, // source - kPointSizeVar, // name - ast::StorageClass::kOutput, // storage_class - f32, // type - false, // is_const - nullptr, // constructor - ast::VariableDecorationList{ - // decorations - mod->create( - ast::Builtin::kPointSize, Source{}), - }); - mod->AddGlobalVariable(pointsize_var); + auto* f32 = mod->create(); - // Build the AST expression & statement for assigning pointsize one. - auto* one = mod->create( - Source{}, mod->create(Source{}, f32, 1.0f)); - auto* pointsize_ident = mod->create( - Source{}, mod->RegisterSymbol(kPointSizeVar), kPointSizeVar); - auto* pointsize_assign = - mod->create(Source{}, pointsize_ident, one); + // Declare the pointsize builtin output variable. + auto* pointsize_var = + mod->create(Source{}, // source + kPointSizeVar, // name + ast::StorageClass::kOutput, // storage_class + f32, // type + false, // is_const + nullptr, // constructor + ast::VariableDecorationList{ + // decorations + mod->create( + ast::Builtin::kPointSize, Source{}), + }); + mod->AddGlobalVariable(pointsize_var); + + // Build the AST expression & statement for assigning pointsize one. + auto* one = mod->create( + Source{}, mod->create(Source{}, f32, 1.0f)); + auto* pointsize_ident = mod->create( + Source{}, mod->RegisterSymbol(kPointSizeVar), kPointSizeVar); + pointsize_assign = + mod->create(Source{}, pointsize_ident, one); + return pointsize_assign; + }; // Add the pointsize assignment statement to the front of all vertex stages. - for (auto* func : mod->functions()) { - if (func->pipeline_stage() == ast::PipelineStage::kVertex) { - func->body()->insert(0, pointsize_assign); - } - } + out.module = in->Clone([&](ast::CloneContext* ctx) { + ctx->ReplaceAll([&, ctx](ast::Function* func) -> ast::Function* { + if (func->pipeline_stage() != ast::PipelineStage::kVertex) { + return nullptr; // Just clone func + } + return CloneWithStatementsAtStart(ctx, func, + {get_pointsize_assign(ctx->mod)}); + }); + }); return out; } diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 97c7670c72..807a2c6924 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -112,70 +112,63 @@ Transform::Output FirstIndexOffset::Run(ast::Module* in) { std::string vertex_index_name; std::string instance_index_name; - Output out; - // Lazilly construct the UniformBuffer on first call to // maybe_create_buffer_var() ast::Variable* buffer_var = nullptr; - auto maybe_create_buffer_var = [&] { + auto maybe_create_buffer_var = [&](ast::Module* mod) { if (buffer_var == nullptr) { - buffer_var = AddUniformBuffer(&out.module); + buffer_var = AddUniformBuffer(mod); } }; // Clone the AST, renaming the kVertexIdx and kInstanceIdx builtins, and add // a CreateFirstIndexOffset() statement to each function that uses one of // these builtins. - ast::CloneContext ctx(&out.module); - ctx.ReplaceAll([&](ast::Variable* var) -> ast::Variable* { - for (ast::VariableDecoration* dec : var->decorations()) { - if (auto* blt_dec = dec->As()) { - ast::Builtin blt_type = blt_dec->value(); - if (blt_type == ast::Builtin::kVertexIdx) { - vertex_index_name = var->name(); - has_vertex_index_ = true; - return clone_variable_with_new_name(&ctx, var, - kIndexOffsetPrefix + var->name()); - } else if (blt_type == ast::Builtin::kInstanceIdx) { - instance_index_name = var->name(); - has_instance_index_ = true; - return clone_variable_with_new_name(&ctx, var, - kIndexOffsetPrefix + var->name()); - } - } - } - return nullptr; // Just clone var - }); - ctx.ReplaceAll( // Note: This happens in the same pass as the rename above - // which determines the original builtin variable names, - // but this should be fine, as variables are cloned first. - [&](ast::Function* func) -> ast::Function* { - maybe_create_buffer_var(); - if (buffer_var == nullptr) { - return nullptr; // no transform need, just clone func - } - ast::StatementList statements; - for (const auto& data : func->local_referenced_builtin_variables()) { - if (data.second->value() == ast::Builtin::kVertexIdx) { - statements.emplace_back(CreateFirstIndexOffset( - vertex_index_name, kFirstVertexName, buffer_var, ctx.mod)); - } else if (data.second->value() == ast::Builtin::kInstanceIdx) { - statements.emplace_back(CreateFirstIndexOffset( - instance_index_name, kFirstInstanceName, buffer_var, ctx.mod)); + + Output out; + out.module = in->Clone([&](ast::CloneContext* ctx) { + ctx->ReplaceAll([&, ctx](ast::Variable* var) -> ast::Variable* { + for (ast::VariableDecoration* dec : var->decorations()) { + if (auto* blt_dec = dec->As()) { + ast::Builtin blt_type = blt_dec->value(); + if (blt_type == ast::Builtin::kVertexIdx) { + vertex_index_name = var->name(); + has_vertex_index_ = true; + return clone_variable_with_new_name( + ctx, var, kIndexOffsetPrefix + var->name()); + } else if (blt_type == ast::Builtin::kInstanceIdx) { + instance_index_name = var->name(); + has_instance_index_ = true; + return clone_variable_with_new_name( + ctx, var, kIndexOffsetPrefix + var->name()); } } - for (auto* s : *func->body()) { - statements.emplace_back(ctx.Clone(s)); - } - return ctx.mod->create( - ctx.Clone(func->source()), func->symbol(), func->name(), - ctx.Clone(func->params()), ctx.Clone(func->return_type()), - ctx.mod->create( - ctx.Clone(func->body()->source()), statements), - ctx.Clone(func->decorations())); - }); + } + return nullptr; // Just clone var + }); + ctx->ReplaceAll( // Note: This happens in the same pass as the rename above + // which determines the original builtin variable names, + // but this should be fine, as variables are cloned first. + [&, ctx](ast::Function* func) -> ast::Function* { + maybe_create_buffer_var(ctx->mod); + if (buffer_var == nullptr) { + return nullptr; // no transform need, just clone func + } + ast::StatementList statements; + for (const auto& data : func->local_referenced_builtin_variables()) { + if (data.second->value() == ast::Builtin::kVertexIdx) { + statements.emplace_back(CreateFirstIndexOffset( + vertex_index_name, kFirstVertexName, buffer_var, ctx->mod)); + } else if (data.second->value() == ast::Builtin::kInstanceIdx) { + statements.emplace_back(CreateFirstIndexOffset( + instance_index_name, kFirstInstanceName, buffer_var, + ctx->mod)); + } + } + return CloneWithStatementsAtStart(ctx, func, statements); + }); + }); - in->Clone(&ctx); return out; } diff --git a/src/transform/transform.cc b/src/transform/transform.cc index f6bdfd23c8..a03b94394d 100644 --- a/src/transform/transform.cc +++ b/src/transform/transform.cc @@ -14,11 +14,30 @@ #include "src/transform/transform.h" +#include "src/ast/block_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/function.h" + namespace tint { namespace transform { Transform::Transform() = default; Transform::~Transform() = default; +ast::Function* Transform::CloneWithStatementsAtStart( + ast::CloneContext* ctx, + ast::Function* in, + ast::StatementList statements) { + for (auto* s : *in->body()) { + statements.emplace_back(ctx->Clone(s)); + } + return ctx->mod->create( + ctx->Clone(in->source()), in->symbol(), in->name(), + ctx->Clone(in->params()), ctx->Clone(in->return_type()), + ctx->mod->create(ctx->Clone(in->body()->source()), + statements), + ctx->Clone(in->decorations())); +} + } // namespace transform } // namespace tint diff --git a/src/transform/transform.h b/src/transform/transform.h index 2a68467114..211be8a9a4 100644 --- a/src/transform/transform.h +++ b/src/transform/transform.h @@ -48,6 +48,18 @@ class Transform { /// @param module the source module to transform /// @returns the transformation result virtual Output Run(ast::Module* module) = 0; + + protected: + /// Clones the function `in` adding `statements` to the beginning of the + /// cloned function body. + /// @param ctx the clone context + /// @param in the function to clone + /// @param statements the statements to prepend to `in`'s body + /// @return the cloned function + static ast::Function* CloneWithStatementsAtStart( + ast::CloneContext* ctx, + ast::Function* in, + ast::StatementList statements); }; } // namespace transform diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index 163d875e61..74feab3f5c 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -20,6 +20,7 @@ #include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" #include "src/ast/bitcast_expression.h" +#include "src/ast/clone_context.h" #include "src/ast/member_accessor_expression.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/stride_decoration.h" @@ -69,27 +70,24 @@ void VertexPulling::SetPullingBufferBindingSet(uint32_t number) { } Transform::Output VertexPulling::Run(ast::Module* in) { - Output out; - out.module = in->Clone(); - - ast::Module* mod = &out.module; - // Check SetVertexState was called if (!cfg.vertex_state_set) { diag::Diagnostic err; err.severity = diag::Severity::Error; err.message = "SetVertexState not called"; + Output out; out.diagnostics.add(std::move(err)); return out; } // Find entry point - auto* func = mod->FindFunctionBySymbolAndStage( - mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex); + auto* func = in->FindFunctionBySymbolAndStage( + in->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex); if (func == nullptr) { diag::Diagnostic err; err.severity = diag::Severity::Error; err.message = "Vertex stage entry point not found"; + Output out; out.diagnostics.add(std::move(err)); return out; } @@ -99,13 +97,22 @@ Transform::Output VertexPulling::Run(ast::Module* in) { // TODO(idanr): Make sure we covered all error cases, to guarantee the // following stages will pass + Output out; + out.module = in->Clone([&](ast::CloneContext* ctx) { + State state{in, ctx->mod, cfg}; + state.FindOrInsertVertexIndexIfUsed(); + state.FindOrInsertInstanceIndexIfUsed(); + state.ConvertVertexInputVariablesToPrivate(); + state.AddVertexStorageBuffers(); - State state{mod, cfg}; - state.FindOrInsertVertexIndexIfUsed(); - state.FindOrInsertInstanceIndexIfUsed(); - state.ConvertVertexInputVariablesToPrivate(); - state.AddVertexStorageBuffers(); - func->body()->insert(0, state.CreateVertexPullingPreamble()); + ctx->ReplaceAll([func, ctx, state](ast::Function* f) -> ast::Function* { + if (f == func) { + return CloneWithStatementsAtStart( + ctx, f, {state.CreateVertexPullingPreamble()}); + } + return nullptr; // Just clone func + }); + }); return out; } @@ -114,11 +121,14 @@ VertexPulling::Config::Config() = default; VertexPulling::Config::Config(const Config&) = default; VertexPulling::Config::~Config() = default; -VertexPulling::State::State(ast::Module* m, const Config& c) : mod(m), cfg(c) {} +VertexPulling::State::State(ast::Module* i, ast::Module* o, const Config& c) + : in(i), out(o), cfg(c) {} + +VertexPulling::State::State(const State&) = default; VertexPulling::State::~State() = default; -std::string VertexPulling::State::GetVertexBufferName(uint32_t index) { +std::string VertexPulling::State::GetVertexBufferName(uint32_t index) const { return kVertexBufferNamePrefix + std::to_string(index); } @@ -135,7 +145,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() { } // Look for an existing vertex index builtin - for (auto* v : mod->global_variables()) { + for (auto* v : in->global_variables()) { if (v->storage_class() != ast::StorageClass::kInput) { continue; } @@ -154,7 +164,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() { vertex_index_name = kDefaultVertexIndexName; auto* var = - mod->create(Source{}, // source + out->create(Source{}, // source vertex_index_name, // name ast::StorageClass::kInput, // storage_class GetI32Type(), // type @@ -162,11 +172,11 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() { nullptr, // constructor ast::VariableDecorationList{ // decorations - mod->create( + out->create( ast::Builtin::kVertexIdx, Source{}), }); - mod->AddGlobalVariable(var); + out->AddGlobalVariable(var); } void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() { @@ -182,7 +192,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() { } // Look for an existing instance index builtin - for (auto* v : mod->global_variables()) { + for (auto* v : in->global_variables()) { if (v->storage_class() != ast::StorageClass::kInput) { continue; } @@ -201,7 +211,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() { instance_index_name = kDefaultInstanceIndexName; auto* var = - mod->create(Source{}, // source + out->create(Source{}, // source instance_index_name, // name ast::StorageClass::kInput, // storage_class GetI32Type(), // type @@ -209,14 +219,14 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() { nullptr, // constructor ast::VariableDecorationList{ // decorations - mod->create( + out->create( ast::Builtin::kInstanceIdx, Source{}), }); - mod->AddGlobalVariable(var); + out->AddGlobalVariable(var); } void VertexPulling::State::ConvertVertexInputVariablesToPrivate() { - for (auto*& v : mod->global_variables()) { + for (auto*& v : in->global_variables()) { if (v->storage_class() != ast::StorageClass::kInput) { continue; } @@ -227,7 +237,7 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() { // This is where the replacement happens. Expressions use identifier // strings instead of pointers, so we don't need to update any other // place in the AST. - v = mod->create( + v = out->create( Source{}, // source v->name(), // name ast::StorageClass::kPrivate, // storage_class @@ -245,31 +255,31 @@ void VertexPulling::State::ConvertVertexInputVariablesToPrivate() { void VertexPulling::State::AddVertexStorageBuffers() { // TODO(idanr): Make this readonly https://github.com/gpuweb/gpuweb/issues/935 // The array inside the struct definition - auto* internal_array_type = mod->create( + auto* internal_array_type = out->create( GetU32Type(), 0, ast::ArrayDecorationList{ - mod->create(4u, Source{}), + out->create(4u, Source{}), }); // Creating the struct type ast::StructMemberList members; ast::StructMemberDecorationList member_dec; member_dec.push_back( - mod->create(0u, Source{})); + out->create(0u, Source{})); - members.push_back(mod->create( + members.push_back(out->create( Source{}, kStructBufferName, internal_array_type, std::move(member_dec))); ast::StructDecorationList decos; - decos.push_back(mod->create(Source{})); + decos.push_back(out->create(Source{})); - auto* struct_type = mod->create( - mod->RegisterSymbol(kStructName), kStructName, - mod->create(Source{}, std::move(members), std::move(decos))); + auto* struct_type = out->create( + out->RegisterSymbol(kStructName), kStructName, + out->create(Source{}, std::move(members), std::move(decos))); for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) { // The decorated variable with struct type - auto* var = mod->create( + auto* var = out->create( Source{}, // source GetVertexBufferName(i), // name ast::StorageClass::kStorageBuffer, // storage_class @@ -278,23 +288,23 @@ void VertexPulling::State::AddVertexStorageBuffers() { nullptr, // constructor ast::VariableDecorationList{ // decorations - mod->create(i, Source{}), - mod->create(cfg.pulling_set, Source{}), + out->create(i, Source{}), + out->create(cfg.pulling_set, Source{}), }); - mod->AddGlobalVariable(var); + out->AddGlobalVariable(var); } - mod->AddConstructedType(struct_type); + out->AddConstructedType(struct_type); } -ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() { +ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const { // Assign by looking at the vertex descriptor to find attributes with matching // location. ast::StatementList stmts; // Declare the |kPullingPosVarName| variable in the shader - auto* pos_declaration = mod->create( - Source{}, mod->create( + auto* pos_declaration = out->create( + Source{}, out->create( Source{}, // source kPullingPosVarName, // name ast::StorageClass::kFunction, // storage_class @@ -323,45 +333,46 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() { ? vertex_index_name : instance_index_name; // Identifier to index by - auto* index_identifier = mod->create( - Source{}, mod->RegisterSymbol(name), name); + auto* index_identifier = out->create( + Source{}, out->RegisterSymbol(name), name); // An expression for the start of the read in the buffer in bytes - auto* pos_value = mod->create( + auto* pos_value = out->create( Source{}, ast::BinaryOp::kAdd, - mod->create( + out->create( Source{}, ast::BinaryOp::kMultiply, index_identifier, GenUint(static_cast(buffer_layout.array_stride))), GenUint(static_cast(attribute_desc.offset))); // Update position of the read - auto* set_pos_expr = mod->create( + auto* set_pos_expr = out->create( Source{}, CreatePullingPositionIdent(), pos_value); stmts.emplace_back(set_pos_expr); - stmts.emplace_back(mod->create( + stmts.emplace_back(out->create( Source{}, - mod->create( - Source{}, mod->RegisterSymbol(v->name()), v->name()), + out->create( + Source{}, out->RegisterSymbol(v->name()), v->name()), AccessByFormat(i, attribute_desc.format))); } } - return mod->create(Source{}, stmts); + return out->create(Source{}, stmts); } -ast::Expression* VertexPulling::State::GenUint(uint32_t value) { - return mod->create( - Source{}, mod->create(Source{}, GetU32Type(), value)); +ast::Expression* VertexPulling::State::GenUint(uint32_t value) const { + return out->create( + Source{}, out->create(Source{}, GetU32Type(), value)); } -ast::Expression* VertexPulling::State::CreatePullingPositionIdent() { - return mod->create( - Source{}, mod->RegisterSymbol(kPullingPosVarName), kPullingPosVarName); +ast::Expression* VertexPulling::State::CreatePullingPositionIdent() const { + return out->create( + Source{}, out->RegisterSymbol(kPullingPosVarName), kPullingPosVarName); } -ast::Expression* VertexPulling::State::AccessByFormat(uint32_t buffer, - VertexFormat format) { +ast::Expression* VertexPulling::State::AccessByFormat( + uint32_t buffer, + VertexFormat format) const { // TODO(idanr): this doesn't account for the format of the attribute in the // shader. ex: vec in shader, and attribute claims VertexFormat::Float4 // right now, we would try to assign a vec4 to this attribute, but we @@ -388,43 +399,44 @@ ast::Expression* VertexPulling::State::AccessByFormat(uint32_t buffer, } ast::Expression* VertexPulling::State::AccessU32(uint32_t buffer, - ast::Expression* pos) { + ast::Expression* pos) const { // Here we divide by 4, since the buffer is uint32 not uint8. The input buffer // has byte offsets for each attribute, and we will convert it to u32 indexes // by dividing. Then, that element is going to be read, and if needed, // unpacked into an appropriate variable. All reads should end up here as a // base case. auto vbuf_name = GetVertexBufferName(buffer); - return mod->create( + return out->create( Source{}, - mod->create( + out->create( Source{}, - mod->create( - Source{}, mod->RegisterSymbol(vbuf_name), vbuf_name), - mod->create( - Source{}, mod->RegisterSymbol(kStructBufferName), + out->create( + Source{}, out->RegisterSymbol(vbuf_name), vbuf_name), + out->create( + Source{}, out->RegisterSymbol(kStructBufferName), kStructBufferName)), - mod->create(Source{}, ast::BinaryOp::kDivide, pos, + out->create(Source{}, ast::BinaryOp::kDivide, pos, GenUint(4))); } ast::Expression* VertexPulling::State::AccessI32(uint32_t buffer, - ast::Expression* pos) { + ast::Expression* pos) const { // as reinterprets bits - return mod->create(Source{}, GetI32Type(), + return out->create(Source{}, GetI32Type(), AccessU32(buffer, pos)); } ast::Expression* VertexPulling::State::AccessF32(uint32_t buffer, - ast::Expression* pos) { + ast::Expression* pos) const { // as reinterprets bits - return mod->create(Source{}, GetF32Type(), + return out->create(Source{}, GetF32Type(), AccessU32(buffer, pos)); } -ast::Expression* VertexPulling::State::AccessPrimitive(uint32_t buffer, - ast::Expression* pos, - VertexFormat format) { +ast::Expression* VertexPulling::State::AccessPrimitive( + uint32_t buffer, + ast::Expression* pos, + VertexFormat format) const { // This function uses a position expression to read, rather than using the // position variable. This allows us to read from offset positions relative to // |kPullingPosVarName|. We can't call AccessByFormat because it reads only @@ -445,31 +457,31 @@ ast::Expression* VertexPulling::State::AccessVec(uint32_t buffer, uint32_t element_stride, ast::type::Type* base_type, VertexFormat base_format, - uint32_t count) { + uint32_t count) const { ast::ExpressionList expr_list; for (uint32_t i = 0; i < count; ++i) { // Offset read position by element_stride for each component - auto* cur_pos = mod->create( + auto* cur_pos = out->create( Source{}, ast::BinaryOp::kAdd, CreatePullingPositionIdent(), GenUint(element_stride * i)); expr_list.push_back(AccessPrimitive(buffer, cur_pos, base_format)); } - return mod->create( - Source{}, mod->create(base_type, count), + return out->create( + Source{}, out->create(base_type, count), std::move(expr_list)); } -ast::type::Type* VertexPulling::State::GetU32Type() { - return mod->create(); +ast::type::Type* VertexPulling::State::GetU32Type() const { + return out->create(); } -ast::type::Type* VertexPulling::State::GetI32Type() { - return mod->create(); +ast::type::Type* VertexPulling::State::GetI32Type() const { + return out->create(); } -ast::type::Type* VertexPulling::State::GetF32Type() { - return mod->create(); +ast::type::Type* VertexPulling::State::GetF32Type() const { + return out->create(); } VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default; diff --git a/src/transform/vertex_pulling.h b/src/transform/vertex_pulling.h index 4f55b1fef4..6d4f23c1b8 100644 --- a/src/transform/vertex_pulling.h +++ b/src/transform/vertex_pulling.h @@ -178,12 +178,13 @@ class VertexPulling : public Transform { Config cfg; struct State { - State(ast::Module* m, const Config& c); + State(ast::Module* in, ast::Module* out, const Config& c); + explicit State(const State&); ~State(); /// Generate the vertex buffer binding name /// @param index index to append to buffer name - std::string GetVertexBufferName(uint32_t index); + std::string GetVertexBufferName(uint32_t index) const; /// Inserts vertex_idx binding, or finds the existing one void FindOrInsertVertexIndexIfUsed(); @@ -198,36 +199,36 @@ class VertexPulling : public Transform { void AddVertexStorageBuffers(); /// Creates and returns the assignment to the variables from the buffers - ast::BlockStatement* CreateVertexPullingPreamble(); + ast::BlockStatement* CreateVertexPullingPreamble() const; /// Generates an expression holding a constant uint /// @param value uint value - ast::Expression* GenUint(uint32_t value); + ast::Expression* GenUint(uint32_t value) const; /// Generates an expression to read the shader value `kPullingPosVarName` - ast::Expression* CreatePullingPositionIdent(); + ast::Expression* CreatePullingPositionIdent() const; /// Generates an expression reading from a buffer a specific format. /// This reads the value wherever `kPullingPosVarName` points to at the time /// of the read. /// @param buffer the index of the vertex buffer /// @param format the format to read - ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format); + ast::Expression* AccessByFormat(uint32_t buffer, VertexFormat format) const; /// Generates an expression reading a uint32 from a vertex buffer /// @param buffer the index of the vertex buffer /// @param pos an expression for the position of the access, in bytes - ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos); + ast::Expression* AccessU32(uint32_t buffer, ast::Expression* pos) const; /// Generates an expression reading an int32 from a vertex buffer /// @param buffer the index of the vertex buffer /// @param pos an expression for the position of the access, in bytes - ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos); + ast::Expression* AccessI32(uint32_t buffer, ast::Expression* pos) const; /// Generates an expression reading a float from a vertex buffer /// @param buffer the index of the vertex buffer /// @param pos an expression for the position of the access, in bytes - ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos); + ast::Expression* AccessF32(uint32_t buffer, ast::Expression* pos) const; /// Generates an expression reading a basic type (u32, i32, f32) from a /// vertex buffer @@ -236,7 +237,7 @@ class VertexPulling : public Transform { /// @param format the underlying vertex format ast::Expression* AccessPrimitive(uint32_t buffer, ast::Expression* pos, - VertexFormat format); + VertexFormat format) const; /// Generates an expression reading a vec2/3/4 from a vertex buffer. /// This reads the value wherever `kPullingPosVarName` points to at the time @@ -250,14 +251,15 @@ class VertexPulling : public Transform { uint32_t element_stride, ast::type::Type* base_type, VertexFormat base_format, - uint32_t count); + uint32_t count) const; // Used to grab corresponding types from the type manager - ast::type::Type* GetU32Type(); - ast::type::Type* GetI32Type(); - ast::type::Type* GetF32Type(); + ast::type::Type* GetU32Type() const; + ast::type::Type* GetI32Type() const; + ast::type::Type* GetF32Type() const; - ast::Module* const mod; + ast::Module* const in; + ast::Module* const out; Config const cfg; std::unordered_map location_to_var; diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index 30145ce9c1..818853f356 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -175,11 +175,6 @@ TEST_F(VertexPullingTest, OneAttribute) { [[block]] StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4} } - Variable{ - var_a - private - __f32 - } Variable{ Decorations{ BuiltinDecoration{vertex_idx} @@ -197,6 +192,11 @@ TEST_F(VertexPullingTest, OneAttribute) { storage_buffer __struct_TintVertexData } + Variable{ + var_a + private + __f32 + } Function main -> __void StageDecoration{vertex} () @@ -262,11 +262,6 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) { [[block]] StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4} } - Variable{ - var_a - private - __f32 - } Variable{ Decorations{ BuiltinDecoration{instance_idx} @@ -284,6 +279,11 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) { storage_buffer __struct_TintVertexData } + Variable{ + var_a + private + __f32 + } Function main -> __void StageDecoration{vertex} () @@ -349,11 +349,6 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { [[block]] StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4} } - Variable{ - var_a - private - __f32 - } Variable{ Decorations{ BuiltinDecoration{vertex_idx} @@ -371,6 +366,11 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { storage_buffer __struct_TintVertexData } + Variable{ + var_a + private + __f32 + } Function main -> __void StageDecoration{vertex} () @@ -465,6 +465,24 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { [[block]] StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4} } + Variable{ + Decorations{ + BindingDecoration{0} + SetDecoration{4} + } + _tint_pulling_vertex_buffer_0 + storage_buffer + __struct_TintVertexData + } + Variable{ + Decorations{ + BindingDecoration{1} + SetDecoration{4} + } + _tint_pulling_vertex_buffer_1 + storage_buffer + __struct_TintVertexData + } Variable{ var_a private @@ -491,24 +509,6 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { in __i32 } - Variable{ - Decorations{ - BindingDecoration{0} - SetDecoration{4} - } - _tint_pulling_vertex_buffer_0 - storage_buffer - __struct_TintVertexData - } - Variable{ - Decorations{ - BindingDecoration{1} - SetDecoration{4} - } - _tint_pulling_vertex_buffer_1 - storage_buffer - __struct_TintVertexData - } Function main -> __void StageDecoration{vertex} () @@ -607,16 +607,6 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { [[block]] StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4} } - Variable{ - var_a - private - __f32 - } - Variable{ - var_b - private - __array__f32_4 - } Variable{ Decorations{ BuiltinDecoration{vertex_idx} @@ -634,6 +624,16 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { storage_buffer __struct_TintVertexData } + Variable{ + var_a + private + __f32 + } + Variable{ + var_b + private + __array__f32_4 + } Function main -> __void StageDecoration{vertex} () @@ -794,21 +794,6 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) { [[block]] StructMember{[[ offset 0 ]] _tint_vertex_data: __array__u32_stride_4} } - Variable{ - var_a - private - __array__f32_2 - } - Variable{ - var_b - private - __array__f32_3 - } - Variable{ - var_c - private - __array__f32_4 - } Variable{ Decorations{ BuiltinDecoration{vertex_idx} @@ -844,6 +829,21 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) { storage_buffer __struct_TintVertexData } + Variable{ + var_a + private + __array__f32_2 + } + Variable{ + var_b + private + __array__f32_3 + } + Variable{ + var_c + private + __array__f32_4 + } Function main -> __void StageDecoration{vertex} ()