diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h index c7c15877d1..f1be730374 100644 --- a/src/ast/block_statement.h +++ b/src/ast/block_statement.h @@ -45,6 +45,11 @@ class BlockStatement : public Statement { /// @returns the number of statements directly in the block size_t size() const { return statements_.size(); } + /// Retrieves the statement at |idx| + /// @param idx the index. The index is not bounds checked. + /// @returns the statement at |idx| + ast::Statement* get(size_t idx) { return statements_[idx].get(); } + /// Retrieves the statement at |idx| /// @param idx the index. The index is not bounds checked. /// @returns the statement at |idx| diff --git a/src/ast/function.cc b/src/ast/function.cc index 380d9a60d9..6fe5a3212b 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -29,7 +29,8 @@ Function::Function(const std::string& name, : Node(), name_(name), params_(std::move(params)), - return_type_(return_type) {} + return_type_(return_type), + body_(std::make_unique()) {} Function::Function(const Source& source, const std::string& name, @@ -38,7 +39,8 @@ Function::Function(const Source& source, : Node(source), name_(name), params_(std::move(params)), - return_type_(return_type) {} + return_type_(return_type), + body_(std::make_unique()) {} Function::Function(Function&&) = default; @@ -154,16 +156,20 @@ void Function::add_ancestor_entry_point(const std::string& ep) { ancestor_entry_points_.push_back(ep); } +void Function::set_body(StatementList body) { + for (auto& stmt : body) { + body_->append(std::move(stmt)); + } +} + bool Function::IsValid() const { for (const auto& param : params_) { if (param == nullptr || !param->IsValid()) return false; } - for (const auto& stmt : body_) { - if (stmt == nullptr || !stmt->IsValid()) - return false; + if (body_ == nullptr || !body_->IsValid()) { + return false; } - if (name_.length() == 0) { return false; } @@ -194,7 +200,7 @@ void Function::to_str(std::ostream& out, size_t indent) const { make_indent(out, indent); out << "{" << std::endl; - for (const auto& stmt : body_) + for (const auto& stmt : *body_) stmt->to_str(out, indent + 2); make_indent(out, indent); diff --git a/src/ast/function.h b/src/ast/function.h index b130ddc765..a6dbe1877f 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -22,6 +22,7 @@ #include #include "src/ast/binding_decoration.h" +#include "src/ast/block_statement.h" #include "src/ast/builtin_decoration.h" #include "src/ast/expression.h" #include "src/ast/location_decoration.h" @@ -123,9 +124,14 @@ class Function : public Node { /// Sets the body of the function /// @param body the function body - void set_body(StatementList body) { body_ = std::move(body); } + void set_body(StatementList body); + /// Sets the body of the function + /// @param body the function body + void set_body(std::unique_ptr body) { + body_ = std::move(body); + } /// @returns the function body - const StatementList& body() const { return body_; } + BlockStatement* body() const { return body_.get(); } /// @returns true if the name and type are both present bool IsValid() const override; @@ -144,7 +150,7 @@ class Function : public Node { std::string name_; VariableList params_; type::Type* return_type_ = nullptr; - StatementList body_; + std::unique_ptr body_; std::vector referenced_module_vars_; std::vector ancestor_entry_points_; }; diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index a2c6471c96..7988388c05 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -188,11 +188,11 @@ TEST_F(FunctionTest, IsValid) { params.push_back( std::make_unique("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique()); + auto block = std::make_unique(); + block->append(std::make_unique()); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); EXPECT_TRUE(f.IsValid()); } @@ -251,12 +251,12 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) { params.push_back( std::make_unique("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique()); - body.push_back(nullptr); + auto block = std::make_unique(); + block->append(std::make_unique()); + block->append(nullptr); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); EXPECT_FALSE(f.IsValid()); } @@ -268,12 +268,12 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) { params.push_back( std::make_unique("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique()); - body.push_back(nullptr); + auto block = std::make_unique(); + block->append(std::make_unique()); + block->append(nullptr); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); EXPECT_FALSE(f.IsValid()); } @@ -281,11 +281,11 @@ TEST_F(FunctionTest, ToStr) { type::VoidType void_type; type::I32Type i32; - StatementList body; - body.push_back(std::make_unique()); + auto block = std::make_unique(); + block->append(std::make_unique()); Function f("func", {}, &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); std::ostringstream out; f.to_str(out, 2); @@ -305,11 +305,11 @@ TEST_F(FunctionTest, ToStr_WithParams) { params.push_back( std::make_unique("var", StorageClass::kNone, &i32)); - StatementList body; - body.push_back(std::make_unique()); + auto block = std::make_unique(); + block->append(std::make_unique()); Function f("func", std::move(params), &void_type); - f.set_body(std::move(body)); + f.set_body(std::move(block)); std::ostringstream out; f.to_str(out, 2); diff --git a/src/reader/wgsl/parser_impl_function_decl_test.cc b/src/reader/wgsl/parser_impl_function_decl_test.cc index 271634e8f7..fe8e3c7d7e 100644 --- a/src/reader/wgsl/parser_impl_function_decl_test.cc +++ b/src/reader/wgsl/parser_impl_function_decl_test.cc @@ -40,8 +40,9 @@ TEST_F(ParserImplTest, FunctionDecl) { ASSERT_NE(f->return_type(), nullptr); EXPECT_TRUE(f->return_type()->IsVoid()); - ASSERT_EQ(f->body().size(), 1u); - EXPECT_TRUE(f->body()[0]->IsReturn()); + auto* body = f->body(); + ASSERT_EQ(body->size(), 1u); + EXPECT_TRUE(body->get(0)->IsReturn()); } TEST_F(ParserImplTest, FunctionDecl_InvalidHeader) { diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 075eed5e5a..c7f9bcc17f 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -734,10 +734,9 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) { std::make_unique("my_var", ast::StorageClass::kNone, &f32); var->set_is_const(true); - ast::StatementList body; - body.push_back(std::make_unique(std::move(var))); - - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique( std::move(my_var), std::make_unique("my_var"))); @@ -756,12 +755,12 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) { auto my_var = std::make_unique("my_var"); auto* my_var_ptr = my_var.get(); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("my_var", ast::StorageClass::kNone, &f32))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::move(my_var), std::make_unique("my_var"))); @@ -823,17 +822,17 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) { std::make_unique("my_func", std::move(params), &f32); auto* func_ptr = func.get(); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("out_var"), std::make_unique("in_var"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("wg_var"), std::make_unique("wg_var"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("sb_var"), std::make_unique("sb_var"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("priv_var"), std::make_unique("priv_var"))); func->set_body(std::move(body)); @@ -882,17 +881,17 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { auto func = std::make_unique("my_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("out_var"), std::make_unique("in_var"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("wg_var"), std::make_unique("wg_var"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("sb_var"), std::make_unique("sb_var"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("priv_var"), std::make_unique("priv_var"))); func->set_body(std::move(body)); @@ -901,7 +900,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { auto func2 = std::make_unique("func", std::move(params), &f32); auto* func2_ptr = func2.get(); - body.push_back(std::make_unique( + + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("out_var"), std::make_unique( std::make_unique("my_func"), @@ -933,9 +934,9 @@ TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) { std::make_unique("my_func", std::move(params), &f32); auto* func_ptr = func.get(); - ast::StatementList body; - body.push_back(std::make_unique(std::move(var))); - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique( std::make_unique("var"), std::make_unique( std::make_unique(&f32, 1.f)))); @@ -1990,9 +1991,10 @@ TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) { auto func = std::make_unique("func", ast::VariableList{}, &i32); - ast::StatementList stmts; - stmts.push_back(std::move(stmt)); - func->set_body(std::move(stmts)); + + auto body = std::make_unique(); + body->append(std::move(stmt)); + func->set_body(std::move(body)); mod()->AddFunction(std::move(func)); @@ -2011,9 +2013,10 @@ TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) { auto func = std::make_unique("func", ast::VariableList{}, &i32); - ast::StatementList stmts; - stmts.push_back(std::move(stmt)); - func->set_body(std::move(stmts)); + + auto body = std::make_unique(); + body->append(std::move(stmt)); + func->set_body(std::move(body)); mod()->AddFunction(std::move(func)); @@ -2030,9 +2033,10 @@ TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) { auto func = std::make_unique("func", ast::VariableList{}, &i32); - ast::StatementList stmts; - stmts.push_back(std::move(stmt)); - func->set_body(std::move(stmts)); + + auto body = std::make_unique(); + body->append(std::move(stmt)); + func->set_body(std::move(body)); mod()->AddFunction(std::move(func)); @@ -3963,13 +3967,14 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints) { auto func_b = std::make_unique("b", std::move(params), &f32); auto* func_b_ptr = func_b.get(); - ast::StatementList body; + auto body = std::make_unique(); func_b->set_body(std::move(body)); auto func_c = std::make_unique("c", std::move(params), &f32); auto* func_c_ptr = func_c.get(); - body.push_back(std::make_unique( + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("second"), std::make_unique( std::make_unique("b"), @@ -3979,7 +3984,8 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints) { auto func_a = std::make_unique("a", std::move(params), &f32); auto* func_a_ptr = func_a.get(); - body.push_back(std::make_unique( + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("first"), std::make_unique( std::make_unique("c"), @@ -3990,12 +3996,13 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints) { std::make_unique("ep_1_func", std::move(params), &f32); auto* ep_1_func_ptr = ep_1_func.get(); - body.push_back(std::make_unique( + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("call_a"), std::make_unique( std::make_unique("a"), ast::ExpressionList{}))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("call_b"), std::make_unique( std::make_unique("b"), @@ -4006,7 +4013,8 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints) { std::make_unique("ep_2_func", std::move(params), &f32); auto* ep_2_func_ptr = ep_2_func.get(); - body.push_back(std::make_unique( + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("call_c"), std::make_unique( std::make_unique("c"), diff --git a/src/validator_impl.cc b/src/validator_impl.cc index 9195baf8c1..3645c96095 100644 --- a/src/validator_impl.cc +++ b/src/validator_impl.cc @@ -43,11 +43,20 @@ bool ValidatorImpl::ValidateFunctions(const ast::FunctionList& funcs) { } bool ValidatorImpl::ValidateFunction(const ast::Function& func) { - if (!ValidateStatements(func.body())) + if (!ValidateStatements(*(func.body()))) return false; return true; } +bool ValidatorImpl::ValidateStatements(const ast::BlockStatement& block) { + for (const auto& stmt : block) { + if (!ValidateStatement(*(stmt.get()))) { + return false; + } + } + return true; +} + bool ValidatorImpl::ValidateStatements(const ast::StatementList& stmts) { for (const auto& stmt : stmts) { if (!ValidateStatement(*(stmt.get()))) { diff --git a/src/validator_impl.h b/src/validator_impl.h index 27ffb8dda3..750fb66b47 100644 --- a/src/validator_impl.h +++ b/src/validator_impl.h @@ -54,6 +54,10 @@ class ValidatorImpl { /// @param func the function to check /// @returns true if the validation was successful bool ValidateFunction(const ast::Function& func); + /// Validates a block of statements + /// @param block the statements to check + /// @returns true if the validation was successful + bool ValidateStatements(const ast::BlockStatement& block); /// Validates a set of statements /// @param stmts the statements to check /// @returns true if the validation was successful diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index f0588b64a3..79fd9ab4d2 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -1231,11 +1231,11 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, } } - out_ << ")"; + out_ << ") "; current_ep_name_ = ep_name; - if (!EmitStatementBlockAndNewline(func->body())) { + if (!EmitBlockAndNewline(func->body())) { return false; } @@ -1395,7 +1395,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) { } generating_entry_point_ = true; - for (const auto& s : func->body()) { + for (const auto& s : *(func->body())) { if (!EmitStatement(s.get())) { return false; } @@ -1578,8 +1578,6 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) { } bool GeneratorImpl::EmitBlock(ast::BlockStatement* stmt) { - make_indent(); - out_ << "{" << std::endl; increment_indent(); @@ -1604,6 +1602,15 @@ bool GeneratorImpl::EmitBlockAndNewline(ast::BlockStatement* stmt) { return result; } +bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) { + make_indent(); + const bool result = EmitBlock(stmt); + if (result) { + out_ << std::endl; + } + return result; +} + bool GeneratorImpl::EmitStatementBlock(const ast::StatementList& statements) { out_ << " {" << std::endl; @@ -1636,7 +1643,7 @@ bool GeneratorImpl::EmitStatement(ast::Statement* stmt) { return EmitAssign(stmt->AsAssign()); } if (stmt->IsBlock()) { - return EmitBlockAndNewline(stmt->AsBlock()); + return EmitIndentedBlockAndNewline(stmt->AsBlock()); } if (stmt->IsBreak()) { return EmitBreak(stmt->AsBreak()); diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 1bc5178a2d..4178a30e7a 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -80,6 +80,10 @@ class GeneratorImpl : public TextGenerator { /// Handles a block statement with a newline at the end /// @param stmt the statement to emit /// @returns true if the statement was emitted successfully + bool EmitIndentedBlockAndNewline(ast::BlockStatement* stmt); + /// Handles a block statement with a newline at the end + /// @param stmt the statement to emit + /// @returns true if the statement was emitted successfully bool EmitBlockAndNewline(ast::BlockStatement* stmt); /// Handles a break statement /// @param stmt the statement to emit diff --git a/src/writer/msl/generator_impl_block_test.cc b/src/writer/msl/generator_impl_block_test.cc index 9be5b60255..f901c4b038 100644 --- a/src/writer/msl/generator_impl_block_test.cc +++ b/src/writer/msl/generator_impl_block_test.cc @@ -50,7 +50,7 @@ TEST_F(MslGeneratorImplTest, Emit_Block_WithoutNewline) { g.increment_indent(); ASSERT_TRUE(g.EmitBlock(&b)) << g.error(); - EXPECT_EQ(g.result(), R"( { + EXPECT_EQ(g.result(), R"({ discard_fragment(); })"); } diff --git a/src/writer/msl/generator_impl_entry_point_test.cc b/src/writer/msl/generator_impl_entry_point_test.cc index ec237dad11..8c85901f43 100644 --- a/src/writer/msl/generator_impl_entry_point_test.cc +++ b/src/writer/msl/generator_impl_entry_point_test.cc @@ -73,11 +73,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Input) { auto func = std::make_unique("vtx_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("foo"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("bar"))); func->set_body(std::move(body)); @@ -139,11 +139,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Output) { auto func = std::make_unique("vtx_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("foo"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("bar"))); func->set_body(std::move(body)); @@ -205,11 +205,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Input) { auto func = std::make_unique("frag_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("foo"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("bar"))); func->set_body(std::move(body)); @@ -271,11 +271,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Output) { auto func = std::make_unique("frag_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("foo"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("bar"))); func->set_body(std::move(body)); @@ -334,11 +334,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Input) { auto func = std::make_unique("comp_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("foo"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("bar"))); func->set_body(std::move(body)); @@ -392,11 +392,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Output) { auto func = std::make_unique("comp_main", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("foo"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("bar"))); func->set_body(std::move(body)); @@ -460,8 +460,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Builtins) { auto func = std::make_unique("frag_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("depth"), std::make_unique( std::make_unique("coord"), diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 0fb8de8599..4a810e65e7 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -53,8 +53,8 @@ TEST_F(MslGeneratorImplTest, Emit_Function) { auto func = std::make_unique("my_func", ast::VariableList{}, &void_type); - ast::StatementList body; - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique()); func->set_body(std::move(body)); ast::Module m; @@ -79,8 +79,8 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Name_Collision) { auto func = std::make_unique("main", ast::VariableList{}, &void_type); - ast::StatementList body; - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique()); func->set_body(std::move(body)); ast::Module m; @@ -113,8 +113,8 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) { auto func = std::make_unique("my_func", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique()); func->set_body(std::move(body)); ast::Module m; @@ -184,11 +184,11 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOutVars) { auto func = std::make_unique("frag_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("foo"))); - body.push_back(std::make_unique()); + body->append(std::make_unique()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -254,13 +254,13 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOut_Builtins) { auto func = std::make_unique("frag_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("depth"), std::make_unique( std::make_unique("coord"), std::make_unique("x")))); - body.push_back(std::make_unique()); + body->append(std::make_unique()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -319,9 +319,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_With_Uniform) { std::make_unique("coord"), std::make_unique("x"))); - ast::StatementList body; - body.push_back(std::make_unique(std::move(var))); - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -375,9 +375,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_With_StorageBuffer) { std::make_unique("coord"), std::make_unique("x"))); - ast::StatementList body; - body.push_back(std::make_unique(std::move(var))); - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -439,14 +439,14 @@ TEST_F(MslGeneratorImplTest, auto sub_func = std::make_unique("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("val"), std::make_unique("param"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("foo"))); sub_func->set_body(std::move(body)); @@ -458,12 +458,14 @@ TEST_F(MslGeneratorImplTest, ast::ExpressionList expr; expr.push_back(std::make_unique( std::make_unique(&f32, 1.0f))); - body.push_back(std::make_unique( + + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("bar"), std::make_unique( std::make_unique("sub_func"), std::move(expr)))); - body.push_back(std::make_unique()); + body->append(std::make_unique()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -530,8 +532,8 @@ TEST_F(MslGeneratorImplTest, auto sub_func = std::make_unique("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("param"))); sub_func->set_body(std::move(body)); @@ -543,12 +545,14 @@ TEST_F(MslGeneratorImplTest, ast::ExpressionList expr; expr.push_back(std::make_unique( std::make_unique(&f32, 1.0f))); - body.push_back(std::make_unique( + + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("depth"), std::make_unique( std::make_unique("sub_func"), std::move(expr)))); - body.push_back(std::make_unique()); + body->append(std::make_unique()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -617,13 +621,13 @@ TEST_F(MslGeneratorImplTest, auto sub_func = std::make_unique("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("depth"), std::make_unique( std::make_unique("coord"), std::make_unique("x")))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("param"))); sub_func->set_body(std::move(body)); @@ -635,12 +639,14 @@ TEST_F(MslGeneratorImplTest, ast::ExpressionList expr; expr.push_back(std::make_unique( std::make_unique(&f32, 1.0f))); - body.push_back(std::make_unique( + + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("depth"), std::make_unique( std::make_unique("sub_func"), std::move(expr)))); - body.push_back(std::make_unique()); + body->append(std::make_unique()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -700,8 +706,8 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoint_With_Uniform) { auto sub_func = std::make_unique("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique( std::make_unique("coord"), std::make_unique("x")))); @@ -722,8 +728,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoint_With_Uniform) { std::make_unique("sub_func"), std::move(expr))); - body.push_back(std::make_unique(std::move(var))); - body.push_back(std::make_unique()); + body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -778,8 +785,8 @@ TEST_F(MslGeneratorImplTest, auto sub_func = std::make_unique("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique( std::make_unique("coord"), std::make_unique("x")))); @@ -800,8 +807,9 @@ TEST_F(MslGeneratorImplTest, std::make_unique("sub_func"), std::move(expr))); - body.push_back(std::make_unique(std::move(var))); - body.push_back(std::make_unique()); + body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); func->set_body(std::move(body)); mod.AddFunction(std::move(func)); @@ -857,11 +865,11 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) { auto sub_func = std::make_unique("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("bar"), std::make_unique("foo"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("foo"))); sub_func->set_body(std::move(body)); @@ -870,12 +878,13 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) { auto func_1 = std::make_unique("frag_1_main", std::move(params), &void_type); - body.push_back(std::make_unique( + body = std::make_unique(); + body->append(std::make_unique( std::make_unique("bar"), std::make_unique( std::make_unique("sub_func"), ast::ExpressionList{}))); - body.push_back(std::make_unique()); + body->append(std::make_unique()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -956,15 +965,15 @@ TEST_F(MslGeneratorImplTest, auto func_1 = std::make_unique("frag_1_main", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("bar"), std::make_unique( std::make_unique(&f32, 1.0f)))); ast::StatementList list; list.push_back(std::make_unique()); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique( ast::BinaryOp::kEqual, std::make_unique( @@ -973,7 +982,7 @@ TEST_F(MslGeneratorImplTest, std::make_unique(&i32, 1))), std::move(list))); - body.push_back(std::make_unique()); + body->append(std::make_unique()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -1017,8 +1026,8 @@ TEST_F(MslGeneratorImplTest, auto sub_func = std::make_unique("sub_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique( std::make_unique(&f32, 1.0)))); sub_func->set_body(std::move(body)); @@ -1028,15 +1037,15 @@ TEST_F(MslGeneratorImplTest, auto func_1 = std::make_unique("frag_1_main", std::move(params), &void_type); - body.push_back(std::make_unique( - std::make_unique("foo", ast::StorageClass::kFunction, - &f32))); - body.back()->AsVariableDecl()->variable()->set_constructor( - std::make_unique( - std::make_unique("sub_func"), - ast::ExpressionList{})); + auto var = std::make_unique( + "foo", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("sub_func"), + ast::ExpressionList{})); - body.push_back(std::make_unique()); + body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); func_1->set_body(std::move(body)); mod.AddFunction(std::move(func_1)); @@ -1126,8 +1135,8 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) { auto func = std::make_unique("my_func", std::move(params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique()); func->set_body(std::move(body)); ast::Module m; diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 94914ce357..00404814cf 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -426,7 +426,7 @@ bool Builder::GenerateFunction(ast::Function* func) { push_function(Function{definition_inst, result_op(), std::move(params)}); - for (const auto& stmt : func->body()) { + for (const auto& stmt : *(func->body())) { if (!GenerateStatement(stmt.get())) { return false; } diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc index 5499f596e9..5b78cbff30 100644 --- a/src/writer/spirv/builder_call_test.cc +++ b/src/writer/spirv/builder_call_test.cc @@ -142,8 +142,8 @@ TEST_F(BuilderTest, Expression_Call) { ast::Function a_func("a_func", std::move(func_params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique( ast::BinaryOp::kAdd, std::make_unique("a"), std::make_unique("b")))); @@ -210,8 +210,8 @@ TEST_F(BuilderTest, Statement_Call) { ast::Function a_func("a_func", std::move(func_params), &void_type); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique( ast::BinaryOp::kAdd, std::make_unique("a"), std::make_unique("b")))); diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc index 34dc0f81c9..04176c1654 100644 --- a/src/writer/spirv/builder_entry_point_test.cc +++ b/src/writer/spirv/builder_entry_point_test.cc @@ -168,15 +168,16 @@ TEST_F(BuilderTest, EntryPoint_WithUsedInterfaceIds) { ast::type::VoidType void_type; ast::Function func("main", {}, &void_type); - ast::StatementList body; - body.push_back(std::make_unique( + + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("my_out"), std::make_unique("my_in"))); - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("my_wg"), std::make_unique("my_wg"))); // Add duplicate usages so we show they don't get output multiple times. - body.push_back(std::make_unique( + body->append(std::make_unique( std::make_unique("my_out"), std::make_unique("my_in"))); func.set_body(std::move(body)); diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index a0fb76f3a7..700d8fe944 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -67,8 +67,8 @@ TEST_F(BuilderTest, Function_WithParams) { ast::Function func("a_func", std::move(params), &f32); - ast::StatementList body; - body.push_back(std::make_unique( + auto body = std::make_unique(); + body->append(std::make_unique( std::make_unique("a"))); func.set_body(std::move(body)); @@ -93,8 +93,8 @@ OpFunctionEnd TEST_F(BuilderTest, Function_WithBody) { ast::type::VoidType void_type; - ast::StatementList body; - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique()); ast::Function func("a_func", {}, &void_type); func.set_body(std::move(body)); diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 373ef4081c..3fdabeeb40 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -346,7 +346,8 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { return false; } - return EmitStatementBlockAndNewline(func->body()); + out_ << " "; + return EmitBlockAndNewline(func->body()); } bool GeneratorImpl::EmitType(ast::type::Type* type) { @@ -600,8 +601,6 @@ bool GeneratorImpl::EmitUnaryOp(ast::UnaryOpExpression* expr) { } bool GeneratorImpl::EmitBlock(ast::BlockStatement* stmt) { - make_indent(); - out_ << "{" << std::endl; increment_indent(); @@ -618,6 +617,15 @@ bool GeneratorImpl::EmitBlock(ast::BlockStatement* stmt) { return true; } +bool GeneratorImpl::EmitIndentedBlockAndNewline(ast::BlockStatement* stmt) { + make_indent(); + const bool result = EmitBlock(stmt); + if (result) { + out_ << std::endl; + } + return result; +} + bool GeneratorImpl::EmitBlockAndNewline(ast::BlockStatement* stmt) { const bool result = EmitBlock(stmt); if (result) { @@ -658,7 +666,7 @@ bool GeneratorImpl::EmitStatement(ast::Statement* stmt) { return EmitAssign(stmt->AsAssign()); } if (stmt->IsBlock()) { - return EmitBlockAndNewline(stmt->AsBlock()); + return EmitIndentedBlockAndNewline(stmt->AsBlock()); } if (stmt->IsBreak()) { return EmitBreak(stmt->AsBreak()); diff --git a/src/writer/wgsl/generator_impl.h b/src/writer/wgsl/generator_impl.h index 6c211cf47e..50e3141a3a 100644 --- a/src/writer/wgsl/generator_impl.h +++ b/src/writer/wgsl/generator_impl.h @@ -74,6 +74,10 @@ class GeneratorImpl : public TextGenerator { /// Handles a block statement with a newline at the end /// @param stmt the statement to emit /// @returns true if the statement was emitted successfully + bool EmitIndentedBlockAndNewline(ast::BlockStatement* stmt); + /// Handles a block statement with a newline at the end + /// @param stmt the statement to emit + /// @returns true if the statement was emitted successfully bool EmitBlockAndNewline(ast::BlockStatement* stmt); /// Handles a break statement /// @param stmt the statement to emit diff --git a/src/writer/wgsl/generator_impl_block_test.cc b/src/writer/wgsl/generator_impl_block_test.cc index bc822738b2..43c0286bd1 100644 --- a/src/writer/wgsl/generator_impl_block_test.cc +++ b/src/writer/wgsl/generator_impl_block_test.cc @@ -48,7 +48,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Block_WithoutNewline) { g.increment_indent(); ASSERT_TRUE(g.EmitBlock(&b)) << g.error(); - EXPECT_EQ(g.result(), R"( { + EXPECT_EQ(g.result(), R"({ discard; })"); } diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index c650e93453..0240a6277a 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -30,9 +30,9 @@ namespace { using WgslGeneratorImplTest = testing::Test; TEST_F(WgslGeneratorImplTest, Emit_Function) { - ast::StatementList body; - body.push_back(std::make_unique()); - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique()); + body->append(std::make_unique()); ast::type::VoidType void_type; ast::Function func("my_func", {}, &void_type); @@ -50,9 +50,9 @@ TEST_F(WgslGeneratorImplTest, Emit_Function) { } TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) { - ast::StatementList body; - body.push_back(std::make_unique()); - body.push_back(std::make_unique()); + auto body = std::make_unique(); + body->append(std::make_unique()); + body->append(std::make_unique()); ast::type::F32Type f32; ast::type::I32Type i32;