Skip to content

Commit

Permalink
tracking optimizer step in cpu-adam when loading checkpoint (microsof…
Browse files Browse the repository at this point in the history
…t#564)

* tracking optimizer step in cpu-adam when loading checkpoint

* add warning/error message for updating optimizer step count

* resolve build issue

* supporting state update from the python side

* track step from python in all cases

* remove comma
  • Loading branch information
RezaYazdaniAminabadi authored Dec 1, 2020
1 parent c78c29f commit 9f52a36
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 26 deletions.
41 changes: 21 additions & 20 deletions csrc/adam/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ void Adam_Optimizer::Step(float* _params,
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;

float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);

float step_size = -1 * _alpha / bias_correction1;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
size_t rounded_size = 0;

Expand All @@ -48,7 +45,7 @@ void Adam_Optimizer::Step(float* _params,
betta2_minus1_4.data = SIMD_SET(betta2_minus1);

AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
bias2_sqrt.data = SIMD_SET(_bias_correction2);

AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
Expand Down Expand Up @@ -130,7 +127,7 @@ void Adam_Optimizer::Step(float* _params,
variance = grad * betta2_minus1 + variance;

grad = sqrt(variance);
grad = grad * bias_correction2 + _eps;
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
Expand Down Expand Up @@ -172,16 +169,13 @@ void Adam_Optimizer::Step_4(float* _params,
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);

float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
bias2_sqrt.data = SIMD_SET(_bias_correction2);

AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);

float step_size = -1 * _alpha / bias_correction1;
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);

Expand Down Expand Up @@ -386,16 +380,13 @@ void Adam_Optimizer::Step_8(float* _params,
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);

float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
bias2_sqrt.data = SIMD_SET(_bias_correction2);

AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);

float step_size = -1 * _alpha / bias_correction1;
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);

Expand Down Expand Up @@ -611,6 +602,11 @@ void Adam_Optimizer::Step_8(float* _params,
int ds_adam_step(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
Expand All @@ -628,8 +624,8 @@ int ds_adam_step(int optimizer_id,

std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_lr(lr);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));

return 0;
Expand All @@ -638,6 +634,11 @@ int ds_adam_step(int optimizer_id,
int ds_adam_step_plus_copy(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
Expand All @@ -658,8 +659,8 @@ int ds_adam_step_plus_copy(int optimizer_id,

std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_lr(lr);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);

Expand Down
37 changes: 31 additions & 6 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,40 @@ class Adam_Optimizer {
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
inline void IncrementStep(size_t step)

inline void IncrementStep(size_t step, float beta1, float beta2)
{
if (_step < step) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
throw std::runtime_error("Optimizer lost track of step count!\n");
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
inline void update_lr(float lr) { _alpha = lr; }
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
{
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;

_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}

private:
#if defined(__AVX512__) or defined(__AVX256__)
Expand All @@ -124,6 +146,9 @@ class Adam_Optimizer {
float _betta2_t;
size_t _step;

float _bias_correction1;
float _bias_correction2;

float* _doubled_buffer[2];
bool _buf_index;
bool _adamw_mode;
Expand Down
13 changes: 13 additions & 0 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
def __init__(self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
Expand All @@ -61,6 +62,7 @@ def __init__(self,
betas=betas,
eps=eps,
weight_decay=weight_decay,
bias_correction=bias_correction,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)

Expand Down Expand Up @@ -112,12 +114,18 @@ def step(self, closure=None, fp16_param_groups=None):
#memory_format=torch.preserve_format)

state['step'] += 1
beta1, beta2 = group['betas']

if fp16_param_groups is not None:
self.ds_opt_adam.adam_update_copy(
self.opt_id,
state['step'],
group['lr'],
beta1,
beta2,
group['eps'],
group['weight_decay'],
group['bias_correction'],
p.data,
p.grad.data,
state['exp_avg'],
Expand All @@ -127,6 +135,11 @@ def step(self, closure=None, fp16_param_groups=None):
self.ds_opt_adam.adam_update(self.opt_id,
state['step'],
group['lr'],
beta1,
beta2,
group['eps'],
group['weight_decay'],
group['bias_correction'],
p.data,
p.grad.data,
state['exp_avg'],
Expand Down

0 comments on commit 9f52a36

Please sign in to comment.