From d8ff56c34f9313a48855e88df8857223dd2e22ea Mon Sep 17 00:00:00 2001 From: RezaYazdaniAminabadi <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Fri, 4 Sep 2020 22:20:45 -0700 Subject: [PATCH] fixing adam copy fp16-param and add more compile flags for cpu_adam (#367) * fixing adam copy fp16-param-add more compile flags for cpu_adam * run precommit * fix variance indexes * fix array-sizes * move adam_test * rename perf test --- csrc/adam/cpu_adam.cpp | 496 ++++++++++++++++++++++++++++----------- csrc/includes/cpu_adam.h | 9 +- setup.py | 34 +-- tests/perf/adam_test1.py | 22 ++ 4 files changed, 406 insertions(+), 155 deletions(-) mode change 100644 => 100755 csrc/adam/cpu_adam.cpp create mode 100755 tests/perf/adam_test1.py diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp old mode 100644 new mode 100755 index 2d3b521bea1c..9399438b57f6 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -49,54 +49,56 @@ void Adam_Optimizer::Step(float* _params, size_t tile = 0; + for (size_t t = 0; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; #pragma omp parallel for - for (size_t i = 0; i < _param_size; i += SIMD_WIDTH) { - __m512 grad_4 = _mm512_loadu_ps(grads + i); + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + __m512 grad_4 = _mm512_loadu_ps(grads + i); - __m512 momntum_4 = _mm512_loadu_ps(_exp_avg + i); - __m512 varianc_4 = _mm512_loadu_ps(_exp_avg_sq + i); + __m512 momntum_4 = _mm512_loadu_ps(_exp_avg + i); + __m512 varianc_4 = _mm512_loadu_ps(_exp_avg_sq + i); - __m512 param_4 = _mm512_loadu_ps(_params + i); + __m512 param_4 = _mm512_loadu_ps(_params + i); - if (_weight_decay > 0) { - __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); - grad_4 = _mm512_fmadd_ps(param_4, weight_decay4, grad_4); - } + if (_weight_decay > 0) { + __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); + grad_4 = _mm512_fmadd_ps(param_4, weight_decay4, grad_4); + } - momntum_4 = _mm512_mul_ps(momntum_4, betta1_4); - momntum_4 = _mm512_fmadd_ps(grad_4, betta1_minus1_4, momntum_4); + momntum_4 = _mm512_mul_ps(momntum_4, betta1_4); + momntum_4 = _mm512_fmadd_ps(grad_4, betta1_minus1_4, momntum_4); - varianc_4 = _mm512_mul_ps(varianc_4, betta2_4); - grad_4 = _mm512_mul_ps(grad_4, grad_4); - varianc_4 = _mm512_fmadd_ps(grad_4, betta2_minus1_4, varianc_4); + varianc_4 = _mm512_mul_ps(varianc_4, betta2_4); + grad_4 = _mm512_mul_ps(grad_4, grad_4); + varianc_4 = _mm512_fmadd_ps(grad_4, betta2_minus1_4, varianc_4); - grad_4 = _mm512_sqrt_ps(varianc_4) / bias2_sqrt; - grad_4 = _mm512_add_ps(grad_4, eps_4); - grad_4 = _mm512_div_ps(momntum_4, grad_4); + grad_4 = _mm512_sqrt_ps(varianc_4) / bias2_sqrt; + grad_4 = _mm512_add_ps(grad_4, eps_4); + grad_4 = _mm512_div_ps(momntum_4, grad_4); - param_4 = _mm512_fmadd_ps(grad_4, step_size_4, param_4); + param_4 = _mm512_fmadd_ps(grad_4, step_size_4, param_4); - _mm512_storeu_ps(_params + i, param_4); - _mm512_storeu_ps(_exp_avg + i, momntum_4); - _mm512_storeu_ps(_exp_avg_sq + i, varianc_4); + _mm512_storeu_ps(_params + i, param_4); + _mm512_storeu_ps(_exp_avg + i, momntum_4); + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4); + } if (dev_params) { - if ((i + SIMD_WIDTH) % TILE == 0) { - size_t offset = tile * TILE; #pragma omp parallel for - for (size_t j = 0; j < TILE; j += 4) { - _doubled_buffer[buf_index][j] = (__half)_params[offset + j]; - _doubled_buffer[buf_index][j + 1] = (__half)_params[offset + j + 1]; - _doubled_buffer[buf_index][j + 2] = (__half)_params[offset + j + 2]; - _doubled_buffer[buf_index][j + 3] = (__half)_params[offset + j + 3]; - } - CUDA_CHECK(cudaMemcpyAsync(dev_params + tile * TILE, - _doubled_buffer[buf_index], - TILE * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); - buf_index = !buf_index; - tile++; + for (size_t j = 0; j < copy_size; j += 4) { + _doubled_buffer[buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + buf_index = !buf_index; } } } @@ -134,113 +136,116 @@ void Adam_Optimizer::Step_4(float* _params, __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); + for (size_t t = 0; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; #pragma omp parallel for - for (size_t i = 0; i < _param_size; i += (SIMD_WIDTH << 2)) { - __m512 grad_4[4]; - grad_4[0] = _mm512_loadu_ps(grads + i); - grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); - grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); - grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); - - __m512 momntum_4[2]; - momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); - momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); - momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); - momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); - - __m512 varianc_4[2]; - varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); - varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); - varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); - varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); - - __m512 param_4[2]; - param_4[0] = _mm512_loadu_ps(_params + i); - param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); - param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); - param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); - - if (_weight_decay > 0) { - __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); - grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); - grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); - grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); - grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { + __m512 grad_4[4]; + grad_4[0] = _mm512_loadu_ps(grads + i); + grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); + grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); + grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); + + __m512 momntum_4[4]; + momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); + momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); + momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); + momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); + + __m512 varianc_4[4]; + varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); + varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); + varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); + varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); + + __m512 param_4[4]; + param_4[0] = _mm512_loadu_ps(_params + i); + param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); + param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); + param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); + + if (_weight_decay > 0) { + __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); + grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); + grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); + grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); + grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); + } + + momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); + momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); + momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); + momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); + momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); + momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); + momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); + momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); + + varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); + varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); + varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); + varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); + grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); + grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); + grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); + grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); + varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); + varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); + varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); + varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); + + grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; + grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; + grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; + grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; + + grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); + grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); + grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); + grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); + grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); + grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); + grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); + grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); + + param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); + param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); + param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); + param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); + + _mm512_storeu_ps(_params + i, param_4[0]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); + + _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); + + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); } - momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); - momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); - momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); - momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); - momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); - momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); - momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); - momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); - - varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); - varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); - varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); - varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); - grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); - grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); - grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); - grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); - varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); - varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); - varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); - varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); - - grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; - grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; - grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; - grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; - - grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); - grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); - grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); - grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); - grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); - grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); - grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); - grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); - - param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); - param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); - param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); - param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); - - _mm512_storeu_ps(_params + i, param_4[0]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); - _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); - - _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); - _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); - - _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); - _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); if (dev_params) { - if ((i + (SIMD_WIDTH << 2)) % TILE == 0) { - size_t offset = tile * TILE; #pragma omp parallel for - for (size_t j = 0; j < TILE; j += 4) { - _doubled_buffer[buf_index][j] = (__half)_params[offset + j]; - _doubled_buffer[buf_index][j + 1] = (__half)_params[offset + j + 1]; - _doubled_buffer[buf_index][j + 2] = (__half)_params[offset + j + 2]; - _doubled_buffer[buf_index][j + 3] = (__half)_params[offset + j + 3]; - } - CUDA_CHECK(cudaMemcpyAsync(dev_params + tile * TILE, - _doubled_buffer[buf_index], - TILE * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); - buf_index = !buf_index; - tile++; + for (size_t j = 0; j < copy_size; j += 4) { + _doubled_buffer[buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + buf_index = !buf_index; } } } @@ -262,6 +267,219 @@ int create_adam_optimizer(int optimizer_id, return 0; } +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params) +{ + _betta1_t *= _betta1; + _betta2_t *= _betta2; + + __m512 betta1_4 = _mm512_set1_ps(_betta1); + __m512 betta2_4 = _mm512_set1_ps(_betta2); + + bool buf_index = 0; + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + __m512 betta1_minus1_4 = _mm512_set1_ps(betta1_minus1); + __m512 betta2_minus1_4 = _mm512_set1_ps(betta2_minus1); + + float bias_correction1 = 1 - _betta1_t; + float bias_correction2 = 1 - _betta2_t; + //__m512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + __m512 bias_correction2_4 = _mm512_set1_ps(bias_correction2); + + __m512 eps_4 = _mm512_set1_ps(_eps); + + float step_size = -1 * _alpha / bias_correction1; + __m512 step_size_4 = _mm512_set1_ps(step_size); + + __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); + + for (size_t t = 0; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { + __m512 grad_4[8]; + grad_4[0] = _mm512_loadu_ps(grads + i); + grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); + grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); + grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); + grad_4[4] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 2)); + grad_4[5] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 5); + grad_4[6] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 6); + grad_4[7] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 7); + + __m512 momntum_4[8]; + momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); + momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); + momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); + momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); + momntum_4[4] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 2)); + momntum_4[5] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 5); + momntum_4[6] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 6); + momntum_4[7] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 7); + + __m512 varianc_4[8]; + varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); + varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); + varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); + varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); + varianc_4[4] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2)); + varianc_4[5] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5); + varianc_4[6] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6); + varianc_4[7] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7); + + __m512 param_4[8]; + param_4[0] = _mm512_loadu_ps(_params + i); + param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); + param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); + param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); + param_4[4] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 2)); + param_4[5] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 5); + param_4[6] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 6); + param_4[7] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 7); + + if (_weight_decay > 0) { + __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); + grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); + grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); + grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); + grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); + grad_4[4] = _mm512_fmadd_ps(param_4[4], weight_decay4, grad_4[4]); + grad_4[5] = _mm512_fmadd_ps(param_4[5], weight_decay4, grad_4[5]); + grad_4[6] = _mm512_fmadd_ps(param_4[6], weight_decay4, grad_4[6]); + grad_4[7] = _mm512_fmadd_ps(param_4[7], weight_decay4, grad_4[7]); + } + + momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); + momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); + momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); + momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); + momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); + momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); + momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); + momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); + momntum_4[4] = _mm512_mul_ps(momntum_4[4], betta1_4); + momntum_4[4] = _mm512_fmadd_ps(grad_4[4], betta1_minus1_4, momntum_4[4]); + momntum_4[5] = _mm512_mul_ps(momntum_4[5], betta1_4); + momntum_4[5] = _mm512_fmadd_ps(grad_4[5], betta1_minus1_4, momntum_4[5]); + momntum_4[6] = _mm512_mul_ps(momntum_4[6], betta1_4); + momntum_4[6] = _mm512_fmadd_ps(grad_4[6], betta1_minus1_4, momntum_4[6]); + momntum_4[7] = _mm512_mul_ps(momntum_4[7], betta1_4); + momntum_4[7] = _mm512_fmadd_ps(grad_4[7], betta1_minus1_4, momntum_4[7]); + + varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); + varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); + varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); + varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); + varianc_4[4] = _mm512_mul_ps(varianc_4[4], betta2_4); + varianc_4[5] = _mm512_mul_ps(varianc_4[5], betta2_4); + varianc_4[6] = _mm512_mul_ps(varianc_4[6], betta2_4); + varianc_4[7] = _mm512_mul_ps(varianc_4[7], betta2_4); + grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); + grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); + grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); + grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); + grad_4[4] = _mm512_mul_ps(grad_4[4], grad_4[4]); + grad_4[5] = _mm512_mul_ps(grad_4[5], grad_4[5]); + grad_4[6] = _mm512_mul_ps(grad_4[6], grad_4[6]); + grad_4[7] = _mm512_mul_ps(grad_4[7], grad_4[7]); + varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); + varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); + varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); + varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); + varianc_4[4] = _mm512_fmadd_ps(grad_4[4], betta2_minus1_4, varianc_4[4]); + varianc_4[5] = _mm512_fmadd_ps(grad_4[5], betta2_minus1_4, varianc_4[5]); + varianc_4[6] = _mm512_fmadd_ps(grad_4[6], betta2_minus1_4, varianc_4[6]); + varianc_4[7] = _mm512_fmadd_ps(grad_4[7], betta2_minus1_4, varianc_4[7]); + + grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; + grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; + grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; + grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; + grad_4[4] = _mm512_sqrt_ps(varianc_4[4]) / bias2_sqrt; + grad_4[5] = _mm512_sqrt_ps(varianc_4[5]) / bias2_sqrt; + grad_4[6] = _mm512_sqrt_ps(varianc_4[6]) / bias2_sqrt; + grad_4[7] = _mm512_sqrt_ps(varianc_4[7]) / bias2_sqrt; + + grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); + grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); + grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); + grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); + grad_4[4] = _mm512_add_ps(grad_4[4], eps_4); + grad_4[5] = _mm512_add_ps(grad_4[5], eps_4); + grad_4[6] = _mm512_add_ps(grad_4[6], eps_4); + grad_4[7] = _mm512_add_ps(grad_4[7], eps_4); + grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); + grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); + grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); + grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); + grad_4[4] = _mm512_div_ps(momntum_4[4], grad_4[4]); + grad_4[5] = _mm512_div_ps(momntum_4[5], grad_4[5]); + grad_4[6] = _mm512_div_ps(momntum_4[6], grad_4[6]); + grad_4[7] = _mm512_div_ps(momntum_4[7], grad_4[7]); + + param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); + param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); + param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); + param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); + param_4[4] = _mm512_fmadd_ps(grad_4[4], step_size_4, param_4[4]); + param_4[5] = _mm512_fmadd_ps(grad_4[5], step_size_4, param_4[5]); + param_4[6] = _mm512_fmadd_ps(grad_4[6], step_size_4, param_4[6]); + param_4[7] = _mm512_fmadd_ps(grad_4[7], step_size_4, param_4[7]); + + _mm512_storeu_ps(_params + i, param_4[0]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 2), param_4[4]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 5, param_4[5]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 6, param_4[6]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 7, param_4[7]); + + _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 2), momntum_4[4]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 5, momntum_4[5]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 6, momntum_4[6]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 7, momntum_4[7]); + + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2), varianc_4[4]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5, varianc_4[5]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6, varianc_4[6]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7, varianc_4[7]); + } + if (dev_params) { +#pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) { + _doubled_buffer[buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; + } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + buf_index = !buf_index; + } + } +} + int ds_adam_step(int optimizer_id, torch::Tensor& params, torch::Tensor& grads, @@ -281,7 +499,7 @@ int ds_adam_step(int optimizer_id, std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->Step_4(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); + opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); return 0; } @@ -308,7 +526,7 @@ int ds_adam_step_plus_copy(int optimizer_id, std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->Step_4( + opt->Step_8( params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr); return 0; diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index d1a2aa088731..40f4cba692ea 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -18,7 +18,8 @@ } #define SIMD_WIDTH 16 -#define TILE 1048576 + +#define TILE (1024 * 1024 * 1024) class Adam_Optimizer { public: @@ -55,6 +56,12 @@ class Adam_Optimizer { float* _exp_avg_sa, size_t param_size, __half* dev_param = nullptr); + void Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params = nullptr); private: float _alpha; diff --git a/setup.py b/setup.py index ae5010a65b11..10564e59a6e8 100755 --- a/setup.py +++ b/setup.py @@ -118,21 +118,25 @@ def fetch_requirements(path): ## Adam ## if BUILD_MASK & DS_BUILD_ADAM: ext_modules.append( - CUDAExtension( - name='deepspeed.ops.adam.cpu_adam_op', - sources=[ - 'csrc/adam/cpu_adam.cpp', - ], - include_dirs=['csrc/includes'], - extra_compile_args={ - 'cxx': - ['-O3', - '-std=c++14', - '-g', - '-Wno-reorder', - '-march=native', - '-fopenmp'] - })) + CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op', + sources=[ + 'csrc/adam/cpu_adam.cpp', + ], + include_dirs=['csrc/includes', + '/usr/local/cuda/include'], + extra_compile_args={ + 'cxx': [ + '-O3', + '-std=c++14', + '-L/usr/local/cuda/lib64', + '-lcudart', + '-lcublas', + '-g', + '-Wno-reorder', + '-march=native', + '-fopenmp' + ] + })) ## Transformer ## if BUILD_MASK & DS_BUILD_TRANSFORMER: diff --git a/tests/perf/adam_test1.py b/tests/perf/adam_test1.py new file mode 100755 index 000000000000..800cb4f42eaa --- /dev/null +++ b/tests/perf/adam_test1.py @@ -0,0 +1,22 @@ +import torch +from deepspeed.ops.adam import DeepSpeedCPUAdam +import time + +device = 'cpu' +model_size = 1 * 1024**3 +param = torch.nn.Parameter(torch.ones(model_size, device=device)) +param_fp16 = torch.nn.Parameter(torch.ones(model_size, + dtype=torch.half, + device='cuda:0')) + +optimizer = DeepSpeedCPUAdam([param]) +#torch.set_num_threads(128) +param.grad = torch.ones(model_size, device=device) +avg = 0 +for i in range(10): + start = time.time() + optimizer.step(fp16_param_groups=[param_fp16]) + stop = time.time() + avg += (stop - start) + param.grad = torch.ones(model_size, device=device) * 2 +print("Elapsed Time is ", avg / 10)