tools/remote-compile: clang-format

This was using the old tint code style

Change-Id: I1aff541eb4cc0d7ec0e045b555710aa605c4da28
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105141
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Auto-Submit: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-10-10 12:17:03 +00:00 committed by Dawn LUCI CQ
parent c8f1075310
commit 467e6e3fc6
6 changed files with 657 additions and 670 deletions

View File

@ -19,10 +19,10 @@
/// The return structure of a compile function /// The return structure of a compile function
struct CompileResult { struct CompileResult {
/// True if shader compiled /// True if shader compiled
bool success = false; bool success = false;
/// Output of the compiler /// Output of the compiler
std::string output; std::string output;
}; };
CompileResult CompileMslUsingMetalAPI(const std::string& src); CompileResult CompileMslUsingMetalAPI(const std::string& src);

View File

@ -34,8 +34,8 @@ namespace {
/// Print the tool usage, and exit with 1. /// Print the tool usage, and exit with 1.
void ShowUsage() { void ShowUsage() {
const char* name = "tint-remote-compile"; const char* name = "tint-remote-compile";
printf(R"(%s is a tool for compiling a shader on a remote machine printf(R"(%s is a tool for compiling a shader on a remote machine
usage as server: usage as server:
%s -s [-p port-number] %s -s [-p port-number]
@ -48,8 +48,8 @@ usage as client:
Alternatively, you can pass xcrun arguments so %s can be used as a Alternatively, you can pass xcrun arguments so %s can be used as a
drop-in replacement. drop-in replacement.
)", )",
name, name, name, name); name, name, name, name);
exit(1); exit(1);
} }
/// The protocol version code. Bump each time the protocol changes /// The protocol version code. Bump each time the protocol changes
@ -57,98 +57,98 @@ constexpr uint32_t kProtocolVersion = 1;
/// Supported shader source languages /// Supported shader source languages
enum SourceLanguage { enum SourceLanguage {
MSL, MSL,
}; };
/// Stream is a serialization wrapper around a socket /// Stream is a serialization wrapper around a socket
struct Stream { struct Stream {
/// The underlying socket /// The underlying socket
Socket* const socket; Socket* const socket;
/// Error state /// Error state
std::string error; std::string error;
/// Writes a uint32_t to the socket /// Writes a uint32_t to the socket
Stream operator<<(uint32_t v) { Stream operator<<(uint32_t v) {
if (error.empty()) { if (error.empty()) {
Write(&v, sizeof(v)); Write(&v, sizeof(v));
}
return *this;
}
/// Reads a uint32_t from the socket
Stream operator>>(uint32_t& v) {
if (error.empty()) {
Read(&v, sizeof(v));
}
return *this;
}
/// Writes a std::string to the socket
Stream operator<<(const std::string& v) {
if (error.empty()) {
uint32_t count = static_cast<uint32_t>(v.size());
*this << count;
if (count) {
Write(v.data(), count);
}
}
return *this;
}
/// Reads a std::string from the socket
Stream operator>>(std::string& v) {
uint32_t count = 0;
*this >> count;
if (count) {
std::vector<char> buf(count);
if (Read(buf.data(), count)) {
v = std::string(buf.data(), buf.size());
}
} else {
v.clear();
}
return *this;
}
/// Writes an enum value to the socket
template <typename T>
std::enable_if_t<std::is_enum<T>::value, Stream> operator<<(T e) {
return *this << static_cast<uint32_t>(e);
}
/// Reads an enum value from the socket
template <typename T>
std::enable_if_t<std::is_enum<T>::value, Stream> operator>>(T& e) {
uint32_t v;
*this >> v;
e = static_cast<T>(v);
return *this;
}
private:
bool Write(const void* data, size_t size) {
if (error.empty()) {
if (!socket->Write(data, size)) {
error = "Socket::Write() failed";
}
}
return error.empty();
}
bool Read(void* data, size_t size) {
auto buf = reinterpret_cast<uint8_t*>(data);
while (size > 0 && error.empty()) {
if (auto n = socket->Read(buf, size)) {
if (n > size) {
error = "Socket::Read() returned more bytes than requested";
return false;
} }
size -= n; return *this;
buf += n; }
}
/// Reads a uint32_t from the socket
Stream operator>>(uint32_t& v) {
if (error.empty()) {
Read(&v, sizeof(v));
}
return *this;
}
/// Writes a std::string to the socket
Stream operator<<(const std::string& v) {
if (error.empty()) {
uint32_t count = static_cast<uint32_t>(v.size());
*this << count;
if (count) {
Write(v.data(), count);
}
}
return *this;
}
/// Reads a std::string from the socket
Stream operator>>(std::string& v) {
uint32_t count = 0;
*this >> count;
if (count) {
std::vector<char> buf(count);
if (Read(buf.data(), count)) {
v = std::string(buf.data(), buf.size());
}
} else {
v.clear();
}
return *this;
}
/// Writes an enum value to the socket
template <typename T>
std::enable_if_t<std::is_enum<T>::value, Stream> operator<<(T e) {
return *this << static_cast<uint32_t>(e);
}
/// Reads an enum value from the socket
template <typename T>
std::enable_if_t<std::is_enum<T>::value, Stream> operator>>(T& e) {
uint32_t v;
*this >> v;
e = static_cast<T>(v);
return *this;
}
private:
bool Write(const void* data, size_t size) {
if (error.empty()) {
if (!socket->Write(data, size)) {
error = "Socket::Write() failed";
}
}
return error.empty();
}
bool Read(void* data, size_t size) {
auto buf = reinterpret_cast<uint8_t*>(data);
while (size > 0 && error.empty()) {
if (auto n = socket->Read(buf, size)) {
if (n > size) {
error = "Socket::Read() returned more bytes than requested";
return false;
}
size -= n;
buf += n;
}
}
return error.empty();
} }
return error.empty();
}
}; };
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -157,113 +157,111 @@ struct Stream {
/// Base class for all messages /// Base class for all messages
struct Message { struct Message {
/// The type of the message /// The type of the message
enum class Type { enum class Type {
ConnectionRequest, ConnectionRequest,
ConnectionResponse, ConnectionResponse,
CompileRequest, CompileRequest,
CompileResponse, CompileResponse,
}; };
explicit Message(Type ty) : type(ty) {} explicit Message(Type ty) : type(ty) {}
const Type type; const Type type;
}; };
struct ConnectionResponse : Message { // Server -> Client struct ConnectionResponse : Message { // Server -> Client
ConnectionResponse() : Message(Type::ConnectionResponse) {} ConnectionResponse() : Message(Type::ConnectionResponse) {}
template <typename T> template <typename T>
void Serialize(T&& f) { void Serialize(T&& f) {
f(error); f(error);
} }
std::string error; std::string error;
}; };
struct ConnectionRequest : Message { // Client -> Server struct ConnectionRequest : Message { // Client -> Server
using Response = ConnectionResponse; using Response = ConnectionResponse;
explicit ConnectionRequest(uint32_t proto_ver = kProtocolVersion) explicit ConnectionRequest(uint32_t proto_ver = kProtocolVersion)
: Message(Type::ConnectionRequest), protocol_version(proto_ver) {} : Message(Type::ConnectionRequest), protocol_version(proto_ver) {}
template <typename T> template <typename T>
void Serialize(T&& f) { void Serialize(T&& f) {
f(protocol_version); f(protocol_version);
} }
uint32_t protocol_version; uint32_t protocol_version;
}; };
struct CompileResponse : Message { // Server -> Client struct CompileResponse : Message { // Server -> Client
CompileResponse() : Message(Type::CompileResponse) {} CompileResponse() : Message(Type::CompileResponse) {}
template <typename T> template <typename T>
void Serialize(T&& f) { void Serialize(T&& f) {
f(error); f(error);
} }
std::string error; std::string error;
}; };
struct CompileRequest : Message { // Client -> Server struct CompileRequest : Message { // Client -> Server
using Response = CompileResponse; using Response = CompileResponse;
CompileRequest() : Message(Type::CompileRequest) {} CompileRequest() : Message(Type::CompileRequest) {}
CompileRequest(SourceLanguage lang, std::string src) CompileRequest(SourceLanguage lang, std::string src)
: Message(Type::CompileRequest), language(lang), source(src) {} : Message(Type::CompileRequest), language(lang), source(src) {}
template <typename T> template <typename T>
void Serialize(T&& f) { void Serialize(T&& f) {
f(language); f(language);
f(source); f(source);
} }
SourceLanguage language; SourceLanguage language;
std::string source; std::string source;
}; };
/// Writes the message `m` to the stream `s` /// Writes the message `m` to the stream `s`
template <typename MESSAGE> template <typename MESSAGE>
std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator<<( std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator<<(Stream& s,
Stream& s, const MESSAGE& m) {
const MESSAGE& m) { s << m.type;
s << m.type; const_cast<MESSAGE&>(m).Serialize([&s](const auto& value) { s << value; });
const_cast<MESSAGE&>(m).Serialize([&s](const auto& value) { s << value; }); return s;
return s;
} }
/// Reads the message `m` from the stream `s` /// Reads the message `m` from the stream `s`
template <typename MESSAGE> template <typename MESSAGE>
std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator>>( std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator>>(Stream& s,
Stream& s, MESSAGE& m) {
MESSAGE& m) { Message::Type ty;
Message::Type ty; s >> ty;
s >> ty; if (ty == m.type) {
if (ty == m.type) { m.Serialize([&s](auto& value) { s >> value; });
m.Serialize([&s](auto& value) { s >> value; }); } else {
} else { std::stringstream ss;
std::stringstream ss; ss << "expected message type " << static_cast<int>(m.type) << ", got "
ss << "expected message type " << static_cast<int>(m.type) << ", got " << static_cast<int>(ty);
<< static_cast<int>(ty); s.error = ss.str();
s.error = ss.str(); }
} return s;
return s;
} }
/// Writes the request message `req` to the stream `s`, then reads and returns /// Writes the request message `req` to the stream `s`, then reads and returns
/// the response message from the same stream. /// the response message from the same stream.
template <typename REQUEST, typename RESPONSE = typename REQUEST::Response> template <typename REQUEST, typename RESPONSE = typename REQUEST::Response>
RESPONSE Send(Stream& s, const REQUEST& req) { RESPONSE Send(Stream& s, const REQUEST& req) {
s << req; s << req;
if (s.error.empty()) {
RESPONSE resp;
s >> resp;
if (s.error.empty()) { if (s.error.empty()) {
return resp; RESPONSE resp;
s >> resp;
if (s.error.empty()) {
return resp;
}
} }
} return {};
return {};
} }
} // namespace } // namespace
@ -272,173 +270,172 @@ bool RunServer(std::string port);
bool RunClient(std::string address, std::string port, std::string file); bool RunClient(std::string address, std::string port, std::string file);
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
bool run_server = false; bool run_server = false;
std::string port = "19000"; std::string port = "19000";
std::vector<std::string> args; std::vector<std::string> args;
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
std::string arg = argv[i]; std::string arg = argv[i];
if (arg == "-s" || arg == "--server") { if (arg == "-s" || arg == "--server") {
run_server = true; run_server = true;
continue; continue;
}
if (arg == "-p" || arg == "--port") {
if (i < argc - 1) {
i++;
port = argv[i];
} else {
printf("expected port number");
exit(1);
}
continue;
}
// xcrun flags are ignored so this executable can be used as a replacement
// for xcrun.
if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
i++;
continue;
}
if (arg == "metal") {
for (; i < argc; i++) {
if (std::string(argv[i]) == "-c") {
break;
}
}
continue;
}
args.emplace_back(arg);
} }
if (arg == "-p" || arg == "--port") {
if (i < argc - 1) { bool success = false;
i++;
port = argv[i]; if (run_server) {
} else { success = RunServer(port);
printf("expected port number"); } else {
std::string address;
std::string file;
switch (args.size()) {
case 1:
if (auto* addr = getenv("TINT_REMOTE_COMPILE_ADDRESS")) {
address = addr;
}
file = args[0];
break;
case 2:
address = args[0];
file = args[1];
break;
}
if (address.empty() || file.empty()) {
ShowUsage();
}
success = RunClient(address, port, file);
}
if (!success) {
exit(1); exit(1);
}
continue;
} }
// xcrun flags are ignored so this executable can be used as a replacement return 0;
// for xcrun.
if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
i++;
continue;
}
if (arg == "metal") {
for (; i < argc; i++) {
if (std::string(argv[i]) == "-c") {
break;
}
}
continue;
}
args.emplace_back(arg);
}
bool success = false;
if (run_server) {
success = RunServer(port);
} else {
std::string address;
std::string file;
switch (args.size()) {
case 1:
if (auto* addr = getenv("TINT_REMOTE_COMPILE_ADDRESS")) {
address = addr;
}
file = args[0];
break;
case 2:
address = args[0];
file = args[1];
break;
}
if (address.empty() || file.empty()) {
ShowUsage();
}
success = RunClient(address, port, file);
}
if (!success) {
exit(1);
}
return 0;
} }
bool RunServer(std::string port) { bool RunServer(std::string port) {
auto server_socket = Socket::Listen("", port.c_str()); auto server_socket = Socket::Listen("", port.c_str());
if (!server_socket) { if (!server_socket) {
printf("Failed to listen on port %s\n", port.c_str()); printf("Failed to listen on port %s\n", port.c_str());
return false; return false;
} }
printf("Listening on port %s...\n", port.c_str()); printf("Listening on port %s...\n", port.c_str());
while (auto conn = server_socket->Accept()) { while (auto conn = server_socket->Accept()) {
std::thread([=] { std::thread([=] {
DEBUG("Client connected..."); DEBUG("Client connected...");
Stream stream{conn.get()}; Stream stream{conn.get()};
{ {
ConnectionRequest req; ConnectionRequest req;
stream >> req; stream >> req;
if (!stream.error.empty()) { if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str()); printf("%s\n", stream.error.c_str());
return; return;
} }
ConnectionResponse resp; ConnectionResponse resp;
if (req.protocol_version != kProtocolVersion) { if (req.protocol_version != kProtocolVersion) {
DEBUG("Protocol version mismatch"); DEBUG("Protocol version mismatch");
resp.error = "Protocol version mismatch"; resp.error = "Protocol version mismatch";
stream << resp; stream << resp;
return; return;
} }
stream << resp; stream << resp;
} }
DEBUG("Connection established"); DEBUG("Connection established");
{ {
CompileRequest req; CompileRequest req;
stream >> req; stream >> req;
if (!stream.error.empty()) { if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str()); printf("%s\n", stream.error.c_str());
return; return;
} }
#ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API #ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API
if (req.language == SourceLanguage::MSL) { if (req.language == SourceLanguage::MSL) {
auto result = CompileMslUsingMetalAPI(req.source); auto result = CompileMslUsingMetalAPI(req.source);
CompileResponse resp; CompileResponse resp;
if (!result.success) { if (!result.success) {
resp.error = result.output; resp.error = result.output;
} }
stream << resp; stream << resp;
return; return;
} }
#endif #endif
CompileResponse resp; CompileResponse resp;
resp.error = "server cannot compile this type of shader"; resp.error = "server cannot compile this type of shader";
stream << resp; stream << resp;
} }
}).detach(); }).detach();
} }
return true; return true;
} }
bool RunClient(std::string address, std::string port, std::string file) { bool RunClient(std::string address, std::string port, std::string file) {
// Read the file // Read the file
std::ifstream input(file, std::ios::binary); std::ifstream input(file, std::ios::binary);
if (!input) { if (!input) {
printf("Couldn't open '%s'\n", file.c_str()); printf("Couldn't open '%s'\n", file.c_str());
return false; return false;
} }
std::string source((std::istreambuf_iterator<char>(input)), std::string source((std::istreambuf_iterator<char>(input)), std::istreambuf_iterator<char>());
std::istreambuf_iterator<char>());
constexpr const int timeout_ms = 10000; constexpr const int timeout_ms = 10000;
DEBUG("Connecting to %s:%s...", address.c_str(), port.c_str()); DEBUG("Connecting to %s:%s...", address.c_str(), port.c_str());
auto conn = Socket::Connect(address.c_str(), port.c_str(), timeout_ms); auto conn = Socket::Connect(address.c_str(), port.c_str(), timeout_ms);
if (!conn) { if (!conn) {
printf("Connection failed\n"); printf("Connection failed\n");
return false; return false;
} }
Stream stream{conn.get()}; Stream stream{conn.get()};
DEBUG("Sending connection request..."); DEBUG("Sending connection request...");
auto conn_resp = Send(stream, ConnectionRequest{kProtocolVersion}); auto conn_resp = Send(stream, ConnectionRequest{kProtocolVersion});
if (!stream.error.empty()) { if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str()); printf("%s\n", stream.error.c_str());
return false; return false;
} }
if (!conn_resp.error.empty()) { if (!conn_resp.error.empty()) {
printf("%s\n", conn_resp.error.c_str()); printf("%s\n", conn_resp.error.c_str());
return false; return false;
} }
DEBUG("Connection established. Requesting compile..."); DEBUG("Connection established. Requesting compile...");
auto comp_resp = Send(stream, CompileRequest{SourceLanguage::MSL, source}); auto comp_resp = Send(stream, CompileRequest{SourceLanguage::MSL, source});
if (!stream.error.empty()) { if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str()); printf("%s\n", stream.error.c_str());
return false; return false;
} }
if (!comp_resp.error.empty()) { if (!comp_resp.error.empty()) {
printf("%s\n", comp_resp.error.c_str()); printf("%s\n", comp_resp.error.c_str());
return false; return false;
} }
DEBUG("Compilation successful"); DEBUG("Compilation successful");
return true; return true;
} }

View File

@ -23,34 +23,33 @@
#pragma clang diagnostic pop #pragma clang diagnostic pop
CompileResult CompileMslUsingMetalAPI(const std::string& src) { CompileResult CompileMslUsingMetalAPI(const std::string& src) {
CompileResult result; CompileResult result;
result.success = false;
NSError* error = nil;
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
if (!device) {
result.output = "MTLCreateSystemDefaultDevice returned null";
result.success = false; result.success = false;
NSError* error = nil;
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
if (!device) {
result.output = "MTLCreateSystemDefaultDevice returned null";
result.success = false;
return result;
}
NSString* source = [NSString stringWithCString:src.c_str() encoding:NSUTF8StringEncoding];
MTLCompileOptions* compileOptions = [MTLCompileOptions new];
compileOptions.languageVersion = MTLLanguageVersion1_2;
id<MTLLibrary> library = [device newLibraryWithSource:source
options:compileOptions
error:&error];
if (!library) {
NSString* output = [error localizedDescription];
result.output = [output UTF8String];
result.success = false;
}
return result; return result;
}
NSString* source = [NSString stringWithCString:src.c_str()
encoding:NSUTF8StringEncoding];
MTLCompileOptions* compileOptions = [MTLCompileOptions new];
compileOptions.languageVersion = MTLLanguageVersion1_2;
id<MTLLibrary> library = [device newLibraryWithSource:source
options:compileOptions
error:&error];
if (!library) {
NSString* output = [error localizedDescription];
result.output = [output UTF8String];
result.success = false;
}
return result;
} }
#endif #endif

View File

@ -26,62 +26,62 @@
/// The lock can be held by an arbitrary number of readers or a single writer. /// The lock can be held by an arbitrary number of readers or a single writer.
/// Also known as a shared mutex. /// Also known as a shared mutex.
class RWMutex { class RWMutex {
public: public:
inline RWMutex() = default; inline RWMutex() = default;
/// lockReader() locks the mutex for reading. /// lockReader() locks the mutex for reading.
/// Multiple read locks can be held while there are no writer locks. /// Multiple read locks can be held while there are no writer locks.
inline void lockReader(); inline void lockReader();
/// unlockReader() unlocks the mutex for reading. /// unlockReader() unlocks the mutex for reading.
inline void unlockReader(); inline void unlockReader();
/// lockWriter() locks the mutex for writing. /// lockWriter() locks the mutex for writing.
/// If the lock is already locked for reading or writing, lockWriter blocks /// If the lock is already locked for reading or writing, lockWriter blocks
/// until the lock is available. /// until the lock is available.
inline void lockWriter(); inline void lockWriter();
/// unlockWriter() unlocks the mutex for writing. /// unlockWriter() unlocks the mutex for writing.
inline void unlockWriter(); inline void unlockWriter();
private: private:
RWMutex(const RWMutex&) = delete; RWMutex(const RWMutex&) = delete;
RWMutex& operator=(const RWMutex&) = delete; RWMutex& operator=(const RWMutex&) = delete;
int readLocks = 0; int readLocks = 0;
int pendingWriteLocks = 0; int pendingWriteLocks = 0;
std::mutex mutex; std::mutex mutex;
std::condition_variable cv; std::condition_variable cv;
}; };
void RWMutex::lockReader() { void RWMutex::lockReader() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
readLocks++; readLocks++;
} }
void RWMutex::unlockReader() { void RWMutex::unlockReader() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
readLocks--; readLocks--;
if (readLocks == 0 && pendingWriteLocks > 0) { if (readLocks == 0 && pendingWriteLocks > 0) {
cv.notify_one(); cv.notify_one();
} }
} }
void RWMutex::lockWriter() { void RWMutex::lockWriter() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
if (readLocks > 0) { if (readLocks > 0) {
pendingWriteLocks++; pendingWriteLocks++;
cv.wait(lock, [&] { return readLocks == 0; }); cv.wait(lock, [&] { return readLocks == 0; });
pendingWriteLocks--; pendingWriteLocks--;
} }
lock.release(); // Keep lock held lock.release(); // Keep lock held
} }
void RWMutex::unlockWriter() { void RWMutex::unlockWriter() {
if (pendingWriteLocks > 0) { if (pendingWriteLocks > 0) {
cv.notify_one(); cv.notify_one();
} }
mutex.unlock(); mutex.unlock();
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -90,49 +90,49 @@ void RWMutex::unlockWriter() {
/// RLock is a RAII read lock helper for a RWMutex. /// RLock is a RAII read lock helper for a RWMutex.
class RLock { class RLock {
public: public:
/// Constructor. /// Constructor.
/// Locks `mutex` with a read-lock for the lifetime of the WLock. /// Locks `mutex` with a read-lock for the lifetime of the WLock.
/// @param mutex the mutex /// @param mutex the mutex
explicit inline RLock(RWMutex& mutex); explicit inline RLock(RWMutex& mutex);
/// Destructor. /// Destructor.
/// Unlocks the RWMutex. /// Unlocks the RWMutex.
inline ~RLock(); inline ~RLock();
/// Move constructor /// Move constructor
/// @param other the other RLock to move into this RLock. /// @param other the other RLock to move into this RLock.
inline RLock(RLock&& other); inline RLock(RLock&& other);
/// Move assignment operator /// Move assignment operator
/// @param other the other RLock to move into this RLock. /// @param other the other RLock to move into this RLock.
/// @returns this RLock so calls can be chained /// @returns this RLock so calls can be chained
inline RLock& operator=(RLock&& other); inline RLock& operator=(RLock&& other);
private: private:
RLock(const RLock&) = delete; RLock(const RLock&) = delete;
RLock& operator=(const RLock&) = delete; RLock& operator=(const RLock&) = delete;
RWMutex* m; RWMutex* m;
}; };
RLock::RLock(RWMutex& mutex) : m(&mutex) { RLock::RLock(RWMutex& mutex) : m(&mutex) {
m->lockReader(); m->lockReader();
} }
RLock::~RLock() { RLock::~RLock() {
if (m != nullptr) { if (m != nullptr) {
m->unlockReader(); m->unlockReader();
} }
} }
RLock::RLock(RLock&& other) { RLock::RLock(RLock&& other) {
m = other.m; m = other.m;
other.m = nullptr; other.m = nullptr;
} }
RLock& RLock::operator=(RLock&& other) { RLock& RLock::operator=(RLock&& other) {
m = other.m; m = other.m;
other.m = nullptr; other.m = nullptr;
return *this; return *this;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -141,50 +141,50 @@ RLock& RLock::operator=(RLock&& other) {
/// WLock is a RAII write lock helper for a RWMutex. /// WLock is a RAII write lock helper for a RWMutex.
class WLock { class WLock {
public: public:
/// Constructor. /// Constructor.
/// Locks `mutex` with a write-lock for the lifetime of the WLock. /// Locks `mutex` with a write-lock for the lifetime of the WLock.
/// @param mutex the mutex /// @param mutex the mutex
explicit inline WLock(RWMutex& mutex); explicit inline WLock(RWMutex& mutex);
/// Destructor. /// Destructor.
/// Unlocks the RWMutex. /// Unlocks the RWMutex.
inline ~WLock(); inline ~WLock();
/// Move constructor /// Move constructor
/// @param other the other WLock to move into this WLock. /// @param other the other WLock to move into this WLock.
inline WLock(WLock&& other); inline WLock(WLock&& other);
/// Move assignment operator /// Move assignment operator
/// @param other the other WLock to move into this WLock. /// @param other the other WLock to move into this WLock.
/// @returns this WLock so calls can be chained /// @returns this WLock so calls can be chained
inline WLock& operator=(WLock&& other); inline WLock& operator=(WLock&& other);
private: private:
WLock(const WLock&) = delete; WLock(const WLock&) = delete;
WLock& operator=(const WLock&) = delete; WLock& operator=(const WLock&) = delete;
RWMutex* m; RWMutex* m;
}; };
WLock::WLock(RWMutex& mutex) : m(&mutex) { WLock::WLock(RWMutex& mutex) : m(&mutex) {
m->lockWriter(); m->lockWriter();
} }
WLock::~WLock() { WLock::~WLock() {
if (m != nullptr) { if (m != nullptr) {
m->unlockWriter(); m->unlockWriter();
} }
} }
WLock::WLock(WLock&& other) { WLock::WLock(WLock&& other) {
m = other.m; m = other.m;
other.m = nullptr; other.m = nullptr;
} }
WLock& WLock::operator=(WLock&& other) { WLock& WLock::operator=(WLock&& other) {
m = other.m; m = other.m;
other.m = nullptr; other.m = nullptr;
return *this; return *this;
} }
#endif // TOOLS_SRC_CMD_REMOTE_COMPILE_RWMUTEX_H_ #endif // TOOLS_SRC_CMD_REMOTE_COMPILE_RWMUTEX_H_

View File

@ -45,266 +45,257 @@ namespace {
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1); constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
void init() { void init() {
#if defined(_WIN32) #if defined(_WIN32)
if (wsaInitCount++ == 0) { if (wsaInitCount++ == 0) {
WSADATA winsockData; WSADATA winsockData;
(void)WSAStartup(MAKEWORD(2, 2), &winsockData); (void)WSAStartup(MAKEWORD(2, 2), &winsockData);
} }
#endif #endif
} }
void term() { void term() {
#if defined(_WIN32) #if defined(_WIN32)
if (--wsaInitCount == 0) { if (--wsaInitCount == 0) {
WSACleanup(); WSACleanup();
} }
#endif #endif
} }
bool setBlocking(SOCKET s, bool blocking) { bool setBlocking(SOCKET s, bool blocking) {
#if defined(_WIN32) #if defined(_WIN32)
u_long mode = blocking ? 0 : 1; u_long mode = blocking ? 0 : 1;
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR; return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
#else #else
auto arg = fcntl(s, F_GETFL, nullptr); auto arg = fcntl(s, F_GETFL, nullptr);
if (arg < 0) { if (arg < 0) {
return false; return false;
} }
arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK); arg = blocking ? (arg & ~O_NONBLOCK) : (arg | O_NONBLOCK);
return fcntl(s, F_SETFL, arg) >= 0; return fcntl(s, F_SETFL, arg) >= 0;
#endif #endif
} }
bool errored(SOCKET s) { bool errored(SOCKET s) {
if (s == InvalidSocket) { if (s == InvalidSocket) {
return true; return true;
} }
char error = 0; char error = 0;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len); getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len);
return error != 0; return error != 0;
} }
class Impl : public Socket { class Impl : public Socket {
public: public:
static std::shared_ptr<Impl> create(const char* address, const char* port) { static std::shared_ptr<Impl> create(const char* address, const char* port) {
init(); init();
addrinfo hints = {}; addrinfo hints = {};
hints.ai_family = AF_INET; hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM; hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP; hints.ai_protocol = IPPROTO_TCP;
hints.ai_flags = AI_PASSIVE; hints.ai_flags = AI_PASSIVE;
addrinfo* info = nullptr; addrinfo* info = nullptr;
auto err = getaddrinfo(address, port, &hints, &info); auto err = getaddrinfo(address, port, &hints, &info);
#if !defined(_WIN32) #if !defined(_WIN32)
if (err) { if (err) {
printf("getaddrinfo(%s, %s) error: %s\n", address, port, printf("getaddrinfo(%s, %s) error: %s\n", address, port, gai_strerror(err));
gai_strerror(err)); }
}
#endif #endif
if (info) { if (info) {
auto socket = auto socket = ::socket(info->ai_family, info->ai_socktype, info->ai_protocol);
::socket(info->ai_family, info->ai_socktype, info->ai_protocol); auto out = std::make_shared<Impl>(info, socket);
auto out = std::make_shared<Impl>(info, socket); out->setOptions();
out->setOptions(); return out;
return out; }
freeaddrinfo(info);
term();
return nullptr;
} }
freeaddrinfo(info); explicit Impl(SOCKET socket) : info(nullptr), s(socket) {}
term(); Impl(addrinfo* info, SOCKET socket) : info(info), s(socket) {}
return nullptr;
}
explicit Impl(SOCKET socket) : info(nullptr), s(socket) {} ~Impl() {
Impl(addrinfo* info, SOCKET socket) : info(info), s(socket) {} freeaddrinfo(info);
Close();
~Impl() { term();
freeaddrinfo(info);
Close();
term();
}
template <typename FUNCTION>
void lock(FUNCTION&& f) {
RLock l(mutex);
f(s, info);
}
void setOptions() {
RLock l(mutex);
if (s == InvalidSocket) {
return;
} }
int enable = 1; template <typename FUNCTION>
void lock(FUNCTION&& f) {
RLock l(mutex);
f(s, info);
}
void setOptions() {
RLock l(mutex);
if (s == InvalidSocket) {
return;
}
int enable = 1;
#if !defined(_WIN32) #if !defined(_WIN32)
// Prevent sockets lingering after process termination, causing // Prevent sockets lingering after process termination, causing
// reconnection issues on the same port. // reconnection issues on the same port.
setsockopt(s, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&enable), setsockopt(s, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<char*>(&enable), sizeof(enable));
sizeof(enable));
struct { struct {
int l_onoff; /* linger active */ int l_onoff; /* linger active */
int l_linger; /* how many seconds to linger for */ int l_linger; /* how many seconds to linger for */
} linger = {false, 0}; } linger = {false, 0};
setsockopt(s, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&linger), setsockopt(s, SOL_SOCKET, SO_LINGER, reinterpret_cast<char*>(&linger), sizeof(linger));
sizeof(linger));
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
// Enable TCP_NODELAY. // Enable TCP_NODELAY.
setsockopt(s, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&enable), setsockopt(s, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char*>(&enable), sizeof(enable));
sizeof(enable));
}
bool IsOpen() override {
{
RLock l(mutex);
if ((s != InvalidSocket) && !errored(s)) {
return true;
}
} }
WLock lock(mutex);
s = InvalidSocket;
return false;
}
void Close() override { bool IsOpen() override {
{ {
RLock l(mutex); RLock l(mutex);
if (s != InvalidSocket) { if ((s != InvalidSocket) && !errored(s)) {
return true;
}
}
WLock lock(mutex);
s = InvalidSocket;
return false;
}
void Close() override {
{
RLock l(mutex);
if (s != InvalidSocket) {
#if defined(_WIN32) #if defined(_WIN32)
closesocket(s); closesocket(s);
#else #else
::shutdown(s, SHUT_RDWR); ::shutdown(s, SHUT_RDWR);
#endif #endif
} }
} }
WLock l(mutex); WLock l(mutex);
if (s != InvalidSocket) { if (s != InvalidSocket) {
#if !defined(_WIN32) #if !defined(_WIN32)
::close(s); ::close(s);
#endif #endif
s = InvalidSocket; s = InvalidSocket;
}
} }
}
size_t Read(void* buffer, size_t bytes) override { size_t Read(void* buffer, size_t bytes) override {
RLock lock(mutex); RLock lock(mutex);
if (s == InvalidSocket) { if (s == InvalidSocket) {
return 0; return 0;
}
auto len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
return (len < 0) ? 0 : len;
} }
auto len =
recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
return (len < 0) ? 0 : len;
}
bool Write(const void* buffer, size_t bytes) override { bool Write(const void* buffer, size_t bytes) override {
RLock lock(mutex); RLock lock(mutex);
if (s == InvalidSocket) { if (s == InvalidSocket) {
return false; return false;
}
if (bytes == 0) {
return true;
}
return ::send(s, reinterpret_cast<const char*>(buffer), static_cast<int>(bytes), 0) > 0;
} }
if (bytes == 0) {
return true; std::shared_ptr<Socket> Accept() override {
std::shared_ptr<Impl> out;
lock([&](SOCKET socket, const addrinfo*) {
if (socket != InvalidSocket) {
init();
out = std::make_shared<Impl>(::accept(socket, 0, 0));
out->setOptions();
}
});
return out;
} }
return ::send(s, reinterpret_cast<const char*>(buffer),
static_cast<int>(bytes), 0) > 0;
}
std::shared_ptr<Socket> Accept() override { private:
std::shared_ptr<Impl> out; addrinfo* const info;
lock([&](SOCKET socket, const addrinfo*) { SOCKET s = InvalidSocket;
if (socket != InvalidSocket) { RWMutex mutex;
init();
out = std::make_shared<Impl>(::accept(socket, 0, 0));
out->setOptions();
}
});
return out;
}
private:
addrinfo* const info;
SOCKET s = InvalidSocket;
RWMutex mutex;
}; };
} // anonymous namespace } // anonymous namespace
std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) { std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) {
auto impl = Impl::create(address, port); auto impl = Impl::create(address, port);
if (!impl) { if (!impl) {
return nullptr; return nullptr;
}
impl->lock([&](SOCKET socket, const addrinfo* info) {
if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
impl.reset();
return;
} }
impl->lock([&](SOCKET socket, const addrinfo* info) {
if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
impl.reset();
return;
}
if (listen(socket, 0) != 0) { if (listen(socket, 0) != 0) {
impl.reset(); impl.reset();
return; return;
} }
}); });
return impl; return impl;
} }
std::shared_ptr<Socket> Socket::Connect(const char* address, std::shared_ptr<Socket> Socket::Connect(const char* address,
const char* port, const char* port,
uint32_t timeoutMillis) { uint32_t timeoutMillis) {
auto impl = Impl::create(address, port); auto impl = Impl::create(address, port);
if (!impl) { if (!impl) {
return nullptr; return nullptr;
}
std::shared_ptr<Socket> out;
impl->lock([&](SOCKET socket, const addrinfo* info) {
if (socket == InvalidSocket) {
return;
} }
if (timeoutMillis == 0) { std::shared_ptr<Socket> out;
if (::connect(socket, info->ai_addr, impl->lock([&](SOCKET socket, const addrinfo* info) {
static_cast<int>(info->ai_addrlen)) == 0) { if (socket == InvalidSocket) {
out = impl; return;
} }
return;
if (timeoutMillis == 0) {
if (::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) == 0) {
out = impl;
}
return;
}
if (!setBlocking(socket, false)) {
return;
}
auto res = ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
if (res == 0) {
if (setBlocking(socket, true)) {
out = impl;
}
} else {
const auto microseconds = timeoutMillis * 1000;
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(socket, &fdset);
timeval tv;
tv.tv_sec = microseconds / 1000000;
tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000);
res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
out = impl;
}
}
});
if (!out) {
return nullptr;
} }
if (!setBlocking(socket, false)) { return out->IsOpen() ? out : nullptr;
return;
}
auto res =
::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
if (res == 0) {
if (setBlocking(socket, true)) {
out = impl;
}
} else {
const auto microseconds = timeoutMillis * 1000;
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(socket, &fdset);
timeval tv;
tv.tv_sec = microseconds / 1000000;
tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000);
res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
out = impl;
}
}
});
if (!out) {
return nullptr;
}
return out->IsOpen() ? out : nullptr;
} }

View File

@ -20,52 +20,52 @@
/// Socket provides an OS abstraction to a TCP socket. /// Socket provides an OS abstraction to a TCP socket.
class Socket { class Socket {
public: public:
/// Connects to the given TCP address and port. /// Connects to the given TCP address and port.
/// @param address the target socket address /// @param address the target socket address
/// @param port the target socket port /// @param port the target socket port
/// @param timeoutMillis the timeout for the connection attempt. /// @param timeoutMillis the timeout for the connection attempt.
/// If timeoutMillis is non-zero and no connection was made before /// If timeoutMillis is non-zero and no connection was made before
/// timeoutMillis milliseconds, then nullptr is returned. /// timeoutMillis milliseconds, then nullptr is returned.
/// @returns the connected Socket, or nullptr on failure /// @returns the connected Socket, or nullptr on failure
static std::shared_ptr<Socket> Connect(const char* address, static std::shared_ptr<Socket> Connect(const char* address,
const char* port, const char* port,
uint32_t timeoutMillis); uint32_t timeoutMillis);
/// Begins listening for connections on the given TCP address and port. /// Begins listening for connections on the given TCP address and port.
/// Call Accept() on the returned Socket to block and wait for a connection. /// Call Accept() on the returned Socket to block and wait for a connection.
/// @param address the socket address to listen on. Use "localhost" for /// @param address the socket address to listen on. Use "localhost" for
/// connections from only this machine, or an empty string to allow /// connections from only this machine, or an empty string to allow
/// connections from any incoming address. /// connections from any incoming address.
/// @param port the socket port to listen on /// @param port the socket port to listen on
/// @returns the Socket that listens for connections /// @returns the Socket that listens for connections
static std::shared_ptr<Socket> Listen(const char* address, const char* port); static std::shared_ptr<Socket> Listen(const char* address, const char* port);
/// Attempts to read at most `n` bytes into buffer, returning the actual /// Attempts to read at most `n` bytes into buffer, returning the actual
/// number of bytes read. /// number of bytes read.
/// read() will block until the socket is closed or at least one byte is read. /// read() will block until the socket is closed or at least one byte is read.
/// @param buffer the output buffer. Must be at least `n` bytes in size. /// @param buffer the output buffer. Must be at least `n` bytes in size.
/// @param n the maximum number of bytes to read /// @param n the maximum number of bytes to read
/// @return the number of bytes read, or 0 if the socket was closed /// @return the number of bytes read, or 0 if the socket was closed
virtual size_t Read(void* buffer, size_t n) = 0; virtual size_t Read(void* buffer, size_t n) = 0;
/// Writes `n` bytes from buffer into the socket. /// Writes `n` bytes from buffer into the socket.
/// @param buffer the source data buffer. Must be at least `n` bytes in size. /// @param buffer the source data buffer. Must be at least `n` bytes in size.
/// @param n the number of bytes to read from `buffer` /// @param n the number of bytes to read from `buffer`
/// @returns true on success, or false if there was an error or the socket was /// @returns true on success, or false if there was an error or the socket was
/// closed. /// closed.
virtual bool Write(const void* buffer, size_t n) = 0; virtual bool Write(const void* buffer, size_t n) = 0;
/// @returns true if the socket has not been closed. /// @returns true if the socket has not been closed.
virtual bool IsOpen() = 0; virtual bool IsOpen() = 0;
/// Closes the socket. /// Closes the socket.
virtual void Close() = 0; virtual void Close() = 0;
/// Blocks for a connection to be made to the listening port, or for the /// Blocks for a connection to be made to the listening port, or for the
/// Socket to be closed. /// Socket to be closed.
/// @returns a pointer to the next established incoming connection /// @returns a pointer to the next established incoming connection
virtual std::shared_ptr<Socket> Accept() = 0; virtual std::shared_ptr<Socket> Accept() = 0;
}; };
#endif // TOOLS_SRC_CMD_REMOTE_COMPILE_SOCKET_H_ #endif // TOOLS_SRC_CMD_REMOTE_COMPILE_SOCKET_H_