tools/remote-compile: Handle socket disconnects

If the socket disconnected mid-communication, the server could spin, waiting for new data.

Actually handle recv() errors, preventing the server spinning itself to death.

Also fix code style to be more tint-like (snake_case variables, PascalCase functions)

Change-Id: I9fcbfde303a8624e7e1ff87abd33581589f4da42
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105142
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-10-10 15:16:23 +00:00 committed by Dawn LUCI CQ
parent a4f064e3a7
commit 5dac4f9644
4 changed files with 95 additions and 85 deletions

View File

@ -145,6 +145,8 @@ struct Stream {
}
size -= n;
buf += n;
} else {
error = "Socket::Read() failed";
}
}
return error.empty();
@ -238,13 +240,15 @@ std::enable_if_t<std::is_base_of<Message, MESSAGE>::value, Stream>& operator>>(S
MESSAGE& m) {
Message::Type ty;
s >> ty;
if (ty == m.type) {
m.Serialize([&s](auto& value) { s >> value; });
} else {
std::stringstream ss;
ss << "expected message type " << static_cast<int>(m.type) << ", got "
<< static_cast<int>(ty);
s.error = ss.str();
if (s.error.empty()) {
if (ty == m.type) {
m.Serialize([&s](auto& value) { s >> value; });
} else {
std::stringstream ss;
ss << "expected message type " << static_cast<int>(m.type) << ", got "
<< static_cast<int>(ty);
s.error = ss.str();
}
}
return s;
}
@ -291,8 +295,7 @@ int main(int argc, char* argv[]) {
continue;
}
// xcrun flags are ignored so this executable can be used as a replacement
// for xcrun.
// xcrun flags are ignored so this executable can be used as a replacement for xcrun.
if ((arg == "-x" || arg == "-sdk") && (i < argc - 1)) {
i++;
continue;
@ -357,7 +360,7 @@ bool RunServer(std::string port) {
ConnectionRequest req;
stream >> req;
if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str());
DEBUG("%s", stream.error.c_str());
return;
}
ConnectionResponse resp;
@ -374,7 +377,7 @@ bool RunServer(std::string port) {
CompileRequest req;
stream >> req;
if (!stream.error.empty()) {
printf("%s\n", stream.error.c_str());
DEBUG("%s\n", stream.error.c_str());
return;
}
#ifdef TINT_ENABLE_MSL_COMPILATION_USING_METAL_API

View File

@ -29,56 +29,56 @@ class RWMutex {
public:
inline RWMutex() = default;
/// lockReader() locks the mutex for reading.
/// LockReader() locks the mutex for reading.
/// Multiple read locks can be held while there are no writer locks.
inline void lockReader();
inline void LockReader();
/// unlockReader() unlocks the mutex for reading.
inline void unlockReader();
/// UnlockReader() unlocks the mutex for reading.
inline void UnlockReader();
/// lockWriter() locks the mutex for writing.
/// If the lock is already locked for reading or writing, lockWriter blocks
/// LockWriter() locks the mutex for writing.
/// If the lock is already locked for reading or writing, LockWriter blocks
/// until the lock is available.
inline void lockWriter();
inline void LockWriter();
/// unlockWriter() unlocks the mutex for writing.
inline void unlockWriter();
/// UnlockWriter() unlocks the mutex for writing.
inline void UnlockWriter();
private:
RWMutex(const RWMutex&) = delete;
RWMutex& operator=(const RWMutex&) = delete;
int readLocks = 0;
int pendingWriteLocks = 0;
int read_locks = 0;
int pending_write_locks = 0;
std::mutex mutex;
std::condition_variable cv;
};
void RWMutex::lockReader() {
void RWMutex::LockReader() {
std::unique_lock<std::mutex> lock(mutex);
readLocks++;
read_locks++;
}
void RWMutex::unlockReader() {
void RWMutex::UnlockReader() {
std::unique_lock<std::mutex> lock(mutex);
readLocks--;
if (readLocks == 0 && pendingWriteLocks > 0) {
read_locks--;
if (read_locks == 0 && pending_write_locks > 0) {
cv.notify_one();
}
}
void RWMutex::lockWriter() {
void RWMutex::LockWriter() {
std::unique_lock<std::mutex> lock(mutex);
if (readLocks > 0) {
pendingWriteLocks++;
cv.wait(lock, [&] { return readLocks == 0; });
pendingWriteLocks--;
if (read_locks > 0) {
pending_write_locks++;
cv.wait(lock, [&] { return read_locks == 0; });
pending_write_locks--;
}
lock.release(); // Keep lock held
}
void RWMutex::unlockWriter() {
if (pendingWriteLocks > 0) {
void RWMutex::UnlockWriter() {
if (pending_write_locks > 0) {
cv.notify_one();
}
mutex.unlock();
@ -115,12 +115,12 @@ class RLock {
};
RLock::RLock(RWMutex& mutex) : m(&mutex) {
m->lockReader();
m->LockReader();
}
RLock::~RLock() {
if (m != nullptr) {
m->unlockReader();
m->UnlockReader();
}
}
@ -167,12 +167,12 @@ class WLock {
};
WLock::WLock(RWMutex& mutex) : m(&mutex) {
m->lockWriter();
m->LockWriter();
}
WLock::~WLock() {
if (m != nullptr) {
m->unlockWriter();
m->UnlockWriter();
}
}

View File

@ -32,7 +32,7 @@
#if defined(_WIN32)
#include <atomic>
namespace {
std::atomic<int> wsaInitCount = {0};
std::atomic<int> wsa_init_count = {0};
} // anonymous namespace
#else
#include <fcntl.h>
@ -43,24 +43,24 @@ using SOCKET = int;
namespace {
constexpr SOCKET InvalidSocket = static_cast<SOCKET>(-1);
void init() {
void Init() {
#if defined(_WIN32)
if (wsaInitCount++ == 0) {
WSADATA winsockData;
(void)WSAStartup(MAKEWORD(2, 2), &winsockData);
if (wsa_init_count++ == 0) {
WSADATA winsock_data;
(void)WSAStartup(MAKEWORD(2, 2), &winsock_data);
}
#endif
}
void term() {
void Term() {
#if defined(_WIN32)
if (--wsaInitCount == 0) {
if (--wsa_init_count == 0) {
WSACleanup();
}
#endif
}
bool setBlocking(SOCKET s, bool blocking) {
bool SetBlocking(SOCKET s, bool blocking) {
#if defined(_WIN32)
u_long mode = blocking ? 0 : 1;
return ioctlsocket(s, FIONBIO, &mode) == NO_ERROR;
@ -74,7 +74,7 @@ bool setBlocking(SOCKET s, bool blocking) {
#endif
}
bool errored(SOCKET s) {
bool Errored(SOCKET s) {
if (s == InvalidSocket) {
return true;
}
@ -87,7 +87,7 @@ bool errored(SOCKET s) {
class Impl : public Socket {
public:
static std::shared_ptr<Impl> create(const char* address, const char* port) {
init();
Init();
addrinfo hints = {};
hints.ai_family = AF_INET;
@ -106,12 +106,12 @@ class Impl : public Socket {
if (info) {
auto socket = ::socket(info->ai_family, info->ai_socktype, info->ai_protocol);
auto out = std::make_shared<Impl>(info, socket);
out->setOptions();
out->SetOptions();
return out;
}
freeaddrinfo(info);
term();
Term();
return nullptr;
}
@ -121,16 +121,16 @@ class Impl : public Socket {
~Impl() {
freeaddrinfo(info);
Close();
term();
Term();
}
template <typename FUNCTION>
void lock(FUNCTION&& f) {
void Lock(FUNCTION&& f) {
RLock l(mutex);
f(s, info);
}
void setOptions() {
void SetOptions() {
RLock l(mutex);
if (s == InvalidSocket) {
return;
@ -157,7 +157,7 @@ class Impl : public Socket {
bool IsOpen() override {
{
RLock l(mutex);
if ((s != InvalidSocket) && !errored(s)) {
if ((s != InvalidSocket) && !Errored(s)) {
return true;
}
}
@ -188,12 +188,20 @@ class Impl : public Socket {
}
size_t Read(void* buffer, size_t bytes) override {
RLock lock(mutex);
if (s == InvalidSocket) {
return 0;
{
RLock lock(mutex);
if (s == InvalidSocket) {
return 0;
}
size_t len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
if (len > 0) {
return len;
}
}
auto len = recv(s, reinterpret_cast<char*>(buffer), static_cast<int>(bytes), 0);
return (len < 0) ? 0 : len;
// Socket closed or errored
WLock lock(mutex);
s = InvalidSocket;
return 0;
}
bool Write(const void* buffer, size_t bytes) override {
@ -209,11 +217,13 @@ class Impl : public Socket {
std::shared_ptr<Socket> Accept() override {
std::shared_ptr<Impl> out;
lock([&](SOCKET socket, const addrinfo*) {
Lock([&](SOCKET socket, const addrinfo*) {
if (socket != InvalidSocket) {
init();
out = std::make_shared<Impl>(::accept(socket, 0, 0));
out->setOptions();
Init();
if (auto s = ::accept(socket, 0, 0); s >= 0) {
out = std::make_shared<Impl>(s);
out->SetOptions();
}
}
});
return out;
@ -232,7 +242,7 @@ std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) {
if (!impl) {
return nullptr;
}
impl->lock([&](SOCKET socket, const addrinfo* info) {
impl->Lock([&](SOCKET socket, const addrinfo* info) {
if (bind(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) != 0) {
impl.reset();
return;
@ -248,46 +258,46 @@ std::shared_ptr<Socket> Socket::Listen(const char* address, const char* port) {
std::shared_ptr<Socket> Socket::Connect(const char* address,
const char* port,
uint32_t timeoutMillis) {
uint32_t timeout_ms) {
auto impl = Impl::create(address, port);
if (!impl) {
return nullptr;
}
std::shared_ptr<Socket> out;
impl->lock([&](SOCKET socket, const addrinfo* info) {
impl->Lock([&](SOCKET socket, const addrinfo* info) {
if (socket == InvalidSocket) {
return;
}
if (timeoutMillis == 0) {
if (timeout_ms == 0) {
if (::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen)) == 0) {
out = impl;
}
return;
}
if (!setBlocking(socket, false)) {
if (!SetBlocking(socket, false)) {
return;
}
auto res = ::connect(socket, info->ai_addr, static_cast<int>(info->ai_addrlen));
if (res == 0) {
if (setBlocking(socket, true)) {
if (SetBlocking(socket, true)) {
out = impl;
}
} else {
const auto microseconds = timeoutMillis * 1000;
const auto timeout_us = timeout_ms * 1000;
fd_set fdset;
FD_ZERO(&fdset);
FD_SET(socket, &fdset);
timeval tv;
tv.tv_sec = microseconds / 1000000;
tv.tv_usec = microseconds - static_cast<uint32_t>(tv.tv_sec * 1000000);
tv.tv_sec = timeout_us / 1000000;
tv.tv_usec = timeout_us - static_cast<uint32_t>(tv.tv_sec * 1000000);
res = select(static_cast<int>(socket + 1), nullptr, &fdset, nullptr, &tv);
if (res > 0 && !errored(socket) && setBlocking(socket, true)) {
if (res > 0 && !Errored(socket) && SetBlocking(socket, true)) {
out = impl;
}
}

View File

@ -24,29 +24,27 @@ class Socket {
/// Connects to the given TCP address and port.
/// @param address the target socket address
/// @param port the target socket port
/// @param timeoutMillis the timeout for the connection attempt.
/// If timeoutMillis is non-zero and no connection was made before
/// timeoutMillis milliseconds, then nullptr is returned.
/// @param timeout_ms the timeout for the connection attempt.
/// If timeout_ms is non-zero and no connection was made before timeout_ms milliseconds,
/// then nullptr is returned.
/// @returns the connected Socket, or nullptr on failure
static std::shared_ptr<Socket> Connect(const char* address,
const char* port,
uint32_t timeoutMillis);
uint32_t timeout_ms);
/// Begins listening for connections on the given TCP address and port.
/// Call Accept() on the returned Socket to block and wait for a connection.
/// @param address the socket address to listen on. Use "localhost" for
/// connections from only this machine, or an empty string to allow
/// connections from any incoming address.
/// @param address the socket address to listen on. Use "localhost" for connections from only
/// this machine, or an empty string to allow connections from any incoming address.
/// @param port the socket port to listen on
/// @returns the Socket that listens for connections
static std::shared_ptr<Socket> Listen(const char* address, const char* port);
/// Attempts to read at most `n` bytes into buffer, returning the actual
/// number of bytes read.
/// Attempts to read at most `n` bytes into buffer, returning the actual number of bytes read.
/// read() will block until the socket is closed or at least one byte is read.
/// @param buffer the output buffer. Must be at least `n` bytes in size.
/// @param n the maximum number of bytes to read
/// @return the number of bytes read, or 0 if the socket was closed
/// @return the number of bytes read, or 0 if the socket was closed or errored
virtual size_t Read(void* buffer, size_t n) = 0;
/// Writes `n` bytes from buffer into the socket.
@ -62,8 +60,7 @@ class Socket {
/// Closes the socket.
virtual void Close() = 0;
/// Blocks for a connection to be made to the listening port, or for the
/// Socket to be closed.
/// Blocks for a connection to be made to the listening port, or for the Socket to be closed.
/// @returns a pointer to the next established incoming connection
virtual std::shared_ptr<Socket> Accept() = 0;
};