Add while statement parsing.

This CL adds parsing for the WGSL `while` statement.

Bug: tint:1425
Change-Id: Ibce5e28568935ca4f51b5ac33e7a60af7a916b4a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93540
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-06-16 12:01:27 +00:00
committed by Dawn LUCI CQ
parent d10f3f4437
commit 49d1a2d950
60 changed files with 2151 additions and 13 deletions

View File

@@ -27,6 +27,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/variable.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/transform/manager.h"
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
@@ -383,6 +384,7 @@ class DecomposeSideEffects::CollectHoistsState : public StateBase {
ProcessStatement(s->expr);
},
[&](const ast::ForLoopStatement* s) { ProcessStatement(s->condition); },
[&](const ast::WhileStatement* s) { ProcessStatement(s->condition); },
[&](const ast::IfStatement* s) { //
ProcessStatement(s->condition);
},
@@ -578,6 +580,15 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::WhileStatement* s) -> const ast::Statement* {
if (!sem.Get(s->condition)->HasSideEffects()) {
return nullptr;
}
ast::StatementList stmts;
ctx.Replace(s->condition, Decompose(s->condition, &stmts));
InsertBefore(stmts, s);
return ctx.CloneWithoutTransform(s);
},
[&](const ast::IfStatement* s) -> const ast::Statement* {
if (!sem.Get(s->condition)->HasSideEffects()) {
return nullptr;

View File

@@ -999,6 +999,45 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InWhileCond) {
auto* src = R"(
fn a(i : i32) -> i32 {
return i;
}
fn f() {
var b = 1;
while(a(0) + b > 0) {
var marker = 0;
}
}
)";
auto* expect = R"(
fn a(i : i32) -> i32 {
return i;
}
fn f() {
var b = 1;
loop {
let tint_symbol = a(0);
if (!(((tint_symbol + b) > 0))) {
break;
}
{
var marker = 0;
}
}
}
)";
DataMap data;
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Arith_InElseIf) {
auto* src = R"(
fn a(i : i32) -> i32 {
@@ -2299,6 +2338,48 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InWhileCond) {
auto* src = R"(
fn a(i : i32) -> bool {
return true;
}
fn f() {
var b = true;
while(a(0) && b) {
var marker = 0;
}
}
)";
auto* expect = R"(
fn a(i : i32) -> bool {
return true;
}
fn f() {
var b = true;
loop {
var tint_symbol = a(0);
if (tint_symbol) {
tint_symbol = b;
}
if (!(tint_symbol)) {
break;
}
{
var marker = 0;
}
}
}
)";
DataMap data;
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, Binary_Logical_InElseIf) {
auto* src = R"(
fn a(i : i32) -> bool {

View File

@@ -25,6 +25,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/loop_statement.h"
#include "src/tint/sem/switch_statement.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/utils/map.h"
@@ -49,7 +50,7 @@ class State {
// Find whether first parent is a switch or a loop
auto* sem_stmt = sem.Get(cont);
auto* sem_parent = sem_stmt->FindFirstParent<sem::SwitchStatement, sem::LoopBlockStatement,
sem::ForLoopStatement>();
sem::ForLoopStatement, sem::WhileStatement>();
if (!sem_parent) {
return nullptr;
}

View File

@@ -559,5 +559,59 @@ fn f() {
EXPECT_EQ(expect, str(got));
}
TEST_F(RemoveContinueInSwitchTest, While) {
auto* src = R"(
fn f() {
var i = 0;
while (i < 4) {
let marker1 = 0;
switch(i) {
case 0: {
continue;
break;
}
default: {
break;
}
}
let marker2 = 0;
break;
}
}
)";
auto* expect = R"(
fn f() {
var i = 0;
while((i < 4)) {
let marker1 = 0;
var tint_continue : bool = false;
switch(i) {
case 0: {
{
tint_continue = true;
break;
}
break;
}
default: {
break;
}
}
if (tint_continue) {
continue;
}
let marker2 = 0;
break;
}
}
)";
DataMap data;
auto got = Run<RemoveContinueInSwitch>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform

View File

@@ -262,6 +262,15 @@ class State {
}
return nullptr;
},
[&](const ast::WhileStatement* s) -> const ast::Statement* {
if (MayDiscard(sem.Get(s->condition))) {
TINT_ICE(Transform, b.Diagnostics())
<< "Unexpected WhileStatement condition that may discard. "
"Make sure transform::PromoteSideEffectsToDecl was run "
"first.";
}
return nullptr;
},
[&](const ast::IfStatement* s) -> const ast::Statement* {
auto* sem_expr = sem.Get(s->condition);
if (!MayDiscard(sem_expr)) {

View File

@@ -800,6 +800,67 @@ fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, While_Cond) {
auto* src = R"(
fn f() -> i32 {
if (true) {
discard;
}
return 42;
}
@fragment
fn main(@builtin(position) coord_in: vec4<f32>) -> @location(0) vec4<f32> {
let marker1 = 0;
while (f() == 42) {
let marker2 = 0;
break;
}
return vec4<f32>();
}
)";
auto* expect = R"(
var<private> tint_discard : bool = false;
fn f() -> i32 {
if (true) {
tint_discard = true;
return i32();
}
return 42;
}
fn tint_discard_func() {
discard;
}
@fragment
fn main(@builtin(position) coord_in : vec4<f32>) -> @location(0) vec4<f32> {
let marker1 = 0;
loop {
let tint_symbol = f();
if (tint_discard) {
tint_discard_func();
return vec4<f32>();
}
if (!((tint_symbol == 42))) {
break;
}
{
let marker2 = 0;
break;
}
}
return vec4<f32>();
}
)";
DataMap data;
auto got = Run<PromoteSideEffectsToDecl, UnwindDiscardFunctions>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(UnwindDiscardFunctionsTest, Switch) {
auto* src = R"(
fn f() -> i32 {

View File

@@ -22,6 +22,7 @@
#include "src/tint/sem/if_statement.h"
#include "src/tint/sem/reference.h"
#include "src/tint/sem/variable.h"
#include "src/tint/sem/while_statement.h"
#include "src/tint/utils/reverse.h"
namespace tint::transform {
@@ -46,7 +47,10 @@ class HoistToDeclBefore::State {
};
/// For-loops that need to be decomposed to loops.
std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
std::unordered_map<const sem::ForLoopStatement*, LoopInfo> for_loops;
/// Whiles that need to be decomposed to loops.
std::unordered_map<const sem::WhileStatement*, LoopInfo> while_loops;
/// 'else if' statements that need to be decomposed to 'else {if}'
std::unordered_map<const ast::IfStatement*, ElseIfInfo> else_ifs;
@@ -55,7 +59,7 @@ class HoistToDeclBefore::State {
// registered declaration statements before the condition or continuing
// statement.
void ForLoopsToLoops() {
if (loops.empty()) {
if (for_loops.empty()) {
return;
}
@@ -64,7 +68,7 @@ class HoistToDeclBefore::State {
auto& sem = ctx.src->Sem();
if (auto* fl = sem.Get(stmt)) {
if (auto it = loops.find(fl); it != loops.end()) {
if (auto it = for_loops.find(fl); it != for_loops.end()) {
auto& info = it->second;
auto* for_loop = fl->Declaration();
// For-loop needs to be decomposed to a loop.
@@ -108,6 +112,51 @@ class HoistToDeclBefore::State {
});
}
// Converts any while-loops marked for conversion to loops, inserting
// registered declaration statements before the condition.
void WhilesToLoops() {
if (while_loops.empty()) {
return;
}
// At least one while needs to be transformed into a loop.
ctx.ReplaceAll([&](const ast::WhileStatement* stmt) -> const ast::Statement* {
auto& sem = ctx.src->Sem();
if (auto* w = sem.Get(stmt)) {
if (auto it = while_loops.find(w); it != while_loops.end()) {
auto& info = it->second;
auto* while_loop = w->Declaration();
// While needs to be decomposed to a loop.
// Build the loop body's statements.
// Start with any let declarations for the conditional
// expression.
auto body_stmts = info.cond_decls;
// Emit the condition as:
// if (!cond) { break; }
auto* cond = while_loop->condition;
// !condition
auto* not_cond =
b.create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
// { break; }
auto* break_body = b.Block(b.create<ast::BreakStatement>());
// if (!condition) { break; }
body_stmts.emplace_back(b.If(not_cond, break_body));
// Next emit the body
body_stmts.emplace_back(ctx.Clone(while_loop->body));
const ast::BlockStatement* continuing = nullptr;
auto* body = b.Block(body_stmts);
auto* loop = b.Loop(body, continuing);
return loop;
}
}
return nullptr;
});
}
void ElseIfsToElseWithNestedIfs() {
// Decompose 'else-if' statements into 'else { if }' blocks.
ctx.ReplaceAll([&](const ast::IfStatement* else_if) -> const ast::Statement* {
@@ -192,7 +241,19 @@ class HoistToDeclBefore::State {
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
auto& decls = loops[fl].cond_decls;
auto& decls = for_loops[fl].cond_decls;
if (stmt) {
decls.emplace_back(stmt);
}
return true;
}
if (auto* w = before_stmt->As<sem::WhileStatement>()) {
// Insertion point is a while condition.
// While needs to be decomposed to a loop.
// Index the map to convert this while, even if `stmt` is nullptr.
auto& decls = while_loops[w].cond_decls;
if (stmt) {
decls.emplace_back(stmt);
}
@@ -227,7 +288,7 @@ class HoistToDeclBefore::State {
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
auto& decls = loops[fl].cont_decls;
auto& decls = for_loops[fl].cont_decls;
if (stmt) {
decls.emplace_back(stmt);
}
@@ -257,6 +318,7 @@ class HoistToDeclBefore::State {
/// @return true on success
bool Apply() {
ForLoopsToLoops();
WhilesToLoops();
ElseIfsToElseWithNestedIfs();
return true;
}

View File

@@ -175,6 +175,47 @@ fn f() {
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, WhileCond) {
// fn f() {
// var a : bool;
// while(a) {
// }
// }
ProgramBuilder b;
auto* var = b.Decl(b.Var("a", b.ty.bool_()));
auto* expr = b.Expr("a");
auto* s = b.While(expr, b.Block());
b.Func("f", {}, b.ty.void_(), {var, s});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
HoistToDeclBefore hoistToDeclBefore(ctx);
auto* sem_expr = ctx.src->Sem().Get(expr);
hoistToDeclBefore.Add(sem_expr, expr, true);
hoistToDeclBefore.Apply();
ctx.Clone();
Program cloned(std::move(cloned_b));
auto* expect = R"(
fn f() {
var a : bool;
loop {
let tint_symbol = a;
if (!(tint_symbol)) {
break;
}
{
}
}
}
)";
EXPECT_EQ(expect, str(cloned));
}
TEST_F(HoistToDeclBeforeTest, ElseIf) {
// fn f() {
// var a : bool;

View File

@@ -0,0 +1,67 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/while_to_loop.h"
#include "src/tint/ast/break_statement.h"
#include "src/tint/program_builder.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::WhileToLoop);
namespace tint::transform {
WhileToLoop::WhileToLoop() = default;
WhileToLoop::~WhileToLoop() = default;
bool WhileToLoop::ShouldRun(const Program* program, const DataMap&) const {
for (auto* node : program->ASTNodes().Objects()) {
if (node->Is<ast::WhileStatement>()) {
return true;
}
}
return false;
}
void WhileToLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) const {
ctx.ReplaceAll([&](const ast::WhileStatement* w) -> const ast::Statement* {
ast::StatementList stmts;
auto* cond = w->condition;
// !condition
auto* not_cond =
ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot, ctx.Clone(cond));
// { break; }
auto* break_body = ctx.dst->Block(ctx.dst->create<ast::BreakStatement>());
// if (!condition) { break; }
stmts.emplace_back(ctx.dst->If(not_cond, break_body));
for (auto* stmt : w->body->statements) {
stmts.emplace_back(ctx.Clone(stmt));
}
const ast::BlockStatement* continuing = nullptr;
auto* body = ctx.dst->Block(stmts);
auto* loop = ctx.dst->create<ast::LoopStatement>(body, continuing);
return loop;
});
ctx.Clone();
}
} // namespace tint::transform

View File

@@ -0,0 +1,49 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_
#define SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_
#include "src/tint/transform/transform.h"
namespace tint::transform {
/// WhileToLoop is a Transform that converts a while statement into a loop
/// statement. This is required by the SPIR-V writer.
class WhileToLoop final : public Castable<WhileToLoop, Transform> {
public:
/// Constructor
WhileToLoop();
/// Destructor
~WhileToLoop() override;
/// @param program the program to inspect
/// @param data optional extra transform-specific input data
/// @returns true if this transform should be run for the given program
bool ShouldRun(const Program* program, const DataMap& data = {}) const override;
protected:
/// Runs the transform using the CloneContext built for transforming a
/// program. Run() is responsible for calling Clone() on the CloneContext.
/// @param ctx the CloneContext primed with the input program and
/// ProgramBuilder
/// @param inputs optional extra transform-specific input data
/// @param outputs optional extra transform-specific output data
void Run(CloneContext& ctx, const DataMap& inputs, DataMap& outputs) const override;
};
} // namespace tint::transform
#endif // SRC_TINT_TRANSFORM_WHILE_TO_LOOP_H_

View File

@@ -0,0 +1,129 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/while_to_loop.h"
#include "src/tint/transform/test_helper.h"
namespace tint::transform {
namespace {
using WhileToLoopTest = TransformTest;
TEST_F(WhileToLoopTest, ShouldRunEmptyModule) {
auto* src = R"()";
EXPECT_FALSE(ShouldRun<WhileToLoop>(src));
}
TEST_F(WhileToLoopTest, ShouldRunHasWhile) {
auto* src = R"(
fn f() {
while (true) {
break;
}
}
)";
EXPECT_TRUE(ShouldRun<WhileToLoop>(src));
}
TEST_F(WhileToLoopTest, EmptyModule) {
auto* src = "";
auto* expect = src;
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test an empty for loop.
TEST_F(WhileToLoopTest, Empty) {
auto* src = R"(
fn f() {
while (true) {
break;
}
}
)";
auto* expect = R"(
fn f() {
loop {
if (!(true)) {
break;
}
break;
}
}
)";
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a for loop with non-empty body.
TEST_F(WhileToLoopTest, Body) {
auto* src = R"(
fn f() {
while (true) {
discard;
}
}
)";
auto* expect = R"(
fn f() {
loop {
if (!(true)) {
break;
}
discard;
}
}
)";
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
// Test a loop with a break condition
TEST_F(WhileToLoopTest, BreakCondition) {
auto* src = R"(
fn f() {
while (0 == 1) {
}
}
)";
auto* expect = R"(
fn f() {
loop {
if (!((0 == 1))) {
break;
}
}
}
)";
auto got = Run<WhileToLoop>(src);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::transform