Skip to content

Commit

Permalink
1-bit adam (#353)
Browse files Browse the repository at this point in the history
* 1-bit Adam v1 (squash) (#346)

* testing_onebit

* test_passed

* test_passed

* updated compressed_allreduce

* 123

* cpu2gpu test added

* Add non-cuda-aware code path. Segfaults for > 2 procs.

* Works for 4 procs with numpy buffers now. TODO: cleanup, evalute perf.

* Fix gather. Cleanup.

* Add new tests.

* Reduce memory footprint. BERT large with BS=16 works.

* Revert "Reduce memory footprint. BERT large with BS=16 works."

This reverts commit e7f38fc.

* Update optim to support bert-large.

* with initialization added on bert_onebit_adam

* This works!!

* Force igather for cupy. Better performance. Need cleanup and reorg to
support TCP now.

* X

* testing the fintune task for FP32 training

* Added the fintune taks for FP32 training

* With the control flag of freeze_step added

* added the freeze_step inside fp32_onebitadam

* Seperate freeze_kernnel added

* Added the sanity test for the Compressed_Allreduce

* Test for Compressed_Allreduce passed, but AllGather need sync.

* add checks for finetuning.

* Running passed for finetune on EastUs

* Add one bit adam clean file.

* Refactor comms. code and move it to a new file.

* fix compile/run errors.

* Adding changes for onebit_adam from Hank

* Save memory by modifying in-place.

Co-authored-by: Your Name <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: tanghl1994 <[email protected]>
Co-authored-by: Hank <[email protected]>
Co-authored-by: root <[email protected]>

* Staging 1bit adam v1 (#348)

* Refactor to correct locations.

* Deleted unused files.

* Fix imports for refactored codebase.

* update the com reduce test.

* Fix some errors.

* Fix optimizer name

* Delete unused tests.

* Fix formatting for pre-commit.

* Add cupy dependencies.

* add cupy for cuda 10add cupy for cuda 10

* Add mpi4py requirement.

Co-authored-by: Ammar Ahmad Awan <[email protected]>

* Use correct initialization for exp_avg.

* Cleanup onebit adam.

* minor wording fix.

* Cleanup custom collectives.

* Fixes for TCP support.

* fix formatting. fix formatting.

* move 1bit adam reqs

* delay importing 1bit adam unless it's used, this will ensure we delay importing mpi4py

* Fix cuda version parsing.

* Temporary tcp fix.

* Update install.sh

* Refactor code to properly support cuda-aware comm.

* Fix imports.

* add cuda_aware flag to tests.

* Cleanup. Add copyrights.

* Add 1-bit Adam tutorial v1.

* Minor fixes to copyright and print statements.

* Update utils.py

* Update utils.py

Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Your Name <[email protected]>
Co-authored-by: tanghl1994 <[email protected]>
Co-authored-by: Hank <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Ammar Ahmad Awan <[email protected]>
  • Loading branch information
7 people authored Sep 8, 2020
1 parent 1ebcd6c commit fa66867
Show file tree
Hide file tree
Showing 12 changed files with 871 additions and 11 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from deepspeed.utils import logger

TENSOR_CORE_ALIGN_SIZE = 8
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER]


def get_amp_enabled(param_dict):
Expand Down
154 changes: 154 additions & 0 deletions deepspeed/runtime/custom_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''

from mpi4py import MPI
import numpy as np
import cupy


def my_igather(rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req


def gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)


def gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)

# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed

for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])

numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)

cupy.cuda.get_current_stream().synchronize()

# 2. use numpy buffers for communication
requests = []

for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)

# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])

cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()

return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale


def allgather_cuda(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)


def allgather_host(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):

# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros([comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)

numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()

# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
22 changes: 16 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
Expand All @@ -27,8 +28,6 @@
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules

from deepspeed.ops.lamb import FusedLamb

from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer

Expand Down Expand Up @@ -122,6 +121,7 @@ def __init__(self,
self.config_params = config_params
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True

if dist_init_required is None:
dist_init_required = not dist.is_initialized()
Expand Down Expand Up @@ -527,6 +527,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
# print(optimizer_parameters.keys())
if 'max_grad_norm' in optimizer_parameters.keys():
raise ValueError(
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
Expand All @@ -535,7 +536,11 @@ def _configure_basic_optimizer(self, model_parameters):
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
Expand All @@ -545,7 +550,8 @@ def _configure_fp16_optimizer(self, optimizer):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
Expand Down Expand Up @@ -734,7 +740,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)

def backward(self, loss, allreduce_gradients=True):
def backward(self, loss, allreduce_gradients=True, release_loss=False):
r"""Execute backward pass on the loss
Arguments:
Expand Down Expand Up @@ -796,7 +802,7 @@ def backward(self, loss, allreduce_gradients=True):
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce').start()

if allreduce_gradients:
if allreduce_gradients and self.enable_backward_allreduce:
self.allreduce_gradients()

if self.wall_clock_breakdown():
Expand All @@ -805,6 +811,10 @@ def backward(self, loss, allreduce_gradients=True):
self.timers('backward').stop()
self.timers('backward_microstep').stop()

if release_loss:
# loss.data = None
pass

return loss

def is_gradient_accumulation_boundary(self):
Expand Down
18 changes: 18 additions & 0 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ def __init__(self,

self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.initialize_optimizer_states()

def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = torch.zeros(
self.fp32_groups_flat[i].size(),
device=self.fp32_groups_flat[i].device)

self.optimizer.step()

for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = None

return

def zero_grad(self, set_grads_to_None=True):
"""
Expand Down Expand Up @@ -204,6 +218,9 @@ def step(self, closure=None):
if p.grad is None else p.grad.to(data_type) for p in group
]))

for p in group:
p.grad = None

self.fp32_groups_flat[i].grad = grads_groups_flat[i]

self.start_timers([COMPUTE_NORM])
Expand All @@ -223,6 +240,7 @@ def step(self, closure=None):
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
self.log_timers(OVERFLOW_TIMERS)
grads_groups_flat = None
return self.overflow

self.start_timers([UNSCALE_AND_CLIP])
Expand Down
Loading

0 comments on commit fa66867

Please sign in to comment.