Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1-bit adam #353

Merged
merged 22 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from deepspeed.utils import logger

TENSOR_CORE_ALIGN_SIZE = 8
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER]


def get_amp_enabled(param_dict):
Expand Down
154 changes: 154 additions & 0 deletions deepspeed/runtime/custom_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''

from mpi4py import MPI
import numpy as np
import cupy


def my_igather(rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req


def gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# We do in-place operations on cupy buffers so we do not return any buffers
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)


def gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
# In-place operations are not possible for newly created cupy arrays
# so we need to return the new buffers
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)

# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed

for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])

numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)

cupy.cuda.get_current_stream().synchronize()

# 2. use numpy buffers for communication
requests = []

for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)

# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])

cupy_worker_scale = cupy.asarray(numpy_worker_scale)
cupy_recvbuf_scale = cupy.asarray(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()

return cupy_sign_list_packed, cupy_recvbuf_sign, cupy_worker_scale, cupy_recvbuf_scale


def allgather_cuda(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)


def allgather_host(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):

# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros([comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)

numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed)
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()

# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.asarray(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.asarray(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.asarray(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.asarray(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

return cupy_server_sign_packed, cupy_recvbuf_sign_server, cupy_server_scale, cupy_recvbuf_scale_server
22 changes: 16 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
Expand All @@ -27,8 +28,6 @@
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules

from deepspeed.ops.lamb import FusedLamb

from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer

Expand Down Expand Up @@ -122,6 +121,7 @@ def __init__(self,
self.config_params = config_params
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True

if dist_init_required is None:
dist_init_required = not dist.is_initialized()
Expand Down Expand Up @@ -527,6 +527,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
# print(optimizer_parameters.keys())
if 'max_grad_norm' in optimizer_parameters.keys():
raise ValueError(
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
Expand All @@ -535,7 +536,11 @@ def _configure_basic_optimizer(self, model_parameters):
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
Expand All @@ -545,7 +550,8 @@ def _configure_fp16_optimizer(self, optimizer):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
Expand Down Expand Up @@ -734,7 +740,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)

def backward(self, loss, allreduce_gradients=True):
def backward(self, loss, allreduce_gradients=True, release_loss=False):
r"""Execute backward pass on the loss

Arguments:
Expand Down Expand Up @@ -796,7 +802,7 @@ def backward(self, loss, allreduce_gradients=True):
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce').start()

if allreduce_gradients:
if allreduce_gradients and self.enable_backward_allreduce:
self.allreduce_gradients()

if self.wall_clock_breakdown():
Expand All @@ -805,6 +811,10 @@ def backward(self, loss, allreduce_gradients=True):
self.timers('backward').stop()
self.timers('backward_microstep').stop()

if release_loss:
# loss.data = None
pass

return loss

def is_gradient_accumulation_boundary(self):
Expand Down
18 changes: 18 additions & 0 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ def __init__(self,

self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.initialize_optimizer_states()

def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = torch.zeros(
self.fp32_groups_flat[i].size(),
device=self.fp32_groups_flat[i].device)

self.optimizer.step()

for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = None

return

def zero_grad(self, set_grads_to_None=True):
"""
Expand Down Expand Up @@ -204,6 +218,9 @@ def step(self, closure=None):
if p.grad is None else p.grad.to(data_type) for p in group
]))

for p in group:
p.grad = None

self.fp32_groups_flat[i].grad = grads_groups_flat[i]

self.start_timers([COMPUTE_NORM])
Expand All @@ -223,6 +240,7 @@ def step(self, closure=None):
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
self.log_timers(OVERFLOW_TIMERS)
grads_groups_flat = None
return self.overflow

self.start_timers([UNSCALE_AND_CLIP])
Expand Down
Loading