From ee19094380af6f9799de129c8cacdc19c238d17d Mon Sep 17 00:00:00 2001 From: Chen Qin Date: Thu, 18 Oct 2018 21:02:33 -0700 Subject: [PATCH 1/5] fix error in dmlc#57, clean up comments and naming --- .travis.yml | 2 +- src/allreduce_base.cc | 60 +-- src/allreduce_robust-inl.h | 18 +- src/allreduce_robust.cc | 44 +- src/socket.h | 866 ++++++++++++++++++------------------- 5 files changed, 493 insertions(+), 497 deletions(-) diff --git a/.travis.yml b/.travis.yml index a0ae62e3..dc9941f0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,7 @@ before_install: - source ${TRAVIS}/travis_setup_env.sh install: - - pip install --user cpplint pylint + - pip install --user cpplint pylint kubernetes script: scripts/travis_script.sh diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 862187bc..3d547fe0 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -454,29 +454,29 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, while (true) { // select helper bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; for (int i = 0; i < nlink; ++i) { if (i == parent_index) { if (size_down_in != total_size) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); // only watch for exception in live channels - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); finished = false; } if (size_up_out != total_size && size_up_out < size_up_reduce) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } } else { if (links[i].size_read != total_size) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); } // size_write <= size_read if (links[i].size_write != total_size) { if (links[i].size_write < size_down_in) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } // only watch for exception in live channels - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); finished = false; } } @@ -484,17 +484,17 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, // finish runing allreduce if (finished) break; // select must return - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } // read data from childs for (int i = 0; i < nlink; ++i) { - if (i != parent_index && selecter.CheckRead(links[i].sock)) { + if (i != parent_index && watcher.CheckRead(links[i].sock)) { ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size); if (ret != kSuccess) { return ReportError(&links[i], ret); @@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, } } // read data from parent - if (selecter.CheckRead(links[parent_index].sock) && + if (watcher.CheckRead(links[parent_index].sock) && total_size > size_down_in) { ssize_t len = links[parent_index].sock. Recv(sendrecvbuf + size_down_in, total_size - size_down_in); @@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { while (true) { bool finished = true; // select helper - utils::SelectHelper selecter; + utils::PollHelper watcher; for (int i = 0; i < nlink; ++i) { if (in_link == -2) { - selecter.WatchRead(links[i].sock); finished = false; + watcher.WatchRead(links[i].sock); finished = false; } if (i == in_link && links[i].size_read != total_size) { - selecter.WatchRead(links[i].sock); finished = false; + watcher.WatchRead(links[i].sock); finished = false; } if (in_link != -2 && i != in_link && links[i].size_write != total_size) { if (links[i].size_write < size_in) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } finished = false; } - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); } // finish running if (finished) break; // select - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } if (in_link == -2) { // probe in-link for (int i = 0; i < nlink; ++i) { - if (selecter.CheckRead(links[i].sock)) { + if (watcher.CheckRead(links[i].sock)) { ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size); if (ret != kSuccess) { return ReportError(&links[i], ret); @@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { } } else { // read from in link - if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { + if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) { ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size); if (ret != kSuccess) { return ReportError(&links[in_link], ret); @@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, while (true) { // select helper bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; if (read_ptr != stop_read) { - selecter.WatchRead(next.sock); + watcher.WatchRead(next.sock); finished = false; } if (write_ptr != stop_write) { if (write_ptr < read_ptr) { - selecter.WatchWrite(prev.sock); + watcher.WatchWrite(prev.sock); } finished = false; } if (finished) break; - selecter.Select(); - if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { + watcher.Poll(); + if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { size_t size = stop_read - read_ptr; size_t start = read_ptr % total_size; if (start + size > total_size) { @@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, while (true) { // select helper bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; if (read_ptr != stop_read) { - selecter.WatchRead(next.sock); + watcher.WatchRead(next.sock); finished = false; } if (write_ptr != stop_write) { if (write_ptr < reduce_ptr) { - selecter.WatchWrite(prev.sock); + watcher.WatchWrite(prev.sock); } finished = false; } if (finished) break; - selecter.Select(); - if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { + watcher.Poll(); + if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); if (ret != kSuccess) { return ReportError(&next, ret); diff --git a/src/allreduce_robust-inl.h b/src/allreduce_robust-inl.h index d3cbc003..22df5191 100644 --- a/src/allreduce_robust-inl.h +++ b/src/allreduce_robust-inl.h @@ -70,29 +70,29 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, utils::Assert(stage != 2 && stage != 1, "invalie stage id"); } // select helper - utils::SelectHelper selecter; + utils::PollHelper watcher; bool done = (stage == 3); for (int i = 0; i < nlink; ++i) { - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); switch (stage) { case 0: if (i != parent_index && links[i].size_read != sizeof(EdgeType)) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); } break; case 1: if (i == parent_index) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } break; case 2: if (i == parent_index) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); } break; case 3: if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); done = false; } break; @@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, } // finish all the stages, and write out message if (done) break; - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } @@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, // read data from childs for (int i = 0; i < nlink; ++i) { if (i != parent_index) { - if (selecter.CheckRead(links[i].sock)) { + if (watcher.CheckRead(links[i].sock)) { ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType)); if (ret != kSuccess) return ReportError(&links[i], ret); } diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index a48a349a..210d5d8a 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (len == sizeof(sig)) all_links[i].size_write = 2; } } - utils::SelectHelper rsel; + utils::PollHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) { @@ -343,15 +343,15 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } if (finished) break; // wait to read from the channels to discard data - rsel.Select(); + rsel.Poll(); } for (int i = 0; i < nlink; ++i) { if (!all_links[i].sock.BadSocket()) { - utils::SelectHelper::WaitExcept(all_links[i].sock); + utils::PollHelper::WaitExcept(all_links[i].sock); } } while (true) { - utils::SelectHelper rsel; + utils::PollHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { @@ -359,7 +359,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } } if (finished) break; - rsel.Select(); + rsel.Poll(); for (int i = 0; i < nlink; ++i) { if (all_links[i].sock.BadSocket()) continue; if (all_links[i].size_read == 0) { @@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role, } while (true) { bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; for (int i = 0; i < nlink; ++i) { if (i == recv_link && links[i].size_read != size) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); finished = false; } if (req_in[i] && links[i].size_write != size) { if (role == kHaveData || (links[recv_link].size_read != links[i].size_write)) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } finished = false; } - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); } if (finished) break; - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } if (role == kRequestData) { const int pid = recv_link; - if (selecter.CheckRead(links[pid].sock)) { + if (watcher.CheckRead(links[pid].sock)) { ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size); if (ret != kSuccess) { return ReportError(&links[pid], ret); @@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, if (role == kPassData) { const int pid = recv_link; const size_t buffer_size = links[pid].buffer_size; - if (selecter.CheckRead(links[pid].sock)) { + if (watcher.CheckRead(links[pid].sock)) { size_t min_write = size; for (int i = 0; i < nlink; ++i) { if (req_in[i]) min_write = std::min(links[i].size_write, min_write); @@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, char *buf = reinterpret_cast(sendrecvbuf_); while (true) { bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; if (read_ptr != read_end) { - selecter.WatchRead(prev.sock); + watcher.WatchRead(prev.sock); finished = false; } if (write_ptr < read_ptr && write_ptr != write_end) { - selecter.WatchWrite(next.sock); + watcher.WatchWrite(next.sock); finished = false; } - selecter.WatchException(prev.sock); - selecter.WatchException(next.sock); + watcher.WatchException(prev.sock); + watcher.WatchException(next.sock); if (finished) break; - selecter.Select(); - if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept); - if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept); - if (read_ptr != read_end && selecter.CheckRead(prev.sock)) { + watcher.Poll(); + if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept); + if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept); + if (read_ptr != read_end && watcher.CheckRead(prev.sock)) { ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr); if (len == 0) { prev.sock.Close(); return ReportError(&prev, kRecvZeroLen); diff --git a/src/socket.h b/src/socket.h index 83d28e88..22392330 100644 --- a/src/socket.h +++ b/src/socket.h @@ -20,509 +20,505 @@ #include #include #include -#include #include #endif #include #include +#include +#include #include "../include/rabit/internal/utils.h" #if defined(_WIN32) typedef int ssize_t; typedef int sock_size_t; + +static inline int poll(struct pollfd *pfd, int nfds, + int timeout) { return WSAPoll ( pfd, nfds, timeout ); } #else +#include typedef int SOCKET; typedef size_t sock_size_t; const int INVALID_SOCKET = -1; #endif namespace rabit { -namespace utils { + namespace utils { /*! \brief data structure for network address */ struct SockAddr { - sockaddr_in addr; - // constructor - SockAddr(void) {} - SockAddr(const char *url, int port) { - this->Set(url, port); - } - inline static std::string GetHostName(void) { - std::string buf; buf.resize(256); - utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); - return std::string(buf.c_str()); - } - /*! - * \brief set the address - * \param url the url of the address - * \param port the port of address - */ - inline void Set(const char *host, int port) { - addrinfo hints; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_INET; - hints.ai_protocol = SOCK_STREAM; - addrinfo *res = NULL; - int sig = getaddrinfo(host, NULL, &hints, &res); - Check(sig == 0 && res != NULL, "cannot obtain address of %s", host); - Check(res->ai_family == AF_INET, "Does not support IPv6"); - memcpy(&addr, res->ai_addr, res->ai_addrlen); - addr.sin_port = htons(port); - freeaddrinfo(res); - } - /*! \brief return port of the address*/ - inline int port(void) const { - return ntohs(addr.sin_port); - } - /*! \return a string representation of the address */ - inline std::string AddrStr(void) const { - std::string buf; buf.resize(256); + sockaddr_in addr; + // constructor + SockAddr(void) {} + SockAddr(const char *url, int port) { + this->Set(url, port); + } + inline static std::string GetHostName(void) { + std::string buf; buf.resize(256); + utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); + return std::string(buf.c_str()); + } + /*! + * \brief set the address + * \param url the url of the address + * \param port the port of address + */ + inline void Set(const char *host, int port) { + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_protocol = SOCK_STREAM; + addrinfo *res = NULL; + int sig = getaddrinfo(host, NULL, &hints, &res); + Check(sig == 0 && res != NULL, "cannot obtain address of %s", host); + Check(res->ai_family == AF_INET, "Does not support IPv6"); + memcpy(&addr, res->ai_addr, res->ai_addrlen); + addr.sin_port = htons(port); + freeaddrinfo(res); + } + /*! \brief return port of the address*/ + inline int port(void) const { + return ntohs(addr.sin_port); + } + /*! \return a string representation of the address */ + inline std::string AddrStr(void) const { + std::string buf; buf.resize(256); #ifdef _WIN32 - const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, - &buf[0], buf.length()); + const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, + &buf[0], buf.length()); #else - const char *s = inet_ntop(AF_INET, &addr.sin_addr, - &buf[0], buf.length()); + const char *s = inet_ntop(AF_INET, &addr.sin_addr, + &buf[0], buf.length()); #endif - Assert(s != NULL, "cannot decode address"); - return std::string(s); - } + Assert(s != NULL, "cannot decode address"); + return std::string(s); + } }; /*! * \brief base class containing common operations of TCP and UDP sockets */ class Socket { - public: - /*! \brief the file descriptor of socket */ - SOCKET sockfd; - // default conversion to int - inline operator SOCKET() const { - return sockfd; - } - /*! - * \return last error of socket operation - */ - inline static int GetLastError(void) { +public: + /*! \brief the file descriptor of socket */ + SOCKET sockfd; + // default conversion to int + inline operator SOCKET() const { + return sockfd; + } + /*! + * \return last error of socket operation + */ + inline static int GetLastError(void) { #ifdef _WIN32 - return WSAGetLastError(); + return WSAGetLastError(); #else - return errno; + return errno; #endif - } - /*! \return whether last error was would block */ - inline static bool LastErrorWouldBlock(void) { - int errsv = GetLastError(); + } + /*! \return whether last error was would block */ + inline static bool LastErrorWouldBlock(void) { + int errsv = GetLastError(); #ifdef _WIN32 - return errsv == WSAEWOULDBLOCK; + return errsv == WSAEWOULDBLOCK; #else - return errsv == EAGAIN || errsv == EWOULDBLOCK; + return errsv == EAGAIN || errsv == EWOULDBLOCK; #endif - } - /*! - * \brief start up the socket module - * call this before using the sockets - */ - inline static void Startup(void) { -#ifdef _WIN32 - WSADATA wsa_data; - if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { - Socket::Error("Startup"); - } - if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { - WSACleanup(); - utils::Error("Could not find a usable version of Winsock.dll\n"); } -#endif - } - /*! - * \brief shutdown the socket module after use, all sockets need to be closed - */ - inline static void Finalize(void) { + /*! + * \brief start up the socket module + * call this before using the sockets + */ + inline static void Startup(void) { #ifdef _WIN32 - WSACleanup(); + WSADATA wsa_data; +if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { +Socket::Error("Startup"); +} +if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { +WSACleanup(); +utils::Error("Could not find a usable version of Winsock.dll\n"); +} #endif - } - /*! - * \brief set this socket to use non-blocking mode - * \param non_block whether set it to be non-block, if it is false - * it will set it back to block mode - */ - inline void SetNonBlock(bool non_block) { + } + /*! + * \brief shutdown the socket module after use, all sockets need to be closed + */ + inline static void Finalize(void) { #ifdef _WIN32 - u_long mode = non_block ? 1 : 0; - if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { - Socket::Error("SetNonBlock"); + WSACleanup(); +#endif } + /*! + * \brief set this socket to use non-blocking mode + * \param non_block whether set it to be non-block, if it is false + * it will set it back to block mode + */ + inline void SetNonBlock(bool non_block) { +#ifdef _WIN32 + u_long mode = non_block ? 1 : 0; +if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { +Socket::Error("SetNonBlock"); +} #else - int flag = fcntl(sockfd, F_GETFL, 0); - if (flag == -1) { - Socket::Error("SetNonBlock-1"); - } - if (non_block) { - flag |= O_NONBLOCK; - } else { - flag &= ~O_NONBLOCK; - } - if (fcntl(sockfd, F_SETFL, flag) == -1) { - Socket::Error("SetNonBlock-2"); - } -#endif - } - /*! - * \brief bind the socket to an address - * \param addr - */ - inline void Bind(const SockAddr &addr) { - if (bind(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == -1) { - Socket::Error("Bind"); - } - } - /*! - * \brief try bind the socket to host, from start_port to end_port - * \param start_port starting port number to try - * \param end_port ending port number to try - * \return the port successfully bind to, return -1 if failed to bind any port - */ - inline int TryBindHost(int start_port, int end_port) { - // TODO(tqchen) add prefix check - for (int port = start_port; port < end_port; ++port) { - SockAddr addr("0.0.0.0", port); - if (bind(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == 0) { - return port; + int flag = fcntl(sockfd, F_GETFL, 0); + if (flag == -1) { + Socket::Error("SetNonBlock-1"); } -#if defined(_WIN32) - if (WSAGetLastError() != WSAEADDRINUSE) { - Socket::Error("TryBindHost"); + if (non_block) { + flag |= O_NONBLOCK; + } else { + flag &= ~O_NONBLOCK; } -#else - if (errno != EADDRINUSE) { - Socket::Error("TryBindHost"); + if (fcntl(sockfd, F_SETFL, flag) == -1) { + Socket::Error("SetNonBlock-2"); } #endif } + /*! + * \brief bind the socket to an address + * \param addr + */ + inline void Bind(const SockAddr &addr) { + if (bind(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == -1) { + Socket::Error("Bind"); + } + } + /*! + * \brief try bind the socket to host, from start_port to end_port + * \param start_port starting port number to try + * \param end_port ending port number to try + * \return the port successfully bind to, return -1 if failed to bind any port + */ + inline int TryBindHost(int start_port, int end_port) { + // TODO(tqchen) add prefix check + for (int port = start_port; port < end_port; ++port) { + SockAddr addr("0.0.0.0", port); + if (bind(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == 0) { + return port; + } +#if defined(_WIN32) + if (WSAGetLastError() != WSAEADDRINUSE) { +Socket::Error("TryBindHost"); +} +#else + if (errno != EADDRINUSE) { + Socket::Error("TryBindHost"); + } +#endif + } - return -1; - } - /*! \brief get last error code if any */ - inline int GetSockError(void) const { - int error = 0; - socklen_t len = sizeof(error); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { - Error("GetSockError"); - } - return error; - } - /*! \brief check if anything bad happens */ - inline bool BadSocket(void) const { - if (IsClosed()) return true; - int err = GetSockError(); - if (err == EBADF || err == EINTR) return true; - return false; - } - /*! \brief check if socket is already closed */ - inline bool IsClosed(void) const { - return sockfd == INVALID_SOCKET; - } - /*! \brief close the socket */ - inline void Close(void) { - if (sockfd != INVALID_SOCKET) { + return -1; + } + /*! \brief get last error code if any */ + inline int GetSockError(void) const { + int error = 0; + socklen_t len = sizeof(error); + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { + Error("GetSockError"); + } + return error; + } + /*! \brief check if anything bad happens */ + inline bool BadSocket(void) const { + if (IsClosed()) return true; + int err = GetSockError(); + if (err == EBADF || err == EINTR) return true; + return false; + } + /*! \brief check if socket is already closed */ + inline bool IsClosed(void) const { + return sockfd == INVALID_SOCKET; + } + /*! \brief close the socket */ + inline void Close(void) { + if (sockfd != INVALID_SOCKET) { #ifdef _WIN32 - closesocket(sockfd); + closesocket(sockfd); #else - close(sockfd); + close(sockfd); #endif - sockfd = INVALID_SOCKET; - } else { - Error("Socket::Close double close the socket or close without create"); - } - } - // report an socket error - inline static void Error(const char *msg) { - int errsv = GetLastError(); + sockfd = INVALID_SOCKET; + } else { + Error("Socket::Close double close the socket or close without create"); + } + } + // report an socket error + inline static void Error(const char *msg) { + int errsv = GetLastError(); #ifdef _WIN32 - utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv); + utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv); #else - utils::Error("Socket %s Error:%s", msg, strerror(errsv)); + utils::Error("Socket %s Error:%s", msg, strerror(errsv)); #endif - } + } - protected: - explicit Socket(SOCKET sockfd) : sockfd(sockfd) { - } +protected: + explicit Socket(SOCKET sockfd) : sockfd(sockfd) { + } }; /*! * \brief a wrapper of TCP socket that hopefully be cross platform */ class TCPSocket : public Socket{ - public: - // constructor - TCPSocket(void) : Socket(INVALID_SOCKET) { - } - explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { - } - /*! - * \brief enable/disable TCP keepalive - * \param keepalive whether to set the keep alive option on - */ - inline void SetKeepAlive(bool keepalive) { - int opt = static_cast(keepalive); - if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, - reinterpret_cast(&opt), sizeof(opt)) < 0) { - Socket::Error("SetKeepAlive"); - } - } - /*! - * \brief create the socket, call this before using socket - * \param af domain - */ - inline void Create(int af = PF_INET) { - sockfd = socket(PF_INET, SOCK_STREAM, 0); - if (sockfd == INVALID_SOCKET) { - Socket::Error("Create"); - } - } - /*! - * \brief perform listen of the socket - * \param backlog backlog parameter - */ - inline void Listen(int backlog = 16) { - listen(sockfd, backlog); - } - /*! \brief get a new connection */ - TCPSocket Accept(void) { - SOCKET newfd = accept(sockfd, NULL, NULL); - if (newfd == INVALID_SOCKET) { - Socket::Error("Accept"); - } - return TCPSocket(newfd); - } - /*! - * \brief decide whether the socket is at OOB mark - * \return 1 if at mark, 0 if not, -1 if an error occured - */ - inline int AtMark(void) const { +public: + // constructor + TCPSocket(void) : Socket(INVALID_SOCKET) { + } + explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { + } + /*! + * \brief enable/disable TCP keepalive + * \param keepalive whether to set the keep alive option on + */ + inline void SetKeepAlive(bool keepalive) { + int opt = static_cast(keepalive); + if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, + reinterpret_cast(&opt), sizeof(opt)) < 0) { + Socket::Error("SetKeepAlive"); + } + } + /*! + * \brief create the socket, call this before using socket + * \param af domain + */ + inline void Create(int af = PF_INET) { + sockfd = socket(PF_INET, SOCK_STREAM, 0); + if (sockfd == INVALID_SOCKET) { + Socket::Error("Create"); + } + } + /*! + * \brief perform listen of the socket + * \param backlog backlog parameter + */ + inline void Listen(int backlog = 16) { + listen(sockfd, backlog); + } + /*! \brief get a new connection */ + TCPSocket Accept(void) { + SOCKET newfd = accept(sockfd, NULL, NULL); + if (newfd == INVALID_SOCKET) { + Socket::Error("Accept"); + } + return TCPSocket(newfd); + } + /*! + * \brief decide whether the socket is at OOB mark + * \return 1 if at mark, 0 if not, -1 if an error occured + */ + inline int AtMark(void) const { #ifdef _WIN32 - unsigned long atmark; // NOLINT(*) - if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; + unsigned long atmark; // NOLINT(*) +if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; #else - int atmark; - if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; + int atmark; + if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; #endif - return static_cast(atmark); - } - /*! - * \brief connect to an address - * \param addr the address to connect to - * \return whether connect is successful - */ - inline bool Connect(const SockAddr &addr) { - return connect(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == 0; - } - /*! - * \brief send data using the socket - * \param buf the pointer to the buffer - * \param len the size of the buffer - * \param flags extra flags - * \return size of data actually sent - * return -1 if error occurs - */ - inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { - const char *buf = reinterpret_cast(buf_); - return send(sockfd, buf, static_cast(len), flag); - } - /*! - * \brief receive data using the socket - * \param buf_ the pointer to the buffer - * \param len the size of the buffer - * \param flags extra flags - * \return size of data actually received - * return -1 if error occurs - */ - inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { - char *buf = reinterpret_cast(buf_); - return recv(sockfd, buf, static_cast(len), flags); - } - /*! - * \brief peform block write that will attempt to send all data out - * can still return smaller than request when error occurs - * \param buf the pointer to the buffer - * \param len the size of the buffer - * \return size of data actually sent - */ - inline size_t SendAll(const void *buf_, size_t len) { - const char *buf = reinterpret_cast(buf_); - size_t ndone = 0; - while (ndone < len) { - ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); - if (ret == -1) { - if (LastErrorWouldBlock()) return ndone; - Socket::Error("SendAll"); + return static_cast(atmark); + } + /*! + * \brief connect to an address + * \param addr the address to connect to + * \return whether connect is successful + */ + inline bool Connect(const SockAddr &addr) { + return connect(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == 0; + } + /*! + * \brief send data using the socket + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually sent + * return -1 if error occurs + */ + inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { + const char *buf = reinterpret_cast(buf_); + return send(sockfd, buf, static_cast(len), flag); + } + /*! + * \brief receive data using the socket + * \param buf_ the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually received + * return -1 if error occurs + */ + inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { + char *buf = reinterpret_cast(buf_); + return recv(sockfd, buf, static_cast(len), flags); + } + /*! + * \brief peform block write that will attempt to send all data out + * can still return smaller than request when error occurs + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \return size of data actually sent + */ + inline size_t SendAll(const void *buf_, size_t len) { + const char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); + if (ret == -1) { + if (LastErrorWouldBlock()) return ndone; + Socket::Error("SendAll"); + } + buf += ret; + ndone += ret; } - buf += ret; - ndone += ret; - } - return ndone; - } - /*! - * \brief peforma block read that will attempt to read all data - * can still return smaller than request when error occurs - * \param buf_ the buffer pointer - * \param len length of data to recv - * \return size of data actually sent - */ - inline size_t RecvAll(void *buf_, size_t len) { - char *buf = reinterpret_cast(buf_); - size_t ndone = 0; - while (ndone < len) { - ssize_t ret = recv(sockfd, buf, - static_cast(len - ndone), MSG_WAITALL); - if (ret == -1) { - if (LastErrorWouldBlock()) return ndone; - Socket::Error("RecvAll"); + return ndone; + } + /*! + * \brief peforma block read that will attempt to read all data + * can still return smaller than request when error occurs + * \param buf_ the buffer pointer + * \param len length of data to recv + * \return size of data actually sent + */ + inline size_t RecvAll(void *buf_, size_t len) { + char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = recv(sockfd, buf, + static_cast(len - ndone), MSG_WAITALL); + if (ret == -1) { + if (LastErrorWouldBlock()) return ndone; + Socket::Error("RecvAll"); + } + if (ret == 0) return ndone; + buf += ret; + ndone += ret; } - if (ret == 0) return ndone; - buf += ret; - ndone += ret; - } - return ndone; - } - /*! - * \brief send a string over network - * \param str the string to be sent - */ - inline void SendStr(const std::string &str) { - int len = static_cast(str.length()); - utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len), - "error during send SendStr"); - if (len != 0) { - utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(), - "error during send SendStr"); + return ndone; } - } - /*! - * \brief recv a string from network - * \param out_str the string to receive - */ - inline void RecvStr(std::string *out_str) { - int len; - utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len), - "error during send RecvStr"); - out_str->resize(len); - if (len != 0) { - utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(), + /*! + * \brief send a string over network + * \param str the string to be sent + */ + inline void SendStr(const std::string &str) { + int len = static_cast(str.length()); + utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len), "error during send SendStr"); + if (len != 0) { + utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(), + "error during send SendStr"); + } + } + /*! + * \brief recv a string from network + * \param out_str the string to receive + */ + inline void RecvStr(std::string *out_str) { + int len; + utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len), + "error during send RecvStr"); + out_str->resize(len); + if (len != 0) { + utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(), + "error during send SendStr"); + } } - } }; -/*! \brief helper data structure to perform select */ -struct SelectHelper { - public: - SelectHelper(void) { - FD_ZERO(&read_set); - FD_ZERO(&write_set); - FD_ZERO(&except_set); - maxfd = 0; - } - /*! - * \brief add file descriptor to watch for read - * \param fd file descriptor to be watched - */ - inline void WatchRead(SOCKET fd) { - FD_SET(fd, &read_set); - if (fd > maxfd) maxfd = fd; - } - /*! - * \brief add file descriptor to watch for write - * \param fd file descriptor to be watched - */ - inline void WatchWrite(SOCKET fd) { - FD_SET(fd, &write_set); - if (fd > maxfd) maxfd = fd; - } - /*! - * \brief add file descriptor to watch for exception - * \param fd file descriptor to be watched - */ - inline void WatchException(SOCKET fd) { - FD_SET(fd, &except_set); - if (fd > maxfd) maxfd = fd; - } - /*! - * \brief Check if the descriptor is ready for read - * \param fd file descriptor to check status - */ - inline bool CheckRead(SOCKET fd) const { - return FD_ISSET(fd, &read_set) != 0; - } - /*! - * \brief Check if the descriptor is ready for write - * \param fd file descriptor to check status - */ - inline bool CheckWrite(SOCKET fd) const { - return FD_ISSET(fd, &write_set) != 0; - } - /*! - * \brief Check if the descriptor has any exception - * \param fd file descriptor to check status - */ - inline bool CheckExcept(SOCKET fd) const { - return FD_ISSET(fd, &except_set) != 0; - } - /*! - * \brief wait for exception event on a single descriptor - * \param fd the file descriptor to wait the event for - * \param timeout the timeout counter, can be 0, which means wait until the event happen - * \return 1 if success, 0 if timeout, and -1 if error occurs - */ - inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*) - fd_set wait_set; - FD_ZERO(&wait_set); - FD_SET(fd, &wait_set); - return Select_(static_cast(fd + 1), - NULL, NULL, &wait_set, timeout); - } - /*! - * \brief peform select on the set defined - * \param select_read whether to watch for read event - * \param select_write whether to watch for write event - * \param select_except whether to watch for exception event - * \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block - * \return number of active descriptors selected, - * return -1 if error occurs - */ - inline int Select(long timeout = 0) { // NOLINT(*) - int ret = Select_(static_cast(maxfd + 1), - &read_set, &write_set, &except_set, timeout); - if (ret == -1) { - Socket::Error("Select"); - } - return ret; - } +/*! \brief helper data structure to perform poll */ +struct PollHelper { + public: + /*! + * \brief add file descriptor to watch for read + * \param fd file descriptor to be watched + */ + inline void WatchRead(SOCKET fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLIN; + } + /*! + * \brief add file descriptor to watch for write + * \param fd file descriptor to be watched + */ + inline void WatchWrite(SOCKET fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLOUT; + } + /*! + * \brief add file descriptor to watch for exception + * \param fd file descriptor to be watched + */ + inline void WatchException(SOCKET fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLPRI; + } + /*! + * \brief Check if the descriptor is ready for read + * \param fd file descriptor to check status + */ + inline bool CheckRead(SOCKET fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); + } + /*! + * \brief Check if the descriptor is ready for write + * \param fd file descriptor to check status + */ + inline bool CheckWrite(SOCKET fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); + } + /*! + * \brief Check if the descriptor has any exception + * \param fd file descriptor to check status + */ + inline bool CheckExcept(SOCKET fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); + } + /*! + * \brief wait for exception event on a single descriptor + * \param fd the file descriptor to wait the event for + * \param timeout the timeout counter, can be negative, which means wait until the event happen + * \return 1 if success, 0 if timeout, and -1 if error occurs + */ + inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*) + pollfd pfd; + pfd.fd = fd; + pfd.events = POLLPRI; + return poll(&pfd, 1, timeout); + } - private: - inline static int Select_(int maxfd, fd_set *rfds, - fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*) -#if !defined(_WIN32) - utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE"); -#endif - if (timeout == 0) { - return select(maxfd, rfds, wfds, efds, NULL); - } else { - timeval tm; - tm.tv_usec = (timeout % 1000) * 1000; - tm.tv_sec = timeout / 1000; - return select(maxfd, rfds, wfds, efds, &tm); - } - } + /*! + * \brief peform poll on the set defined, read, write, exception + * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block + * \return + */ + inline void Poll(long timeout = -1) { // NOLINT(*) + std::vector fdset; + fdset.reserve(fds.size()); + for (auto kv : fds) { + fdset.push_back(kv.second); + } + int ret = poll(fdset.data(), fdset.size(), timeout); + if (ret == -1) { + Socket::Error("Poll"); + } else { + for (auto& pfd : fdset) { + auto revents = pfd.revents & pfd.events; + if (!revents) { + fds.erase(pfd.fd); + } else { + fds[pfd.fd].events = revents; + } + } + } + } - SOCKET maxfd; - fd_set read_set, write_set, except_set; -}; -} // namespace utils + std::unordered_map fds; + }; + } // namespace utils } // namespace rabit -#endif // RABIT_SOCKET_H_ +#endif // RABIT_SOCKET_H_ \ No newline at end of file From cc5e1589b270fb2cb4607dcdc68a18d0ab296ff0 Mon Sep 17 00:00:00 2001 From: Chen Qin Date: Thu, 18 Oct 2018 22:13:45 -0700 Subject: [PATCH 2/5] include missing packages, disable recovery tests for now --- .travis.yml | 2 +- src/allreduce_base.cc | 2 +- src/socket.h | 821 +++++++++++++++++++++--------------------- test/local_recover.cc | 16 +- test/test.mk | 2 +- 5 files changed, 422 insertions(+), 421 deletions(-) diff --git a/.travis.yml b/.travis.yml index dc9941f0..1e0aa244 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,7 @@ before_install: - source ${TRAVIS}/travis_setup_env.sh install: - - pip install --user cpplint pylint kubernetes + - pip install --user cpplint pylint kubernetes urllib3 script: scripts/travis_script.sh diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 3d547fe0..93e15b48 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -210,7 +210,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { #ifdef _MSC_VER Sleep(1); #else - sleep(1); + sleep(retry << 1); #endif continue; } diff --git a/src/socket.h b/src/socket.h index 22392330..9a04a0a0 100644 --- a/src/socket.h +++ b/src/socket.h @@ -42,94 +42,94 @@ const int INVALID_SOCKET = -1; #endif namespace rabit { - namespace utils { +namespace utils { /*! \brief data structure for network address */ struct SockAddr { - sockaddr_in addr; - // constructor - SockAddr(void) {} - SockAddr(const char *url, int port) { - this->Set(url, port); - } - inline static std::string GetHostName(void) { - std::string buf; buf.resize(256); - utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); - return std::string(buf.c_str()); - } - /*! - * \brief set the address - * \param url the url of the address - * \param port the port of address - */ - inline void Set(const char *host, int port) { - addrinfo hints; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_INET; - hints.ai_protocol = SOCK_STREAM; - addrinfo *res = NULL; - int sig = getaddrinfo(host, NULL, &hints, &res); - Check(sig == 0 && res != NULL, "cannot obtain address of %s", host); - Check(res->ai_family == AF_INET, "Does not support IPv6"); - memcpy(&addr, res->ai_addr, res->ai_addrlen); - addr.sin_port = htons(port); - freeaddrinfo(res); - } - /*! \brief return port of the address*/ - inline int port(void) const { - return ntohs(addr.sin_port); - } - /*! \return a string representation of the address */ - inline std::string AddrStr(void) const { - std::string buf; buf.resize(256); + sockaddr_in addr; + // constructor + SockAddr(void) {} + SockAddr(const char *url, int port) { + this->Set(url, port); + } + inline static std::string GetHostName(void) { + std::string buf; buf.resize(256); + utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); + return std::string(buf.c_str()); + } + /*! + * \brief set the address + * \param url the url of the address + * \param port the port of address + */ + inline void Set(const char *host, int port) { + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_protocol = SOCK_STREAM; + addrinfo *res = NULL; + int sig = getaddrinfo(host, NULL, &hints, &res); + Check(sig == 0 && res != NULL, "cannot obtain address of %s", host); + Check(res->ai_family == AF_INET, "Does not support IPv6"); + memcpy(&addr, res->ai_addr, res->ai_addrlen); + addr.sin_port = htons(port); + freeaddrinfo(res); + } + /*! \brief return port of the address*/ + inline int port(void) const { + return ntohs(addr.sin_port); + } + /*! \return a string representation of the address */ + inline std::string AddrStr(void) const { + std::string buf; buf.resize(256); #ifdef _WIN32 - const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, - &buf[0], buf.length()); + const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, + &buf[0], buf.length()); #else - const char *s = inet_ntop(AF_INET, &addr.sin_addr, - &buf[0], buf.length()); + const char *s = inet_ntop(AF_INET, &addr.sin_addr, + &buf[0], buf.length()); #endif - Assert(s != NULL, "cannot decode address"); - return std::string(s); - } + Assert(s != NULL, "cannot decode address"); + return std::string(s); + } }; /*! * \brief base class containing common operations of TCP and UDP sockets */ class Socket { -public: - /*! \brief the file descriptor of socket */ - SOCKET sockfd; - // default conversion to int - inline operator SOCKET() const { - return sockfd; - } - /*! - * \return last error of socket operation - */ - inline static int GetLastError(void) { + public: + /*! \brief the file descriptor of socket */ + SOCKET sockfd; + // default conversion to int + inline operator SOCKET() const { + return sockfd; + } + /*! + * \return last error of socket operation + */ + inline static int GetLastError(void) { #ifdef _WIN32 - return WSAGetLastError(); + return WSAGetLastError(); #else - return errno; + return errno; #endif - } - /*! \return whether last error was would block */ - inline static bool LastErrorWouldBlock(void) { - int errsv = GetLastError(); + } + /*! \return whether last error was would block */ + inline static bool LastErrorWouldBlock(void) { + int errsv = GetLastError(); #ifdef _WIN32 - return errsv == WSAEWOULDBLOCK; + return errsv == WSAEWOULDBLOCK; #else - return errsv == EAGAIN || errsv == EWOULDBLOCK; + return errsv == EAGAIN || errsv == EWOULDBLOCK; #endif - } - /*! - * \brief start up the socket module - * call this before using the sockets - */ - inline static void Startup(void) { + } + /*! + * \brief start up the socket module + * call this before using the sockets + */ + inline static void Startup(void) { #ifdef _WIN32 - WSADATA wsa_data; + WSADATA wsa_data; if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { Socket::Error("Startup"); } @@ -138,387 +138,388 @@ WSACleanup(); utils::Error("Could not find a usable version of Winsock.dll\n"); } #endif - } - /*! - * \brief shutdown the socket module after use, all sockets need to be closed - */ - inline static void Finalize(void) { + } + /*! + * \brief shutdown the socket module after use, all sockets need to be closed + */ + inline static void Finalize(void) { #ifdef _WIN32 - WSACleanup(); + WSACleanup(); #endif - } - /*! - * \brief set this socket to use non-blocking mode - * \param non_block whether set it to be non-block, if it is false - * it will set it back to block mode - */ - inline void SetNonBlock(bool non_block) { + } + /*! + * \brief set this socket to use non-blocking mode + * \param non_block whether set it to be non-block, if it is false + * it will set it back to block mode + */ + inline void SetNonBlock(bool non_block) { #ifdef _WIN32 - u_long mode = non_block ? 1 : 0; + u_long mode = non_block ? 1 : 0; if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { Socket::Error("SetNonBlock"); } #else - int flag = fcntl(sockfd, F_GETFL, 0); - if (flag == -1) { - Socket::Error("SetNonBlock-1"); - } - if (non_block) { - flag |= O_NONBLOCK; - } else { - flag &= ~O_NONBLOCK; - } - if (fcntl(sockfd, F_SETFL, flag) == -1) { - Socket::Error("SetNonBlock-2"); - } -#endif + int flag = fcntl(sockfd, F_GETFL, 0); + if (flag == -1) { + Socket::Error("SetNonBlock-1"); } - /*! - * \brief bind the socket to an address - * \param addr - */ - inline void Bind(const SockAddr &addr) { - if (bind(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == -1) { - Socket::Error("Bind"); - } + if (non_block) { + flag |= O_NONBLOCK; + } else { + flag &= ~O_NONBLOCK; } - /*! - * \brief try bind the socket to host, from start_port to end_port - * \param start_port starting port number to try - * \param end_port ending port number to try - * \return the port successfully bind to, return -1 if failed to bind any port - */ - inline int TryBindHost(int start_port, int end_port) { - // TODO(tqchen) add prefix check - for (int port = start_port; port < end_port; ++port) { - SockAddr addr("0.0.0.0", port); - if (bind(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == 0) { - return port; - } + if (fcntl(sockfd, F_SETFL, flag) == -1) { + Socket::Error("SetNonBlock-2"); + } +#endif + } + /*! + * \brief bind the socket to an address + * \param addr + */ + inline void Bind(const SockAddr &addr) { + if (bind(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == -1) { + Socket::Error("Bind"); + } + } + /*! + * \brief try bind the socket to host, from start_port to end_port + * \param start_port starting port number to try + * \param end_port ending port number to try + * \return the port successfully bind to, return -1 if failed to bind any port + */ + inline int TryBindHost(int start_port, int end_port) { + // TODO(tqchen) add prefix check + for (int port = start_port; port < end_port; ++port) { + SockAddr addr("0.0.0.0", port); + if (bind(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == 0) { + return port; + } #if defined(_WIN32) - if (WSAGetLastError() != WSAEADDRINUSE) { + if (WSAGetLastError() != WSAEADDRINUSE) { Socket::Error("TryBindHost"); } #else - if (errno != EADDRINUSE) { - Socket::Error("TryBindHost"); - } -#endif + if (errno != EADDRINUSE) { + Socket::Error("TryBindHost"); } - - return -1; - } - /*! \brief get last error code if any */ - inline int GetSockError(void) const { - int error = 0; - socklen_t len = sizeof(error); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { - Error("GetSockError"); - } - return error; - } - /*! \brief check if anything bad happens */ - inline bool BadSocket(void) const { - if (IsClosed()) return true; - int err = GetSockError(); - if (err == EBADF || err == EINTR) return true; - return false; - } - /*! \brief check if socket is already closed */ - inline bool IsClosed(void) const { - return sockfd == INVALID_SOCKET; +#endif } - /*! \brief close the socket */ - inline void Close(void) { - if (sockfd != INVALID_SOCKET) { + + return -1; + } + /*! \brief get last error code if any */ + inline int GetSockError(void) const { + int error = 0; + socklen_t len = sizeof(error); + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) != 0) { + Error("GetSockError"); + } + return error; + } + /*! \brief check if anything bad happens */ + inline bool BadSocket(void) const { + if (IsClosed()) return true; + int err = GetSockError(); + if (err == EBADF || err == EINTR) return true; + return false; + } + /*! \brief check if socket is already closed */ + inline bool IsClosed(void) const { + return sockfd == INVALID_SOCKET; + } + /*! \brief close the socket */ + inline void Close(void) { + if (sockfd != INVALID_SOCKET) { #ifdef _WIN32 - closesocket(sockfd); + closesocket(sockfd); #else - close(sockfd); + close(sockfd); #endif - sockfd = INVALID_SOCKET; - } else { - Error("Socket::Close double close the socket or close without create"); - } - } - // report an socket error - inline static void Error(const char *msg) { - int errsv = GetLastError(); + sockfd = INVALID_SOCKET; + } else { + Error("Socket::Close double close the socket or close without create"); + } + } + // report an socket error + inline static void Error(const char *msg) { + int errsv = GetLastError(); #ifdef _WIN32 - utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv); + utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv); #else - utils::Error("Socket %s Error:%s", msg, strerror(errsv)); + utils::Error("Socket %s Error:%s", msg, strerror(errsv)); #endif - } + } -protected: - explicit Socket(SOCKET sockfd) : sockfd(sockfd) { - } + protected: + explicit Socket(SOCKET sockfd) : sockfd(sockfd) { + } }; /*! * \brief a wrapper of TCP socket that hopefully be cross platform */ class TCPSocket : public Socket{ -public: - // constructor - TCPSocket(void) : Socket(INVALID_SOCKET) { - } - explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { - } - /*! - * \brief enable/disable TCP keepalive - * \param keepalive whether to set the keep alive option on - */ - inline void SetKeepAlive(bool keepalive) { - int opt = static_cast(keepalive); - if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, - reinterpret_cast(&opt), sizeof(opt)) < 0) { - Socket::Error("SetKeepAlive"); - } - } - /*! - * \brief create the socket, call this before using socket - * \param af domain - */ - inline void Create(int af = PF_INET) { - sockfd = socket(PF_INET, SOCK_STREAM, 0); - if (sockfd == INVALID_SOCKET) { - Socket::Error("Create"); - } - } - /*! - * \brief perform listen of the socket - * \param backlog backlog parameter - */ - inline void Listen(int backlog = 16) { - listen(sockfd, backlog); - } - /*! \brief get a new connection */ - TCPSocket Accept(void) { - SOCKET newfd = accept(sockfd, NULL, NULL); - if (newfd == INVALID_SOCKET) { - Socket::Error("Accept"); - } - return TCPSocket(newfd); - } - /*! - * \brief decide whether the socket is at OOB mark - * \return 1 if at mark, 0 if not, -1 if an error occured - */ - inline int AtMark(void) const { + public: + // constructor + TCPSocket(void) : Socket(INVALID_SOCKET) { + } + explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { + } + /*! + * \brief enable/disable TCP keepalive + * \param keepalive whether to set the keep alive option on + */ + inline void SetKeepAlive(bool keepalive) { + int opt = static_cast(keepalive); + if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, + reinterpret_cast(&opt), sizeof(opt)) < 0) { + Socket::Error("SetKeepAlive"); + } + } + /*! + * \brief create the socket, call this before using socket + * \param af domain + */ + inline void Create(int af = PF_INET) { + sockfd = socket(PF_INET, SOCK_STREAM, 0); + if (sockfd == INVALID_SOCKET) { + Socket::Error("Create"); + } + } + /*! + * \brief perform listen of the socket + * \param backlog backlog parameter + */ + inline void Listen(int backlog = 16) { + listen(sockfd, backlog); + } + /*! \brief get a new connection */ + TCPSocket Accept(void) { + SOCKET newfd = accept(sockfd, NULL, NULL); + if (newfd == INVALID_SOCKET) { + Socket::Error("Accept"); + } + return TCPSocket(newfd); + } + /*! + * \brief decide whether the socket is at OOB mark + * \return 1 if at mark, 0 if not, -1 if an error occured + */ + inline int AtMark(void) const { #ifdef _WIN32 - unsigned long atmark; // NOLINT(*) + unsigned long atmark; // NOLINT(*) if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; #else - int atmark; - if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; + int atmark; + if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; #endif - return static_cast(atmark); - } - /*! - * \brief connect to an address - * \param addr the address to connect to - * \return whether connect is successful - */ - inline bool Connect(const SockAddr &addr) { - return connect(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == 0; - } - /*! - * \brief send data using the socket - * \param buf the pointer to the buffer - * \param len the size of the buffer - * \param flags extra flags - * \return size of data actually sent - * return -1 if error occurs - */ - inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { - const char *buf = reinterpret_cast(buf_); - return send(sockfd, buf, static_cast(len), flag); - } - /*! - * \brief receive data using the socket - * \param buf_ the pointer to the buffer - * \param len the size of the buffer - * \param flags extra flags - * \return size of data actually received - * return -1 if error occurs - */ - inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { - char *buf = reinterpret_cast(buf_); - return recv(sockfd, buf, static_cast(len), flags); - } - /*! - * \brief peform block write that will attempt to send all data out - * can still return smaller than request when error occurs - * \param buf the pointer to the buffer - * \param len the size of the buffer - * \return size of data actually sent - */ - inline size_t SendAll(const void *buf_, size_t len) { - const char *buf = reinterpret_cast(buf_); - size_t ndone = 0; - while (ndone < len) { - ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); - if (ret == -1) { - if (LastErrorWouldBlock()) return ndone; - Socket::Error("SendAll"); - } - buf += ret; - ndone += ret; + return static_cast(atmark); + } + /*! + * \brief connect to an address + * \param addr the address to connect to + * \return whether connect is successful + */ + inline bool Connect(const SockAddr &addr) { + return connect(sockfd, reinterpret_cast(&addr.addr), + sizeof(addr.addr)) == 0; + } + /*! + * \brief send data using the socket + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually sent + * return -1 if error occurs + */ + inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { + const char *buf = reinterpret_cast(buf_); + return send(sockfd, buf, static_cast(len), flag); + } + /*! + * \brief receive data using the socket + * \param buf_ the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually received + * return -1 if error occurs + */ + inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { + char *buf = reinterpret_cast(buf_); + return recv(sockfd, buf, static_cast(len), flags); + } + /*! + * \brief peform block write that will attempt to send all data out + * can still return smaller than request when error occurs + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \return size of data actually sent + */ + inline size_t SendAll(const void *buf_, size_t len) { + const char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); + if (ret == -1) { + if (LastErrorWouldBlock()) return ndone; + Socket::Error("SendAll"); } - return ndone; - } - /*! - * \brief peforma block read that will attempt to read all data - * can still return smaller than request when error occurs - * \param buf_ the buffer pointer - * \param len length of data to recv - * \return size of data actually sent - */ - inline size_t RecvAll(void *buf_, size_t len) { - char *buf = reinterpret_cast(buf_); - size_t ndone = 0; - while (ndone < len) { - ssize_t ret = recv(sockfd, buf, - static_cast(len - ndone), MSG_WAITALL); - if (ret == -1) { - if (LastErrorWouldBlock()) return ndone; - Socket::Error("RecvAll"); - } - if (ret == 0) return ndone; - buf += ret; - ndone += ret; + buf += ret; + ndone += ret; + } + return ndone; + } + /*! + * \brief peforma block read that will attempt to read all data + * can still return smaller than request when error occurs + * \param buf_ the buffer pointer + * \param len length of data to recv + * \return size of data actually sent + */ + inline size_t RecvAll(void *buf_, size_t len) { + char *buf = reinterpret_cast(buf_); + size_t ndone = 0; + while (ndone < len) { + ssize_t ret = recv(sockfd, buf, + static_cast(len - ndone), MSG_WAITALL); + if (ret == -1) { + if (LastErrorWouldBlock()) return ndone; + Socket::Error("RecvAll"); } - return ndone; - } - /*! - * \brief send a string over network - * \param str the string to be sent - */ - inline void SendStr(const std::string &str) { - int len = static_cast(str.length()); - utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len), + if (ret == 0) return ndone; + buf += ret; + ndone += ret; + } + return ndone; + } + /*! + * \brief send a string over network + * \param str the string to be sent + */ + inline void SendStr(const std::string &str) { + int len = static_cast(str.length()); + utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len), + "error during send SendStr"); + if (len != 0) { + utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(), "error during send SendStr"); - if (len != 0) { - utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(), - "error during send SendStr"); - } } - /*! - * \brief recv a string from network - * \param out_str the string to receive - */ - inline void RecvStr(std::string *out_str) { - int len; - utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len), - "error during send RecvStr"); - out_str->resize(len); - if (len != 0) { - utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(), - "error during send SendStr"); - } + } + /*! + * \brief recv a string from network + * \param out_str the string to receive + */ + inline void RecvStr(std::string *out_str) { + int len; + utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len), + "error during send RecvStr"); + out_str->resize(len); + if (len != 0) { + utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(), + "error during send SendStr"); } + } }; /*! \brief helper data structure to perform poll */ struct PollHelper { - public: - /*! - * \brief add file descriptor to watch for read - * \param fd file descriptor to be watched - */ - inline void WatchRead(SOCKET fd) { - auto& pfd = fds[fd]; - pfd.fd = fd; - pfd.events |= POLLIN; - } - /*! - * \brief add file descriptor to watch for write - * \param fd file descriptor to be watched - */ - inline void WatchWrite(SOCKET fd) { - auto& pfd = fds[fd]; - pfd.fd = fd; - pfd.events |= POLLOUT; - } - /*! - * \brief add file descriptor to watch for exception - * \param fd file descriptor to be watched - */ - inline void WatchException(SOCKET fd) { - auto& pfd = fds[fd]; - pfd.fd = fd; - pfd.events |= POLLPRI; - } - /*! - * \brief Check if the descriptor is ready for read - * \param fd file descriptor to check status - */ - inline bool CheckRead(SOCKET fd) const { - const auto& pfd = fds.find(fd); - return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); - } - /*! - * \brief Check if the descriptor is ready for write - * \param fd file descriptor to check status - */ - inline bool CheckWrite(SOCKET fd) const { - const auto& pfd = fds.find(fd); - return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); - } - /*! - * \brief Check if the descriptor has any exception - * \param fd file descriptor to check status - */ - inline bool CheckExcept(SOCKET fd) const { - const auto& pfd = fds.find(fd); - return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); - } - /*! - * \brief wait for exception event on a single descriptor - * \param fd the file descriptor to wait the event for - * \param timeout the timeout counter, can be negative, which means wait until the event happen - * \return 1 if success, 0 if timeout, and -1 if error occurs - */ - inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*) - pollfd pfd; - pfd.fd = fd; - pfd.events = POLLPRI; - return poll(&pfd, 1, timeout); - } + public: + /*! + * \brief add file descriptor to watch for read + * \param fd file descriptor to be watched + */ + inline void WatchRead(SOCKET fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLIN; + } + /*! + * \brief add file descriptor to watch for write + * \param fd file descriptor to be watched + */ + inline void WatchWrite(SOCKET fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLOUT; + } + /*! + * \brief add file descriptor to watch for exception + * \param fd file descriptor to be watched + */ + inline void WatchException(SOCKET fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLPRI; + } + /*! + * \brief Check if the descriptor is ready for read + * \param fd file descriptor to check status + */ + inline bool CheckRead(SOCKET fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); + } + /*! + * \brief Check if the descriptor is ready for write + * \param fd file descriptor to check status + */ + inline bool CheckWrite(SOCKET fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); + } + /*! + * \brief Check if the descriptor has any exception + * \param fd file descriptor to check status + */ + inline bool CheckExcept(SOCKET fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); + } + /*! + * \brief wait for exception event on a single descriptor + * \param fd the file descriptor to wait the event for + * \param timeout the timeout counter, can be negative, which means wait until the event happen + * \return 1 if success, 0 if timeout, and -1 if error occurs + */ + inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*) + pollfd pfd; + pfd.fd = fd; + pfd.events = POLLPRI; + return poll(&pfd, 1, timeout); + } - /*! - * \brief peform poll on the set defined, read, write, exception - * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block - * \return - */ - inline void Poll(long timeout = -1) { // NOLINT(*) - std::vector fdset; - fdset.reserve(fds.size()); - for (auto kv : fds) { - fdset.push_back(kv.second); - } - int ret = poll(fdset.data(), fdset.size(), timeout); - if (ret == -1) { - Socket::Error("Poll"); - } else { - for (auto& pfd : fdset) { - auto revents = pfd.revents & pfd.events; - if (!revents) { - fds.erase(pfd.fd); - } else { - fds[pfd.fd].events = revents; - } - } - } - } + /*! + * \brief peform poll on the set defined, read, write, exception + * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block + * \return + */ + inline void Poll(long timeout = -1) { // NOLINT(*) + std::vector fdset; + fdset.reserve(fds.size()); + for (auto kv : fds) { + fdset.push_back(kv.second); + } + int ret = poll(fdset.data(), fdset.size(), timeout); + if (ret == -1) { + Socket::Error("Poll"); + } else { + for (auto& pfd : fdset) { + auto revents = pfd.revents & pfd.events; + if (!revents) { + fds.erase(pfd.fd); + } else { + fds[pfd.fd].events = revents; + } + } + } + } - std::unordered_map fds; - }; - } // namespace utils + std::unordered_map fds; +}; +} // namespace utils } // namespace rabit -#endif // RABIT_SOCKET_H_ \ No newline at end of file +#endif // RABIT_SOCKET_H_ diff --git a/test/local_recover.cc b/test/local_recover.cc index a63bd2f8..d2d7d209 100644 --- a/test/local_recover.cc +++ b/test/local_recover.cc @@ -120,17 +120,17 @@ int main(int argc, char *argv[]) { printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); } for (int r = iter; r < 3; ++r) { - TestMax(&model, &local, ntrial, r); - printf("[%d] !!!TestMax pass, iter=%d\n", rank, r); + //TestMax(&model, &local, ntrial, r); + //printf("[%d] !!!TestMax pass, iter=%d\n", rank, r); int step = std::max(nproc / 3, 1); for (int i = 0; i < nproc; i += step) { - TestBcast(n, i, ntrial, r); + //TestBcast(n, i, ntrial, r); } - printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); - TestSum(&model, &local, ntrial, r); - printf("[%d] !!!TestSum pass, iter=%d\n", rank, r); - rabit::CheckPoint(&model, &local); - printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); + //printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); + //TestSum(&model, &local, ntrial, r); + //printf("[%d] !!!TestSum pass, iter=%d\n", rank, r); + //rabit::CheckPoint(&model, &local); + //printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); } rabit::Finalize(); return 0; diff --git a/test/test.mk b/test/test.mk index 4a545113..37941374 100644 --- a/test/test.mk +++ b/test/test.mk @@ -1,7 +1,7 @@ # this is a makefile used to show testcases of rabit .PHONY: all -all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k +all: model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k # this experiment test recovery with actually process exit, use keepalive to keep program alive model_recover_10_10k: From 7a627e3f15107a86839d2104b26722c14e274423 Mon Sep 17 00:00:00 2001 From: Chen Qin Date: Fri, 19 Oct 2018 13:27:43 -0700 Subject: [PATCH 3/5] disable local_recover tests until we have a bug fix --- test/test.mk | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/test/test.mk b/test/test.mk index 37941374..0b4cfaec 100644 --- a/test/test.mk +++ b/test/test.mk @@ -1,29 +1,30 @@ # this is a makefile used to show testcases of rabit .PHONY: all -all: model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k +all: local_recover_10_10k # this experiment test recovery with actually process exit, use keepalive to keep program alive -model_recover_10_10k: - ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 +# TODO: enable those tests once we fix issue in rabit +#model_recover_10_10k: +# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 -model_recover_10_10k_die_same: - ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 +#model_recover_10_10k_die_same: +# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 -model_recover_10_10k_die_hard: - ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 +#model_recover_10_10k_die_hard: +# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 -local_recover_10_10k: - ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 +#local_recover_10_10k: +# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 -pylocal_recover_10_10k: - ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 ./local_recover.py 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 +#pylocal_recover_10_10k: +# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 ./local_recover.py 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 -lazy_recover_10_10k_die_hard: - ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 +#lazy_recover_10_10k_die_hard: +# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 -lazy_recover_10_10k_die_same: - ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 +#lazy_recover_10_10k_die_same: +# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 ringallreduce_10_10k: ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 100 rabit_reduce_ring_mincount=10 From fa17555481bcf28757fa512f1b6cc6778710bb0c Mon Sep 17 00:00:00 2001 From: Chen Qin Date: Sun, 21 Oct 2018 13:39:44 -0700 Subject: [PATCH 4/5] support larger cluster --- src/allreduce_base.cc | 4 ++-- src/allreduce_robust-inl.h | 4 ++-- src/allreduce_robust.cc | 2 +- src/socket.h | 28 ++++++++++++++-------------- test/local_recover.cc | 16 ++++++++-------- test/test.mk | 30 +++++++++++++++--------------- 6 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 93e15b48..a509e827 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -208,7 +208,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { } else { fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str()); #ifdef _MSC_VER - Sleep(1); + Sleep(retry << 1); #else sleep(retry << 1); #endif @@ -896,4 +896,4 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_, std::min(prank * step, count)) * type_nbytes); } } // namespace engine -} // namespace rabit +} // namespace rabit \ No newline at end of file diff --git a/src/allreduce_robust-inl.h b/src/allreduce_robust-inl.h index 22df5191..7db18a42 100644 --- a/src/allreduce_robust-inl.h +++ b/src/allreduce_robust-inl.h @@ -69,7 +69,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, if (parent_index == -1) { utils::Assert(stage != 2 && stage != 1, "invalie stage id"); } - // select helper + // poll helper utils::PollHelper watcher; bool done = (stage == 3); for (int i = 0; i < nlink; ++i) { @@ -166,4 +166,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, } } // namespace engine } // namespace rabit -#endif // RABIT_ALLREDUCE_ROBUST_INL_H_ +#endif // RABIT_ALLREDUCE_ROBUST_INL_H_ \ No newline at end of file diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 210d5d8a..9809cbd7 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -1185,4 +1185,4 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, return kSuccess; } } // namespace engine -} // namespace rabit +} // namespace rabit \ No newline at end of file diff --git a/src/socket.h b/src/socket.h index 9a04a0a0..f0b7d7c7 100644 --- a/src/socket.h +++ b/src/socket.h @@ -130,13 +130,13 @@ class Socket { inline static void Startup(void) { #ifdef _WIN32 WSADATA wsa_data; -if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { -Socket::Error("Startup"); -} -if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { -WSACleanup(); -utils::Error("Could not find a usable version of Winsock.dll\n"); -} + if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { + Socket::Error("Startup"); + } + if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { + WSACleanup(); + utils::Error("Could not find a usable version of Winsock.dll\n"); + } #endif } /*! @@ -155,9 +155,9 @@ utils::Error("Could not find a usable version of Winsock.dll\n"); inline void SetNonBlock(bool non_block) { #ifdef _WIN32 u_long mode = non_block ? 1 : 0; -if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { -Socket::Error("SetNonBlock"); -} + if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { + Socket::Error("SetNonBlock"); + } #else int flag = fcntl(sockfd, F_GETFL, 0); if (flag == -1) { @@ -199,8 +199,8 @@ Socket::Error("SetNonBlock"); } #if defined(_WIN32) if (WSAGetLastError() != WSAEADDRINUSE) { -Socket::Error("TryBindHost"); -} + Socket::Error("TryBindHost"); + } #else if (errno != EADDRINUSE) { Socket::Error("TryBindHost"); @@ -312,7 +312,7 @@ class TCPSocket : public Socket{ inline int AtMark(void) const { #ifdef _WIN32 unsigned long atmark; // NOLINT(*) -if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; + if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; #else int atmark; if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; @@ -522,4 +522,4 @@ struct PollHelper { }; } // namespace utils } // namespace rabit -#endif // RABIT_SOCKET_H_ +#endif // RABIT_SOCKET_H_ \ No newline at end of file diff --git a/test/local_recover.cc b/test/local_recover.cc index d2d7d209..a63bd2f8 100644 --- a/test/local_recover.cc +++ b/test/local_recover.cc @@ -120,17 +120,17 @@ int main(int argc, char *argv[]) { printf("[%d] reload-trail=%d, init iter=%d\n", rank, ntrial, iter); } for (int r = iter; r < 3; ++r) { - //TestMax(&model, &local, ntrial, r); - //printf("[%d] !!!TestMax pass, iter=%d\n", rank, r); + TestMax(&model, &local, ntrial, r); + printf("[%d] !!!TestMax pass, iter=%d\n", rank, r); int step = std::max(nproc / 3, 1); for (int i = 0; i < nproc; i += step) { - //TestBcast(n, i, ntrial, r); + TestBcast(n, i, ntrial, r); } - //printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); - //TestSum(&model, &local, ntrial, r); - //printf("[%d] !!!TestSum pass, iter=%d\n", rank, r); - //rabit::CheckPoint(&model, &local); - //printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); + printf("[%d] !!!TestBcast pass, iter=%d\n", rank, r); + TestSum(&model, &local, ntrial, r); + printf("[%d] !!!TestSum pass, iter=%d\n", rank, r); + rabit::CheckPoint(&model, &local); + printf("[%d] !!!CheckPont pass, iter=%d\n", rank, r); } rabit::Finalize(); return 0; diff --git a/test/test.mk b/test/test.mk index 0b4cfaec..9dfebb02 100644 --- a/test/test.mk +++ b/test/test.mk @@ -1,30 +1,30 @@ # this is a makefile used to show testcases of rabit .PHONY: all -all: local_recover_10_10k +all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k # this experiment test recovery with actually process exit, use keepalive to keep program alive # TODO: enable those tests once we fix issue in rabit -#model_recover_10_10k: -# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 +model_recover_10_10k: + ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 -#model_recover_10_10k_die_same: -# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 +model_recover_10_10k_die_same: + ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 -#model_recover_10_10k_die_hard: -# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 +model_recover_10_10k_die_hard: + ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 -#local_recover_10_10k: -# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 +local_recover_10_10k: + ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 local_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 -#pylocal_recover_10_10k: -# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 ./local_recover.py 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 +pylocal_recover_10_10k: + ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 ./local_recover.py 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=1,1,1,1 -#lazy_recover_10_10k_die_hard: -# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 +lazy_recover_10_10k_die_hard: + ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 -#lazy_recover_10_10k_die_same: -# ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 +lazy_recover_10_10k_die_same: + ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 lazy_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 ringallreduce_10_10k: ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 100 rabit_reduce_ring_mincount=10 From d71e2478b38db10646217cf830618979c974e1e7 Mon Sep 17 00:00:00 2001 From: Chen Qin Date: Sun, 21 Oct 2018 14:00:14 -0700 Subject: [PATCH 5/5] fix lint, merge with master --- doc/Doxyfile | 6 ------ include/rabit/c_api.h | 18 ++++++++++++++---- include/rabit/rabit.h | 12 ++++++++---- src/allreduce_base.cc | 2 +- src/allreduce_robust-inl.h | 2 +- src/allreduce_robust.cc | 2 +- src/socket.h | 2 +- 7 files changed, 26 insertions(+), 18 deletions(-) diff --git a/doc/Doxyfile b/doc/Doxyfile index 254a9467..3e64641f 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -42,7 +42,6 @@ SUBGROUPING = YES INLINE_GROUPED_CLASSES = NO INLINE_SIMPLE_STRUCTS = NO TYPEDEF_HIDES_STRUCT = NO -SYMBOL_CACHE_SIZE = 0 LOOKUP_CACHE_SIZE = 0 #--------------------------------------------------------------------------- # Build related configuration options @@ -76,7 +75,6 @@ GENERATE_DEPRECATEDLIST= YES ENABLED_SECTIONS = MAX_INITIALIZER_LINES = 30 SHOW_USED_FILES = YES -SHOW_DIRECTORIES = NO SHOW_FILES = YES SHOW_NAMESPACES = YES FILE_VERSION_FILTER = @@ -142,7 +140,6 @@ HTML_COLORSTYLE_HUE = 220 HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 HTML_TIMESTAMP = YES -HTML_ALIGN_MEMBERS = YES HTML_DYNAMIC_SECTIONS = NO GENERATE_DOCSET = NO DOCSET_FEEDNAME = "Doxygen generated docs" @@ -169,7 +166,6 @@ ECLIPSE_DOC_ID = org.doxygen.Project DISABLE_INDEX = NO GENERATE_TREEVIEW = NO ENUM_VALUES_PER_LINE = 4 -USE_INLINE_TREES = NO TREEVIEW_WIDTH = 250 EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 @@ -218,8 +214,6 @@ MAN_LINKS = NO #--------------------------------------------------------------------------- GENERATE_XML = YES XML_OUTPUT = xml -XML_SCHEMA = -XML_DTD = XML_PROGRAMLISTING = YES #--------------------------------------------------------------------------- # configuration options for the AutoGen Definitions output diff --git a/include/rabit/c_api.h b/include/rabit/c_api.h index 1b47f804..4668ae47 100644 --- a/include/rabit/c_api.h +++ b/include/rabit/c_api.h @@ -41,13 +41,22 @@ RABIT_DLL void RabitInit(int argc, char *argv[]); */ RABIT_DLL void RabitFinalize(void); -/*! \brief get rank of current process */ +/*! + * \brief get rank of current process + * \return rank number of worker + * */ RABIT_DLL int RabitGetRank(void); -/*! \brief get total number of process */ +/*! + * \brief get total number of process + * \return total world size + * */ RABIT_DLL int RabitGetWorldSize(void); -/*! \brief get rank of current process */ +/*! + * \brief get rank of current process + * \return if rabit is distributed + * */ RABIT_DLL int RabitIsDistributed(void); /*! @@ -136,6 +145,7 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far + * \return rabit version number */ RABIT_DLL int RabitVersionNumber(void); @@ -144,7 +154,7 @@ RABIT_DLL int RabitVersionNumber(void); * \brief a Dummy function, * used to cause force link of C API into the DLL. * \code - * // force link rabit C API library. + * \/\/force link rabit C API library. * static int must_link_rabit_ = RabitLinkTag(); * \endcode * \return a dummy integer. diff --git a/include/rabit/rabit.h b/include/rabit/rabit.h index 1eda2ea7..9686eef3 100644 --- a/include/rabit/rabit.h +++ b/include/rabit/rabit.h @@ -79,14 +79,18 @@ inline void Init(int argc, char *argv[]); * \brief finalizes the rabit engine, call this function after you finished with all the jobs */ inline void Finalize(); -/*! \brief gets rank of the current process */ +/*! \brief gets rank of the current process + * \return rank number of worker*/ inline int GetRank(); -/*! \brief gets total number of processes */ +/*! \brief gets total number of processes + * \return total world size*/ inline int GetWorldSize(); -/*! \brief whether rabit env is in distributed mode */ +/*! \brief whether rabit env is in distributed mode + * \return is distributed*/ inline bool IsDistributed(); -/*! \brief gets processor's name */ +/*! \brief gets processor's name + * \return processor name*/ inline std::string GetProcessorName(); /*! * \brief prints the msg to the tracker, diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index a509e827..2e79324e 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -896,4 +896,4 @@ AllreduceBase::TryAllreduceRing(void *sendrecvbuf_, std::min(prank * step, count)) * type_nbytes); } } // namespace engine -} // namespace rabit \ No newline at end of file +} // namespace rabit diff --git a/src/allreduce_robust-inl.h b/src/allreduce_robust-inl.h index 7db18a42..7baa14bf 100644 --- a/src/allreduce_robust-inl.h +++ b/src/allreduce_robust-inl.h @@ -166,4 +166,4 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, } } // namespace engine } // namespace rabit -#endif // RABIT_ALLREDUCE_ROBUST_INL_H_ \ No newline at end of file +#endif // RABIT_ALLREDUCE_ROBUST_INL_H_ diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 9809cbd7..210d5d8a 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -1185,4 +1185,4 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, return kSuccess; } } // namespace engine -} // namespace rabit \ No newline at end of file +} // namespace rabit diff --git a/src/socket.h b/src/socket.h index f0b7d7c7..cfc2449c 100644 --- a/src/socket.h +++ b/src/socket.h @@ -522,4 +522,4 @@ struct PollHelper { }; } // namespace utils } // namespace rabit -#endif // RABIT_SOCKET_H_ \ No newline at end of file +#endif // RABIT_SOCKET_H_