tools/remote-compile: Handle socket disconnects

If the socket disconnected mid-communication, the server could spin, waiting for new data.

Actually handle recv() errors, preventing the server spinning itself to death.

Also fix code style to be more tint-like (snake_case variables, PascalCase functions)

Change-Id: I9fcbfde303a8624e7e1ff87abd33581589f4da42
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105142
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-10-10 15:16:23 +00:00 committed by Dawn LUCI CQ
parent a4f064e3a7
commit 5dac4f9644
4 changed files with 95 additions and 85 deletions

View File

@ -145,6 +145,8 @@ struct Stream {
} }
size -= n; size -= n;
buf += n; buf += n;
} else {
error = "Socket::Read() failed";
} }
} }
return error.empty(); return error.empty();
@ -238,13 +240,15 @@ std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator>>(S
MESSAGE& m) { MESSAGE& m) {
Message::Type ty; Message::Type ty;
s >> ty; s >> ty;
if (ty == m.type) { if (s.error.empty()) {
m.Serialize([&s](auto& value) { s >> value; }); if (ty == m.type) {
} else { m.Serialize([&s](auto& value) { s >> value; });
std::stringstream ss; } else {
ss << "expected message type " << static_cast<int>(m.type) << ", got " std::stringstream ss;
<< static_cast<int>(ty); ss << "expected message type " << static_cast<int>(m.type) << ", got "
s.error = ss.str(); << static_cast<int>(ty);
s.error = ss.str();
}
} }
return s; return s;
} }
@ -291,8 +295,7 @@ int main(int argc, char* argv[]) {
continue; continue;
} }
// xcrun flags are ignored so this executable can be used as a replacement // xcrun flags are ignored so this executable can be used as a replacement for xcrun.
// for xcrun.
if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) { if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
i++; i++;
continue; continue;
@ -357,7 +360,7 @@ bool RunServer(std::string port) {
ConnectionRequest req; ConnectionRequest req;
stream >> req; stream >> req;
if (!stream.error.empty()) { if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str()); DEBUG("%s", stream.error.c_str());
return; return;
} }
ConnectionResponse resp; ConnectionResponse resp;
@ -374,7 +377,7 @@ bool RunServer(std::string port) {
CompileRequest req; CompileRequest req;
stream >> req; stream >> req;
if (!stream.error.empty()) { if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str()); DEBUG("%s\n", stream.error.c_str());
return; return;
} }
#ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API #ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API

View File

@ -29,56 +29,56 @@ class RWMutex {
public: public:
inline RWMutex() = default; inline RWMutex() = default;
/// lockReader() locks the mutex for reading. /// LockReader() locks the mutex for reading.
/// Multiple read locks can be held while there are no writer locks. /// Multiple read locks can be held while there are no writer locks.
inline void lockReader(); inline void LockReader();
/// unlockReader() unlocks the mutex for reading. /// UnlockReader() unlocks the mutex for reading.
inline void unlockReader(); inline void UnlockReader();
/// lockWriter() locks the mutex for writing. /// LockWriter() locks the mutex for writing.
/// If the lock is already locked for reading or writing, lockWriter blocks /// If the lock is already locked for reading or writing, LockWriter blocks
/// until the lock is available. /// until the lock is available.
inline void lockWriter(); inline void LockWriter();
/// unlockWriter() unlocks the mutex for writing. /// UnlockWriter() unlocks the mutex for writing.
inline void unlockWriter(); inline void UnlockWriter();
private: private:
RWMutex(const RWMutex&) = delete; RWMutex(const RWMutex&) = delete;
RWMutex& operator=(const RWMutex&) = delete; RWMutex& operator=(const RWMutex&) = delete;
int readLocks = 0; int read_locks = 0;
int pendingWriteLocks = 0; int pending_write_locks = 0;
std::mutex mutex; std::mutex mutex;
std::condition_variable cv; std::condition_variable cv;
}; };
void RWMutex::lockReader() { void RWMutex::LockReader() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
readLocks++; read_locks++;
} }
void RWMutex::unlockReader() { void RWMutex::UnlockReader() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
readLocks--; read_locks--;
if (readLocks == 0 && pendingWriteLocks > 0) { if (read_locks == 0 && pending_write_locks > 0) {
cv.notify_one(); cv.notify_one();
} }
} }
void RWMutex::lockWriter() { void RWMutex::LockWriter() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
if (readLocks > 0) { if (read_locks > 0) {
pendingWriteLocks++; pending_write_locks++;
cv.wait(lock, [&] { return readLocks == 0; }); cv.wait(lock, [&] { return read_locks == 0; });
pendingWriteLocks--; pending_write_locks--;
} }
lock.release(); // Keep lock held lock.release(); // Keep lock held
} }
void RWMutex::unlockWriter() { void RWMutex::UnlockWriter() {
if (pendingWriteLocks > 0) { if (pending_write_locks > 0) {
cv.notify_one(); cv.notify_one();
} }
mutex.unlock(); mutex.unlock();
@ -115,12 +115,12 @@ class RLock {
}; };
RLock::RLock(RWMutex& mutex) : m(&mutex) { RLock::RLock(RWMutex& mutex) : m(&mutex) {
m->lockReader(); m->LockReader();
} }
RLock::~RLock() { RLock::~RLock() {
if (m != nullptr) { if (m != nullptr) {
m->unlockReader(); m->UnlockReader();
} }
} }
@ -167,12 +167,12 @@ class WLock {
}; };
WLock::WLock(RWMutex& mutex) : m(&mutex) { WLock::WLock(RWMutex& mutex) : m(&mutex) {
m->lockWriter(); m->LockWriter();
} }
WLock::~WLock() { WLock::~WLock() {
if (m != nullptr) { if (m != nullptr) {
m->unlockWriter(); m->UnlockWriter();
} }
} }

View File

@ -32,7 +32,7 @@
#if defined(_WIN32) #if defined(_WIN32)
#include <atomic> #include <atomic>
namespace { namespace {
std::atomic<int> wsaInitCount = {0}; std::atomic<int> wsa_init_count = {0};
} // anonymous namespace } // anonymous namespace
#else #else
#include <fcntl.h> #include <fcntl.h>
@ -43,24 +43,24 @@ using SOCKET = int;
namespace { namespace {
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1); constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
void init() { void Init() {
#if defined(_WIN32) #if defined(_WIN32)
if (wsaInitCount++ == 0) { if (wsa_init_count++ == 0) {
WSADATA winsockData; WSADATA winsock_data;
(void)WSAStartup(MAKEWORD(2, 2), &winsockData); (void)WSAStartup(MAKEWORD(2, 2), &winsock_data);
} }
#endif #endif
} }
void term() { void Term() {
#if defined(_WIN32) #if defined(_WIN32)
if (--wsaInitCount == 0) { if (--wsa_init_count == 0) {
WSACleanup(); WSACleanup();
} }
#endif #endif
} }
bool setBlocking(SOCKET s, bool blocking) { bool SetBlocking(SOCKET s, bool blocking) {
#if defined(_WIN32) #if defined(_WIN32)
u_long mode = blocking ? 0 : 1; u_long mode = blocking ? 0 : 1;
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR; return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
@ -74,7 +74,7 @@ bool setBlocking(SOCKET s, bool blocking) {
#endif #endif
} }
bool errored(SOCKET s) { bool Errored(SOCKET s) {
if (s == InvalidSocket) { if (s == InvalidSocket) {
return true; return true;
} }
@ -87,7 +87,7 @@ bool errored(SOCKET s) {
class Impl : public Socket { class Impl : public Socket {
public: public:
static std::shared_ptr<Impl> create(const char* address, const char* port) { static std::shared_ptr<Impl> create(const char* address, const char* port) {
init(); Init();
addrinfo hints = {}; addrinfo hints = {};
hints.ai_family = AF_INET; hints.ai_family = AF_INET;
@ -106,12 +106,12 @@ class Impl : public Socket {
if (info) { if (info) {
auto socket = ::socket(info->ai_family, info->ai_socktype, info->ai_protocol); auto socket = ::socket(info->ai_family, info->ai_socktype, info->ai_protocol);
auto out = std::make_shared<Impl>(info, socket); auto out = std::make_shared<Impl>(info, socket);
out->setOptions(); out->SetOptions();
return out; return out;
} }
freeaddrinfo(info); freeaddrinfo(info);
term(); Term();
return nullptr; return nullptr;
} }
@ -121,16 +121,16 @@ class Impl : public Socket {
~Impl() { ~Impl() {
freeaddrinfo(info); freeaddrinfo(info);
Close(); Close();
term(); Term();
} }
template <typename FUNCTION> template <typename FUNCTION>
void lock(FUNCTION&& f) { void Lock(FUNCTION&& f) {
RLock l(mutex); RLock l(mutex);
f(s, info); f(s, info);
} }
void setOptions() { void SetOptions() {
RLock l(mutex); RLock l(mutex);
if (s == InvalidSocket) { if (s == InvalidSocket) {
return; return;
@ -157,7 +157,7 @@ class Impl : public Socket {
bool IsOpen() override { bool IsOpen() override {
{ {
RLock l(mutex); RLock l(mutex);
if ((s != InvalidSocket) && !errored(s)) { if ((s != InvalidSocket) && !Errored(s)) {
return true; return true;
} }
} }
@ -188,12 +188,20 @@ class Impl : public Socket {
} }
size_t Read(void* buffer, size_t bytes) override { size_t Read(void* buffer, size_t bytes) override {
RLock lock(mutex); {
if (s == InvalidSocket) { RLock lock(mutex);
return 0; if (s == InvalidSocket) {
return 0;
}
size_t len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
if (len > 0) {
return len;
}
} }
auto len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0); // Socket closed or errored
return (len < 0) ? 0 : len; WLock lock(mutex);
s = InvalidSocket;
return 0;
} }
bool Write(const void* buffer, size_t bytes) override { bool Write(const void* buffer, size_t bytes) override {
@ -209,11 +217,13 @@ class Impl : public Socket {
std::shared_ptr<Socket> Accept() override { std::shared_ptr<Socket> Accept() override {
std::shared_ptr<Impl> out; std::shared_ptr<Impl> out;
lock([&](SOCKET socket, const addrinfo*) { Lock([&](SOCKET socket, const addrinfo*) {
if (socket != InvalidSocket) { if (socket != InvalidSocket) {
init(); Init();
out = std::make_shared<Impl>(::accept(socket, 0, 0)); if (auto s = ::accept(socket, 0, 0); s >= 0) {
out->setOptions(); out = std::make_shared<Impl>(s);
out->SetOptions();
}
} }
}); });
return out; return out;
@ -232,7 +242,7 @@ std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) {
if (!impl) { if (!impl) {
return nullptr; return nullptr;
} }
impl->lock([&](SOCKET socket, const addrinfo* info) { impl->Lock([&](SOCKET socket, const addrinfo* info) {
if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) { if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
impl.reset(); impl.reset();
return; return;
@ -248,46 +258,46 @@ std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) {
std::shared_ptr<Socket> Socket::Connect(const char* address, std::shared_ptr<Socket> Socket::Connect(const char* address,
const char* port, const char* port,
uint32_t timeoutMillis) { uint32_t timeout_ms) {
auto impl = Impl::create(address, port); auto impl = Impl::create(address, port);
if (!impl) { if (!impl) {
return nullptr; return nullptr;
} }
std::shared_ptr<Socket> out; std::shared_ptr<Socket> out;
impl->lock([&](SOCKET socket, const addrinfo* info) { impl->Lock([&](SOCKET socket, const addrinfo* info) {
if (socket == InvalidSocket) { if (socket == InvalidSocket) {
return; return;
} }
if (timeoutMillis == 0) { if (timeout_ms == 0) {
if (::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) == 0) { if (::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) == 0) {
out = impl; out = impl;
} }
return; return;
} }
if (!setBlocking(socket, false)) { if (!SetBlocking(socket, false)) {
return; return;
} }
auto res = ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)); auto res = ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
if (res == 0) { if (res == 0) {
if (setBlocking(socket, true)) { if (SetBlocking(socket, true)) {
out = impl; out = impl;
} }
} else { } else {
const auto microseconds = timeoutMillis * 1000; const auto timeout_us = timeout_ms * 1000;
fd_set fdset; fd_set fdset;
FD_ZERO(&fdset); FD_ZERO(&fdset);
FD_SET(socket, &fdset); FD_SET(socket, &fdset);
timeval tv; timeval tv;
tv.tv_sec = microseconds / 1000000; tv.tv_sec = timeout_us / 1000000;
tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000); tv.tv_usec = timeout_us - static_cast<uint32_t>(tv.tv_sec * 1000000);
res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv); res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
if (res > 0 && !errored(socket) && setBlocking(socket, true)) { if (res > 0 && !Errored(socket) && SetBlocking(socket, true)) {
out = impl; out = impl;
} }
} }

View File

@ -24,29 +24,27 @@ class Socket {
/// Connects to the given TCP address and port. /// Connects to the given TCP address and port.
/// @param address the target socket address /// @param address the target socket address
/// @param port the target socket port /// @param port the target socket port
/// @param timeoutMillis the timeout for the connection attempt. /// @param timeout_ms the timeout for the connection attempt.
/// If timeoutMillis is non-zero and no connection was made before /// If timeout_ms is non-zero and no connection was made before timeout_ms milliseconds,
/// timeoutMillis milliseconds, then nullptr is returned. /// then nullptr is returned.
/// @returns the connected Socket, or nullptr on failure /// @returns the connected Socket, or nullptr on failure
static std::shared_ptr<Socket> Connect(const char* address, static std::shared_ptr<Socket> Connect(const char* address,
const char* port, const char* port,
uint32_t timeoutMillis); uint32_t timeout_ms);
/// Begins listening for connections on the given TCP address and port. /// Begins listening for connections on the given TCP address and port.
/// Call Accept() on the returned Socket to block and wait for a connection. /// Call Accept() on the returned Socket to block and wait for a connection.
/// @param address the socket address to listen on. Use "localhost" for /// @param address the socket address to listen on. Use "localhost" for connections from only
/// connections from only this machine, or an empty string to allow /// this machine, or an empty string to allow connections from any incoming address.
/// connections from any incoming address.
/// @param port the socket port to listen on /// @param port the socket port to listen on
/// @returns the Socket that listens for connections /// @returns the Socket that listens for connections
static std::shared_ptr<Socket> Listen(const char* address, const char* port); static std::shared_ptr<Socket> Listen(const char* address, const char* port);
/// Attempts to read at most `n` bytes into buffer, returning the actual /// Attempts to read at most `n` bytes into buffer, returning the actual number of bytes read.
/// number of bytes read.
/// read() will block until the socket is closed or at least one byte is read. /// read() will block until the socket is closed or at least one byte is read.
/// @param buffer the output buffer. Must be at least `n` bytes in size. /// @param buffer the output buffer. Must be at least `n` bytes in size.
/// @param n the maximum number of bytes to read /// @param n the maximum number of bytes to read
/// @return the number of bytes read, or 0 if the socket was closed /// @return the number of bytes read, or 0 if the socket was closed or errored
virtual size_t Read(void* buffer, size_t n) = 0; virtual size_t Read(void* buffer, size_t n) = 0;
/// Writes `n` bytes from buffer into the socket. /// Writes `n` bytes from buffer into the socket.
@ -62,8 +60,7 @@ class Socket {
/// Closes the socket. /// Closes the socket.
virtual void Close() = 0; virtual void Close() = 0;
/// Blocks for a connection to be made to the listening port, or for the /// Blocks for a connection to be made to the listening port, or for the Socket to be closed.
/// Socket to be closed.
/// @returns a pointer to the next established incoming connection /// @returns a pointer to the next established incoming connection
virtual std::shared_ptr<Socket> Accept() = 0; virtual std::shared_ptr<Socket> Accept() = 0;
}; };