From 13d852da08b2fe2c4998da9d563d418a32de23c2 Mon Sep 17 00:00:00 2001 From: Yulu Jia Date: Wed, 30 Dec 2020 22:14:35 -0800 Subject: [PATCH 1/2] add SyncBatchNorm - add byteps.torch.SyncBatchNorm - add BYTEPS_SYNC_BN_GLOBAL to choose between global sync and node local sync. defaults to node local sync - add node local allgather Signed-off-by: yulu.jia --- CHANGELOG.rst | 3 + byteps/common/__init__.py | 5 +- byteps/common/common.h | 7 +- byteps/common/communicator.cc | 3 + byteps/common/communicator.h | 2 + byteps/common/core_loops.cc | 57 ++++++- byteps/common/core_loops.h | 1 + byteps/common/global.cc | 10 +- byteps/common/global.h | 2 + byteps/common/operations.cc | 143 +++++++++++++++- byteps/common/operations.h | 13 +- byteps/common/scheduled_queue.cc | 5 + byteps/torch/__init__.py | 3 +- byteps/torch/ops.cc | 97 ++++++++++- byteps/torch/ops.h | 2 +- byteps/torch/ops.py | 97 ++++++++++- byteps/torch/sync_batch_norm.py | 285 +++++++++++++++++++++++++++++++ launcher/launch.py | 4 +- tests/test_bps_torch_syncbn.py | 144 ++++++++++++++++ 19 files changed, 849 insertions(+), 34 deletions(-) create mode 100644 byteps/torch/sync_batch_norm.py create mode 100644 tests/test_bps_torch_syncbn.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4c47dcc2c4..210e874337 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,9 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Changelog for BytePS ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +0.2.5.post2 (2020-11) +* add ability to collect PushPull performance data + 0.2.4 (2020-06) ------------------ * Fix compatibility issue with tf2 + standalone keras diff --git a/byteps/common/__init__.py b/byteps/common/__init__.py index 7160696160..83a8d7d3d6 100644 --- a/byteps/common/__init__.py +++ b/byteps/common/__init__.py @@ -60,9 +60,10 @@ def init(self, lazy=True): """A function that inits BytePS.""" atexit.register(self.shutdown) if lazy: - return self.C_LIB_CTYPES.byteps_lazy_init() + ret = self.C_LIB_CTYPES.byteps_lazy_init() else: - return self.C_LIB_CTYPES.byteps_init() + ret = self.C_LIB_CTYPES.byteps_init() + return ret def shutdown(self): """A function that shuts BytePS down.""" diff --git a/byteps/common/common.h b/byteps/common/common.h index 1db7f595aa..ad7fa9e44b 100644 --- a/byteps/common/common.h +++ b/byteps/common/common.h @@ -98,6 +98,8 @@ enum QueueType { COPYH2D, COORDINATE_BROADCAST, BROADCAST, + ALLGATHER, + COORDINATE_ALLGATHER, QUEUE_NUM_AND_NOT_A_REAL_QUEUE_TYPE_AND_MUST_BE_THE_LAST }; @@ -115,7 +117,10 @@ const std::vector LogStrings = {"COORDINATE_REDUCE", "DECOMPRESS", "COPYH2D", "COORDINATE_BROADCAST", - "BROADCAST"}; + "BROADCAST", + "ALLGATHER", + "COORDINATE_ALLGATHER", +}; class Status { public: diff --git a/byteps/common/communicator.cc b/byteps/common/communicator.cc index 1ed9b238d9..f29726b3b5 100644 --- a/byteps/common/communicator.cc +++ b/byteps/common/communicator.cc @@ -191,6 +191,9 @@ void BytePSCommSocket::startListenThread() { // only root starts this in case BCAST_READY: BytePSGlobal::GetBroadcastTable()->AddReadyCount(message.key); break; + case ALLGATHER_READY: + BytePSGlobal::GetAllgatherTable()->AddReadyCount(message.key); + break; case PUSH_READY: BytePSGlobal::GetPushTable()->AddReadyCount(message.key); break; diff --git a/byteps/common/communicator.h b/byteps/common/communicator.h index ac42786eb6..5851dac802 100644 --- a/byteps/common/communicator.h +++ b/byteps/common/communicator.h @@ -44,9 +44,11 @@ enum BytePSCommSignal { REDUCE_READY, PCIE_REDUCE_READY, BCAST_READY, + ALLGATHER_READY, PUSH_READY, DO_REDUCE, DO_BROADCAST, + DO_ALLGATHER, DO_GROUP, DO_COPYH2D }; diff --git a/byteps/common/core_loops.cc b/byteps/common/core_loops.cc index a17803e38b..cbb569f1f4 100644 --- a/byteps/common/core_loops.cc +++ b/byteps/common/core_loops.cc @@ -162,6 +162,11 @@ bool RunCoordinateLoopOnce(QueueType this_op) { comm = BytePSGlobal::GetNccl()->GetSignalComm(); break; } + case COORDINATE_ALLGATHER: { + sig = ALLGATHER_READY; + comm = BytePSGlobal::GetNccl()->GetSignalComm(); + break; + } case COORDINATE_PUSH: { sig = PUSH_READY; comm = BytePSGlobal::GetBasicComm(); @@ -189,9 +194,20 @@ bool RunCoordinateLoopOnce(QueueType this_op) { inline void PostNcclCalls( std::shared_ptr task, QueueType this_op) { - BPS_CHECK(this_op == REDUCE || this_op == BROADCAST) - << "Only REDUCE and BROADCAST use NCCL."; - auto tensor = (this_op == REDUCE) ? task->tensor : task->output; + BPS_CHECK(this_op == REDUCE || this_op == BROADCAST || this_op == ALLGATHER) + << "Only REDUCE, BROADCAST and ALLGATHER use NCCL."; + + decltype(task->tensor) tensor; + + switch (this_op) { + case REDUCE: + case ALLGATHER: { + tensor = task->tensor; + break; + } + default: + tensor = task->output; + } BPS_CHECK(tensor); BPS_CHECK_EQ(0, tensor->size() % tensor->shape().num_elements()); @@ -203,6 +219,7 @@ inline void PostNcclCalls( if (task->device == CPU_DEVICE_ID) { p = (char *)(task->gpu_ptr) + offset; } + auto out_p = (char *)(task->output->data()) + offset; auto nccl_dtype = getNcclDataType(tensor->dtype()); @@ -213,6 +230,7 @@ inline void PostNcclCalls( auto nccl_size = nccl->GetSize(); auto nccl_rank = nccl->GetRank(key, this_op); + auto num_elem_all = len / unit_len; auto num_elem_per_gpu = len / nccl_size / unit_len; auto left_elem = (len / unit_len) - (num_elem_per_gpu * nccl_size); if (BytePSGlobal::IsUsingReduce()) { @@ -251,6 +269,12 @@ inline void PostNcclCalls( (ncclRedOp_t)ncclSum, (int)nccl_root, (ncclComm_t)nccl_comm, (cudaStream_t)nccl_stream)); } + } else if (this_op == ALLGATHER) { + BPS_CHECK(task->device != CPU_DEVICE_ID); + NCCLCHECK(ncclAllGather( + (const void *)(p), + (void *)out_p, (size_t)num_elem_all, (ncclDataType_t)nccl_dtype, + (ncclComm_t)nccl_comm, (cudaStream_t)nccl_stream)); } else { if (num_elem_per_gpu) { NCCLCHECK(ncclAllGather( @@ -275,7 +299,7 @@ bool RunRootNcclLoopOnce() { BPS_CHECK_EQ(rank, root); int nccl_size = BytePSGlobal::GetNccl()->GetSize(); - QueueType nccl_ops[] = {REDUCE, BROADCAST}; + QueueType nccl_ops[] = {REDUCE, BROADCAST, ALLGATHER}; auto nccl_entry = std::make_shared(); auto &tasks = nccl_entry->tasks; @@ -294,8 +318,22 @@ bool RunRootNcclLoopOnce() { if (nccl_size > 1) { // notify non-root devices + BytePSCommSignal sig; + switch (this_op) { + case REDUCE: + sig = DO_REDUCE; + break; + case BROADCAST: + sig = DO_BROADCAST; + break; + case ALLGATHER: + sig = DO_ALLGATHER; + break; + default: + BPS_CHECK(0) << "unsupported operation: " << this_op; + } struct BytePSCommMsg msg = { - rank, (this_op == REDUCE) ? DO_REDUCE : DO_BROADCAST, task->key}; + rank, sig, task->key}; signal_comm->broadcastSignal(&msg, sizeof(BytePSCommMsg)); PostNcclCalls(task, this_op); } @@ -337,6 +375,8 @@ bool RunNonRootNcclLoopOnce() { QueueType this_op = REDUCE; if (msg.signal == DO_BROADCAST) { this_op = BROADCAST; + } else if (msg.signal == DO_ALLGATHER) { + this_op = ALLGATHER; } else { BPS_CHECK_EQ(msg.signal, DO_REDUCE) << msg.signal << ", " << DO_REDUCE; } @@ -752,6 +792,13 @@ bool RunNonRootCopyHost2DeviceLoopOnce() { return true; } +void CoordinateAllgatherLoop() { + while (RunCoordinateLoopOnce(COORDINATE_ALLGATHER) && + !BytePSGlobal::ShouldShutdown()) { + } + BytePSGlobal::ReportThreadFinish(); +} + void CoordinateReduceLoop() { while (RunCoordinateLoopOnce(COORDINATE_REDUCE) && !BytePSGlobal::ShouldShutdown()) { diff --git a/byteps/common/core_loops.h b/byteps/common/core_loops.h index 2437c33239..bd89b17268 100644 --- a/byteps/common/core_loops.h +++ b/byteps/common/core_loops.h @@ -22,6 +22,7 @@ namespace common { void CoordinateReduceLoop(); void CoordinateBroadcastLoop(); +void CoordinateAllgatherLoop(); void CoordinatePushLoop(); diff --git a/byteps/common/global.cc b/byteps/common/global.cc index 770c79c037..2fdda10437 100644 --- a/byteps/common/global.cc +++ b/byteps/common/global.cc @@ -68,6 +68,7 @@ std::mutex BytePSGlobal::_encode_mutex; ReadyTable* BytePSGlobal::_reduce_table; ReadyTable* BytePSGlobal::_pcie_reduce_table; ReadyTable* BytePSGlobal::_broadcast_table; +ReadyTable* BytePSGlobal::_allgather_table; ReadyTable* BytePSGlobal::_push_table; ReadyTable* BytePSGlobal::_copy_table; bool BytePSGlobal::_is_using_reduce = false; @@ -232,6 +233,9 @@ void BytePSGlobal::Init() { _reduce_table = new ReadyTable(GetPcieSwitchSize() - 1, "NCCL_REDUCE"); _broadcast_table = new ReadyTable(GetPcieSwitchSize() - 1, "NCCL_BROADCAST"); + _allgather_table = + new ReadyTable(GetPcieSwitchSize() - 1, "NCCL_ALLGATHER"); + BPS_LOG(DEBUG) << "Created reduce table, broadcast table and alltagher table"; } // Configure the reduce strategy @@ -370,6 +374,10 @@ void BytePSGlobal::Shutdown() { delete _broadcast_table; _broadcast_table = NULL; } + if (_allgather_table) { + delete _allgather_table; + _allgather_table = NULL; + } if (_push_table) { delete _push_table; _push_table = NULL; @@ -701,7 +709,7 @@ std::size_t PushPullSpeed::_limit = 1024; std::chrono::time_point PushPullSpeed::_last_ts; bool PushPullSpeed::_initialized = false; bool PushPullSpeed::_should_record = - getenv("BYTEPS_TELEMETRY_ON") ? atoi(getenv("BYTEPS_TELEMETRY_ON")) : true; + getenv("BYTEPS_TELEMETRY_ON") ? atoi(getenv("BYTEPS_TELEMETRY_ON")) : false; void PushPullSpeed::RecordSpeed(std::shared_ptr task) { std::lock_guard lock(_mtx); diff --git a/byteps/common/global.h b/byteps/common/global.h index 387e84f244..883d0f3577 100644 --- a/byteps/common/global.h +++ b/byteps/common/global.h @@ -110,6 +110,7 @@ class BytePSGlobal { static ReadyTable* GetReduceTable() { return _reduce_table; } static ReadyTable* GetPcieReduceTable() { return _pcie_reduce_table; } static ReadyTable* GetBroadcastTable() { return _broadcast_table; } + static ReadyTable* GetAllgatherTable() { return _allgather_table; } static ReadyTable* GetPushTable() { return _push_table; } // reduce strategies @@ -187,6 +188,7 @@ class BytePSGlobal { static ReadyTable* _reduce_table; static ReadyTable* _pcie_reduce_table; static ReadyTable* _broadcast_table; + static ReadyTable* _allgather_table; static ReadyTable* _push_table; // (key, ready_signal_count) pair, only valid for non-root device diff --git a/byteps/common/operations.cc b/byteps/common/operations.cc index aeb198f3e6..56c0c0efde 100644 --- a/byteps/common/operations.cc +++ b/byteps/common/operations.cc @@ -80,6 +80,7 @@ void byteps_lazy_init() { } else { func.push_back(CoordinateReduceLoop); func.push_back(CoordinateBroadcastLoop); + func.push_back(CoordinateAllgatherLoop); func.push_back(NonRootNcclLoop); } @@ -280,6 +281,103 @@ Status EnqueueTensor(BPSContext &context, std::shared_ptr input, return Status::OK(); } +Status EnqueueAllgatherTensor(BPSContext &context, std::shared_ptr input, + std::shared_ptr output, + std::shared_ptr ready_event, const int device, + const int priority, const int version, + StatusCallback callback, + std::shared_ptr> queue_list) { + if (BytePSGlobal::ShouldShutdown()) { + return Status::OK(); + } + + auto &name = context.tensor_name; + + // add queue + if (BytePSGlobal::IsRootDevice() && !context.compressor_list.empty()) { + auto it = std::find(queue_list->begin(), queue_list->end(), PUSH); + it = queue_list->insert(it, COMPRESS); // before PUSH + it = std::find(queue_list->begin(), queue_list->end(), PULL); + queue_list->insert(it + 1, DECOMPRESS); // after PULL + } + + std::shared_ptr e(new TensorTableEntry); + e->tensor_name = name; + e->context = &context; + e->tensor = input; + e->output = output; + e->ready_event = ready_event; + e->device = device; + e->priority = priority; + e->version = version; + e->callback = callback; + + if (device == CPU_DEVICE_ID) { + cudaError_t err = cudaHostRegister(const_cast(input->data()), + input->size(), cudaHostRegisterMapped); + if (err == cudaSuccess) { + BPS_LOG(DEBUG) << name + << " cpu address has changed, so it is pinned again."; + } + CUDA_CALL(cudaHostGetDevicePointer(&(context.gpu_ptr), + const_cast(input->data()), 0)); + } + + e->cpubuff = context.cpubuff; + e->gpu_ptr = context.gpu_ptr; + e->pcie_cpubuff = context.pcie_cpubuff; + e->queue_list = *queue_list; + e->counter_ptr = std::make_shared(0); + e->total_partnum = context.key_list.size(); + + std::vector> partitions; + PartitionTensor(e, partitions); + BPS_CHECK_EQ(context.key_list.size(), partitions.size()) + << name << ": " << context.key_list.size() << ", " << partitions.size(); + + if (e->queue_list.size() == 0) { + BPS_CHECK(e->tensor_name != ""); + BPS_LOG(TRACE) << e->tensor_name << ", device=" << e->device + << " has no queue_list assigned, skipped"; + e->callback(Status::OK()); + return Status::OK(); + } + + // add for profiling + if (context.profile_flag) { + auto now = std::chrono::system_clock::now(); + auto duration = now.time_since_epoch(); + auto us = std::chrono::duration_cast(duration); + + BPSCommTime *ret = new BPSCommTime; + ret->start_t = (long long)(us.count()); + context.comm_time.push(ret); + } + + unsigned int accumulated = 0; + for (size_t i = 0; i < partitions.size(); ++i) { + auto task = partitions[i]; + task->key = context.key_list[i]; // assign the key now + BPS_CHECK(task->tensor_name != ""); + BPS_LOG(TRACE) << "EnqueueTensor: " << (task->tensor_name) + << ", key=" << (task->key) << ", offset=" << (task->offset) + << ", len=" << (task->len) << ", device=" << (task->device) + << " rank=" << BytePSGlobal::GetLocalRank(); + + BytePSGlobal::GetScheduledQueue(e->queue_list[0])->addTask(task); + accumulated += task->len; + } + + auto tensor = (e->tensor ? e->tensor : e->output); + BPS_CHECK(tensor); + BPS_CHECK_EQ(accumulated, tensor->size()) + << "accumulated partition size not equal to original tensor size"; + + BPS_LOG(TRACE) << "EnqueueTensor finished: " << name + << ", rank=" << BytePSGlobal::GetLocalRank(); + return Status::OK(); +} + void InitTensor(BPSContext &context, size_t size, int dtype, void *cpubuff) { std::lock_guard lock(context.init_mutex); if (context.initialized) { @@ -348,8 +446,16 @@ void InitTensor(BPSContext &context, size_t size, int dtype, void *cpubuff) { shm_obj->openPcieSharedMemory(key_list[0], aligned_size); context.cpubuff = context.pcie_cpubuff.back(); } else { - context.cpubuff = shm_obj->openSharedMemory(std::string("BytePS_ShM_"), + int debug_mode = getenv("BYTEPS_DEBUG_MODE") ? + atoi(getenv("BYTEPS_DEBUG_MODE")) : 0; + std::string shm_prefix = std::string("BytePS_ShM_"); + if (debug_mode) { + shm_prefix = shm_prefix + std::to_string(BytePSGlobal::GetWorkerID()) + + std::string("_"); + } + context.cpubuff = shm_obj->openSharedMemory(shm_prefix, key_list[0], aligned_size); + } BPS_LOG(TRACE) << name << ": open shared memory size " << aligned_size; @@ -426,7 +532,9 @@ void RegisterCompressor(const std::string &name, return BytePSGlobal::RegisterCompressor(name, kwargs); } -std::shared_ptr> GetPushQueueList(int device) { +std::shared_ptr> GetPushQueueList(int device, + bool node_local) { + auto queue_list = std::make_shared>(); // Per-PCIe-switch NCCL reduce @@ -438,7 +546,7 @@ std::shared_ptr> GetPushQueueList(int device) { } // Copy from GPU to CPU - if (BytePSGlobal::IsDistributed() || BytePSGlobal::IsCrossPcieSwitch()) { + if ((!node_local && BytePSGlobal::IsDistributed()) || BytePSGlobal::IsCrossPcieSwitch()) { queue_list->push_back(COPYD2H); } @@ -449,7 +557,7 @@ std::shared_ptr> GetPushQueueList(int device) { // Push in distributed mode // In case IsCrossPcieSwitch(), PUSH runs as a dummy barrier - if (BytePSGlobal::IsDistributed() || BytePSGlobal::IsCrossPcieSwitch()) { + if ((!node_local && BytePSGlobal::IsDistributed()) || BytePSGlobal::IsCrossPcieSwitch()) { if (BytePSGlobal::IsRootDevice()) { queue_list->push_back(PUSH); } else { @@ -459,18 +567,20 @@ std::shared_ptr> GetPushQueueList(int device) { return queue_list; } -std::shared_ptr> GetPullQueueList(int device) { +std::shared_ptr> GetPullQueueList(int device, + bool node_local) { + auto queue_list = std::make_shared>(); // Pull in distributed mode - if (BytePSGlobal::IsDistributed()) { + if (!node_local && BytePSGlobal::IsDistributed()) { if (BytePSGlobal::IsRootDevice()) { queue_list->push_back(PULL); } } // Copy from CPU to GPU - if (BytePSGlobal::IsDistributed() || BytePSGlobal::IsCrossPcieSwitch()) { + if ((!node_local && BytePSGlobal::IsDistributed()) || BytePSGlobal::IsCrossPcieSwitch()) { queue_list->push_back(COPYH2D); } @@ -484,5 +594,24 @@ std::shared_ptr> GetPullQueueList(int device) { return queue_list; } +std::shared_ptr> GetAllgatherQueueList(int device, + bool node_local) { + + auto queue_list = std::make_shared>(); + + assert(node_local); + + // Per-PCIe-switch NCCL reduce + if (BytePSGlobal::GetNccl()->IsSignalRoot()) { + queue_list->push_back(ALLGATHER); + BPS_LOG(DEBUG) << "added ALLGATHER to queue list"; + } else { + queue_list->push_back(COORDINATE_ALLGATHER); + queue_list->push_back(ALLGATHER); + BPS_LOG(DEBUG) << "added COORDINATE_ALLGATHER and ALLGATHER to queue list"; + } + return queue_list; +} + } // namespace common } // namespace byteps diff --git a/byteps/common/operations.h b/byteps/common/operations.h index a91d93900b..6df341b3b9 100644 --- a/byteps/common/operations.h +++ b/byteps/common/operations.h @@ -69,6 +69,13 @@ Status EnqueueTensor(BPSContext &context, std::shared_ptr input, StatusCallback callback, std::shared_ptr> queue_list); +Status EnqueueAllgatherTensor(BPSContext &context, std::shared_ptr input, + std::shared_ptr output, + std::shared_ptr ready_event, const int device, + const int priority, const int version, + StatusCallback callback, + std::shared_ptr> queue_list); + void InitTensor(BPSContext &context, size_t size, int dtype, void *cpubuff); // Only call these in Framework plugins for the best performance @@ -79,9 +86,11 @@ void RegisterCompressor(const std::string &name, BPSContext &GetContextFromName(const std::string &name); -std::shared_ptr> GetPushQueueList(int device); +std::shared_ptr> GetPushQueueList(int device, bool node_local = false); + +std::shared_ptr> GetPullQueueList(int device, bool node_local = false); -std::shared_ptr> GetPullQueueList(int device); +std::shared_ptr> GetAllgatherQueueList(int device, bool node_local = false); } // namespace common } // namespace byteps diff --git a/byteps/common/scheduled_queue.cc b/byteps/common/scheduled_queue.cc index 4b4a88e1ea..303ed32710 100644 --- a/byteps/common/scheduled_queue.cc +++ b/byteps/common/scheduled_queue.cc @@ -74,6 +74,11 @@ BytePSScheduledQueue::BytePSScheduledQueue(QueueType type) { _rt = BytePSGlobal::GetBroadcastTable(); } break; + case ALLGATHER: + if (BytePSGlobal::GetNccl()->IsSignalRoot()) { + _rt = BytePSGlobal::GetAllgatherTable(); + } + break; default: break; } diff --git a/byteps/torch/__init__.py b/byteps/torch/__init__.py index dd3729aaa3..b59f3a6d95 100644 --- a/byteps/torch/__init__.py +++ b/byteps/torch/__init__.py @@ -22,10 +22,11 @@ from byteps.torch.compression import Compression from byteps.torch.ops import push_pull_async_inplace as byteps_push_pull -from byteps.torch.ops import push_pull +from byteps.torch.ops import push_pull, allgather from byteps.torch.ops import poll, synchronize, declare from byteps.torch.ops import init, shutdown, suspend, resume from byteps.torch.ops import size, local_size, rank, local_rank +from byteps.torch.sync_batch_norm import SyncBatchNorm import os import torch diff --git a/byteps/torch/ops.cc b/byteps/torch/ops.cc index 1a98645588..1ecc86da4a 100644 --- a/byteps/torch/ops.cc +++ b/byteps/torch/ops.cc @@ -52,7 +52,8 @@ int GetDeviceID(const ::torch::Tensor& tensor) { } // namespace void StartTask(::torch::Tensor tensor, ::torch::Tensor output, int average, - const std::string tensor_name, int version, int priority, int handle) { + const std::string tensor_name, int version, int priority, int handle, + bool node_local = false) { auto device = GetDeviceID(tensor); auto ready_event = RecordReadyEvent(device); @@ -67,8 +68,8 @@ void StartTask(::torch::Tensor tensor, ::torch::Tensor output, int average, ? const_cast(byteps_input->data()) : nullptr); - auto queue_list = common::GetPushQueueList(device); - auto queue_list_pull = common::GetPullQueueList(device); + auto queue_list = common::GetPushQueueList(device, node_local); + auto queue_list_pull = common::GetPullQueueList(device, node_local); queue_list->insert(queue_list->end(), queue_list_pull->begin(), queue_list_pull->end()); @@ -97,16 +98,79 @@ void StartTask(::torch::Tensor tensor, ::torch::Tensor output, int average, } int DoPushPull(::torch::Tensor tensor, ::torch::Tensor output, int average, - const std::string& name, int version, int priority) { + const std::string& name, int version, int priority, + bool node_local) { ThrowIfError(common::CheckInitialized()); auto handle = handle_manager.AllocateHandle(); std::string tensor_name = GetOpName("byteps", name.c_str(), 0); auto& context = common::GetContextFromName(tensor_name); if (context.initialized) { - StartTask(tensor, output, average, tensor_name, version, priority, handle); + StartTask(tensor, output, average, tensor_name, version, priority, handle, + node_local); } else { - std::thread t(StartTask, tensor, output, average, tensor_name, version, priority, handle); + std::thread t(StartTask, tensor, output, average, tensor_name, version, priority, handle, node_local); + t.detach(); + } + return handle; +} + +void StartAllGatherTask(::torch::Tensor tensor, ::torch::Tensor output, int average, + const std::string tensor_name, int version, int priority, int handle, + bool node_local = false) { + + auto device = GetDeviceID(tensor); + auto ready_event = RecordReadyEvent(device); + auto byteps_input = std::make_shared(tensor); + auto byteps_output = std::make_shared(output); + size_t size = byteps_input->size(); + auto dtype = byteps_input->dtype(); + + auto& context = common::GetContextFromName(tensor_name); + common::InitTensor(context, size, dtype, + (device == CPU_DEVICE_ID) + ? const_cast(byteps_input->data()) + : nullptr); + + auto queue_list = common::GetAllgatherQueueList(device, node_local); + + auto enqueue_result = common::EnqueueAllgatherTensor( + context, byteps_input, byteps_output, ready_event, device, priority, + version, + [handle, average, tensor, output](const Status& status) mutable { + // Will execute in the `device` context. + if (average) { +#if TORCH_VERSION >= 1005000000 + if (isIntegralType(output.scalar_type(), false)) { + output.floor_divide_(byteps_size()); + handle_manager.MarkDone(handle, status); + return; + } +#endif + output.div_(byteps_size()); + } + handle_manager.MarkDone(handle, status); + }, + queue_list); + + ThrowIfError(enqueue_result); + return; + +} + +int DoAllGather(::torch::Tensor tensor, ::torch::Tensor output, int average, + const std::string& name, int version, int priority, + bool node_local) { + ThrowIfError(common::CheckInitialized()); + + auto handle = handle_manager.AllocateHandle(); + std::string tensor_name = GetOpName("byteps", name.c_str(), 0); + auto& context = common::GetContextFromName(tensor_name); + if (context.initialized) { + StartAllGatherTask(tensor, output, average, tensor_name, version, priority, handle, + node_local); + } else { + std::thread t(StartAllGatherTask, tensor, output, average, tensor_name, version, priority, handle, node_local); t.detach(); } return handle; @@ -146,10 +210,11 @@ pybind11::tuple DoPushPullGroupSync(::torch::Tensor tensor, int curr_count; if (context.initialized) { - StartTask(tensor, output, average, tensor_name, version, priority, handle); + StartTask(tensor, output, average, tensor_name, version, priority, handle, + false); } else { std::thread t(StartTask, tensor, output, average, tensor_name, version, - priority, handle); + priority, handle, false); t.detach(); } @@ -183,6 +248,14 @@ PYBIND11_MODULE(c_lib, m) { m.def("byteps_torch_push_pull_group_sync_torch_FloatTensor", &DoPushPullGroupSync); m.def("byteps_torch_push_pull_group_sync_torch_DoubleTensor", &DoPushPullGroupSync); + // allgather + m.def("byteps_torch_allgather_async_torch_ByteTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_IntTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_LongTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_HalfTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_FloatTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_DoubleTensor", &DoAllGather); + #if HAVE_CUDA m.def("byteps_torch_push_pull_async_torch_cuda_ByteTensor", &DoPushPull); m.def("byteps_torch_push_pull_async_torch_cuda_IntTensor", &DoPushPull); @@ -197,6 +270,14 @@ PYBIND11_MODULE(c_lib, m) { m.def("byteps_torch_push_pull_group_sync_torch_cuda_HalfTensor", &DoPushPullGroupSync); m.def("byteps_torch_push_pull_group_sync_torch_cuda_FloatTensor", &DoPushPullGroupSync); m.def("byteps_torch_push_pull_group_sync_torch_cuda_DoubleTensor", &DoPushPullGroupSync); + + // allgather + m.def("byteps_torch_allgather_async_torch_cuda_ByteTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_cuda_IntTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_cuda_LongTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_cuda_HalfTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_cuda_FloatTensor", &DoAllGather); + m.def("byteps_torch_allgather_async_torch_cuda_DoubleTensor", &DoAllGather); #endif // basics diff --git a/byteps/torch/ops.h b/byteps/torch/ops.h index 42f5a99e8e..3fbfa0d70a 100644 --- a/byteps/torch/ops.h +++ b/byteps/torch/ops.h @@ -39,7 +39,7 @@ size_t grad_count_; #define PUSHPULL_H(torch_Tensor, THTensor) \ extern "C" int byteps_torch_push_pull_async_##torch_Tensor( \ THTensor* tensor, THTensor* output, int average, char* name, \ - int version, int priority); + int version, int priority, bool node_local = false); PUSHPULL_H(torch_ByteTensor, THByteTensor) PUSHPULL_H(torch_IntTensor, THIntTensor) diff --git a/byteps/torch/ops.py b/byteps/torch/ops.py index 27208ba857..348d6358d8 100644 --- a/byteps/torch/ops.py +++ b/byteps/torch/ops.py @@ -66,12 +66,13 @@ def _push_pull_function_factory(tensor): def _push_pull_group_function_factory(tensor): return 'byteps_torch_push_pull_group_sync_' + tensor.type().replace('.', '_') -def _do_push_pull_async(tensor, output, average, name, version=0, priority=0): +def _do_push_pull_async(tensor, output, average, name, version=0, priority=0, + node_local=False): c_lib.byteps_torch_declare_tensor(name.encode() if name is not None else _NULL) function = _check_function(_push_pull_function_factory, tensor) handle = getattr(c_lib, function)(tensor, output, average, name.encode() if name is not None else _NULL, - version, priority) + version, priority, node_local) _handle_map[handle] = (tensor, output) return handle @@ -85,7 +86,8 @@ def _do_push_pull_group_sync(tensor, output, average, name, version=0, priority= return handle, curr_count -def push_pull_async(tensor, average=True, name=None, version=0, priority=0): +def push_pull_async(tensor, average=True, name=None, version=0, priority=0, + node_local=False): """ A function that performs asynchronous averaging or summation of the input tensor over all the BytePS processes. The input tensor is not modified. @@ -103,7 +105,8 @@ def push_pull_async(tensor, average=True, name=None, version=0, priority=0): `synchronize()`. """ output = tensor.new(tensor.shape) - return _do_push_pull_async(tensor, output, average, name, version, priority) + return _do_push_pull_async(tensor, output, average, name, version, priority, + node_local) class BytePSPushPull(torch.autograd.Function): @@ -196,6 +199,92 @@ def push_pull_inplace(tensor, average=True, name=None, version=0, priority=0): handle = push_pull_async_inplace(tensor, average, name, version, priority) return synchronize(handle) +def _allgather_function_factory(tensor): + return 'byteps_torch_allgather_async_' + tensor.type().replace('.', '_') + + +def _do_allgather_async(tensor, output, name, version, priority, node_local): + average = False + c_lib.byteps_torch_declare_tensor(name.encode() if name is not None else _NULL) + function = _check_function(_allgather_function_factory, tensor) + handle = getattr(c_lib, function)( + tensor, output, average, name.encode() if name is not None else _NULL, + version, priority, node_local) + _handle_map[handle] = (tensor, output) + return handle + + +def allgather_async(tensor, name=None, version=0, priority=0, node_local=True): + """ + A function that asynchronously concatenates the input tensor with the same input + tensor on all other Horovod processes. The input tensor is not modified. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Arguments: + tensor: A tensor to allgather. + name: A name of the allgather operation. + + Returns: + A handle to the allgather operation that can be used with `poll()` or + `synchronize()`. + """ + output_shape = list(tensor.shape) + output_shape[0] *= local_size() + output = torch.empty(output_shape).cuda() + return _do_allgather_async(tensor, output, name, version, priority, node_local) + +class BytepsAllgather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor.""" + + @staticmethod + def forward(ctx, tensor, name): + ctx.dim = tensor.shape[0] + ctx.name = name + ctx.version = version + ctx.priority = priority + handle = allgather_async(tensor, name) + return synchronize(handle) + + @staticmethod + def backward(ctx, grad_output): + grad_reduced = allreduce(grad_output, average=True) + + dim_t = torch.IntTensor([ctx.dim]) + dim = allgather(dim_t).view(size()) + + r = rank() + offset = torch.sum(dim.narrow(0, 0, r)).item() if r != 0 else 0 + return grad_reduced.narrow(0, offset, ctx.dim), None + + +def allgather(tensor, name=None, version=0, priority=0): + """ + A function that concatenates the input tensor with the same input tensor on + all other Horovod processes. The input tensor is not modified. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + This acts as a thin wrapper around an autograd function. If your input + tensor requires gradients, then callings this function will allow gradients + to be computed and backpropagated. + + Arguments: + tensor: A tensor to allgather. + name: A name of the allgather operation. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different Horovod processes. + """ + return BytepsAllgather.apply(tensor, name, version, priority) + def poll(handle): """ diff --git a/byteps/torch/sync_batch_norm.py b/byteps/torch/sync_batch_norm.py new file mode 100644 index 0000000000..c65a3c9978 --- /dev/null +++ b/byteps/torch/sync_batch_norm.py @@ -0,0 +1,285 @@ +# Based on https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/_functions.py +# Modifications copyright 2020 Maka Autonomous Robotic Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from byteps.torch.ops import push_pull_async, allgather_async, size, rank, synchronize, local_size, local_rank + +from distutils.version import LooseVersion + +import os +import torch +import time +from torch.autograd.function import Function +import torch.nn.functional as F +from torch.nn.modules.batchnorm import _BatchNorm +from pprint import pprint + + +my_syncbn_dict = dict() +my_counter = 0 + +# Backward compat for old PyTorch +if not hasattr(torch.jit, 'unused'): + torch.jit.unused = lambda x: x + + +_SYNC_BN_V2 = ( + LooseVersion(torch.__version__) >= LooseVersion('1.5.0') and + LooseVersion(torch.__version__) <= LooseVersion('1.6.0') +) +_SYNC_BN_V3 = LooseVersion(torch.__version__) >= LooseVersion('1.6.0') + + +class SyncBatchNorm(_BatchNorm): + """Applies synchronous version of N-dimensional BatchNorm. + + In this version, normalization parameters are synchronized across workers during forward pass. + This is very useful in situations where each GPU can fit a very small number of examples. + + See https://pytorch.org/docs/stable/nn.html#batchnorm2d for more details about BatchNorm. + + Arguments: + num_features: number of channels `C` from the shape `(N, C, ...)` + eps: a value added to the denominator for numerical stability. Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to `None` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to `True`, this module has + learnable affine parameters. Default: `True` + track_running_stats: a boolean value that when set to `True`, this + module tracks the running mean and variance, and when set to `False`, + this module does not track such statistics and always uses batch + statistics in both training and eval modes. Default: `True` + + .. note:: Only GPU input tensors are supported in the training mode. + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True): + global my_counter + super().__init__(num_features, eps, momentum, affine, track_running_stats) + self.mark = torch.tensor([my_counter]) + my_counter += 1 + node_local = os.getenv('BYTEPS_SYNC_BN_GLOBAL', 'False').lower() not in ["true", "1"] + self.sync_size = size() + if node_local: + self.sync_size = local_size() + self.skip_syncbn = os.getenv('BYTEPS_SKIP_SYNC_BN', 'False').lower() in ["true", "1"] + + def _check_input_dim(self, input): + if input.dim() < 2: + raise ValueError('expected at least 2D input (got {}D input)'.format(input.dim())) + + def _run_bn(self, input): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, self.momentum, self.eps) + + @torch.jit.unused + def _maybe_run_sync_bn(self, input): + # if size() == 1: + if self.sync_size == 1: + return self._run_bn(input) + return _SyncBatchNorm.apply( + input, self.weight, self.bias, self.running_mean, self.running_var, + self.eps, self.momentum, self.num_features, self.mark) + + def forward(self, input): + # currently only GPU input is supported by underlying kernel from PyTorch + if not input.is_cuda: + raise ValueError('SyncBatchNorm expected input tensor to be on GPU') + + self._check_input_dim(input) + + if self.training and self.track_running_stats: + self.num_batches_tracked = self.num_batches_tracked + 1 + + if not self.training and self.track_running_stats: + return self._run_bn(input) + elif self.skip_syncbn: + return self._run_bn(input) + else: + return self._maybe_run_sync_bn(input) + + +class _SyncBatchNorm(Function): + @staticmethod + def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, num_features, mark): + global my_syncbn_dict + node_local = os.getenv('BYTEPS_SYNC_BN_GLOBAL', 'False').lower() not in ["true", "1"] + my_rank = rank() + sync_size = size() + if node_local: + my_rank = local_rank() + sync_size = local_size() + input = input.contiguous() + weight = weight.contiguous() + + my_size = input.numel() // input.size(1) + tmp_size = [0] * sync_size + tmp_size[my_rank] = my_size + count = torch.tensor(tmp_size).cuda() + + # calculate mean/invstd for input. + mean, invstd = torch.batch_norm_stats(input, eps) + raw_mean = mean + raw_invstd = invstd + + my_unique_id = mark.numpy()[0] + if my_unique_id not in my_syncbn_dict: + my_syncbn_dict[my_unique_id] = { + "tmp_count": torch.zeros((sync_size,), dtype=mean.dtype, device=mean.device), + "tmp_mean": torch.zeros((sync_size, num_features), dtype=mean.dtype, device=mean.device), + "tmp_invstd": torch.zeros((sync_size, num_features), dtype=mean.dtype, device=mean.device), + "tmp_sum_dy": torch.zeros((num_features,), dtype=mean.dtype, device=mean.device), + "tmp_sum_dy_xmu": torch.zeros((num_features,), dtype=mean.dtype, device=mean.device), + } + tmp_dict = my_syncbn_dict[my_unique_id] + tmp_count = tmp_dict["tmp_count"] + tmp_mean = tmp_dict["tmp_mean"] + tmp_invstd = tmp_dict["tmp_invstd"] + tmp_sum_dy = tmp_dict["tmp_sum_dy"] + tmp_sum_dy_xmu = tmp_dict["tmp_sum_dy_xmu"] + tmp_dict["name_count"] = 'sync_batch_norm.count.' + str(torch.numel(tmp_count)) + '.' + str(tmp_count.dtype) + '.' + str(my_unique_id) + tmp_dict["name_mean"] = 'sync_batch_norm.mean.' + str(torch.numel(tmp_mean)) + '.' + str(tmp_mean.dtype) + '.' + str(my_unique_id) + tmp_dict["name_invstd"] = 'sync_batch_norm.invstd.' + str(torch.numel(tmp_invstd)) + '.' + str(tmp_invstd.dtype) + '.' + str(my_unique_id) + tmp_dict["name_sum_dy"] = 'sync_batch_norm.sum_dy.' + str(torch.numel(tmp_sum_dy)) + '.' + str(tmp_sum_dy.dtype) + '.' + str(my_unique_id) + tmp_dict["name_sum_dy_xmu"] = 'sync_batch_norm.sum_dy_xmu.' + str(torch.numel(tmp_sum_dy_xmu)) + '.' + str(tmp_sum_dy_xmu.dtype) + '.' + str(my_unique_id) + tmp_dict = my_syncbn_dict[my_unique_id] + tmp_count = tmp_dict["tmp_count"].fill_(0) + tmp_mean = tmp_dict["tmp_mean"].fill_(0) + tmp_invstd = tmp_dict["tmp_invstd"].fill_(0) + name_count = tmp_dict["name_count"] + name_mean = tmp_dict["name_mean"] + name_invstd = tmp_dict["name_invstd"] + + tmp_invstd[my_rank].copy_(raw_invstd) + tmp_mean[my_rank].copy_(raw_mean) + tmp_count[my_rank] = my_size + + # raw_count = torch.tensor([my_size*1.0]).cuda() + raw_count = torch.full((1,), input.numel() // input.size(1), + dtype=mean.dtype, + device=mean.device) + count_handle = allgather_async(raw_count.unsqueeze(0), name=name_count, node_local=node_local) + mean_handle = allgather_async(raw_mean.unsqueeze(0), name=name_mean, node_local=node_local) + invstd_handle = allgather_async(raw_invstd.unsqueeze(0), name=name_invstd, node_local=node_local) + + # wait on the async communication to finish + count_all = synchronize(count_handle) + mean_all = synchronize(mean_handle) + invstd_all = synchronize(invstd_handle) + + if _SYNC_BN_V3: + counts_for_bngswc = count_all.view(-1).float().to(input.device) + else: + # backwards compatibility + counts_for_bngswc = count_all.view(-1).tolist() + + # calculate global mean & invstd + mean, invstd = torch.batch_norm_gather_stats_with_counts( + input, + mean_all, + invstd_all, + running_mean, + running_var, + momentum, + eps, + counts_for_bngswc + ) + + self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32), mark) + + # apply element-wise normalization + return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) + + @staticmethod + def backward(self, grad_output): + grad_output = grad_output.contiguous() + saved_input, weight, mean, invstd, count_all, mark = self.saved_tensors + need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[0:3] + + # calculate local stats as well as grad_weight / grad_bias + sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( + grad_output, + saved_input, + mean, + invstd, + weight, + need_input_grad, + need_weight_grad, + need_bias_grad + ) + + if need_input_grad: + node_local = os.getenv('BYTEPS_SYNC_BN_GLOBAL', 'False').lower() not in ["true", "1"] + my_rank = rank() + sync_size = size() + if node_local: + my_rank = local_rank() + sync_size = local_size() + + my_unique_id = mark.numpy()[0] + assert my_unique_id in my_syncbn_dict + tmp_dict = my_syncbn_dict[my_unique_id] + + tmp_sum_dy = tmp_dict["tmp_sum_dy"] + tmp_sum_dy_xmu = tmp_dict["tmp_sum_dy_xmu"] + name_sum_dy = tmp_dict["name_sum_dy"] + name_sum_dy_xmu = tmp_dict["name_sum_dy_xmu"] + + tmp_sum_dy.copy_(sum_dy) + tmp_sum_dy_xmu.copy_(sum_dy_xmu) + sum_dy = tmp_sum_dy + sum_dy_xmu = tmp_sum_dy_xmu + + sum_dy_handle = push_pull_async(sum_dy, average=False, name=name_sum_dy, node_local=node_local) + sum_dy_xmu_handle = push_pull_async(sum_dy_xmu, average=False, name=name_sum_dy_xmu, node_local=node_local) + + # wait on the async communication to finish + sum_dy = synchronize(sum_dy_handle) + sum_dy_xmu = synchronize(sum_dy_xmu_handle) + + if _SYNC_BN_V2 or _SYNC_BN_V3: + count_all_sum = count_all.sum() + mean_dy = sum_dy / count_all_sum + mean_dy_xmu = sum_dy_xmu / count_all_sum + else: + # before 1.5.0, sum_dy was sum of means from every worker, so we just + # need to divide it by number of workers + mean_dy = sum_dy / sync_size + mean_dy_xmu = sum_dy_xmu / sync_size + + # backward pass for gradient calculation + grad_input = torch.batch_norm_backward_elemt( + grad_output, + saved_input, + mean, + invstd, + weight, + mean_dy, + mean_dy_xmu + ) + else: + grad_input = None + + # synchronizing of grad_weight / grad_bias is not needed as distributed + # training would handle all reduce. + if weight is None or not need_weight_grad: + grad_weight = None + + if weight is None or not need_bias_grad: + grad_bias = None + + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None diff --git a/launcher/launch.py b/launcher/launch.py index 3da1e6be6d..77be5ff01f 100644 --- a/launcher/launch.py +++ b/launcher/launch.py @@ -85,10 +85,10 @@ def _get_allocation(nodes, quota): def _get_quota(nodes, local_size): if len(nodes) > 1: - cpu_nums = reduce(lambda x, y: (len(x) + len(y)), nodes) + cpu_nums = reduce(lambda x, y: (x + len(y)), nodes, 0) else: cpu_nums = len(nodes[0]) - + # default quota is the number of cpus for non-root processess default_quota = int(os.getenv("BYTEPS_NUMA_DEFAULT_QUOTA", 6)) while default_quota >= 1 and default_quota * local_size > cpu_nums: diff --git a/tests/test_bps_torch_syncbn.py b/tests/test_bps_torch_syncbn.py new file mode 100644 index 0000000000..936d15edf8 --- /dev/null +++ b/tests/test_bps_torch_syncbn.py @@ -0,0 +1,144 @@ +# Copyright 2018 Uber Technologies, Inc. All Rights Reserved. +# Modifications copyright (C) 2019 Intel Corporation +# Modifications copyright (C) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from distutils.version import LooseVersion + +import inspect +import itertools +import os +import platform +import sys +import unittest +import warnings +import time +import json + +from collections.abc import Iterable + +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +import byteps.torch as bps + +sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, 'utils')) + + +_1_5_api = LooseVersion(torch.__version__) >= LooseVersion('1.5.0') + +ccl_supported_types = set([torch.ByteTensor, torch.CharTensor, torch.ShortTensor, + torch.IntTensor, torch.LongTensor, torch.FloatTensor, + torch.DoubleTensor]) + + +class TorchTests(unittest.TestCase): + """ + Tests for ops in horovod.torch. + """ + + def __init__(self, *args, **kwargs): + super(TorchTests, self).__init__(*args, **kwargs) + warnings.simplefilter('module') + + def convert_cpu_fp16_to_fp32(self, *values): + # PyTorch doesn't support any CPU ops on FP16 tensors. + # In case we need to do ops, we will convert tensor to FP32 here. + result = [] + for value in values: + if value.dtype in [torch.float16, torch.HalfTensor] and not value.is_cuda: + result.append(value.float()) + else: + result.append(value) + return result + + def cast_and_place(self, tensor, dtype): + if dtype.is_cuda: + return tensor.cuda(bps.local_rank()).type(dtype) + return tensor.type(dtype) + + def filter_supported_types(self, types): + if 'CCL_ROOT' in os.environ: + types = [t for t in types if t in ccl_supported_types] + return types + + def test_gpu_required(self): + if not torch.cuda.is_available(): + skip_or_fail_gpu_test(self, "No GPUs available") + + def test_horovod_sync_batch_norm(self): + print("xxx starting the test") + """Tests Horovod version of SyncBatchNorm.""" + if not torch.cuda.is_available(): + self.skipTest("No GPUs available") + + bps.init() + torch.cuda.set_device(bps.local_rank()) + print("xxx my local rank is ", bps.rank()) + + ts_list = [ + torch.stack([ + torch.tensor([ + [r, r + 1], + [r * 2, r * 2 + 1], + [r * 3, r * 3 + 1], + [r * 4, r * 4 + 1] + ]) + for r in range(bps.size()) + ]), + torch.stack([ + torch.tensor([ + [r + 1], + [r * 2 + 1], + [r * 3 + 1], + [r * 4 + 1] + ]) + for r in range(bps.size()) + ]), + ] + + sync_bn = bps.SyncBatchNorm(num_features=4) + sync_bn.cuda(bps.local_rank()) + + bn = torch.nn.BatchNorm1d(num_features=4) + bn.cuda(bps.local_rank()) + for idx, ts in enumerate(ts_list): + + ts = ts.cuda(bps.local_rank()).float() + ts1 = ts.clone().requires_grad_() + ts2 = ts.clone().requires_grad_() + + # Training + sync_bn_out = sync_bn(ts1[bps.rank()].unsqueeze(0)) + bn_out = bn(ts2) + assert torch.allclose(sync_bn_out, bn_out[bps.rank()].unsqueeze(0), 1e-6) + assert torch.allclose(sync_bn.running_mean, bn.running_mean, 1e-6) + assert torch.allclose(sync_bn.running_var, bn.running_var, 1e-6) + + # Gradients + sync_bn_out.sum().backward() + bn_out.mean(dim=0).sum().backward() + assert torch.allclose(bps.push_pull(sync_bn.weight.grad, name='sync_bn.weight.grad.' + str(idx)), bn.weight.grad, 1e-6) + assert torch.allclose(bps.push_pull(sync_bn.bias.grad, name='sync_bn.bias.grad.' + str(idx)), bn.bias.grad, 1e-6) + assert torch.allclose(bps.push_pull(ts1.grad, name='ts1.grad.' + str(idx)), ts2.grad, 1e-6) + break + + + +if __name__ == "__main__": + unittest.main() From 7355e1669e09edc7a5bca5103957cdf34952de09 Mon Sep 17 00:00:00 2001 From: Yulu Jia Date: Wed, 19 May 2021 09:31:54 -0700 Subject: [PATCH 2/2] add SyncBatchNorm - add byteps.torch.SyncBatchNorm - add BYTEPS_SYNC_BN_GLOBAL to choose between global sync and node local sync. defaults to node local sync - add node local allgather Signed-off-by: yulu.jia Signed-off-by: Yulu Jia --- byteps/torch/ops.py | 4 ++++ tests/test_bps_torch_syncbn.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/byteps/torch/ops.py b/byteps/torch/ops.py index 348d6358d8..812db55ebe 100644 --- a/byteps/torch/ops.py +++ b/byteps/torch/ops.py @@ -226,6 +226,10 @@ def allgather_async(tensor, name=None, version=0, priority=0, node_local=True): Arguments: tensor: A tensor to allgather. name: A name of the allgather operation. + node_local: if True, the allgather is performed among ranks on the same + node. If False, the allgather is performed globally among + all ranks. Only node_local = True is implemented for the + moment. Returns: A handle to the allgather operation that can be used with `poll()` or diff --git a/tests/test_bps_torch_syncbn.py b/tests/test_bps_torch_syncbn.py index 936d15edf8..2ffe638450 100644 --- a/tests/test_bps_torch_syncbn.py +++ b/tests/test_bps_torch_syncbn.py @@ -82,14 +82,12 @@ def test_gpu_required(self): skip_or_fail_gpu_test(self, "No GPUs available") def test_horovod_sync_batch_norm(self): - print("xxx starting the test") """Tests Horovod version of SyncBatchNorm.""" if not torch.cuda.is_available(): self.skipTest("No GPUs available") bps.init() torch.cuda.set_device(bps.local_rank()) - print("xxx my local rank is ", bps.rank()) ts_list = [ torch.stack([