Skip to content

Commit

Permalink
Avoid an unnecessary host-device sync when created initial loss tenso…
Browse files Browse the repository at this point in the history
…rs (#68)
  • Loading branch information
epwalsh authored Oct 11, 2024
1 parent ad4c8bb commit 5af60ba
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- `prepare_cli_environment()` now calls `add_cached_path_clients()`.
- Removed an unnecessary host-device sync.

## [v1.4.0](https://github.com/allenai/OLMo-core/releases/tag/v1.4.0) - 2024-10-02

Expand Down
7 changes: 5 additions & 2 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,9 +1057,11 @@ def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False):
# In case this helps with memory utilization.
del batch

ce_batch_loss = torch.tensor(0.0, device=self.device)
ce_batch_loss = move_to_device(torch.tensor(0.0), self.device)
z_batch_loss = (
None if self.z_loss_multiplier is None else torch.tensor(0.0, device=self.device)
None
if self.z_loss_multiplier is None
else move_to_device(torch.tensor(0.0), self.device)
)

# Train one micro-batch at a time.
Expand Down Expand Up @@ -1170,6 +1172,7 @@ def _fit_epoch(self):

if first_batch or self.global_step % self.metrics_collect_interval == 0:
self._log_metrics()
torch.cuda.set_sync_debug_mode("warn")

first_batch = False

Expand Down
6 changes: 6 additions & 0 deletions src/olmo_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ def filter_warnings():
category=UserWarning,
message="Please use DTensor instead.*",
)
warnings.filterwarnings(
action="ignore",
category=UserWarning,
message="Synchronization debug mode is a prototype feature.*",
module="torch.cuda",
)
warnings.filterwarnings(
action="ignore",
category=FutureWarning,
Expand Down

0 comments on commit 5af60ba

Please sign in to comment.