From 34f8bae6ab337cbad5221b36e19fa24b7a372ca1 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 7 Aug 2020 23:48:08 +0000 Subject: [PATCH 1/2] add fix and tests for get_lr from lr_scheduler before training starts --- deepspeed/pt/deepspeed_lr_schedules.py | 4 ++ tests/unit/test_lr_schedulers.py | 59 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 tests/unit/test_lr_schedulers.py diff --git a/deepspeed/pt/deepspeed_lr_schedules.py b/deepspeed/pt/deepspeed_lr_schedules.py index 97c18ffec0a2..02f66a45aab4 100755 --- a/deepspeed/pt/deepspeed_lr_schedules.py +++ b/deepspeed/pt/deepspeed_lr_schedules.py @@ -676,6 +676,10 @@ def __init__(self, self.last_batch_iteration = last_batch_iteration def get_lr(self): + if self.last_batch_iteration < 0: + logger.warning( + "Attempting to get learning rate from scheduler before it has started") + return 0.0 gamma = self._get_gamma() return [ min_lr + (delta_lr * gamma) for min_lr, diff --git a/tests/unit/test_lr_schedulers.py b/tests/unit/test_lr_schedulers.py new file mode 100644 index 000000000000..9e089561cf0e --- /dev/null +++ b/tests/unit/test_lr_schedulers.py @@ -0,0 +1,59 @@ +import torch +import deepspeed +import argparse +import pytest +import json +import os +from common import distributed_test +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict + + +@pytest.mark.parametrize("scheduler_type,params", + [("WarmupLR", + {}), + ("OneCycle", + { + 'cycle_min_lr': 0, + 'cycle_max_lr': 0 + }), + ("LRRangeTest", + {})]) +def test_get_lr_before_train(tmpdir, scheduler_type, params): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + }, + }, + "scheduler": { + "type": scheduler_type, + "params": params + }, + "gradient_clipping": 1.0 + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_warmup_lr(args, model, hidden_dim): + model, _, _, lr_scheduler = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + for n, batch in enumerate(data_loader): + # get lr before training starts + lr_scheduler.get_lr() + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_warmup_lr(args=args, model=model, hidden_dim=hidden_dim) From bc519f107f4538a31280ebff9abc75643140fc0d Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 7 Aug 2020 23:50:48 +0000 Subject: [PATCH 2/2] rename --- tests/unit/test_lr_schedulers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_lr_schedulers.py b/tests/unit/test_lr_schedulers.py index 9e089561cf0e..0c388627a38f 100644 --- a/tests/unit/test_lr_schedulers.py +++ b/tests/unit/test_lr_schedulers.py @@ -40,7 +40,7 @@ def test_get_lr_before_train(tmpdir, scheduler_type, params): model = SimpleModel(hidden_dim, empty_grad=False) @distributed_test(world_size=[1]) - def _test_warmup_lr(args, model, hidden_dim): + def _test_get_lr_before_train(args, model, hidden_dim): model, _, _, lr_scheduler = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) @@ -56,4 +56,4 @@ def _test_warmup_lr(args, model, hidden_dim): model.backward(loss) model.step() - _test_warmup_lr(args=args, model=model, hidden_dim=hidden_dim) + _test_get_lr_before_train(args=args, model=model, hidden_dim=hidden_dim)