Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support larger cluster #73

Merged
merged 5 commits into from
Oct 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ before_install:
- source ${TRAVIS}/travis_setup_env.sh

install:
- pip install --user cpplint pylint
- pip install --user cpplint pylint kubernetes urllib3

script: scripts/travis_script.sh

Expand Down
6 changes: 0 additions & 6 deletions doc/Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ SUBGROUPING = YES
INLINE_GROUPED_CLASSES = NO
INLINE_SIMPLE_STRUCTS = NO
TYPEDEF_HIDES_STRUCT = NO
SYMBOL_CACHE_SIZE = 0
LOOKUP_CACHE_SIZE = 0
#---------------------------------------------------------------------------
# Build related configuration options
Expand Down Expand Up @@ -76,7 +75,6 @@ GENERATE_DEPRECATEDLIST= YES
ENABLED_SECTIONS =
MAX_INITIALIZER_LINES = 30
SHOW_USED_FILES = YES
SHOW_DIRECTORIES = NO
SHOW_FILES = YES
SHOW_NAMESPACES = YES
FILE_VERSION_FILTER =
Expand Down Expand Up @@ -142,7 +140,6 @@ HTML_COLORSTYLE_HUE = 220
HTML_COLORSTYLE_SAT = 100
HTML_COLORSTYLE_GAMMA = 80
HTML_TIMESTAMP = YES
HTML_ALIGN_MEMBERS = YES
HTML_DYNAMIC_SECTIONS = NO
GENERATE_DOCSET = NO
DOCSET_FEEDNAME = "Doxygen generated docs"
Expand All @@ -169,7 +166,6 @@ ECLIPSE_DOC_ID = org.doxygen.Project
DISABLE_INDEX = NO
GENERATE_TREEVIEW = NO
ENUM_VALUES_PER_LINE = 4
USE_INLINE_TREES = NO
TREEVIEW_WIDTH = 250
EXT_LINKS_IN_WINDOW = NO
FORMULA_FONTSIZE = 10
Expand Down Expand Up @@ -218,8 +214,6 @@ MAN_LINKS = NO
#---------------------------------------------------------------------------
GENERATE_XML = YES
XML_OUTPUT = xml
XML_SCHEMA =
XML_DTD =
XML_PROGRAMLISTING = YES
#---------------------------------------------------------------------------
# configuration options for the AutoGen Definitions output
Expand Down
18 changes: 14 additions & 4 deletions include/rabit/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,22 @@ RABIT_DLL void RabitInit(int argc, char *argv[]);
*/
RABIT_DLL void RabitFinalize(void);

/*! \brief get rank of current process */
/*!
* \brief get rank of current process
* \return rank number of worker
* */
RABIT_DLL int RabitGetRank(void);

/*! \brief get total number of process */
/*!
* \brief get total number of process
* \return total world size
* */
RABIT_DLL int RabitGetWorldSize(void);

/*! \brief get rank of current process */
/*!
* \brief get rank of current process
* \return if rabit is distributed
* */
RABIT_DLL int RabitIsDistributed(void);

/*!
Expand Down Expand Up @@ -136,6 +145,7 @@ RABIT_DLL void RabitCheckPoint(const char *global_model,
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \return rabit version number
*/
RABIT_DLL int RabitVersionNumber(void);

Expand All @@ -144,7 +154,7 @@ RABIT_DLL int RabitVersionNumber(void);
* \brief a Dummy function,
* used to cause force link of C API into the DLL.
* \code
* // force link rabit C API library.
* \/\/force link rabit C API library.
* static int must_link_rabit_ = RabitLinkTag();
* \endcode
* \return a dummy integer.
Expand Down
12 changes: 8 additions & 4 deletions include/rabit/rabit.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,18 @@ inline void Init(int argc, char *argv[]);
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
*/
inline void Finalize();
/*! \brief gets rank of the current process */
/*! \brief gets rank of the current process
* \return rank number of worker*/
inline int GetRank();
/*! \brief gets total number of processes */
/*! \brief gets total number of processes
* \return total world size*/
inline int GetWorldSize();
/*! \brief whether rabit env is in distributed mode */
/*! \brief whether rabit env is in distributed mode
* \return is distributed*/
inline bool IsDistributed();

/*! \brief gets processor's name */
/*! \brief gets processor's name
* \return processor name*/
inline std::string GetProcessorName();
/*!
* \brief prints the msg to the tracker,
Expand Down
64 changes: 32 additions & 32 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
} else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
#ifdef _MSC_VER
Sleep(1);
Sleep(retry << 1);
#else
sleep(1);
sleep(retry << 1);
#endif
continue;
}
Expand Down Expand Up @@ -454,47 +454,47 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
}
}
// finish runing allreduce
if (finished) break;
// select must return
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
Expand Down Expand Up @@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) &&
if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
Expand Down Expand Up @@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
while (true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
Expand All @@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
Expand Down Expand Up @@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
Expand Down Expand Up @@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
Expand Down
20 changes: 10 additions & 10 deletions src/allreduce_robust-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,30 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
// poll helper
utils::PollHelper watcher;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 1:
if (i == parent_index) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
break;
case 2:
if (i == parent_index) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
done = false;
}
break;
Expand All @@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
}
// finish all the stages, and write out message
if (done) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
Expand All @@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
Expand Down
Loading