Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix+tests for get_lr from lr_scheduler before training starts #310

Merged
merged 3 commits into from
Aug 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepspeed/pt/deepspeed_lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,10 @@ def __init__(self,
self.last_batch_iteration = last_batch_iteration

def get_lr(self):
if self.last_batch_iteration < 0:
logger.warning(
"Attempting to get learning rate from scheduler before it has started")
return 0.0
gamma = self._get_gamma()
return [
min_lr + (delta_lr * gamma) for min_lr,
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import deepspeed
import argparse
import pytest
import json
import os
from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict


@pytest.mark.parametrize("scheduler_type,params",
[("WarmupLR",
{}),
("OneCycle",
{
'cycle_min_lr': 0,
'cycle_max_lr': 0
}),
("LRRangeTest",
{})])
def test_get_lr_before_train(tmpdir, scheduler_type, params):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
},
},
"scheduler": {
"type": scheduler_type,
"params": params
},
"gradient_clipping": 1.0
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[1])
def _test_get_lr_before_train(args, model, hidden_dim):
model, _, _, lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
# get lr before training starts
lr_scheduler.get_lr()
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

_test_get_lr_before_train(args=args, model=model, hidden_dim=hidden_dim)