Skip to content

Commit

Permalink
Feature/sg 907 add yaml explanation (#1422)
Browse files Browse the repository at this point in the history
* start

* draft
  • Loading branch information
Louis-Dupont authored Aug 28, 2023
1 parent 7240726 commit adc9c3d
Showing 1 changed file with 142 additions and 25 deletions.
167 changes: 142 additions & 25 deletions documentation/source/LRScheduling.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,44 @@ Learning rate scheduling type is controlled by the training parameter `lr_mode`.
For example, the training code below will start with an initial learning rate of 0.1 and decay by 0.1 at epochs 100,150 and 200:

```python

from super_gradients.training import Trainer
...


trainer = Trainer("my_custom_scheduler_training_experiment")

train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {...
"initial_lr": 0.1,
"lr_mode":"step",
"lr_updates": [100, 150, 200],
"lr_decay_factor": 0.1,
...}
train_params = {
"initial_lr": 0.1,
"lr_mode":"step",
"lr_updates": [100, 150, 200],
"lr_decay_factor": 0.1,
...,
}

trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)

```

<details>
<summary>Equivalent in a <code>.yaml</code> configuration file:</summary>

```yaml
training_hyperparams:
initial_lr: 0.1
lr_mode: step
user_lr_updates:
- 100
- 150
- 200
lr_decay_factor: 0.1
...

...
```
</details>


## Using Custom LR Schedulers

Prerequisites: [phase callbacks](PhaseCallbacks.md), [training with configuration files](configuration_files.md).
Expand Down Expand Up @@ -182,26 +200,49 @@ from myscheduler import UserStepLRCallback # triggers registry, now we can pass

And finally, use your new scheduler just as any other one supported by SG.
```python


trainer = Trainer("my_custom_scheduler_training_experiment")

# The following code sections marked with '...' are placeholders
# indicating additional necessary code that is not shown for simplicity.
train_dataloader = ...
valid_dataloader = ...
model = ...
train_params = {...
"initial_lr": 0.1,
"lr_mode": "user_step",
"user_lr_updates": [100, 150, 200], # WILL BE PASSED TO UserStepLRCallback CONSTRUCTOR
"user_lr_decay_factors": [0.1, 0.01, 0.001] # WILL BE PASSED TO UserStepLRCallback CONSTRUCTOR
...}

train_params = {
"initial_lr": 0.1,
"lr_mode": "user_step",
"user_lr_updates": [100, 150, 200], # WILL BE PASSED TO UserStepLRCallback CONSTRUCTOR
"user_lr_decay_factors": [0.1, 0.01, 0.001], # WILL BE PASSED TO UserStepLRCallback CONSTRUCTOR
...
}

trainer.train(model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader)

```

Note that internally, Trainer unpacks [training_params to the scheduler callback constructor](https://github.com/Deci-AI/super-gradients/blob/537a0f0afe7bcf28d331fe2c0fa797fa10f54b99/src/super_gradients/training/sg_trainer/sg_trainer.py#L1078), so we pass scheduler related parameters through training_params as well.


<details>
<summary>Equivalent in a <code>.yaml</code> configuration file:</summary>

```yaml
training_hyperparams:
initial_lr: 0.1
lr_mode: user_step
user_lr_updates: # WILL BE PASSED TO UserStepLRCallback CONSTRUCTOR
- 100
- 150
- 200
user_lr_decay_factors: # WILL BE PASSED TO UserStepLRCallback CONSTRUCTOR
- 0.1
- 0.01
- 0.001
...

...
```
</details>

### Using PyTorchs Native LR Schedulers (torch.optim.lr_scheduler)

PyTorch offers a [wide variety of learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate).
Expand Down Expand Up @@ -238,11 +279,15 @@ From `Trainer.train(...)` docs:
train_epoch(...)
scheduler.step()
....
For example:

### Examples
Using `StepLR`

```python
...
trainer = Trainer("torch_Scheduler_example")

# The following code sections marked with '...' are placeholders
# indicating additional necessary code that is not shown for simplicity.
train_dataloader = ...
valid_dataloader = ...
model = ...
Expand All @@ -260,13 +305,52 @@ train_params = {
"valid_metrics_list": [Accuracy()],
"metric_to_watch": "Accuracy",
"greater_metric_to_watch_is_better": True,
}
}
trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)
```

And as stated above, for ReduceLROnPlateau we need to pass a "metric_name", which follows the same
rules as the training parameter "metric_to_watch"(see [metrics guide](Metrics.md) when not familiar).
For example:
<details>
<summary>Equivalent in a <code>.yaml</code> configuration file:</summary>

```yaml
training_hyperparams:
# Setting up LR Scheduler
lr_mode:
StepLR:
gamma: 0.1
step_size: 1
phase: TRAIN_EPOCH_END

# Setting up other parameters
max_epochs: 2
lr_warmup_epochs: 0
initial_lr: 0.1
loss: CrossEntropyLoss
optimizer: SGD
criterion_params: {}
optimizer_params:
weight_decay: 1e-4
momentum: 0.9
train_metrics_list:
- Accuracy
valid_metrics_list:
- Accuracy
metric_to_watch: Accuracy
greater_metric_to_watch_is_better: true

...
```
</details>


**Using `ReduceLROnPlateau`**

If you choose to use `ReduceLROnPlateau` as the learning rate scheduler, you need to specify a `metric_name`.
This parameter follows the same guidelines as `metric_to_watch`.

For an in-depth understanding of these metrics,
see the [metrics guide](Metrics.md).


```python
trainer = Trainer("torch_ROP_Scheduler_example")
Expand All @@ -290,8 +374,41 @@ train_params = {
"greater_metric_to_watch_is_better": True,
}
trainer.train(model=model, training_params=train_params, train_loader=dataloader, valid_loader=dataloader)

```

The scheduler's `state_dict` is saved under `torch_scheduler_state_dict` entry inside the checkpoint during training,
allowing us to resume from the same state of the scheduling.

<details>
<summary>Equivalent in a <code>.yaml</code> configuration file:</summary>

```yaml
training_hyperparams:
# Setting up LR Scheduler
lr_mode:
ReduceLROnPlateau:
patience: 0
phase: TRAIN_EPOCH_END
metric_name: DummyMetric

# Setting up other parameters
max_epochs: 2
lr_decay_factor: 0.1
lr_warmup_epochs: 0
initial_lr: 0.1
loss: CrossEntropyLoss
optimizer: SGD
criterion_params: {}
optimizer_params:
weight_decay: 1e-4
momentum: 0.9
train_metrics_list:
- Accuracy
valid_metrics_list:
- Accuracy
metric_to_watch: DummyMetric
greater_metric_to_watch_is_better: true

...
```
</details>

0 comments on commit adc9c3d

Please sign in to comment.