Skip to content

Commit

Permalink
Replace select() with poll() in SelectHandler
Browse files Browse the repository at this point in the history
When launching multiple distributed XGBoosts within
one spark job it's pretty easy to run beyond FD_SETSIZE=1024
on linux.
  • Loading branch information
Boris Filippov committed Apr 20, 2018
1 parent 7bc46b8 commit 115aa0b
Showing 1 changed file with 45 additions and 45 deletions.
90 changes: 45 additions & 45 deletions src/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@
#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <sys/ioctl.h>
#endif
#include <string>
#include <cstring>
#include <vector>
#include <unordered_map>
#include "../include/rabit/internal/utils.h"

#if defined(_WIN32)
typedef int ssize_t;
typedef int sock_size_t;

static inline int poll(struct pollfd *pfd, int nfds,
int timeout) { return WSAPoll ( pfd, nfds, timeout ); }
#else
#include <sys/poll.h>
typedef int SOCKET;
typedef size_t sock_size_t;
const int INVALID_SOCKET = -1;
Expand Down Expand Up @@ -422,106 +427,101 @@ class TCPSocket : public Socket{
/*! \brief helper data structure to perform select */
struct SelectHelper {
public:
SelectHelper(void) {
FD_ZERO(&read_set);
FD_ZERO(&write_set);
FD_ZERO(&except_set);
maxfd = 0;
}
/*!
* \brief add file descriptor to watch for read
* \param fd file descriptor to be watched
*/
inline void WatchRead(SOCKET fd) {
FD_SET(fd, &read_set);
if (fd > maxfd) maxfd = fd;
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLIN;
}
/*!
* \brief add file descriptor to watch for write
* \param fd file descriptor to be watched
*/
inline void WatchWrite(SOCKET fd) {
FD_SET(fd, &write_set);
if (fd > maxfd) maxfd = fd;
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLOUT;
}
/*!
* \brief add file descriptor to watch for exception
* \param fd file descriptor to be watched
*/
inline void WatchException(SOCKET fd) {
FD_SET(fd, &except_set);
if (fd > maxfd) maxfd = fd;
auto& pfd = fds[fd];
pfd.fd = fd;
pfd.events |= POLLPRI;
}
/*!
* \brief Check if the descriptor is ready for read
* \param fd file descriptor to check status
*/
inline bool CheckRead(SOCKET fd) const {
return FD_ISSET(fd, &read_set) != 0;
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
}
/*!
* \brief Check if the descriptor is ready for write
* \param fd file descriptor to check status
*/
inline bool CheckWrite(SOCKET fd) const {
return FD_ISSET(fd, &write_set) != 0;
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
}
/*!
* \brief Check if the descriptor has any exception
* \param fd file descriptor to check status
*/
inline bool CheckExcept(SOCKET fd) const {
return FD_ISSET(fd, &except_set) != 0;
const auto& pfd = fds.find(fd);
return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
}
/*!
* \brief wait for exception event on a single descriptor
* \param fd the file descriptor to wait the event for
* \param timeout the timeout counter, can be 0, which means wait until the event happen
* \param timeout the timeout counter, can be negative, which means wait until the event happen
* \return 1 if success, 0 if timeout, and -1 if error occurs
*/
inline static int WaitExcept(SOCKET fd, long timeout = 0) { // NOLINT(*)
fd_set wait_set;
FD_ZERO(&wait_set);
FD_SET(fd, &wait_set);
return Select_(static_cast<int>(fd + 1),
NULL, NULL, &wait_set, timeout);
inline static int WaitExcept(SOCKET fd, long timeout = -1) { // NOLINT(*)
pollfd pfd;
pfd.fd = fd;
pfd.events = POLLPRI;
return poll(&pfd, 1, timeout);
}

/*!
* \brief peform select on the set defined
* \param select_read whether to watch for read event
* \param select_write whether to watch for write event
* \param select_except whether to watch for exception event
* \param timeout specify timeout in micro-seconds(ms) if equals 0, means select will always block
* \param timeout specify timeout in milliseconds(ms) if negative, means select will always block
* \return number of active descriptors selected,
* return -1 if error occurs
*/
inline int Select(long timeout = 0) { // NOLINT(*)
int ret = Select_(static_cast<int>(maxfd + 1),
&read_set, &write_set, &except_set, timeout);
inline void Select(long timeout = -1) { // NOLINT(*)
std::vector<pollfd> fdset;
fdset.reserve(fds.size());
for (auto kv : fds) {
fdset.push_back(kv.second);
}
int ret = poll(fdset.data(), fdset.size(), timeout);
if (ret == -1) {
Socket::Error("Select");
}
return ret;
}

private:
inline static int Select_(int maxfd, fd_set *rfds,
fd_set *wfds, fd_set *efds, long timeout) { // NOLINT(*)
#if !defined(_WIN32)
utils::Assert(maxfd < FD_SETSIZE, "maxdf must be smaller than FDSETSIZE");
#endif
if (timeout == 0) {
return select(maxfd, rfds, wfds, efds, NULL);
} else {
timeval tm;
tm.tv_usec = (timeout % 1000) * 1000;
tm.tv_sec = timeout / 1000;
return select(maxfd, rfds, wfds, efds, &tm);
for (auto& pfd : fdset) {
auto revents = pfd.revents & pfd.events;
if (!revents) {
fds.erase(pfd.fd);
} else {
fds[pfd.fd].events = revents;
}
}
}
}

SOCKET maxfd;
fd_set read_set, write_set, except_set;
std::unordered_map<SOCKET, pollfd> fds;
};
} // namespace utils
} // namespace rabit
Expand Down

0 comments on commit 115aa0b

Please sign in to comment.