writer/msl: Emit helper functions for atomicCompareExchangeWeak

By generating a helper function for these, we can keep the atomic expression pre-statement-free. This can help prevent for-loops from being transformed into while loops.

Change-Id: Id034ea5ea9be601661ddb78db973015d845c420f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57463
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-07-09 20:21:49 +00:00 committed by Tint LUCI CQ
parent 659bcbeacf
commit 9569e2c790
7 changed files with 154 additions and 143 deletions

View File

@ -55,6 +55,7 @@
#include "src/sem/void_type.h" #include "src/sem/void_type.h"
#include "src/transform/msl.h" #include "src/transform/msl.h"
#include "src/utils/defer.h" #include "src/utils/defer.h"
#include "src/utils/get_or_create.h"
#include "src/utils/scoped_assignment.h" #include "src/utils/scoped_assignment.h"
#include "src/writer/float_to_string.h" #include "src/writer/float_to_string.h"
@ -91,6 +92,8 @@ bool GeneratorImpl::Generate() {
line(); line();
line() << "using namespace metal;"; line() << "using namespace metal;";
auto helpers_insertion_point = current_buffer_->lines.size();
for (auto* const type_decl : program_->AST().TypeDecls()) { for (auto* const type_decl : program_->AST().TypeDecls()) {
if (!type_decl->Is<ast::Alias>()) { if (!type_decl->Is<ast::Alias>()) {
if (!EmitTypeDecl(TypeOf(type_decl))) { if (!EmitTypeDecl(TypeOf(type_decl))) {
@ -137,6 +140,11 @@ bool GeneratorImpl::Generate() {
line(); line();
} }
if (!helpers_.lines.empty()) {
current_buffer_->Insert("", helpers_insertion_point++, 0);
current_buffer_->Insert(helpers_, helpers_insertion_point++, 0);
}
return true; return true;
} }
@ -454,7 +462,7 @@ bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
bool GeneratorImpl::EmitAtomicCall(std::ostream& out, bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
ast::CallExpression* expr, ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) { const sem::Intrinsic* intrinsic) {
auto call = [&](const char* name) { auto call = [&](const std::string& name, bool append_memory_order_relaxed) {
out << name; out << name;
{ {
ScopedParen sp(out); ScopedParen sp(out);
@ -467,84 +475,77 @@ bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
return false; return false;
} }
} }
out << ", memory_order_relaxed"; if (append_memory_order_relaxed) {
out << ", memory_order_relaxed";
}
} }
return true; return true;
}; };
switch (intrinsic->Type()) { switch (intrinsic->Type()) {
case sem::IntrinsicType::kAtomicLoad: case sem::IntrinsicType::kAtomicLoad:
return call("atomic_load_explicit"); return call("atomic_load_explicit", true);
case sem::IntrinsicType::kAtomicStore: case sem::IntrinsicType::kAtomicStore:
return call("atomic_store_explicit"); return call("atomic_store_explicit", true);
case sem::IntrinsicType::kAtomicAdd: case sem::IntrinsicType::kAtomicAdd:
return call("atomic_fetch_add_explicit"); return call("atomic_fetch_add_explicit", true);
case sem::IntrinsicType::kAtomicMax: case sem::IntrinsicType::kAtomicMax:
return call("atomic_fetch_max_explicit"); return call("atomic_fetch_max_explicit", true);
case sem::IntrinsicType::kAtomicMin: case sem::IntrinsicType::kAtomicMin:
return call("atomic_fetch_min_explicit"); return call("atomic_fetch_min_explicit", true);
case sem::IntrinsicType::kAtomicAnd: case sem::IntrinsicType::kAtomicAnd:
return call("atomic_fetch_and_explicit"); return call("atomic_fetch_and_explicit", true);
case sem::IntrinsicType::kAtomicOr: case sem::IntrinsicType::kAtomicOr:
return call("atomic_fetch_or_explicit"); return call("atomic_fetch_or_explicit", true);
case sem::IntrinsicType::kAtomicXor: case sem::IntrinsicType::kAtomicXor:
return call("atomic_fetch_xor_explicit"); return call("atomic_fetch_xor_explicit", true);
case sem::IntrinsicType::kAtomicExchange: case sem::IntrinsicType::kAtomicExchange:
return call("atomic_exchange_explicit"); return call("atomic_exchange_explicit", true);
case sem::IntrinsicType::kAtomicCompareExchangeWeak: { case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
auto* target = expr->params()[0]; auto* ptr_ty = TypeOf(expr->params()[0])->UnwrapRef()->As<sem::Pointer>();
auto* compare_value = expr->params()[1]; auto sc = ptr_ty->StorageClass();
auto* value = expr->params()[2];
auto prev_value = UniqueIdentifier("prev_value"); auto func = utils::GetOrCreate(
auto matched = UniqueIdentifier("matched"); atomicCompareExchangeWeak_, sc, [&]() -> std::string {
auto name = UniqueIdentifier("atomicCompareExchangeWeak");
auto& buf = helpers_;
{ // prev_value = <compare_value>; line(&buf) << "template <typename A, typename T>";
auto pre = line(); {
if (!EmitType(pre, TypeOf(value), "")) { auto f = line(&buf);
return false; f << "vec<T, 2> " << name << "(";
} if (!EmitStorageClass(f, sc)) {
pre << " " << prev_value << " = "; return "";
if (!EmitExpression(pre, compare_value)) { }
return false; f << " A* atomic, T compare, T value) {";
} }
pre << ";";
}
{ // bool matched = atomic_compare_exchange_weak_explicit( buf.IncrementIndent();
// target, &got, <value>, memory_order_relaxed, memory_order_relaxed) TINT_DEFER({
auto pre = line(); buf.DecrementIndent();
pre << "bool " << matched << " = atomic_compare_exchange_weak_explicit"; line(&buf) << "}";
{ line(&buf);
ScopedParen sp(pre); });
if (!EmitExpression(pre, target)) {
return false;
}
pre << ", &" << prev_value << ", ";
if (!EmitExpression(pre, value)) {
return false;
}
pre << ", memory_order_relaxed, memory_order_relaxed";
}
pre << ";";
}
{ // [u]int2(got, matched) line(&buf) << "T prev_value = compare;";
if (!EmitType(out, TypeOf(expr), "")) { line(&buf) << "bool matched = "
return false; "atomic_compare_exchange_weak_explicit(atomic, "
} "&prev_value, value, memory_order_relaxed, "
out << "(" << prev_value << ", " << matched << ")"; "memory_order_relaxed);";
} line(&buf) << "return {prev_value, matched};";
return true; return name;
});
return call(func, false);
} }
default: default:
@ -1867,24 +1868,10 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} }
if (auto* ptr = type->As<sem::Pointer>()) { if (auto* ptr = type->As<sem::Pointer>()) {
switch (ptr->StorageClass()) { if (!EmitStorageClass(out, ptr->StorageClass())) {
case ast::StorageClass::kFunction: return false;
case ast::StorageClass::kPrivate:
case ast::StorageClass::kUniformConstant:
out << "thread ";
break;
case ast::StorageClass::kWorkgroup:
out << "threadgroup ";
break;
case ast::StorageClass::kStorage:
out << "device ";
break;
case ast::StorageClass::kUniform:
out << "constant ";
break;
default:
TINT_ICE(Writer, diagnostics_) << "unhandled storage class for pointer";
} }
out << " ";
if (ptr->StoreType()->Is<sem::Array>()) { if (ptr->StoreType()->Is<sem::Array>()) {
std::string inner = "(*" + name + ")"; std::string inner = "(*" + name + ")";
if (!EmitType(out, ptr->StoreType(), inner)) { if (!EmitType(out, ptr->StoreType(), inner)) {
@ -2004,6 +1991,29 @@ bool GeneratorImpl::EmitType(std::ostream& out,
return false; return false;
} }
bool GeneratorImpl::EmitStorageClass(std::ostream& out, ast::StorageClass sc) {
switch (sc) {
case ast::StorageClass::kFunction:
case ast::StorageClass::kPrivate:
case ast::StorageClass::kUniformConstant:
out << "thread";
return true;
case ast::StorageClass::kWorkgroup:
out << "threadgroup";
return true;
case ast::StorageClass::kStorage:
out << "device";
return true;
case ast::StorageClass::kUniform:
out << "constant";
return true;
default:
break;
}
TINT_ICE(Writer, diagnostics_) << "unhandled storage class: " << sc;
return false;
}
bool GeneratorImpl::EmitPackedType(std::ostream& out, bool GeneratorImpl::EmitPackedType(std::ostream& out,
const sem::Type* type, const sem::Type* type,
const std::string& name) { const std::string& name) {

View File

@ -16,6 +16,7 @@
#define SRC_WRITER_MSL_GENERATOR_IMPL_H_ #define SRC_WRITER_MSL_GENERATOR_IMPL_H_
#include <string> #include <string>
#include <unordered_map>
#include "src/ast/array_accessor_expression.h" #include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h" #include "src/ast/assignment_statement.h"
@ -218,6 +219,11 @@ class GeneratorImpl : public TextGenerator {
bool EmitType(std::ostream& out, bool EmitType(std::ostream& out,
const sem::Type* type, const sem::Type* type,
const std::string& name); const std::string& name);
/// Handles generating a storage class
/// @param out the output of the type stream
/// @param sc the storage class to generate
/// @returns true if the storage class is emitted
bool EmitStorageClass(std::ostream& out, ast::StorageClass sc);
/// Handles generating an MSL-packed storage type. /// Handles generating an MSL-packed storage type.
/// If the type does not have a packed form, the standard non-packed form is /// If the type does not have a packed form, the standard non-packed form is
/// emitted. /// emitted.
@ -282,11 +288,20 @@ class GeneratorImpl : public TextGenerator {
uint32_t align; uint32_t align;
}; };
TextBuffer helpers_; // Helper functions emitted at the top of the output
/// @returns the MSL packed type size and alignment in bytes for the given /// @returns the MSL packed type size and alignment in bytes for the given
/// type. /// type.
SizeAndAlign MslPackedTypeSizeAndAlign(const sem::Type* ty); SizeAndAlign MslPackedTypeSizeAndAlign(const sem::Type* ty);
using StorageClassToString =
std::unordered_map<ast::StorageClass, std::string>;
std::function<bool()> emit_continuing_; std::function<bool()> emit_continuing_;
/// Name of atomicCompareExchangeWeak() helper for the given pointer storage
/// class.
StorageClassToString atomicCompareExchangeWeak_;
}; };
} // namespace msl } // namespace msl

View File

@ -185,9 +185,8 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
// return; // return;
// } // }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup); Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2); auto* multi_stmt = Block(Ignore(1), Ignore(2));
auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr, auto* f = For(multi_stmt, nullptr, nullptr, Block(Return()));
Block(Return()));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -196,9 +195,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
int prev_value = 1; {
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed); (void) 1;
int2 b = int2(prev_value, matched); (void) 2;
}
for(; ; ) { for(; ; ) {
return; return;
} }
@ -225,35 +225,6 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCond) {
)"); )");
} }
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCond) {
// var<workgroup> a : atomic<i32>;
// for(; atomicCompareExchangeWeak(&a, 1, 2).x == 0; ) {
// return;
// }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = create<ast::BinaryExpression>(
ast::BinaryOp::kEqual,
MemberAccessor(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2),
"x"),
Expr(0));
auto* f = For(nullptr, multi_stmt, nullptr, Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while (true) {
int prev_value = 1;
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
if (!((int2(prev_value, matched).x == 0))) { break; }
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) { TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) {
// for(; ; i = i + 1) { // for(; ; i = i + 1) {
// return; // return;
@ -276,13 +247,12 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) {
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) { TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) {
// var<workgroup> a : atomic<i32>; // var<workgroup> a : atomic<i32>;
// for(; ; ignore(atomicCompareExchangeWeak(&a, 1, 2))) { // for(; ; { ignore(1); ignore(2); }) {
// return; // return;
// } // }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup); Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = auto* multi_stmt = Block(Ignore(1), Ignore(2));
Ignore(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2));
auto* f = For(nullptr, nullptr, multi_stmt, Block(Return())); auto* f = For(nullptr, nullptr, multi_stmt, Block(Return()));
WrapInFunction(f); WrapInFunction(f);
@ -293,9 +263,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while (true) { EXPECT_EQ(gen.result(), R"( while (true) {
return; return;
int prev_value = 1; {
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed); (void) 1;
(void) int2(prev_value, matched); (void) 2;
}
} }
)"); )");
} }
@ -322,22 +293,13 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInitCondCont) {
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) { TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
// var<workgroup> a : atomic<i32>; // var<workgroup> a : atomic<i32>;
// for(var b = atomicCompareExchangeWeak(&a, 1, 2); // for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) {
// atomicCompareExchangeWeak(&a, 1, 2).x == 0;
// ignore(atomicCompareExchangeWeak(&a, 1, 2))) {
// return; // return;
// } // }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup); Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt_a = Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2); auto* multi_stmt_a = Block(Ignore(1), Ignore(2));
auto* multi_stmt_b = create<ast::BinaryExpression>( auto* multi_stmt_b = Block(Ignore(3), Ignore(4));
ast::BinaryOp::kEqual, auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, Block(Return()));
MemberAccessor(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2),
"x"),
Expr(0));
auto* multi_stmt_c =
Ignore(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2));
auto* f = For(Decl(Var("b", nullptr, multi_stmt_a)), multi_stmt_b,
multi_stmt_c, Block(Return()));
WrapInFunction(f); WrapInFunction(f);
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -346,17 +308,17 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( { EXPECT_EQ(gen.result(), R"( {
int prev_value = 1; {
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed); (void) 1;
int2 b = int2(prev_value, matched); (void) 2;
}
while (true) { while (true) {
int prev_value_1 = 1; if (!(true)) { break; }
bool matched_1 = atomic_compare_exchange_weak_explicit(&(a), &prev_value_1, 2, memory_order_relaxed, memory_order_relaxed);
if (!((int2(prev_value_1, matched_1).x == 0))) { break; }
return; return;
int prev_value_2 = 1; {
bool matched_2 = atomic_compare_exchange_weak_explicit(&(a), &prev_value_2, 2, memory_order_relaxed, memory_order_relaxed); (void) 3;
(void) int2(prev_value_2, matched_2); (void) 4;
}
} }
} }
)"); )");

View File

@ -1,14 +1,20 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(device A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
struct SB_RW { struct SB_RW {
/* 0x0000 */ atomic_int arg_0; /* 0x0000 */ atomic_int arg_0;
}; };
void atomicCompareExchangeWeak_12871c(device SB_RW& sb_rw) { void atomicCompareExchangeWeak_12871c(device SB_RW& sb_rw) {
int prev_value = 1; int2 res = atomicCompareExchangeWeak_1(&(sb_rw.arg_0), 1, 1);
bool matched = atomic_compare_exchange_weak_explicit(&(sb_rw.arg_0), &prev_value, 1, memory_order_relaxed, memory_order_relaxed);
int2 res = int2(prev_value, matched);
} }
fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) { fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) {

View File

@ -1,14 +1,20 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(device A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
struct SB_RW { struct SB_RW {
/* 0x0000 */ atomic_uint arg_0; /* 0x0000 */ atomic_uint arg_0;
}; };
void atomicCompareExchangeWeak_6673da(device SB_RW& sb_rw) { void atomicCompareExchangeWeak_6673da(device SB_RW& sb_rw) {
uint prev_value = 1u; uint2 res = atomicCompareExchangeWeak_1(&(sb_rw.arg_0), 1u, 1u);
bool matched = atomic_compare_exchange_weak_explicit(&(sb_rw.arg_0), &prev_value, 1u, memory_order_relaxed, memory_order_relaxed);
uint2 res = uint2(prev_value, matched);
} }
fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) { fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) {

View File

@ -1,10 +1,16 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
void atomicCompareExchangeWeak_89ea3b(threadgroup atomic_int* const tint_symbol_1) { void atomicCompareExchangeWeak_89ea3b(threadgroup atomic_int* const tint_symbol_1) {
int prev_value = 1; int2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1, 1);
bool matched = atomic_compare_exchange_weak_explicit(&(*(tint_symbol_1)), &prev_value, 1, memory_order_relaxed, memory_order_relaxed);
int2 res = int2(prev_value, matched);
} }
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) { kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -1,10 +1,16 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
void atomicCompareExchangeWeak_b2ab2c(threadgroup atomic_uint* const tint_symbol_1) { void atomicCompareExchangeWeak_b2ab2c(threadgroup atomic_uint* const tint_symbol_1) {
uint prev_value = 1u; uint2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1u, 1u);
bool matched = atomic_compare_exchange_weak_explicit(&(*(tint_symbol_1)), &prev_value, 1u, memory_order_relaxed, memory_order_relaxed);
uint2 res = uint2(prev_value, matched);
} }
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) { kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {