From cd7eb4f968510147e5d9eec3076fbeab3c16c352 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Sat, 17 Jul 2021 18:09:26 +0000 Subject: [PATCH] Resolver: Validation for continuing blocks Check they do not contain returns, discards Check they do not directly contain continues, however a nested loop can have its own continue. Bug: chromium:1229976 Change-Id: Ia3c4ac118ffdaa6cca6025366c19f9897718c930 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58384 Kokoro: Kokoro Reviewed-by: David Neto Commit-Queue: David Neto Auto-Submit: Ben Clayton --- src/program_builder.h | 30 +++-- src/resolver/resolver.cc | 39 ++++++- src/resolver/validation_test.cc | 118 ++++++++++++++++++++ src/sem/statement.h | 71 +++++++++--- src/writer/hlsl/generator_impl_loop_test.cc | 73 ++++++++---- src/writer/msl/generator_impl_loop_test.cc | 69 ++++++++---- src/writer/wgsl/generator_impl_loop_test.cc | 6 +- 7 files changed, 335 insertions(+), 71 deletions(-) diff --git a/src/program_builder.h b/src/program_builder.h index 70b356d948..d4a25cfe52 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -105,6 +105,14 @@ class CloneContext; /// To construct a Program, populate the builder and then `std::move` it to a /// Program. class ProgramBuilder { + /// A helper used to disable overloads if the first type in `TYPES` is a + /// Source. Used to avoid ambiguities in overloads that take a Source as the + /// first parameter and those that perfectly-forward the first argument. + template + using DisableIfSource = traits::EnableIfIsNotType< + traits::Decay>, + Source>; + /// VarOptionals is a helper for accepting a number of optional, extra /// arguments for Var() and Global(). struct VarOptionals { @@ -1383,7 +1391,7 @@ class ProgramBuilder { /// global variable with the ast::Module. template , Source>* = nullptr> + typename = DisableIfSource> ast::Variable* Global(NAME&& name, const ast::Type* type, OPTIONAL&&... optional) { @@ -1504,9 +1512,7 @@ class ProgramBuilder { /// @param args the function call arguments /// @returns a `ast::CallExpression` to the function `func`, with the /// arguments of `args` converted to `ast::Expression`s using `Expr()`. - template , Source>* = nullptr> + template > ast::CallExpression* Call(NAME&& func, ARGS&&... args) { return create(Expr(func), ExprList(std::forward(args)...)); @@ -1781,7 +1787,7 @@ class ProgramBuilder { /// Creates an ast::ReturnStatement with the given return value /// @param val the return value /// @returns the return statement pointer - template + template > ast::ReturnStatement* Return(EXPR&& val) { return create(Expr(std::forward(val))); } @@ -1886,12 +1892,22 @@ class ProgramBuilder { } /// Creates a ast::BlockStatement with input statements + /// @param source the source information for the block /// @param statements statements of block /// @returns the block statement pointer template - ast::BlockStatement* Block(Statements&&... statements) { + ast::BlockStatement* Block(const Source& source, Statements&&... statements) { return create( - ast::StatementList{std::forward(statements)...}); + source, ast::StatementList{std::forward(statements)...}); + } + + /// Creates a ast::BlockStatement with input statements + /// @param statements statements of block + /// @returns the block statement pointer + template > + ast::BlockStatement* Block(STATEMENTS&&... statements) { + return create( + ast::StatementList{std::forward(statements)...}); } /// Creates a ast::ElseStatement with input condition and body diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 8288a3f6b4..ca5ea27ed3 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2062,11 +2062,18 @@ bool Resolver::Statement(ast::Statement* stmt) { } if (stmt->Is()) { // Set if we've hit the first continue statement in our parent loop - if (auto* loop_block = - current_block_->FindFirstParent()) { - if (loop_block->FirstContinue() == size_t(~0)) { - const_cast(loop_block) - ->SetFirstContinue(loop_block->Decls().size()); + if (auto* block = + current_block_->FindFirstParent< + sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) { + if (auto* loop_block = block->As()) { + if (loop_block->FirstContinue() == size_t(~0)) { + const_cast(loop_block) + ->SetFirstContinue(loop_block->Decls().size()); + } + } else { + AddError("continuing blocks must not contain a continue statement", + stmt->source()); + return false; } } else { AddError("continue statement must be in a loop", stmt->source()); @@ -2076,6 +2083,17 @@ bool Resolver::Statement(ast::Statement* stmt) { return true; } if (stmt->Is()) { + if (auto* continuing = + sem_statement + ->FindFirstParent()) { + AddError("continuing blocks must not contain a discard statement", + stmt->source()); + if (continuing != sem_statement->Parent()) { + AddNote("see continuing block here", + continuing->Declaration()->source()); + } + return false; + } return true; } if (stmt->Is()) { @@ -4110,6 +4128,17 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) { return false; } + auto* sem = builder_->Sem().Get(ret); + if (auto* continuing = + sem->FindFirstParent()) { + AddError("continuing blocks must not contain a return statement", + ret->source()); + if (continuing != sem->Parent()) { + AddNote("see continuing block here", continuing->Declaration()->source()); + } + return false; + } + return true; } diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 573629abd2..6bb426fd74 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -20,6 +20,7 @@ #include "src/ast/break_statement.h" #include "src/ast/call_statement.h" #include "src/ast/continue_statement.h" +#include "src/ast/discard_statement.h" #include "src/ast/if_statement.h" #include "src/ast/intrinsic_texture_helper_test.h" #include "src/ast/loop_statement.h" @@ -650,6 +651,123 @@ TEST_F(ResolverTest, Stmt_Loop_ContinueInLoopBodyAfterDecl_UsageInContinuing) { EXPECT_TRUE(r()->Resolve()); } +TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Direct) { + // loop { + // continuing { + // return; + // } + // } + + WrapInFunction(Loop( // loop + Block(), // loop block + Block( // loop continuing block + Return(Source{{12, 34}})))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(12:34 error: continuing blocks must not contain a return statement)"); +} + +TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Indirect) { + // loop { + // continuing { + // loop { + // return; + // } + // } + // } + + WrapInFunction(Loop( // outer loop + Block(), // outer loop block + Block(Source{{56, 78}}, // outer loop continuing block + Loop( // inner loop + Block( // inner loop block + Return(Source{{12, 34}})))))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(12:34 error: continuing blocks must not contain a return statement +56:78 note: see continuing block here)"); +} + +TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Direct) { + // loop { + // continuing { + // discard; + // } + // } + + WrapInFunction(Loop( // loop + Block(), // loop block + Block( // loop continuing block + create(Source{{12, 34}})))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(12:34 error: continuing blocks must not contain a discard statement)"); +} + +TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Indirect) { + // loop { + // continuing { + // loop { discard; } + // } + // } + + WrapInFunction(Loop( // outer loop + Block(), // outer loop block + Block(Source{{56, 78}}, // outer loop continuing block + Loop( // inner loop + Block( // inner loop block + create(Source{{12, 34}})))))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(12:34 error: continuing blocks must not contain a discard statement +56:78 note: see continuing block here)"); +} + +TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Direct) { + // loop { + // continuing { + // continue; + // } + // } + + WrapInFunction(Loop( // loop + Block(), // loop block + Block(Source{{56, 78}}, // loop continuing block + create(Source{{12, 34}})))); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + "12:34 error: continuing blocks must not contain a continue statement"); +} + +TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Indirect) { + // loop { + // continuing { + // loop { + // continue; + // } + // } + // } + + WrapInFunction(Loop( // outer loop + Block(), // outer loop block + Block( // outer loop continuing block + Loop( // inner loop + Block( // inner loop block + create(Source{{12, 34}})))))); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) { // for (; 1.0f; ) { // } diff --git a/src/sem/statement.h b/src/sem/statement.h index 821a8a5f9a..0449c3aeb3 100644 --- a/src/sem/statement.h +++ b/src/sem/statement.h @@ -34,6 +34,30 @@ namespace sem { /// Forward declaration class CompoundStatement; +namespace detail { +/// FindFirstParentReturn is a traits helper for determining the return type for +/// the template member function Statement::FindFirstParent(). +/// For zero or multiple template arguments, FindFirstParentReturn::type +/// resolves to CompoundStatement. +template +struct FindFirstParentReturn { + /// The pointer type returned by Statement::FindFirstParent() + using type = CompoundStatement; +}; + +/// A specialization of FindFirstParentReturn for a single template argument. +/// FindFirstParentReturn::type resolves to the single template argument. +template +struct FindFirstParentReturn { + /// The pointer type returned by Statement::FindFirstParent() + using type = T; +}; + +template +using FindFirstParentReturnType = + typename FindFirstParentReturn::type; +} // namespace detail + /// Statement holds the semantic information for a statement. class Statement : public Castable { public: @@ -49,16 +73,18 @@ class Statement : public Castable { const CompoundStatement* Parent() const { return parent_; } /// @returns the closest enclosing parent that satisfies the given predicate, - /// which may be the statement itself, or nullptr if no match is found + /// which may be the statement itself, or nullptr if no match is found. /// @param pred a predicate that the resulting block must satisfy template const CompoundStatement* FindFirstParent(Pred&& pred) const; - /// @returns the statement itself if it matches the template type `T`, - /// otherwise the nearest enclosing statement that matches `T`, or nullptr if - /// there is none. - template - const T* FindFirstParent() const; + /// @returns the closest enclosing parent that is of one of the types in + /// `TYPES`, which may be the statement itself, or nullptr if no match is + /// found. If `TYPES` is a single template argument, the return type is a + /// pointer to that template argument type, otherwise a CompoundStatement + /// pointer is returned. + template + const detail::FindFirstParentReturnType* FindFirstParent() const; /// @return the closest enclosing block for this statement const BlockStatement* Block() const; @@ -99,17 +125,32 @@ const CompoundStatement* Statement::FindFirstParent(Pred&& pred) const { return curr; } -template -const T* Statement::FindFirstParent() const { - if (auto* p = As()) { - return p; - } - const auto* curr = parent_; - while (curr) { - if (auto* p = curr->As()) { +template +const detail::FindFirstParentReturnType* Statement::FindFirstParent() + const { + using ReturnType = detail::FindFirstParentReturnType; + if (sizeof...(TYPES) == 1) { + if (auto* p = As()) { return p; } - curr = curr->Parent(); + const auto* curr = parent_; + while (curr) { + if (auto* p = curr->As()) { + return p; + } + curr = curr->Parent(); + } + } else { + if (IsAnyOf()) { + return As(); + } + const auto* curr = parent_; + while (curr) { + if (curr->IsAnyOf()) { + return curr->As(); + } + curr = curr->Parent(); + } } return nullptr; } diff --git a/src/writer/hlsl/generator_impl_loop_test.cc b/src/writer/hlsl/generator_impl_loop_test.cc index b9d006f064..87b4a3b2d7 100644 --- a/src/writer/hlsl/generator_impl_loop_test.cc +++ b/src/writer/hlsl/generator_impl_loop_test.cc @@ -41,8 +41,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_Loop) { } TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) { + Func("a_statement", {}, ty.void_(), {}); + auto* body = Block(create()); - auto* continuing = Block(Return()); + auto* continuing = Block(create(Call("a_statement"))); auto* l = Loop(body, continuing); WrapInFunction(l); @@ -55,18 +57,20 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) { EXPECT_EQ(gen.result(), R"( while (true) { discard; { - return; + a_statement(); } } )"); } TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) { + Func("a_statement", {}, ty.void_(), {}); + Global("lhs", ty.f32(), ast::StorageClass::kPrivate); Global("rhs", ty.f32(), ast::StorageClass::kPrivate); auto* body = Block(create()); - auto* continuing = Block(Return()); + auto* continuing = Block(create(Call("a_statement"))); auto* inner = Loop(body, continuing); body = Block(inner); @@ -88,7 +92,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) { while (true) { discard; { - return; + a_statement(); } } { @@ -153,7 +157,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoop) { // return; // } - auto* f = For(nullptr, nullptr, nullptr, Block(Return())); + Func("a_statement", {}, ty.void_(), {}); + + auto* f = For(nullptr, nullptr, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -163,7 +170,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoop) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( { for(; ; ) { - return; + a_statement(); } } )"); @@ -174,7 +181,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInit) { // return; // } - auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return())); + Func("a_statement", {}, ty.void_(), {}); + + auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -184,7 +194,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInit) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( { for(int i = 0; ; ) { - return; + a_statement(); } } )"); @@ -194,10 +204,12 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInit) { // for(var b = true && false; ; ) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* multi_stmt = create(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false)); auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr, - Block(Return())); + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -212,7 +224,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInit) { } bool b = (tint_tmp); for(; ; ) { - return; + a_statement(); } } )"); @@ -223,7 +235,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCond) { // return; // } - auto* f = For(nullptr, true, nullptr, Block(Return())); + Func("a_statement", {}, ty.void_(), {}); + + auto* f = For(nullptr, true, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -233,7 +248,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCond) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( { for(; true; ) { - return; + a_statement(); } } )"); @@ -244,9 +259,12 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCond) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* multi_stmt = create(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false)); - auto* f = For(nullptr, multi_stmt, nullptr, Block(Return())); + auto* f = For(nullptr, multi_stmt, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -261,7 +279,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCond) { tint_tmp = false; } if (!((tint_tmp))) { break; } - return; + a_statement(); } } )"); @@ -272,8 +290,11 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCont) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* v = Decl(Var("i", ty.i32())); - auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return())); + auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), + Block(create(Call("a_statement")))); WrapInFunction(v, f); GeneratorImpl& gen = Build(); @@ -283,7 +304,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCont) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( { for(; ; i = (i + 1)) { - return; + a_statement(); } } )"); @@ -294,10 +315,13 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCont) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* multi_stmt = create(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false)); auto* v = Decl(Var("i", ty.bool_())); - auto* f = For(nullptr, nullptr, Assign("i", multi_stmt), Block(Return())); + auto* f = For(nullptr, nullptr, Assign("i", multi_stmt), + Block(create(Call("a_statement")))); WrapInFunction(v, f); GeneratorImpl& gen = Build(); @@ -307,7 +331,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCont) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( { while (true) { - return; + a_statement(); bool tint_tmp = true; if (tint_tmp) { tint_tmp = false; @@ -323,8 +347,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInitCondCont) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)), - Block(Return())); + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -334,7 +360,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInitCondCont) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( { for(int i = 0; true; i = (i + 1)) { - return; + a_statement(); } } )"); @@ -344,6 +370,8 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) { // for(var i = true && false; true && false; i = true && false) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* multi_stmt_a = create(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false)); auto* multi_stmt_b = create(ast::BinaryOp::kLogicalAnd, @@ -352,7 +380,8 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) { Expr(true), Expr(false)); auto* f = For(Decl(Var("i", nullptr, multi_stmt_a)), multi_stmt_b, - Assign("i", multi_stmt_c), Block(Return())); + Assign("i", multi_stmt_c), + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -372,7 +401,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) { tint_tmp_1 = false; } if (!((tint_tmp_1))) { break; } - return; + a_statement(); bool tint_tmp_2 = true; if (tint_tmp_2) { tint_tmp_2 = false; diff --git a/src/writer/msl/generator_impl_loop_test.cc b/src/writer/msl/generator_impl_loop_test.cc index 5c32822644..178df076ee 100644 --- a/src/writer/msl/generator_impl_loop_test.cc +++ b/src/writer/msl/generator_impl_loop_test.cc @@ -40,8 +40,10 @@ TEST_F(MslGeneratorImplTest, Emit_Loop) { } TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) { + Func("a_statement", {}, ty.void_(), {}); + auto* body = Block(create()); - auto* continuing = Block(Return()); + auto* continuing = Block(create(Call("a_statement"))); auto* l = Loop(body, continuing); WrapInFunction(l); @@ -53,18 +55,20 @@ TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) { EXPECT_EQ(gen.result(), R"( while (true) { discard_fragment(); { - return; + a_statement(); } } )"); } TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) { + Func("a_statement", {}, ty.void_(), {}); + Global("lhs", ty.f32(), ast::StorageClass::kPrivate); Global("rhs", ty.f32(), ast::StorageClass::kPrivate); auto* body = Block(create()); - auto* continuing = Block(Return()); + auto* continuing = Block(create(Call("a_statement"))); auto* inner = Loop(body, continuing); body = Block(inner); @@ -83,7 +87,7 @@ TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) { while (true) { discard_fragment(); { - return; + a_statement(); } } { @@ -146,7 +150,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoop) { // return; // } - auto* f = For(nullptr, nullptr, nullptr, Block(Return())); + Func("a_statement", {}, ty.void_(), {}); + + auto* f = For(nullptr, nullptr, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -155,7 +162,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoop) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( for(; ; ) { - return; + a_statement(); } )"); } @@ -165,7 +172,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInit) { // return; // } - auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return())); + Func("a_statement", {}, ty.void_(), {}); + + auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -174,7 +184,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInit) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( for(int i = 0; ; ) { - return; + a_statement(); } )"); } @@ -184,9 +194,13 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) { // for({ignore(1); ignore(2);}; ; ) { // return; // } + + Func("a_statement", {}, ty.void_(), {}); + Global("a", ty.atomic(), ast::StorageClass::kWorkgroup); auto* multi_stmt = Block(Ignore(1), Ignore(2)); - auto* f = For(multi_stmt, nullptr, nullptr, Block(Return())); + auto* f = For(multi_stmt, nullptr, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -200,7 +214,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) { (void) 2; } for(; ; ) { - return; + a_statement(); } } )"); @@ -211,7 +225,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCond) { // return; // } - auto* f = For(nullptr, true, nullptr, Block(Return())); + Func("a_statement", {}, ty.void_(), {}); + + auto* f = For(nullptr, true, nullptr, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -220,7 +237,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCond) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( for(; true; ) { - return; + a_statement(); } )"); } @@ -230,8 +247,11 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* v = Decl(Var("i", ty.i32())); - auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return())); + auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), + Block(create(Call("a_statement")))); WrapInFunction(v, f); GeneratorImpl& gen = Build(); @@ -240,7 +260,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( for(; ; i = (i + 1)) { - return; + a_statement(); } )"); } @@ -251,9 +271,12 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + Global("a", ty.atomic(), ast::StorageClass::kWorkgroup); auto* multi_stmt = Block(Ignore(1), Ignore(2)); - auto* f = For(nullptr, nullptr, multi_stmt, Block(Return())); + auto* f = For(nullptr, nullptr, multi_stmt, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -262,7 +285,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( while (true) { - return; + a_statement(); { (void) 1; (void) 2; @@ -276,8 +299,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInitCondCont) { // return; // } + Func("a_statement", {}, ty.void_(), {}); + auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)), - Block(Return())); + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -286,7 +311,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInitCondCont) { ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); EXPECT_EQ(gen.result(), R"( for(int i = 0; true; i = (i + 1)) { - return; + a_statement(); } )"); } @@ -296,10 +321,14 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) { // for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) { // return; // } + + Func("a_statement", {}, ty.void_(), {}); + Global("a", ty.atomic(), ast::StorageClass::kWorkgroup); auto* multi_stmt_a = Block(Ignore(1), Ignore(2)); auto* multi_stmt_b = Block(Ignore(3), Ignore(4)); - auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, Block(Return())); + auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, + Block(create(Call("a_statement")))); WrapInFunction(f); GeneratorImpl& gen = Build(); @@ -314,7 +343,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) { } while (true) { if (!(true)) { break; } - return; + a_statement(); { (void) 3; (void) 4; diff --git a/src/writer/wgsl/generator_impl_loop_test.cc b/src/writer/wgsl/generator_impl_loop_test.cc index 6c413a89a0..1a28857dd8 100644 --- a/src/writer/wgsl/generator_impl_loop_test.cc +++ b/src/writer/wgsl/generator_impl_loop_test.cc @@ -40,8 +40,10 @@ TEST_F(WgslGeneratorImplTest, Emit_Loop) { } TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) { + Func("a_statement", {}, ty.void_(), {}); + auto* body = Block(create()); - auto* continuing = Block(Return()); + auto* continuing = Block(create(Call("a_statement"))); auto* l = Loop(body, continuing); WrapInFunction(l); @@ -55,7 +57,7 @@ TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) { discard; continuing { - return; + a_statement(); } } )");