From abdb08e9f9b13d6ec4da0873490b0e39003c3580 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 13 Jul 2023 05:43:33 -0400 Subject: [PATCH 01/17] fused adam can build --- accelerator/cpu_accelerator.py | 6 +- csrc/cpu/adam/cpu_adam.h | 201 ++++++++++++++++++++++++++ csrc/cpu/adam/fused_adam.cpp | 253 +++++++++++++++++++++++++++++++++ csrc/cpu/adam/simd.h | 140 ++++++++++++++++++ op_builder/cpu/__init__.py | 1 + op_builder/cpu/fused_adam.py | 25 ++++ 6 files changed, 624 insertions(+), 2 deletions(-) create mode 100644 csrc/cpu/adam/cpu_adam.h create mode 100644 csrc/cpu/adam/fused_adam.cpp create mode 100644 csrc/cpu/adam/simd.h create mode 100644 op_builder/cpu/fused_adam.py diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index b1fb036fed07..3a041b52f3e4 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -245,12 +245,14 @@ def get_op_builder(self, class_name): # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path from op_builder import __deepspeed__ # noqa: F401 - from op_builder.cpu import CCLCommBuilder, NotImplementedBuilder + from op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, NotImplementedBuilder except ImportError: - from deepspeed.ops.op_builder.cpu import CCLCommBuilder, NotImplementedBuilder + from deepspeed.ops.op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, NotImplementedBuilder if class_name == "CCLCommBuilder": return CCLCommBuilder + elif class_name == "FusedAdamBuilder": + return FusedAdamBuilder else: # return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests return NotImplementedBuilder diff --git a/csrc/cpu/adam/cpu_adam.h b/csrc/cpu/adam/cpu_adam.h new file mode 100644 index 000000000000..8b540ca98384 --- /dev/null +++ b/csrc/cpu/adam/cpu_adam.h @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#define NOMINMAX // Windows idiosyncrasy + // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c + +#include +#include +#include "simd.h" + +#include +typedef unsigned short ds_half_precision_t; + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + ds_half_precision_t* dev_param = nullptr, \ + bool half_precision = false); + +class Adam_Optimizer { +public: + Adam_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) + { + } + ~Adam_Optimizer() + { + } + +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t param_size, + ds_half_precision_t* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) + { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + +private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; + +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Adam_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + int rshft = half_precision ? 1 : 0; + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + (i >> rshft), half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, _exp_avg + i, false); + + AVX_Data variance_4[span]; + simd_load(variance_4, _exp_avg_sq + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + (i >> rshft), half_precision); + + if (_weight_decay > 0 && !_adamw_mode) { + simd_fma(grad_4, param_4, weight_decay4, grad_4); + } + + simd_mul(momentum_4, momentum_4, betta1_4); + simd_fma(momentum_4, grad_4, betta1_minus1_4, momentum_4); + simd_mul(variance_4, variance_4, betta2_4); + simd_mul(grad_4, grad_4, grad_4); + simd_fma(variance_4, grad_4, betta2_minus1_4, variance_4); + simd_sqrt(grad_4, variance_4); + simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); + simd_div(grad_4, momentum_4, grad_4); + + if (_weight_decay > 0 && _adamw_mode) { + simd_fma(param_4, param_4, weight_decay4, param_4); + } + + simd_fma(param_4, grad_4, step_size_4, param_4); + + simd_store(_params + (i >> rshft), param_4, half_precision); + simd_store(_exp_avg + i, momentum_4, false); + simd_store(_exp_avg_sq + i, variance_4, false); + } + } + *rounded_size = new_rounded_size; +} +#endif diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp new file mode 100644 index 000000000000..c285cd3fef3b --- /dev/null +++ b/csrc/cpu/adam/fused_adam.cpp @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cpu_adam.h" +#include +#include +#include +#include +#include +#include + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adam_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) { + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + ds_half_precision_t* grads_cast_h; + ds_half_precision_t* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast(grads); + params_cast_h = reinterpret_cast(_params); + } + + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; + if (half_precision) + params_cast_h[k] = (ds_half_precision_t)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + } + } +} + +void Adam_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false) +{ + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", + alpha, + betta1, + betta2, + weight_decay, + (int)adamw_mode); + } + + return 0; +} + +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + // assert(params.options().dtype() == grads.options().dtype()); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + nullptr, + (params.options().dtype() == at::kHalf)); + + return 0; +} + +int destroy_adam_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} + +void multi_tensor_adam(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay) { +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + /* + m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); + m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); + */ + m.def("multi_tensor_adam", + &multi_tensor_adam, + "Compute and apply gradient update to parameters for Adam optimizer"); +} diff --git a/csrc/cpu/adam/simd.h b/csrc/cpu/adam/simd.h new file mode 100644 index 000000000000..712dd5b32e96 --- /dev/null +++ b/csrc/cpu/adam/simd.h @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define TILE (128 * 1024 * 1024) +#if defined(__AVX512__) or defined(__AVX256__) + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) + +#if defined(__AVX512__) +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_WIDTH 16 + +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x)) +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm512_storeu_ps(x, d)) + +#define INTV __m256i +#elif defined(__AVX256__) +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_WIDTH 8 +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) + +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm256_storeu_ps(x, d)) + +#define INTV __m128i +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) + __m256 data; +#endif + // float data_f[16]; +}; + +template +inline void simd_store(float* dst, AVX_Data* src, bool half_precision) +{ + size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); +#pragma unroll + for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); } +} +template +inline void simd_load(AVX_Data* dst, float* src, bool half_precision) +{ + size_t width = (half_precision ? 1 : SIMD_WIDTH); +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data); + } +} +template +inline void simd_sqrt(AVX_Data* dst, AVX_Data* src) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma unroll + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } +} + +#endif diff --git a/op_builder/cpu/__init__.py b/op_builder/cpu/__init__.py index b3e854d37d5f..301edcd5b565 100644 --- a/op_builder/cpu/__init__.py +++ b/op_builder/cpu/__init__.py @@ -5,4 +5,5 @@ '''Copyright The Microsoft DeepSpeed Team''' from .comm import CCLCommBuilder +from .fused_adam import FusedAdamBuilder from .no_impl import NotImplementedBuilder diff --git a/op_builder/cpu/fused_adam.py b/op_builder/cpu/fused_adam.py new file mode 100644 index 000000000000..e11a81c76e7e --- /dev/null +++ b/op_builder/cpu/fused_adam.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + +import sys + + +class FusedAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/cpu/adam/fused_adam.cpp'] + + def include_paths(self): + return ['csrc/cpu/includes', 'csrc/cpu/adam'] From f8916365689d4899bfefcaba19cc488270d5693b Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Mon, 17 Jul 2023 22:56:20 -0400 Subject: [PATCH 02/17] use cpu adam to implement fused adam --- csrc/cpu/adam/fused_adam.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp index c285cd3fef3b..1fe10eccecec 100644 --- a/csrc/cpu/adam/fused_adam.cpp +++ b/csrc/cpu/adam/fused_adam.cpp @@ -229,7 +229,7 @@ int destroy_adam_optimizer(int optimizer_id) void multi_tensor_adam(int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, + std::vector> tensor_lists, /*gpmv*/ const float lr, const float beta1, const float beta2, @@ -238,6 +238,12 @@ void multi_tensor_adam(int chunk_size, const int mode, const int bias_correction, const float weight_decay) { + create_adam_optimizer(0); + for (int i = 0; i < tensor_lists[0].size(); i++) { + ds_adam_step(0, step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, + tensor_lists[1][i], tensor_lists[0][i], tensor_lists[2][i], tensor_lists[3][i]); + } + destroy_adam_optimizer(0); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) From c7c2451bcbbf9b7b3e99c0667c03773c1263d34a Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 18 Jul 2023 05:03:38 -0400 Subject: [PATCH 03/17] enable zero stage 1 and 2 for synchronized accelerator (a.k.a. CPU) --- accelerator/cpu_accelerator.py | 4 ++-- deepspeed/runtime/utils.py | 9 +++++++++ deepspeed/runtime/zero/stage_1_and_2.py | 9 ++++----- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 3a041b52f3e4..32a45f0f5874 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -86,8 +86,8 @@ def Stream(self): return None def stream(self, stream): - from deepspeed.runtime.utils import noop_decorator - return noop_decorator + from deepspeed.runtime.utils import noop_context + return noop_context() def current_stream(self, device_index=None): return None diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 0320f4c9779a..9b07812d47cf 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -51,6 +51,15 @@ def __init__(self, params): def noop_decorator(func): return func +class noop_context(object): + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass def ensure_directory_exists(filename): """Create the directory path to ``filename`` if it does not already exist. diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index ec94fb99f0cb..7258a5d28eb9 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -374,10 +374,8 @@ def __init__(self, self.reduce_bucket_size = int(reduce_bucket_size) self.allgather_bucket_size = int(allgather_bucket_size) - self.reduction_event = get_accelerator().Event(enable_timing=False, blocking=False) - self.reduction_stream = get_accelerator().Stream() - self.cpu_computation_stream = get_accelerator().Stream() - self.copy_grad_stream = get_accelerator().Stream() + self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() + #self.copy_grad_stream = get_accelerator().Stream() self.callback_queued = False self.param_dict = {} @@ -904,7 +902,8 @@ def gradient_reduction_w_predivide(self, tensor): def average_tensor(self, tensor): if self.overlap_comm: stream = self.reduction_stream - stream.wait_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + stream.wait_stream(get_accelerator().current_stream()) else: stream = get_accelerator().current_stream() From 49f5e415677c3468c5d2d090406ae8691c4ca928 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 18 Jul 2023 05:04:49 -0400 Subject: [PATCH 04/17] remove unused parameters --- csrc/cpu/adam/cpu_adam.h | 3 --- csrc/cpu/adam/fused_adam.cpp | 9 --------- 2 files changed, 12 deletions(-) diff --git a/csrc/cpu/adam/cpu_adam.h b/csrc/cpu/adam/cpu_adam.h index 8b540ca98384..5ebf400eaa6e 100644 --- a/csrc/cpu/adam/cpu_adam.h +++ b/csrc/cpu/adam/cpu_adam.h @@ -21,7 +21,6 @@ typedef unsigned short ds_half_precision_t; float* _exp_avg, \ float* _exp_avg_sq, \ size_t _param_size, \ - ds_half_precision_t* dev_param = nullptr, \ bool half_precision = false); class Adam_Optimizer { @@ -55,7 +54,6 @@ class Adam_Optimizer { float* _exp_avg, float* _exp_avg_sq, size_t param_size, - ds_half_precision_t* dev_param = nullptr, bool half_precision = false); #endif STEP(1) @@ -121,7 +119,6 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - ds_half_precision_t* dev_params, bool half_precision) { size_t new_rounded_size = 0; diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp index 1fe10eccecec..e93d17d14954 100644 --- a/csrc/cpu/adam/fused_adam.cpp +++ b/csrc/cpu/adam/fused_adam.cpp @@ -20,7 +20,6 @@ void Adam_Optimizer::Step_1(float* _params, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - ds_half_precision_t* dev_params, bool half_precision) { size_t rounded_size = 0; @@ -31,7 +30,6 @@ void Adam_Optimizer::Step_1(float* _params, _exp_avg, _exp_avg_sq, _param_size, - dev_params, half_precision); #endif if (_param_size > rounded_size) { @@ -86,7 +84,6 @@ void Adam_Optimizer::Step_4(float* _params, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - ds_half_precision_t* dev_params, bool half_precision) { size_t rounded_size = 0; @@ -97,7 +94,6 @@ void Adam_Optimizer::Step_4(float* _params, _exp_avg, _exp_avg_sq, _param_size, - dev_params, half_precision); #endif if (_param_size > rounded_size) @@ -106,7 +102,6 @@ void Adam_Optimizer::Step_4(float* _params, (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), half_precision); } @@ -155,7 +150,6 @@ void Adam_Optimizer::Step_8(float* _params, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - ds_half_precision_t* dev_params, bool half_precision) { size_t rounded_size = 0; @@ -166,7 +160,6 @@ void Adam_Optimizer::Step_8(float* _params, _exp_avg, _exp_avg_sq, _param_size, - dev_params, half_precision); #endif if (_param_size > rounded_size) @@ -175,7 +168,6 @@ void Adam_Optimizer::Step_8(float* _params, (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), half_precision); } @@ -214,7 +206,6 @@ int ds_adam_step(int optimizer_id, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), - nullptr, (params.options().dtype() == at::kHalf)); return 0; From 13bc33ad55323796a2043527c7b8d42e957037f2 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Wed, 19 Jul 2023 08:52:40 +0800 Subject: [PATCH 05/17] fix format error --- csrc/cpu/adam/cpu_adam.h | 17 +++++----- csrc/cpu/adam/fused_adam.cpp | 61 ++++++++++++++++-------------------- deepspeed/runtime/utils.py | 3 ++ 3 files changed, 37 insertions(+), 44 deletions(-) diff --git a/csrc/cpu/adam/cpu_adam.h b/csrc/cpu/adam/cpu_adam.h index 5ebf400eaa6e..12892beb86b0 100644 --- a/csrc/cpu/adam/cpu_adam.h +++ b/csrc/cpu/adam/cpu_adam.h @@ -15,12 +15,12 @@ #include typedef unsigned short ds_half_precision_t; -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + float* _exp_avg_sq, \ + size_t _param_size, \ bool half_precision = false); class Adam_Optimizer { @@ -42,9 +42,7 @@ class Adam_Optimizer { _adamw_mode(adamw_mode) { } - ~Adam_Optimizer() - { - } + ~Adam_Optimizer() {} #if defined(__AVX512__) or defined(__AVX256__) template @@ -108,7 +106,6 @@ class Adam_Optimizer { float _bias_correction2; bool _adamw_mode; - }; #if defined(__AVX512__) or defined(__AVX256__) diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp index e93d17d14954..bcadc1df0753 100644 --- a/csrc/cpu/adam/fused_adam.cpp +++ b/csrc/cpu/adam/fused_adam.cpp @@ -3,13 +3,13 @@ // DeepSpeed Team -#include "cpu_adam.h" #include #include #include #include #include #include +#include "cpu_adam.h" static std::unordered_map> s_optimizers; @@ -24,13 +24,7 @@ void Adam_Optimizer::Step_1(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - half_precision); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, half_precision); #endif if (_param_size > rounded_size) { float betta1_minus1 = 1 - _betta1; @@ -88,13 +82,7 @@ void Adam_Optimizer::Step_4(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - half_precision); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, half_precision); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), @@ -154,13 +142,7 @@ void Adam_Optimizer::Step_8(float* _params, { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - half_precision); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, half_precision); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), @@ -219,20 +201,31 @@ int destroy_adam_optimizer(int optimizer_id) } void multi_tensor_adam(int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, /*gpmv*/ - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int mode, - const int bias_correction, - const float weight_decay) { + at::Tensor noop_flag, + std::vector> tensor_lists, /*gpmv*/ + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay) +{ create_adam_optimizer(0); for (int i = 0; i < tensor_lists[0].size(); i++) { - ds_adam_step(0, step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, - tensor_lists[1][i], tensor_lists[0][i], tensor_lists[2][i], tensor_lists[3][i]); + ds_adam_step(0, + step, + lr, + beta1, + beta2, + epsilon, + weight_decay, + bias_correction, + tensor_lists[1][i], + tensor_lists[0][i], + tensor_lists[2][i], + tensor_lists[3][i]); } destroy_adam_optimizer(0); } diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 9b07812d47cf..973264f901fe 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -51,7 +51,9 @@ def __init__(self, params): def noop_decorator(func): return func + class noop_context(object): + def __init__(self): pass @@ -61,6 +63,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass + def ensure_directory_exists(filename): """Create the directory path to ``filename`` if it does not already exist. From 6f9d839d958f98524451cbe512a915fa667d867c Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 18 Jul 2023 23:19:13 -0400 Subject: [PATCH 06/17] Remove adam class --- csrc/cpu/adam/cpu_adam.h | 145 ++++++++++++----------------------- csrc/cpu/adam/fused_adam.cpp | 141 +++++++++++++++------------------- 2 files changed, 109 insertions(+), 177 deletions(-) diff --git a/csrc/cpu/adam/cpu_adam.h b/csrc/cpu/adam/cpu_adam.h index 12892beb86b0..3d366b5a505c 100644 --- a/csrc/cpu/adam/cpu_adam.h +++ b/csrc/cpu/adam/cpu_adam.h @@ -15,34 +15,21 @@ #include typedef unsigned short ds_half_precision_t; -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - bool half_precision = false); - -class Adam_Optimizer { -public: - Adam_Optimizer(float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true) - : _alpha(alpha), - _betta1(betta1), - _betta2(betta2), - _eps(eps), - _weight_decay(weight_decay), - _betta1_t(1.0), - _betta2_t(1.0), - _step(0), - _adamw_mode(adamw_mode) - { - } - ~Adam_Optimizer() {} +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + float lr, \ + float betta1, \ + float betta2, \ + float eps, \ + float weight_decay, \ + float bias_correction1, \ + float bias_correction2, \ + bool half_precision = false, \ + bool adam_mode = true); #if defined(__AVX512__) or defined(__AVX256__) template @@ -52,101 +39,67 @@ class Adam_Optimizer { float* _exp_avg, float* _exp_avg_sq, size_t param_size, - bool half_precision = false); + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode); #endif - STEP(1) - STEP(4) - STEP(8) - inline void IncrementStep(size_t step, float beta1, float beta2) - { - if (beta1 != _betta1 || beta2 != _betta2) { - _step = step; - _betta1 = beta1; - _betta2 = beta2; - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; - } - } - } - inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) - { - _alpha = lr; - _eps = epsilon; - _weight_decay = weight_decay; - - _bias_correction1 = 1.0f; - _bias_correction2 = 1.0f; - if (bias_correction == 1) { - _bias_correction1 = 1 - _betta1_t; - _bias_correction2 = 1 / sqrt(1 - _betta2_t); - } - } - -private: - float _alpha; - float _betta1; - float _betta2; - float _eps; - float _weight_decay; - - float _betta1_t; - float _betta2_t; - size_t _step; - - float _bias_correction1; - float _bias_correction2; - - bool _adamw_mode; -}; +STEP(1) +STEP(4) +STEP(8) #if defined(__AVX512__) or defined(__AVX256__) template -void Adam_Optimizer::Step_AVX(size_t* rounded_size, +void Step_AVX(size_t* rounded_size, float* _params, float* grads, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - bool half_precision) + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t new_rounded_size = 0; int rshft = half_precision ? 1 : 0; AVX_Data betta1_4; - betta1_4.data = SIMD_SET(_betta1); + betta1_4.data = SIMD_SET(betta1); AVX_Data betta2_4; - betta2_4.data = SIMD_SET(_betta2); + betta2_4.data = SIMD_SET(betta2); - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; + float betta1_minus1 = 1 - betta1; + float betta2_minus1 = 1 - betta2; AVX_Data betta1_minus1_4; betta1_minus1_4.data = SIMD_SET(betta1_minus1); AVX_Data betta2_minus1_4; betta2_minus1_4.data = SIMD_SET(betta2_minus1); AVX_Data bias2_sqrt; - bias2_sqrt.data = SIMD_SET(_bias_correction2); + bias2_sqrt.data = SIMD_SET(bias_correction2); AVX_Data eps_4; - eps_4.data = SIMD_SET(_eps); + eps_4.data = SIMD_SET(eps); - float step_size = -1 * _alpha / _bias_correction1; + float step_size = -1 * lr / bias_correction1; AVX_Data step_size_4; step_size_4.data = SIMD_SET(step_size); - float w_decay = -1 * _alpha * _weight_decay; + float w_decay = -1 * lr * weight_decay; AVX_Data weight_decay4; - if (_weight_decay > 0) - weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + if (weight_decay > 0) + weight_decay4.data = (adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(weight_decay)); new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); for (size_t t = 0; t < new_rounded_size; t += TILE) { size_t copy_size = TILE; @@ -166,7 +119,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, AVX_Data param_4[span]; simd_load(param_4, _params + (i >> rshft), half_precision); - if (_weight_decay > 0 && !_adamw_mode) { + if (weight_decay > 0 && !adamw_mode) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } @@ -179,7 +132,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); simd_div(grad_4, momentum_4, grad_4); - if (_weight_decay > 0 && _adamw_mode) { + if (weight_decay > 0 && adamw_mode) { simd_fma(param_4, param_4, weight_decay4, param_4); } diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp index bcadc1df0753..5f05f5f77469 100644 --- a/csrc/cpu/adam/fused_adam.cpp +++ b/csrc/cpu/adam/fused_adam.cpp @@ -15,23 +15,31 @@ static std::unordered_map> s_optimizers; // C++ interface -void Adam_Optimizer::Step_1(float* _params, +void Step_1(float* _params, float* grads, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - bool half_precision) + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, half_precision); + Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, lr, betta1, betta2, eps, weight_decay, bias_correction1, bias_correction2, half_precision, adamw_mode); #endif if (_param_size > rounded_size) { - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; + float betta1_minus1 = 1 - betta1; + float betta2_minus1 = 1 - betta2; - float step_size = -1 * _alpha / _bias_correction1; - float w_decay = -1 * _alpha * _weight_decay; + float step_size = -1 * lr / bias_correction1; + float w_decay = -1 * lr * weight_decay; ds_half_precision_t* grads_cast_h; ds_half_precision_t* params_cast_h; if (half_precision) { @@ -49,18 +57,18 @@ void Adam_Optimizer::Step_1(float* _params, float param = half_precision ? (float)params_cast_h[k] : _params[k]; float momentum = _exp_avg[k]; float variance = _exp_avg_sq[k]; - if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } - momentum = momentum * _betta1; + if (weight_decay > 0 && !adamw_mode) { grad = param * weight_decay + grad; } + momentum = momentum * betta1; momentum = grad * betta1_minus1 + momentum; - variance = variance * _betta2; + variance = variance * betta2; grad = grad * grad; variance = grad * betta2_minus1 + variance; grad = sqrt(variance); - grad = grad * _bias_correction2 + _eps; + grad = grad * bias_correction2 + eps; grad = momentum / grad; - if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + if (weight_decay > 0 && adamw_mode) { param += w_decay * param; } param = grad * step_size + param; if (half_precision) params_cast_h[k] = (ds_half_precision_t)param; @@ -73,16 +81,24 @@ void Adam_Optimizer::Step_1(float* _params, } } -void Adam_Optimizer::Step_4(float* _params, +void Step_4(float* _params, float* grads, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - bool half_precision) + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, half_precision); + Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, lr, betta1, betta2, eps, weight_decay, bias_correction1, bias_correction2, half_precision, adamw_mode); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), @@ -90,59 +106,29 @@ void Adam_Optimizer::Step_4(float* _params, (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), - half_precision); + lr, betta1, betta2, eps, weight_decay, + bias_correction1, bias_correction2, + half_precision, adamw_mode); } -int create_adam_optimizer(int optimizer_id, - float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true, - bool should_log = false) -{ - auto opt = - std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); - - s_optimizers[optimizer_id] = opt; - - if (should_log) { - std::string avx_type = ""; -#if defined(__AVX512__) - avx_type = "AVX512"; -#else -#if defined(__AVX256__) - avx_type = "AVX2"; -#else - avx_type = "scalar"; -#endif -#endif - - printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", - optimizer_id, - avx_type.c_str()); - printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", - alpha, - betta1, - betta2, - weight_decay, - (int)adamw_mode); - } - - return 0; -} - -void Adam_Optimizer::Step_8(float* _params, +void Step_8(float* _params, float* grads, float* _exp_avg, float* _exp_avg_sq, size_t _param_size, - bool half_precision) + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, half_precision); + Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, lr, betta1, betta2, eps, weight_decay, bias_correction1, bias_correction2, half_precision, adamw_mode); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), @@ -150,7 +136,9 @@ void Adam_Optimizer::Step_8(float* _params, (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), - half_precision); + lr, betta1, betta2, eps, weight_decay, + bias_correction1, bias_correction2, + half_precision, adamw_mode); } int ds_adam_step(int optimizer_id, @@ -161,6 +149,7 @@ int ds_adam_step(int optimizer_id, float epsilon, float weight_decay, bool bias_correction, + bool adam_mode, torch::Tensor& params, torch::Tensor& grads, torch::Tensor& exp_avg, @@ -178,24 +167,20 @@ int ds_adam_step(int optimizer_id, float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - - opt->Step_8(params_ptr, + float bias_correction1 = 1.0f, bias_correction2= 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1.0 - std::pow(beta1, step); + bias_correction2 = 1 / sqrt(1.0 - std::pow(beta2, step)); + } + Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), - (params.options().dtype() == at::kHalf)); - - return 0; -} - -int destroy_adam_optimizer(int optimizer_id) -{ - s_optimizers.erase(optimizer_id); + lr, beta1, beta2, epsilon, weight_decay, + bias_correction1, bias_correction2, + (params.options().dtype() == at::kHalf), + adam_mode); return 0; } @@ -212,7 +197,6 @@ void multi_tensor_adam(int chunk_size, const int bias_correction, const float weight_decay) { - create_adam_optimizer(0); for (int i = 0; i < tensor_lists[0].size(); i++) { ds_adam_step(0, step, @@ -222,21 +206,16 @@ void multi_tensor_adam(int chunk_size, epsilon, weight_decay, bias_correction, + mode, tensor_lists[1][i], tensor_lists[0][i], tensor_lists[2][i], tensor_lists[3][i]); } - destroy_adam_optimizer(0); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - /* - m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); - m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); - m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); - */ m.def("multi_tensor_adam", &multi_tensor_adam, "Compute and apply gradient update to parameters for Adam optimizer"); From 618936ff131a3c028eef5cb5e927edff16c8228b Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 19 Jul 2023 14:47:48 +0800 Subject: [PATCH 07/17] fix format --- csrc/cpu/adam/cpu_adam.h | 60 ++++++------ csrc/cpu/adam/fused_adam.cpp | 173 +++++++++++++++++++++++------------ 2 files changed, 146 insertions(+), 87 deletions(-) diff --git a/csrc/cpu/adam/cpu_adam.h b/csrc/cpu/adam/cpu_adam.h index 3d366b5a505c..c288f8d5a49f 100644 --- a/csrc/cpu/adam/cpu_adam.h +++ b/csrc/cpu/adam/cpu_adam.h @@ -32,22 +32,22 @@ typedef unsigned short ds_half_precision_t; bool adam_mode = true); #if defined(__AVX512__) or defined(__AVX256__) - template - void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode); +template +void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t param_size, + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode); #endif STEP(1) STEP(4) @@ -56,20 +56,20 @@ STEP(8) #if defined(__AVX512__) or defined(__AVX256__) template void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t new_rounded_size = 0; int rshft = half_precision ? 1 : 0; diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp index 5f05f5f77469..cae8ad362126 100644 --- a/csrc/cpu/adam/fused_adam.cpp +++ b/csrc/cpu/adam/fused_adam.cpp @@ -16,23 +16,37 @@ static std::unordered_map> s_optimizers; // C++ interface void Step_1(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, lr, betta1, betta2, eps, weight_decay, bias_correction1, bias_correction2, half_precision, adamw_mode); + Step_AVX<1>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + lr, + betta1, + betta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + half_precision, + adamw_mode); #endif if (_param_size > rounded_size) { float betta1_minus1 = 1 - betta1; @@ -82,23 +96,37 @@ void Step_1(float* _params, } void Step_4(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, lr, betta1, betta2, eps, weight_decay, bias_correction1, bias_correction2, half_precision, adamw_mode); + Step_AVX<4>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + lr, + betta1, + betta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + half_precision, + adamw_mode); #endif if (_param_size > rounded_size) Step_1((_params + rounded_size), @@ -106,29 +134,49 @@ void Step_4(float* _params, (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), - lr, betta1, betta2, eps, weight_decay, - bias_correction1, bias_correction2, - half_precision, adamw_mode); + lr, + betta1, + betta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + half_precision, + adamw_mode); } void Step_8(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + float lr, + float betta1, + float betta2, + float eps, + float weight_decay, + float bias_correction1, + float bias_correction2, + bool half_precision, + bool adamw_mode) { size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, _params, grads, _exp_avg, _exp_avg_sq, _param_size, lr, betta1, betta2, eps, weight_decay, bias_correction1, bias_correction2, half_precision, adamw_mode); + Step_AVX<8>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + lr, + betta1, + betta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + half_precision, + adamw_mode); #endif if (_param_size > rounded_size) Step_4((_params + rounded_size), @@ -136,9 +184,15 @@ void Step_8(float* _params, (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), (_param_size - rounded_size), - lr, betta1, betta2, eps, weight_decay, - bias_correction1, bias_correction2, - half_precision, adamw_mode); + lr, + betta1, + betta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + half_precision, + adamw_mode); } int ds_adam_step(int optimizer_id, @@ -167,20 +221,25 @@ int ds_adam_step(int optimizer_id, float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - float bias_correction1 = 1.0f, bias_correction2= 1.0f; + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { bias_correction1 = 1.0 - std::pow(beta1, step); bias_correction2 = 1 / sqrt(1.0 - std::pow(beta2, step)); } Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - lr, beta1, beta2, epsilon, weight_decay, - bias_correction1, bias_correction2, - (params.options().dtype() == at::kHalf), - adam_mode); + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + lr, + beta1, + beta2, + epsilon, + weight_decay, + bias_correction1, + bias_correction2, + (params.options().dtype() == at::kHalf), + adam_mode); return 0; } From 10f8b3846b833ba95b5ac2f0d2271fcf5c74f1d2 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 19 Jul 2023 04:56:08 -0400 Subject: [PATCH 08/17] support stage3 --- accelerator/cpu_accelerator.py | 2 +- csrc/cpu/comm/ccl.cpp | 1 + deepspeed/runtime/zero/parameter_offload.py | 2 +- .../runtime/zero/partition_parameters.py | 6 ++- .../zero/partitioned_param_coordinator.py | 10 +++-- deepspeed/runtime/zero/stage3.py | 38 ++++++++++++------- 6 files changed, 37 insertions(+), 22 deletions(-) diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 035b9b6b4740..f9228fb506c6 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -139,7 +139,7 @@ def reset_max_memory_cached(self, device_index=None): return def memory_stats(self, device_index=None): - return self.get_rss() + return None def reset_peak_memory_stats(self, device_index=None): self.reset_rss() diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp index 0f974fb859e0..77ae3ea9d70e 100644 --- a/csrc/cpu/comm/ccl.cpp +++ b/csrc/cpu/comm/ccl.cpp @@ -98,6 +98,7 @@ ccl::datatype get_ccl_datatype(c10::ScalarType type) ccl::datatype ccl_type; switch (type) { case c10::ScalarType::Int: ccl_type = ccl::datatype::int32; break; + case c10::ScalarType::Long: ccl_type = ccl::datatype::int64; break; case c10::ScalarType::Float: ccl_type = ccl::datatype::float32; break; case c10::ScalarType::Double: ccl_type = ccl::datatype::float64; break; case c10::ScalarType::BFloat16: ccl_type = ccl::datatype::bfloat16; break; diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index d6ad32b16f99..9e29957782c1 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -247,7 +247,7 @@ def __init__(self, self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) - self.__allgather_stream = get_accelerator().Stream() if overlap_comm else get_accelerator().default_stream() + self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() if overlap_comm else get_accelerator().default_stream() if not hasattr(module, "ds_inflight_param_registry"): module.ds_inflight_param_registry = dict() diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 1718a2af113a..a7ad2ce32823 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -273,7 +273,8 @@ def free_param(param: Parameter) -> None: if get_accelerator().on_accelerator(param.data): # need to make sure that we don't free the parameter while it is still # being used for computation - param.data.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + param.data.record_stream(get_accelerator().current_stream()) # param.data doesn't store anything meaningful in partitioned state param.data = torch.empty(0, dtype=param.dtype, device=param.device) param.ds_status = ZeroParamStatus.NOT_AVAILABLE @@ -609,7 +610,8 @@ def wait(self) -> None: param.ds_status = ZeroParamStatus.AVAILABLE for part_to_copy in partitions: - part_to_copy.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + part_to_copy.record_stream(get_accelerator().current_stream()) param_offset += ds_tensor_numel diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index b5d3e50a9c3e..666b127ce4d4 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -296,12 +296,14 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: self.__inflight_param_registry.pop(param).wait() - event = get_accelerator().Event() - event.record() - self.__ongoing_fetch_events.append(event) + if not get_accelerator().is_synchronized_device(): + event = get_accelerator().Event() + event.record() + self.__ongoing_fetch_events.append(event) assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if not get_accelerator().is_synchronized_device(): + get_accelerator().current_stream().wait_stream(self.__allgather_stream) self.__profiler.stop_event(wait_event_name, wait_numel) # kick off parameter prefetches for upcoming modules diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d8d25ffa8e29..454426e5396f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -188,7 +188,7 @@ def __init__(self, self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu ### streams used for overlapping computation with communication - self.reduce_and_partition_stream = get_accelerator().Stream() if overlap_comm else get_accelerator( + self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() if overlap_comm else get_accelerator( ).default_stream() ############################################################################ @@ -996,7 +996,8 @@ def independent_gradient_partition_epilogue(self): self.__reduce_and_partition_ipg_grads() self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - self.reduce_and_partition_stream.synchronize() + if not get_accelerator().is_synchronized_device(): + self.reduce_and_partition_stream.synchronize() #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad #TODO: use a similar code path for both cpu_offload and non-cpu offload @@ -1083,7 +1084,8 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): @instrument_w_nvtx @torch.no_grad() def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - self.reduce_and_partition_stream.wait_stream(get_accelerator().default_stream()) + if not get_accelerator().is_synchronized_device(): + self.reduce_and_partition_stream.wait_stream(get_accelerator().default_stream()) if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel() < self.reduce_bucket_size: # move the gradient to a contiguous buffer @@ -1092,7 +1094,8 @@ def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow(0, self.elements_in_ipg_bucket, param.grad.numel()).view_as(param.grad) new_grad_tensor.copy_(param.grad, non_blocking=True) - param.grad.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + param.grad.record_stream(get_accelerator().current_stream()) param.grad.data = new_grad_tensor self.params_in_ipg_bucket.append(param) @@ -1130,9 +1133,10 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: self.params_in_ipg_bucket.clear() - event = get_accelerator().Event() - event.record() - self.param_reduce_events.append(event) + if not get_accelerator().is_synchronized_device(): + event = get_accelerator().Event() + event.record() + self.param_reduce_events.append(event) @instrument_w_nvtx def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]: @@ -1317,7 +1321,8 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L fp32_grad_tensor.copy_(grad_buffer) # free the gradient - param.grad.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + param.grad.record_stream(get_accelerator().current_stream()) param.grad = None if self.offload_optimizer and self.swap_optimizer: @@ -1690,8 +1695,9 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id): # release all the gradient since we have already created a necessary copy in dp_grad_partition self.zero_grad(set_to_none=True) - for grad in filter(lambda g: get_accelerator().on_accelerator(g), self.averaged_gradients[sub_group_id]): - grad.record_stream(get_accelerator().current_stream()) + if not get_accelerator().is_synchronized_device(): + for grad in filter(lambda g: get_accelerator().on_accelerator(g), self.averaged_gradients[sub_group_id]): + grad.record_stream(get_accelerator().current_stream()) self.averaged_gradients[sub_group_id] = None @@ -1979,7 +1985,8 @@ def has_overflow(self, partition_gradients=True): overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8) self.inf_or_nan_tracker.zero_() - get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream) + if not get_accelerator().is_synchronized_device(): + get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) else: @@ -2049,7 +2056,8 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: """get fp32 gradient partition dictionary accessed as grad_dict[parameter_group_index][parameter_index] """ - self.reduce_and_partition_stream.synchronize() + if not get_accelerator().is_synchronized_device(): + self.reduce_and_partition_stream.synchronize() grad_dict = collections.defaultdict(dict) if self.offload_optimizer: for group in self.fp16_groups: @@ -2083,7 +2091,8 @@ def get_fp32_grad_for_param(self, param) -> Tensor: if not param.requires_grad: return None - self.reduce_and_partition_stream.synchronize() + if not get_accelerator().is_synchronized_device(): + self.reduce_and_partition_stream.synchronize() if self.offload_optimizer: group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] @@ -2098,7 +2107,8 @@ def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: if not param.requires_grad: return None - self.reduce_and_partition_stream.synchronize() + if not get_accelerator().is_synchronized_device(): + self.reduce_and_partition_stream.synchronize() group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] if self._swappable_optimizer_subgroup(group_idx): From 3c749bdd9c72db7e6dded32c38b9cc9972211547 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 19 Jul 2023 08:46:21 -0400 Subject: [PATCH 09/17] reuse simd.h --- csrc/cpu/adam/simd.h | 140 ----------------------------------- op_builder/cpu/fused_adam.py | 2 +- 2 files changed, 1 insertion(+), 141 deletions(-) delete mode 100644 csrc/cpu/adam/simd.h diff --git a/csrc/cpu/adam/simd.h b/csrc/cpu/adam/simd.h deleted file mode 100644 index 712dd5b32e96..000000000000 --- a/csrc/cpu/adam/simd.h +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#pragma once - -#if (__x86_64__ || __i386__) -#include -#include -#endif - -#define TILE (128 * 1024 * 1024) -#if defined(__AVX512__) or defined(__AVX256__) - -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) - -#if defined(__AVX512__) -#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm512_loadu_ps(x) -#define SIMD_SET(x) _mm512_set1_ps(x) -#define SIMD_ADD(x, y) _mm512_add_ps(x, y) -#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm512_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_WIDTH 16 - -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm512_storeu_ps(x, d)) - -#define INTV __m256i -#elif defined(__AVX256__) -#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm256_loadu_ps(x) -#define SIMD_SET(x) _mm256_set1_ps(x) -#define SIMD_ADD(x, y) _mm256_add_ps(x, y) -#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm256_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm256_div_ps(x, y) -#define SIMD_WIDTH 8 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) - -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm256_storeu_ps(x, d)) - -#define INTV __m128i -#endif - -union AVX_Data { -#if defined(__AVX512__) - __m512 data; -#elif defined(__AVX256__) - __m256 data; -#endif - // float data_f[16]; -}; - -template -inline void simd_store(float* dst, AVX_Data* src, bool half_precision) -{ - size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH); -#pragma unroll - for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); } -} -template -inline void simd_load(AVX_Data* dst, float* src, bool half_precision) -{ - size_t width = (half_precision ? 1 : SIMD_WIDTH); -#pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data); - } -} -template -inline void simd_sqrt(AVX_Data* dst, AVX_Data* src) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); } -} -template -inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); } -} -template -inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); } -} -template -inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); } -} -template -inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); } -} -template -inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma unroll - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } -} - -#endif diff --git a/op_builder/cpu/fused_adam.py b/op_builder/cpu/fused_adam.py index e11a81c76e7e..41d8ea56d124 100644 --- a/op_builder/cpu/fused_adam.py +++ b/op_builder/cpu/fused_adam.py @@ -22,4 +22,4 @@ def sources(self): return ['csrc/cpu/adam/fused_adam.cpp'] def include_paths(self): - return ['csrc/cpu/includes', 'csrc/cpu/adam'] + return ['csrc/includes', 'csrc/cpu/adam'] From 20a7d8d0017bf6a6f415413208bbc6b57b8b126a Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 19 Jul 2023 10:16:19 -0400 Subject: [PATCH 10/17] fix format --- op_builder/cpu/fused_adam.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/op_builder/cpu/fused_adam.py b/op_builder/cpu/fused_adam.py index 41d8ea56d124..a5ac257d3301 100644 --- a/op_builder/cpu/fused_adam.py +++ b/op_builder/cpu/fused_adam.py @@ -5,8 +5,6 @@ from .builder import CPUOpBuilder -import sys - class FusedAdamBuilder(CPUOpBuilder): BUILD_VAR = "DS_BUILD_FUSED_ADAM" From d77d67b386e676d4126835141b6d8f85068a4b1a Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 19 Jul 2023 21:36:42 -0400 Subject: [PATCH 11/17] make memory_stat return meaningful dict --- accelerator/cpu_accelerator.py | 6 +++++- deepspeed/runtime/zero/stage3.py | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index f9228fb506c6..f3c6b8bab2e9 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -139,7 +139,11 @@ def reset_max_memory_cached(self, device_index=None): return def memory_stats(self, device_index=None): - return None + mem = self.get_rss() + mem_stat = {} + mem_stat['allocated_bytes.all.current'] = mem + mem_stat['allocated_bytes.all.peak'] = self.max_mem + return mem_stat def reset_peak_memory_stats(self, device_index=None): self.reset_rss() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 454426e5396f..52e1ba12c8d6 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -188,8 +188,8 @@ def __init__(self, self.device = get_accelerator().current_device_name() if not self.offload_optimizer else OffloadDeviceEnum.cpu ### streams used for overlapping computation with communication - self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() if overlap_comm else get_accelerator( - ).default_stream() + self.reduce_and_partition_stream = None if get_accelerator().is_synchronized_device() else get_accelerator( + ).Stream() if overlap_comm else get_accelerator().default_stream() ############################################################################ @@ -1898,7 +1898,9 @@ def step(self, closure=None): # warn user about caching allocator flushes memory_stats = get_accelerator().memory_stats() - alloc_retries = memory_stats["num_alloc_retries"] if memory_stats is not None else 0 + alloc_retries = memory_stats.get("num_alloc_retries") + if alloc_retries == None: + alloc_retries = 0 if alloc_retries > self.n_caching_allocator_flushes: if dist.get_rank() == 0: logger.warning( From 4745bdf96178a87a16cd69246b1a0199b1ab7411 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Thu, 20 Jul 2023 23:30:46 +0800 Subject: [PATCH 12/17] fix format --- deepspeed/runtime/zero/parameter_offload.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 9e29957782c1..d89ad2436857 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -247,7 +247,8 @@ def __init__(self, self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) - self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() if overlap_comm else get_accelerator().default_stream() + self.__allgather_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream( + ) if overlap_comm else get_accelerator().default_stream() if not hasattr(module, "ds_inflight_param_registry"): module.ds_inflight_param_registry = dict() From 4d30d4d0681cd702c9e06df9737f85ae918eb1b5 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Mon, 24 Jul 2023 22:14:08 -0400 Subject: [PATCH 13/17] add cpu_adam --- accelerator/cpu_accelerator.py | 6 ++++-- op_builder/cpu/__init__.py | 1 + op_builder/cpu/cpu_adam.py | 27 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 op_builder/cpu/cpu_adam.py diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 4bfb026f17aa..33320bf224b1 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -254,14 +254,16 @@ def get_op_builder(self, class_name): # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path from op_builder import __deepspeed__ # noqa: F401 - from op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, NotImplementedBuilder + from op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder except ImportError: - from deepspeed.ops.op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, NotImplementedBuilder + from deepspeed.ops.op_builder.cpu import CCLCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder if class_name == "CCLCommBuilder": return CCLCommBuilder elif class_name == "FusedAdamBuilder": return FusedAdamBuilder + elif class_name == "CPUAdamBuilder": + return CPUAdamBuilder else: # return a NotImplementedBuilder to avoid get NoneType[Name] in unit tests return NotImplementedBuilder diff --git a/op_builder/cpu/__init__.py b/op_builder/cpu/__init__.py index 301edcd5b565..0ecf2caf65ae 100644 --- a/op_builder/cpu/__init__.py +++ b/op_builder/cpu/__init__.py @@ -6,4 +6,5 @@ from .comm import CCLCommBuilder from .fused_adam import FusedAdamBuilder +from .cpu_adam import CPUAdamBuilder from .no_impl import NotImplementedBuilder diff --git a/op_builder/cpu/cpu_adam.py b/op_builder/cpu/cpu_adam.py new file mode 100644 index 000000000000..30343226f469 --- /dev/null +++ b/op_builder/cpu/cpu_adam.py @@ -0,0 +1,27 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class CPUAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp'] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] From 06a89fc73f157d90de0d544f00a17f2630a82fe6 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 25 Jul 2023 04:10:10 -0400 Subject: [PATCH 14/17] reuse cpu_adam --- csrc/adam/cpu_adam.cpp | 279 ----------------------------------- csrc/cpu/adam/cpu_adam.h | 148 ------------------- csrc/cpu/adam/fused_adam.cpp | 237 +---------------------------- csrc/includes/cpu_adam.h | 23 +++ op_builder/cpu/cpu_adam.py | 2 +- op_builder/cpu/fused_adam.py | 4 +- op_builder/cpu_adam.py | 4 +- 7 files changed, 33 insertions(+), 664 deletions(-) delete mode 100644 csrc/cpu/adam/cpu_adam.h diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 4d3d5a45e628..cbf20a2e9651 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -19,285 +19,6 @@ #include "custom_cuda_layers.h" #endif -static std::unordered_map> s_optimizers; - -// C++ interface - -void Adam_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) { - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; - - float step_size = -1 * _alpha / _bias_correction1; - float w_decay = -1 * _alpha * _weight_decay; - ds_half_precision_t* grads_cast_h; - ds_half_precision_t* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast(grads); - params_cast_h = reinterpret_cast(_params); - } - - for (size_t t = rounded_size; t < _param_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; - size_t offset = copy_size + t; -#if defined(__ENABLE_CUDA__) - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#endif -#pragma omp parallel for - for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; - if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } - momentum = momentum * _betta1; - momentum = grad * betta1_minus1 + momentum; - - variance = variance * _betta2; - grad = grad * grad; - variance = grad * betta2_minus1 + variance; - - grad = sqrt(variance); - grad = grad * _bias_correction2 + _eps; - grad = momentum / grad; - if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } - param = grad * step_size + param; -#if defined(__ENABLE_CUDA__) - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; -#endif - if (half_precision) - params_cast_h[k] = (ds_half_precision_t)param; - else - _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; - } -#if defined(__ENABLE_CUDA__) - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - - _buf_index = !_buf_index; - } -#endif - } - } -} - -void Adam_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_1((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int create_adam_optimizer(int optimizer_id, - float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true, - bool should_log = false) -{ - auto opt = - std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); - - s_optimizers[optimizer_id] = opt; - - if (should_log) { - std::string avx_type = ""; -#if defined(__AVX512__) - avx_type = "AVX512"; -#else -#if defined(__AVX256__) - avx_type = "AVX2"; -#else - avx_type = "scalar"; -#endif -#endif - - printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", - optimizer_id, - avx_type.c_str()); - printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", - alpha, - betta1, - betta2, - weight_decay, - (int)adamw_mode); - } - - return 0; -} - -void Adam_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - ds_half_precision_t* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int ds_adam_step(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - - // assert(params.options().dtype() == grads.options().dtype()); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - nullptr, - (params.options().dtype() == at::kHalf)); - -#if defined(__ENABLE_CUDA__) - opt->SynchronizeStreams(); -#endif - return 0; -} - -int ds_adam_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ -#if defined(__ENABLE_CUDA__) - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); -#else - assert(false); -#endif - return 0; -} - -int destroy_adam_optimizer(int optimizer_id) -{ - s_optimizers.erase(optimizer_id); - - return 0; -} - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); diff --git a/csrc/cpu/adam/cpu_adam.h b/csrc/cpu/adam/cpu_adam.h deleted file mode 100644 index c288f8d5a49f..000000000000 --- a/csrc/cpu/adam/cpu_adam.h +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#pragma once - -#define NOMINMAX // Windows idiosyncrasy - // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c - -#include -#include -#include "simd.h" - -#include -typedef unsigned short ds_half_precision_t; - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - float lr, \ - float betta1, \ - float betta2, \ - float eps, \ - float weight_decay, \ - float bias_correction1, \ - float bias_correction2, \ - bool half_precision = false, \ - bool adam_mode = true); - -#if defined(__AVX512__) or defined(__AVX256__) -template -void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode); -#endif -STEP(1) -STEP(4) -STEP(8) - -#if defined(__AVX512__) or defined(__AVX256__) -template -void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) -{ - size_t new_rounded_size = 0; - int rshft = half_precision ? 1 : 0; - - AVX_Data betta1_4; - betta1_4.data = SIMD_SET(betta1); - AVX_Data betta2_4; - betta2_4.data = SIMD_SET(betta2); - - float betta1_minus1 = 1 - betta1; - float betta2_minus1 = 1 - betta2; - AVX_Data betta1_minus1_4; - betta1_minus1_4.data = SIMD_SET(betta1_minus1); - AVX_Data betta2_minus1_4; - betta2_minus1_4.data = SIMD_SET(betta2_minus1); - - AVX_Data bias2_sqrt; - bias2_sqrt.data = SIMD_SET(bias_correction2); - - AVX_Data eps_4; - eps_4.data = SIMD_SET(eps); - - float step_size = -1 * lr / bias_correction1; - AVX_Data step_size_4; - step_size_4.data = SIMD_SET(step_size); - - float w_decay = -1 * lr * weight_decay; - AVX_Data weight_decay4; - if (weight_decay > 0) - weight_decay4.data = (adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(weight_decay)); - new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); - for (size_t t = 0; t < new_rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { - AVX_Data grad_4[span]; - simd_load(grad_4, grads + (i >> rshft), half_precision); - - AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); - - AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); - - AVX_Data param_4[span]; - simd_load(param_4, _params + (i >> rshft), half_precision); - - if (weight_decay > 0 && !adamw_mode) { - simd_fma(grad_4, param_4, weight_decay4, grad_4); - } - - simd_mul(momentum_4, momentum_4, betta1_4); - simd_fma(momentum_4, grad_4, betta1_minus1_4, momentum_4); - simd_mul(variance_4, variance_4, betta2_4); - simd_mul(grad_4, grad_4, grad_4); - simd_fma(variance_4, grad_4, betta2_minus1_4, variance_4); - simd_sqrt(grad_4, variance_4); - simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); - simd_div(grad_4, momentum_4, grad_4); - - if (weight_decay > 0 && adamw_mode) { - simd_fma(param_4, param_4, weight_decay4, param_4); - } - - simd_fma(param_4, grad_4, step_size_4, param_4); - - simd_store(_params + (i >> rshft), param_4, half_precision); - simd_store(_exp_avg + i, momentum_4, false); - simd_store(_exp_avg_sq + i, variance_4, false); - } - } - *rounded_size = new_rounded_size; -} -#endif diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp index cae8ad362126..cc1f0336b6c6 100644 --- a/csrc/cpu/adam/fused_adam.cpp +++ b/csrc/cpu/adam/fused_adam.cpp @@ -11,239 +11,8 @@ #include #include "cpu_adam.h" -static std::unordered_map> s_optimizers; - // C++ interface -void Step_1(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - lr, - betta1, - betta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - half_precision, - adamw_mode); -#endif - if (_param_size > rounded_size) { - float betta1_minus1 = 1 - betta1; - float betta2_minus1 = 1 - betta2; - - float step_size = -1 * lr / bias_correction1; - float w_decay = -1 * lr * weight_decay; - ds_half_precision_t* grads_cast_h; - ds_half_precision_t* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast(grads); - params_cast_h = reinterpret_cast(_params); - } - - for (size_t t = rounded_size; t < _param_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; - if (weight_decay > 0 && !adamw_mode) { grad = param * weight_decay + grad; } - momentum = momentum * betta1; - momentum = grad * betta1_minus1 + momentum; - - variance = variance * betta2; - grad = grad * grad; - variance = grad * betta2_minus1 + variance; - - grad = sqrt(variance); - grad = grad * bias_correction2 + eps; - grad = momentum / grad; - if (weight_decay > 0 && adamw_mode) { param += w_decay * param; } - param = grad * step_size + param; - if (half_precision) - params_cast_h[k] = (ds_half_precision_t)param; - else - _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; - } - } - } -} - -void Step_4(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - lr, - betta1, - betta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - half_precision, - adamw_mode); -#endif - if (_param_size > rounded_size) - Step_1((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - lr, - betta1, - betta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - half_precision, - adamw_mode); -} - -void Step_8(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - float lr, - float betta1, - float betta2, - float eps, - float weight_decay, - float bias_correction1, - float bias_correction2, - bool half_precision, - bool adamw_mode) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - lr, - betta1, - betta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - half_precision, - adamw_mode); -#endif - if (_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - lr, - betta1, - betta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - half_precision, - adamw_mode); -} - -int ds_adam_step(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - bool adam_mode, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - - // assert(params.options().dtype() == grads.options().dtype()); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; - if (bias_correction == 1) { - bias_correction1 = 1.0 - std::pow(beta1, step); - bias_correction2 = 1 / sqrt(1.0 - std::pow(beta2, step)); - } - Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.numel(), - lr, - beta1, - beta2, - epsilon, - weight_decay, - bias_correction1, - bias_correction2, - (params.options().dtype() == at::kHalf), - adam_mode); - - return 0; -} - void multi_tensor_adam(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, /*gpmv*/ @@ -256,6 +25,11 @@ void multi_tensor_adam(int chunk_size, const int bias_correction, const float weight_decay) { + static bool initialized = false; + if (!initialized) { + create_adam_optimizer(0); + initialized = true; + } for (int i = 0; i < tensor_lists[0].size(); i++) { ds_adam_step(0, step, @@ -265,7 +39,6 @@ void multi_tensor_adam(int chunk_size, epsilon, weight_decay, bias_correction, - mode, tensor_lists[1][i], tensor_lists[0][i], tensor_lists[2][i], diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 4648aede93ee..0a5fa44097fb 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -9,6 +9,7 @@ // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c #include +#include #include #include "simd.h" @@ -250,3 +251,25 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size, *rounded_size = new_rounded_size; } #endif + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false); + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq); diff --git a/op_builder/cpu/cpu_adam.py b/op_builder/cpu/cpu_adam.py index 30343226f469..0c8438aea40d 100644 --- a/op_builder/cpu/cpu_adam.py +++ b/op_builder/cpu/cpu_adam.py @@ -17,7 +17,7 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - return ['csrc/adam/cpu_adam.cpp'] + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] def libraries_args(self): args = super().libraries_args() diff --git a/op_builder/cpu/fused_adam.py b/op_builder/cpu/fused_adam.py index a5ac257d3301..34b43825b090 100644 --- a/op_builder/cpu/fused_adam.py +++ b/op_builder/cpu/fused_adam.py @@ -17,7 +17,7 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - return ['csrc/cpu/adam/fused_adam.cpp'] + return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] def include_paths(self): - return ['csrc/includes', 'csrc/cpu/adam'] + return ['csrc/includes'] diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 29cdced0d9f2..e500a1eea907 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -19,9 +19,9 @@ def absolute_name(self): def sources(self): if self.build_for_cpu: - return ['csrc/adam/cpu_adam.cpp'] + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] - return ['csrc/adam/cpu_adam.cpp', 'csrc/common/custom_cuda_kernel.cu'] + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp', 'csrc/common/custom_cuda_kernel.cu'] def libraries_args(self): args = super().libraries_args() From c58c19a55c16d317516bd1008135ce724ffffcc5 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 25 Jul 2023 04:31:06 -0400 Subject: [PATCH 15/17] header cleanup --- csrc/adam/cpu_adam.cpp | 14 -------------- csrc/cpu/adam/fused_adam.cpp | 6 ------ 2 files changed, 20 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index cbf20a2e9651..96809827f3e1 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -4,20 +4,6 @@ // DeepSpeed Team #include "cpu_adam.h" -#include -#include -#include -#include -#include -#include - -#if defined(__ENABLE_CUDA__) -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" -#endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/csrc/cpu/adam/fused_adam.cpp b/csrc/cpu/adam/fused_adam.cpp index cc1f0336b6c6..d25578f410da 100644 --- a/csrc/cpu/adam/fused_adam.cpp +++ b/csrc/cpu/adam/fused_adam.cpp @@ -3,12 +3,6 @@ // DeepSpeed Team -#include -#include -#include -#include -#include -#include #include "cpu_adam.h" // C++ interface From 038d026832e97443a2132be8d006c2096ae7d90f Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 25 Jul 2023 04:45:06 -0400 Subject: [PATCH 16/17] fix cpu_adam --- csrc/includes/cpu_adam.h | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 0a5fa44097fb..1246df8b1449 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -273,3 +273,20 @@ int ds_adam_step(int optimizer_id, torch::Tensor& grads, torch::Tensor& exp_avg, torch::Tensor& exp_avg_sq); + +int ds_adam_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params); + +int destroy_adam_optimizer(int optimizer_id); + From 43beb57785ac7ba636bb12c00cb1dea8d49ea38b Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 25 Jul 2023 04:57:16 -0400 Subject: [PATCH 17/17] fix format, add missing file --- csrc/adam/cpu_adam_impl.cpp | 299 ++++++++++++++++++++++++++++++++++++ csrc/includes/cpu_adam.h | 1 - 2 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 csrc/adam/cpu_adam_impl.cpp diff --git a/csrc/adam/cpu_adam_impl.cpp b/csrc/adam/cpu_adam_impl.cpp new file mode 100644 index 000000000000..742cb4292777 --- /dev/null +++ b/csrc/adam/cpu_adam_impl.cpp @@ -0,0 +1,299 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include +#include +#include +#include +#include "cpu_adam.h" + +#if defined(__ENABLE_CUDA__) +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" +#include "custom_cuda_layers.h" +#endif + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adam_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) { + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + ds_half_precision_t* grads_cast_h; + ds_half_precision_t* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast(grads); + params_cast_h = reinterpret_cast(_params); + } + + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; +#if defined(__ENABLE_CUDA__) + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#endif +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; +#if defined(__ENABLE_CUDA__) + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; +#endif + if (half_precision) + params_cast_h[k] = (ds_half_precision_t)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } +#if defined(__ENABLE_CUDA__) + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + + _buf_index = !_buf_index; + } +#endif + } + } +} + +void Adam_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adam_optimizer(int optimizer_id, + float alpha, + float betta1, + float betta2, + float eps, + float weight_decay, + bool adamw_mode, + bool should_log) +{ + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", + alpha, + betta1, + betta2, + weight_decay, + (int)adamw_mode); + } + + return 0; +} + +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + ds_half_precision_t* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + // assert(params.options().dtype() == grads.options().dtype()); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + nullptr, + (params.options().dtype() == at::kHalf)); + +#if defined(__ENABLE_CUDA__) + opt->SynchronizeStreams(); +#endif + return 0; +} + +int ds_adam_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params) +{ +#if defined(__ENABLE_CUDA__) + auto params_c = params.contiguous(); + auto gpu_params_c = gpu_params.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + auto grads_c = grads.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + ds_half_precision_t* gpu_params_ptr = (ds_half_precision_t*)gpu_params_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.numel(), + gpu_params_ptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); +#else + assert(false); +#endif + return 0; +} + +int destroy_adam_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 1246df8b1449..c4f7edcd7410 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -289,4 +289,3 @@ int ds_adam_step_plus_copy(int optimizer_id, torch::Tensor& gpu_params); int destroy_adam_optimizer(int optimizer_id); -