diff --git a/src/tint/transform/demote_to_helper.cc b/src/tint/transform/demote_to_helper.cc index f6d2388540..cb35c67632 100644 --- a/src/tint/transform/demote_to_helper.cc +++ b/src/tint/transform/demote_to_helper.cc @@ -14,6 +14,7 @@ #include "src/tint/transform/demote_to_helper.h" +#include #include #include @@ -24,6 +25,7 @@ #include "src/tint/sem/reference.h" #include "src/tint/sem/statement.h" #include "src/tint/transform/utils/hoist_to_decl_before.h" +#include "src/tint/utils/map.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::DemoteToHelper); @@ -102,6 +104,7 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&, // Mask all writes to host-visible memory using the discarded flag. // We also insert a discard statement before all return statements in entry points for shaders // that discard. + std::unordered_map atomic_cmpxchg_result_types; for (auto* node : src->ASTNodes().Objects()) { Switch( node, @@ -174,11 +177,51 @@ Transform::ApplyResult DemoteToHelper::Apply(const Program* src, const DataMap&, // } // let y = x + tmp; auto result = b.Sym(); - auto result_decl = - b.Decl(b.Var(result, CreateASTTypeFor(ctx, sem_call->Type()))); - auto* masked_call = - b.If(b.Not(flag), - b.Block(b.Assign(result, ctx.CloneWithoutTransform(call)))); + const ast::Type* result_ty = nullptr; + const ast::Statement* masked_call = nullptr; + if (builtin->Type() == sem::BuiltinType::kAtomicCompareExchangeWeak) { + // Special case for atomicCompareExchangeWeak as we cannot name its + // result type. We have to declare an equivalent struct and copy the + // original member values over to it. + + // Declare a struct to hold the result values. + auto* result_struct = sem_call->Type()->As(); + auto* atomic_ty = result_struct->Members()[0]->Type(); + result_ty = b.ty.type_name( + utils::GetOrCreate(atomic_cmpxchg_result_types, atomic_ty, [&]() { + auto name = b.Sym(); + b.Structure( + name, + utils::Vector{ + b.Member("old_value", CreateASTTypeFor(ctx, atomic_ty)), + b.Member("exchanged", b.ty.bool_()), + }); + return name; + })); + + // Generate the masked call and member-wise copy: + // if (!tint_discarded) { + // let tmp_result = atomicCompareExchangeWeak(&p, 1, 2); + // result.exchanged = tmp_result.exchanged; + // result.old_value = tmp_result.old_value; + // } + auto tmp_result = b.Sym(); + masked_call = + b.If(b.Not(flag), + b.Block(utils::Vector{ + b.Decl(b.Let(tmp_result, ctx.CloneWithoutTransform(call))), + b.Assign(b.MemberAccessor(result, "old_value"), + b.MemberAccessor(tmp_result, "old_value")), + b.Assign(b.MemberAccessor(result, "exchanged"), + b.MemberAccessor(tmp_result, "exchanged")), + })); + } else { + result_ty = CreateASTTypeFor(ctx, sem_call->Type()); + masked_call = + b.If(b.Not(flag), + b.Block(b.Assign(result, ctx.CloneWithoutTransform(call)))); + } + auto* result_decl = b.Decl(b.Var(result, result_ty)); hoist_to_decl_before.Prepare(sem_call); hoist_to_decl_before.InsertBefore(stmt, result_decl); hoist_to_decl_before.InsertBefore(stmt, masked_call); diff --git a/src/tint/transform/demote_to_helper_test.cc b/src/tint/transform/demote_to_helper_test.cc index 59754aa589..943da0df8c 100644 --- a/src/tint/transform/demote_to_helper_test.cc +++ b/src/tint/transform/demote_to_helper_test.cc @@ -1044,6 +1044,76 @@ fn foo(@location(0) in : f32, @location(1) coord : vec2) -> @location(0) i3 EXPECT_EQ(expect, str(got)); } +TEST_F(DemoteToHelperTest, AtomicCompareExchangeWeak) { + auto* src = R"( +@group(0) @binding(0) var t : texture_2d; + +@group(0) @binding(1) var s : sampler; + +@group(0) @binding(2) var a : atomic; + +@fragment +fn foo(@location(0) in : f32, @location(1) coord : vec2) -> @location(0) i32 { + if (in == 0.0) { + discard; + } + var result = 0; + if (!atomicCompareExchangeWeak(&a, i32(in), 42).exchanged) { + let xchg = atomicCompareExchangeWeak(&a, i32(in), 42); + result = i32(textureSample(t, s, coord).x) * xchg.old_value; + } + return result; +} +)"; + + auto* expect = R"( +var tint_discarded = false; + +struct tint_symbol_1 { + old_value : i32, + exchanged : bool, +} + +@group(0) @binding(0) var t : texture_2d; + +@group(0) @binding(1) var s : sampler; + +@group(0) @binding(2) var a : atomic; + +@fragment +fn foo(@location(0) in : f32, @location(1) coord : vec2) -> @location(0) i32 { + if ((in == 0.0)) { + tint_discarded = true; + } + var result = 0; + var tint_symbol : tint_symbol_1; + if (!(tint_discarded)) { + let tint_symbol_2 = atomicCompareExchangeWeak(&(a), i32(in), 42); + tint_symbol.old_value = tint_symbol_2.old_value; + tint_symbol.exchanged = tint_symbol_2.exchanged; + } + if (!(tint_symbol.exchanged)) { + var tint_symbol_3 : tint_symbol_1; + if (!(tint_discarded)) { + let tint_symbol_4 = atomicCompareExchangeWeak(&(a), i32(in), 42); + tint_symbol_3.old_value = tint_symbol_4.old_value; + tint_symbol_3.exchanged = tint_symbol_4.exchanged; + } + let xchg = tint_symbol_3; + result = (i32(textureSample(t, s, coord).x) * xchg.old_value); + } + if (tint_discarded) { + discard; + } + return result; +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + // Test that no masking is generated for calls to `atomicLoad()`. TEST_F(DemoteToHelperTest, AtomicLoad) { auto* src = R"( diff --git a/test/tint/statements/discard/atomic_cmpxchg.wgsl b/test/tint/statements/discard/atomic_cmpxchg.wgsl new file mode 100644 index 0000000000..f87ebac232 --- /dev/null +++ b/test/tint/statements/discard/atomic_cmpxchg.wgsl @@ -0,0 +1,12 @@ +@group(0) @binding(0) var a : atomic; + +@fragment +fn foo() -> @location(0) i32 { + discard; + var x = 0; + let result = atomicCompareExchangeWeak(&a, 0, 1); + if (result.exchanged) { + x = result.old_value; + } + return x; +} diff --git a/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.dxc.hlsl b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.dxc.hlsl new file mode 100644 index 0000000000..ed03b3f20e --- /dev/null +++ b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.dxc.hlsl @@ -0,0 +1,50 @@ +static bool tint_discarded = false; + +struct tint_symbol_2 { + int old_value; + bool exchanged; +}; + +RWByteAddressBuffer a : register(u0, space0); + +struct tint_symbol { + int value : SV_Target0; +}; +struct atomic_compare_exchange_weak_ret_type { + int old_value; + bool exchanged; +}; + +atomic_compare_exchange_weak_ret_type tint_atomicCompareExchangeWeak(RWByteAddressBuffer buffer, uint offset, int compare, int value) { + atomic_compare_exchange_weak_ret_type result=(atomic_compare_exchange_weak_ret_type)0; + buffer.InterlockedCompareExchange(offset, compare, value, result.old_value); + result.exchanged = result.old_value == compare; + return result; +} + + +int foo_inner() { + tint_discarded = true; + int x = 0; + tint_symbol_2 tint_symbol_1 = (tint_symbol_2)0; + if (!(tint_discarded)) { + const atomic_compare_exchange_weak_ret_type tint_symbol_3 = tint_atomicCompareExchangeWeak(a, 0u, 0, 1); + tint_symbol_1.old_value = tint_symbol_3.old_value; + tint_symbol_1.exchanged = tint_symbol_3.exchanged; + } + const tint_symbol_2 result = tint_symbol_1; + if (result.exchanged) { + x = result.old_value; + } + return x; +} + +tint_symbol foo() { + const int inner_result = foo_inner(); + tint_symbol wrapper_result = (tint_symbol)0; + wrapper_result.value = inner_result; + if (tint_discarded) { + discard; + } + return wrapper_result; +} diff --git a/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.fxc.hlsl b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.fxc.hlsl new file mode 100644 index 0000000000..ed03b3f20e --- /dev/null +++ b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.fxc.hlsl @@ -0,0 +1,50 @@ +static bool tint_discarded = false; + +struct tint_symbol_2 { + int old_value; + bool exchanged; +}; + +RWByteAddressBuffer a : register(u0, space0); + +struct tint_symbol { + int value : SV_Target0; +}; +struct atomic_compare_exchange_weak_ret_type { + int old_value; + bool exchanged; +}; + +atomic_compare_exchange_weak_ret_type tint_atomicCompareExchangeWeak(RWByteAddressBuffer buffer, uint offset, int compare, int value) { + atomic_compare_exchange_weak_ret_type result=(atomic_compare_exchange_weak_ret_type)0; + buffer.InterlockedCompareExchange(offset, compare, value, result.old_value); + result.exchanged = result.old_value == compare; + return result; +} + + +int foo_inner() { + tint_discarded = true; + int x = 0; + tint_symbol_2 tint_symbol_1 = (tint_symbol_2)0; + if (!(tint_discarded)) { + const atomic_compare_exchange_weak_ret_type tint_symbol_3 = tint_atomicCompareExchangeWeak(a, 0u, 0, 1); + tint_symbol_1.old_value = tint_symbol_3.old_value; + tint_symbol_1.exchanged = tint_symbol_3.exchanged; + } + const tint_symbol_2 result = tint_symbol_1; + if (result.exchanged) { + x = result.old_value; + } + return x; +} + +tint_symbol foo() { + const int inner_result = foo_inner(); + tint_symbol wrapper_result = (tint_symbol)0; + wrapper_result.value = inner_result; + if (tint_discarded) { + discard; + } + return wrapper_result; +} diff --git a/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.glsl b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.glsl new file mode 100644 index 0000000000..a66bdd9f4b --- /dev/null +++ b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.glsl @@ -0,0 +1,47 @@ +#version 310 es +precision mediump float; + +struct atomic_compare_exchange_resulti32 { + int old_value; + bool exchanged; +}; + + +bool tint_discarded = false; +struct tint_symbol_1 { + int old_value; + bool exchanged; +}; + +layout(location = 0) out int value; +layout(binding = 0, std430) buffer a_block_ssbo { + int inner; +} a; + +int foo() { + tint_discarded = true; + int x = 0; + tint_symbol_1 tint_symbol = tint_symbol_1(0, false); + if (!(tint_discarded)) { + atomic_compare_exchange_resulti32 atomic_compare_result; + atomic_compare_result.old_value = atomicCompSwap(a.inner, 0, 1); + atomic_compare_result.exchanged = atomic_compare_result.old_value == 0; + atomic_compare_exchange_resulti32 tint_symbol_2 = atomic_compare_result; + tint_symbol.old_value = tint_symbol_2.old_value; + tint_symbol.exchanged = tint_symbol_2.exchanged; + } + tint_symbol_1 result = tint_symbol; + if (result.exchanged) { + x = result.old_value; + } + return x; +} + +void main() { + int inner_result = foo(); + value = inner_result; + if (tint_discarded) { + discard; + } + return; +} diff --git a/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.msl b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.msl new file mode 100644 index 0000000000..cfdeaaffab --- /dev/null +++ b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.msl @@ -0,0 +1,50 @@ +#include + +using namespace metal; + +struct atomic_compare_exchange_resulti32 { + int old_value; + bool exchanged; +}; +atomic_compare_exchange_resulti32 atomicCompareExchangeWeak_1(device atomic_int* atomic, int compare, int value) { + int old_value = compare; + bool exchanged = atomic_compare_exchange_weak_explicit(atomic, &old_value, value, memory_order_relaxed, memory_order_relaxed); + return {old_value, exchanged}; +} + +struct tint_symbol_2 { + int old_value; + bool exchanged; +}; + +struct tint_symbol { + int value [[color(0)]]; +}; + +int foo_inner(thread bool* const tint_symbol_4, device atomic_int* const tint_symbol_5) { + *(tint_symbol_4) = true; + int x = 0; + tint_symbol_2 tint_symbol_1 = {}; + if (!(*(tint_symbol_4))) { + atomic_compare_exchange_resulti32 const tint_symbol_3 = atomicCompareExchangeWeak_1(tint_symbol_5, 0, 1); + tint_symbol_1.old_value = tint_symbol_3.old_value; + tint_symbol_1.exchanged = tint_symbol_3.exchanged; + } + tint_symbol_2 const result = tint_symbol_1; + if (result.exchanged) { + x = result.old_value; + } + return x; +} + +fragment tint_symbol foo(device atomic_int* tint_symbol_7 [[buffer(0)]]) { + thread bool tint_symbol_6 = false; + int const inner_result = foo_inner(&(tint_symbol_6), tint_symbol_7); + tint_symbol wrapper_result = {}; + wrapper_result.value = inner_result; + if (tint_symbol_6) { + discard_fragment(); + } + return wrapper_result; +} + diff --git a/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.spvasm b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.spvasm new file mode 100644 index 0000000000..b210415256 --- /dev/null +++ b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.spvasm @@ -0,0 +1,106 @@ +; SPIR-V +; Version: 1.3 +; Generator: Google Tint Compiler; 0 +; Bound: 56 +; Schema: 0 + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %foo "foo" %value + OpExecutionMode %foo OriginUpperLeft + OpName %tint_discarded "tint_discarded" + OpName %value "value" + OpName %a_block "a_block" + OpMemberName %a_block 0 "inner" + OpName %a "a" + OpName %foo_inner "foo_inner" + OpName %x "x" + OpName %tint_symbol_1 "tint_symbol_1" + OpMemberName %tint_symbol_1 0 "old_value" + OpMemberName %tint_symbol_1 1 "exchanged" + OpName %tint_symbol "tint_symbol" + OpName %__atomic_compare_exchange_resulti32 "__atomic_compare_exchange_resulti32" + OpMemberName %__atomic_compare_exchange_resulti32 0 "old_value" + OpMemberName %__atomic_compare_exchange_resulti32 1 "exchanged" + OpName %foo "foo" + OpDecorate %value Location 0 + OpDecorate %a_block Block + OpMemberDecorate %a_block 0 Offset 0 + OpDecorate %a DescriptorSet 0 + OpDecorate %a Binding 0 + OpMemberDecorate %tint_symbol_1 0 Offset 0 + OpMemberDecorate %tint_symbol_1 1 Offset 4 + OpMemberDecorate %__atomic_compare_exchange_resulti32 0 Offset 0 + OpMemberDecorate %__atomic_compare_exchange_resulti32 1 Offset 4 + %bool = OpTypeBool + %2 = OpConstantNull %bool +%_ptr_Private_bool = OpTypePointer Private %bool +%tint_discarded = OpVariable %_ptr_Private_bool Private %2 + %int = OpTypeInt 32 1 +%_ptr_Output_int = OpTypePointer Output %int + %8 = OpConstantNull %int + %value = OpVariable %_ptr_Output_int Output %8 + %a_block = OpTypeStruct %int +%_ptr_StorageBuffer_a_block = OpTypePointer StorageBuffer %a_block + %a = OpVariable %_ptr_StorageBuffer_a_block StorageBuffer + %12 = OpTypeFunction %int + %true = OpConstantTrue %bool +%_ptr_Function_int = OpTypePointer Function %int +%tint_symbol_1 = OpTypeStruct %int %bool +%_ptr_Function_tint_symbol_1 = OpTypePointer Function %tint_symbol_1 + %21 = OpConstantNull %tint_symbol_1 +%__atomic_compare_exchange_resulti32 = OpTypeStruct %int %bool + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + %uint_0 = OpConstant %uint 0 +%_ptr_StorageBuffer_int = OpTypePointer StorageBuffer %int + %int_1 = OpConstant %int 1 +%_ptr_Function_bool = OpTypePointer Function %bool + %void = OpTypeVoid + %48 = OpTypeFunction %void + %foo_inner = OpFunction %int None %12 + %14 = OpLabel + %x = OpVariable %_ptr_Function_int Function %8 +%tint_symbol = OpVariable %_ptr_Function_tint_symbol_1 Function %21 + OpStore %tint_discarded %true + OpStore %x %8 + %23 = OpLoad %bool %tint_discarded + %22 = OpLogicalNot %bool %23 + OpSelectionMerge %24 None + OpBranchConditional %22 %25 %24 + %25 = OpLabel + %33 = OpAccessChain %_ptr_StorageBuffer_int %a %uint_0 + %35 = OpAtomicCompareExchange %int %33 %uint_1 %uint_0 %uint_0 %int_1 %8 + %36 = OpIEqual %bool %35 %8 + %26 = OpCompositeConstruct %__atomic_compare_exchange_resulti32 %35 %36 + %37 = OpAccessChain %_ptr_Function_int %tint_symbol %uint_0 + %38 = OpCompositeExtract %int %26 0 + OpStore %37 %38 + %40 = OpAccessChain %_ptr_Function_bool %tint_symbol %uint_1 + %41 = OpCompositeExtract %bool %26 1 + OpStore %40 %41 + OpBranch %24 + %24 = OpLabel + %42 = OpLoad %tint_symbol_1 %tint_symbol + %43 = OpCompositeExtract %bool %42 1 + OpSelectionMerge %44 None + OpBranchConditional %43 %45 %44 + %45 = OpLabel + %46 = OpCompositeExtract %int %42 0 + OpStore %x %46 + OpBranch %44 + %44 = OpLabel + %47 = OpLoad %int %x + OpReturnValue %47 + OpFunctionEnd + %foo = OpFunction %void None %48 + %51 = OpLabel + %52 = OpFunctionCall %int %foo_inner + OpStore %value %52 + %53 = OpLoad %bool %tint_discarded + OpSelectionMerge %54 None + OpBranchConditional %53 %55 %54 + %55 = OpLabel + OpKill + %54 = OpLabel + OpReturn + OpFunctionEnd diff --git a/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.wgsl b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.wgsl new file mode 100644 index 0000000000..b430e288bd --- /dev/null +++ b/test/tint/statements/discard/atomic_cmpxchg.wgsl.expected.wgsl @@ -0,0 +1,12 @@ +@group(0) @binding(0) var a : atomic; + +@fragment +fn foo() -> @location(0) i32 { + discard; + var x = 0; + let result = atomicCompareExchangeWeak(&(a), 0, 1); + if (result.exchanged) { + x = result.old_value; + } + return x; +}