PromoteSideEffectsToDecl: add decomposing 'else if's to 'if { else }'

Just as we do with for loops that need decomposing to loops, we must
also decompose 'else if's to 'else { if }' so that we can insert decls
above the condition.

Bug: tint:1300
Change-Id: Ia16f1cf351964817587d353e58a02d9ae6f8386c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/77500
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano 2022-01-24 15:18:59 +00:00 committed by Tint LUCI CQ
parent 8db439d848
commit 4183a574b0
7 changed files with 390 additions and 2 deletions

View File

@ -18,6 +18,7 @@
#include "src/resolver/resolver_test_helper.h" #include "src/resolver/resolver_test_helper.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/for_loop_statement.h" #include "src/sem/for_loop_statement.h"
#include "src/sem/if_statement.h"
namespace tint { namespace tint {
namespace resolver { namespace resolver {

View File

@ -29,6 +29,10 @@ IfStatement::IfStatement(const ast::IfStatement* declaration,
IfStatement::~IfStatement() = default; IfStatement::~IfStatement() = default;
const ast::IfStatement* IfStatement::Declaration() const {
return static_cast<const ast::IfStatement*>(Base::Declaration());
}
ElseStatement::ElseStatement(const ast::ElseStatement* declaration, ElseStatement::ElseStatement(const ast::ElseStatement* declaration,
const CompoundStatement* parent, const CompoundStatement* parent,
const sem::Function* function) const sem::Function* function)

View File

@ -45,6 +45,9 @@ class IfStatement : public Castable<IfStatement, CompoundStatement> {
/// Destructor /// Destructor
~IfStatement() override; ~IfStatement() override;
/// @returns the AST node
const ast::IfStatement* Declaration() const;
/// @returns the if-statement condition expression /// @returns the if-statement condition expression
const Expression* Condition() const { return condition_; } const Expression* Condition() const { return condition_; }

View File

@ -23,8 +23,10 @@ namespace tint {
namespace ast { namespace ast {
class CallExpression; class CallExpression;
class Expression; class Expression;
class ElseStatement;
class ForLoopStatement; class ForLoopStatement;
class Function; class Function;
class IfStatement;
class MemberAccessorExpression; class MemberAccessorExpression;
class Node; class Node;
class Statement; class Statement;
@ -40,8 +42,10 @@ namespace sem {
class Array; class Array;
class Call; class Call;
class Expression; class Expression;
class ElseStatement;
class ForLoopStatement; class ForLoopStatement;
class Function; class Function;
class IfStatement;
class MemberAccessorExpression; class MemberAccessorExpression;
class Node; class Node;
class Statement; class Statement;
@ -58,8 +62,10 @@ struct TypeMappings {
//! @cond Doxygen_Suppress //! @cond Doxygen_Suppress
Call* operator()(ast::CallExpression*); Call* operator()(ast::CallExpression*);
Expression* operator()(ast::Expression*); Expression* operator()(ast::Expression*);
ElseStatement* operator()(ast::ElseStatement*);
ForLoopStatement* operator()(ast::ForLoopStatement*); ForLoopStatement* operator()(ast::ForLoopStatement*);
Function* operator()(ast::Function*); Function* operator()(ast::Function*);
IfStatement* operator()(ast::IfStatement*);
MemberAccessorExpression* operator()(ast::MemberAccessorExpression*); MemberAccessorExpression* operator()(ast::MemberAccessorExpression*);
Node* operator()(ast::Node*); Node* operator()(ast::Node*);
Statement* operator()(ast::Statement*); Statement* operator()(ast::Statement*);

View File

@ -23,8 +23,10 @@
#include "src/sem/call.h" #include "src/sem/call.h"
#include "src/sem/expression.h" #include "src/sem/expression.h"
#include "src/sem/for_loop_statement.h" #include "src/sem/for_loop_statement.h"
#include "src/sem/if_statement.h"
#include "src/sem/statement.h" #include "src/sem/statement.h"
#include "src/sem/type_constructor.h" #include "src/sem/type_constructor.h"
#include "src/utils/reverse.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteSideEffectsToDecl); TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteSideEffectsToDecl);
TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteSideEffectsToDecl::Config); TINT_INSTANTIATE_TYPEINFO(tint::transform::PromoteSideEffectsToDecl::Config);
@ -47,16 +49,41 @@ class PromoteSideEffectsToDecl::State {
ast::StatementList cont_decls; ast::StatementList cont_decls;
}; };
/// Holds information about 'if's with 'else-if' statements that need to be
/// decomposed into 'if {else}' so that declaration statements can be inserted
/// before the condition expression.
struct IfInfo {
/// Info for each else-if that needs decomposing
struct ElseIfInfo {
/// Decls to insert before condition
ast::StatementList cond_decls;
};
/// 'else if's that need to be decomposed to 'else { if }'
std::unordered_map<const sem::ElseStatement*, ElseIfInfo> else_ifs;
};
// For-loops that need to be decomposed to loops. // For-loops that need to be decomposed to loops.
std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops; std::unordered_map<const sem::ForLoopStatement*, LoopInfo> loops;
/// If statements with 'else if's that need to be decomposed to 'else { if }'
std::unordered_map<const sem::IfStatement*, IfInfo> ifs;
// Inserts `decl` before `sem_expr`, possibly marking a for-loop to be // Inserts `decl` before `sem_expr`, possibly marking a for-loop to be
// converted to a loop. // converted to a loop, or an else-if to an else { if }..
bool InsertBefore(const sem::Expression* sem_expr, bool InsertBefore(const sem::Expression* sem_expr,
const ast::VariableDeclStatement* decl) { const ast::VariableDeclStatement* decl) {
auto* sem_stmt = sem_expr->Stmt(); auto* sem_stmt = sem_expr->Stmt();
auto* stmt = sem_stmt->Declaration(); auto* stmt = sem_stmt->Declaration();
if (auto* else_if = sem_stmt->As<sem::ElseStatement>()) {
// Expression used in 'else if' condition.
// Need to convert 'else if' to 'else { if }'.
auto& if_info = ifs[else_if->Parent()->As<sem::IfStatement>()];
if_info.else_ifs[else_if].cond_decls.push_back(decl);
return true;
}
if (auto* fl = sem_stmt->As<sem::ForLoopStatement>()) { if (auto* fl = sem_stmt->As<sem::ForLoopStatement>()) {
// Expression used in for-loop condition. // Expression used in for-loop condition.
// For-loop needs to be decomposed to a loop. // For-loop needs to be decomposed to a loop.
@ -241,6 +268,80 @@ class PromoteSideEffectsToDecl::State {
}); });
} }
void ElseIfsToElseWithNestedIfs() {
if (ifs.empty()) {
return;
}
ctx.ReplaceAll([&](const ast::IfStatement* if_stmt) //
-> const ast::IfStatement* {
auto& sem = ctx.src->Sem();
auto* sem_if = sem.Get(if_stmt);
if (!sem_if) {
return nullptr;
}
auto it = ifs.find(sem_if);
if (it == ifs.end()) {
return nullptr;
}
auto& if_info = it->second;
// This if statement has "else if"s that need to be converted to "else
// { if }"s
ast::ElseStatementList next_else_stmts;
next_else_stmts.reserve(if_stmt->else_statements.size());
for (auto* else_stmt : utils::Reverse(if_stmt->else_statements)) {
if (else_stmt->condition == nullptr) {
// The last 'else', keep as is
next_else_stmts.insert(next_else_stmts.begin(), ctx.Clone(else_stmt));
} else {
auto* sem_else_if = sem.Get(else_stmt);
auto it2 = if_info.else_ifs.find(sem_else_if);
if (it2 == if_info.else_ifs.end()) {
// 'else if' we don't need to modify (no decls to insert), so
// keep as is
next_else_stmts.insert(next_else_stmts.begin(),
ctx.Clone(else_stmt));
} else {
// 'else if' we need to replace with 'else <decls> { if }'
auto& else_if_info = it2->second;
// Build the else body's statements, starting with let decls for
// the conditional expression
auto& body_stmts = else_if_info.cond_decls;
// Build nested if
body_stmts.emplace_back(b.If(ctx.Clone(else_stmt->condition),
ctx.Clone(else_stmt->body),
next_else_stmts));
// Build else
auto* else_with_nested_if = b.Else(b.Block(body_stmts));
// This will be used in parent if (either another nested if, or
// top-level if)
next_else_stmts = {else_with_nested_if};
}
}
}
// Build a new top-level if with new else statements
if (next_else_stmts.empty()) {
TINT_ICE(Transform, b.Diagnostics())
<< "Expected else statements to insert into new if";
}
auto* new_if = b.If(ctx.Clone(if_stmt->condition),
ctx.Clone(if_stmt->body), next_else_stmts);
return new_if;
});
}
public: public:
/// Constructor /// Constructor
/// @param ctx_in the CloneContext primed with the input program and /// @param ctx_in the CloneContext primed with the input program and
@ -286,6 +387,7 @@ class PromoteSideEffectsToDecl::State {
} }
ForLoopsToLoops(); ForLoopsToLoops();
ElseIfsToElseWithNestedIfs();
ctx.Clone(); ctx.Clone();
} }

View File

@ -23,7 +23,8 @@ namespace transform {
/// A transform that hoists expressions with side-effects to a variable /// A transform that hoists expressions with side-effects to a variable
/// declaration just before the statement of usage. This transform may also /// declaration just before the statement of usage. This transform may also
/// decompose for-loops into loops so that let declarations can be emitted /// decompose for-loops into loops so that let declarations can be emitted
/// before loop condition expressions and/or continuing statements. /// before loop condition expressions and/or continuing statements. It may also
/// similarly decompose 'else if's to 'else { if }'s for the same reason.
/// @see crbug.com/tint/406 /// @see crbug.com/tint/406
class PromoteSideEffectsToDecl class PromoteSideEffectsToDecl
: public Castable<PromoteSideEffectsToDecl, Transform> { : public Castable<PromoteSideEffectsToDecl, Transform> {

View File

@ -276,6 +276,93 @@ fn f() {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(PromoteSideEffectsToDeclTest, TypeCtorToLet_ArrayInElseIf) {
auto* src = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else if (f == array<f32, 2u>(f, f)[0]) {
var marker = 1;
}
}
)";
auto* expect = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else {
let tint_symbol = array<f32, 2u>(f, f);
if ((f == tint_symbol[0])) {
var marker = 1;
}
}
}
)";
DataMap data;
data.Add<PromoteSideEffectsToDecl::Config>(/* type_ctor_to_let */ true,
/* dynamic_index_to_var */ false);
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, TypeCtorToLet_ArrayInElseIfChain) {
auto* src = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else if (f == array<f32, 2u>(f, f)[0]) {
var marker = 2;
} else if (f == array<f32, 2u>(f, f)[1]) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
)";
auto* expect = R"(
fn f() {
var f = 1.0;
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else {
let tint_symbol = array<f32, 2u>(f, f);
if ((f == tint_symbol[0])) {
var marker = 2;
} else {
let tint_symbol_1 = array<f32, 2u>(f, f);
if ((f == tint_symbol_1[1])) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
}
}
)";
DataMap data;
data.Add<PromoteSideEffectsToDecl::Config>(/* type_ctor_to_let */ true,
/* dynamic_index_to_var */ false);
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, TypeCtorToLet_ArrayInArrayArray) { TEST_F(PromoteSideEffectsToDeclTest, TypeCtorToLet_ArrayInArrayArray) {
auto* src = R"( auto* src = R"(
fn f() { fn f() {
@ -638,6 +725,190 @@ fn f() {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(PromoteSideEffectsToDeclTest, DynamicIndexToVar_ArrayIndexInElseIf) {
auto* src = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
if (false) {
var marker = 0;
} else if (p[i] < 3) {
var marker = 1;
}
}
)";
auto* expect = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
if (false) {
var marker = 0;
} else {
var var_for_index = p;
if ((var_for_index[i] < 3)) {
var marker = 1;
}
}
}
)";
DataMap data;
data.Add<PromoteSideEffectsToDecl::Config>(/* type_ctor_to_let */ false,
/* dynamic_index_to_var */ true);
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest,
DynamicIndexToVar_ArrayIndexInElseIfChain) {
auto* src = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else if (p[i] < 3) {
var marker = 2;
} else if (p[i] < 4) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
)";
auto* expect = R"(
fn f() {
var i : i32;
let p = array<i32, 2>(1, 2);
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else {
var var_for_index = p;
if ((var_for_index[i] < 3)) {
var marker = 2;
} else {
var var_for_index_1 = p;
if ((var_for_index_1[i] < 4)) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
}
}
)";
DataMap data;
data.Add<PromoteSideEffectsToDecl::Config>(/* type_ctor_to_let */ false,
/* dynamic_index_to_var */ true);
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, DynamicIndexToVar_MatrixIndexInElseIf) {
auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
if (false) {
var marker_if = 1;
} else if (p[i].x < 3.0) {
var marker_else_if = 1;
}
}
)";
auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
if (false) {
var marker_if = 1;
} else {
var var_for_index = p;
if ((var_for_index[i].x < 3.0)) {
var marker_else_if = 1;
}
}
}
)";
DataMap data;
data.Add<PromoteSideEffectsToDecl::Config>(/* type_ctor_to_let */ false,
/* dynamic_index_to_var */ true);
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest,
DynamicIndexToVar_MatrixIndexInElseIfChain) {
auto* src = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else if (p[i].x < 3.0) {
var marker = 2;
} else if (p[i].y < 3.0) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
)";
auto* expect = R"(
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
if (true) {
var marker = 0;
} else if (true) {
var marker = 1;
} else {
var var_for_index = p;
if ((var_for_index[i].x < 3.0)) {
var marker = 2;
} else {
var var_for_index_1 = p;
if ((var_for_index_1[i].y < 3.0)) {
var marker = 3;
} else if (true) {
var marker = 4;
} else {
var marker = 5;
}
}
}
}
)";
DataMap data;
data.Add<PromoteSideEffectsToDecl::Config>(/* type_ctor_to_let */ false,
/* dynamic_index_to_var */ true);
auto got = Run<PromoteSideEffectsToDecl>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(PromoteSideEffectsToDeclTest, DynamicIndexToVar_ArrayIndexLiteral) { TEST_F(PromoteSideEffectsToDeclTest, DynamicIndexToVar_ArrayIndexLiteral) {
auto* src = R"( auto* src = R"(
fn f() { fn f() {