diff --git a/include/rabit/internal/engine.h b/include/rabit/internal/engine.h index 79565218..81dee72d 100644 --- a/include/rabit/internal/engine.h +++ b/include/rabit/internal/engine.h @@ -158,8 +158,8 @@ class IEngine { * \param msg message to be printed in the tracker */ virtual void TrackerPrint(const std::string &msg) = 0; - virtual void TrackerSetConfig(const std::string &key, const std::string &value) = 0; - virtual void TrackerGetConfig(const std::string& key, std::string* value) = 0; + virtual void TrackerSetConfig(const std::string &key, const int size, const void* value) = 0; + virtual void TrackerGetConfig(const std::string& key, const int size, void* value) = 0; }; /*! \brief initializes the engine module */ diff --git a/include/rabit/internal/rabit-inl.h b/include/rabit/internal/rabit-inl.h index 861861b3..3ad5ee11 100644 --- a/include/rabit/internal/rabit-inl.h +++ b/include/rabit/internal/rabit-inl.h @@ -142,7 +142,8 @@ inline void Broadcast(std::vector *sendrecv_data, int root, const char* c Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root, caller); } } -inline void Broadcast(std::string *sendrecv_data, int root, const char* caller) { +inline void Broadcast(std::string *sendrecv_data, int root, + const char* caller) { size_t size = sendrecv_data->length(); Broadcast(&size, sizeof(size), root, caller); if (sendrecv_data->length() != size) { @@ -182,12 +183,12 @@ inline void TrackerPrint(const std::string &msg) { engine::GetEngine()->TrackerPrint(msg); } -inline void TrackerSetConfig(const std::string &key, const std::string &value) { - engine::GetEngine()->TrackerSetConfig(key, value); +inline void TrackerSetConfig(const std::string &key, const int bsize, const void *value) { + engine::GetEngine()->TrackerSetConfig(key, bsize, value); } -inline void TrackerGetConfig(const std::string &key, std::string* value) { - engine::GetEngine()->TrackerGetConfig(key, value); +inline void TrackerGetConfig(const std::string &key, const int bsize, void *value) { + engine::GetEngine()->TrackerGetConfig(key, bsize, value); } #ifndef RABIT_STRICT_CXX98_ @@ -202,7 +203,7 @@ inline void TrackerPrintf(const char *fmt, ...) { TrackerPrint(msg); } -inline void TrackerSetConfig(const char *key, const char *value, ...) { +inline void TrackerSetConfig(const char *key, const int bsize, const void *value, ...) { const int kPrintBuffer = 1 << 10; std::string k(kPrintBuffer, '\0'), v(kPrintBuffer, '\0'); @@ -210,28 +211,22 @@ inline void TrackerSetConfig(const char *key, const char *value, ...) { va_start(args1, key); va_start(args2, value); vsnprintf(&k[0], kPrintBuffer, key, args1); - vsnprintf(&v[0], kPrintBuffer, value, args2); va_end(args1); va_end(args2); k.resize(strlen(k.c_str())); - v.resize(strlen(v.c_str())); - engine::GetEngine()->TrackerSetConfig(k, v); + engine::GetEngine()->TrackerSetConfig(k, bsize, value); } -inline void TrackerGetConfig(const char *key, char* value, ...) { +inline void TrackerGetConfig(const char *key, const int bsize, void* value, ...) { const int kPrintBuffer = 1 << 10; std::string k(kPrintBuffer, '\0'), v(kPrintBuffer, '\0'); - va_list args1, args2; + va_list args1; va_start(args1, key); - va_start(args2, value); vsnprintf(&k[0], kPrintBuffer, key, args1); - vsnprintf(&v[0], kPrintBuffer, value, args2); va_end(args1); - va_end(args2); k.resize(strlen(k.c_str())); - v.resize(strlen(v.c_str())); - engine::GetEngine()->TrackerGetConfig(k, &v); + engine::GetEngine()->TrackerGetConfig(k, bsize, value); } #endif // RABIT_STRICT_CXX98_ // load latest check point diff --git a/include/rabit/rabit.h b/include/rabit/rabit.h index cac237d8..ca0e5c9a 100644 --- a/include/rabit/rabit.h +++ b/include/rabit/rabit.h @@ -105,13 +105,13 @@ inline void TrackerPrint(const std::string &msg); * \param key configuration key * \param value value of config */ -inline void TrackerSetConfig(const std::string &key, const std::string &value); +inline void TrackerSetConfig(const std::string &key, const int bsize, const void* value); /*! * \brief get config to tracker, * \param key configuration key * \param value value of config */ -inline void TrackerGetConfig(const std::string &key, std::string* value); +inline void TrackerGetConfig(const std::string &key, const int bsize, void* value); #ifndef RABIT_STRICT_CXX98_ /*! @@ -127,13 +127,13 @@ inline void TrackerPrintf(const char *fmt, ...); * \param key configuration key * \param value value of config */ -inline void TrackerSetConfig(const char *key, const char *value, ...); +inline void TrackerSetConfig(const char *key, const int bsize, const void* value, ...); /*! * \brief get config to tracker, * \param key configuration key * \param value value of config */ -inline void TrackerGetConfig(const char *key, char* value, ...); +inline void TrackerGetConfig(const char *key, const int bsize, void* value, ...); #endif // RABIT_STRICT_CXX98_ /*! * \brief broadcasts a memory region to every node from the root @@ -143,7 +143,8 @@ inline void TrackerGetConfig(const char *key, char* value, ...); * \param size the data size * \param root the process root */ -inline void Broadcast(void *sendrecv_data, size_t size, int root, const char* caller = __builtin_FUNCTION()); +inline void Broadcast(void *sendrecv_data, size_t size, int root, + const char* caller = __builtin_FUNCTION()); /*! * \brief broadcasts an std::vector to every node from root * \param sendrecv_data the pointer to send/receive vector, @@ -153,14 +154,16 @@ inline void Broadcast(void *sendrecv_data, size_t size, int root, const char* ca * that can be directly transmitted by sending the sizeof(DType) */ template -inline void Broadcast(std::vector *sendrecv_data, int root, const char* caller = __builtin_FUNCTION()); +inline void Broadcast(std::vector *sendrecv_data, int root, + const char* caller = __builtin_FUNCTION()); /*! * \brief broadcasts a std::string to every node from the root * \param sendrecv_data the pointer to the send/receive buffer, * for the receiver, the vector does not need to be pre-allocated * \param root the process root */ -inline void Broadcast(std::string *sendrecv_data, int root, const char* caller = __builtin_FUNCTION()); +inline void Broadcast(std::string *sendrecv_data, int root, + const char* caller = __builtin_FUNCTION()); /*! * \brief performs in-place Allreduce on sendrecvbuf * this function is NOT thread-safe @@ -185,7 +188,8 @@ inline void Broadcast(std::string *sendrecv_data, int root, const char* caller = template inline void Allreduce(DType *sendrecvbuf, size_t count, void (*prepare_fun)(void *) = NULL, - void *prepare_arg = NULL, const char* caller = __builtin_FUNCTION()); + void *prepare_arg = NULL, + const char* caller = __builtin_FUNCTION()); // C++11 support for lambda prepare function #if DMLC_USE_CXX11 /*! @@ -214,7 +218,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count, */ template inline void Allreduce(DType *sendrecvbuf, size_t count, - std::function prepare_fun, const char* caller = __builtin_FUNCTION()); + std::function prepare_fun, + const char* caller = __builtin_FUNCTION()); #endif // C++11 /*! * \brief loads the latest check point diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 3c184efe..b11527ad 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -146,19 +146,20 @@ void AllreduceBase::TrackerPrint(const std::string &msg) { tracker.Close(); } -void AllreduceBase::TrackerSetConfig(const std::string &key, const std::string &value) { +void AllreduceBase::TrackerSetConfig(const std::string &key, const int bytesize, const void* value) { utils::TCPSocket tracker = this->ConnectTracker(); tracker.SendStr(std::string("set")); tracker.SendStr(key); - tracker.SendStr(value); + tracker.Send(&bytesize, sizeof(int)); + tracker.SendAll(value, bytesize); tracker.Close(); } -void AllreduceBase::TrackerGetConfig(const std::string &key, std::string* value) { +void AllreduceBase::TrackerGetConfig(const std::string &key, const int bytesize, void* value) { utils::TCPSocket tracker = this->ConnectTracker(); tracker.SendStr(std::string("get")); tracker.SendStr(key); - tracker.RecvStr(value); + tracker.RecvAll(value, bytesize); tracker.Close(); } diff --git a/src/allreduce_base.h b/src/allreduce_base.h index da41fda5..29472855 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -54,8 +54,8 @@ class AllreduceBase : public IEngine { * \param msg message to be printed in the tracker */ virtual void TrackerPrint(const std::string &msg); - virtual void TrackerSetConfig(const std::string &key, const std::string &value); - virtual void TrackerGetConfig(const std::string& key, std::string* value); + virtual void TrackerSetConfig(const std::string &key, const int bytesize, const void* value); + virtual void TrackerGetConfig(const std::string &key, const int bytesize, void* value); /*! \brief get rank */ virtual int GetRank(void) const { diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index f3c77568..e56c3b52 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -92,6 +92,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_, if (prepare_fun != NULL) prepare_fun(prepare_arg); return; } + bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 &&