writer/msl: Emit helper functions for atomicCompareExchangeWeak

By generating a helper function for these, we can keep the atomic expression pre-statement-free. This can help prevent for-loops from being transformed into while loops.

Change-Id: Id034ea5ea9be601661ddb78db973015d845c420f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57463
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-07-09 20:21:49 +00:00 committed by Tint LUCI CQ
parent 659bcbeacf
commit 9569e2c790
7 changed files with 154 additions and 143 deletions

View File

@ -55,6 +55,7 @@
#include "src/sem/void_type.h"
#include "src/transform/msl.h"
#include "src/utils/defer.h"
#include "src/utils/get_or_create.h"
#include "src/utils/scoped_assignment.h"
#include "src/writer/float_to_string.h"
@ -91,6 +92,8 @@ bool GeneratorImpl::Generate() {
line();
line() << "using namespace metal;";
auto helpers_insertion_point = current_buffer_->lines.size();
for (auto* const type_decl : program_->AST().TypeDecls()) {
if (!type_decl->Is<ast::Alias>()) {
if (!EmitTypeDecl(TypeOf(type_decl))) {
@ -137,6 +140,11 @@ bool GeneratorImpl::Generate() {
line();
}
if (!helpers_.lines.empty()) {
current_buffer_->Insert("", helpers_insertion_point++, 0);
current_buffer_->Insert(helpers_, helpers_insertion_point++, 0);
}
return true;
}
@ -454,7 +462,7 @@ bool GeneratorImpl::EmitIntrinsicCall(std::ostream& out,
bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
ast::CallExpression* expr,
const sem::Intrinsic* intrinsic) {
auto call = [&](const char* name) {
auto call = [&](const std::string& name, bool append_memory_order_relaxed) {
out << name;
{
ScopedParen sp(out);
@ -467,84 +475,77 @@ bool GeneratorImpl::EmitAtomicCall(std::ostream& out,
return false;
}
}
if (append_memory_order_relaxed) {
out << ", memory_order_relaxed";
}
}
return true;
};
switch (intrinsic->Type()) {
case sem::IntrinsicType::kAtomicLoad:
return call("atomic_load_explicit");
return call("atomic_load_explicit", true);
case sem::IntrinsicType::kAtomicStore:
return call("atomic_store_explicit");
return call("atomic_store_explicit", true);
case sem::IntrinsicType::kAtomicAdd:
return call("atomic_fetch_add_explicit");
return call("atomic_fetch_add_explicit", true);
case sem::IntrinsicType::kAtomicMax:
return call("atomic_fetch_max_explicit");
return call("atomic_fetch_max_explicit", true);
case sem::IntrinsicType::kAtomicMin:
return call("atomic_fetch_min_explicit");
return call("atomic_fetch_min_explicit", true);
case sem::IntrinsicType::kAtomicAnd:
return call("atomic_fetch_and_explicit");
return call("atomic_fetch_and_explicit", true);
case sem::IntrinsicType::kAtomicOr:
return call("atomic_fetch_or_explicit");
return call("atomic_fetch_or_explicit", true);
case sem::IntrinsicType::kAtomicXor:
return call("atomic_fetch_xor_explicit");
return call("atomic_fetch_xor_explicit", true);
case sem::IntrinsicType::kAtomicExchange:
return call("atomic_exchange_explicit");
return call("atomic_exchange_explicit", true);
case sem::IntrinsicType::kAtomicCompareExchangeWeak: {
auto* target = expr->params()[0];
auto* compare_value = expr->params()[1];
auto* value = expr->params()[2];
auto* ptr_ty = TypeOf(expr->params()[0])->UnwrapRef()->As<sem::Pointer>();
auto sc = ptr_ty->StorageClass();
auto prev_value = UniqueIdentifier("prev_value");
auto matched = UniqueIdentifier("matched");
auto func = utils::GetOrCreate(
atomicCompareExchangeWeak_, sc, [&]() -> std::string {
auto name = UniqueIdentifier("atomicCompareExchangeWeak");
auto& buf = helpers_;
{ // prev_value = <compare_value>;
auto pre = line();
if (!EmitType(pre, TypeOf(value), "")) {
return false;
}
pre << " " << prev_value << " = ";
if (!EmitExpression(pre, compare_value)) {
return false;
}
pre << ";";
}
{ // bool matched = atomic_compare_exchange_weak_explicit(
// target, &got, <value>, memory_order_relaxed, memory_order_relaxed)
auto pre = line();
pre << "bool " << matched << " = atomic_compare_exchange_weak_explicit";
line(&buf) << "template <typename A, typename T>";
{
ScopedParen sp(pre);
if (!EmitExpression(pre, target)) {
return false;
auto f = line(&buf);
f << "vec<T, 2> " << name << "(";
if (!EmitStorageClass(f, sc)) {
return "";
}
pre << ", &" << prev_value << ", ";
if (!EmitExpression(pre, value)) {
return false;
}
pre << ", memory_order_relaxed, memory_order_relaxed";
}
pre << ";";
f << " A* atomic, T compare, T value) {";
}
{ // [u]int2(got, matched)
if (!EmitType(out, TypeOf(expr), "")) {
return false;
}
out << "(" << prev_value << ", " << matched << ")";
}
return true;
buf.IncrementIndent();
TINT_DEFER({
buf.DecrementIndent();
line(&buf) << "}";
line(&buf);
});
line(&buf) << "T prev_value = compare;";
line(&buf) << "bool matched = "
"atomic_compare_exchange_weak_explicit(atomic, "
"&prev_value, value, memory_order_relaxed, "
"memory_order_relaxed);";
line(&buf) << "return {prev_value, matched};";
return name;
});
return call(func, false);
}
default:
@ -1867,24 +1868,10 @@ bool GeneratorImpl::EmitType(std::ostream& out,
}
if (auto* ptr = type->As<sem::Pointer>()) {
switch (ptr->StorageClass()) {
case ast::StorageClass::kFunction:
case ast::StorageClass::kPrivate:
case ast::StorageClass::kUniformConstant:
out << "thread ";
break;
case ast::StorageClass::kWorkgroup:
out << "threadgroup ";
break;
case ast::StorageClass::kStorage:
out << "device ";
break;
case ast::StorageClass::kUniform:
out << "constant ";
break;
default:
TINT_ICE(Writer, diagnostics_) << "unhandled storage class for pointer";
if (!EmitStorageClass(out, ptr->StorageClass())) {
return false;
}
out << " ";
if (ptr->StoreType()->Is<sem::Array>()) {
std::string inner = "(*" + name + ")";
if (!EmitType(out, ptr->StoreType(), inner)) {
@ -2004,6 +1991,29 @@ bool GeneratorImpl::EmitType(std::ostream& out,
return false;
}
bool GeneratorImpl::EmitStorageClass(std::ostream& out, ast::StorageClass sc) {
switch (sc) {
case ast::StorageClass::kFunction:
case ast::StorageClass::kPrivate:
case ast::StorageClass::kUniformConstant:
out << "thread";
return true;
case ast::StorageClass::kWorkgroup:
out << "threadgroup";
return true;
case ast::StorageClass::kStorage:
out << "device";
return true;
case ast::StorageClass::kUniform:
out << "constant";
return true;
default:
break;
}
TINT_ICE(Writer, diagnostics_) << "unhandled storage class: " << sc;
return false;
}
bool GeneratorImpl::EmitPackedType(std::ostream& out,
const sem::Type* type,
const std::string& name) {

View File

@ -16,6 +16,7 @@
#define SRC_WRITER_MSL_GENERATOR_IMPL_H_
#include <string>
#include <unordered_map>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/assignment_statement.h"
@ -218,6 +219,11 @@ class GeneratorImpl : public TextGenerator {
bool EmitType(std::ostream& out,
const sem::Type* type,
const std::string& name);
/// Handles generating a storage class
/// @param out the output of the type stream
/// @param sc the storage class to generate
/// @returns true if the storage class is emitted
bool EmitStorageClass(std::ostream& out, ast::StorageClass sc);
/// Handles generating an MSL-packed storage type.
/// If the type does not have a packed form, the standard non-packed form is
/// emitted.
@ -282,11 +288,20 @@ class GeneratorImpl : public TextGenerator {
uint32_t align;
};
TextBuffer helpers_; // Helper functions emitted at the top of the output
/// @returns the MSL packed type size and alignment in bytes for the given
/// type.
SizeAndAlign MslPackedTypeSizeAndAlign(const sem::Type* ty);
using StorageClassToString =
std::unordered_map<ast::StorageClass, std::string>;
std::function<bool()> emit_continuing_;
/// Name of atomicCompareExchangeWeak() helper for the given pointer storage
/// class.
StorageClassToString atomicCompareExchangeWeak_;
};
} // namespace msl

View File

@ -185,9 +185,8 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
// return;
// }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2);
auto* f = For(Decl(Var("b", nullptr, multi_stmt)), nullptr, nullptr,
Block(Return()));
auto* multi_stmt = Block(Ignore(1), Ignore(2));
auto* f = For(multi_stmt, nullptr, nullptr, Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@ -196,9 +195,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInit) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
int prev_value = 1;
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
int2 b = int2(prev_value, matched);
{
(void) 1;
(void) 2;
}
for(; ; ) {
return;
}
@ -225,35 +225,6 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCond) {
)");
}
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCond) {
// var<workgroup> a : atomic<i32>;
// for(; atomicCompareExchangeWeak(&a, 1, 2).x == 0; ) {
// return;
// }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt = create<ast::BinaryExpression>(
ast::BinaryOp::kEqual,
MemberAccessor(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2),
"x"),
Expr(0));
auto* f = For(nullptr, multi_stmt, nullptr, Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while (true) {
int prev_value = 1;
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
if (!((int2(prev_value, matched).x == 0))) { break; }
return;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) {
// for(; ; i = i + 1) {
// return;
@ -276,13 +247,12 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleCont) {
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) {
// var<workgroup> a : atomic<i32>;
// for(; ; ignore(atomicCompareExchangeWeak(&a, 1, 2))) {
// for(; ; { ignore(1); ignore(2); }) {
// return;
// }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt =
Ignore(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2));
auto* multi_stmt = Block(Ignore(1), Ignore(2));
auto* f = For(nullptr, nullptr, multi_stmt, Block(Return()));
WrapInFunction(f);
@ -293,9 +263,10 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( while (true) {
return;
int prev_value = 1;
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
(void) int2(prev_value, matched);
{
(void) 1;
(void) 2;
}
}
)");
}
@ -322,22 +293,13 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithSimpleInitCondCont) {
TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
// var<workgroup> a : atomic<i32>;
// for(var b = atomicCompareExchangeWeak(&a, 1, 2);
// atomicCompareExchangeWeak(&a, 1, 2).x == 0;
// ignore(atomicCompareExchangeWeak(&a, 1, 2))) {
// for({ ignore(1); ignore(2); }; true; { ignore(3); ignore(4); }) {
// return;
// }
Global("a", ty.atomic<i32>(), ast::StorageClass::kWorkgroup);
auto* multi_stmt_a = Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2);
auto* multi_stmt_b = create<ast::BinaryExpression>(
ast::BinaryOp::kEqual,
MemberAccessor(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2),
"x"),
Expr(0));
auto* multi_stmt_c =
Ignore(Call("atomicCompareExchangeWeak", AddressOf("a"), 1, 2));
auto* f = For(Decl(Var("b", nullptr, multi_stmt_a)), multi_stmt_b,
multi_stmt_c, Block(Return()));
auto* multi_stmt_a = Block(Ignore(1), Ignore(2));
auto* multi_stmt_b = Block(Ignore(3), Ignore(4));
auto* f = For(multi_stmt_a, Expr(true), multi_stmt_b, Block(Return()));
WrapInFunction(f);
GeneratorImpl& gen = Build();
@ -346,17 +308,17 @@ TEST_F(MslGeneratorImplTest, Emit_ForLoopWithMultiStmtInitCondCont) {
ASSERT_TRUE(gen.EmitStatement(f)) << gen.error();
EXPECT_EQ(gen.result(), R"( {
int prev_value = 1;
bool matched = atomic_compare_exchange_weak_explicit(&(a), &prev_value, 2, memory_order_relaxed, memory_order_relaxed);
int2 b = int2(prev_value, matched);
{
(void) 1;
(void) 2;
}
while (true) {
int prev_value_1 = 1;
bool matched_1 = atomic_compare_exchange_weak_explicit(&(a), &prev_value_1, 2, memory_order_relaxed, memory_order_relaxed);
if (!((int2(prev_value_1, matched_1).x == 0))) { break; }
if (!(true)) { break; }
return;
int prev_value_2 = 1;
bool matched_2 = atomic_compare_exchange_weak_explicit(&(a), &prev_value_2, 2, memory_order_relaxed, memory_order_relaxed);
(void) int2(prev_value_2, matched_2);
{
(void) 3;
(void) 4;
}
}
}
)");

View File

@ -1,14 +1,20 @@
#include <metal_stdlib>
using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(device A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
struct SB_RW {
/* 0x0000 */ atomic_int arg_0;
};
void atomicCompareExchangeWeak_12871c(device SB_RW& sb_rw) {
int prev_value = 1;
bool matched = atomic_compare_exchange_weak_explicit(&(sb_rw.arg_0), &prev_value, 1, memory_order_relaxed, memory_order_relaxed);
int2 res = int2(prev_value, matched);
int2 res = atomicCompareExchangeWeak_1(&(sb_rw.arg_0), 1, 1);
}
fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) {

View File

@ -1,14 +1,20 @@
#include <metal_stdlib>
using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(device A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
struct SB_RW {
/* 0x0000 */ atomic_uint arg_0;
};
void atomicCompareExchangeWeak_6673da(device SB_RW& sb_rw) {
uint prev_value = 1u;
bool matched = atomic_compare_exchange_weak_explicit(&(sb_rw.arg_0), &prev_value, 1u, memory_order_relaxed, memory_order_relaxed);
uint2 res = uint2(prev_value, matched);
uint2 res = atomicCompareExchangeWeak_1(&(sb_rw.arg_0), 1u, 1u);
}
fragment void fragment_main(device SB_RW& sb_rw [[buffer(0)]]) {

View File

@ -1,10 +1,16 @@
#include <metal_stdlib>
using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
void atomicCompareExchangeWeak_89ea3b(threadgroup atomic_int* const tint_symbol_1) {
int prev_value = 1;
bool matched = atomic_compare_exchange_weak_explicit(&(*(tint_symbol_1)), &prev_value, 1, memory_order_relaxed, memory_order_relaxed);
int2 res = int2(prev_value, matched);
int2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1, 1);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {

View File

@ -1,10 +1,16 @@
#include <metal_stdlib>
using namespace metal;
template <typename A, typename T>
vec<T, 2> atomicCompareExchangeWeak_1(threadgroup A* atomic, T compare, T value) {
T prev_value = compare;
bool matched = atomic_compare_exchange_weak_explicit(atomic, &prev_value, value, memory_order_relaxed, memory_order_relaxed);
return {prev_value, matched};
}
void atomicCompareExchangeWeak_b2ab2c(threadgroup atomic_uint* const tint_symbol_1) {
uint prev_value = 1u;
bool matched = atomic_compare_exchange_weak_explicit(&(*(tint_symbol_1)), &prev_value, 1u, memory_order_relaxed, memory_order_relaxed);
uint2 res = uint2(prev_value, matched);
uint2 res = atomicCompareExchangeWeak_1(&(*(tint_symbol_1)), 1u, 1u);
}
kernel void compute_main(uint local_invocation_index [[thread_index_in_threadgroup]]) {