Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix hanging trainings #132

Merged
merged 2 commits into from
Jan 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ bool AllreduceBase::Shutdown(void) {
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.Close();
// close listening sockets
sock_listen.Close();
utils::TCPSocket::Finalize();
return true;
} catch (const std::exception& e) {
Expand Down Expand Up @@ -282,6 +280,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
}
try {
utils::TCPSocket tracker = this->ConnectTracker();
fprintf(stdout, "task %s connected to the tracker\n", task_id.c_str());
tracker.SendStr(std::string(cmd));

// the rank of previous link, next link in ring
Expand All @@ -304,6 +303,8 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
// tracker got overwhelemed and not able to assign correct rank
if (rank == -1) exit(-1);

fprintf(stdout, "task %s got new rank %d\n", task_id.c_str(), rank);

Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
Expand All @@ -317,25 +318,15 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");

if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
if (!sock_listen.IsClosed()) {
sock_listen.Close();
}
// create listening socket
sock_listen.Create();
sock_listen.SetKeepAlive(true);
// http://deepix.github.io/2016/10/21/tcprst.html
sock_listen.SetLinger(0);
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
// work around processes bind to same port without set reuse option,
// start explore from slave_port + newrank towards end
port = sock_listen.TryBindHost(slave_port + newrank % nport_trial, slave_port + nport_trial);
// if no port bindable, explore first half of range
if (port == -1) sock_listen.TryBindHost(slave_port, newrank % nport_trial + slave_port);

utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
}
utils::TCPSocket sock_listen;
if (!sock_listen.IsClosed()) {
sock_listen.Close();
}
// create listening socket
sock_listen.Create();
int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial);
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();

// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
Expand Down Expand Up @@ -423,7 +414,7 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) {
}
if (!match) all_links.push_back(r);
}

sock_listen.Close();
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
Expand Down
4 changes: 0 additions & 4 deletions src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,10 +571,6 @@ class AllreduceBase : public IEngine {
int world_size;
// connect retry time
int connect_retry;
// backdoor listening peer connection
utils::TCPSocket sock_listen;
// backdoor port
int port = 0;
// enable bootstrap cache 0 false 1 true
bool rabit_bootstrap_cache = false;
// enable detailed logging
Expand Down
3 changes: 0 additions & 3 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,9 +708,6 @@ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
return true;
}
}
// print on tracker to help debuging
TrackerPrint("[ERROR] rank " + std::to_string(rank) + "@"+
host_uri + ":" +std::to_string(port) + " timeout\n");
_error("[%d] exit due to time out %d s\n", rank, timeout_sec);
return false;
});
Expand Down