Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1-bit adam #353

Merged
merged 22 commits into from
Sep 8, 2020
Merged
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
@@ -14,9 +14,10 @@
from deepspeed.utils import logger

TENSOR_CORE_ALIGN_SIZE = 8
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER]
DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER]


def get_amp_enabled(param_dict):
165 changes: 165 additions & 0 deletions deepspeed/runtime/custom_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from mpi4py import MPI
import numpy as np
import cupy


def my_igather(rank, size, comm, sendbuf, recbuf, root):
req = []
if rank == root:
for idx in range(size):
if idx != rank:
req.append(comm.Irecv(recbuf[idx], source=idx))
else:
recbuf[rank] = sendbuf
else:
req.append(comm.Isend(sendbuf, dest=root))
return req


def gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
requests = []
for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
cupy_sign_list_packed[idx],
cupy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
cupy_worker_scale,
cupy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)


def gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
numpy_recvbuf_sign = np.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
numpy_recvbuf_scale = np.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)

# 1. convert from cupy to numpy
numpy_sign_list_packed = cupy_sign_list_packed

for idx in range(world_size):
numpy_sign_list_packed[idx] = cupy.asnumpy(cupy_sign_list_packed[idx])

numpy_worker_scale = cupy.asnumpy(cupy_worker_scale)
numpy_recvbuf_scale = cupy.asnumpy(cupy_recvbuf_scale)

cupy.cuda.get_current_stream().synchronize()

# 2. use numpy buffers for communication
requests = []

for idx in range(world_size):
req_sign = my_igather(rank,
world_size,
comm,
numpy_sign_list_packed[idx],
numpy_recvbuf_sign,
root=idx)
requests += req_sign

for idx in range(world_size):
req_scale = my_igather(rank,
world_size,
comm,
numpy_worker_scale,
numpy_recvbuf_scale,
root=idx)
requests += req_scale

MPI.Request.Waitall(requests)

# 3. Convert back from numpy to cupy
cupy_recvbuf_sign = cupy.asarray(numpy_recvbuf_sign)
for idx in range(world_size):
cupy_sign_list_packed[idx] = cupy.asarray(numpy_sign_list_packed[idx])

cupy_worker_scale = cupy.array(numpy_worker_scale)
cupy_recvbuf_scale = cupy.array(numpy_recvbuf_scale)
cupy.cuda.get_current_stream().synchronize()


def gather(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale):
cuda_aware = False
if cuda_aware:
gather_cuda(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
else:
gather_host(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)


def allgather(comm,
cupy_server_sign_packed,
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server):
cuda_aware = False
if cuda_aware:
comm.Allgather(cupy_server_sign_packed, cupy_recvbuf_sign_server)
comm.Allgather(cupy_server_scale, cupy_recvbuf_scale_server)
else:
# 1. Convert cupy to numpy
numpy_recvbuf_sign_server = np.zeros(
[comm.Get_size(),
cupy_server_sign_packed.size],
dtype=cupy_server_sign_packed.dtype)
numpy_recvbuf_scale_server = np.zeros([comm.Get_size(),
1],
dtype=cupy_server_scale.dtype)

numpy_server_sign_packed = cupy.asnumpy(cupy_server_sign_packed[0])
numpy_recvbuf_sign_server = cupy.asnumpy(cupy_recvbuf_sign_server)
numpy_server_scale = cupy.asnumpy(cupy_server_scale)
numpy_recvbuf_scale_server = cupy.asnumpy(cupy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()

# 2. Communicate numpy buffers
comm.Allgather(numpy_server_sign_packed, numpy_recvbuf_sign_server)
comm.Allgather(numpy_server_scale, numpy_recvbuf_scale_server)
comm.Barrier()

# 3. Convert numpy back to cupy
cupy_server_sign_packed = cupy.array(numpy_server_sign_packed)
cupy_recvbuf_sign_server = cupy.array(numpy_recvbuf_sign_server)
cupy_server_scale = cupy.array(numpy_server_scale)
cupy_recvbuf_scale_server = cupy.array(numpy_recvbuf_scale_server)
cupy.cuda.get_current_stream().synchronize()
22 changes: 16 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,8 @@
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
@@ -27,8 +28,6 @@
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules

from deepspeed.ops.lamb import FusedLamb

from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer

@@ -122,6 +121,7 @@ def __init__(self,
self.config_params = config_params
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True

if dist_init_required is None:
dist_init_required = not dist.is_initialized()
@@ -527,6 +527,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
# print(optimizer_parameters.keys())
if 'max_grad_norm' in optimizer_parameters.keys():
raise ValueError(
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
@@ -535,7 +536,11 @@ def _configure_basic_optimizer(self, model_parameters):
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam
optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
@@ -545,7 +550,8 @@ def _configure_fp16_optimizer(self, optimizer):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.optimizer_name() == ADAM_OPTIMIZER or self.optimizer_name(
) == ONEBIT_ADAM_OPTIMIZER:
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
@@ -734,7 +740,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)

def backward(self, loss, allreduce_gradients=True):
def backward(self, loss, allreduce_gradients=True, release_loss=False):
r"""Execute backward pass on the loss
Arguments:
@@ -796,7 +802,7 @@ def backward(self, loss, allreduce_gradients=True):
self.timers('backward_allreduce_microstep').start()
self.timers('backward_allreduce').start()

if allreduce_gradients:
if allreduce_gradients and self.enable_backward_allreduce:
self.allreduce_gradients()

if self.wall_clock_breakdown():
@@ -805,6 +811,10 @@ def backward(self, loss, allreduce_gradients=True):
self.timers('backward').stop()
self.timers('backward_microstep').stop()

if release_loss:
# loss.data = None
pass

return loss

def is_gradient_accumulation_boundary(self):
18 changes: 18 additions & 0 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
@@ -101,6 +101,20 @@ def __init__(self,

self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.initialize_optimizer_states()

def initialize_optimizer_states(self):
for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = torch.zeros(
self.fp32_groups_flat[i].size(),
device=self.fp32_groups_flat[i].device)

self.optimizer.step()

for i, group in enumerate(self.fp16_groups):
self.fp32_groups_flat[i].grad = None

return

def zero_grad(self, set_grads_to_None=True):
"""
@@ -204,6 +218,9 @@ def step(self, closure=None):
if p.grad is None else p.grad.to(data_type) for p in group
]))

for p in group:
p.grad = None

self.fp32_groups_flat[i].grad = grads_groups_flat[i]

self.start_timers([COMPUTE_NORM])
@@ -223,6 +240,7 @@ def step(self, closure=None):
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
self.log_timers(OVERFLOW_TIMERS)
grads_groups_flat = None
return self.overflow

self.start_timers([UNSCALE_AND_CLIP])
356 changes: 356 additions & 0 deletions deepspeed/runtime/fp16/onebit_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
'''
Copyright 2019 The Microsoft DeepSpeed Team
'''
import types
import torch
import importlib
import numpy as np
import time
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
from deepspeed.utils.logging import logger

from mpi4py import MPI
from deepspeed.runtime.custom_collectives import gather, allgather


class OnebitAdam(torch.optim.Optimizer):
"""Implements the 1-bit Adam algorithm. Currently GPU-only.
For usage example please see, TODO DeepSpeed Tutorial
It has been proposed in APMSqueeze (https://arxiv.org/abs/2008.11343)
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
freeze_step (int, optional): Number of steps for warmup (uncompressed)
stage before we start using compressed communication. (default 100000)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0)
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self,
params,
deepspeed=None,
lr=1e-3,
freeze_step=100000,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
eps_inside_sqrt=False,
weight_decay=0.,
max_grad_norm=0.,
amsgrad=False):

if amsgrad:
raise RuntimeError('FusedLamb does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm)

super(OnebitAdam, self).__init__(params, defaults)
from mpi4py import MPI
self.eps_mode = 0 if eps_inside_sqrt else 1

self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.comm_time = 0.0
self.step_time = 0.0
self.ave_step = 1
self.bk_time = 0.0
self.divider = int(self.size * 8 / np.gcd(self.size, 8))
self.deepspeed = deepspeed
self.adam_freeze_key = False
self.initialize = False
self.freeze_step = freeze_step

def torch2cupy(self, tensor):
return cupy.fromDlpack(to_dlpack(tensor))

def cupy2torch(self, cupy_tensor):
return from_dlpack(cupy_tensor.toDlpack())

def compress_by_chunk(self, cupy_bool_tensor, num_chunks):
packed_sign = cupy.packbits(cupy_bool_tensor)
sign_list_packed = cupy.split(packed_sign, num_chunks)
cupy.cuda.get_current_stream().synchronize()
return sign_list_packed

def Compressed_Allreduce(self,
buffer_m: torch.tensor,
worker_error,
server_error,
rank,
world_size,
comm,
local_rank):

all_start_time = time.time()
original_size = buffer_m.numel()
cupy.cuda.Device(local_rank).use()

if torch.numel(buffer_m) != torch.numel(worker_error):
empty_tensor = torch.zeros(torch.numel(worker_error) - torch.numel(buffer_m),
device=buffer_m.device)
buffer_m = torch.cat([buffer_m, empty_tensor])

buffer_m.add_(worker_error)
worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
sign_buffer_m = buffer_m.sign().add_(1).bool()
sign_buffer_m = sign_buffer_m.float()
sign_buffer_m.add_(-0.5).mul_(2.0)
worker_error.set_((buffer_m - worker_scale * sign_buffer_m))
sign_buffer_m = None

compensated_buffer_m = buffer_m
compensated_buffer_m.sign_()
compensated_buffer_m = compensated_buffer_m.add_(1).bool()
cupy_worker_scale = self.torch2cupy(worker_scale)
cupy_compensated_buffer_m = self.torch2cupy(compensated_buffer_m)
compensated_buffer_m = None

cupy_sign_list_packed = self.compress_by_chunk(cupy_compensated_buffer_m,
world_size)
cupy_compensated_buffer_m = None

cupy_recvbuf_sign = cupy.zeros([world_size,
cupy_sign_list_packed[rank].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale = cupy.zeros([world_size, 1], dtype=cupy_worker_scale.dtype)

# Communication Phase 1
gather_start = time.time()
gather(rank,
world_size,
comm,
cupy_sign_list_packed,
cupy_recvbuf_sign,
cupy_worker_scale,
cupy_recvbuf_scale)
gather_end = time.time()

cupy_unpacked_sign = (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
world_size,
-1)
cupy_recvbuf_sign = None
unpacked_sign = self.cupy2torch(cupy_unpacked_sign).float()
cupy_unpacked_sign = None
unpacked_sign = unpacked_sign.add_(-0.5).mul_(2.0)
worker_scale = self.cupy2torch(cupy_recvbuf_scale).mul_(1 / world_size)
compensated_server_m = unpacked_sign.mul_(worker_scale).sum(0)
unpacked_sign = None

compensated_server_m.add_(server_error)
server_scale = torch.norm(compensated_server_m) / np.sqrt(
compensated_server_m.numel())
sign_server_m = compensated_server_m.sign().add_(1).bool()
sign_server_m = sign_server_m.float()
sign_server_m.add_(-0.5).mul_(2.0)
server_error.set_(compensated_server_m - server_scale * sign_server_m)
sign_server_m = None

compensated_server_m.sign_()
compensated_server_m = compensated_server_m.add_(1).bool()
cupy_server_scale = self.torch2cupy(server_scale)
cupy_compensated_server_m = self.torch2cupy(compensated_server_m)
compensated_server_m = None

cupy_server_sign_packed = self.compress_by_chunk(cupy_compensated_server_m, 1)

cupy_recvbuf_sign_server = cupy.zeros(
[world_size,
cupy_server_sign_packed[0].size],
dtype=cupy_sign_list_packed[0].dtype)
cupy_recvbuf_scale_server = cupy.zeros([world_size,
1],
dtype=cupy_worker_scale.dtype)

# Communication Phase 2
allgather(comm,
cupy_server_sign_packed[0],
cupy_recvbuf_sign_server,
cupy_server_scale,
cupy_recvbuf_scale_server)

cupy_server_unpacked_sign = (cupy.unpackbits(
cupy_recvbuf_sign_server.flatten())).reshape(world_size,
-1)
cupy_recvbuf_sign_server = None

server_unpacked_sign = self.cupy2torch(cupy_server_unpacked_sign)
cupy_server_unpacked_sign = None

server_unpacked_sign = server_unpacked_sign.float().add_(-0.5).mul_(2.0)
server_scale = self.cupy2torch(cupy_recvbuf_scale_server)
buffer_m = server_unpacked_sign.mul_(server_scale).flatten()[0:original_size]

return buffer_m

def step(self, closure=None, grads=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced recision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()

gather_time = 0
allgather_time = 0
all_time = 0

if self.adam_freeze_key is False:
v_diff_buffer = 0.0

if grads is None:
grads_group = [None] * len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0]) != list:
grads_group = [grads]
else:
grads_group = grads

for group, grads_this_group in zip(self.param_groups, grads_group):
if grads_this_group is None:
grads_this_group = [None] * len(group['params'])

bias_correction = 1 if group['bias_correction'] else 0

for p, grad in zip(group['params'], grads_this_group):
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

state['tensor_size'] = torch.numel(p.data)
state['corrected_tensor_size'] = state['tensor_size']

if state['tensor_size'] % (self.size * self.divider) != 0:
state['corrected_tensor_size'] += ((self.size * self.divider) -
(state['tensor_size'] %
(self.size * self.divider)))
state['server_chunk_size'] = state[
'corrected_tensor_size'] // self.size

if not self.initialize or (self.adam_freeze_key
and 'worker_error' not in state.keys()):
if torch.distributed.get_rank() == 0:
print("Allocating worker_error and setting exp_avg_sq to half")
torch.cuda.empty_cache()
state['worker_error'] = torch.zeros(state['corrected_tensor_size'],
device=p.device)
state['server_error'] = torch.zeros(state['server_chunk_size'],
device=p.device)
torch.cuda.empty_cache()
self.adam_freeze_key = True
if not self.initialize and torch.distributed.get_rank() == 0:
print("Cupy Buffers Initialized")

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

if self.adam_freeze_key is False:
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
grad = None
if self.initialize:
update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])

else:
if 'non_freeze' in group.keys() and group['non_freeze'] is True:
dist.all_reduce(grad)
grad.mul_(1 / dist.get_world_size())
exp_avg.mul_(beta1).add(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
grad = None
else:
if self.initialize is True:
exp_avg.mul_(beta1).add_(1 - beta1, grad)
grad = None

if self.size > 1:
exp_avg.set_(
self.Compressed_Allreduce(exp_avg,
state['worker_error'],
state['server_error'],
self.rank,
self.size,
self.comm,
self.deepspeed.local_rank))
if self.initialize:
update = exp_avg / (exp_avg_sq.sqrt() + group['eps'])

if self.initialize:
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
with torch.no_grad():
p.add_(-group['lr'] * update)

if not self.initialize:
print('Pop out errors', flush=True)
state.pop('worker_error')
state.pop('server_error')

if not self.initialize:
self.adam_freeze_key = False
self.initialize = True
print(
f"Finished the initialization step at rant {torch.distributed.get_rank()}"
)
return loss

if self.adam_freeze_key is False:
if state['step'] >= self.freeze_step:
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False

return loss
6 changes: 6 additions & 0 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@

import torch
from torch._six import inf
import torch.distributed as dist

from deepspeed.utils import logger

@@ -32,6 +33,11 @@ def check_using_norm(self, norm_group):
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group())
overflow = overflow_gpu[0].item()
else:
cuda_overflow = torch.cuda.FloatTensor([overflow])
dist.all_reduce(cuda_overflow, op=torch.distributed.ReduceOp.MAX)
dist.barrier()
overflow = cuda_overflow[0].item()

return bool(overflow)

6 changes: 4 additions & 2 deletions deepspeed/utils/timer.py
Original file line number Diff line number Diff line change
@@ -145,7 +145,8 @@ def stop(self, report_speed=True):
torch.cuda.synchronize()
self.end_time = time.time()
duration = self.end_time - self.start_time
self.total_elapsed_time += duration
# self.total_elapsed_time += duration
self.total_elapsed_time = duration
if self.local_step_count % self.steps_per_output == 0:
if report_speed:
self.logging("{}/{}, SamplesPerSec={}".format(
@@ -165,7 +166,8 @@ def avg_samples_per_sec(self):
if self.total_step_count > 0:
samples_per_step = self.batch_size * self.num_workers
total_step_offset = self.total_step_count - self.start_step
avg_time_per_step = self.total_elapsed_time / total_step_offset
# avg_time_per_step = self.total_elapsed_time / total_step_offset
avg_time_per_step = self.total_elapsed_time / self.steps_per_output
# training samples per second
return samples_per_step / avg_time_per_step
return float("-inf")
1 change: 1 addition & 0 deletions requirements/requirements-1bit-adam.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mpi4py
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,11 @@ def fetch_requirements(path):
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')

onebit_adam_requires = fetch_requirements('requirements/requirements-1bit-adam.txt')
if torch.cuda.is_available():
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
install_requires += onebit_adam_requires

# Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10
74 changes: 74 additions & 0 deletions tests/onebitadam/test_com_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from mpi4py import MPI
import time
import torch
import torch.distributed as dist
import numpy as np
import deepspeed
from deepspeed.runtime.fp16.onebit_adam import OnebitAdam

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

torch.distributed.init_process_group(backend='nccl',
init_method='tcp://worker-1:2345',
world_size=size,
rank=rank)

dummy_model = [torch.nn.Parameter(torch.ones(10))]
dummy_optim = OnebitAdam(dummy_model)

device = torch.device('cuda', rank % torch.cuda.device_count())


def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
#a_sign = a.sign()
scale = a.norm() / np.sqrt(a.numel())
a_compressed = scale * a_sign
a_sign = None
worker_error = a - a_compressed
dist.all_reduce(a_compressed)
a_compressed.mul_(1 / dist.get_world_size())
a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
#a_server_sign = a_compressed.sign()
a_list = torch.chunk(a_compressed, chunks=dist.get_world_size())
server_scale = [chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list]
a_sign_list = torch.chunk(a_server_sign, dist.get_world_size())
a_server_compressed = torch.cat(
[server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())])
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
return a_server_compressed, worker_error, server_error


tensor_size = 200 * 2**20
server_size = int(tensor_size / size)

# a = -torch.ones(tensor_size, device=device)
a = torch.randn(tensor_size, device=device)
#if rank == 0:
# print('a is: ',a)
worker_error = torch.zeros_like(a)
server_error = torch.zeros(server_size, device=device)
a_torch, worker_error_torch, server_error_torch = torch_sim(a)
torch.cuda.empty_cache()
local_rank = rank % torch.cuda.device_count()
a_after = dummy_optim.Compressed_Allreduce(a,
worker_error,
server_error,
rank,
size,
comm,
local_rank)
#print('a becomes ',a)
#if rank == 0:
if True:
print('Rank is {} =============================================='.format(rank))
print('Diff is: ', torch.norm(a_after - a_torch))
#print('Original Norm is: ', torch.norm(a_after))
#print('Compressed_addreduce gives: ', a_after[0:10])
print('Worker error diff is: ', torch.norm(worker_error - worker_error_torch))
print('Server error is: ', torch.norm(server_error - server_error_torch))
print('+++++++++++++++++++++++++++++++')
#print('torch sim gives: ', a_torch[0:10])