tint/transform/utils: Correctly scope for-loop init

When using HoistToDeclBefore on a for-loop initializer, the inserted
statement would be scoped outside the for-loop. This was incorrect.

Change-Id: I764d07068e907cc203145ac8d6f0110b1b73e667
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/122301
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Ben Clayton <bclayton@chromium.org>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2023-03-02 17:18:27 +00:00 committed by Dawn LUCI CQ
parent b990d393f5
commit 3cde73cb1a
21 changed files with 233 additions and 111 deletions

View File

@ -436,17 +436,33 @@ fn a_U_X_X(pre : i32, p : U_X_X, post : i32) -> vec4<i32> {
}
fn b() {
{
let ptr_index_save = first();
for(let p1 = &(U[ptr_index_save]); true; ) {
let p1 = &(U[ptr_index_save]);
loop {
if (!(true)) {
break;
}
{
a_U_X_X(10, U_X_X(u32(ptr_index_save), u32(second())), 20);
}
}
}
}
fn c_U() {
{
let ptr_index_save_1 = first();
for(let p1 = &(U[ptr_index_save_1]); true; ) {
let p1 = &(U[ptr_index_save_1]);
loop {
if (!(true)) {
break;
}
{
a_U_X_X(10, U_X_X(u32(ptr_index_save_1), u32(second())), 20);
}
}
}
}
fn d() {

View File

@ -391,11 +391,16 @@ fn idx2() -> i32 {
}
fn main() {
{
let tint_symbol = &(a[idx1()]);
let tint_symbol_1 = idx2();
for((*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1); ; ) {
(*(tint_symbol))[tint_symbol_1] = ((*(tint_symbol))[tint_symbol_1] + 1);
loop {
{
break;
}
}
}
}
)";

View File

@ -272,10 +272,15 @@ fn f() {
auto* expect = R"(
fn f() {
var insert_after = 1;
{
let tint_symbol : array<f32, 4u> = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
for(var i = tint_symbol[insert_after]; ; ) {
var i = tint_symbol[insert_after];
loop {
{
break;
}
}
}
}
)";
@ -301,10 +306,15 @@ fn f() {
const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
let runtime_value = 1;
var insert_after = 1;
{
let tint_symbol : array<f32, 4u> = arr;
for(var i = tint_symbol[runtime_value]; ; ) {
var i = tint_symbol[runtime_value];
loop {
{
break;
}
}
}
}
)";
@ -332,10 +342,15 @@ const arr = array<f32, 4u>(0.0, 1.0, 2.0, 3.0);
fn f() {
let runtime_value = 1;
var insert_after = 1;
{
let tint_symbol : array<f32, 4u> = arr;
for(var i = tint_symbol[runtime_value]; ; ) {
var i = tint_symbol[runtime_value];
loop {
{
break;
}
}
}
}
)";
@ -377,10 +392,15 @@ fn get_b_runtime(s : S) -> f32 {
fn f() {
var insert_after = 1;
{
let tint_symbol : S = S(1, 2.0, vec3<f32>());
for(var x = get_b_runtime(tint_symbol); ; ) {
var x = get_b_runtime(tint_symbol);
loop {
{
break;
}
}
}
}
)";
@ -412,10 +432,15 @@ struct S {
auto* expect = R"(
fn f() {
var insert_after = 1;
{
let tint_symbol : S = S(1, 2.0, vec3<f32>());
for(var x = get_b_runtime(tint_symbol); ; ) {
var x = get_b_runtime(tint_symbol);
loop {
{
break;
}
}
}
}
fn get_b_runtime(s : S) -> f32 {
@ -683,8 +708,8 @@ fn f() {
auto* expect = R"(
fn f() {
let runtime_value = 0;
let tint_symbol : array<f32, 1u> = array<f32, 1u>(0.0);
{
let tint_symbol : array<f32, 1u> = array<f32, 1u>(0.0);
var f = tint_symbol[runtime_value];
loop {
let tint_symbol_1 : array<f32, 1u> = array<f32, 1u>(1.0);
@ -728,8 +753,8 @@ fn f() {
const arr_a = array<f32, 1u>(0.0);
const arr_b = array<f32, 1u>(1.0);
const arr_c = array<f32, 1u>(2.0);
let tint_symbol : array<f32, 1u> = arr_a;
{
let tint_symbol : array<f32, 1u> = arr_a;
var f = tint_symbol[runtime_value];
loop {
let tint_symbol_1 : array<f32, 1u> = arr_b;
@ -1301,10 +1326,22 @@ fn Y() {
fn Z() {
var i = 10;
{
let tint_symbol_2 : array<i32, 1u> = array<i32, 1u>(i);
for(var f = tint_symbol_2[0]; (f < 10); f = (f + 1)) {
var f = tint_symbol_2[0];
loop {
if (!((f < 10))) {
break;
}
{
var i = 20;
}
continuing {
f = (f + 1);
}
}
}
}
)";

View File

@ -848,11 +848,16 @@ fn a(i : i32) -> i32 {
fn f() {
var b = 1;
{
let tint_symbol = a(0);
for(var r = (tint_symbol + b); ; ) {
var r = (tint_symbol + b);
loop {
{
var marker = 0;
break;
}
}
}
}
)";
@ -2169,14 +2174,19 @@ fn a(i : i32) -> bool {
fn f() {
var b = true;
{
var tint_symbol = a(0);
if (tint_symbol) {
tint_symbol = b;
}
for(var r = tint_symbol; ; ) {
var r = tint_symbol;
loop {
{
var marker = 0;
break;
}
}
}
}
)";

View File

@ -112,6 +112,7 @@ struct HoistToDeclBefore::State {
/// loop, so that declaration statements can be inserted before the
/// condition expression or continuing statement.
struct LoopInfo {
utils::Vector<StmtBuilder, 8> init_decls;
utils::Vector<StmtBuilder, 8> cond_decls;
utils::Vector<StmtBuilder, 8> cont_decls;
};
@ -198,7 +199,7 @@ struct HoistToDeclBefore::State {
// Next emit the for-loop body
body_stmts.Push(ctx.Clone(for_loop->body));
// Finally create the continuing block if there was one.
// Create the continuing block if there was one.
const ast::BlockStatement* continuing = nullptr;
if (auto* cont = for_loop->continuing) {
// Continuing block starts with any let declarations used by
@ -210,8 +211,17 @@ struct HoistToDeclBefore::State {
auto* body = b.Block(body_stmts);
auto* loop = b.Loop(body, continuing);
// If the loop has no initializer statements, then we're done.
// Otherwise, wrap loop with another block, prefixed with the initializer
// statements
if (!info->init_decls.IsEmpty() || for_loop->initializer) {
auto stmts = Build(info->init_decls);
if (auto* init = for_loop->initializer) {
return b.Block(ctx.Clone(init), loop);
stmts.Push(ctx.Clone(init));
}
stmts.Push(loop);
return b.Block(std::move(stmts));
}
return loop;
}
@ -299,7 +309,7 @@ struct HoistToDeclBefore::State {
// Need to convert 'else if' to 'else { if }'.
auto else_if_info = ElseIf(else_if->Declaration());
// Index the map to convert this else if, even if `stmt` is nullptr.
// Index the map to decompose this else if, even if `stmt` is nullptr.
auto& decls = else_if_info->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
@ -311,7 +321,7 @@ struct HoistToDeclBefore::State {
// Insertion point is a for-loop condition.
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
// Index the map to decompose this for-loop, even if `stmt` is nullptr.
auto& decls = ForLoop(fl)->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
@ -323,7 +333,7 @@ struct HoistToDeclBefore::State {
// Insertion point is a while condition.
// While needs to be decomposed to a loop.
// Index the map to convert this while, even if `stmt` is nullptr.
// Index the map to decompose this while, even if `stmt` is nullptr.
auto& decls = WhileLoop(w)->cond_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
@ -348,11 +358,14 @@ struct HoistToDeclBefore::State {
// These require special care.
if (fl->Declaration()->initializer == ip) {
// Insertion point is a for-loop initializer.
// Insert the new statement above the for-loop.
// For-loop needs to be decomposed to a loop.
// Index the map to decompose this for-loop, even if `stmt` is nullptr.
auto& decls = ForLoop(fl)->init_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
ctx.InsertBefore(fl->Block()->Declaration()->statements, fl->Declaration(),
std::forward<BUILDER>(builder));
decls.Push(std::forward<BUILDER>(builder));
}
return true;
}
@ -360,11 +373,12 @@ struct HoistToDeclBefore::State {
// Insertion point is a for-loop continuing statement.
// For-loop needs to be decomposed to a loop.
// Index the map to convert this for-loop, even if `stmt` is nullptr.
// Index the map to decompose this for-loop, even if `stmt` is nullptr.
auto& decls = ForLoop(fl)->cont_decls;
if constexpr (!std::is_same_v<BUILDER, Decompose>) {
decls.Push(std::forward<BUILDER>(builder));
}
return true;
}

View File

@ -82,8 +82,16 @@ TEST_F(HoistToDeclBeforeTest, ForLoopInit) {
auto* expect = R"(
fn f() {
{
var tint_symbol : i32 = 1i;
for(var a = tint_symbol; true; ) {
var a = tint_symbol;
loop {
if (!(true)) {
break;
}
{
}
}
}
}
)";
@ -545,8 +553,16 @@ fn foo() {
}
fn f() {
{
foo();
for(var a = 1i; true; ) {
var a = 1i;
loop {
if (!(true)) {
break;
}
{
}
}
}
}
)";
@ -584,8 +600,16 @@ fn foo() {
}
fn f() {
{
foo();
for(var a = 1i; true; ) {
var a = 1i;
loop {
if (!(true)) {
break;
}
{
}
}
}
}
)";

View File

@ -126,10 +126,15 @@ fn f() {
fn f() {
var i : i32;
let p = array<array<i32, 2>, 2>(array<i32, 2>(1, 2), array<i32, 2>(3, 4));
{
var var_for_index : array<array<i32, 2u>, 2u> = p;
for(let x = var_for_index[i]; ; ) {
let x = var_for_index[i];
loop {
{
break;
}
}
}
}
)";
@ -154,10 +159,15 @@ fn f() {
fn f() {
var i : i32;
let p = mat2x2(1.0, 2.0, 3.0, 4.0);
{
var var_for_index : mat2x2<f32> = p;
for(let x = var_for_index[i]; ; ) {
let x = var_for_index[i];
loop {
{
break;
}
}
}
}
)";

View File

@ -7,14 +7,16 @@ int foo() {
void tint_symbol() {
float arr[4] = float[4](0.0f, 0.0f, 0.0f, 0.0f);
{
int tint_symbol_1 = foo();
int a_save = tint_symbol_1;
while (true) {
{
for(; ; ) {
float x = arr[a_save];
break;
}
}
}
}
void main() {

View File

@ -20,12 +20,16 @@ int foo() {
fragment void tint_symbol() {
tint_array<float, 4> arr = tint_array<float, 4>{};
{
int const tint_symbol_1 = foo();
int const a_save = tint_symbol_1;
for(; ; ) {
while (true) {
{
float const x = arr[a_save];
break;
}
}
}
return;
}

View File

@ -23,8 +23,8 @@ int idx3() {
void foo() {
float a[4] = (float[4])0;
const int tint_symbol_save = idx1();
{
const int tint_symbol_save = idx1();
a[tint_symbol_save] = (a[tint_symbol_save] * 2.0f);
while (true) {
const int tint_symbol_2 = idx2();

View File

@ -23,8 +23,8 @@ int idx3() {
void foo() {
float a[4] = (float[4])0;
const int tint_symbol_save = idx1();
{
const int tint_symbol_save = idx1();
a[tint_symbol_save] = (a[tint_symbol_save] * 2.0f);
while (true) {
const int tint_symbol_2 = idx2();

View File

@ -35,9 +35,9 @@ int idx3() {
void foo() {
float a[4] = float[4](0.0f, 0.0f, 0.0f, 0.0f);
{
int tint_symbol_2 = idx1();
int tint_symbol_save = tint_symbol_2;
{
a[tint_symbol_save] = (a[tint_symbol_save] * 2.0f);
while (true) {
int tint_symbol_3 = idx2();

View File

@ -37,9 +37,9 @@ int idx3(thread uint* const tint_symbol_7) {
void foo(thread uint* const tint_symbol_8) {
tint_array<float, 4> a = tint_array<float, 4>{};
{
int const tint_symbol_2 = idx1(tint_symbol_8);
int const tint_symbol_save = tint_symbol_2;
{
a[tint_symbol_save] = (a[tint_symbol_save] * 2.0f);
while (true) {
int const tint_symbol_3 = idx2(tint_symbol_8);

View File

@ -37,10 +37,10 @@ int idx6() {
}
void main() {
{
const int tint_symbol_save = idx1();
const int tint_symbol_save_1 = idx2();
const int tint_symbol_1 = idx3();
{
buffer.Store((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))), asuint((asint(buffer.Load((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))))) - 1)));
while (true) {
if (!((v < 10u))) {

View File

@ -37,10 +37,10 @@ int idx6() {
}
void main() {
{
const int tint_symbol_save = idx1();
const int tint_symbol_save_1 = idx2();
const int tint_symbol_1 = idx3();
{
buffer.Store((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))), asuint((asint(buffer.Load((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))))) - 1)));
while (true) {
if (!((v < 10u))) {

View File

@ -44,12 +44,12 @@ int idx6() {
}
void tint_symbol_1() {
{
int tint_symbol_6 = idx1();
int tint_symbol_7 = idx2();
int tint_symbol_2_save = tint_symbol_6;
int tint_symbol_2_save_1 = tint_symbol_7;
int tint_symbol_3 = idx3();
{
tint_symbol.inner[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3] = (tint_symbol.inner[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3] - 1);
while (true) {
if (!((v < 10u))) {

View File

@ -49,12 +49,12 @@ int idx6(thread uint* const tint_symbol_15) {
}
void tint_symbol_1(thread uint* const tint_symbol_16, device tint_array<S, 1>* const tint_symbol_17) {
{
int const tint_symbol_6 = idx1(tint_symbol_16);
int const tint_symbol_7 = idx2(tint_symbol_16);
int const tint_symbol_2_save = tint_symbol_6;
int const tint_symbol_2_save_1 = tint_symbol_7;
int const tint_symbol_3 = idx3(tint_symbol_16);
{
(*(tint_symbol_17))[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3] = as_type<int>((as_type<uint>((*(tint_symbol_17))[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3]) - as_type<uint>(1)));
while (true) {
if (!((*(tint_symbol_16) < 10u))) {

View File

@ -37,10 +37,10 @@ int idx6() {
}
void main() {
{
const int tint_symbol_save = idx1();
const int tint_symbol_save_1 = idx2();
const int tint_symbol_1 = idx3();
{
buffer.Store((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))), asuint((asint(buffer.Load((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))))) + 1)));
while (true) {
if (!((v < 10u))) {

View File

@ -37,10 +37,10 @@ int idx6() {
}
void main() {
{
const int tint_symbol_save = idx1();
const int tint_symbol_save_1 = idx2();
const int tint_symbol_1 = idx3();
{
buffer.Store((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))), asuint((asint(buffer.Load((((64u * uint(tint_symbol_save)) + (16u * uint(tint_symbol_save_1))) + (4u * uint(tint_symbol_1))))) + 1)));
while (true) {
if (!((v < 10u))) {

View File

@ -44,12 +44,12 @@ int idx6() {
}
void tint_symbol_1() {
{
int tint_symbol_6 = idx1();
int tint_symbol_7 = idx2();
int tint_symbol_2_save = tint_symbol_6;
int tint_symbol_2_save_1 = tint_symbol_7;
int tint_symbol_3 = idx3();
{
tint_symbol.inner[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3] = (tint_symbol.inner[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3] + 1);
while (true) {
if (!((v < 10u))) {

View File

@ -49,12 +49,12 @@ int idx6(thread uint* const tint_symbol_15) {
}
void tint_symbol_1(thread uint* const tint_symbol_16, device tint_array<S, 1>* const tint_symbol_17) {
{
int const tint_symbol_6 = idx1(tint_symbol_16);
int const tint_symbol_7 = idx2(tint_symbol_16);
int const tint_symbol_2_save = tint_symbol_6;
int const tint_symbol_2_save_1 = tint_symbol_7;
int const tint_symbol_3 = idx3(tint_symbol_16);
{
(*(tint_symbol_17))[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3] = as_type<int>((as_type<uint>((*(tint_symbol_17))[tint_symbol_2_save].a[tint_symbol_2_save_1][tint_symbol_3]) + as_type<uint>(1)));
while (true) {
if (!((*(tint_symbol_16) < 10u))) {