Skip to content

Commit

Permalink
Fix+tests for get_lr from lr_scheduler before training starts (deepsp…
Browse files Browse the repository at this point in the history
…eedai#310)

* add fix and tests for get_lr from lr_scheduler before training starts
  • Loading branch information
jeffra authored Aug 10, 2020
1 parent 903a41a commit cd68e6e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
4 changes: 4 additions & 0 deletions deepspeed/pt/deepspeed_lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,10 @@ def __init__(self,
self.last_batch_iteration = last_batch_iteration

def get_lr(self):
if self.last_batch_iteration < 0:
logger.warning(
"Attempting to get learning rate from scheduler before it has started")
return 0.0
gamma = self._get_gamma()
return [
min_lr + (delta_lr * gamma) for min_lr,
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import deepspeed
import argparse
import pytest
import json
import os
from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict


@pytest.mark.parametrize("scheduler_type,params",
[("WarmupLR",
{}),
("OneCycle",
{
'cycle_min_lr': 0,
'cycle_max_lr': 0
}),
("LRRangeTest",
{})])
def test_get_lr_before_train(tmpdir, scheduler_type, params):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": scheduler_type,
"params": params
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[1])
def _test_get_lr_before_train(args, model, hidden_dim):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
# get lr before training starts
lr_scheduler.get_lr()
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

_test_get_lr_before_train(args=args, model=model, hidden_dim=hidden_dim)

0 comments on commit cd68e6e

Please sign in to comment.