Resolver: Validation for continuing blocks

Check they do not contain returns, discards
Check they do not directly contain continues, however a nested loop can have its own continue.

Bug: chromium:1229976
Change-Id: Ia3c4ac118ffdaa6cca6025366c19f9897718c930
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58384
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-07-17 18:09:26 +00:00 committed by Tint LUCI CQ
parent 725159c17c
commit cd7eb4f968
7 changed files with 335 additions and 71 deletions

View File

@ -105,6 +105,14 @@ class CloneContext;
/// To construct a Program, populate the builder and then `std::move` it to a /// To construct a Program, populate the builder and then `std::move` it to a
/// Program. /// Program.
class ProgramBuilder { class ProgramBuilder {
/// A helper used to disable overloads if the first type in `TYPES` is a
/// Source. Used to avoid ambiguities in overloads that take a Source as the
/// first parameter and those that perfectly-forward the first argument.
template <typename... TYPES>
using DisableIfSource = traits::EnableIfIsNotType<
traits::Decay<traits::NthTypeOf<0, TYPES..., void>>,
Source>;
/// VarOptionals is a helper for accepting a number of optional, extra /// VarOptionals is a helper for accepting a number of optional, extra
/// arguments for Var() and Global(). /// arguments for Var() and Global().
struct VarOptionals { struct VarOptionals {
@ -1383,7 +1391,7 @@ class ProgramBuilder {
/// global variable with the ast::Module. /// global variable with the ast::Module.
template <typename NAME, template <typename NAME,
typename... OPTIONAL, typename... OPTIONAL,
traits::EnableIfIsNotType<traits::Decay<NAME>, Source>* = nullptr> typename = DisableIfSource<NAME>>
ast::Variable* Global(NAME&& name, ast::Variable* Global(NAME&& name,
const ast::Type* type, const ast::Type* type,
OPTIONAL&&... optional) { OPTIONAL&&... optional) {
@ -1504,9 +1512,7 @@ class ProgramBuilder {
/// @param args the function call arguments /// @param args the function call arguments
/// @returns a `ast::CallExpression` to the function `func`, with the /// @returns a `ast::CallExpression` to the function `func`, with the
/// arguments of `args` converted to `ast::Expression`s using `Expr()`. /// arguments of `args` converted to `ast::Expression`s using `Expr()`.
template <typename NAME, template <typename NAME, typename... ARGS, typename = DisableIfSource<NAME>>
typename... ARGS,
traits::EnableIfIsNotType<traits::Decay<NAME>, Source>* = nullptr>
ast::CallExpression* Call(NAME&& func, ARGS&&... args) { ast::CallExpression* Call(NAME&& func, ARGS&&... args) {
return create<ast::CallExpression>(Expr(func), return create<ast::CallExpression>(Expr(func),
ExprList(std::forward<ARGS>(args)...)); ExprList(std::forward<ARGS>(args)...));
@ -1781,7 +1787,7 @@ class ProgramBuilder {
/// Creates an ast::ReturnStatement with the given return value /// Creates an ast::ReturnStatement with the given return value
/// @param val the return value /// @param val the return value
/// @returns the return statement pointer /// @returns the return statement pointer
template <typename EXPR> template <typename EXPR, typename = DisableIfSource<EXPR>>
ast::ReturnStatement* Return(EXPR&& val) { ast::ReturnStatement* Return(EXPR&& val) {
return create<ast::ReturnStatement>(Expr(std::forward<EXPR>(val))); return create<ast::ReturnStatement>(Expr(std::forward<EXPR>(val)));
} }
@ -1886,12 +1892,22 @@ class ProgramBuilder {
} }
/// Creates a ast::BlockStatement with input statements /// Creates a ast::BlockStatement with input statements
/// @param source the source information for the block
/// @param statements statements of block /// @param statements statements of block
/// @returns the block statement pointer /// @returns the block statement pointer
template <typename... Statements> template <typename... Statements>
ast::BlockStatement* Block(Statements&&... statements) { ast::BlockStatement* Block(const Source& source, Statements&&... statements) {
return create<ast::BlockStatement>( return create<ast::BlockStatement>(
ast::StatementList{std::forward<Statements>(statements)...}); source, ast::StatementList{std::forward<Statements>(statements)...});
}
/// Creates a ast::BlockStatement with input statements
/// @param statements statements of block
/// @returns the block statement pointer
template <typename... STATEMENTS, typename = DisableIfSource<STATEMENTS...>>
ast::BlockStatement* Block(STATEMENTS&&... statements) {
return create<ast::BlockStatement>(
ast::StatementList{std::forward<STATEMENTS>(statements)...});
} }
/// Creates a ast::ElseStatement with input condition and body /// Creates a ast::ElseStatement with input condition and body

View File

@ -2062,11 +2062,18 @@ bool Resolver::Statement(ast::Statement* stmt) {
} }
if (stmt->Is<ast::ContinueStatement>()) { if (stmt->Is<ast::ContinueStatement>()) {
// Set if we've hit the first continue statement in our parent loop // Set if we've hit the first continue statement in our parent loop
if (auto* loop_block = if (auto* block =
current_block_->FindFirstParent<sem::LoopBlockStatement>()) { current_block_->FindFirstParent<
if (loop_block->FirstContinue() == size_t(~0)) { sem::LoopBlockStatement, sem::LoopContinuingBlockStatement>()) {
const_cast<sem::LoopBlockStatement*>(loop_block) if (auto* loop_block = block->As<sem::LoopBlockStatement>()) {
->SetFirstContinue(loop_block->Decls().size()); if (loop_block->FirstContinue() == size_t(~0)) {
const_cast<sem::LoopBlockStatement*>(loop_block)
->SetFirstContinue(loop_block->Decls().size());
}
} else {
AddError("continuing blocks must not contain a continue statement",
stmt->source());
return false;
} }
} else { } else {
AddError("continue statement must be in a loop", stmt->source()); AddError("continue statement must be in a loop", stmt->source());
@ -2076,6 +2083,17 @@ bool Resolver::Statement(ast::Statement* stmt) {
return true; return true;
} }
if (stmt->Is<ast::DiscardStatement>()) { if (stmt->Is<ast::DiscardStatement>()) {
if (auto* continuing =
sem_statement
->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a discard statement",
stmt->source());
if (continuing != sem_statement->Parent()) {
AddNote("see continuing block here",
continuing->Declaration()->source());
}
return false;
}
return true; return true;
} }
if (stmt->Is<ast::FallthroughStatement>()) { if (stmt->Is<ast::FallthroughStatement>()) {
@ -4110,6 +4128,17 @@ bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) {
return false; return false;
} }
auto* sem = builder_->Sem().Get(ret);
if (auto* continuing =
sem->FindFirstParent<sem::LoopContinuingBlockStatement>()) {
AddError("continuing blocks must not contain a return statement",
ret->source());
if (continuing != sem->Parent()) {
AddNote("see continuing block here", continuing->Declaration()->source());
}
return false;
}
return true; return true;
} }

View File

@ -20,6 +20,7 @@
#include "src/ast/break_statement.h" #include "src/ast/break_statement.h"
#include "src/ast/call_statement.h" #include "src/ast/call_statement.h"
#include "src/ast/continue_statement.h" #include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/if_statement.h" #include "src/ast/if_statement.h"
#include "src/ast/intrinsic_texture_helper_test.h" #include "src/ast/intrinsic_texture_helper_test.h"
#include "src/ast/loop_statement.h" #include "src/ast/loop_statement.h"
@ -650,6 +651,123 @@ TEST_F(ResolverTest, Stmt_Loop_ContinueInLoopBodyAfterDecl_UsageInContinuing) {
EXPECT_TRUE(r()->Resolve()); EXPECT_TRUE(r()->Resolve());
} }
TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Direct) {
// loop {
// continuing {
// return;
// }
// }
WrapInFunction(Loop( // loop
Block(), // loop block
Block( // loop continuing block
Return(Source{{12, 34}}))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: continuing blocks must not contain a return statement)");
}
TEST_F(ResolverTest, Stmt_Loop_ReturnInContinuing_Indirect) {
// loop {
// continuing {
// loop {
// return;
// }
// }
// }
WrapInFunction(Loop( // outer loop
Block(), // outer loop block
Block(Source{{56, 78}}, // outer loop continuing block
Loop( // inner loop
Block( // inner loop block
Return(Source{{12, 34}}))))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: continuing blocks must not contain a return statement
56:78 note: see continuing block here)");
}
TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Direct) {
// loop {
// continuing {
// discard;
// }
// }
WrapInFunction(Loop( // loop
Block(), // loop block
Block( // loop continuing block
create<ast::DiscardStatement>(Source{{12, 34}}))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: continuing blocks must not contain a discard statement)");
}
TEST_F(ResolverTest, Stmt_Loop_DiscardInContinuing_Indirect) {
// loop {
// continuing {
// loop { discard; }
// }
// }
WrapInFunction(Loop( // outer loop
Block(), // outer loop block
Block(Source{{56, 78}}, // outer loop continuing block
Loop( // inner loop
Block( // inner loop block
create<ast::DiscardStatement>(Source{{12, 34}}))))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: continuing blocks must not contain a discard statement
56:78 note: see continuing block here)");
}
TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Direct) {
// loop {
// continuing {
// continue;
// }
// }
WrapInFunction(Loop( // loop
Block(), // loop block
Block(Source{{56, 78}}, // loop continuing block
create<ast::ContinueStatement>(Source{{12, 34}}))));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: continuing blocks must not contain a continue statement");
}
TEST_F(ResolverTest, Stmt_Loop_ContinueInContinuing_Indirect) {
// loop {
// continuing {
// loop {
// continue;
// }
// }
// }
WrapInFunction(Loop( // outer loop
Block(), // outer loop block
Block( // outer loop continuing block
Loop( // inner loop
Block( // inner loop block
create<ast::ContinueStatement>(Source{{12, 34}}))))));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) { TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) {
// for (; 1.0f; ) { // for (; 1.0f; ) {
// } // }

View File

@ -34,6 +34,30 @@ namespace sem {
/// Forward declaration /// Forward declaration
class CompoundStatement; class CompoundStatement;
namespace detail {
/// FindFirstParentReturn is a traits helper for determining the return type for
/// the template member function Statement::FindFirstParent().
/// For zero or multiple template arguments, FindFirstParentReturn::type
/// resolves to CompoundStatement.
template <typename... TYPES>
struct FindFirstParentReturn {
/// The pointer type returned by Statement::FindFirstParent()
using type = CompoundStatement;
};
/// A specialization of FindFirstParentReturn for a single template argument.
/// FindFirstParentReturn::type resolves to the single template argument.
template <typename T>
struct FindFirstParentReturn<T> {
/// The pointer type returned by Statement::FindFirstParent()
using type = T;
};
template <typename... TYPES>
using FindFirstParentReturnType =
typename FindFirstParentReturn<TYPES...>::type;
} // namespace detail
/// Statement holds the semantic information for a statement. /// Statement holds the semantic information for a statement.
class Statement : public Castable<Statement, Node> { class Statement : public Castable<Statement, Node> {
public: public:
@ -49,16 +73,18 @@ class Statement : public Castable<Statement, Node> {
const CompoundStatement* Parent() const { return parent_; } const CompoundStatement* Parent() const { return parent_; }
/// @returns the closest enclosing parent that satisfies the given predicate, /// @returns the closest enclosing parent that satisfies the given predicate,
/// which may be the statement itself, or nullptr if no match is found /// which may be the statement itself, or nullptr if no match is found.
/// @param pred a predicate that the resulting block must satisfy /// @param pred a predicate that the resulting block must satisfy
template <typename Pred> template <typename Pred>
const CompoundStatement* FindFirstParent(Pred&& pred) const; const CompoundStatement* FindFirstParent(Pred&& pred) const;
/// @returns the statement itself if it matches the template type `T`, /// @returns the closest enclosing parent that is of one of the types in
/// otherwise the nearest enclosing statement that matches `T`, or nullptr if /// `TYPES`, which may be the statement itself, or nullptr if no match is
/// there is none. /// found. If `TYPES` is a single template argument, the return type is a
template <typename T> /// pointer to that template argument type, otherwise a CompoundStatement
const T* FindFirstParent() const; /// pointer is returned.
template <typename... TYPES>
const detail::FindFirstParentReturnType<TYPES...>* FindFirstParent() const;
/// @return the closest enclosing block for this statement /// @return the closest enclosing block for this statement
const BlockStatement* Block() const; const BlockStatement* Block() const;
@ -99,17 +125,32 @@ const CompoundStatement* Statement::FindFirstParent(Pred&& pred) const {
return curr; return curr;
} }
template <typename T> template <typename... TYPES>
const T* Statement::FindFirstParent() const { const detail::FindFirstParentReturnType<TYPES...>* Statement::FindFirstParent()
if (auto* p = As<T>()) { const {
return p; using ReturnType = detail::FindFirstParentReturnType<TYPES...>;
} if (sizeof...(TYPES) == 1) {
const auto* curr = parent_; if (auto* p = As<ReturnType>()) {
while (curr) {
if (auto* p = curr->As<T>()) {
return p; return p;
} }
curr = curr->Parent(); const auto* curr = parent_;
while (curr) {
if (auto* p = curr->As<ReturnType>()) {
return p;
}
curr = curr->Parent();
}
} else {
if (IsAnyOf<TYPES...>()) {
return As<ReturnType>();
}
const auto* curr = parent_;
while (curr) {
if (curr->IsAnyOf<TYPES...>()) {
return curr->As<ReturnType>();
}
curr = curr->Parent();
}
} }
return nullptr; return nullptr;
} }

View File

@ -41,8 +41,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_Loop) {
} }
TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) { TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) {
Func("a_statement", {}, ty.void_(), {});
auto* body = Block(create<ast::DiscardStatement>()); auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(Return()); auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); WrapInFunction(l);
@ -55,18 +57,20 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopWithContinuing) {
EXPECT_EQ(gen.result(), R"( while (true) { EXPECT_EQ(gen.result(), R"( while (true) {
discard; discard;
{ {
return; a_statement();
} }
} }
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) { TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
Func("a_statement", {}, ty.void_(), {});
Global("lhs", ty.f32(), ast::StorageClass::kPrivate); Global("lhs", ty.f32(), ast::StorageClass::kPrivate);
Global("rhs", ty.f32(), ast::StorageClass::kPrivate); Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
auto* body = Block(create<ast::DiscardStatement>()); auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(Return()); auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* inner = Loop(body, continuing); auto* inner = Loop(body, continuing);
body = Block(inner); body = Block(inner);
@ -88,7 +92,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_LoopNestedWithContinuing) {
while (true) { while (true) {
discard; discard;
{ {
return; a_statement();
} }
} }
{ {
@ -153,7 +157,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoop) {
// return; // return;
// } // }
auto* f = For(nullptr, nullptr, nullptr, Block(Return())); Func("a_statement", {}, ty.void_(), {});
auto* f = For(nullptr, nullptr, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -163,7 +170,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoop) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
for(; ; ) { for(; ; ) {
return; a_statement();
} }
} }
)"); )");
@ -174,7 +181,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInit) {
// return; // return;
// } // }
auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return())); Func("a_statement", {}, ty.void_(), {});
auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -184,7 +194,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInit) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
for(int i = 0; ; ) { for(int i = 0; ; ) {
return; a_statement();
} }
} }
)"); )");
@ -194,10 +204,12 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInit) {
// for(var b = true && false; ; ) { // for(var b = true && false; ; ) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false)); Expr(true), Expr(false));
auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr, auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr,
Block(Return())); Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -212,7 +224,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInit) {
} }
bool b = (tint_tmp); bool b = (tint_tmp);
for(; ; ) { for(; ; ) {
return; a_statement();
} }
} }
)"); )");
@ -223,7 +235,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCond) {
// return; // return;
// } // }
auto* f = For(nullptr, true, nullptr, Block(Return())); Func("a_statement", {}, ty.void_(), {});
auto* f = For(nullptr, true, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -233,7 +248,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCond) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
for(; true; ) { for(; true; ) {
return; a_statement();
} }
} }
)"); )");
@ -244,9 +259,12 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCond) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false)); Expr(true), Expr(false));
auto* f = For(nullptr, multi_stmt, nullptr, Block(Return())); auto* f = For(nullptr, multi_stmt, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -261,7 +279,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCond) {
tint_tmp = false; tint_tmp = false;
} }
if (!((tint_tmp))) { break; } if (!((tint_tmp))) { break; }
return; a_statement();
} }
} }
)"); )");
@ -272,8 +290,11 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCont) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* v = Decl(Var("i", ty.i32())); auto* v = Decl(Var("i", ty.i32()));
auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return())); auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)),
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(v, f); WrapInFunction(v, f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -283,7 +304,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
for(; ; i = (i + 1)) { for(; ; i = (i + 1)) {
return; a_statement();
} }
} }
)"); )");
@ -294,10 +315,13 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCont) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, auto* multi_stmt = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false)); Expr(true), Expr(false));
auto* v = Decl(Var("i", ty.bool_())); auto* v = Decl(Var("i", ty.bool_()));
auto* f = For(nullptr, nullptr, Assign("i", multi_stmt), Block(Return())); auto* f = For(nullptr, nullptr, Assign("i", multi_stmt),
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(v, f); WrapInFunction(v, f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -307,7 +331,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
while (true) { while (true) {
return; a_statement();
bool tint_tmp = true; bool tint_tmp = true;
if (tint_tmp) { if (tint_tmp) {
tint_tmp = false; tint_tmp = false;
@ -323,8 +347,10 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInitCondCont) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)), auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)),
Block(Return())); Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -334,7 +360,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithSimpleInitCondCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
for(int i = 0; true; i = (i + 1)) { for(int i = 0; true; i = (i + 1)) {
return; a_statement();
} }
} }
)"); )");
@ -344,6 +370,8 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) {
// for(var i = true && false; true && false; i = true && false) { // for(var i = true && false; true && false; i = true && false) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* multi_stmt_a = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, auto* multi_stmt_a = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
Expr(true), Expr(false)); Expr(true), Expr(false));
auto* multi_stmt_b = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, auto* multi_stmt_b = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd,
@ -352,7 +380,8 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) {
Expr(true), Expr(false)); Expr(true), Expr(false));
auto* f = For(Decl(Var("i", nullptr, multi_stmt_a)), multi_stmt_b, auto* f = For(Decl(Var("i", nullptr, multi_stmt_a)), multi_stmt_b,
Assign("i", multi_stmt_c), Block(Return())); Assign("i", multi_stmt_c),
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -372,7 +401,7 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) {
tint_tmp_1 = false; tint_tmp_1 = false;
} }
if (!((tint_tmp_1))) { break; } if (!((tint_tmp_1))) { break; }
return; a_statement();
bool tint_tmp_2 = true; bool tint_tmp_2 = true;
if (tint_tmp_2) { if (tint_tmp_2) {
tint_tmp_2 = false; tint_tmp_2 = false;

View File

@ -40,8 +40,10 @@ TEST_F(MslGeneratorImplTest, Emit_Loop) {
} }
TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) { TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) {
Func("a_statement", {}, ty.void_(), {});
auto* body = Block(create<ast::DiscardStatement>()); auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(Return()); auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); WrapInFunction(l);
@ -53,18 +55,20 @@ TEST_F(MslGeneratorImplTest, Emit_LoopWithContinuing) {
EXPECT_EQ(gen.result(), R"( while (true) { EXPECT_EQ(gen.result(), R"( while (true) {
discard_fragment(); discard_fragment();
{ {
return; a_statement();
} }
} }
)"); )");
} }
TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) { TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
Func("a_statement", {}, ty.void_(), {});
Global("lhs", ty.f32(), ast::StorageClass::kPrivate); Global("lhs", ty.f32(), ast::StorageClass::kPrivate);
Global("rhs", ty.f32(), ast::StorageClass::kPrivate); Global("rhs", ty.f32(), ast::StorageClass::kPrivate);
auto* body = Block(create<ast::DiscardStatement>()); auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(Return()); auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* inner = Loop(body, continuing); auto* inner = Loop(body, continuing);
body = Block(inner); body = Block(inner);
@ -83,7 +87,7 @@ TEST_F(MslGeneratorImplTest, Emit_LoopNestedWithContinuing) {
while (true) { while (true) {
discard_fragment(); discard_fragment();
{ {
return; a_statement();
} }
} }
{ {
@ -146,7 +150,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoop) {
// return; // return;
// } // }
auto* f = For(nullptr, nullptr, nullptr, Block(Return())); Func("a_statement", {}, ty.void_(), {});
auto* f = For(nullptr, nullptr, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -155,7 +162,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoop) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(; ; ) { EXPECT_EQ(gen.result(), R"( for(; ; ) {
return; a_statement();
} }
)"); )");
} }
@ -165,7 +172,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInit) {
// return; // return;
// } // }
auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr, Block(Return())); Func("a_statement", {}, ty.void_(), {});
auto* f = For(Decl(Var("i", ty.i32())), nullptr, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -174,7 +184,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInit) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(int i = 0; ; ) { EXPECT_EQ(gen.result(), R"( for(int i = 0; ; ) {
return; a_statement();
} }
)"); )");
} }
@ -184,9 +194,13 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
// for({ignore(1); ignore(2);}; ; ) { // for({ignore(1); ignore(2);}; ; ) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup); Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = Block(Ignore(1), Ignore(2)); auto* multi_stmt = Block(Ignore(1), Ignore(2));
auto* f = For(multi_stmt, nullptr, nullptr, Block(Return())); auto* f = For(multi_stmt, nullptr, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -200,7 +214,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
(void) 2; (void) 2;
} }
for(; ; ) { for(; ; ) {
return; a_statement();
} }
} }
)"); )");
@ -211,7 +225,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCond) {
// return; // return;
// } // }
auto* f = For(nullptr, true, nullptr, Block(Return())); Func("a_statement", {}, ty.void_(), {});
auto* f = For(nullptr, true, nullptr,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -220,7 +237,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCond) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(; true; ) { EXPECT_EQ(gen.result(), R"( for(; true; ) {
return; a_statement();
} }
)"); )");
} }
@ -230,8 +247,11 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* v = Decl(Var("i", ty.i32())); auto* v = Decl(Var("i", ty.i32()));
auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)), Block(Return())); auto* f = For(nullptr, nullptr, Assign("i", Add("i", 1)),
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(v, f); WrapInFunction(v, f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -240,7 +260,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(; ; i = (i + 1)) { EXPECT_EQ(gen.result(), R"( for(; ; i = (i + 1)) {
return; a_statement();
} }
)"); )");
} }
@ -251,9 +271,12 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup); Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = Block(Ignore(1), Ignore(2)); auto* multi_stmt = Block(Ignore(1), Ignore(2));
auto* f = For(nullptr, nullptr, multi_stmt, Block(Return())); auto* f = For(nullptr, nullptr, multi_stmt,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -262,7 +285,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while (true) { EXPECT_EQ(gen.result(), R"( while (true) {
return; a_statement();
{ {
(void) 1; (void) 1;
(void) 2; (void) 2;
@ -276,8 +299,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInitCondCont) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)), auto* f = For(Decl(Var("i", ty.i32())), true, Assign("i", Add("i", 1)),
Block(Return())); Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -286,7 +311,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInitCondCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( for(int i = 0; true; i = (i + 1)) { EXPECT_EQ(gen.result(), R"( for(int i = 0; true; i = (i + 1)) {
return; a_statement();
} }
)"); )");
} }
@ -296,10 +321,14 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
// for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) { // for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) {
// return; // return;
// } // }
Func("a_statement", {}, ty.void_(), {});
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup); Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt_a = Block(Ignore(1), Ignore(2)); auto* multi_stmt_a = Block(Ignore(1), Ignore(2));
auto* multi_stmt_b = Block(Ignore(3), Ignore(4)); auto* multi_stmt_b = Block(Ignore(3), Ignore(4));
auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, Block(Return())); auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b,
Block(create<ast::CallStatement>(Call("a_statement"))));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -314,7 +343,7 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
} }
while (true) { while (true) {
if (!(true)) { break; } if (!(true)) { break; }
return; a_statement();
{ {
(void) 3; (void) 3;
(void) 4; (void) 4;

View File

@ -40,8 +40,10 @@ TEST_F(WgslGeneratorImplTest, Emit_Loop) {
} }
TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) { TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) {
Func("a_statement", {}, ty.void_(), {});
auto* body = Block(create<ast::DiscardStatement>()); auto* body = Block(create<ast::DiscardStatement>());
auto* continuing = Block(Return()); auto* continuing = Block(create<ast::CallStatement>(Call("a_statement")));
auto* l = Loop(body, continuing); auto* l = Loop(body, continuing);
WrapInFunction(l); WrapInFunction(l);
@ -55,7 +57,7 @@ TEST_F(WgslGeneratorImplTest, Emit_LoopWithContinuing) {
discard; discard;
continuing { continuing {
return; a_statement();
} }
} }
)"); )");