Skip to content

Commit

Permalink
Fixing CPU-Adam convergence issue (microsoft#503)
Browse files Browse the repository at this point in the history
* fixing cpu-adam

* fixing copy with optimizer for data and model parallelism

* fixing cpu-adam

* fix cpu-adam

* fix cpu-adam
  • Loading branch information
RezaYazdaniAminabadi authored Nov 5, 2020
1 parent 4c37d70 commit 7d4d742
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 24 deletions.
32 changes: 27 additions & 5 deletions csrc/adam/cpu_adam.cpp
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void Adam_Optimizer::Step(float* _params,
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum *= momentum * _betta1;
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;

variance = variance * _betta2;
Expand Down Expand Up @@ -333,13 +333,31 @@ int create_adam_optimizer(int optimizer_id,
#if defined(__AVX512__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX512 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else
#if defined(__AVX256__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX2 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with scalar arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#endif
#endif
return 0;
Expand Down Expand Up @@ -434,8 +452,6 @@ void Adam_Optimizer::Step_8(float* _params,
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);

if (_weight_decay > 0 && !_adamw_mode) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
Expand Down Expand Up @@ -593,6 +609,8 @@ void Adam_Optimizer::Step_8(float* _params,
}

int ds_adam_step(int optimizer_id,
size_t step,
float lr,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
Expand All @@ -610,13 +628,16 @@ int ds_adam_step(int optimizer_id,

std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->IncrementStep(step);
opt->update_lr(lr);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));

return 0;
}

int ds_adam_step_plus_copy(int optimizer_id,
size_t step,
float lr,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
Expand All @@ -637,7 +658,8 @@ int ds_adam_step_plus_copy(int optimizer_id,

std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep();
opt->IncrementStep(step);
opt->update_lr(lr);
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);

Expand Down
15 changes: 12 additions & 3 deletions csrc/includes/cpu_adam.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Adam_Optimizer {
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_buf_index(false),
_adamw_mode(adamw_mode)
{
Expand Down Expand Up @@ -88,11 +89,18 @@ class Adam_Optimizer {
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
inline void IncrementStep()
inline void IncrementStep(size_t step)
{
_betta1_t *= _betta1;
_betta2_t *= _betta2;
if (_step < step) {
_step++;
if (_step != step) {
throw std::runtime_error("Optimizer lost track of step count!\n");
}
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
inline void update_lr(float lr) { _alpha = lr; }

private:
#if defined(__AVX512__) or defined(__AVX256__)
Expand All @@ -114,6 +122,7 @@ class Adam_Optimizer {

float _betta1_t;
float _betta2_t;
size_t _step;

float* _doubled_buffer[2];
bool _buf_index;
Expand Down
34 changes: 20 additions & 14 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,32 +95,38 @@ def step(self, closure=None, fp16_param_groups=None):
if p.grad is None:
continue

grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, device='cpu')
state['exp_avg'] = torch.zeros_like(
p.data,
memory_format=torch.preserve_format)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu')
state['exp_avg_sq'] = torch.zeros_like(
p.data,
memory_format=torch.preserve_format)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1

if fp16_param_groups is not None:
p_fp16 = fp16_param_groups[group_id][param_id]
ds_opt_adam.adam_update_copy(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq,
p_fp16)
ds_opt_adam.adam_update_copy(
self.opt_id,
state['step'],
group['lr'],
p.data,
p.grad.data,
state['exp_avg'],
state['exp_avg_sq'],
fp16_param_groups[group_id][param_id].data)
else:
ds_opt_adam.adam_update(self.opt_id,
state['step'],
group['lr'],
p.data,
grad,
exp_avg,
exp_avg_sq)
p.grad.data,
state['exp_avg'],
state['exp_avg_sq'])
return loss
7 changes: 5 additions & 2 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,8 +1416,11 @@ def step(self, closure=None):
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam:
self.optimizer.step(
fp16_param_groups=self.parallel_partitioned_fp16_groups)
fp16_param_groups = [
fp16_partitions[partition_id]
for fp16_partitions in self.parallel_partitioned_fp16_groups
]
self.optimizer.step(fp16_param_groups=fp16_param_groups)
else:
self.optimizer.step()
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
Expand Down

0 comments on commit 7d4d742

Please sign in to comment.