Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support larger cluster #73

Merged
merged 5 commits into from
Oct 22, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix error in #57, clean up comments and naming
Chen Qin committed Oct 19, 2018
commit ee19094380af6f9799de129c8cacdc19c238d17d
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ before_install:
- source ${TRAVIS}/travis_setup_env.sh

install:
- pip install --user cpplint pylint
- pip install --user cpplint pylint kubernetes

script: scripts/travis_script.sh

60 changes: 30 additions & 30 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
@@ -454,47 +454,47 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
}
}
// finish runing allreduce
if (finished) break;
// select must return
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
@@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) &&
if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
@@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
while (true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
@@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
@@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
@@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
18 changes: 9 additions & 9 deletions src/allreduce_robust-inl.h
Original file line number Diff line number Diff line change
@@ -70,29 +70,29 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
utils::PollHelper watcher;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 1:
if (i == parent_index) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
break;
case 2:
if (i == parent_index) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
done = false;
}
break;
@@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
}
// finish all the stages, and write out message
if (done) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
@@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
44 changes: 22 additions & 22 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
@@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (len == sizeof(sig)) all_links[i].size_write = 2;
}
}
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
@@ -343,23 +343,23 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
}
if (finished) break;
// wait to read from the channels to discard data
rsel.Select();
rsel.Poll();
}
for (int i = 0; i < nlink; ++i) {
if (!all_links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(all_links[i].sock);
utils::PollHelper::WaitExcept(all_links[i].sock);
}
}
while (true) {
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
rsel.WatchRead(all_links[i].sock); finished = false;
}
}
if (finished) break;
rsel.Select();
rsel.Poll();
for (int i = 0; i < nlink; ++i) {
if (all_links[i].sock.BadSocket()) continue;
if (all_links[i].size_read == 0) {
@@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role,
}
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == recv_link && links[i].size_read != size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
finished = false;
}
if (req_in[i] && links[i].size_write != size) {
if (role == kHaveData ||
(links[recv_link].size_read != links[i].size_write)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
if (finished) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (role == kRequestData) {
const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
if (ret != kSuccess) {
return ReportError(&links[pid], ret);
@@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kPassData) {
const int pid = recv_link;
const size_t buffer_size = links[pid].buffer_size;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
size_t min_write = size;
for (int i = 0; i < nlink; ++i) {
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
@@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != read_end) {
selecter.WatchRead(prev.sock);
watcher.WatchRead(prev.sock);
finished = false;
}
if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock);
watcher.WatchWrite(next.sock);
finished = false;
}
selecter.WatchException(prev.sock);
selecter.WatchException(next.sock);
watcher.WatchException(prev.sock);
watcher.WatchException(next.sock);
if (finished) break;
selecter.Select();
if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
watcher.Poll();
if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && watcher.CheckRead(prev.sock)) {
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) {
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);
866 changes: 431 additions & 435 deletions src/socket.h

Large diffs are not rendered by default.