-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ZeRO-Offload (squash) (#381) Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Jeff Rasley <[email protected]> Co-authored-by: Jie <[email protected]> Co-authored-by: Arash Ashari <[email protected]> Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Samyam Rajbhandari <[email protected]> Co-authored-by: Shaden Smith <[email protected]> Co-authored-by: arashashari <[email protected]> Co-authored-by: RezaYazdaniAminabadi <[email protected]> Co-authored-by: Reza Yazdani <[email protected]> Co-authored-by: Samyam Rajbhandari <[email protected]> Co-authored-by: Shaden Smith <[email protected]>
- v0.16.4
- v0.16.3
- v0.16.2
- v0.16.1
- v0.16.0
- v0.15.4
- v0.15.3
- v0.15.2
- v0.15.1
- v0.15.0
- v0.14.5
- v0.14.4
- v0.14.3
- v0.14.2
- v0.14.1
- v0.14.0
- v0.13.5
- v0.13.4
- v0.13.3
- v0.13.2
- v0.13.1
- v0.13.0
- v0.12.6
- v0.12.5
- v0.12.4
- v0.12.3
- v0.12.2
- v0.12.1
- v0.12.0
- v0.11.2
- v0.11.1
- v0.11.0
- v0.10.3
- v0.10.2
- v0.10.1
- v0.10.0
- v0.9.5
- v0.9.4
- v0.9.3
- v0.9.2
- v0.9.1
- v0.9.0
- v0.8.3
- v0.8.2
- v0.8.1
- v0.8.0
- v0.7.7
- v0.7.6
- v0.7.5
- v0.7.4
- v0.7.3
- v0.7.2
- v0.7.1
- v0.7.0
- v0.6.7
- v0.6.6
- v0.6.5
- v0.6.4
- v0.6.3
- v0.6.2
- v0.6.1
- v0.6.0
- v0.5.10
- v0.5.9
- v0.5.8
- v0.5.7
- v0.5.6
- v0.5.5
- v0.5.4
- v0.5.3
- v0.5.2
- v0.5.1
- v0.5.0
- v0.4.5
- v0.4.4
- v0.4.3
- v0.4.2
- v0.4.1
- v0.4.0
- v0.3.16
- v0.3.15
- v0.3.14
- v0.3.13
- v0.3.12
- v0.3.11
- v0.3.10
- v0.3.9
- v0.3.8
- v0.3.7
- v0.3.6
- v0.3.5
- v0.3.4
- v0.3.3
- v0.3.2
- v0.3.1
- v0.3.0
- grad-norm-test
1 parent
79093d7
commit 41db1c2
Showing
44 changed files
with
3,788 additions
and
1,889 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
|
||
|
||
#include "custom_cuda_layers.h" | ||
|
||
__global__ void param_update_kernel(const float* input, __half* output, int size) | ||
{ | ||
const float4* input_cast = reinterpret_cast<const float4*>(input); | ||
float2* output_cast = reinterpret_cast<float2*>(output); | ||
|
||
int id = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
if (id < size) { | ||
float4 data = input_cast[id]; | ||
float2 cast_data; | ||
__half* output_h = reinterpret_cast<__half*>(&cast_data); | ||
|
||
output_h[0] = (__half)data.x; | ||
output_h[1] = (__half)data.y; | ||
output_h[2] = (__half)data.z; | ||
output_h[3] = (__half)data.w; | ||
|
||
output_cast[id] = cast_data; | ||
} | ||
} | ||
|
||
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) | ||
{ | ||
int threads = 512; | ||
|
||
size /= 4; | ||
dim3 grid_dim((size - 1) / threads + 1); | ||
dim3 block_dim(threads); | ||
|
||
param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#pragma once | ||
|
||
#include <cpuid.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime_api.h> | ||
#include <stdio.h> | ||
#include <x86intrin.h> | ||
#include <cassert> | ||
#include "context.h" | ||
#include "cublas_v2.h" | ||
#include "cuda.h" | ||
#include "curand.h" | ||
|
||
#define CUDA_CHECK(callstr) \ | ||
{ \ | ||
cudaError_t error_code = callstr; \ | ||
if (error_code != cudaSuccess) { \ | ||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ | ||
assert(0); \ | ||
} \ | ||
} | ||
|
||
#define TILE (1024 * 1024 * 1024) | ||
|
||
#if defined(__AVX512__) | ||
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) | ||
#define SIMD_LOAD(x) _mm512_loadu_ps(x) | ||
#define SIMD_SET(x) _mm512_set1_ps(x) | ||
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) | ||
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) | ||
#define SIMD_SQRT(x) _mm512_sqrt_ps(x) | ||
#define SIMD_DIV(x, y) _mm512_div_ps(x, y) | ||
#define SIMD_WIDTH 16 | ||
#else | ||
#if defined(__AVX256__) | ||
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) | ||
#define SIMD_LOAD(x) _mm256_loadu_ps(x) | ||
#define SIMD_SET(x) _mm256_set1_ps(x) | ||
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) | ||
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) | ||
#define SIMD_SQRT(x) _mm256_sqrt_ps(x) | ||
#define SIMD_DIV(x, y) _mm256_div_ps(x, y) | ||
#define SIMD_WIDTH 8 | ||
#endif | ||
#endif | ||
|
||
class Adam_Optimizer { | ||
public: | ||
Adam_Optimizer(float alpha = 1e-3, | ||
float betta1 = 0.9, | ||
float betta2 = 0.999, | ||
float eps = 1e-8, | ||
float weight_decay = 0) | ||
: _alpha(alpha), | ||
_betta1(betta1), | ||
_betta2(betta2), | ||
_eps(eps), | ||
_weight_decay(weight_decay), | ||
_betta1_t(1.0), | ||
_betta2_t(1.0), | ||
_buf_index(false) | ||
{ | ||
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); | ||
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); | ||
} | ||
~Adam_Optimizer() | ||
{ | ||
cudaFreeHost(_doubled_buffer[0]); | ||
cudaFreeHost(_doubled_buffer[1]); | ||
} | ||
void Step(float* _params, | ||
float* grads, | ||
float* _exp_avg, | ||
float* _exp_avg_sq, | ||
size_t param_size, | ||
__half* dev_param = nullptr); | ||
void Step_4(float* _params, | ||
float* grads, | ||
float* _exp_avg, | ||
float* _exp_avg_sa, | ||
size_t param_size, | ||
__half* dev_param = nullptr); | ||
void Step_8(float* _params, | ||
float* grads, | ||
float* _exp_avg, | ||
float* _exp_avg_sq, | ||
size_t _param_size, | ||
__half* dev_params = nullptr); | ||
inline void IncrementStep() | ||
{ | ||
_betta1_t *= _betta1; | ||
_betta2_t *= _betta2; | ||
} | ||
|
||
private: | ||
#if defined(__AVX512__) or defined(__AVX256__) | ||
union AVX_Data { | ||
#if defined(__AVX512__) | ||
__m512 data; | ||
#else | ||
__m256 data; | ||
#endif | ||
// float data_f[16]; | ||
}; | ||
#endif | ||
|
||
float _alpha; | ||
float _betta1; | ||
float _betta2; | ||
float _eps; | ||
float _weight_decay; | ||
|
||
float _betta1_t; | ||
float _betta2_t; | ||
|
||
float* _doubled_buffer[2]; | ||
bool _buf_index; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from ..git_version_info import installed_ops as __installed_ops__ | ||
from . import lamb | ||
from . import transformer | ||
if __installed_ops__['sparse-attn']: | ||
from . import sparse_attention | ||
if __installed_ops__['cpu-adam']: | ||
from . import adam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .cpu_adam import DeepSpeedCPUAdam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import math | ||
import torch | ||
import importlib | ||
|
||
ds_opt_adam = None | ||
|
||
|
||
class DeepSpeedCPUAdam(torch.optim.Optimizer): | ||
|
||
optimizer_id = 0 | ||
|
||
def __init__(self, | ||
model_params, | ||
lr=1e-3, | ||
betas=(0.9, | ||
0.999), | ||
eps=1e-8, | ||
weight_decay=0, | ||
amsgrad=False): | ||
|
||
default_args = dict(lr=lr, | ||
betas=betas, | ||
eps=eps, | ||
weight_decay=weight_decay, | ||
amsgrad=amsgrad) | ||
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args) | ||
|
||
self.opt_id = DeepSpeedCPUAdam.optimizer_id | ||
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 | ||
|
||
global ds_opt_adam | ||
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op') | ||
ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay) | ||
|
||
def __setstate__(self, state): | ||
super(DeepSpeedCPUAdam, self).__setstate__(state) | ||
for group in self.param_groups: | ||
group.setdefault('amsgrad', False) | ||
|
||
@torch.no_grad() | ||
def step(self, closure=None, fp16_param_groups=None): | ||
loss = None | ||
if closure is not None: | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
for group_id, group in enumerate(self.param_groups): | ||
for param_id, p in enumerate(group['params']): | ||
|
||
if p.grad is None: | ||
continue | ||
|
||
grad = p.grad.data | ||
state = self.state[p] | ||
# State initialization | ||
if len(state) == 0: | ||
print(f'group {group_id} param {param_id} = {p.numel()}') | ||
state['step'] = 0 | ||
# gradient momentums | ||
state['exp_avg'] = torch.zeros_like(p.data, device='cpu') | ||
# gradient variances | ||
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu') | ||
|
||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||
state['step'] += 1 | ||
|
||
if fp16_param_groups is not None: | ||
p_fp16 = fp16_param_groups[group_id][param_id] | ||
ds_opt_adam.adam_update_copy(self.opt_id, | ||
p.data, | ||
grad, | ||
exp_avg, | ||
exp_avg_sq, | ||
p_fp16) | ||
else: | ||
ds_opt_adam.adam_update(self.opt_id, | ||
p.data, | ||
grad, | ||
exp_avg, | ||
exp_avg_sq) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
from torch.autograd import Variable | ||
import collections | ||
|
||
|
||
def async_migrate_to(obj, dev, main_stream=None): | ||
if torch.is_tensor(obj): | ||
obj = Variable(obj) | ||
if isinstance(obj, Variable): | ||
v = obj.cuda(dev, async=True) | ||
if main_stream is not None: | ||
v.data.record_stream(main_stream) | ||
return v | ||
elif isinstance(obj, collections.Mapping): | ||
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} | ||
elif isinstance(obj, collections.Sequence): | ||
return [async_copy_to(o, dev, main_stream) for o in obj] | ||
else: | ||
return obj | ||
|
||
|
||
def async_copy_to(obj, dev, main_stream=None): | ||
if torch.is_tensor(obj): | ||
obj = Variable(obj) | ||
if isinstance(obj, Variable): | ||
target = torch.empty_like(obj, device=dev).copy_(obj) | ||
if main_stream is not None: | ||
target.data.record_stream(main_stream) | ||
return target | ||
elif isinstance(obj, collections.Mapping): | ||
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} | ||
elif isinstance(obj, collections.Sequence): | ||
return [async_copy_to(o, dev, main_stream) for o in obj] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.