From f5aa2547d88f145c4b5223d5bb8f79db69ee5288 Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Fri, 30 Oct 2020 09:01:04 -0700 Subject: [PATCH] Add CPUAdam optimizer for zero-offload in deepspeed engine (#484) * add adamW to CPU-ADAM implementation * supporting cpu-adam optimizer for zero-offload on deepspeed side * bump DSE to match cpu-adam updates Co-authored-by: Jeff Rasley --- DeepSpeedExamples | 2 +- csrc/adam/cpu_adam.cpp | 58 +++++++++++++++++++++------- csrc/includes/cpu_adam.h | 7 +++- deepspeed/__init__.py | 2 +- deepspeed/ops/adam/cpu_adam.py | 49 ++++++++++++++++++++++- deepspeed/runtime/config.py | 7 +++- deepspeed/runtime/engine.py | 31 +++++++++++---- deepspeed/runtime/zero/stage2.py | 14 ++++--- docs/_pages/config-json.md | 10 ++++- docs/code-docs/source/optimizers.rst | 12 ++++++ tests/unit/test_cpu_adam.py | 12 +++++- 11 files changed, 164 insertions(+), 40 deletions(-) create mode 100755 docs/code-docs/source/optimizers.rst diff --git a/DeepSpeedExamples b/DeepSpeedExamples index ba63ad0fa861..a79272cc8b8f 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit ba63ad0fa861d28b3b33bc2c20f702647403e258 +Subproject commit a79272cc8b8f0c5b66c803e581a1355341eacb77 diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 380bc4ea0ab0..1528b28584f6 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -32,7 +32,7 @@ void Adam_Optimizer::Step(float* _params, float bias_correction2 = 1 / sqrt(1 - _betta2_t); float step_size = -1 * _alpha / bias_correction1; - + float w_decay = -1 * _alpha * _weight_decay; size_t rounded_size = 0; #if defined(__AVX512__) or defined(__AVX256__) @@ -57,8 +57,8 @@ void Adam_Optimizer::Step(float* _params, step_size_4.data = SIMD_SET(step_size); AVX_Data weight_decay4; - if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay); - + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); for (size_t t = 0; t < rounded_size; t += TILE) { @@ -78,9 +78,9 @@ void Adam_Optimizer::Step(float* _params, AVX_Data param_4; param_4.data = SIMD_LOAD(_params + i); - if (_weight_decay > 0) + if (_weight_decay > 0 && !_adamw_mode) { grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data); - + } momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data); momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data); @@ -91,7 +91,9 @@ void Adam_Optimizer::Step(float* _params, grad_4.data = SIMD_SQRT(variance_4.data); grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data); grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data); - + if (_weight_decay > 0 && _adamw_mode) { + param_4.data = SIMD_FMA(param_4.data, weight_decay4.data, param_4.data); + } param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); SIMD_STORE(_params + i, param_4.data); @@ -119,8 +121,7 @@ void Adam_Optimizer::Step(float* _params, float param = _params[k]; float momentum = _exp_avg[k]; float variance = _exp_avg_sq[k]; - if (_weight_decay > 0) grad = param * _weight_decay + grad; - + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } momentum *= momentum * _betta1; momentum = grad * betta1_minus1 + momentum; @@ -131,7 +132,7 @@ void Adam_Optimizer::Step(float* _params, grad = sqrt(variance); grad = grad * bias_correction2 + _eps; grad = momentum / grad; - + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } param = grad * step_size + param; if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param; @@ -184,6 +185,10 @@ void Adam_Optimizer::Step_4(float* _params, AVX_Data step_size_4; step_size_4.data = SIMD_SET(step_size); + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2)); for (size_t t = 0; t < rounded_size; t += TILE) { @@ -216,9 +221,7 @@ void Adam_Optimizer::Step_4(float* _params, param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1)); param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3); - if (_weight_decay > 0) { - AVX_Data weight_decay4; - weight_decay4.data = SIMD_SET(_weight_decay); + if (_weight_decay > 0 && !_adamw_mode) { grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data); grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data); grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data); @@ -261,6 +264,13 @@ void Adam_Optimizer::Step_4(float* _params, grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data); grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data); + if (_weight_decay > 0 && _adamw_mode) { + param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data); + param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data); + param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data); + param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data); + } + param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data); param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data); param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data); @@ -313,9 +323,11 @@ int create_adam_optimizer(int optimizer_id, float betta1 = 0.9, float betta2 = 0.999, float eps = 1e-8, - float weight_decay = 0) + float weight_decay = 0, + bool adamw_mode = true) { - auto opt = std::make_shared(alpha, betta1, betta2, eps, weight_decay); + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); s_optimizers[optimizer_id] = opt; #if defined(__AVX512__) @@ -369,6 +381,11 @@ void Adam_Optimizer::Step_8(float* _params, AVX_Data step_size_4; step_size_4.data = SIMD_SET(step_size); + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = + (_adamw_mode ? _mm512_set1_ps(w_decay) : _mm512_set1_ps(_weight_decay)); rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3)); for (size_t t = 0; t < rounded_size; t += TILE) { @@ -417,7 +434,7 @@ void Adam_Optimizer::Step_8(float* _params, param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6); param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7); - if (_weight_decay > 0) { + if (_weight_decay > 0 && !_adamw_mode) { AVX_Data weight_decay4; weight_decay4.data = SIMD_SET(_weight_decay); grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data); @@ -498,6 +515,17 @@ void Adam_Optimizer::Step_8(float* _params, grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data); grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data); + if (_weight_decay > 0 && _adamw_mode) { + param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data); + param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data); + param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data); + param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data); + param_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, param_4[4].data); + param_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, param_4[5].data); + param_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, param_4[6].data); + param_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, param_4[7].data); + } + param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data); param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data); param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data); diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 996450c56ae9..3db12254bfe6 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -50,7 +50,8 @@ class Adam_Optimizer { float betta1 = 0.9, float betta2 = 0.999, float eps = 1e-8, - float weight_decay = 0) + float weight_decay = 0, + bool adamw_mode = true) : _alpha(alpha), _betta1(betta1), _betta2(betta2), @@ -58,7 +59,8 @@ class Adam_Optimizer { _weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), - _buf_index(false) + _buf_index(false), + _adamw_mode(adamw_mode) { cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); @@ -115,4 +117,5 @@ class Adam_Optimizer { float* _doubled_buffer[2]; bool _buf_index; + bool _adamw_mode; }; diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 3a9b79292edd..61a6f730fb17 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -7,7 +7,7 @@ from . import ops from .runtime.engine import DeepSpeedEngine -from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM +from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER from .runtime.pipe.engine import PipelineEngine from .runtime.lr_schedules import add_tuning_arguments from .runtime.config import DeepSpeedConfig diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 76d8d323f6f5..35f9a6f360cf 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -1,3 +1,7 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + import math import torch import importlib @@ -6,6 +10,40 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): + """Fast vectorized implementation of two variations of Adam optimizer on CPU: + + - Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980); + - AdamW: FIXING WEIGHT DECAY REGULARIZATION IN ADAM (https://arxiv.org/abs/1711.05101v1) + + DeepSpeed CPU Adam(W) provides between 5x to 7x speedu over torch.optim.adam(W). + In order to apply this optimizer, the model requires to have its master parameter (in FP32) + reside on the CPU memory. + + To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers + the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory, + with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize + the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial + (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology. + + For calling step function, there are two options available: (1) update optimizer's states and (2) update + optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second + option can bring 30% higher throughput than the doing the copy separately using option one. + + + Arguments: + model_params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in DeepSpeed CPUAdam! + adamw_mode: select between Adam and AdamW implementations (default: AdamW) + """ optimizer_id = 0 @@ -16,7 +54,8 @@ def __init__(self, 0.999), eps=1e-8, weight_decay=0, - amsgrad=False): + amsgrad=False, + adamw_mode=True): default_args = dict(lr=lr, betas=betas, @@ -30,7 +69,13 @@ def __init__(self, global ds_opt_adam ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op') - ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay) + ds_opt_adam.create_adam(self.opt_id, + lr, + betas[0], + betas[1], + eps, + weight_decay, + adamw_mode) def __setstate__(self, state): super(DeepSpeedCPUAdam, self).__setstate__(state) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index b49e3f3c144f..10ceab52a8dd 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -15,17 +15,20 @@ from deepspeed.utils import logger TENSOR_CORE_ALIGN_SIZE = 8 + ADAM_OPTIMIZER = 'adam' LAMB_OPTIMIZER = 'lamb' ONEBIT_ADAM_OPTIMIZER = 'onebitadam' -DEEPSPEED_ADAM = 'deepspeed_adam' DEEPSPEED_OPTIMIZERS = [ ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, - DEEPSPEED_ADAM ] +# extra optimizer parameters for adam +TORCH_ADAM_PARAM = "torch_adam" +ADAM_W_MODE_PARAM = "adam_w_mode" + def get_amp_enabled(param_dict): if AMP in param_dict.keys(): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a442154a03bc..0932ef9e4998 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -19,8 +19,10 @@ from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer -from deepspeed.runtime.config import DeepSpeedConfig, \ - ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, DEEPSPEED_ADAM, DEEPSPEED_OPTIMIZERS +from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ + ADAM_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \ + TORCH_ADAM_PARAM, ADAM_W_MODE_PARAM + from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.constants import \ ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ @@ -548,15 +550,30 @@ def _configure_basic_optimizer(self, model_parameters): raise ValueError( "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) + if self.optimizer_name() == ADAM_OPTIMIZER: - if self.zero_cpu_offload(): + torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False) + adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE_PARAM, True) + + # zero-offload torch-adam adam_w_mode optimizer + # T|F T T torch.optim.AdamW + # T|F T F torch.optim.Adam + # T F T|F DeepSpeedCPUAdam(adam_w_mode) + # F F T|F FusedAdam(adam_w_mode) + if torch_adam and adam_w_mode: + optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters) + elif torch_adam and not adam_w_mode: optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) - else: + elif self.zero_cpu_offload() and not torch_adam: + from deepspeed.ops.adam import DeepSpeedCPUAdam + optimizer = DeepSpeedCPUAdam(model_parameters, + **optimizer_parameters, + adamw_mode=adam_w_mode) + elif not self.zero_cpu_offload() and not torch_adam: from apex.optimizers.fused_adam import FusedAdam + optimizer_parameters[ADAM_W_MODE_PARAM] = adam_w_mode optimizer = FusedAdam(model_parameters, **optimizer_parameters) - elif self.optimizer_name() == DEEPSPEED_ADAM: - from deepspeed.ops.adam import DeepSpeedCPUAdam - optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == LAMB_OPTIMIZER: from deepspeed.ops.lamb import FusedLamb optimizer = FusedLamb(model_parameters, **optimizer_parameters) diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index c86686a37146..52f4c28173ed 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -157,8 +157,7 @@ def __init__(self, self.cpu_offload = cpu_offload - self.deepspeed_adam_offload = (cpu_offload - and type(init_optimizer) == DeepSpeedCPUAdam) + self.deepspeed_adam_offload = cpu_offload self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu' @@ -1416,10 +1415,13 @@ def step(self, closure=None): timers('optimizer_step').start() if self.deepspeed_adam_offload: from deepspeed.ops.adam import DeepSpeedCPUAdam - self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups) - #self.optimizer.step() - #for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - # fp16_partitions[partition_id].data.copy_(fp32_partition.data) + if type(self.optimizer) == DeepSpeedCPUAdam: + self.optimizer.step( + fp16_param_groups=self.parallel_partitioned_fp16_groups) + else: + self.optimizer.step() + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id].data.copy_(fp32_partition.data) else: self.optimizer.step() diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 049dbf3c2899..3efc2ced025f 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -34,10 +34,10 @@ title: "DeepSpeed Configuration JSON" | Fields | Value | Example | | ------ | ------------------------------------------------------------ | ------------------------------ | -| type | The optimizer name. DeepSpeed natively supports **Adam**, **DeepSpeedAdam**, **OneBitAdam**, and **LAMB** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | +| type | The optimizer name. DeepSpeed natively supports **Adam**, **OneBitAdam**, and **Lamb** optimizers and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | | params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | - Example of ***optimizer*** + Example of ***optimizer*** with Adam ```json "optimizer": { @@ -53,6 +53,12 @@ title: "DeepSpeed Configuration JSON" } } ``` +The Adam optimizer also supports the following two params keys/values in addition to the standard parameters from [torch.optim.Adam](https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam): +| "params" key | Description | Default | +| ------------- | --------------------------------------------------------------------------- | --------| +| torch\_adam | Use torch's implementation of adam instead of our fused adam implementation | false | +| adam\_w\_mode | Apply L2 regularization (also known as AdamW) | true | + Another example of ***optimizer*** with 1-bit Adam specific parameters is as follows. ```json diff --git a/docs/code-docs/source/optimizers.rst b/docs/code-docs/source/optimizers.rst new file mode 100755 index 000000000000..65f1ca2bf33f --- /dev/null +++ b/docs/code-docs/source/optimizers.rst @@ -0,0 +1,12 @@ +Optimizers +=================== + +DeepSpeed offers high-performance implementations of Adam and Lamb optimizers on CPU and GPU, respectively. + +DeepSpeed CPU Adam +---------------------------- +.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam + +DeepSpeed Fused Lamb +---------------------------- +.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam diff --git a/tests/unit/test_cpu_adam.py b/tests/unit/test_cpu_adam.py index a6a13d1a2e9e..5130dc72fa74 100755 --- a/tests/unit/test_cpu_adam.py +++ b/tests/unit/test_cpu_adam.py @@ -1,6 +1,6 @@ import argparse import torch - +import apex import time import numpy as np import pytest @@ -37,8 +37,12 @@ def test_cpu_adam_opt(model_size): param = torch.nn.Parameter(torch.randn(model_size, device=device)) torch.set_rng_state(rng_state) param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + param2_data = torch.randn(model_size, device=device).cuda() + param2 = torch.nn.Parameter(param2_data) - optimizer1 = torch.optim.Adam([param1]) + optimizer1 = torch.optim.AdamW([param1]) + optimizer2 = apex.optimizers.FusedAdam([param2]) optimizer = DeepSpeedCPUAdam([param]) for i in range(10): @@ -46,8 +50,12 @@ def test_cpu_adam_opt(model_size): param.grad = torch.randn(model_size, device=device) torch.set_rng_state(rng_state) param1.grad = torch.randn(model_size, device=device) + torch.set_rng_state(rng_state) + param2.grad = torch.randn(model_size, device=device).cuda() optimizer.step() + optimizer2.step() optimizer1.step() check_equal(param, param1, atol=1e-2, verbose=True) + check_equal(param, param2.cpu(), atol=1e-2, verbose=True)