diff --git a/src/ast/module.h b/src/ast/module.h index 6b08b4f063..59a67b0d50 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -56,6 +56,10 @@ class Module { Module Clone(); /// Clone this module into `ctx->mod` using the provided CloneContext + /// The module will be cloned in this order: + /// * Constructed types + /// * Global variables + /// * Functions /// @param ctx the clone context void Clone(CloneContext* ctx); diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index fcca0c66d4..37014d0cd9 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -25,6 +25,7 @@ #include "src/ast/builtin_decoration.h" #include "src/ast/call_statement.h" #include "src/ast/case_statement.h" +#include "src/ast/clone_context.h" #include "src/ast/constructor_expression.h" #include "src/ast/decorated_variable.h" #include "src/ast/else_statement.h" @@ -61,6 +62,20 @@ constexpr char kFirstVertexName[] = "tint_first_vertex_index"; constexpr char kFirstInstanceName[] = "tint_first_instance_index"; constexpr char kIndexOffsetPrefix[] = "tint_first_index_offset_"; +ast::DecoratedVariable* clone_variable_with_new_name(ast::CloneContext* ctx, + ast::DecoratedVariable* in, + std::string new_name) { + auto* var = ctx->mod->create(ctx->Clone(in->source()), + new_name, in->storage_class(), + ctx->Clone(in->type())); + var->set_is_const(in->is_const()); + var->set_constructor(ctx->Clone(in->constructor())); + + auto* out = ctx->mod->create(var); + out->set_decorations(ctx->Clone(in->decorations())); + return out; +} + } // namespace FirstIndexOffset::FirstIndexOffset(uint32_t binding, uint32_t set) @@ -69,17 +84,29 @@ FirstIndexOffset::FirstIndexOffset(uint32_t binding, uint32_t set) FirstIndexOffset::~FirstIndexOffset() = default; Transform::Output FirstIndexOffset::Run(ast::Module* in) { - Output out; - out.module = in->Clone(); - auto* mod = &out.module; + // First do a quick check to see if the transform has already been applied. + for (ast::Variable* var : in->global_variables()) { + if (auto* dec_var = var->As()) { + if (dec_var->name() == kBufferName) { + diag::Diagnostic err; + err.message = "First index offset transform has already been applied."; + err.severity = diag::Severity::Error; + Output out; + out.diagnostics.add(std::move(err)); + return out; + } + } + } // Running TypeDeterminer as we require local_referenced_builtin_variables() - // to be populated - TypeDeterminer td(mod); + // to be populated. TODO(bclayton) - it should not be necessary to re-run the + // type determiner if semantic information is already generated. Remove. + TypeDeterminer td(in); if (!td.Determine()) { diag::Diagnostic err; err.severity = diag::Severity::Error; err.message = td.error(); + Output out; out.diagnostics.add(std::move(err)); return out; } @@ -87,51 +114,69 @@ Transform::Output FirstIndexOffset::Run(ast::Module* in) { std::string vertex_index_name; std::string instance_index_name; - for (ast::Variable* var : mod->global_variables()) { - if (auto* dec_var = var->As()) { - if (dec_var->name() == kBufferName) { - diag::Diagnostic err; - err.message = "First index offset transform has already been applied."; - err.severity = diag::Severity::Error; - out.diagnostics.add(std::move(err)); - return out; - } + Output out; - for (ast::VariableDecoration* dec : dec_var->decorations()) { - if (auto* blt_dec = dec->As()) { - ast::Builtin blt_type = blt_dec->value(); - if (blt_type == ast::Builtin::kVertexIdx) { - vertex_index_name = var->name(); - var->set_name(kIndexOffsetPrefix + var->name()); - has_vertex_index_ = true; - } else if (blt_type == ast::Builtin::kInstanceIdx) { - instance_index_name = var->name(); - var->set_name(kIndexOffsetPrefix + var->name()); - has_instance_index_ = true; - } + // Lazilly construct the UniformBuffer on first call to + // maybe_create_buffer_var() + ast::Variable* buffer_var = nullptr; + auto maybe_create_buffer_var = [&] { + if (buffer_var == nullptr) { + buffer_var = AddUniformBuffer(&out.module); + } + }; + + // Clone the AST, renaming the kVertexIdx and kInstanceIdx builtins, and add + // a CreateFirstIndexOffset() statement to each function that uses one of + // these builtins. + ast::CloneContext ctx(&out.module); + ctx.ReplaceAll([&](ast::DecoratedVariable* var) -> ast::DecoratedVariable* { + for (ast::VariableDecoration* dec : var->decorations()) { + if (auto* blt_dec = dec->As()) { + ast::Builtin blt_type = blt_dec->value(); + if (blt_type == ast::Builtin::kVertexIdx) { + vertex_index_name = var->name(); + has_vertex_index_ = true; + return clone_variable_with_new_name(&ctx, var, + kIndexOffsetPrefix + var->name()); + } else if (blt_type == ast::Builtin::kInstanceIdx) { + instance_index_name = var->name(); + has_instance_index_ = true; + return clone_variable_with_new_name(&ctx, var, + kIndexOffsetPrefix + var->name()); } } } - } - - if (!has_vertex_index_ && !has_instance_index_) { - return out; - } - - ast::Variable* buffer_var = AddUniformBuffer(mod); - - for (ast::Function* func : mod->functions()) { - for (const auto& data : func->local_referenced_builtin_variables()) { - if (data.second->value() == ast::Builtin::kVertexIdx) { - AddFirstIndexOffset(vertex_index_name, kFirstVertexName, buffer_var, - func, mod); - } else if (data.second->value() == ast::Builtin::kInstanceIdx) { - AddFirstIndexOffset(instance_index_name, kFirstInstanceName, buffer_var, - func, mod); - } - } - } + return nullptr; // Just clone var + }); + ctx.ReplaceAll( // Note: This happens in the same pass as the rename above + // which determines the original builtin variable names, + // but this should be fine, as variables are cloned first. + [&](ast::Function* func) -> ast::Function* { + maybe_create_buffer_var(); + if (buffer_var == nullptr) { + return nullptr; // no transform need, just clone func + } + auto* body = ctx.mod->create( + ctx.Clone(func->body()->source())); + for (const auto& data : func->local_referenced_builtin_variables()) { + if (data.second->value() == ast::Builtin::kVertexIdx) { + body->append(CreateFirstIndexOffset( + vertex_index_name, kFirstVertexName, buffer_var, ctx.mod)); + } else if (data.second->value() == ast::Builtin::kInstanceIdx) { + body->append(CreateFirstIndexOffset( + instance_index_name, kFirstInstanceName, buffer_var, ctx.mod)); + } + } + for (auto* s : *func->body()) { + body->append(ctx.Clone(s)); + } + return ctx.mod->create( + ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()), + ctx.Clone(func->return_type()), ctx.Clone(body), + ctx.Clone(func->decorations())); + }); + in->Clone(&ctx); return out; } @@ -187,12 +232,10 @@ ast::Variable* FirstIndexOffset::AddUniformBuffer(ast::Module* mod) { auto* idx_var = mod->create(mod->create( Source{}, kBufferName, ast::StorageClass::kUniform, struct_type)); - - ast::VariableDecorationList decorations; - decorations.push_back( - mod->create(binding_, Source{})); - decorations.push_back(mod->create(set_, Source{})); - idx_var->set_decorations(std::move(decorations)); + idx_var->set_decorations({ + mod->create(binding_, Source{}), + mod->create(set_, Source{}), + }); mod->AddGlobalVariable(idx_var); @@ -201,11 +244,11 @@ ast::Variable* FirstIndexOffset::AddUniformBuffer(ast::Module* mod) { return idx_var; } -void FirstIndexOffset::AddFirstIndexOffset(const std::string& original_name, - const std::string& field_name, - ast::Variable* buffer_var, - ast::Function* func, - ast::Module* mod) { +ast::VariableDeclStatement* FirstIndexOffset::CreateFirstIndexOffset( + const std::string& original_name, + const std::string& field_name, + ast::Variable* buffer_var, + ast::Module* mod) { auto* buffer = mod->create(buffer_var->name()); auto* var = mod->create(Source{}, original_name, ast::StorageClass::kNone, @@ -217,8 +260,7 @@ void FirstIndexOffset::AddFirstIndexOffset(const std::string& original_name, mod->create(kIndexOffsetPrefix + var->name()), mod->create( buffer, mod->create(field_name)))); - func->body()->insert(0, - mod->create(std::move(var))); + return mod->create(var); } } // namespace transform diff --git a/src/transform/first_index_offset.h b/src/transform/first_index_offset.h index 163bfbb090..873ffc822b 100644 --- a/src/transform/first_index_offset.h +++ b/src/transform/first_index_offset.h @@ -18,6 +18,7 @@ #include #include "src/ast/module.h" +#include "src/ast/variable_decl_statement.h" #include "src/transform/transform.h" namespace tint { @@ -94,12 +95,12 @@ class FirstIndexOffset : public Transform { /// @param original_name the name of the original builtin used in function /// @param field_name name of field in firstVertex/Instance buffer /// @param buffer_var variable of firstVertex/Instance buffer - /// @param func function to modify - void AddFirstIndexOffset(const std::string& original_name, - const std::string& field_name, - ast::Variable* buffer_var, - ast::Function* func, - ast::Module* module); + /// @param module the target module to contain the new ast nodes + ast::VariableDeclStatement* CreateFirstIndexOffset( + const std::string& original_name, + const std::string& field_name, + ast::Variable* buffer_var, + ast::Module* module); uint32_t binding_; uint32_t set_; diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc index 3e20748574..b1dbdd0650 100644 --- a/src/transform/first_index_offset_test.cc +++ b/src/transform/first_index_offset_test.cc @@ -76,6 +76,8 @@ TEST_F(FirstIndexOffsetTest, Error_AlreadyTransformed) { struct Builder : public ModuleBuilder { void Build() override { AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx); + AddFunction("test")->body()->append(create( + Source{}, create("vert_idx"))); } }; @@ -106,15 +108,16 @@ TEST_F(FirstIndexOffsetTest, EmptyModule) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - EXPECT_EQ("Module{\n}\n", result.module.to_str()); + auto got = result.module.to_str(); + auto* expected = "Module{\n}\n"; + EXPECT_EQ(got, expected); } TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) { struct Builder : public ModuleBuilder { void Build() override { AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx); - ast::Function* func = AddFunction("test"); - func->body()->append(create( + AddFunction("test")->body()->append(create( Source{}, create("vert_idx"))); } }; @@ -131,7 +134,9 @@ TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - EXPECT_EQ(R"(Module{ + auto got = result.module.to_str(); + auto* expected = + R"(Module{ TintFirstIndexOffsetData Struct{ [[block]] StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32} @@ -180,14 +185,16 @@ TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) { } } } -)", - result.module.to_str()); +)"; + EXPECT_EQ(got, expected); } TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) { struct Builder : public ModuleBuilder { void Build() override { AddBuiltinInput("inst_idx", ast::Builtin::kInstanceIdx); + AddFunction("test")->body()->append(create( + Source{}, create("inst_idx"))); } }; @@ -202,7 +209,9 @@ TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - EXPECT_EQ(R"(Module{ + + auto got = result.module.to_str(); + auto* expected = R"(Module{ TintFirstIndexOffsetData Struct{ [[block]] StructMember{[[ offset 0 ]] tint_first_instance_index: __u32} @@ -224,9 +233,35 @@ TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) { uniform __struct_TintFirstIndexOffsetData } + Function test -> __u32 + () + { + VariableDeclStatement{ + VariableConst{ + inst_idx + none + __u32 + { + Binary[__u32]{ + Identifier[__ptr_in__u32]{tint_first_index_offset_inst_idx} + add + MemberAccessor[__ptr_uniform__u32]{ + Identifier[__ptr_uniform__struct_TintFirstIndexOffsetData]{tint_first_index_data} + Identifier[not set]{tint_first_instance_index} + } + } + } + } + } + Return{ + { + Identifier[__u32]{inst_idx} + } + } + } } -)", - result.module.to_str()); +)"; + EXPECT_EQ(got, expected); } TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) { @@ -234,6 +269,8 @@ TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) { void Build() override { AddBuiltinInput("inst_idx", ast::Builtin::kInstanceIdx); AddBuiltinInput("vert_idx", ast::Builtin::kVertexIdx); + AddFunction("test")->body()->append( + create(Source{}, Expr(1u))); } }; @@ -251,7 +288,9 @@ TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - EXPECT_EQ(R"(Module{ + + auto got = result.module.to_str(); + auto* expected = R"(Module{ TintFirstIndexOffsetData Struct{ [[block]] StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32} @@ -282,9 +321,18 @@ TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) { uniform __struct_TintFirstIndexOffsetData } + Function test -> __u32 + () + { + Return{ + { + ScalarConstructor[__u32]{1} + } + } + } } -)", - result.module.to_str()); +)"; + EXPECT_EQ(got, expected); EXPECT_TRUE(transform_ptr->HasVertexIndex()); EXPECT_EQ(transform_ptr->GetFirstVertexOffset(), 0u); @@ -321,7 +369,9 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - EXPECT_EQ(R"(Module{ + + auto got = result.module.to_str(); + auto* expected = R"(Module{ TintFirstIndexOffsetData Struct{ [[block]] StructMember{[[ offset 0 ]] tint_first_vertex_index: __u32} @@ -383,8 +433,8 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) { } } } -)", - result.module.to_str()); +)"; + EXPECT_EQ(got, expected); } } // namespace