Add while statement parsing.

This CL adds parsing for the WGSL `while` statement.

Bug: tint:1425
Change-Id: Ibce5e28568935ca4f51b5ac33e7a60af7a916b4a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93540
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2022-06-16 12:01:27 +00:00 committed by Dawn LUCI CQ
parent d10f3f4437
commit 49d1a2d950
60 changed files with 2151 additions and 13 deletions

View File

@ -58,6 +58,26 @@ sem::ForLoopStatement {
}
```
## while
WGSL:
```
while (condition) {
statement;
}
```
Semantic tree:
```
sem::WhileStatement {
sem::Expression condition
sem::LoopBlockStatement {
sem::Statement statement
}
}
```
## loop
WGSL:

View File

@ -336,6 +336,8 @@ libtint_source_set("libtint_core_all_src") {
"ast/vector.h",
"ast/void.cc",
"ast/void.h",
"ast/while_statement.cc",
"ast/while_statement.h",
"ast/workgroup_attribute.cc",
"ast/workgroup_attribute.h",
"castable.cc",
@ -436,6 +438,7 @@ libtint_source_set("libtint_core_all_src") {
"sem/u32.h",
"sem/vector.h",
"sem/void.h",
"sem/while_statement.h",
"source.cc",
"source.h",
"symbol.cc",
@ -523,6 +526,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/vectorize_scalar_matrix_constructors.h",
"transform/vertex_pulling.cc",
"transform/vertex_pulling.h",
"transform/while_to_loop.cc",
"transform/while_to_loop.h",
"transform/wrap_arrays_in_structs.cc",
"transform/wrap_arrays_in_structs.h",
"transform/zero_init_workgroup_memory.cc",
@ -666,6 +671,8 @@ libtint_source_set("libtint_sem_src") {
"sem/vector.h",
"sem/void.cc",
"sem/void.h",
"sem/while_statement.cc",
"sem/while_statement.h",
]
public_deps = [ ":libtint_core_all_src" ]

View File

@ -223,6 +223,8 @@ set(TINT_LIB_SRCS
ast/vector.h
ast/void.cc
ast/void.h
ast/while_statement.cc
ast/while_statement.h
ast/workgroup_attribute.cc
ast/workgroup_attribute.h
castable.cc
@ -365,6 +367,8 @@ set(TINT_LIB_SRCS
sem/vector.h
sem/void.cc
sem/void.h
sem/while_statement.cc
sem/while_statement.h
symbol_table.cc
symbol_table.h
symbol.cc
@ -450,6 +454,8 @@ set(TINT_LIB_SRCS
transform/vectorize_scalar_matrix_constructors.h
transform/vertex_pulling.cc
transform/vertex_pulling.h
transform/while_to_loop.cc
transform/while_to_loop.h
transform/wrap_arrays_in_structs.cc
transform/wrap_arrays_in_structs.h
transform/zero_init_workgroup_memory.cc
@ -743,6 +749,7 @@ if(TINT_BUILD_TESTS)
ast/variable_decl_statement_test.cc
ast/variable_test.cc
ast/vector_test.cc
ast/while_statement_test.cc
ast/workgroup_attribute_test.cc
castable_test.cc
clone_context_test.cc
@ -987,6 +994,7 @@ if(TINT_BUILD_TESTS)
reader/wgsl/parser_impl_variable_ident_decl_test.cc
reader/wgsl/parser_impl_variable_stmt_test.cc
reader/wgsl/parser_impl_variable_qualifier_test.cc
reader/wgsl/parser_impl_while_stmt_test.cc
reader/wgsl/token_test.cc
)
endif()
@ -1102,6 +1110,7 @@ if(TINT_BUILD_TESTS)
transform/var_for_dynamic_index_test.cc
transform/vectorize_scalar_matrix_constructors_test.cc
transform/vertex_pulling_test.cc
transform/while_to_loop_test.cc
transform/wrap_arrays_in_structs_test.cc
transform/zero_init_workgroup_memory_test.cc
transform/utils/get_insertion_point_test.cc

View File

@ -0,0 +1,48 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/while_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::ast::WhileStatement);
namespace tint::ast {
WhileStatement::WhileStatement(ProgramID pid,
const Source& src,
const Expression* cond,
const BlockStatement* b)
: Base(pid, src), condition(cond), body(b) {
TINT_ASSERT(AST, cond);
TINT_ASSERT(AST, body);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, condition, program_id);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(AST, body, program_id);
}
WhileStatement::WhileStatement(WhileStatement&&) = default;
WhileStatement::~WhileStatement() = default;
const WhileStatement* WhileStatement::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto src = ctx->Clone(source);
auto* cond = ctx->Clone(condition);
auto* b = ctx->Clone(body);
return ctx->dst->create<WhileStatement>(src, cond, b);
}
} // namespace tint::ast

View File

@ -0,0 +1,55 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_AST_WHILE_STATEMENT_H_
#define SRC_TINT_AST_WHILE_STATEMENT_H_
#include "src/tint/ast/block_statement.h"
namespace tint::ast {
class Expression;
/// A while loop statement
class WhileStatement final : public Castable<WhileStatement, Statement> {
public:
/// Constructor
/// @param program_id the identifier of the program that owns this node
/// @param source the for loop statement source
/// @param condition the optional loop condition expression
/// @param body the loop body
WhileStatement(ProgramID program_id,
Source const& source,
const Expression* condition,
const BlockStatement* body);
/// Move constructor
WhileStatement(WhileStatement&&);
~WhileStatement() override;
/// Clones this node and all transitive child nodes using the `CloneContext`
/// `ctx`.
/// @param ctx the clone context
/// @return the newly cloned node
const WhileStatement* Clone(CloneContext* ctx) const override;
/// The condition expression
const Expression* const condition;
/// The loop body block
const BlockStatement* const body;
};
} // namespace tint::ast
#endif // SRC_TINT_AST_WHILE_STATEMENT_H_

View File

@ -0,0 +1,85 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest-spi.h"
#include "src/tint/ast/binary_expression.h"
#include "src/tint/ast/test_helper.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::ast {
namespace {
using WhileStatementTest = TestHelper;
TEST_F(WhileStatementTest, Creation) {
auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
auto* body = Block(Return());
auto* l = While(cond, body);
EXPECT_EQ(l->condition, cond);
EXPECT_EQ(l->body, body);
}
TEST_F(WhileStatementTest, Creation_WithSource) {
auto* cond = create<BinaryExpression>(BinaryOp::kLessThan, Expr("i"), Expr(5_u));
auto* body = Block(Return());
auto* l = While(Source{{20u, 2u}}, cond, body);
auto src = l->source;
EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u);
}
TEST_F(WhileStatementTest, Assert_Null_Cond) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
auto* body = b.Block();
b.While(nullptr, body);
},
"internal compiler error");
}
TEST_F(WhileStatementTest, Assert_Null_Body) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
auto* cond = b.create<BinaryExpression>(BinaryOp::kLessThan, b.Expr("i"), b.Expr(5_u));
b.While(cond, nullptr);
},
"internal compiler error");
}
TEST_F(WhileStatementTest, Assert_DifferentProgramID_Condition) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.While(b2.Expr(true), b1.Block());
},
"internal compiler error");
}
TEST_F(WhileStatementTest, Assert_DifferentProgramID_Body) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b1;
ProgramBuilder b2;
b1.While(b1.Expr(true), b2.Block());
},
"internal compiler error");
}
} // namespace
} // namespace tint::ast

View File

@ -76,6 +76,7 @@
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/vector.h"
#include "src/tint/ast/void.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/number.h"
#include "src/tint/program.h"
@ -2339,6 +2340,27 @@ class ProgramBuilder {
return create<ast::ForLoopStatement>(init, Expr(std::forward<COND>(cond)), cont, body);
}
/// Creates a ast::WhileStatement with input body and condition.
/// @param source the source information
/// @param cond the loop condition
/// @param body the loop body
/// @returns the while statement pointer
template <typename COND>
const ast::WhileStatement* While(const Source& source,
COND&& cond,
const ast::BlockStatement* body) {
return create<ast::WhileStatement>(source, Expr(std::forward<COND>(cond)), body);
}
/// Creates a ast::WhileStatement with given condition and body.
/// @param cond the condition
/// @param body the loop body
/// @returns the while loop statement pointer
template <typename COND>
const ast::WhileStatement* While(COND&& cond, const ast::BlockStatement* body) {
return create<ast::WhileStatement>(Expr(std::forward<COND>(cond)), body);
}
/// Creates a ast::VariableDeclStatement for the input variable
/// @param source the source information
/// @param var the variable to wrap in a decl statement

View File

@ -1271,6 +1271,9 @@ Token Lexer::check_keyword(const Source& source, std::string_view str) {
if (str == "vec4") {
return {Token::Type::kVec4, source, "vec4"};
}
if (str == "while") {
return {Token::Type::kWhile, source, "while"};
}
if (str == "workgroup") {
return {Token::Type::kWorkgroup, source, "workgroup"};
}

View File

@ -990,6 +990,7 @@ INSTANTIATE_TEST_SUITE_P(
TokenData{"vec2", Token::Type::kVec2},
TokenData{"vec3", Token::Type::kVec3},
TokenData{"vec4", Token::Type::kVec4},
TokenData{"while", Token::Type::kWhile},
TokenData{"workgroup", Token::Type::kWorkgroup}));
} // namespace

View File

@ -1597,6 +1597,7 @@ Expect<ast::StatementList> ParserImpl::expect_statements() {
// | switch_stmt
// | loop_stmt
// | for_stmt
// | while_stmt
// | non_block_statement
// : return_stmt SEMICOLON
// | func_call_stmt SEMICOLON
@ -1654,6 +1655,14 @@ Maybe<const ast::Statement*> ParserImpl::statement() {
return stmt_for.value;
}
auto stmt_while = while_stmt();
if (stmt_while.errored) {
return Failure::kErrored;
}
if (stmt_while.matched) {
return stmt_while.value;
}
if (peek_is(Token::Type::kBraceLeft)) {
auto body = expect_body_stmt();
if (body.errored) {
@ -2191,6 +2200,30 @@ Maybe<const ast::ForLoopStatement*> ParserImpl::for_stmt() {
create<ast::BlockStatement>(stmts.value));
}
// while_statement
// : WHILE expression compound_statement
Maybe<const ast::WhileStatement*> ParserImpl::while_stmt() {
Source source;
if (!match(Token::Type::kWhile, &source)) {
return Failure::kNoMatch;
}
auto condition = logical_or_expression();
if (condition.errored) {
return Failure::kErrored;
}
if (!condition.matched) {
return add_error(peek(), "unable to parse while condition expression");
}
auto body = expect_body_stmt();
if (body.errored) {
return Failure::kErrored;
}
return create<ast::WhileStatement>(source, condition.value, body.value);
}
// func_call_stmt
// : IDENT argument_expression_list
Maybe<const ast::CallStatement*> ParserImpl::func_call_stmt() {

View File

@ -527,6 +527,9 @@ class ParserImpl {
/// Parses a `for_stmt` grammar element
/// @returns the parsed for loop or nullptr
Maybe<const ast::ForLoopStatement*> for_stmt();
/// Parses a `while_stmt` grammar element
/// @returns the parsed while loop or nullptr
Maybe<const ast::WhileStatement*> while_stmt();
/// Parses a `continuing_stmt` grammar element
/// @returns the parsed statements
Maybe<const ast::BlockStatement*> continuing_stmt();

View File

@ -103,8 +103,7 @@ INSTANTIATE_TEST_SUITE_P(ParserImplReservedKeywordTest,
"unless",
"using",
"vec",
"void",
"while"));
"void"));
} // namespace
} // namespace tint::reader::wgsl

View File

@ -0,0 +1,157 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/reader/wgsl/parser_impl_test_helper.h"
#include "src/tint/ast/discard_statement.h"
namespace tint::reader::wgsl {
namespace {
using WhileStmtTest = ParserImplTest;
// Test an empty while loop.
TEST_F(WhileStmtTest, Empty) {
auto p = parser("while (true) { }");
auto wl = p->while_stmt();
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(wl.errored);
ASSERT_TRUE(wl.matched);
EXPECT_TRUE(Is<ast::Expression>(wl->condition));
EXPECT_TRUE(wl->body->Empty());
}
// Test a while loop with non-empty body.
TEST_F(WhileStmtTest, Body) {
auto p = parser("while (true) { discard; }");
auto wl = p->while_stmt();
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(wl.errored);
ASSERT_TRUE(wl.matched);
EXPECT_TRUE(Is<ast::Expression>(wl->condition));
ASSERT_EQ(wl->body->statements.size(), 1u);
EXPECT_TRUE(wl->body->statements[0]->Is<ast::DiscardStatement>());
}
// Test a while loop with complex condition.
TEST_F(WhileStmtTest, ComplexCondition) {
auto p = parser("while ((a + 1 - 2) == 3) { }");
auto wl = p->while_stmt();
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(wl.errored);
ASSERT_TRUE(wl.matched);
EXPECT_TRUE(Is<ast::Expression>(wl->condition));
EXPECT_TRUE(wl->body->Empty());
}
// Test a while loop with no brackets.
TEST_F(WhileStmtTest, NoBrackets) {
auto p = parser("while (a + 1 - 2) == 3 { }");
auto wl = p->while_stmt();
EXPECT_FALSE(p->has_error()) << p->error();
EXPECT_FALSE(wl.errored);
ASSERT_TRUE(wl.matched);
EXPECT_TRUE(Is<ast::BinaryExpression>(wl->condition));
EXPECT_TRUE(wl->body->Empty());
}
class WhileStmtErrorTest : public ParserImplTest {
public:
void TestForWithError(std::string for_str, std::string error_str) {
auto p_for = parser(for_str);
auto e_for = p_for->while_stmt();
EXPECT_FALSE(e_for.matched);
EXPECT_TRUE(e_for.errored);
EXPECT_TRUE(p_for->has_error());
ASSERT_EQ(e_for.value, nullptr);
EXPECT_EQ(p_for->error(), error_str);
}
};
// Test a while loop with missing left parenthesis is invalid.
TEST_F(WhileStmtErrorTest, MissingLeftParen) {
std::string while_str = "while bool) { }";
std::string error_str = "1:11: expected '(' for type constructor";
TestForWithError(while_str, error_str);
}
// Test a while loop with missing condition is invalid.
TEST_F(WhileStmtErrorTest, MissingFirstSemicolon) {
std::string while_str = "while () {}";
std::string error_str = "1:8: unable to parse expression";
TestForWithError(while_str, error_str);
}
// Test a while loop with missing right parenthesis is invalid.
TEST_F(WhileStmtErrorTest, MissingRightParen) {
std::string while_str = "while (true {}";
std::string error_str = "1:13: expected ')'";
TestForWithError(while_str, error_str);
}
// Test a while loop with missing left brace is invalid.
TEST_F(WhileStmtErrorTest, MissingLeftBrace) {
std::string while_str = "while (true) }";
std::string error_str = "1:14: expected '{'";
TestForWithError(while_str, error_str);
}
// Test a for loop with missing right brace is invalid.
TEST_F(WhileStmtErrorTest, MissingRightBrace) {
std::string while_str = "while (true) {";
std::string error_str = "1:15: expected '}'";
TestForWithError(while_str, error_str);
}
// Test a while loop with an invalid break condition.
TEST_F(WhileStmtErrorTest, InvalidBreakConditionAsExpression) {
std::string while_str = "while ((0 == 1) { }";
std::string error_str = "1:17: expected ')'";
TestForWithError(while_str, error_str);
}
// Test a while loop with a break condition not matching
// logical_or_expression.
TEST_F(WhileStmtErrorTest, InvalidBreakConditionMatch) {
std::string while_str = "while (var i: i32 = 0) { }";
std::string error_str = "1:8: unable to parse expression";
TestForWithError(while_str, error_str);
}
// Test a while loop with an invalid body.
TEST_F(WhileStmtErrorTest, InvalidBody) {
std::string while_str = "while (true) { let x: i32; }";
std::string error_str = "1:26: expected '=' for let declaration";
TestForWithError(while_str, error_str);
}
// Test a for loop with a body not matching statements
TEST_F(WhileStmtErrorTest, InvalidBodyMatch) {
std::string while_str = "while (true) { fn main() {} }";
std::string error_str = "1:16: expected '}'";
TestForWithError(while_str, error_str);
}
} // namespace
} // namespace tint::reader::wgsl

View File

@ -263,6 +263,8 @@ std::string_view Token::TypeToName(Type type) {
return "vec3";
case Token::Type::kVec4:
return "vec4";
case Token::Type::kWhile:
return "while";
case Token::Type::kWorkgroup:
return "workgroup";
}

View File

@ -274,6 +274,8 @@ class Token {
kVec3,
/// A 'vec4'
kVec4,
/// A 'while'
kWhile,
/// A 'workgroup'
kWorkgroup,
};

View File

@ -21,6 +21,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/loop_statement.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/while_statement.h"
using namespace tint::number_suffixes; // NOLINT
@ -239,6 +240,55 @@ TEST_F(ResolverCompoundStatementTest, ForLoop) {
}
}
TEST_F(ResolverCompoundStatementTest, While) {
// fn F() {
// while (true) {
// return;
// }
// }
auto* cond = Expr(true);
auto* stmt = Return();
auto* body = Block(stmt);
auto* while_ = While(cond, body);
auto* f = Func("W", {}, ty.void_(), {while_});
ASSERT_TRUE(r()->Resolve()) << r()->error();
{
auto* s = Sem().Get(while_);
ASSERT_NE(s, nullptr);
EXPECT_EQ(Sem().Get(body)->Parent(), s);
EXPECT_TRUE(s->Is<sem::WhileStatement>());
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
EXPECT_EQ(s->Parent(), s->Block());
}
{ // Condition expression's statement is the while itself
auto* e = Sem().Get(cond);
ASSERT_NE(e, nullptr);
auto* s = e->Stmt();
ASSERT_NE(s, nullptr);
ASSERT_TRUE(Is<sem::WhileStatement>(s));
ASSERT_NE(s->Parent(), nullptr);
EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Parent(), s->FindFirstParent<sem::FunctionBlockStatement>());
EXPECT_TRUE(Is<sem::FunctionBlockStatement>(s->Block()));
}
{
auto* s = Sem().Get(stmt);
ASSERT_NE(s, nullptr);
ASSERT_NE(s->Block(), nullptr);
EXPECT_EQ(s->Parent(), s->Block());
EXPECT_EQ(s->Block(), s->FindFirstParent<sem::LoopBlockStatement>());
EXPECT_TRUE(Is<sem::WhileStatement>(s->Parent()->Parent()));
EXPECT_EQ(s->Block()->Parent(), s->FindFirstParent<sem::WhileStatement>());
ASSERT_TRUE(Is<sem::FunctionBlockStatement>(s->Block()->Parent()->Parent()));
EXPECT_EQ(s->Block()->Parent()->Parent(),
s->FindFirstParent<sem::FunctionBlockStatement>());
EXPECT_EQ(s->Function()->Declaration(), f);
EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr);
}
}
TEST_F(ResolverCompoundStatementTest, If) {
// fn F() {
// if (cond_a) {

View File

@ -263,6 +263,12 @@ class DependencyScanner {
TraverseExpression(v->variable->constructor);
Declare(v->variable->symbol, v->variable);
},
[&](const ast::WhileStatement* w) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseExpression(w->condition);
TraverseStatement(w->body);
},
[&](Default) {
if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::FallthroughStatement>()) {

View File

@ -1245,6 +1245,9 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
Assign(V, V), //
Block( //
Assign(V, V))), //
While(Equal(V, V), //
Block( //
Assign(V, V))), //
Loop(Block(Assign(V, V)), //
Block(Assign(V, V))), //
Switch(V, //

View File

@ -49,6 +49,7 @@
#include "src/tint/ast/unary_op_expression.h"
#include "src/tint/ast/variable_decl_statement.h"
#include "src/tint/ast/vector.h"
#include "src/tint/ast/while_statement.h"
#include "src/tint/ast/workgroup_attribute.h"
#include "src/tint/resolver/uniformity.h"
#include "src/tint/sem/abstract_float.h"
@ -77,6 +78,7 @@
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/math.h"
#include "src/tint/utils/reverse.h"
@ -854,6 +856,7 @@ sem::Statement* Resolver::Statement(const ast::Statement* stmt) {
[&](const ast::BlockStatement* b) { return BlockStatement(b); },
[&](const ast::ForLoopStatement* l) { return ForLoopStatement(l); },
[&](const ast::LoopStatement* l) { return LoopStatement(l); },
[&](const ast::WhileStatement* w) { return WhileStatement(w); },
[&](const ast::IfStatement* i) { return IfStatement(i); },
[&](const ast::SwitchStatement* s) { return SwitchStatement(s); },
@ -1039,6 +1042,39 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(const ast::ForLoopStatement* s
});
}
sem::WhileStatement* Resolver::WhileStatement(const ast::WhileStatement* stmt) {
auto* sem =
builder_->create<sem::WhileStatement>(stmt, current_compound_statement_, current_function_);
return StatementScope(stmt, sem, [&] {
auto& behaviors = sem->Behaviors();
auto* cond = Expression(stmt->condition);
if (!cond) {
return false;
}
sem->SetCondition(cond);
behaviors.Add(cond->Behaviors());
Mark(stmt->body);
auto* body = builder_->create<sem::LoopBlockStatement>(
stmt->body, current_compound_statement_, current_function_);
if (!StatementScope(stmt->body, body, [&] { return Statements(stmt->body->statements); })) {
return false;
}
behaviors.Add(body->Behaviors());
// Always consider the while as having a 'next' behaviour because it has
// a condition. We don't check if the condition will terminate but it isn't
// valid to have an infinite loop in a WGSL program, so a non-terminating
// condition is already an invalid program.
behaviors.Add(sem::Behavior::kNext);
behaviors.Remove(sem::Behavior::kBreak, sem::Behavior::kContinue);
return validator_.WhileStatement(sem);
});
}
sem::Expression* Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
constexpr size_t kMaxExpressionDepth = 512U;

View File

@ -54,6 +54,7 @@ class ReturnStatement;
class SwitchStatement;
class UnaryOpExpression;
class Variable;
class WhileStatement;
} // namespace tint::ast
namespace tint::sem {
class Array;
@ -67,6 +68,7 @@ class LoopStatement;
class Statement;
class SwitchStatement;
class TypeConstructor;
class WhileStatement;
} // namespace tint::sem
namespace tint::resolver {
@ -233,6 +235,7 @@ class Resolver {
sem::Statement* DiscardStatement(const ast::DiscardStatement*);
sem::Statement* FallthroughStatement(const ast::FallthroughStatement*);
sem::ForLoopStatement* ForLoopStatement(const ast::ForLoopStatement*);
sem::WhileStatement* WhileStatement(const ast::WhileStatement*);
sem::GlobalVariable* GlobalVariable(const ast::Variable*);
sem::Statement* Parameter(const ast::Variable*);
sem::IfStatement* IfStatement(const ast::IfStatement*);

View File

@ -20,6 +20,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/while_statement.h"
using namespace tint::number_suffixes; // NOLINT
@ -314,6 +315,56 @@ TEST_F(ResolverBehaviorTest, StmtForLoopEmpty_CondCallFuncMayDiscard) {
EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtWhileBreak) {
auto* stmt = While(Expr(true), Block(Break()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behavior::kNext);
}
TEST_F(ResolverBehaviorTest, StmtWhileDiscard) {
auto* stmt = While(Expr(true), Block(Discard()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtWhileReturn) {
auto* stmt = While(Expr(true), Block(Return()));
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kReturn, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondTrue) {
auto* stmt = While(Expr(true), Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtWhileEmpty_CondCallFuncMayDiscard) {
auto* stmt = While(Equal(Call("DiscardOrNext"), 1_i), Block());
WrapInFunction(stmt);
ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(stmt);
EXPECT_EQ(sem->Behaviors(), sem::Behaviors(sem::Behavior::kDiscard, sem::Behavior::kNext));
}
TEST_F(ResolverBehaviorTest, StmtIfTrue_ThenEmptyBlock) {
auto* stmt = If(true, Block());
WrapInFunction(stmt);

View File

@ -35,6 +35,7 @@
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/block_allocator.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/unique_vector.h"
@ -491,7 +492,7 @@ class UniformityGraph {
// Find the loop or switch statement that we are in.
auto* parent = sem_.Get(b)
->FindFirstParent<sem::SwitchStatement, sem::LoopStatement,
sem::ForLoopStatement>();
sem::ForLoopStatement, sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent);
@ -535,8 +536,9 @@ class UniformityGraph {
[&](const ast::ContinueStatement* c) {
// Find the loop statement that we are in.
auto* parent =
sem_.Get(c)->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement>();
auto* parent = sem_.Get(c)
->FindFirstParent<sem::LoopStatement, sem::ForLoopStatement,
sem::WhileStatement>();
TINT_ASSERT(Resolver, current_function_->loop_switch_infos.count(parent));
auto& info = current_function_->loop_switch_infos.at(parent);
@ -638,6 +640,68 @@ class UniformityGraph {
}
},
[&](const ast::WhileStatement* w) {
auto* sem_loop = sem_.Get(w);
auto* cfx = CreateNode("loop_start");
auto* cf_start = cf;
auto& info = current_function_->loop_switch_infos[sem_loop];
info.type = "whileloop";
// Create input nodes for any variables declared before this loop.
for (auto* v : current_function_->local_var_decls) {
auto name = builder_->Symbols().NameFor(v->Declaration()->symbol);
auto* in_node = CreateNode(name + "_value_forloop_in");
in_node->AddEdge(current_function_->variables.Get(v));
info.var_in_nodes[v] = in_node;
current_function_->variables.Set(v, in_node);
}
// Insert the condition at the start of the loop body.
{
auto [cf_cond, v] = ProcessExpression(cfx, w->condition);
auto* cf_condition_end = CreateNode("while_condition_CFend", w);
cf_condition_end->affects_control_flow = true;
cf_condition_end->AddEdge(v);
cf_start = cf_condition_end;
}
// Propagate assignments to the loop exit nodes.
for (auto* var : current_function_->local_var_decls) {
auto* exit_node = utils::GetOrCreate(info.var_exit_nodes, var, [&]() {
auto name = builder_->Symbols().NameFor(var->Declaration()->symbol);
return CreateNode(name + "_value_" + info.type + "_exit");
});
exit_node->AddEdge(current_function_->variables.Get(var));
}
auto* cf1 = ProcessStatement(cf_start, w->body);
cfx->AddEdge(cf1);
cfx->AddEdge(cf);
// Add edges from variable loop input nodes to their values at the end of the loop.
for (auto v : info.var_in_nodes) {
auto* in_node = v.second;
auto* out_node = current_function_->variables.Get(v.first);
if (out_node != in_node) {
in_node->AddEdge(out_node);
}
}
// Set each variable's exit node as its value in the outer scope.
for (auto v : info.var_exit_nodes) {
current_function_->variables.Set(v.first, v.second);
}
current_function_->loop_switch_infos.erase(sem_loop);
if (sem_loop->Behaviors() == sem::Behaviors{sem::Behavior::kNext}) {
return cf;
} else {
return cfx;
}
},
[&](const ast::IfStatement* i) {
auto* sem_if = sem_.Get(i);
auto [_, v_cond] = ProcessExpression(cf, i->condition);

View File

@ -2311,6 +2311,304 @@ fn foo() {
RunTest(src, true);
}
TEST_F(UniformityAnalysisTest, While_CallInside_UniformCondition) {
std::string src = R"(
@group(0) @binding(0) var<storage, read> n : i32;
fn foo() {
var i = 0;
while (i < n) {
workgroupBarrier();
i = i + 1;
}
}
)";
RunTest(src, true);
}
TEST_F(UniformityAnalysisTest, While_CallInside_NonUniformCondition) {
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> n : i32;
fn foo() {
var i = 0;
while (i < n) {
workgroupBarrier();
i = i + 1;
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:7:5 warning: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
test:6:3 note: control flow depends on non-uniform value
while (i < n) {
^^^^^
test:6:14 note: reading from read_write storage buffer 'n' may result in a non-uniform value
while (i < n) {
^
)");
}
TEST_F(UniformityAnalysisTest, While_VarBecomesNonUniformInLoopAfterBarrier) {
// Use a variable for a conditional barrier in a loop, and then assign a non-uniform value to
// that variable later in that loop.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
var v = 0;
var i = 0;
while (i < 10) {
if (v == 0) {
workgroupBarrier();
break;
}
v = non_uniform;
i++;
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:9:7 warning: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
test:8:5 note: control flow depends on non-uniform value
if (v == 0) {
^^
test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
v = non_uniform;
^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, While_ConditionalAssignNonUniformWithBreak_BarrierInLoop) {
// In a conditional block, assign a non-uniform value and then break, then use a variable for a
// conditional barrier later in the loop.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
var v = 0;
var i = 0;
while (i < 10) {
if (true) {
v = non_uniform;
break;
}
if (v == 0) {
workgroupBarrier();
}
i++;
}
}
)";
RunTest(src, true);
}
TEST_F(UniformityAnalysisTest, While_ConditionalAssignNonUniformWithBreak_BarrierAfterLoop) {
// In a conditional block, assign a non-uniform value and then break, then use a variable for a
// conditional barrier after the loop.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
var v = 0;
var i = 0;
while (i < 10) {
if (true) {
v = non_uniform;
break;
}
v = 5;
i++;
}
if (v == 0) {
workgroupBarrier();
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:17:5 warning: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
test:16:3 note: control flow depends on non-uniform value
if (v == 0) {
^^
test:9:11 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
v = non_uniform;
^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, While_VarRemainsNonUniformAtLoopEnd_BarrierAfterLoop) {
// Assign a non-uniform value, assign a uniform value before all explicit break points but leave
// the value non-uniform at loop exit, then use a variable for a conditional barrier after the
// loop.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
var v = 0;
var i = 0;
while (i < 10) {
if (true) {
v = 5;
break;
}
v = non_uniform;
if (true) {
v = 6;
break;
}
i++;
}
if (v == 0) {
workgroupBarrier();
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:23:5 warning: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
test:22:3 note: control flow depends on non-uniform value
if (v == 0) {
^^
test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
v = non_uniform;
^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, While_VarBecomesNonUniformBeforeConditionalContinue_BarrierAtStart) {
// Use a variable for a conditional barrier in a loop, assign a non-uniform value to
// that variable later in that loop, then perform a conditional continue before assigning a
// uniform value to that variable.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
var v = 0;
var i = 0;
while (i < 10) {
if (v == 0) {
workgroupBarrier();
break;
}
v = non_uniform;
if (true) {
continue;
}
v = 5;
i++;
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:9:7 warning: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
test:8:5 note: control flow depends on non-uniform value
if (v == 0) {
^^
test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
v = non_uniform;
^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, While_VarBecomesNonUniformBeforeConditionalContinue) {
// Use a variable for a conditional barrier in a loop, assign a non-uniform value to
// that variable later in that loop, then perform a conditional continue before assigning a
// uniform value to that variable.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> non_uniform : i32;
fn foo() {
var v = 0;
var i = 0;
while (i < 10) {
if (v == 0) {
workgroupBarrier();
break;
}
v = non_uniform;
if (true) {
continue;
}
v = 5;
i++;
}
}
)";
RunTest(src, false);
EXPECT_EQ(error_,
R"(test:9:7 warning: 'workgroupBarrier' must only be called from uniform control flow
workgroupBarrier();
^^^^^^^^^^^^^^^^
test:8:5 note: control flow depends on non-uniform value
if (v == 0) {
^^
test:13:9 note: reading from read_write storage buffer 'non_uniform' may result in a non-uniform value
v = non_uniform;
^^^^^^^^^^^
)");
}
TEST_F(UniformityAnalysisTest, While_NonUniformCondition_Reconverge) {
// Loops reconverge at exit, so test that we can call workgroupBarrier() after a loop that has a
// non-uniform condition.
std::string src = R"(
@group(0) @binding(0) var<storage, read_write> n : i32;
fn foo() {
var i = 0;
while (i < n) {
}
workgroupBarrier();
i = i + 1;
}
)";
RunTest(src, true);
}
} // namespace LoopTest
////////////////////////////////////////////////////////////////////////////////

View File

@ -986,6 +986,26 @@ TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) {
EXPECT_EQ(r()->error(), "12:34 error: for-loop condition must be bool, got f32");
}
TEST_F(ResolverTest, Stmt_While_CondIsBoolRef) {
// var cond : bool = false;
// while (cond) {
// }
auto* cond = Var("cond", ty.bool_(), Expr(false));
WrapInFunction(Decl(cond), While("cond", Block()));
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTest, Stmt_While_CondIsNotBool) {
// while (1.0f) {
// }
WrapInFunction(While(Expr(Source{{12, 34}}, 1_f), Block()));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: while condition must be bool, got f32");
}
TEST_F(ResolverValidationTest, Stmt_ContinueInLoop) {
WrapInFunction(Loop(Block(If(false, Block(Break())), //
Continue(Source{{12, 34}}))));

View File

@ -72,6 +72,7 @@
#include "src/tint/sem/type_constructor.h"
#include "src/tint/sem/type_conversion.h"
#include "src/tint/sem/variable.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/defer.h"
#include "src/tint/utils/map.h"
#include "src/tint/utils/math.h"
@ -237,6 +238,11 @@ const ast::Statement* Validator::ClosestContinuing(bool stop_at_loop,
break;
}
}
if (Is<sem::WhileStatement>(s->Parent())) {
if (stop_at_loop) {
break;
}
}
}
return nullptr;
}
@ -1460,6 +1466,22 @@ bool Validator::ForLoopStatement(const sem::ForLoopStatement* stmt) const {
return true;
}
bool Validator::WhileStatement(const sem::WhileStatement* stmt) const {
if (stmt->Behaviors().Empty()) {
AddError("while does not exit", stmt->Declaration()->source.Begin());
return false;
}
if (auto* cond = stmt->Condition()) {
auto* cond_ty = cond->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {
AddError("while condition must be bool, got " + sem_.TypeNameOf(cond_ty),
stmt->Condition()->Declaration()->source);
return false;
}
}
return true;
}
bool Validator::IfStatement(const sem::IfStatement* stmt) const {
auto* cond_ty = stmt->Condition()->Type()->UnwrapRef();
if (!cond_ty->Is<sem::Bool>()) {

View File

@ -44,6 +44,7 @@ class ReturnStatement;
class SwitchStatement;
class UnaryOpExpression;
class Variable;
class WhileStatement;
} // namespace tint::ast
namespace tint::sem {
class Array;
@ -58,6 +59,7 @@ class Materialize;
class Statement;
class SwitchStatement;
class TypeConstructor;
class WhileStatement;
} // namespace tint::sem
namespace tint::resolver {
@ -207,6 +209,11 @@ class Validator {
/// @returns true on success, false otherwise
bool ForLoopStatement(const sem::ForLoopStatement* stmt) const;
/// Validates a while loop
/// @param stmt the while statement to validate
/// @returns true on success, false otherwise
bool WhileStatement(const sem::WhileStatement* stmt) const;
/// Validates a fallthrough statement
/// @param stmt the fallthrough to validate
/// @returns true on success, false otherwise

View File

@ -34,6 +34,7 @@ class SwitchStatement;
class Type;
class TypeDecl;
class Variable;
class WhileStatement;
} // namespace tint::ast
namespace tint::sem {
class Array;
@ -50,6 +51,7 @@ class StructMember;
class SwitchStatement;
class Type;
class Variable;
class WhileStatement;
} // namespace tint::sem
namespace tint::sem {
@ -74,6 +76,7 @@ struct TypeMappings {
Type* operator()(ast::Type*);
Type* operator()(ast::TypeDecl*);
Variable* operator()(ast::Variable*);
WhileStatement* operator()(ast::WhileStatement*);
//! @endcond
};

View File

@ -0,0 +1,34 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/sem/while_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::WhileStatement);
namespace tint::sem {
WhileStatement::WhileStatement(const ast::WhileStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function)
: Base(declaration, parent, function) {}
WhileStatement::~WhileStatement() = default;
const ast::WhileStatement* WhileStatement::Declaration() const {
return static_cast<const ast::WhileStatement*>(Base::Declaration());
}
} // namespace tint::sem

View File

@ -0,0 +1,60 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0(the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_SEM_WHILE_STATEMENT_H_
#define SRC_TINT_SEM_WHILE_STATEMENT_H_
#include "src/tint/sem/statement.h"
// Forward declarations
namespace tint::ast {
class WhileStatement;
} // namespace tint::ast
namespace tint::sem {
class Expression;
} // namespace tint::sem
namespace tint::sem {
/// Holds semantic information about a while statement
class WhileStatement final : public Castable<WhileStatement, CompoundStatement> {
public:
/// Constructor
/// @param declaration the AST node for this while statement
/// @param parent the owning statement
/// @param function the owning function
WhileStatement(const ast::WhileStatement* declaration,
const CompoundStatement* parent,
const sem::Function* function);
/// Destructor
~WhileStatement() override;
/// @returns the AST node
const ast::WhileStatement* Declaration() const;
/// @returns the whilecondition expression
const Expression* Condition() const { return condition_; }
/// Sets the while condition expression
/// @param condition the while condition expression
void SetCondition(const Expression* condition) { condition_ = condition; }
private:
const Expression* condition_ = nullptr;
};
} // namespace tint::sem
#endif // SRC_TINT_SEM_WHILE_STATEMENT_H_

View File

@ -27,6 +27,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/variable.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/transform/manager.h"
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
@ -383,6 +384,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
ProcessStatement(s->expr);
},
[&](const ast::ForLoopStatement* s) { ProcessStatement(s->condition); },
[&](const ast::WhileStatement* s) { ProcessStatement(s->condition); },
[&](const ast::IfStatement* s) { //
ProcessStatement(s->condition);
},
@ -578,6 +580,15 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::WhileStatement* s) -> const ast::Statement* {
if (!sem.Get(s->condition)->HasSideEffects()) {
return nullptr;
}
ast::StatementList stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::IfStatement* s) -> const ast::Statement* {
if (!sem.Get(s->condition)->HasSideEffects()) {
return nullptr;

View File

@ -999,6 +999,45 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InWhileCond) {
auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
fn f() {
var b = 1;
while(a(0) + b > 0) {
var marker = 0;
}
}
)";
auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
fn f() {
var b = 1;
loop {
let tint_symbol = a(0);
if (!(((tint_symbol + b) > 0))) {
break;
}
{
var marker = 0;
}
}
}
)";
DataMap data;
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InElseIf) {
auto* src = R"(
fn a(i : i32) -> i32 {
@ -2299,6 +2338,48 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InWhileCond) {
auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
fn f() {
var b = true;
while(a(0) && b) {
var marker = 0;
}
}
)";
auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
fn f() {
var b = true;
loop {
var tint_symbol = a(0);
if (tint_symbol) {
tint_symbol = b;
}
if (!(tint_symbol)) {
break;
}
{
var marker = 0;
}
}
}
)";
DataMap data;
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InElseIf) {
auto* src = R"(
fn a(i : i32) -> bool {

View File

@ -25,6 +25,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/loop_statement.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/utils/map.h"
@ -49,7 +50,7 @@ class State {
// Find whether first parent is a switch or a loop
auto* sem_stmt = sem.Get(cont);
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
sem::ForLoopStatement>();
sem::ForLoopStatement, sem::WhileStatement>();
if (!sem_parent) {
return nullptr;
}

View File

@ -559,5 +559,59 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, While) {
auto* src = R"(
fn f() {
var i = 0;
while (i < 4) {
let marker1 = 0;
switch(i) {
case 0: {
continue;
break;
}
default: {
break;
}
}
let marker2 = 0;
break;
}
}
)";
auto* expect = R"(
fn f() {
var i = 0;
while((i < 4)) {
let marker1 = 0;
var tint_continue : bool = false;
switch(i) {
case 0: {
{
tint_continue = true;
break;
}
break;
}
default: {
break;
}
}
if (tint_continue) {
continue;
}
let marker2 = 0;
break;
}
}
)";
DataMap data;
auto got = Run<RemoveContinueInSwitch>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform

View File

@ -262,6 +262,15 @@ class State {
}
return nullptr;
},
[&](const ast::WhileStatement* s) -> const ast::Statement* {
if (MayDiscard(sem.Get(s->condition))) {
TINT_ICE(Transform, b.Diagnostics())
<< "Unexpected WhileStatement condition that may discard. "
"Make sure transform::PromoteSideEffectsToDecl was run "
"first.";
}
return nullptr;
},
[&](const ast::IfStatement* s) -> const ast::Statement* {
auto* sem_expr = sem.Get(s->condition);
if (!MayDiscard(sem_expr)) {

View File

@ -800,6 +800,67 @@ fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, While_Cond) {
auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
}
return 42;
}
@fragment
fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
let marker1 = 0;
while (f() == 42) {
let marker2 = 0;
break;
}
return vec4<f32>();
}
)";
auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
if (true) {
tint_discard = true;
return i32();
}
return 42;
}
fn tint_discard_func() {
discard;
}
@fragment
fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
let marker1 = 0;
loop {
let tint_symbol = f();
if (tint_discard) {
tint_discard_func();
return vec4<f32>();
}
if (!((tint_symbol == 42))) {
break;
}
{
let marker2 = 0;
break;
}
}
return vec4<f32>();
}
)";
DataMap data;
auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Switch) {
auto* src = R"(
fn f() -> i32 {

View File

@ -22,6 +22,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/variable.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/reverse.h"
namespace tint::transform {
@ -46,7 +47,10 @@ class HoistToDeclBefore::State {
};
/// For-loops that need to be decomposed to loops.
std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
std::unordered_map<const sem::ForLoopStatement*, LoopInfo> for_loops;
/// Whiles that need to be decomposed to loops.
std::unordered_map<const sem::WhileStatement*, LoopInfo> while_loops;
/// 'else if' statements that need to be decomposed to 'else {if}'
std::unordered_map<const ast::IfStatement*, ElseIfInfo> else_ifs;
@ -55,7 +59,7 @@ class HoistToDeclBefore::State {
// registered declaration statements before the condition or continuing
// statement.
void ForLoopsToLoops() {
if (loops.empty()) {
if (for_loops.empty()) {
return;
}
@ -64,7 +68,7 @@ class HoistToDeclBefore::State {
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
if (auto it = loops.find(fl); it != loops.end()) {
if (auto it = for_loops.find(fl); it != for_loops.end()) {
auto& info = it->second;
auto* for_loop = fl->Declaration();
// For-loop needs to be decomposed to a loop.
@ -108,6 +112,51 @@ class HoistToDeclBefore::State {
});
}
// Converts any while-loops marked for conversion to loops, inserting
// registered declaration statements before the condition.
void WhilesToLoops() {
if (while_loops.empty()) {
return;
}
// At least one while needs to be transformed into a loop.
ctx.ReplaceAll([&](const ast::WhileStatement* stmt) -> const ast::Statement* {
auto& sem = ctx.src->Sem();
if (auto* w = sem.Get(stmt)) {
if (auto it = while_loops.find(w); it != while_loops.end()) {
auto& info = it->second;
auto* while_loop = w->Declaration();
// While needs to be decomposed to a loop.
// Build the loop body's statements.
// Start with any let declarations for the conditional
// expression.
auto body_stmts = info.cond_decls;
// Emit the condition as:
// if (!cond) { break; }
auto* cond = while_loop->condition;
// !condition
auto* not_cond =
b.create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
// { break; }
auto* break_body = b.Block(b.create<ast::BreakStatement>());
// if (!condition) { break; }
body_stmts.emplace_back(b.If(not_cond, break_body));
// Next emit the body
body_stmts.emplace_back(ctx.Clone(while_loop->body));
const ast::BlockStatement* continuing = nullptr;
auto* body = b.Block(body_stmts);
auto* loop = b.Loop(body, continuing);
return loop;
}
}
return nullptr;
});
}
void ElseIfsToElseWithNestedIfs() {
// Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const ast::IfStatement* else_if) -> const ast::Statement* {
@ -192,7 +241,19 @@ class HoistToDeclBefore::State {
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
auto& decls = loops[fl].cond_decls;
auto& decls = for_loops[fl].cond_decls;
if (stmt) {
decls.emplace_back(stmt);
}
return true;
}
if (auto* w = before_stmt->As<sem::WhileStatement>()) {
// Insertion point is a while condition.
// While needs to be decomposed to a loop.
// Index the map to convert this while, even if `stmt` is nullptr.
auto& decls = while_loops[w].cond_decls;
if (stmt) {
decls.emplace_back(stmt);
}
@ -227,7 +288,7 @@ class HoistToDeclBefore::State {
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
auto& decls = loops[fl].cont_decls;
auto& decls = for_loops[fl].cont_decls;
if (stmt) {
decls.emplace_back(stmt);
}
@ -257,6 +318,7 @@ class HoistToDeclBefore::State {
/// @return true on success
bool Apply() {
ForLoopsToLoops();
WhilesToLoops();
ElseIfsToElseWithNestedIfs();
return true;
}

View File

@ -175,6 +175,47 @@ fn f() {
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, WhileCond) {
// fn f() {
// var a : bool;
// while(a) {
// }
// }
ProgramBuilder b;
auto* var = b.Decl(b.Var("a", b.ty.bool_()));
auto* expr = b.Expr("a");
auto* s = b.While(expr, b.Block());
b.Func("f", {}, b.ty.void_(), {var, s});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
var a : bool;
loop {
let tint_symbol = a;
if (!(tint_symbol)) {
break;
}
{
}
}
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ElseIf) {
// fn f() {
// var a : bool;

View File

@ -0,0 +1,67 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/while_to_loop.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop);
namespace tint::transform {
WhileToLoop::WhileToLoop() = default;
WhileToLoop::~WhileToLoop() = default;
bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::WhileStatement>()) {
return true;
}
}
return false;
}
void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
ast::StatementList stmts;
auto* cond = w->condition;
// !condition
auto* not_cond =
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
// { break; }
auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
// if (!condition) { break; }
stmts.emplace_back(ctx.dst->If(not_cond, break_body));
for (auto* stmt : w->body->statements) {
stmts.emplace_back(ctx.Clone(stmt));
}
const ast::BlockStatement* continuing = nullptr;
auto* body = ctx.dst->Block(stmts);
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
return loop;
});
ctx.Clone();
}
} // namespace tint::transform

View File

@ -0,0 +1,49 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_
#define SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_
#include "src/tint/transform/transform.h"
namespace tint::transform {
/// WhileToLoop is a Transform that converts a while statement into a loop
/// statement. This is required by the SPIR-V writer.
class WhileToLoop final : public Castable<WhileToLoop, Transform> {
public:
/// Constructor
WhileToLoop();
/// Destructor
~WhileToLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
#endif // SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_

View File

@ -0,0 +1,129 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/while_to_loop.h"
#include "src/tint/transform/test_helper.h"
namespace tint::transform {
namespace {
using WhileToLoopTest = TransformTest;
TEST_F(WhileToLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<WhileToLoop>(src));
}
TEST_F(WhileToLoopTest, ShouldRunHasWhile) {
auto* src = R"(
fn f() {
while (true) {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<WhileToLoop>(src));
}
TEST_F(WhileToLoopTest, EmptyModule) {
auto* src = "";
auto* expect = src;
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test an empty for loop.
TEST_F(WhileToLoopTest, Empty) {
auto* src = R"(
fn f() {
while (true) {
break;
}
}
)";
auto* expect = R"(
fn f() {
loop {
if (!(true)) {
break;
}
break;
}
}
)";
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop with non-empty body.
TEST_F(WhileToLoopTest, Body) {
auto* src = R"(
fn f() {
while (true) {
discard;
}
}
)";
auto* expect = R"(
fn f() {
loop {
if (!(true)) {
break;
}
discard;
}
}
)";
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a loop with a break condition
TEST_F(WhileToLoopTest, BreakCondition) {
auto* src = R"(
fn f() {
while (0 == 1) {
}
}
)";
auto* expect = R"(
fn f() {
loop {
if (!((0 == 1))) {
break;
}
}
}
)";
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform

View File

@ -2523,6 +2523,54 @@ bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
return true;
}
bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
TextBuffer cond_pre;
std::stringstream cond_buf;
{
auto* cond = stmt->condition;
TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
if (!EmitExpression(cond_buf, cond)) {
return false;
}
}
// If the whilehas a multi-statement conditional, then we cannot emit this
// as a regular while in GLSL. Instead we need to generate a `while(true)` loop.
bool emit_as_loop = cond_pre.lines.size() > 0;
if (emit_as_loop) {
line() << "while (true) {";
increment_indent();
TINT_DEFER({
decrement_indent();
line() << "}";
});
current_buffer_->Append(cond_pre);
line() << "if (!(" << cond_buf.str() << ")) { break; }";
if (!EmitStatements(stmt->body->statements)) {
return false;
}
} else {
// While can be generated.
{
auto out = line();
out << "while";
{
ScopedParen sp(out);
out << cond_buf.str();
}
out << " {";
}
if (!EmitStatementsWithIndent(stmt->body->statements)) {
return false;
}
line() << "}";
}
return true;
}
bool GeneratorImpl::EmitMemberAccessor(std::ostream& out,
const ast::MemberAccessorExpression* expr) {
if (!EmitExpression(out, expr->structure)) {
@ -2591,6 +2639,9 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
return EmitForLoop(l);
}
if (auto* l = stmt->As<ast::WhileStatement>()) {
return EmitWhile(l);
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
return EmitReturn(r);
}

View File

@ -357,6 +357,10 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitForLoop(const ast::ForLoopStatement* stmt);
/// Handles a while statement
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles generating an identifier expression
/// @param out the output of the expression stream
/// @param expr the identifier expression

View File

@ -381,5 +381,52 @@ TEST_F(GlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) {
)");
}
TEST_F(GlslGeneratorImplTest_Loop, Emit_While) {
// while(true) {
// return;
// }
auto* f = While(Expr(true), Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while(true) {
return;
}
)");
}
TEST_F(GlslGeneratorImplTest_Loop, Emit_WhileWithMultiStmtCond) {
// while(true && false) {
// return;
// }
Func("a_statement", {}, ty.void_(), {});
auto* multi_stmt =
create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
auto* f = While(multi_stmt, Block(CallStmt(Call("a_statement"))));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while (true) {
bool tint_tmp = true;
if (tint_tmp) {
tint_tmp = false;
}
if (!((tint_tmp))) { break; }
a_statement();
}
)");
}
} // namespace
} // namespace tint::writer::glsl

View File

@ -3481,6 +3481,53 @@ bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
return true;
}
bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
TextBuffer cond_pre;
std::stringstream cond_buf;
{
auto* cond = stmt->condition;
TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
if (!EmitExpression(cond_buf, cond)) {
return false;
}
}
// If the while has a multi-statement conditional, then we cannot emit this
// as a regular while in HLSL. Instead we need to generate a `while(true)` loop.
bool emit_as_loop = cond_pre.lines.size() > 0;
if (emit_as_loop) {
line() << LoopAttribute() << "while (true) {";
increment_indent();
TINT_DEFER({
decrement_indent();
line() << "}";
});
current_buffer_->Append(cond_pre);
line() << "if (!(" << cond_buf.str() << ")) { break; }";
if (!EmitStatements(stmt->body->statements)) {
return false;
}
} else {
// While can be generated.
{
auto out = line();
out << LoopAttribute() << "while";
{
ScopedParen sp(out);
out << cond_buf.str();
}
out << " {";
}
if (!EmitStatementsWithIndent(stmt->body->statements)) {
return false;
}
line() << "}";
}
return true;
}
bool GeneratorImpl::EmitMemberAccessor(std::ostream& out,
const ast::MemberAccessorExpression* expr) {
if (!EmitExpression(out, expr->structure)) {
@ -3551,6 +3598,9 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
[&](const ast::ForLoopStatement* l) { //
return EmitForLoop(l);
},
[&](const ast::WhileStatement* l) { //
return EmitWhile(l);
},
[&](const ast::ReturnStatement* r) { //
return EmitReturn(r);
},

View File

@ -353,6 +353,10 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitForLoop(const ast::ForLoopStatement* stmt);
/// Handles a while statement
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles generating an identifier expression
/// @param out the output of the expression stream
/// @param expr the identifier expression

View File

@ -373,5 +373,50 @@ TEST_F(HlslGeneratorImplTest_Loop, Emit_ForLoopWithMultiStmtInitCondCont) {
)");
}
TEST_F(HlslGeneratorImplTest_Loop, Emit_While) {
// while(true) {
// return;
// }
auto* f = While(Expr(true), Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( [loop] while(true) {
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Loop, Emit_WhileWithMultiStmtCond) {
// while(true && false) {
// return;
// }
auto* multi_stmt =
create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
auto* f = While(multi_stmt, Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( [loop] while (true) {
bool tint_tmp = true;
if (tint_tmp) {
tint_tmp = false;
}
if (!((tint_tmp))) { break; }
return;
}
)");
}
} // namespace
} // namespace tint::writer::hlsl

View File

@ -2124,6 +2124,53 @@ bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
return true;
}
bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
TextBuffer cond_pre;
std::stringstream cond_buf;
{
auto* cond = stmt->condition;
TINT_SCOPED_ASSIGNMENT(current_buffer_, &cond_pre);
if (!EmitExpression(cond_buf, cond)) {
return false;
}
}
// If the while has a multi-statement conditional, then we cannot emit this
// as a regular while in MSL. Instead we need to generate a `while(true)` loop.
bool emit_as_loop = cond_pre.lines.size() > 0;
if (emit_as_loop) {
line() << "while (true) {";
increment_indent();
TINT_DEFER({
decrement_indent();
line() << "}";
});
current_buffer_->Append(cond_pre);
line() << "if (!(" << cond_buf.str() << ")) { break; }";
if (!EmitStatements(stmt->body->statements)) {
return false;
}
} else {
// While can be generated.
{
auto out = line();
out << "while";
{
ScopedParen sp(out);
out << cond_buf.str();
}
out << " {";
}
if (!EmitStatementsWithIndent(stmt->body->statements)) {
return false;
}
line() << "}";
}
return true;
}
bool GeneratorImpl::EmitDiscard(const ast::DiscardStatement*) {
// TODO(dsinclair): Verify this is correct when the discard semantics are
// defined for WGSL (https://github.com/gpuweb/gpuweb/issues/361)
@ -2280,6 +2327,9 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
[&](const ast::ForLoopStatement* l) { //
return EmitForLoop(l);
},
[&](const ast::WhileStatement* l) { //
return EmitWhile(l);
},
[&](const ast::ReturnStatement* r) { //
return EmitReturn(r);
},

View File

@ -270,6 +270,10 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitForLoop(const ast::ForLoopStatement* stmt);
/// Handles a while statement
/// @param stmt the statement to emit
/// @returns true if the statement was emitted
bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles a member accessor expression
/// @param out the output of the expression stream
/// @param expr the member accessor expression

View File

@ -344,5 +344,45 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
)");
}
TEST_F(MslGeneratorImplTest, Emit_While) {
// while(true) {
// return;
// }
auto* f = While(Expr(true), Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while(true) {
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_WhileWithMultiCond) {
// while(true && false) {
// return;
// }
auto* multi_stmt =
create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
auto* f = While(multi_stmt, Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while((true && false)) {
return;
}
)");
}
} // namespace
} // namespace tint::writer::msl

View File

@ -32,6 +32,7 @@
#include "src/tint/transform/unwind_discard_functions.h"
#include "src/tint/transform/var_for_dynamic_index.h"
#include "src/tint/transform/vectorize_scalar_matrix_constructors.h"
#include "src/tint/transform/while_to_loop.h"
#include "src/tint/transform/zero_init_workgroup_memory.h"
#include "src/tint/writer/generate_external_texture_bindings.h"
@ -74,7 +75,7 @@ SanitizedResult Sanitize(const Program* in, const Options& options) {
manager.Add<transform::SimplifyPointers>(); // Required for arrayLength()
manager.Add<transform::VectorizeScalarMatrixConstructors>();
manager.Add<transform::ForLoopToLoop>(); // Must come after
// ZeroInitWorkgroupMemory
manager.Add<transform::WhileToLoop>(); // ZeroInitWorkgroupMemory
manager.Add<transform::CanonicalizeEntryPointIO>();
manager.Add<transform::AddEmptyEntryPoint>();
manager.Add<transform::AddSpirvBlockAttribute>();

View File

@ -919,6 +919,7 @@ bool GeneratorImpl::EmitStatement(const ast::Statement* stmt) {
[&](const ast::IncrementDecrementStatement* l) { return EmitIncrementDecrement(l); },
[&](const ast::LoopStatement* l) { return EmitLoop(l); },
[&](const ast::ForLoopStatement* l) { return EmitForLoop(l); },
[&](const ast::WhileStatement* l) { return EmitWhile(l); },
[&](const ast::ReturnStatement* r) { return EmitReturn(r); },
[&](const ast::SwitchStatement* s) { return EmitSwitch(s); },
[&](const ast::VariableDeclStatement* v) { return EmitVariable(line(), v->variable); },
@ -1181,6 +1182,30 @@ bool GeneratorImpl::EmitForLoop(const ast::ForLoopStatement* stmt) {
return true;
}
bool GeneratorImpl::EmitWhile(const ast::WhileStatement* stmt) {
{
auto out = line();
out << "while";
{
ScopedParen sp(out);
auto* cond = stmt->condition;
if (!EmitExpression(out, cond)) {
return false;
}
}
out << " {";
}
if (!EmitStatementsWithIndent(stmt->body->statements)) {
return false;
}
line() << "}";
return true;
}
bool GeneratorImpl::EmitReturn(const ast::ReturnStatement* stmt) {
auto out = line();
out << "return";

View File

@ -152,6 +152,10 @@ class GeneratorImpl : public TextGenerator {
/// @param stmt the statement to emit
/// @returns true if the statement was emtited
bool EmitForLoop(const ast::ForLoopStatement* stmt);
/// Handles a while statement
/// @param stmt the statement to emit
/// @returns true if the statement was emtited
bool EmitWhile(const ast::WhileStatement* stmt);
/// Handles a member accessor expression
/// @param out the output of the expression stream
/// @param expr the member accessor expression

View File

@ -198,5 +198,45 @@ TEST_F(WgslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
)");
}
TEST_F(WgslGeneratorImplTest, Emit_While) {
// while(true) {
// return;
// }
auto* f = While(Expr(true), Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while(true) {
return;
}
)");
}
TEST_F(WgslGeneratorImplTest, Emit_WhileMultiCond) {
// while(true && false) {
// return;
// }
auto* multi_stmt =
create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr(true), Expr(false));
auto* f = While(multi_stmt, Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while((true && false)) {
return;
}
)");
}
} // namespace
} // namespace tint::writer::wgsl

View File

@ -211,6 +211,7 @@ tint_unittests_source_set("tint_unittests_ast_src") {
"../../src/tint/ast/variable_decl_statement_test.cc",
"../../src/tint/ast/variable_test.cc",
"../../src/tint/ast/vector_test.cc",
"../../src/tint/ast/while_statement_test.cc",
"../../src/tint/ast/workgroup_attribute_test.cc",
]
}
@ -307,8 +308,8 @@ tint_unittests_source_set("tint_unittests_sem_src") {
"../../src/tint/sem/sem_struct_test.cc",
"../../src/tint/sem/storage_texture_test.cc",
"../../src/tint/sem/texture_test.cc",
"../../src/tint/sem/type_test.cc",
"../../src/tint/sem/type_manager_test.cc",
"../../src/tint/sem/type_test.cc",
"../../src/tint/sem/u32_test.cc",
"../../src/tint/sem/vector_test.cc",
]
@ -359,6 +360,7 @@ tint_unittests_source_set("tint_unittests_transform_src") {
"../../src/tint/transform/var_for_dynamic_index_test.cc",
"../../src/tint/transform/vectorize_scalar_matrix_constructors_test.cc",
"../../src/tint/transform/vertex_pulling_test.cc",
"../../src/tint/transform/while_to_loop_test.cc",
"../../src/tint/transform/wrap_arrays_in_structs_test.cc",
"../../src/tint/transform/zero_init_workgroup_memory_test.cc",
]
@ -552,6 +554,7 @@ tint_unittests_source_set("tint_unittests_wgsl_reader_src") {
"../../src/tint/reader/wgsl/parser_impl_variable_ident_decl_test.cc",
"../../src/tint/reader/wgsl/parser_impl_variable_qualifier_test.cc",
"../../src/tint/reader/wgsl/parser_impl_variable_stmt_test.cc",
"../../src/tint/reader/wgsl/parser_impl_while_stmt_test.cc",
"../../src/tint/reader/wgsl/parser_test.cc",
"../../src/tint/reader/wgsl/token_test.cc",
]

View File

@ -0,0 +1,7 @@
fn f() -> i32 {
var i : i32;
while (i < 4) {
i = i + 1;
}
return i;
}

View File

@ -0,0 +1,14 @@
#version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
int f() {
int i = 0;
while((i < 4)) {
i = (i + 1);
}
return i;
}

View File

@ -0,0 +1,12 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
int f() {
int i = 0;
[loop] while((i < 4)) {
i = (i + 1);
}
return i;
}

View File

@ -0,0 +1,11 @@
#include <metal_stdlib>
using namespace metal;
int f() {
int i = 0;
while((i < 4)) {
i = as_type<int>((as_type<uint>(i) + as_type<uint>(1)));
}
return i;
}

View File

@ -0,0 +1,51 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 27
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %unused_entry_point "unused_entry_point"
OpName %f "f"
OpName %i "i"
%void = OpTypeVoid
%1 = OpTypeFunction %void
%int = OpTypeInt 32 1
%5 = OpTypeFunction %int
%_ptr_Function_int = OpTypePointer Function %int
%11 = OpConstantNull %int
%int_4 = OpConstant %int 4
%bool = OpTypeBool
%int_1 = OpConstant %int 1
%unused_entry_point = OpFunction %void None %1
%4 = OpLabel
OpReturn
OpFunctionEnd
%f = OpFunction %int None %5
%8 = OpLabel
%i = OpVariable %_ptr_Function_int Function %11
OpBranch %12
%12 = OpLabel
OpLoopMerge %13 %14 None
OpBranch %15
%15 = OpLabel
%17 = OpLoad %int %i
%19 = OpSLessThan %bool %17 %int_4
%16 = OpLogicalNot %bool %19
OpSelectionMerge %21 None
OpBranchConditional %16 %22 %21
%22 = OpLabel
OpBranch %13
%21 = OpLabel
%23 = OpLoad %int %i
%25 = OpIAdd %int %23 %int_1
OpStore %i %25
OpBranch %14
%14 = OpLabel
OpBranch %12
%13 = OpLabel
%26 = OpLoad %int %i
OpReturnValue %26
OpFunctionEnd

View File

@ -0,0 +1,7 @@
fn f() -> i32 {
var i : i32;
while((i < 4)) {
i = (i + 1);
}
return i;
}