forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Progressive layer dropping docs (microsoft#499) * test * Adding tutorial and news page for pld * updating the tutorial and posts of PLD * update the finetune tutorial * Update PLD tutorial (microsoft#512) * Update installation instructions * Format fix * ZeRO tutorial * Format fixes * ZeRO-Offload * ZeRO and ZeRO-Offload tutorials * Update navigation page * Format fixes * Add yuxhe feedback * Fix blog post link * Fix OneBit-Adam link Tweak scheduler example * Fix date link * Add DeepSpeed_Adam * Add PLD tutorial to navigation Co-authored-by: Shaden Smith <[email protected]> Co-authored-by: Jeff Rasley <[email protected]> * updating the pld docs * DeepSpeed implementation of PLD (microsoft#508) * DeepSpeed implementation of PLD * Format fixes * Formatting fixes * Fix broken url * Address PR feedback * Bump DSE Co-authored-by: Minjia Zhang <[email protected]> Co-authored-by: Shaden Smith <[email protected]> Co-authored-by: Jeff Rasley <[email protected]> Co-authored-by: Minjia Zhang <[email protected]>
- Loading branch information
1 parent
e082d47
commit be1147c
Showing
11 changed files
with
238 additions
and
10 deletions.
There are no files selected for viewing
Submodule DeepSpeedExamples
updated
12 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,8 +155,7 @@ all repos using our CLA. | |
This project has adopted the [Microsoft Open Source Code of | ||
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the | ||
[Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact | ||
[[email protected]](mailto:[email protected]) with any additional questions or | ||
comments. | ||
[[email protected]](mailto:[email protected]) with any additional questions or comments. | ||
|
||
# Publications | ||
1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: Memory Optimization Towards Training A Trillion Parameter Models. [ArXiv:1910.02054](https://arxiv.org/abs/1910.02054) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
from deepspeed.utils import log_dist | ||
|
||
|
||
class ProgressiveLayerDrop(object): | ||
r""" Progressive Layer Dropping (PLD) for model training. | ||
This implements the PLD technique for compressed model training | ||
from this paper: https://arxiv.org/pdf/2010.13369.pdf | ||
Args: | ||
theta (float): a hyper-parameter that controls the trade-off between training time and robustness. | ||
The lower the theta value, the faster the training speed. Default value: 0.5. | ||
gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001. | ||
""" | ||
def __init__(self, theta=0.5, gamma=0.001): | ||
super().__init__() | ||
|
||
self.theta = theta | ||
self.gamma = gamma | ||
self.current_theta = 1.0 | ||
log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0]) | ||
|
||
def get_state(self): | ||
kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()} | ||
return kwargs | ||
|
||
def get_theta(self): | ||
return self.current_theta | ||
|
||
def update_state(self, global_step): | ||
def _prob(x, gamma, p): | ||
return (1. - p) * np.exp(-gamma * x) + p | ||
|
||
self.current_theta = _prob(global_step, self.gamma, self.theta) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import numpy as np | ||
import deepspeed | ||
import pytest | ||
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop | ||
from common import distributed_test | ||
from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict | ||
|
||
|
||
@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) | ||
def test_pld_schedule(tmpdir, theta): | ||
gamma = 0.001 | ||
|
||
pld_scheduler = ProgressiveLayerDrop(theta, gamma) | ||
for i in range(10): | ||
pld_scheduler.update_state(i) | ||
expected_theta = (1. - theta) * np.exp(-gamma * i) + theta | ||
actual_theta = pld_scheduler.get_theta() | ||
assert expected_theta == actual_theta | ||
|
||
|
||
@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) | ||
def test_pld_model(tmpdir, theta): | ||
gamma = 0.001 | ||
config_dict = { | ||
"train_batch_size": 1, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": 'Adam', | ||
"params": { | ||
"lr": 0.0001 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": True | ||
}, | ||
"progressive_layer_drop": { | ||
"enabled": True, | ||
"theta": theta, | ||
"gamma": gamma | ||
} | ||
} | ||
|
||
args = args_from_dict(tmpdir, config_dict) | ||
hidden_dim = 10 | ||
|
||
model = PLD_SimpleModel(hidden_dim, empty_grad=False) | ||
|
||
@distributed_test(world_size=[1]) | ||
def _test_pld_model(args, model, hidden_dim, theta, gamma): | ||
model, _, _, _ = deepspeed.initialize(args=args, | ||
model=model, | ||
model_parameters=model.parameters()) | ||
|
||
data_loader = random_dataloader(model=model, | ||
total_samples=50, | ||
hidden_dim=hidden_dim, | ||
device=model.device) | ||
|
||
for i, batch in enumerate(data_loader): | ||
loss = model(batch[0], batch[1]) | ||
model.backward(loss) | ||
model.step() | ||
|
||
expected_theta = (1. - theta) * np.exp(-gamma * i) + theta | ||
actual_theta = model.get_pld_theta() | ||
assert expected_theta == actual_theta | ||
|
||
_test_pld_model(args=args, | ||
model=model, | ||
hidden_dim=hidden_dim, | ||
theta=theta, | ||
gamma=gamma) | ||
|
||
|
||
def test_non_pld_model(tmpdir): | ||
gamma = 0.001 | ||
theta = 0.5 | ||
config_dict = { | ||
"train_batch_size": 1, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": 'Adam', | ||
"params": { | ||
"lr": 0.0001 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": True | ||
}, | ||
"progressive_layer_drop": { | ||
"enabled": True, | ||
"theta": theta, | ||
"gamma": gamma | ||
} | ||
} | ||
|
||
args = args_from_dict(tmpdir, config_dict) | ||
hidden_dim = 10 | ||
|
||
model = SimpleModel(hidden_dim, empty_grad=False) | ||
|
||
@distributed_test(world_size=[1]) | ||
def _test_non_pld_model(args, model, hidden_dim): | ||
model, _, _, _ = deepspeed.initialize(args=args, | ||
model=model, | ||
model_parameters=model.parameters()) | ||
|
||
data_loader = random_dataloader(model=model, | ||
total_samples=1, | ||
hidden_dim=hidden_dim, | ||
device=model.device) | ||
|
||
for i, batch in enumerate(data_loader): | ||
with pytest.raises(TypeError): | ||
loss = model(batch[0], batch[1]) | ||
|
||
_test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim) |