Skip to content

Commit

Permalink
Add option to compile loss function, move logits FP32 casting into lo…
Browse files Browse the repository at this point in the history
…ss function (#77)

This change gives a small but not insignificant improvement to
throughput and GPU memory usage by allowing the casting of logits to
FP32 to be fused into the loss function via `torch.compile()`.
  • Loading branch information
epwalsh authored Nov 1, 2024
1 parent 4928f82 commit 00d34f6
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added option to compile the trainer's loss function (`Trainer.compile_loss`).
- Added `SourceMixtureDataset` for composing a training mixture based on ratios of source datasets.
- Added `NumpyFSLDatasetMixture` for constructing a `NumpyDatasetBase` from a `SourceMixtureDataset`. Note this is only supported for FSL datasets.
- Added tests for `SourceMixture*` and `NumpyFSLDatasetMixture`.
Expand Down
3 changes: 3 additions & 0 deletions src/olmo_core/nn/functional/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def cross_entropy_loss(
:returns: The cross entropy loss and optionally the z-loss.
"""
logits = logits.float()
loss = F.cross_entropy(logits, labels, ignore_index=ignore_index, reduction=reduction)

if not compute_z_loss:
Expand Down Expand Up @@ -87,6 +88,8 @@ def fused_cross_entropy_loss(
if _fused_cross_entropy_loss is None:
raise RuntimeError("triton is required for fused_cross_entropy_loss")

logits = logits.float()

loss, z_loss = _fused_cross_entropy_loss(
logits,
labels,
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def forward(
h = block(h, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens)

h = self.norm(h) if self.norm is not None else h
out = self.w_out(h).float() if self.w_out is not None else h
out = self.w_out(h) if self.w_out is not None else h
return out

def apply_activation_checkpointing(
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class TrainerConfig(Config):
metrics_collect_interval: int = 5
callbacks: Dict[str, Callback] = field(default_factory=dict)
fused_loss: bool = False
compile_loss: bool = False
z_loss_multiplier: Optional[float] = None
autocast_precision: Optional[DType] = None

Expand Down
19 changes: 17 additions & 2 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class Trainer:
depend on the input sizes.
"""

compile_loss: bool = False
"""
Compile the loss function.
"""

z_loss_multiplier: Optional[float] = None
"""
Use Z-loss with this multiplier.
Expand Down Expand Up @@ -259,6 +264,9 @@ class Trainer:
_thread_pool: Optional[ThreadPoolExecutor] = None
_bookkeeping_pg: Optional[dist.ProcessGroup] = None
_checkpoint_loaded: bool = False
# NOTE: do not assign a default here or it will become a bound method due to the way
# dataclasses work.
_loss_fn = None

def __post_init__(self):
self.save_folder = normalize_path(self.save_folder)
Expand Down Expand Up @@ -350,6 +358,14 @@ def __post_init__(self):
for callback in self.callbacks.values():
callback.post_attach()

# Set loss function.
if self.fused_loss:
self._loss_fn = fused_cross_entropy_loss
else:
self._loss_fn = cross_entropy_loss
if self.compile_loss:
self._loss_fn = torch.compile(self._loss_fn)

@property
def global_batch_size(self) -> int:
"""
Expand Down Expand Up @@ -874,7 +890,6 @@ def get_losses(
:returns: The cross entropy and optional Z-loss, respectively.
"""
loss_fn = cross_entropy_loss if not self.fused_loss else fused_cross_entropy_loss
if compute_z_loss is None:
compute_z_loss = self.z_loss_multiplier is not None

Expand All @@ -888,7 +903,7 @@ def get_losses(
# shape: (batch_size * (seq_len - 1),)
labels = labels.view(-1)

ce_loss, z_loss = loss_fn(
ce_loss, z_loss = self._loss_fn( # type: ignore
logits_for_loss,
labels,
ignore_index=self.data_loader.collator.label_ignore_index,
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/train/Llama-8B.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
metrics_collect_interval=10,
cancel_check_interval=1,
z_loss_multiplier=1e-5,
fused_loss=True,
compile_loss=True,
)
.with_callback(
"checkpointer",
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/train/OLMo-13B.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
metrics_collect_interval=10,
cancel_check_interval=1,
z_loss_multiplier=1e-5,
fused_loss=True,
compile_loss=True,
)
.with_callback(
"checkpointer",
Expand Down
1 change: 1 addition & 0 deletions src/scripts/train/OLMo-1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
metrics_collect_interval=10,
cancel_check_interval=1,
z_loss_multiplier=1e-5,
compile_loss=True,
)
.with_callback(
"checkpointer",
Expand Down
2 changes: 1 addition & 1 deletion src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
metrics_collect_interval=10,
cancel_check_interval=1,
z_loss_multiplier=1e-5,
fused_loss=True,
compile_loss=True,
)
.with_callback(
"checkpointer",
Expand Down
1 change: 1 addition & 0 deletions src/scripts/train/OLMoE-1B-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
metrics_collect_interval=10,
cancel_check_interval=1,
z_loss_multiplier=1e-5,
compile_loss=True,
)
.with_callback(
"checkpointer",
Expand Down

0 comments on commit 00d34f6

Please sign in to comment.