tint/hlsl: for default-only switch, only emit condition if it has side-effects

This fixes edge-cases, like the condition expression being a type-cast,
which DXC apparently sees as a variable re-declaration. Example:

fn foo(x : f32) {
  switch (i32(x)) {
    default {
    }
  }
}

was emitted as HLSL:

void foo(float x) {
  int(x);
  do {
  } while (false);
}

The `int(x)` is seen as a re-declaration of `x` by DXC.

We fix this by only emitted the condition expression if it has
side-effects (which currently means it contains a call expression).

Bug: tint:1820
Change-Id: I7e4320fa09ea2d634c9e324cb0b752b0ee7dcde9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/118161
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2023-02-01 15:46:34 +00:00 committed by Dawn LUCI CQ
parent 98bd83a8fc
commit eab1f62629
11 changed files with 271 additions and 8 deletions

View File

@ -3859,10 +3859,9 @@ bool GeneratorImpl::EmitDefaultOnlySwitch(const ast::SwitchStatement* stmt) {
// default case body. We work around this here by emitting the default case
// without the switch.
// Emit the switch condition as-is in case it has side-effects (e.g.
// function call). Note that's it's fine not to assign the result of the
// expression.
{
// Emit the switch condition as-is if it has side-effects (e.g.
// function call). Note that we can ignore the result of the expression (if any).
if (auto* sem_cond = builder_.Sem().Get(stmt->condition); sem_cond->HasSideEffects()) {
auto out = line();
if (!EmitExpression(out, stmt->condition)) {
return false;

View File

@ -66,7 +66,16 @@ TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_MixedDefault) {
)");
}
TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase) {
TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase_NoSideEffectsCondition) {
// var<private> cond : i32;
// var<private> a : i32;
// fn test() {
// switch(cond) {
// default: {
// a = 42;
// }
// }
// }
GlobalVar("cond", ty.i32(), type::AddressSpace::kPrivate);
GlobalVar("a", ty.i32(), type::AddressSpace::kPrivate);
auto* s = Switch( //
@ -79,7 +88,45 @@ TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase) {
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
EXPECT_EQ(gen.result(), R"( cond;
EXPECT_EQ(gen.result(), R"( do {
a = 42;
} while (false);
)");
}
TEST_F(HlslGeneratorImplTest_Switch, Emit_Switch_OnlyDefaultCase_SideEffectsCondition) {
// var<private> global : i32;
// fn bar() -> i32 {
// global = 84;
// return global;
// }
//
// var<private> a : i32;
// fn test() {
// switch(bar()) {
// default: {
// a = 42;
// }
// }
// }
GlobalVar("global", ty.i32(), type::AddressSpace::kPrivate);
Func("bar", {}, ty.i32(),
utils::Vector{ //
Assign("global", Expr(84_i)), //
Return("global")});
GlobalVar("a", ty.i32(), type::AddressSpace::kPrivate);
auto* s = Switch( //
Call("bar"), //
DefaultCase(Block(Assign(Expr("a"), Expr(42_i)))));
WrapInFunction(s);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(s)) << gen.error();
EXPECT_EQ(gen.result(), R"( bar();
do {
a = 42;
} while (false);

View File

@ -0,0 +1,22 @@
fn foo(x : f32) {
switch (i32(x)) {
default {
}
}
}
var<private> global : i32;
fn baz(x : i32) -> i32 {
global = 42;
return x;
}
fn bar(x : f32) {
switch (baz(i32(x))) {
default {
}
}
}
fn main() {
}

View File

@ -0,0 +1,25 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void foo(float x) {
do {
} while (false);
}
static int global = 0;
int baz(int x) {
global = 42;
return x;
}
void bar(float x) {
baz(int(x));
do {
} while (false);
}
void main() {
}

View File

@ -0,0 +1,25 @@
[numthreads(1, 1, 1)]
void unused_entry_point() {
return;
}
void foo(float x) {
do {
} while (false);
}
static int global = 0;
int baz(int x) {
global = 42;
return x;
}
void bar(float x) {
baz(int(x));
do {
} while (false);
}
void main() {
}

View File

@ -0,0 +1,31 @@
#version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void unused_entry_point() {
return;
}
void foo(float x) {
switch(int(x)) {
default: {
break;
}
}
}
int global = 0;
int baz(int x) {
global = 42;
return x;
}
void bar(float x) {
switch(baz(int(x))) {
default: {
break;
}
}
}
void tint_symbol() {
}

View File

@ -0,0 +1,28 @@
#include <metal_stdlib>
using namespace metal;
void foo(float x) {
switch(int(x)) {
default: {
break;
}
}
}
int baz(int x) {
thread int tint_symbol_1 = 0;
tint_symbol_1 = 42;
return x;
}
void bar(float x) {
switch(baz(int(x))) {
default: {
break;
}
}
}
void tint_symbol() {
}

View File

@ -0,0 +1,65 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 31
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
OpExecutionMode %unused_entry_point LocalSize 1 1 1
OpName %global "global"
OpName %unused_entry_point "unused_entry_point"
OpName %foo "foo"
OpName %x "x"
OpName %baz "baz"
OpName %x_0 "x"
OpName %bar "bar"
OpName %x_1 "x"
OpName %main "main"
%int = OpTypeInt 32 1
%_ptr_Private_int = OpTypePointer Private %int
%4 = OpConstantNull %int
%global = OpVariable %_ptr_Private_int Private %4
%void = OpTypeVoid
%5 = OpTypeFunction %void
%float = OpTypeFloat 32
%9 = OpTypeFunction %void %float
%17 = OpTypeFunction %int %int
%int_42 = OpConstant %int 42
%unused_entry_point = OpFunction %void None %5
%8 = OpLabel
OpReturn
OpFunctionEnd
%foo = OpFunction %void None %9
%x = OpFunctionParameter %float
%13 = OpLabel
%15 = OpConvertFToS %int %x
OpSelectionMerge %14 None
OpSwitch %15 %16
%16 = OpLabel
OpBranch %14
%14 = OpLabel
OpReturn
OpFunctionEnd
%baz = OpFunction %int None %17
%x_0 = OpFunctionParameter %int
%20 = OpLabel
OpStore %global %int_42
OpReturnValue %x_0
OpFunctionEnd
%bar = OpFunction %void None %9
%x_1 = OpFunctionParameter %float
%24 = OpLabel
%27 = OpConvertFToS %int %x_1
%26 = OpFunctionCall %int %baz %27
OpSelectionMerge %25 None
OpSwitch %26 %28
%28 = OpLabel
OpBranch %25
%25 = OpLabel
OpReturn
OpFunctionEnd
%main = OpFunction %void None %5
%30 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,23 @@
fn foo(x : f32) {
switch(i32(x)) {
default: {
}
}
}
var<private> global : i32;
fn baz(x : i32) -> i32 {
global = 42;
return x;
}
fn bar(x : f32) {
switch(baz(i32(x))) {
default: {
}
}
}
fn main() {
}

View File

@ -2,7 +2,6 @@
void f() {
int i = 0;
int result = 0;
i;
do {
result = 44;
break;

View File

@ -2,7 +2,6 @@
void f() {
int i = 0;
int result = 0;
i;
do {
result = 44;
break;