Skip to content

Commit

Permalink
Fail fast if scheduler warmup and max duration are incompatible (#2458)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Aug 23, 2023
1 parent 6226657 commit e7a0922
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 5 deletions.
32 changes: 28 additions & 4 deletions composer/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def __call__(self, state: State, ssr: float = 1.0):
class PolynomialScheduler(ComposerScheduler):
r"""Sets the learning rate to be proportional to a power of the fraction of training time left.
Specifially, the learning rate multiplier :math:`\alpha` can be expressed as:
Specifically, the learning rate multiplier :math:`\alpha` can be expressed as:
.. math::
\alpha(t) = \alpha_f + (1 - \alpha_f) \times (1 - \tau) ^ {\kappa}
Expand Down Expand Up @@ -510,6 +510,22 @@ def __call__(self, state: State, ssr: float = 1.0):
return current_factor


def _raise_if_warmup_and_max_duration_incompatible(t_warmup: Union[str, Time], t_max: Union[str, Time]):
if isinstance(t_warmup, str):
t_warmup = Time.from_timestring(t_warmup)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)
units_same = t_warmup.unit == t_max.unit
warmup_is_dur = t_warmup.unit == TimeUnit('dur')
batches_vs_epochs = (t_warmup.unit == TimeUnit('ba') and
t_max.unit == TimeUnit('ep')) or (t_warmup.unit == TimeUnit('ep') and
t_max.unit == TimeUnit('ba'))
if not units_same and not warmup_is_dur and not batches_vs_epochs:
raise ValueError(f'Cannot use warmup scheduler with max_duration {t_max} and warmup {t_warmup}. '
't_warmup units must be the same as max_duration units, warmup must be in units "dur", '
'max_duration must be "ba" and t_warmup "ep", or max_duration must be "ep" and t_warmup "ba".')


class MultiStepWithWarmupScheduler(ComposerScheduler):
r"""Decays the learning rate discretely at fixed milestones, with an initial warmup.
Expand Down Expand Up @@ -558,6 +574,8 @@ def __init__(self,
self.step_scheduler = MultiStepScheduler(milestones=milestones, gamma=gamma)

def __call__(self, state: State, ssr: float = 1.0):
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'
_raise_if_warmup_and_max_duration_incompatible(self.t_warmup, state.max_duration)
t_warmup = _convert_time(self.t_warmup, state)
if t_warmup.value == 0:
warnings.warn(
Expand Down Expand Up @@ -639,7 +657,7 @@ class LinearWithWarmupScheduler(ComposerScheduler):
\alpha_i + (alpha_f - \alpha_i) \times \tau_w & \text{otherwise}
\end{cases}
Given :math:`\tau_w`, the fraction of post-warmup time elpased (clipped to the interval :math:`[0, 1]`), as:
Given :math:`\tau_w`, the fraction of post-warmup time elapsed (clipped to the interval :math:`[0, 1]`), as:
.. math::
\tau_w = (t - t_{warmup}) / t_{max}
Expand Down Expand Up @@ -676,6 +694,8 @@ def __init__(self,
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=alpha_i, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'
_raise_if_warmup_and_max_duration_incompatible(self.t_warmup, state.max_duration)
t_warmup = _convert_time(self.t_warmup, state)
if t_warmup.value == 0:
warnings.warn(
Expand Down Expand Up @@ -713,7 +733,7 @@ class CosineAnnealingWithWarmupScheduler(ComposerScheduler):
\alpha_f + (1 - \alpha_f) \times \frac{1}{2} (1 + \cos(\pi \times \tau_w)) & \text{otherwise}
\end{cases}
Given :math:`\tau_w`, the fraction of post-warmup time elpased (clipped to the interval :math:`[0, 1]`), as:
Given :math:`\tau_w`, the fraction of post-warmup time elapsed (clipped to the interval :math:`[0, 1]`), as:
.. math::
\tau_w = (t - t_{warmup}) / t_{max}
Expand Down Expand Up @@ -744,6 +764,8 @@ def __init__(self,
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'
_raise_if_warmup_and_max_duration_incompatible(self.t_warmup, state.max_duration)
t_warmup = _convert_time(self.t_warmup, state)
if t_warmup.value == 0:
warnings.warn(
Expand Down Expand Up @@ -779,7 +801,7 @@ class PolynomialWithWarmupScheduler(ComposerScheduler):
\alpha_f + (1 - \alpha_f) \times (1 - \tau_w) ^ {\kappa} & \text{otherwise}
\end{cases}
Given :math:`\tau_w`, the fraction of post-warmup time elpased (clipped to the interval :math:`[0, 1]`), as:
Given :math:`\tau_w`, the fraction of post-warmup time elapsed (clipped to the interval :math:`[0, 1]`), as:
.. math::
\tau_w = (t - t_{warmup}) / t_{max}
Expand Down Expand Up @@ -814,6 +836,8 @@ def __init__(self,
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'
_raise_if_warmup_and_max_duration_incompatible(self.t_warmup, state.max_duration)
t_warmup = _convert_time(self.t_warmup, state)
if t_warmup.value == 0:
warnings.warn(
Expand Down
67 changes: 66 additions & 1 deletion tests/optim/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader

from composer.core import State, Time
from composer.core.time import TimeUnit
from composer.core.time import Timestamp, TimeUnit
from composer.devices import DeviceCPU, DeviceGPU
from composer.optim.scheduler import (ComposerScheduler, ConstantWithWarmupScheduler, CosineAnnealingScheduler,
CosineAnnealingWarmRestartsScheduler, CosineAnnealingWithWarmupScheduler,
Expand Down Expand Up @@ -177,3 +177,68 @@ def test_scheduler_trains(
seed=rank_zero_seed,
)
trainer.fit()


@pytest.mark.parametrize('scheduler_class', [
CosineAnnealingWithWarmupScheduler, MultiStepWithWarmupScheduler, ConstantWithWarmupScheduler,
LinearWithWarmupScheduler, PolynomialWithWarmupScheduler
])
@pytest.mark.parametrize('max_duration_unit', ['tok', 'sp', 'ba', 'ep'])
@pytest.mark.parametrize('warmup_duration_unit', ['ba', 'tok', 'sp', 'ep', 'dur'])
def test_warmup_schedulers_fail_fast(scheduler_class: Type[ComposerScheduler], max_duration_unit: str,
warmup_duration_unit: str, dummy_schedulers_state: State):
if warmup_duration_unit == max_duration_unit or warmup_duration_unit == 'dur' or (
max_duration_unit == 'ep' and warmup_duration_unit == 'ba') or (max_duration_unit == 'ba' and
warmup_duration_unit == 'ep'):
error_context = contextlib.nullcontext()
else:
error_context = pytest.raises(ValueError, match='Cannot use warmup scheduler with max_duration')

tokens_per_sample = 8
samples_per_batch = 16
batches_per_epoch = 32
num_epochs = 4
total_batches = batches_per_epoch * num_epochs
total_samples = total_batches * samples_per_batch
total_tokens = total_samples * tokens_per_sample

warmup_duration_pct = 0.25
warmup_batches = int(total_batches * warmup_duration_pct)
warmup_samples = int(total_samples * warmup_duration_pct)
warmup_tokens = int(total_tokens * warmup_duration_pct)
warmup_epochs = int(num_epochs * warmup_duration_pct)

max_duration_unit_to_str = {
'tok': f'{total_tokens}tok',
'sp': f'{total_samples}sp',
'ba': f'{total_batches}ba',
'ep': f'{num_epochs}ep',
}

warmup_duration_unit_to_str = {
'tok': f'{warmup_tokens}tok',
'sp': f'{warmup_samples}sp',
'ba': f'{warmup_batches}ba',
'ep': f'{warmup_epochs}ep',
'dur': f'{warmup_duration_pct}dur',
}

max_duration_str = max_duration_unit_to_str[max_duration_unit]
warmup_duration_str = warmup_duration_unit_to_str[warmup_duration_unit]
num_steps = total_batches

if scheduler_class == MultiStepWithWarmupScheduler:
scheduler = scheduler_class(milestones=['60ba'], t_warmup=warmup_duration_str) # type: ignore
else:
scheduler = scheduler_class(t_warmup=warmup_duration_str) # type: ignore

state = dummy_schedulers_state
state.max_duration = Time.from_timestring(max_duration_str)
state.timestamp = Timestamp()
state.set_dataloader([None] * batches_per_epoch, 'train')

with error_context:
for _ in range(num_steps):
_ = scheduler(state)
state.timestamp = state.timestamp.to_next_batch(samples=samples_per_batch,
tokens=tokens_per_sample * samples_per_batch)

0 comments on commit e7a0922

Please sign in to comment.