tint: Fix DemoteToHelper for atomicCmpXchgWeak

We cannot explicitly name the result type of this builtin, so we have
to redeclare it manually.

Fixed: oss-fuzz:53347, oss-fuzz:53343
Change-Id: I23816b8b35eb20ae91472143ab30668b573d65bf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110160
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: James Price <jrprice@google.com>
This commit is contained in:
James Price 2022-11-14 20:30:38 +00:00 committed by Dawn LUCI CQ
parent fbd00b44ee
commit 8cd34f8425
9 changed files with 445 additions and 5 deletions

View File

@ -14,6 +14,7 @@
#include "src/tint/transform/demote_to_helper.h"
#include <unordered_map>
#include <unordered_set>
#include <utility>
@ -24,6 +25,7 @@
#include "src/tint/sem/reference.h"
#include "src/tint/sem/statement.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
#include "src/tint/utils/map.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::DemoteToHelper);
@ -102,6 +104,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
// Mask all writes to host-visible memory using the discarded flag.
// We also insert a discard statement before all return statements in entry points for shaders
// that discard.
std::unordered_map<const sem::Type*, Symbol> atomic_cmpxchg_result_types;
for (auto* node : src->ASTNodes().Objects()) {
Switch(
node,
@ -174,11 +177,51 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&,
// }
// let y = x + tmp;
auto result = b.Sym();
auto result_decl =
b.Decl(b.Var(result, CreateASTTypeFor(ctx, sem_call->Type())));
auto* masked_call =
b.If(b.Not(flag),
b.Block(b.Assign(result, ctx.CloneWithoutTransform(call))));
const ast::Type* result_ty = nullptr;
const ast::Statement* masked_call = nullptr;
if (builtin->Type() == sem::BuiltinType::kAtomicCompareExchangeWeak) {
// Special case for atomicCompareExchangeWeak as we cannot name its
// result type. We have to declare an equivalent struct and copy the
// original member values over to it.
// Declare a struct to hold the result values.
auto* result_struct = sem_call->Type()->As<sem::Struct>();
auto* atomic_ty = result_struct->Members()[0]->Type();
result_ty = b.ty.type_name(
utils::GetOrCreate(atomic_cmpxchg_result_types, atomic_ty, [&]() {
auto name = b.Sym();
b.Structure(
name,
utils::Vector{
b.Member("old_value", CreateASTTypeFor(ctx, atomic_ty)),
b.Member("exchanged", b.ty.bool_()),
});
return name;
}));
// Generate the masked call and member-wise copy:
// if (!tint_discarded) {
// let tmp_result = atomicCompareExchangeWeak(&p, 1, 2);
// result.exchanged = tmp_result.exchanged;
// result.old_value = tmp_result.old_value;
// }
auto tmp_result = b.Sym();
masked_call =
b.If(b.Not(flag),
b.Block(utils::Vector{
b.Decl(b.Let(tmp_result, ctx.CloneWithoutTransform(call))),
b.Assign(b.MemberAccessor(result, "old_value"),
b.MemberAccessor(tmp_result, "old_value")),
b.Assign(b.MemberAccessor(result, "exchanged"),
b.MemberAccessor(tmp_result, "exchanged")),
}));
} else {
result_ty = CreateASTTypeFor(ctx, sem_call->Type());
masked_call =
b.If(b.Not(flag),
b.Block(b.Assign(result, ctx.CloneWithoutTransform(call))));
}
auto* result_decl = b.Decl(b.Var(result, result_ty));
hoist_to_decl_before.Prepare(sem_call);
hoist_to_decl_before.InsertBefore(stmt, result_decl);
hoist_to_decl_before.InsertBefore(stmt, masked_call);

View File

@ -1044,6 +1044,76 @@ fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i3
EXPECT_EQ(expect, str(got));
}
TEST_F(DemoteToHelperTest, AtomicCompareExchangeWeak) {
auto* src = R"(
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@group(0) @binding(2) var<storage, read_write> a : atomic<i32>;
@fragment
fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
if (in == 0.0) {
discard;
}
var result = 0;
if (!atomicCompareExchangeWeak(&a, i32(in), 42).exchanged) {
let xchg = atomicCompareExchangeWeak(&a, i32(in), 42);
result = i32(textureSample(t, s, coord).x) * xchg.old_value;
}
return result;
}
)";
auto* expect = R"(
var<private> tint_discarded = false;
struct tint_symbol_1 {
old_value : i32,
exchanged : bool,
}
@group(0) @binding(0) var t : texture_2d<f32>;
@group(0) @binding(1) var s : sampler;
@group(0) @binding(2) var<storage, read_write> a : atomic<i32>;
@fragment
fn foo(@location(0) in : f32, @location(1) coord : vec2<f32>) -> @location(0) i32 {
if ((in == 0.0)) {
tint_discarded = true;
}
var result = 0;
var tint_symbol : tint_symbol_1;
if (!(tint_discarded)) {
let tint_symbol_2 = atomicCompareExchangeWeak(&(a), i32(in), 42);
tint_symbol.old_value = tint_symbol_2.old_value;
tint_symbol.exchanged = tint_symbol_2.exchanged;
}
if (!(tint_symbol.exchanged)) {
var tint_symbol_3 : tint_symbol_1;
if (!(tint_discarded)) {
let tint_symbol_4 = atomicCompareExchangeWeak(&(a), i32(in), 42);
tint_symbol_3.old_value = tint_symbol_4.old_value;
tint_symbol_3.exchanged = tint_symbol_4.exchanged;
}
let xchg = tint_symbol_3;
result = (i32(textureSample(t, s, coord).x) * xchg.old_value);
}
if (tint_discarded) {
discard;
}
return result;
}
)";
auto got = Run<DemoteToHelper>(src);
EXPECT_EQ(expect, str(got));
}
// Test that no masking is generated for calls to `atomicLoad()`.
TEST_F(DemoteToHelperTest, AtomicLoad) {
auto* src = R"(

View File

@ -0,0 +1,12 @@
@group(0) @binding(0) var<storage, read_write> a : atomic<i32>;
@fragment
fn foo() -> @location(0) i32 {
discard;
var x = 0;
let result = atomicCompareExchangeWeak(&a, 0, 1);
if (result.exchanged) {
x = result.old_value;
}
return x;
}

View File

@ -0,0 +1,50 @@
static bool tint_discarded = false;
struct tint_symbol_2 {
int old_value;
bool exchanged;
};
RWByteAddressBuffer a : register(u0, space0);
struct tint_symbol {
int value : SV_Target0;
};
struct atomic_compare_exchange_weak_ret_type {
int old_value;
bool exchanged;
};
atomic_compare_exchange_weak_ret_type tint_atomicCompareExchangeWeak(RWByteAddressBuffer buffer, uint offset, int compare, int value) {
atomic_compare_exchange_weak_ret_type result=(atomic_compare_exchange_weak_ret_type)0;
buffer.InterlockedCompareExchange(offset, compare, value, result.old_value);
result.exchanged = result.old_value == compare;
return result;
}
int foo_inner() {
tint_discarded = true;
int x = 0;
tint_symbol_2 tint_symbol_1 = (tint_symbol_2)0;
if (!(tint_discarded)) {
const atomic_compare_exchange_weak_ret_type tint_symbol_3 = tint_atomicCompareExchangeWeak(a, 0u, 0, 1);
tint_symbol_1.old_value = tint_symbol_3.old_value;
tint_symbol_1.exchanged = tint_symbol_3.exchanged;
}
const tint_symbol_2 result = tint_symbol_1;
if (result.exchanged) {
x = result.old_value;
}
return x;
}
tint_symbol foo() {
const int inner_result = foo_inner();
tint_symbol wrapper_result = (tint_symbol)0;
wrapper_result.value = inner_result;
if (tint_discarded) {
discard;
}
return wrapper_result;
}

View File

@ -0,0 +1,50 @@
static bool tint_discarded = false;
struct tint_symbol_2 {
int old_value;
bool exchanged;
};
RWByteAddressBuffer a : register(u0, space0);
struct tint_symbol {
int value : SV_Target0;
};
struct atomic_compare_exchange_weak_ret_type {
int old_value;
bool exchanged;
};
atomic_compare_exchange_weak_ret_type tint_atomicCompareExchangeWeak(RWByteAddressBuffer buffer, uint offset, int compare, int value) {
atomic_compare_exchange_weak_ret_type result=(atomic_compare_exchange_weak_ret_type)0;
buffer.InterlockedCompareExchange(offset, compare, value, result.old_value);
result.exchanged = result.old_value == compare;
return result;
}
int foo_inner() {
tint_discarded = true;
int x = 0;
tint_symbol_2 tint_symbol_1 = (tint_symbol_2)0;
if (!(tint_discarded)) {
const atomic_compare_exchange_weak_ret_type tint_symbol_3 = tint_atomicCompareExchangeWeak(a, 0u, 0, 1);
tint_symbol_1.old_value = tint_symbol_3.old_value;
tint_symbol_1.exchanged = tint_symbol_3.exchanged;
}
const tint_symbol_2 result = tint_symbol_1;
if (result.exchanged) {
x = result.old_value;
}
return x;
}
tint_symbol foo() {
const int inner_result = foo_inner();
tint_symbol wrapper_result = (tint_symbol)0;
wrapper_result.value = inner_result;
if (tint_discarded) {
discard;
}
return wrapper_result;
}

View File

@ -0,0 +1,47 @@
#version 310 es
precision mediump float;
struct atomic_compare_exchange_resulti32 {
int old_value;
bool exchanged;
};
bool tint_discarded = false;
struct tint_symbol_1 {
int old_value;
bool exchanged;
};
layout(location = 0) out int value;
layout(binding = 0, std430) buffer a_block_ssbo {
int inner;
} a;
int foo() {
tint_discarded = true;
int x = 0;
tint_symbol_1 tint_symbol = tint_symbol_1(0, false);
if (!(tint_discarded)) {
atomic_compare_exchange_resulti32 atomic_compare_result;
atomic_compare_result.old_value = atomicCompSwap(a.inner, 0, 1);
atomic_compare_result.exchanged = atomic_compare_result.old_value == 0;
atomic_compare_exchange_resulti32 tint_symbol_2 = atomic_compare_result;
tint_symbol.old_value = tint_symbol_2.old_value;
tint_symbol.exchanged = tint_symbol_2.exchanged;
}
tint_symbol_1 result = tint_symbol;
if (result.exchanged) {
x = result.old_value;
}
return x;
}
void main() {
int inner_result = foo();
value = inner_result;
if (tint_discarded) {
discard;
}
return;
}

View File

@ -0,0 +1,50 @@
#include <metal_stdlib>
using namespace metal;
struct atomic_compare_exchange_resulti32 {
int old_value;
bool exchanged;
};
atomic_compare_exchange_resulti32 atomicCompareExchangeWeak_1(device atomic_int* atomic, int compare, int value) {
int old_value = compare;
bool exchanged = atomic_compare_exchange_weak_explicit(atomic, &old_value, value, memory_order_relaxed, memory_order_relaxed);
return {old_value, exchanged};
}
struct tint_symbol_2 {
int old_value;
bool exchanged;
};
struct tint_symbol {
int value [[color(0)]];
};
int foo_inner(thread bool* const tint_symbol_4, device atomic_int* const tint_symbol_5) {
*(tint_symbol_4) = true;
int x = 0;
tint_symbol_2 tint_symbol_1 = {};
if (!(*(tint_symbol_4))) {
atomic_compare_exchange_resulti32 const tint_symbol_3 = atomicCompareExchangeWeak_1(tint_symbol_5, 0, 1);
tint_symbol_1.old_value = tint_symbol_3.old_value;
tint_symbol_1.exchanged = tint_symbol_3.exchanged;
}
tint_symbol_2 const result = tint_symbol_1;
if (result.exchanged) {
x = result.old_value;
}
return x;
}
fragment tint_symbol foo(device atomic_int* tint_symbol_7 [[buffer(0)]]) {
thread bool tint_symbol_6 = false;
int const inner_result = foo_inner(&(tint_symbol_6), tint_symbol_7);
tint_symbol wrapper_result = {};
wrapper_result.value = inner_result;
if (tint_symbol_6) {
discard_fragment();
}
return wrapper_result;
}

View File

@ -0,0 +1,106 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 0
; Bound: 56
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %foo "foo" %value
OpExecutionMode %foo OriginUpperLeft
OpName %tint_discarded "tint_discarded"
OpName %value "value"
OpName %a_block "a_block"
OpMemberName %a_block 0 "inner"
OpName %a "a"
OpName %foo_inner "foo_inner"
OpName %x "x"
OpName %tint_symbol_1 "tint_symbol_1"
OpMemberName %tint_symbol_1 0 "old_value"
OpMemberName %tint_symbol_1 1 "exchanged"
OpName %tint_symbol "tint_symbol"
OpName %__atomic_compare_exchange_resulti32 "__atomic_compare_exchange_resulti32"
OpMemberName %__atomic_compare_exchange_resulti32 0 "old_value"
OpMemberName %__atomic_compare_exchange_resulti32 1 "exchanged"
OpName %foo "foo"
OpDecorate %value Location 0
OpDecorate %a_block Block
OpMemberDecorate %a_block 0 Offset 0
OpDecorate %a DescriptorSet 0
OpDecorate %a Binding 0
OpMemberDecorate %tint_symbol_1 0 Offset 0
OpMemberDecorate %tint_symbol_1 1 Offset 4
OpMemberDecorate %__atomic_compare_exchange_resulti32 0 Offset 0
OpMemberDecorate %__atomic_compare_exchange_resulti32 1 Offset 4
%bool = OpTypeBool
%2 = OpConstantNull %bool
%_ptr_Private_bool = OpTypePointer Private %bool
%tint_discarded = OpVariable %_ptr_Private_bool Private %2
%int = OpTypeInt 32 1
%_ptr_Output_int = OpTypePointer Output %int
%8 = OpConstantNull %int
%value = OpVariable %_ptr_Output_int Output %8
%a_block = OpTypeStruct %int
%_ptr_StorageBuffer_a_block = OpTypePointer StorageBuffer %a_block
%a = OpVariable %_ptr_StorageBuffer_a_block StorageBuffer
%12 = OpTypeFunction %int
%true = OpConstantTrue %bool
%_ptr_Function_int = OpTypePointer Function %int
%tint_symbol_1 = OpTypeStruct %int %bool
%_ptr_Function_tint_symbol_1 = OpTypePointer Function %tint_symbol_1
%21 = OpConstantNull %tint_symbol_1
%__atomic_compare_exchange_resulti32 = OpTypeStruct %int %bool
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
%_ptr_StorageBuffer_int = OpTypePointer StorageBuffer %int
%int_1 = OpConstant %int 1
%_ptr_Function_bool = OpTypePointer Function %bool
%void = OpTypeVoid
%48 = OpTypeFunction %void
%foo_inner = OpFunction %int None %12
%14 = OpLabel
%x = OpVariable %_ptr_Function_int Function %8
%tint_symbol = OpVariable %_ptr_Function_tint_symbol_1 Function %21
OpStore %tint_discarded %true
OpStore %x %8
%23 = OpLoad %bool %tint_discarded
%22 = OpLogicalNot %bool %23
OpSelectionMerge %24 None
OpBranchConditional %22 %25 %24
%25 = OpLabel
%33 = OpAccessChain %_ptr_StorageBuffer_int %a %uint_0
%35 = OpAtomicCompareExchange %int %33 %uint_1 %uint_0 %uint_0 %int_1 %8
%36 = OpIEqual %bool %35 %8
%26 = OpCompositeConstruct %__atomic_compare_exchange_resulti32 %35 %36
%37 = OpAccessChain %_ptr_Function_int %tint_symbol %uint_0
%38 = OpCompositeExtract %int %26 0
OpStore %37 %38
%40 = OpAccessChain %_ptr_Function_bool %tint_symbol %uint_1
%41 = OpCompositeExtract %bool %26 1
OpStore %40 %41
OpBranch %24
%24 = OpLabel
%42 = OpLoad %tint_symbol_1 %tint_symbol
%43 = OpCompositeExtract %bool %42 1
OpSelectionMerge %44 None
OpBranchConditional %43 %45 %44
%45 = OpLabel
%46 = OpCompositeExtract %int %42 0
OpStore %x %46
OpBranch %44
%44 = OpLabel
%47 = OpLoad %int %x
OpReturnValue %47
OpFunctionEnd
%foo = OpFunction %void None %48
%51 = OpLabel
%52 = OpFunctionCall %int %foo_inner
OpStore %value %52
%53 = OpLoad %bool %tint_discarded
OpSelectionMerge %54 None
OpBranchConditional %53 %55 %54
%55 = OpLabel
OpKill
%54 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -0,0 +1,12 @@
@group(0) @binding(0) var<storage, read_write> a : atomic<i32>;
@fragment
fn foo() -> @location(0) i32 {
discard;
var x = 0;
let result = atomicCompareExchangeWeak(&(a), 0, 1);
if (result.exchanged) {
x = result.old_value;
}
return x;
}