Skip to content

Commit

Permalink
Add CPUAdam optimizer for zero-offload in deepspeed engine (microsoft…
Browse files Browse the repository at this point in the history
…#484)

* add adamW to CPU-ADAM implementation

* supporting cpu-adam optimizer for zero-offload on deepspeed side

* bump DSE to match cpu-adam updates

Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
RezaYazdaniAminabadi and jeffra authored Oct 30, 2020
1 parent d720fdb commit f5aa254
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 40 deletions.
58 changes: 43 additions & 15 deletions csrc/adam/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void Adam_Optimizer::Step(float* _params,
float bias_correction2 = 1 / sqrt(1 - _betta2_t);

float step_size = -1 * _alpha / bias_correction1;

float w_decay = -1 * _alpha * _weight_decay;
size_t rounded_size = 0;

#if defined(__AVX512__) or defined(__AVX256__)
Expand All @@ -57,8 +57,8 @@ void Adam_Optimizer::Step(float* _params,
step_size_4.data = SIMD_SET(step_size);

AVX_Data weight_decay4;
if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay);

if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);

for (size_t t = 0; t < rounded_size; t += TILE) {
Expand All @@ -78,9 +78,9 @@ void Adam_Optimizer::Step(float* _params,
AVX_Data param_4;
param_4.data = SIMD_LOAD(_params + i);

if (_weight_decay > 0)
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data);

}
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);

Expand All @@ -91,7 +91,9 @@ void Adam_Optimizer::Step(float* _params,
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);

if (_weight_decay > 0 && _adamw_mode) {
param_4.data = SIMD_FMA(param_4.data, weight_decay4.data, param_4.data);
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);

SIMD_STORE(_params + i, param_4.data);
Expand Down Expand Up @@ -119,8 +121,7 @@ void Adam_Optimizer::Step(float* _params,
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0) grad = param * _weight_decay + grad;

if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum *= momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;

Expand All @@ -131,7 +132,7 @@ void Adam_Optimizer::Step(float* _params,
grad = sqrt(variance);
grad = grad * bias_correction2 + _eps;
grad = momentum / grad;

if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param;

Expand Down Expand Up @@ -184,6 +185,10 @@ void Adam_Optimizer::Step_4(float* _params,
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);

float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2));

for (size_t t = 0; t < rounded_size; t += TILE) {
Expand Down Expand Up @@ -216,9 +221,7 @@ void Adam_Optimizer::Step_4(float* _params,
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);

if (_weight_decay > 0) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
Expand Down Expand Up @@ -261,6 +264,13 @@ void Adam_Optimizer::Step_4(float* _params,
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);

if (_weight_decay > 0 && _adamw_mode) {
param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
}

param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
Expand Down Expand Up @@ -313,9 +323,11 @@ int create_adam_optimizer(int optimizer_id,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
float weight_decay = 0,
bool adamw_mode = true)
{
auto opt = std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay);
auto opt =
std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);

s_optimizers[optimizer_id] = opt;
#if defined(__AVX512__)
Expand Down Expand Up @@ -369,6 +381,11 @@ void Adam_Optimizer::Step_8(float* _params,
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);

float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data =
(_adamw_mode ? _mm512_set1_ps(w_decay) : _mm512_set1_ps(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3));

for (size_t t = 0; t < rounded_size; t += TILE) {
Expand Down Expand Up @@ -417,7 +434,7 @@ void Adam_Optimizer::Step_8(float* _params,
param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6);
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);

if (_weight_decay > 0) {
if (_weight_decay > 0 && !_adamw_mode) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
Expand Down Expand Up @@ -498,6 +515,17 @@ void Adam_Optimizer::Step_8(float* _params,
grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data);

if (_weight_decay > 0 && _adamw_mode) {
param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, param_4[7].data);
}

param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
Expand Down
7 changes: 5 additions & 2 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,17 @@ class Adam_Optimizer {
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_buf_index(false)
_buf_index(false),
_adamw_mode(adamw_mode)
{
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
Expand Down Expand Up @@ -115,4 +117,5 @@ class Adam_Optimizer {

float* _doubled_buffer[2];
bool _buf_index;
bool _adamw_mode;
};
2 changes: 1 addition & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import ops

from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig
Expand Down
49 changes: 47 additions & 2 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import math
import torch
import importlib
Expand All @@ -6,6 +10,40 @@


class DeepSpeedCPUAdam(torch.optim.Optimizer):
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
- Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
- AdamW: FIXING WEIGHT DECAY REGULARIZATION IN ADAM (https://arxiv.org/abs/1711.05101v1)
DeepSpeed CPU Adam(W) provides between 5x to 7x speedu over torch.optim.adam(W).
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
reside on the CPU memory.
To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
For calling step function, there are two options available: (1) update optimizer's states and (2) update
optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
option can bring 30% higher throughput than the doing the copy separately using option one.
Arguments:
model_params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
"""

optimizer_id = 0

Expand All @@ -16,7 +54,8 @@ def __init__(self,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False):
amsgrad=False,
adamw_mode=True):

default_args = dict(lr=lr,
betas=betas,
Expand All @@ -30,7 +69,13 @@ def __init__(self,

global ds_opt_adam
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay)
ds_opt_adam.create_adam(self.opt_id,
lr,
betas[0],
betas[1],
eps,
weight_decay,
adamw_mode)

def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
Expand Down
7 changes: 5 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@
from deepspeed.utils import logger

TENSOR_CORE_ALIGN_SIZE = 8

ADAM_OPTIMIZER = 'adam'
LAMB_OPTIMIZER = 'lamb'
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
DEEPSPEED_ADAM = 'deepspeed_adam'
DEEPSPEED_OPTIMIZERS = [
ADAM_OPTIMIZER,
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
DEEPSPEED_ADAM
]

# extra optimizer parameters for adam
TORCH_ADAM_PARAM = "torch_adam"
ADAM_W_MODE_PARAM = "adam_w_mode"


def get_amp_enabled(param_dict):
if AMP in param_dict.keys():
Expand Down
31 changes: 24 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_ADAM, DEEPSPEED_OPTIMIZERS
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE_PARAM

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
Expand Down Expand Up @@ -548,15 +550,30 @@ def _configure_basic_optimizer(self, model_parameters):
raise ValueError(
"'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details"
)

if self.optimizer_name() == ADAM_OPTIMIZER:
if self.zero_cpu_offload():
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE_PARAM, True)

# zero-offload torch-adam adam_w_mode optimizer
# T|F T T torch.optim.AdamW
# T|F T F torch.optim.Adam
# T F T|F DeepSpeedCPUAdam(adam_w_mode)
# F F T|F FusedAdam(adam_w_mode)
if torch_adam and adam_w_mode:
optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters)
elif torch_adam and not adam_w_mode:
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)
else:
elif self.zero_cpu_offload() and not torch_adam:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=adam_w_mode)
elif not self.zero_cpu_offload() and not torch_adam:
from apex.optimizers.fused_adam import FusedAdam
optimizer_parameters[ADAM_W_MODE_PARAM] = adam_w_mode
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
elif self.optimizer_name() == DEEPSPEED_ADAM:
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)

elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
optimizer = FusedLamb(model_parameters, **optimizer_parameters)
Expand Down
14 changes: 8 additions & 6 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def __init__(self,

self.cpu_offload = cpu_offload

self.deepspeed_adam_offload = (cpu_offload
and type(init_optimizer) == DeepSpeedCPUAdam)
self.deepspeed_adam_offload = cpu_offload

self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'

Expand Down Expand Up @@ -1416,10 +1415,13 @@ def step(self, closure=None):
timers('optimizer_step').start()
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
#self.optimizer.step()
#for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
# fp16_partitions[partition_id].data.copy_(fp32_partition.data)
if type(self.optimizer) == DeepSpeedCPUAdam:
self.optimizer.step(
fp16_param_groups=self.parallel_partitioned_fp16_groups)
else:
self.optimizer.step()
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()

Expand Down
10 changes: 8 additions & 2 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ title: "DeepSpeed Configuration JSON"

| Fields | Value | Example |
| ------ | ------------------------------------------------------------ | ------------------------------ |
| type | The optimizer name. DeepSpeed natively supports **Adam**, **DeepSpeedAdam**, **OneBitAdam**, and **LAMB** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
| type | The optimizer name. DeepSpeed natively supports **Adam**, **OneBitAdam**, and **Lamb** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` |

Example of ***optimizer***
Example of ***optimizer*** with Adam

```json
"optimizer": {
Expand All @@ -53,6 +53,12 @@ title: "DeepSpeed Configuration JSON"
}
}
```
The Adam optimizer also supports the following two params keys/values in addition to the standard parameters from [torch.optim.Adam](https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam):
| "params" key | Description | Default |
| ------------- | --------------------------------------------------------------------------- | --------|
| torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false |
| adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true |

Another example of ***optimizer*** with 1-bit Adam specific parameters is as follows.

```json
Expand Down
Loading

0 comments on commit f5aa254

Please sign in to comment.