diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index d884451700e9..e6afa7afe3e8 100755 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include #include #include @@ -11,9 +10,13 @@ #include "cuda.h" #include "curand.h" #include "custom_cuda_layers.h" +#include static std::unordered_map> s_optimizers; +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) + + // C++ interface void Adam_Optimizer::Step(float* _params, @@ -26,80 +29,136 @@ void Adam_Optimizer::Step(float* _params, _betta1_t *= _betta1; _betta2_t *= _betta2; - bool buf_index = 0; - - __m512 betta1_4 = _mm512_set1_ps(_betta1); - __m512 betta2_4 = _mm512_set1_ps(_betta2); + AVX_512 betta1_4; + betta1_4.data = _mm512_set1_ps(_betta1); + AVX_512 betta2_4; + betta2_4.data = _mm512_set1_ps(_betta2); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; - __m512 betta1_minus1_4 = _mm512_set1_ps(betta1_minus1); - __m512 betta2_minus1_4 = _mm512_set1_ps(betta2_minus1); + AVX_512 betta1_minus1_4; + betta1_minus1_4.data = _mm512_set1_ps(betta1_minus1); + AVX_512 betta2_minus1_4; + betta2_minus1_4.data = _mm512_set1_ps(betta2_minus1); float bias_correction1 = 1 - _betta1_t; - float bias_correction2 = 1 - _betta2_t; - //__m512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); - __m512 bias_correction2_4 = _mm512_set1_ps(bias_correction2); + float bias_correction2 = 1 / sqrt(1 - _betta2_t); + //AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + AVX_512 bias2_sqrt ; + bias2_sqrt.data = _mm512_set1_ps(bias_correction2); - __m512 eps_4 = _mm512_set1_ps(_eps); + AVX_512 eps_4; + eps_4.data = _mm512_set1_ps(_eps); float step_size = -1 * _alpha / bias_correction1; - __m512 step_size_4 = _mm512_set1_ps(step_size); + AVX_512 step_size_4; + step_size_4.data = _mm512_set1_ps(step_size); - __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); - size_t tile = 0; - - for (size_t t = 0; t < _param_size; t += TILE) { + for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH) { - __m512 grad_4 = _mm512_loadu_ps(grads + i); + AVX_512 grad_4; + grad_4.data = _mm512_loadu_ps(grads + i); - __m512 momntum_4 = _mm512_loadu_ps(_exp_avg + i); - __m512 varianc_4 = _mm512_loadu_ps(_exp_avg_sq + i); + AVX_512 momntum_4; + momntum_4.data = _mm512_loadu_ps(_exp_avg + i); + AVX_512 varianc_4; + varianc_4.data = _mm512_loadu_ps(_exp_avg_sq + i); - __m512 param_4 = _mm512_loadu_ps(_params + i); + AVX_512 param_4; + param_4.data = _mm512_loadu_ps(_params + i); if (_weight_decay > 0) { - __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); - grad_4 = _mm512_fmadd_ps(param_4, weight_decay4, grad_4); + AVX_512 weight_decay4; + weight_decay4.data = _mm512_set1_ps(_weight_decay); + grad_4.data = _mm512_fmadd_ps(param_4.data, weight_decay4.data, grad_4.data); } - momntum_4 = _mm512_mul_ps(momntum_4, betta1_4); - momntum_4 = _mm512_fmadd_ps(grad_4, betta1_minus1_4, momntum_4); - - varianc_4 = _mm512_mul_ps(varianc_4, betta2_4); - grad_4 = _mm512_mul_ps(grad_4, grad_4); - varianc_4 = _mm512_fmadd_ps(grad_4, betta2_minus1_4, varianc_4); - - grad_4 = _mm512_sqrt_ps(varianc_4) / bias2_sqrt; - grad_4 = _mm512_add_ps(grad_4, eps_4); - grad_4 = _mm512_div_ps(momntum_4, grad_4); - - param_4 = _mm512_fmadd_ps(grad_4, step_size_4, param_4); - - _mm512_storeu_ps(_params + i, param_4); - _mm512_storeu_ps(_exp_avg + i, momntum_4); - _mm512_storeu_ps(_exp_avg_sq + i, varianc_4); + momntum_4.data = _mm512_mul_ps(momntum_4.data, betta1_4.data); + momntum_4.data = _mm512_fmadd_ps(grad_4.data, betta1_minus1_4.data, momntum_4.data); + + varianc_4.data = _mm512_mul_ps(varianc_4.data, betta2_4.data); + grad_4.data = _mm512_mul_ps(grad_4.data, grad_4.data); + varianc_4.data = _mm512_fmadd_ps(grad_4.data, betta2_minus1_4.data, varianc_4.data); + + grad_4.data = _mm512_sqrt_ps(varianc_4.data); + grad_4.data = _mm512_fmadd_ps(grad_4.data, bias2_sqrt.data, eps_4.data); + grad_4.data = _mm512_div_ps(momntum_4.data, grad_4.data); + + param_4.data = _mm512_fmadd_ps(grad_4.data, step_size_4.data, param_4.data); + if (dev_params) { + for (size_t j = 0; j < SIMD_WIDTH; j += 4) { + _doubled_buffer[_buf_index][(i - t) + (j << 2)] = (__half)param_4.data_f[(j << 2)]; + _doubled_buffer[_buf_index][(i - t) + (j << 2) + 1] = (__half)param_4.data_f[(j << 2) + 1]; + _doubled_buffer[_buf_index][(i - t) + (j << 2) + 2] = (__half)param_4.data_f[(j << 2) + 2]; + _doubled_buffer[_buf_index][(i - t) + (j << 2) + 3] = (__half)param_4.data_f[(j << 2) + 3]; + } + } + _mm512_storeu_ps(_params + i, param_4.data); + _mm512_storeu_ps(_exp_avg + i, momntum_4.data); + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4.data); } - if (dev_params) { + if (dev_params) {/* #pragma omp parallel for for (size_t j = 0; j < copy_size; j += 4) { - _doubled_buffer[buf_index][j] = (__half)_params[t + j]; - _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; - _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; - _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; - } + _doubled_buffer[_buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[_buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[_buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[_buf_index][j + 3] = (__half)_params[t + j + 3]; + }*/ CUDA_CHECK(cudaMemcpyAsync(dev_params + t, - _doubled_buffer[buf_index], + _doubled_buffer[_buf_index], copy_size * sizeof(__half), cudaMemcpyHostToDevice, Context::Instance().GetCurrentStream())); - buf_index = !buf_index; + _buf_index = !_buf_index; + } + } + + if(_param_size > rounded_size) + { +#pragma omp parallel for + for (size_t k = rounded_size; k < _param_size; k++) + { + float grad = grads[k]; + float param = _params[k]; + float momntum = _exp_avg[k]; + float varianc = _exp_avg_sq[k]; + if (_weight_decay > 0) { + grad = param * _weight_decay + grad; + } + + momntum *= momntum * _betta1; + momntum = grad * betta1_minus1 + momntum; + + varianc = varianc * _betta2; + grad = grad * grad; + varianc = grad * betta2_minus1 + varianc; + + grad = sqrt(varianc); + grad = grad * bias_correction2 + _eps; + grad = momntum / grad; + + param = grad * step_size + param; + if (dev_params) + _doubled_buffer[_buf_index][k - rounded_size] = (__half)param; + + _params[k] = param; + _exp_avg[k] = momntum; + _exp_avg_sq[k] = varianc; + } + if (dev_params) { + CUDA_CHECK(cudaMemcpyAsync(dev_params + rounded_size, + _doubled_buffer[_buf_index], + (_param_size - rounded_size) * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); } } } @@ -114,141 +173,165 @@ void Adam_Optimizer::Step_4(float* _params, _betta1_t *= _betta1; _betta2_t *= _betta2; - __m512 betta1_4 = _mm512_set1_ps(_betta1); - __m512 betta2_4 = _mm512_set1_ps(_betta2); - - bool buf_index = 0; - size_t tile = 0; + AVX_512 betta1_4; + betta1_4.data = _mm512_set1_ps(_betta1); + AVX_512 betta2_4; + betta2_4.data = _mm512_set1_ps(_betta2); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; - __m512 betta1_minus1_4 = _mm512_set1_ps(betta1_minus1); - __m512 betta2_minus1_4 = _mm512_set1_ps(betta2_minus1); + AVX_512 betta1_minus1_4; + betta1_minus1_4.data = _mm512_set1_ps(betta1_minus1); + AVX_512 betta2_minus1_4; + betta2_minus1_4.data = _mm512_set1_ps(betta2_minus1); float bias_correction1 = 1 - _betta1_t; - float bias_correction2 = 1 - _betta2_t; - //__m512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); - __m512 bias_correction2_4 = _mm512_set1_ps(bias_correction2); + float bias_correction2 = 1 / sqrt(1 - _betta2_t); + //AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + AVX_512 bias2_sqrt ; + bias2_sqrt.data = _mm512_set1_ps(bias_correction2); - __m512 eps_4 = _mm512_set1_ps(_eps); + AVX_512 eps_4; + eps_4.data = _mm512_set1_ps(_eps); float step_size = -1 * _alpha / bias_correction1; - __m512 step_size_4 = _mm512_set1_ps(step_size); + AVX_512 step_size_4; + step_size_4.data = _mm512_set1_ps(step_size); - __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); + size_t rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2)); - for (size_t t = 0; t < _param_size; t += TILE) { + for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { - __m512 grad_4[4]; - grad_4[0] = _mm512_loadu_ps(grads + i); - grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); - grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); - grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); - - __m512 momntum_4[4]; - momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); - momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); - momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); - momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); - - __m512 varianc_4[4]; - varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); - varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); - varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); - varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); - - __m512 param_4[4]; - param_4[0] = _mm512_loadu_ps(_params + i); - param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); - param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); - param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); + AVX_512 grad_4[4]; + grad_4[0].data = _mm512_loadu_ps(grads + i); + grad_4[1].data = _mm512_loadu_ps(grads + i + SIMD_WIDTH); + grad_4[2].data = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); + grad_4[3].data = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); + + AVX_512 momntum_4[4]; + momntum_4[0].data = _mm512_loadu_ps(_exp_avg + i); + momntum_4[1].data = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); + momntum_4[2].data = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); + momntum_4[3].data = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); + + AVX_512 varianc_4[4]; + varianc_4[0].data = _mm512_loadu_ps(_exp_avg_sq + i); + varianc_4[1].data = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); + varianc_4[2].data = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); + varianc_4[3].data = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); + + AVX_512 param_4[4]; + param_4[0].data = _mm512_loadu_ps(_params + i); + param_4[1].data = _mm512_loadu_ps(_params + i + SIMD_WIDTH); + param_4[2].data = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); + param_4[3].data = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); if (_weight_decay > 0) { - __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); - grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); - grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); - grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); - grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); + AVX_512 weight_decay4; + weight_decay4.data = _mm512_set1_ps(_weight_decay); + grad_4[0].data = _mm512_fmadd_ps(param_4[0].data, weight_decay4.data, grad_4[0].data); + grad_4[1].data = _mm512_fmadd_ps(param_4[1].data, weight_decay4.data, grad_4[1].data); + grad_4[2].data = _mm512_fmadd_ps(param_4[2].data, weight_decay4.data, grad_4[2].data); + grad_4[3].data = _mm512_fmadd_ps(param_4[3].data, weight_decay4.data, grad_4[3].data); + } + + momntum_4[0].data = _mm512_mul_ps(momntum_4[0].data, betta1_4.data); + momntum_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta1_minus1_4.data, momntum_4[0].data); + momntum_4[1].data = _mm512_mul_ps(momntum_4[1].data, betta1_4.data); + momntum_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta1_minus1_4.data, momntum_4[1].data); + momntum_4[2].data = _mm512_mul_ps(momntum_4[2].data, betta1_4.data); + momntum_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta1_minus1_4.data, momntum_4[2].data); + momntum_4[3].data = _mm512_mul_ps(momntum_4[3].data, betta1_4.data); + momntum_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta1_minus1_4.data, momntum_4[3].data); + + varianc_4[0].data = _mm512_mul_ps(varianc_4[0].data, betta2_4.data); + varianc_4[1].data = _mm512_mul_ps(varianc_4[1].data, betta2_4.data); + varianc_4[2].data = _mm512_mul_ps(varianc_4[2].data, betta2_4.data); + varianc_4[3].data = _mm512_mul_ps(varianc_4[3].data, betta2_4.data); + grad_4[0].data = _mm512_mul_ps(grad_4[0].data, grad_4[0].data); + grad_4[1].data = _mm512_mul_ps(grad_4[1].data, grad_4[1].data); + grad_4[2].data = _mm512_mul_ps(grad_4[2].data, grad_4[2].data); + grad_4[3].data = _mm512_mul_ps(grad_4[3].data, grad_4[3].data); + varianc_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta2_minus1_4.data, varianc_4[0].data); + varianc_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta2_minus1_4.data, varianc_4[1].data); + varianc_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta2_minus1_4.data, varianc_4[2].data); + varianc_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta2_minus1_4.data, varianc_4[3].data); + + grad_4[0].data = _mm512_sqrt_ps(varianc_4[0].data); + grad_4[1].data = _mm512_sqrt_ps(varianc_4[1].data); + grad_4[2].data = _mm512_sqrt_ps(varianc_4[2].data); + grad_4[3].data = _mm512_sqrt_ps(varianc_4[3].data); + + grad_4[0].data = _mm512_fmadd_ps(grad_4[0].data, bias2_sqrt.data, eps_4.data); + grad_4[1].data = _mm512_fmadd_ps(grad_4[1].data, bias2_sqrt.data, eps_4.data); + grad_4[2].data = _mm512_fmadd_ps(grad_4[2].data, bias2_sqrt.data, eps_4.data); + grad_4[3].data = _mm512_fmadd_ps(grad_4[3].data, bias2_sqrt.data, eps_4.data); + grad_4[0].data = _mm512_div_ps(momntum_4[0].data, grad_4[0].data); + grad_4[1].data = _mm512_div_ps(momntum_4[1].data, grad_4[1].data); + grad_4[2].data = _mm512_div_ps(momntum_4[2].data, grad_4[2].data); + grad_4[3].data = _mm512_div_ps(momntum_4[3].data, grad_4[3].data); + + param_4[0].data = _mm512_fmadd_ps(grad_4[0].data, step_size_4.data, param_4[0].data); + param_4[1].data = _mm512_fmadd_ps(grad_4[1].data, step_size_4.data, param_4[1].data); + param_4[2].data = _mm512_fmadd_ps(grad_4[2].data, step_size_4.data, param_4[2].data); + param_4[3].data = _mm512_fmadd_ps(grad_4[3].data, step_size_4.data, param_4[3].data); + + if (dev_params) { + for(int u = 0;u < 4;u++) + { + for (size_t j = 0; j < SIMD_WIDTH; j += 4) { + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2)] = (__half)param_4[u].data_f[(j << 2)]; + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 1] = (__half)param_4[u].data_f[(j << 2) + 1]; + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 2] = (__half)param_4[u].data_f[(j << 2) + 2]; + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 3] = (__half)param_4[u].data_f[(j << 2) + 3]; + } + } } - momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); - momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); - momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); - momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); - momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); - momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); - momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); - momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); - - varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); - varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); - varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); - varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); - grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); - grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); - grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); - grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); - varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); - varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); - varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); - varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); - - grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; - grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; - grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; - grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; - - grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); - grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); - grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); - grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); - grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); - grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); - grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); - grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); - - param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); - param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); - param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); - param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); - - _mm512_storeu_ps(_params + i, param_4[0]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); - _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); - - _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); - _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); - - _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); - _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); + _mm512_storeu_ps(_params + i, param_4[0].data); + _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1].data); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2].data); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3].data); + + _mm512_storeu_ps(_exp_avg + i, momntum_4[0].data); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1].data); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2].data); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3].data); + + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0].data); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1].data); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2].data); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3].data); } - if (dev_params) { + if (dev_params) {/* #pragma omp parallel for for (size_t j = 0; j < copy_size; j += 4) { - _doubled_buffer[buf_index][j] = (__half)_params[t + j]; - _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; - _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; - _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; - } + _doubled_buffer[_buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[_buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[_buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[_buf_index][j + 3] = (__half)_params[t + j + 3]; + }*/ CUDA_CHECK(cudaMemcpyAsync(dev_params + t, - _doubled_buffer[buf_index], + _doubled_buffer[_buf_index], copy_size * sizeof(__half), cudaMemcpyHostToDevice, Context::Instance().GetCurrentStream())); - buf_index = !buf_index; + _buf_index = !_buf_index; } } + if(_param_size > rounded_size) + Step((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params)); } int create_adam_optimizer(int optimizer_id, @@ -278,207 +361,237 @@ void Adam_Optimizer::Step_8(float* _params, _betta1_t *= _betta1; _betta2_t *= _betta2; - __m512 betta1_4 = _mm512_set1_ps(_betta1); - __m512 betta2_4 = _mm512_set1_ps(_betta2); - - bool buf_index = 0; + AVX_512 betta1_4; + betta1_4.data = _mm512_set1_ps(_betta1); + AVX_512 betta2_4; + betta2_4.data = _mm512_set1_ps(_betta2); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; - __m512 betta1_minus1_4 = _mm512_set1_ps(betta1_minus1); - __m512 betta2_minus1_4 = _mm512_set1_ps(betta2_minus1); + AVX_512 betta1_minus1_4; + betta1_minus1_4.data = _mm512_set1_ps(betta1_minus1); + AVX_512 betta2_minus1_4; + betta2_minus1_4.data = _mm512_set1_ps(betta2_minus1); float bias_correction1 = 1 - _betta1_t; - float bias_correction2 = 1 - _betta2_t; - //__m512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); - __m512 bias_correction2_4 = _mm512_set1_ps(bias_correction2); + float bias_correction2 = 1 / sqrt(1 - _betta2_t); + //AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + AVX_512 bias2_sqrt ; + bias2_sqrt.data = _mm512_set1_ps(bias_correction2); - __m512 eps_4 = _mm512_set1_ps(_eps); + AVX_512 eps_4; + eps_4.data = _mm512_set1_ps(_eps); float step_size = -1 * _alpha / bias_correction1; - __m512 step_size_4 = _mm512_set1_ps(step_size); + AVX_512 step_size_4; + step_size_4.data = _mm512_set1_ps(step_size); - __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); + size_t rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3)); - for (size_t t = 0; t < _param_size; t += TILE) { + for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; size_t offset = copy_size + t; #pragma omp parallel for for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { - __m512 grad_4[8]; - grad_4[0] = _mm512_loadu_ps(grads + i); - grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); - grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); - grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); - grad_4[4] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 2)); - grad_4[5] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 5); - grad_4[6] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 6); - grad_4[7] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 7); - - __m512 momntum_4[8]; - momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); - momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); - momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); - momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); - momntum_4[4] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 2)); - momntum_4[5] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 5); - momntum_4[6] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 6); - momntum_4[7] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 7); - - __m512 varianc_4[8]; - varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); - varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); - varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); - varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); - varianc_4[4] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2)); - varianc_4[5] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5); - varianc_4[6] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6); - varianc_4[7] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7); - - __m512 param_4[8]; - param_4[0] = _mm512_loadu_ps(_params + i); - param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); - param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); - param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); - param_4[4] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 2)); - param_4[5] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 5); - param_4[6] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 6); - param_4[7] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 7); + AVX_512 grad_4[8]; + grad_4[0].data = _mm512_loadu_ps(grads + i); + grad_4[1].data = _mm512_loadu_ps(grads + i + SIMD_WIDTH); + grad_4[2].data = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); + grad_4[3].data = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); + grad_4[4].data = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 2)); + grad_4[5].data = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 5); + grad_4[6].data = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 6); + grad_4[7].data = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 7); + + AVX_512 momntum_4[8]; + momntum_4[0].data = _mm512_loadu_ps(_exp_avg + i); + momntum_4[1].data = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); + momntum_4[2].data = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); + momntum_4[3].data = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); + momntum_4[4].data = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 2)); + momntum_4[5].data = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 5); + momntum_4[6].data = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 6); + momntum_4[7].data = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 7); + + AVX_512 varianc_4[8]; + varianc_4[0].data = _mm512_loadu_ps(_exp_avg_sq + i); + varianc_4[1].data = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); + varianc_4[2].data = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); + varianc_4[3].data = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); + varianc_4[4].data = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2)); + varianc_4[5].data = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5); + varianc_4[6].data = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6); + varianc_4[7].data = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7); + + AVX_512 param_4[8]; + param_4[0].data = _mm512_loadu_ps(_params + i); + param_4[1].data = _mm512_loadu_ps(_params + i + SIMD_WIDTH); + param_4[2].data = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); + param_4[3].data = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); + param_4[4].data = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 2)); + param_4[5].data = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 5); + param_4[6].data = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 6); + param_4[7].data = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 7); if (_weight_decay > 0) { - __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); - grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); - grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); - grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); - grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); - grad_4[4] = _mm512_fmadd_ps(param_4[4], weight_decay4, grad_4[4]); - grad_4[5] = _mm512_fmadd_ps(param_4[5], weight_decay4, grad_4[5]); - grad_4[6] = _mm512_fmadd_ps(param_4[6], weight_decay4, grad_4[6]); - grad_4[7] = _mm512_fmadd_ps(param_4[7], weight_decay4, grad_4[7]); + AVX_512 weight_decay4; + weight_decay4.data = _mm512_set1_ps(_weight_decay); + grad_4[0].data = _mm512_fmadd_ps(param_4[0].data, weight_decay4.data, grad_4[0].data); + grad_4[1].data = _mm512_fmadd_ps(param_4[1].data, weight_decay4.data, grad_4[1].data); + grad_4[2].data = _mm512_fmadd_ps(param_4[2].data, weight_decay4.data, grad_4[2].data); + grad_4[3].data = _mm512_fmadd_ps(param_4[3].data, weight_decay4.data, grad_4[3].data); + grad_4[4].data = _mm512_fmadd_ps(param_4[4].data, weight_decay4.data, grad_4[4].data); + grad_4[5].data = _mm512_fmadd_ps(param_4[5].data, weight_decay4.data, grad_4[5].data); + grad_4[6].data = _mm512_fmadd_ps(param_4[6].data, weight_decay4.data, grad_4[6].data); + grad_4[7].data = _mm512_fmadd_ps(param_4[7].data, weight_decay4.data, grad_4[7].data); + } + + momntum_4[0].data = _mm512_mul_ps(momntum_4[0].data, betta1_4.data); + momntum_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta1_minus1_4.data, momntum_4[0].data); + momntum_4[1].data = _mm512_mul_ps(momntum_4[1].data, betta1_4.data); + momntum_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta1_minus1_4.data, momntum_4[1].data); + momntum_4[2].data = _mm512_mul_ps(momntum_4[2].data, betta1_4.data); + momntum_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta1_minus1_4.data, momntum_4[2].data); + momntum_4[3].data = _mm512_mul_ps(momntum_4[3].data, betta1_4.data); + momntum_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta1_minus1_4.data, momntum_4[3].data); + momntum_4[4].data = _mm512_mul_ps(momntum_4[4].data, betta1_4.data); + momntum_4[4].data = _mm512_fmadd_ps(grad_4[4].data, betta1_minus1_4.data, momntum_4[4].data); + momntum_4[5].data = _mm512_mul_ps(momntum_4[5].data, betta1_4.data); + momntum_4[5].data = _mm512_fmadd_ps(grad_4[5].data, betta1_minus1_4.data, momntum_4[5].data); + momntum_4[6].data = _mm512_mul_ps(momntum_4[6].data, betta1_4.data); + momntum_4[6].data = _mm512_fmadd_ps(grad_4[6].data, betta1_minus1_4.data, momntum_4[6].data); + momntum_4[7].data = _mm512_mul_ps(momntum_4[7].data, betta1_4.data); + momntum_4[7].data = _mm512_fmadd_ps(grad_4[7].data, betta1_minus1_4.data, momntum_4[7].data); + + varianc_4[0].data = _mm512_mul_ps(varianc_4[0].data, betta2_4.data); + varianc_4[1].data = _mm512_mul_ps(varianc_4[1].data, betta2_4.data); + varianc_4[2].data = _mm512_mul_ps(varianc_4[2].data, betta2_4.data); + varianc_4[3].data = _mm512_mul_ps(varianc_4[3].data, betta2_4.data); + varianc_4[4].data = _mm512_mul_ps(varianc_4[4].data, betta2_4.data); + varianc_4[5].data = _mm512_mul_ps(varianc_4[5].data, betta2_4.data); + varianc_4[6].data = _mm512_mul_ps(varianc_4[6].data, betta2_4.data); + varianc_4[7].data = _mm512_mul_ps(varianc_4[7].data, betta2_4.data); + grad_4[0].data = _mm512_mul_ps(grad_4[0].data, grad_4[0].data); + grad_4[1].data = _mm512_mul_ps(grad_4[1].data, grad_4[1].data); + grad_4[2].data = _mm512_mul_ps(grad_4[2].data, grad_4[2].data); + grad_4[3].data = _mm512_mul_ps(grad_4[3].data, grad_4[3].data); + grad_4[4].data = _mm512_mul_ps(grad_4[4].data, grad_4[4].data); + grad_4[5].data = _mm512_mul_ps(grad_4[5].data, grad_4[5].data); + grad_4[6].data = _mm512_mul_ps(grad_4[6].data, grad_4[6].data); + grad_4[7].data = _mm512_mul_ps(grad_4[7].data, grad_4[7].data); + varianc_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta2_minus1_4.data, varianc_4[0].data); + varianc_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta2_minus1_4.data, varianc_4[1].data); + varianc_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta2_minus1_4.data, varianc_4[2].data); + varianc_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta2_minus1_4.data, varianc_4[3].data); + varianc_4[4].data = _mm512_fmadd_ps(grad_4[4].data, betta2_minus1_4.data, varianc_4[4].data); + varianc_4[5].data = _mm512_fmadd_ps(grad_4[5].data, betta2_minus1_4.data, varianc_4[5].data); + varianc_4[6].data = _mm512_fmadd_ps(grad_4[6].data, betta2_minus1_4.data, varianc_4[6].data); + varianc_4[7].data = _mm512_fmadd_ps(grad_4[7].data, betta2_minus1_4.data, varianc_4[7].data); + + grad_4[0].data = _mm512_sqrt_ps(varianc_4[0].data); + grad_4[1].data = _mm512_sqrt_ps(varianc_4[1].data); + grad_4[2].data = _mm512_sqrt_ps(varianc_4[2].data); + grad_4[3].data = _mm512_sqrt_ps(varianc_4[3].data); + grad_4[4].data = _mm512_sqrt_ps(varianc_4[4].data); + grad_4[5].data = _mm512_sqrt_ps(varianc_4[5].data); + grad_4[6].data = _mm512_sqrt_ps(varianc_4[6].data); + grad_4[7].data = _mm512_sqrt_ps(varianc_4[7].data); + + grad_4[0].data = _mm512_fmadd_ps(grad_4[0].data, bias2_sqrt.data, eps_4.data); + grad_4[1].data = _mm512_fmadd_ps(grad_4[1].data, bias2_sqrt.data, eps_4.data); + grad_4[2].data = _mm512_fmadd_ps(grad_4[2].data, bias2_sqrt.data, eps_4.data); + grad_4[3].data = _mm512_fmadd_ps(grad_4[3].data, bias2_sqrt.data, eps_4.data); + grad_4[4].data = _mm512_fmadd_ps(grad_4[4].data, bias2_sqrt.data, eps_4.data); + grad_4[5].data = _mm512_fmadd_ps(grad_4[5].data, bias2_sqrt.data, eps_4.data); + grad_4[6].data = _mm512_fmadd_ps(grad_4[6].data, bias2_sqrt.data, eps_4.data); + grad_4[7].data = _mm512_fmadd_ps(grad_4[7].data, bias2_sqrt.data, eps_4.data); + grad_4[0].data = _mm512_div_ps(momntum_4[0].data, grad_4[0].data); + grad_4[1].data = _mm512_div_ps(momntum_4[1].data, grad_4[1].data); + grad_4[2].data = _mm512_div_ps(momntum_4[2].data, grad_4[2].data); + grad_4[3].data = _mm512_div_ps(momntum_4[3].data, grad_4[3].data); + grad_4[4].data = _mm512_div_ps(momntum_4[4].data, grad_4[4].data); + grad_4[5].data = _mm512_div_ps(momntum_4[5].data, grad_4[5].data); + grad_4[6].data = _mm512_div_ps(momntum_4[6].data, grad_4[6].data); + grad_4[7].data = _mm512_div_ps(momntum_4[7].data, grad_4[7].data); + + param_4[0].data = _mm512_fmadd_ps(grad_4[0].data, step_size_4.data, param_4[0].data); + param_4[1].data = _mm512_fmadd_ps(grad_4[1].data, step_size_4.data, param_4[1].data); + param_4[2].data = _mm512_fmadd_ps(grad_4[2].data, step_size_4.data, param_4[2].data); + param_4[3].data = _mm512_fmadd_ps(grad_4[3].data, step_size_4.data, param_4[3].data); + param_4[4].data = _mm512_fmadd_ps(grad_4[4].data, step_size_4.data, param_4[4].data); + param_4[5].data = _mm512_fmadd_ps(grad_4[5].data, step_size_4.data, param_4[5].data); + param_4[6].data = _mm512_fmadd_ps(grad_4[6].data, step_size_4.data, param_4[6].data); + param_4[7].data = _mm512_fmadd_ps(grad_4[7].data, step_size_4.data, param_4[7].data); + + _mm512_storeu_ps(_params + i, param_4[0].data); + _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1].data); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2].data); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3].data); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 2), param_4[4].data); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 5, param_4[5].data); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 6, param_4[6].data); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 7, param_4[7].data); + + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t), param_4[0]); + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1]); + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1), param_4[2]); + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3]); + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2), param_4[4]); + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5]); + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6]); + //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7]); + if (dev_params) { + for(int u = 0;u < 8;u++) + { + for (size_t j = 0; j < SIMD_WIDTH; j += 4) { + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2)] = (__half)param_4[u].data_f[(j << 2)]; + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 1] = (__half)param_4[u].data_f[(j << 2) + 1]; + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 2] = (__half)param_4[u].data_f[(j << 2) + 2]; + _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 3] = (__half)param_4[u].data_f[(j << 2) + 3]; + } + } } - momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); - momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); - momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); - momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); - momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); - momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); - momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); - momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); - momntum_4[4] = _mm512_mul_ps(momntum_4[4], betta1_4); - momntum_4[4] = _mm512_fmadd_ps(grad_4[4], betta1_minus1_4, momntum_4[4]); - momntum_4[5] = _mm512_mul_ps(momntum_4[5], betta1_4); - momntum_4[5] = _mm512_fmadd_ps(grad_4[5], betta1_minus1_4, momntum_4[5]); - momntum_4[6] = _mm512_mul_ps(momntum_4[6], betta1_4); - momntum_4[6] = _mm512_fmadd_ps(grad_4[6], betta1_minus1_4, momntum_4[6]); - momntum_4[7] = _mm512_mul_ps(momntum_4[7], betta1_4); - momntum_4[7] = _mm512_fmadd_ps(grad_4[7], betta1_minus1_4, momntum_4[7]); - - varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); - varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); - varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); - varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); - varianc_4[4] = _mm512_mul_ps(varianc_4[4], betta2_4); - varianc_4[5] = _mm512_mul_ps(varianc_4[5], betta2_4); - varianc_4[6] = _mm512_mul_ps(varianc_4[6], betta2_4); - varianc_4[7] = _mm512_mul_ps(varianc_4[7], betta2_4); - grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); - grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); - grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); - grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); - grad_4[4] = _mm512_mul_ps(grad_4[4], grad_4[4]); - grad_4[5] = _mm512_mul_ps(grad_4[5], grad_4[5]); - grad_4[6] = _mm512_mul_ps(grad_4[6], grad_4[6]); - grad_4[7] = _mm512_mul_ps(grad_4[7], grad_4[7]); - varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); - varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); - varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); - varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); - varianc_4[4] = _mm512_fmadd_ps(grad_4[4], betta2_minus1_4, varianc_4[4]); - varianc_4[5] = _mm512_fmadd_ps(grad_4[5], betta2_minus1_4, varianc_4[5]); - varianc_4[6] = _mm512_fmadd_ps(grad_4[6], betta2_minus1_4, varianc_4[6]); - varianc_4[7] = _mm512_fmadd_ps(grad_4[7], betta2_minus1_4, varianc_4[7]); - - grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; - grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; - grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; - grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; - grad_4[4] = _mm512_sqrt_ps(varianc_4[4]) / bias2_sqrt; - grad_4[5] = _mm512_sqrt_ps(varianc_4[5]) / bias2_sqrt; - grad_4[6] = _mm512_sqrt_ps(varianc_4[6]) / bias2_sqrt; - grad_4[7] = _mm512_sqrt_ps(varianc_4[7]) / bias2_sqrt; - - grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); - grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); - grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); - grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); - grad_4[4] = _mm512_add_ps(grad_4[4], eps_4); - grad_4[5] = _mm512_add_ps(grad_4[5], eps_4); - grad_4[6] = _mm512_add_ps(grad_4[6], eps_4); - grad_4[7] = _mm512_add_ps(grad_4[7], eps_4); - grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); - grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); - grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); - grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); - grad_4[4] = _mm512_div_ps(momntum_4[4], grad_4[4]); - grad_4[5] = _mm512_div_ps(momntum_4[5], grad_4[5]); - grad_4[6] = _mm512_div_ps(momntum_4[6], grad_4[6]); - grad_4[7] = _mm512_div_ps(momntum_4[7], grad_4[7]); - - param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); - param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); - param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); - param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); - param_4[4] = _mm512_fmadd_ps(grad_4[4], step_size_4, param_4[4]); - param_4[5] = _mm512_fmadd_ps(grad_4[5], step_size_4, param_4[5]); - param_4[6] = _mm512_fmadd_ps(grad_4[6], step_size_4, param_4[6]); - param_4[7] = _mm512_fmadd_ps(grad_4[7], step_size_4, param_4[7]); - - _mm512_storeu_ps(_params + i, param_4[0]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); - _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); - _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 2), param_4[4]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH * 5, param_4[5]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH * 6, param_4[6]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH * 7, param_4[7]); - - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t), param_4[0]); - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t) + SIMD_WIDTH, param_4[1]); - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t) + (SIMD_WIDTH << 1), param_4[2]); - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3]); - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t) + (SIMD_WIDTH << 2), param_4[4]); - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5]); - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6]); - _mm512_storeu_ps(_doubled_buffer[buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7]); - - _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); - _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); - _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 2), momntum_4[4]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 5, momntum_4[5]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 6, momntum_4[6]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 7, momntum_4[7]); - - _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); - _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); - _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2), varianc_4[4]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5, varianc_4[5]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6, varianc_4[6]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7, varianc_4[7]); + _mm512_storeu_ps(_exp_avg + i, momntum_4[0].data); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1].data); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2].data); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3].data); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 2), momntum_4[4].data); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 5, momntum_4[5].data); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 6, momntum_4[6].data); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 7, momntum_4[7].data); + + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0].data); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1].data); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2].data); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3].data); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2), varianc_4[4].data); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5, varianc_4[5].data); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6, varianc_4[6].data); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7, varianc_4[7].data); } if (dev_params) { - launch_param_update(_doubled_buffer[buf_index], + /*launch_param_update(_doubled_buffer[_buf_index], dev_params + t, copy_size, Context::Instance().GetCurrentStream()); - buf_index = !buf_index; + _buf_index = !_buf_index;*/ + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[_buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + _buf_index = !_buf_index; } } + if(_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params)); } int ds_adam_step(int optimizer_id, diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h old mode 100644 new mode 100755 index 40f4cba692ea..f56eb1501aa4 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -7,6 +7,7 @@ #include "cublas_v2.h" #include "cuda.h" #include "curand.h" +#include #define CUDA_CHECK(callstr) \ { \ @@ -34,7 +35,8 @@ class Adam_Optimizer { _eps(eps), _weight_decay(weight_decay), _betta1_t(1.0), - _betta2_t(1.0) + _betta2_t(1.0), + _buf_index(false) { cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(__half)); cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(__half)); @@ -64,6 +66,12 @@ class Adam_Optimizer { __half* dev_params = nullptr); private: + + union AVX_512{ + __m512 data; + float data_f[16]; + }; + float _alpha; float _betta1; float _betta2; @@ -74,4 +82,5 @@ class Adam_Optimizer { float _betta2_t; __half* _doubled_buffer[2]; + bool _buf_index; }; diff --git a/tests/unit/test_adam_acuracy.py b/tests/unit/test_adam_acuracy.py index e616e140bdde..f61b6ecba58b 100755 --- a/tests/unit/test_adam_acuracy.py +++ b/tests/unit/test_adam_acuracy.py @@ -10,9 +10,6 @@ def check_equal(first, second, atol=1e-2, verbose=False): - if verbose: - print(first) - print(second) x = first.detach().numpy() y = second.detach().numpy() if verbose: @@ -23,6 +20,11 @@ def check_equal(first, second, atol=1e-2, verbose=False): @pytest.mark.parametrize('model_size', [ + (64), + (22), + (55), + (127), + (1024), (1048576), ]) # yapf: disable def test_adam_opt(model_size):