diff --git a/.travis.yml b/.travis.yml index a0ae62e3..1e0aa244 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,7 @@ before_install: - source ${TRAVIS}/travis_setup_env.sh install: - - pip install --user cpplint pylint + - pip install --user cpplint pylint kubernetes urllib3 script: scripts/travis_script.sh diff --git a/doc/Doxyfile b/doc/Doxyfile index 254a9467..3e64641f 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -42,7 +42,6 @@ SUBGROUPING = YES INLINE_GROUPED_CLASSES = NO INLINE_SIMPLE_STRUCTS = NO TYPEDEF_HIDES_STRUCT = NO -SYMBOL_CACHE_SIZE = 0 LOOKUP_CACHE_SIZE = 0 #--------------------------------------------------------------------------- # Build related configuration options @@ -76,7 +75,6 @@ GENERATE_DEPRECATEDLIST= YES ENABLED_SECTIONS = MAX_INITIALIZER_LINES = 30 SHOW_USED_FILES = YES -SHOW_DIRECTORIES = NO SHOW_FILES = YES SHOW_NAMESPACES = YES FILE_VERSION_FILTER = @@ -142,7 +140,6 @@ HTML_COLORSTYLE_HUE = 220 HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 HTML_TIMESTAMP = YES -HTML_ALIGN_MEMBERS = YES HTML_DYNAMIC_SECTIONS = NO GENERATE_DOCSET = NO DOCSET_FEEDNAME = "Doxygen generated docs" @@ -169,7 +166,6 @@ ECLIPSE_DOC_ID = org.doxygen.Project DISABLE_INDEX = NO GENERATE_TREEVIEW = NO ENUM_VALUES_PER_LINE = 4 -USE_INLINE_TREES = NO TREEVIEW_WIDTH = 250 EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 @@ -218,8 +214,6 @@ MAN_LINKS = NO #--------------------------------------------------------------------------- GENERATE_XML = YES XML_OUTPUT = xml -XML_SCHEMA = -XML_DTD = XML_PROGRAMLISTING = YES #--------------------------------------------------------------------------- # configuration options for the AutoGen Definitions output diff --git a/include/rabit/c_api.h b/include/rabit/c_api.h index 1b47f804..4668ae47 100644 --- a/include/rabit/c_api.h +++ b/include/rabit/c_api.h @@ -41,13 +41,22 @@ RABIT_DLL void RabitInit(int argc, char *argv[]); */ RABIT_DLL void RabitFinalize(void); -/*! \brief get rank of current process */ +/*! + * \brief get rank of current process + * \return rank number of worker + * */ RABIT_DLL int RabitGetRank(void); -/*! \brief get total number of process */ +/*! + * \brief get total number of process + * \return total world size + * */ RABIT_DLL int RabitGetWorldSize(void); -/*! \brief get rank of current process */ +/*! + * \brief get rank of current process + * \return if rabit is distributed + * */ RABIT_DLL int RabitIsDistributed(void); /*! @@ -136,6 +145,7 @@ RABIT_DLL void RabitCheckPoint(const char *global_model, /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far + * \return rabit version number */ RABIT_DLL int RabitVersionNumber(void); @@ -144,7 +154,7 @@ RABIT_DLL int RabitVersionNumber(void); * \brief a Dummy function, * used to cause force link of C API into the DLL. * \code - * // force link rabit C API library. + * \/\/force link rabit C API library. * static int must_link_rabit_ = RabitLinkTag(); * \endcode * \return a dummy integer. diff --git a/include/rabit/rabit.h b/include/rabit/rabit.h index 1eda2ea7..9686eef3 100644 --- a/include/rabit/rabit.h +++ b/include/rabit/rabit.h @@ -79,14 +79,18 @@ inline void Init(int argc, char *argv[]); * \brief finalizes the rabit engine, call this function after you finished with all the jobs */ inline void Finalize(); -/*! \brief gets rank of the current process */ +/*! \brief gets rank of the current process + * \return rank number of worker*/ inline int GetRank(); -/*! \brief gets total number of processes */ +/*! \brief gets total number of processes + * \return total world size*/ inline int GetWorldSize(); -/*! \brief whether rabit env is in distributed mode */ +/*! \brief whether rabit env is in distributed mode + * \return is distributed*/ inline bool IsDistributed(); -/*! \brief gets processor's name */ +/*! \brief gets processor's name + * \return processor name*/ inline std::string GetProcessorName(); /*! * \brief prints the msg to the tracker, diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 862187bc..2e79324e 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -208,9 +208,9 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { } else { fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str()); #ifdef _MSC_VER - Sleep(1); + Sleep(retry << 1); #else - sleep(1); + sleep(retry << 1); #endif continue; } @@ -454,29 +454,29 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, while (true) { // select helper bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; for (int i = 0; i < nlink; ++i) { if (i == parent_index) { if (size_down_in != total_size) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); // only watch for exception in live channels - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); finished = false; } if (size_up_out != total_size && size_up_out < size_up_reduce) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } } else { if (links[i].size_read != total_size) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); } // size_write <= size_read if (links[i].size_write != total_size) { if (links[i].size_write < size_down_in) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } // only watch for exception in live channels - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); finished = false; } } @@ -484,17 +484,17 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, // finish runing allreduce if (finished) break; // select must return - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } // read data from childs for (int i = 0; i < nlink; ++i) { - if (i != parent_index && selecter.CheckRead(links[i].sock)) { + if (i != parent_index && watcher.CheckRead(links[i].sock)) { ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size); if (ret != kSuccess) { return ReportError(&links[i], ret); @@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, } } // read data from parent - if (selecter.CheckRead(links[parent_index].sock) && + if (watcher.CheckRead(links[parent_index].sock) && total_size > size_down_in) { ssize_t len = links[parent_index].sock. Recv(sendrecvbuf + size_down_in, total_size - size_down_in); @@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { while (true) { bool finished = true; // select helper - utils::SelectHelper selecter; + utils::PollHelper watcher; for (int i = 0; i < nlink; ++i) { if (in_link == -2) { - selecter.WatchRead(links[i].sock); finished = false; + watcher.WatchRead(links[i].sock); finished = false; } if (i == in_link && links[i].size_read != total_size) { - selecter.WatchRead(links[i].sock); finished = false; + watcher.WatchRead(links[i].sock); finished = false; } if (in_link != -2 && i != in_link && links[i].size_write != total_size) { if (links[i].size_write < size_in) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } finished = false; } - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); } // finish running if (finished) break; // select - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } if (in_link == -2) { // probe in-link for (int i = 0; i < nlink; ++i) { - if (selecter.CheckRead(links[i].sock)) { + if (watcher.CheckRead(links[i].sock)) { ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size); if (ret != kSuccess) { return ReportError(&links[i], ret); @@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) { } } else { // read from in link - if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) { + if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) { ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size); if (ret != kSuccess) { return ReportError(&links[in_link], ret); @@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size, while (true) { // select helper bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; if (read_ptr != stop_read) { - selecter.WatchRead(next.sock); + watcher.WatchRead(next.sock); finished = false; } if (write_ptr != stop_write) { if (write_ptr < read_ptr) { - selecter.WatchWrite(prev.sock); + watcher.WatchWrite(prev.sock); } finished = false; } if (finished) break; - selecter.Select(); - if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { + watcher.Poll(); + if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { size_t size = stop_read - read_ptr; size_t start = read_ptr % total_size; if (start + size > total_size) { @@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, while (true) { // select helper bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; if (read_ptr != stop_read) { - selecter.WatchRead(next.sock); + watcher.WatchRead(next.sock); finished = false; } if (write_ptr != stop_write) { if (write_ptr < reduce_ptr) { - selecter.WatchWrite(prev.sock); + watcher.WatchWrite(prev.sock); } finished = false; } if (finished) break; - selecter.Select(); - if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { + watcher.Poll(); + if (read_ptr != stop_read && watcher.CheckRead(next.sock)) { ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); if (ret != kSuccess) { return ReportError(&next, ret); diff --git a/src/allreduce_robust-inl.h b/src/allreduce_robust-inl.h index d3cbc003..7baa14bf 100644 --- a/src/allreduce_robust-inl.h +++ b/src/allreduce_robust-inl.h @@ -69,30 +69,30 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, if (parent_index == -1) { utils::Assert(stage != 2 && stage != 1, "invalie stage id"); } - // select helper - utils::SelectHelper selecter; + // poll helper + utils::PollHelper watcher; bool done = (stage == 3); for (int i = 0; i < nlink; ++i) { - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); switch (stage) { case 0: if (i != parent_index && links[i].size_read != sizeof(EdgeType)) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); } break; case 1: if (i == parent_index) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } break; case 2: if (i == parent_index) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); } break; case 3: if (i != parent_index && links[i].size_write != sizeof(EdgeType)) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); done = false; } break; @@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, } // finish all the stages, and write out message if (done) break; - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { // recive OOB message from some link - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } @@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value, // read data from childs for (int i = 0; i < nlink; ++i) { if (i != parent_index) { - if (selecter.CheckRead(links[i].sock)) { + if (watcher.CheckRead(links[i].sock)) { ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType)); if (ret != kSuccess) return ReportError(&links[i], ret); } diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index a48a349a..210d5d8a 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { if (len == sizeof(sig)) all_links[i].size_write = 2; } } - utils::SelectHelper rsel; + utils::PollHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) { @@ -343,15 +343,15 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } if (finished) break; // wait to read from the channels to discard data - rsel.Select(); + rsel.Poll(); } for (int i = 0; i < nlink; ++i) { if (!all_links[i].sock.BadSocket()) { - utils::SelectHelper::WaitExcept(all_links[i].sock); + utils::PollHelper::WaitExcept(all_links[i].sock); } } while (true) { - utils::SelectHelper rsel; + utils::PollHelper rsel; bool finished = true; for (int i = 0; i < nlink; ++i) { if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) { @@ -359,7 +359,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } } if (finished) break; - rsel.Select(); + rsel.Poll(); for (int i = 0; i < nlink; ++i) { if (all_links[i].sock.BadSocket()) continue; if (all_links[i].size_read == 0) { @@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role, } while (true) { bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; for (int i = 0; i < nlink; ++i) { if (i == recv_link && links[i].size_read != size) { - selecter.WatchRead(links[i].sock); + watcher.WatchRead(links[i].sock); finished = false; } if (req_in[i] && links[i].size_write != size) { if (role == kHaveData || (links[recv_link].size_read != links[i].size_write)) { - selecter.WatchWrite(links[i].sock); + watcher.WatchWrite(links[i].sock); } finished = false; } - selecter.WatchException(links[i].sock); + watcher.WatchException(links[i].sock); } if (finished) break; - selecter.Select(); + watcher.Poll(); // exception handling for (int i = 0; i < nlink; ++i) { - if (selecter.CheckExcept(links[i].sock)) { + if (watcher.CheckExcept(links[i].sock)) { return ReportError(&links[i], kGetExcept); } } if (role == kRequestData) { const int pid = recv_link; - if (selecter.CheckRead(links[pid].sock)) { + if (watcher.CheckRead(links[pid].sock)) { ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size); if (ret != kSuccess) { return ReportError(&links[pid], ret); @@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, if (role == kPassData) { const int pid = recv_link; const size_t buffer_size = links[pid].buffer_size; - if (selecter.CheckRead(links[pid].sock)) { + if (watcher.CheckRead(links[pid].sock)) { size_t min_write = size; for (int i = 0; i < nlink; ++i) { if (req_in[i]) min_write = std::min(links[i].size_write, min_write); @@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_, char *buf = reinterpret_cast(sendrecvbuf_); while (true) { bool finished = true; - utils::SelectHelper selecter; + utils::PollHelper watcher; if (read_ptr != read_end) { - selecter.WatchRead(prev.sock); + watcher.WatchRead(prev.sock); finished = false; } if (write_ptr < read_ptr && write_ptr != write_end) { - selecter.WatchWrite(next.sock); + watcher.WatchWrite(next.sock); finished = false; } - selecter.WatchException(prev.sock); - selecter.WatchException(next.sock); + watcher.WatchException(prev.sock); + watcher.WatchException(next.sock); if (finished) break; - selecter.Select(); - if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept); - if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept); - if (read_ptr != read_end && selecter.CheckRead(prev.sock)) { + watcher.Poll(); + if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept); + if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept); + if (read_ptr != read_end && watcher.CheckRead(prev.sock)) { ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr); if (len == 0) { prev.sock.Close(); return ReportError(&prev, kRecvZeroLen); diff --git a/src/socket.h b/src/socket.h index 83d28e88..cfc2449c 100644 --- a/src/socket.h +++ b/src/socket.h @@ -20,17 +20,22 @@ #include #include #include -#include #include #endif #include #include +#include +#include #include "../include/rabit/internal/utils.h" #if defined(_WIN32) typedef int ssize_t; typedef int sock_size_t; + +static inline int poll(struct pollfd *pfd, int nfds, + int timeout) { return WSAPoll ( pfd, nfds, timeout ); } #else +#include typedef int SOCKET; typedef size_t sock_size_t; const int INVALID_SOCKET = -1; @@ -78,7 +83,7 @@ struct SockAddr { std::string buf; buf.resize(256); #ifdef _WIN32 const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, - &buf[0], buf.length()); + &buf[0], buf.length()); #else const char *s = inet_ntop(AF_INET, &addr.sin_addr, &buf[0], buf.length()); @@ -126,11 +131,11 @@ class Socket { #ifdef _WIN32 WSADATA wsa_data; if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { - Socket::Error("Startup"); + Socket::Error("Startup"); } if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { - WSACleanup(); - utils::Error("Could not find a usable version of Winsock.dll\n"); + WSACleanup(); + utils::Error("Could not find a usable version of Winsock.dll\n"); } #endif } @@ -209,7 +214,8 @@ class Socket { inline int GetSockError(void) const { int error = 0; socklen_t len = sizeof(error); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) != 0) { Error("GetSockError"); } return error; @@ -419,109 +425,100 @@ class TCPSocket : public Socket{ } }; -/*! \brief helper data structure to perform select */ -struct SelectHelper { +/*! \brief helper data structure to perform poll */ +struct PollHelper { public: - SelectHelper(void) { - FD_ZERO(&read_set); - FD_ZERO(&write_set); - FD_ZERO(&except_set); - maxfd = 0; - } /*! * \brief add file descriptor to watch for read * \param fd file descriptor to be watched */ inline void WatchRead(SOCKET fd) { - FD_SET(fd, &read_set); - if (fd > maxfd) maxfd = fd; + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLIN; } /*! * \brief add file descriptor to watch for write * \param fd file descriptor to be watched */ inline void WatchWrite(SOCKET fd) { - FD_SET(fd, &write_set); - if (fd > maxfd) maxfd = fd; + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLOUT; } /*! * \brief add file descriptor to watch for exception * \param fd file descriptor to be watched */ inline void WatchException(SOCKET fd) { - FD_SET(fd, &except_set); - if (fd > maxfd) maxfd = fd; + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLPRI; } /*! * \brief Check if the descriptor is ready for read * \param fd file descriptor to check status */ inline bool CheckRead(SOCKET fd) const { - return FD_ISSET(fd, &read_set) != 0; + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); } /*! * \brief Check if the descriptor is ready for write * \param fd file descriptor to check status */ inline bool CheckWrite(SOCKET fd) const { - return FD_ISSET(fd, &write_set) != 0; + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); } /*! * \brief Check if the descriptor has any exception * \param fd file descriptor to check status */ inline bool CheckExcept(SOCKET fd) const { - return FD_ISSET(fd, &except_set) != 0; + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); } /*! * \brief wait for exception event on a single descriptor * \param fd the file descriptor to wait the event for - * \param timeout the timeout counter, can be 0, which means wait until the event happen + * \param timeout the timeout counter, can be negative, which means wait until the event happen * \return 1 if success, 0 if timeout, and -1 if error occurs */ - inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*) - fd_set wait_set; - FD_ZERO(&wait_set); - FD_SET(fd, &wait_set); - return Select_(static_cast(fd + 1), - NULL, NULL, &wait_set, timeout); + inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*) + pollfd pfd; + pfd.fd = fd; + pfd.events = POLLPRI; + return poll(&pfd, 1, timeout); } + /*! - * \brief peform select on the set defined - * \param select_read whether to watch for read event - * \param select_write whether to watch for write event - * \param select_except whether to watch for exception event - * \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block - * \return number of active descriptors selected, - * return -1 if error occurs + * \brief peform poll on the set defined, read, write, exception + * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block + * \return */ - inline int Select(long timeout = 0) { // NOLINT(*) - int ret = Select_(static_cast(maxfd + 1), - &read_set, &write_set, &except_set, timeout); - if (ret == -1) { - Socket::Error("Select"); + inline void Poll(long timeout = -1) { // NOLINT(*) + std::vector fdset; + fdset.reserve(fds.size()); + for (auto kv : fds) { + fdset.push_back(kv.second); } - return ret; - } - - private: - inline static int Select_(int maxfd, fd_set *rfds, - fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*) -#if !defined(_WIN32) - utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE"); -#endif - if (timeout == 0) { - return select(maxfd, rfds, wfds, efds, NULL); + int ret = poll(fdset.data(), fdset.size(), timeout); + if (ret == -1) { + Socket::Error("Poll"); } else { - timeval tm; - tm.tv_usec = (timeout % 1000) * 1000; - tm.tv_sec = timeout / 1000; - return select(maxfd, rfds, wfds, efds, &tm); + for (auto& pfd : fdset) { + auto revents = pfd.revents & pfd.events; + if (!revents) { + fds.erase(pfd.fd); + } else { + fds[pfd.fd].events = revents; + } + } } } - SOCKET maxfd; - fd_set read_set, write_set, except_set; + std::unordered_map fds; }; } // namespace utils } // namespace rabit diff --git a/test/test.mk b/test/test.mk index 4a545113..9dfebb02 100644 --- a/test/test.mk +++ b/test/test.mk @@ -4,6 +4,7 @@ all: model_recover_10_10k model_recover_10_10k_die_same model_recover_10_10k_die_hard local_recover_10_10k # this experiment test recovery with actually process exit, use keepalive to keep program alive +# TODO: enable those tests once we fix issue in rabit model_recover_10_10k: ../dmlc-core/tracker/dmlc-submit --cluster local --num-workers=10 model_recover 10000 mock=0,0,1,0 mock=1,1,1,0