Skip to content

Commit

Permalink
Allow unknown number of batches with data loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 2, 2024
1 parent 87f1e89 commit b921299
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 113 deletions.
1 change: 0 additions & 1 deletion docs/source/train/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@

callbacks
checkpoint
utils
5 changes: 0 additions & 5 deletions docs/source/train/utils.rst

This file was deleted.

2 changes: 1 addition & 1 deletion src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Duration,
TrainerConfig,
prepare_training_environment,
teardown_training_environment,
Expand All @@ -37,7 +38,6 @@
SequenceLengthSchedulerCallback,
WandBCallback,
)
from olmo_core.train.utils import Duration
from olmo_core.utils import get_default_device, seed_all


Expand Down
16 changes: 12 additions & 4 deletions src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __init__(
def epoch(self) -> int:
"""
Get the current epoch (1-based).
.. warning::
Accessing this before :meth:`reshuffle()` is called will raise an error.
"""
if self._epoch is None:
raise RuntimeError(
Expand All @@ -122,17 +125,22 @@ def epoch(self) -> int:

@property
@abstractmethod
def total_batches(self) -> int:
def total_batches(self) -> Optional[int]:
"""
The total number of batches that the dataset will produce over the course of an epoch.
The total number of batches that the dataset will produce over the course of an epoch, if known.
Otherwise this should return ``None``.
"""
raise NotImplementedError

def __len__(self) -> int:
"""
Returns the total number of batches in an epoch, the same as :data:`total_batches`.
Returns the total number of batches in an epoch (same as :data:`total_batches`) if known,
otherwise a :class:`TypeError` is raised.
"""
return self.total_batches
if self.total_batches is not None:
return self.total_batches
else:
raise TypeError("data loader length (number of batches) is unknown")

def __iter__(self) -> Iterator[Dict[str, Any]]:
"""
Expand Down
6 changes: 5 additions & 1 deletion src/olmo_core/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,19 @@
from ..distributed.utils import init_distributed, is_distributed
from ..io import add_cached_path_clients
from ..utils import LogFilterType, prepare_cli_environment, seed_all
from .common import Duration, DurationUnit, LoadStrategy, ReduceType
from .config import TrainerConfig
from .trainer import LoadStrategy, Trainer
from .trainer import Trainer

__all__ = [
"prepare_training_environment",
"teardown_training_environment",
"TrainerConfig",
"Trainer",
"LoadStrategy",
"Duration",
"DurationUnit",
"ReduceType",
]


Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/train/callbacks/evaluator_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.utils import format_float, move_to_device

from ..utils import Duration
from ..common import Duration
from .callback import Callback, CallbackConfig

if TYPE_CHECKING:
Expand Down
116 changes: 116 additions & 0 deletions src/olmo_core/train/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from dataclasses import dataclass

from ..config import StrEnum


class DurationUnit(StrEnum):
"""
Units that can be used to define a :class:`Duration`.
"""

steps = "steps"
"""
Steps (batches).
"""
epochs = "epochs"
"""
Epochs.
"""
tokens = "tokens"
"""
Tokens.
"""


@dataclass
class Duration:
value: int
"""
The value of the duration.
"""
unit: DurationUnit
"""
The unit associated with the :data:`value`.
"""

@classmethod
def steps(cls, steps: int) -> "Duration":
"""
Define a duration from a number of steps.
"""
return cls(value=steps, unit=DurationUnit.steps)

@classmethod
def epochs(cls, epochs: int) -> "Duration":
"""
Define a duration from a number of epochs.
"""
return cls(value=epochs, unit=DurationUnit.epochs)

@classmethod
def tokens(cls, tokens: int) -> "Duration":
"""
Define a duration from a number of tokens.
"""
return cls(value=tokens, unit=DurationUnit.tokens)

def due(self, *, step: int, tokens: int, epoch: int) -> bool:
"""
Check if the duration is due.
"""
if self.unit == DurationUnit.steps:
return step >= self.value
elif self.unit == DurationUnit.tokens:
return tokens >= self.value
elif self.unit == DurationUnit.epochs:
return epoch > self.value
else:
raise NotImplementedError


class LoadStrategy(StrEnum):
"""
Determines the strategy for loading checkpoints prior to training.
"""

if_available = "if_available"
"""
Only load from the load path if a checkpoint exists there.
"""

always = "always"
"""
Always try loading from the load path.
"""

never = "never"
"""
Never load from the load path.
"""


class ReduceType(StrEnum):
"""
An enumeration of the allowed ways to reduce a metric across ranks.
"""

mean = "mean"
"""
Average across the process group.
"""

sum = "sum"
"""
Add across the process group.
"""

max = "max"
"""
Take the max across the process group.
"""

l2_norm = "l2_norm"
"""
For metrics that are computed as L2 norms on each rank, this will correctly reduce the norm
across the process group to produce the global L2 norm.
"""
4 changes: 2 additions & 2 deletions src/olmo_core/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from ..utils import get_default_device
from .callbacks import Callback, CallbackConfig
from .checkpoint import Checkpointer
from .trainer import LoadStrategy, Trainer
from .utils import Duration
from .common import Duration, LoadStrategy
from .trainer import Trainer


@dataclass
Expand Down
58 changes: 23 additions & 35 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch.optim import Optimizer

from ..aliases import PathOrStr
from ..config import StrEnum
from ..data import DataLoaderBase
from ..data.utils import get_labels, split_batch
from ..distributed.utils import (
Expand Down Expand Up @@ -55,14 +54,8 @@
SpeedMonitorCallback,
)
from .checkpoint import Checkpointer
from .utils import (
Duration,
DurationUnit,
EnvRngStates,
ReduceType,
move_metrics,
reduce_metrics,
)
from .common import Duration, DurationUnit, LoadStrategy, ReduceType
from .utils import EnvRngStates, move_metrics, reduce_metrics

log = logging.getLogger(__name__)

Expand All @@ -73,27 +66,6 @@
SEQ_LEN_METRIC = "data/sequence length"


class LoadStrategy(StrEnum):
"""
Determines the strategy for loading checkpoints prior to training.
"""

if_available = "if_available"
"""
Only load from the load path if a checkpoint exists there.
"""

always = "always"
"""
Always try loading from the load path.
"""

never = "never"
"""
Never load from the load path.
"""


class TrainerStateDict(TypedDict):
global_step: int
global_train_tokens_seen: int
Expand Down Expand Up @@ -171,6 +143,13 @@ class Trainer:
max_duration: Duration
"""
The duration to train for.
.. important::
The total number of training steps must be known ahead of time for various reasons such
as setting a learning rate schedule. Therefore if your data loader's number of batches
(:data:`~olmo_core.data.data_loader.DataLoaderBase.total_batches`) is unknown ahead of time,
you must set the ``max_duration`` in terms of :meth:`tokens <Duration.tokens>`
or :meth:`steps <Duration.steps>`, but not epochs.
"""

rank_microbatch_size: int
Expand Down Expand Up @@ -420,18 +399,21 @@ def tokens_per_batch(self) -> int:
return self.global_batch_size

@property
def steps_per_epoch(self) -> int:
def steps_per_epoch(self) -> Optional[int]:
"""
The total number of training steps in an epoch.
The total number of training steps in an epoch, if known.
"""
return self.data_loader.total_batches

@property
def tokens_per_epoch(self) -> int:
def tokens_per_epoch(self) -> Optional[int]:
"""
The total number of tokens in the training dataset, minus left-overs.
"""
return self.steps_per_epoch * self.tokens_per_batch
if self.steps_per_epoch is not None:
return self.steps_per_epoch * self.tokens_per_batch
else:
return None

@property
def max_steps(self) -> int:
Expand All @@ -441,13 +423,19 @@ def max_steps(self) -> int:
if self.max_duration.unit == DurationUnit.steps:
return self.max_duration.value
elif self.max_duration.unit == DurationUnit.epochs:
if self.data_loader.total_batches is None:
raise RuntimeError(
"the number of steps cannot be determined from an 'epochs' duration since "
"the data loader's number of batches is unknown"
)
max_epochs = self.max_duration.value
complete_epochs_remaining = max(max_epochs - self.epoch, 0)
steps_remaining_this_epoch = max(
self.data_loader.total_batches - self.data_loader.batches_processed, 0
)
steps_remaining = (
complete_epochs_remaining * self.steps_per_epoch + steps_remaining_this_epoch
complete_epochs_remaining * self.data_loader.total_batches
+ steps_remaining_this_epoch
)
return self.global_step + steps_remaining
elif self.max_duration.unit == DurationUnit.tokens:
Expand Down
Loading

0 comments on commit b921299

Please sign in to comment.