Skip to content

Commit

Permalink
Do a dry-run batch before starting training (#67)
Browse files Browse the repository at this point in the history
Fixes a discrepancy after restarts caused by compiling on the first batch.
  • Loading branch information
epwalsh authored Oct 11, 2024
1 parent 71bc5c8 commit 0c75ef6
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `CometCallback` for logging training runs to Comet.ml.
- Added `DataMixBase` class, to allow extending to new data mix groups.
- Added method `DataLoaderBase.get_mock_batch()`.
- Trainer now starts with a dry-run of a fake batch created by `DataLoaderBase.get_mock_batch()`.

### Changed

Expand Down
15 changes: 15 additions & 0 deletions src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,14 @@ def reset(self):
self.batches_processed = 0
self.tokens_processed = 0

@abstractmethod
def get_mock_batch(self) -> Dict[str, Any]:
"""
Return a batch with arbitrary data. This can just be random data as it's only used by the
trainer to do a dry-run of the forward and backward pass before training officially starts.
"""
raise NotImplementedError


class NumpyDataLoaderBase(DataLoaderBase):
"""
Expand Down Expand Up @@ -421,6 +429,13 @@ def reshuffle(self, epoch: Optional[int] = None, in_memory: bool = False, **kwar
self._epoch = epoch
self.build_and_save_global_indices(in_memory=in_memory)

def get_mock_batch(self) -> Dict[str, Any]:
num_instances = self.rank_batch_size // self.dataset.max_sequence_length
input_ids = torch.randint(
0, self.dataset.vocab_size, (num_instances, self.dataset.max_sequence_length)
)
return {"input_ids": input_ids}

def _iter_batches(self) -> Iterable[Dict[str, Any]]:
return torch.utils.data.DataLoader(
_IterableDatasetWrapper(self),
Expand Down
31 changes: 28 additions & 3 deletions src/olmo_core/data/numpy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
*paths: PathOrStr,
pad_token_id: int,
eos_token_id: int,
vocab_size: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16,
):
if not paths:
Expand All @@ -106,6 +107,7 @@ def __init__(
self._array_paths = tuple(paths)
self._pad_token_id = pad_token_id
self._eos_token_id = eos_token_id
self._vocab_size = vocab_size
self._dtype = dtype
self._fs_local_rank = get_fs_local_rank()
self._work_dir: Optional[Path] = None
Expand Down Expand Up @@ -144,6 +146,10 @@ def pad_token_id(self) -> int:
def eos_token_id(self) -> int:
return self._eos_token_id

@property
def vocab_size(self) -> int:
return self._vocab_size

@property
def dtype(
self,
Expand Down Expand Up @@ -340,6 +346,7 @@ def __init__(
sequence_length: int,
pad_token_id: int,
eos_token_id: int,
vocab_size: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16,
metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
include_instance_metadata: Optional[bool] = None,
Expand All @@ -365,7 +372,13 @@ def __init__(
else:
metadata = [metadata or {}] * len(paths)

super().__init__(*paths, pad_token_id=pad_token_id, eos_token_id=eos_token_id, dtype=dtype)
super().__init__(
*paths,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
vocab_size=vocab_size,
dtype=dtype,
)
self._metadata = tuple(metadata)
self._sequence_length = sequence_length
self._max_target_sequence_length = max_target_sequence_length
Expand Down Expand Up @@ -502,6 +515,7 @@ def __init__(
sequence_length: int,
pad_token_id: int,
eos_token_id: int,
vocab_size: int,
dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16,
metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
include_instance_metadata: Optional[bool] = None,
Expand All @@ -511,6 +525,7 @@ def __init__(
sequence_length=sequence_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
vocab_size=vocab_size,
dtype=dtype,
metadata=metadata,
include_instance_metadata=include_instance_metadata,
Expand Down Expand Up @@ -605,7 +620,7 @@ def _write_instance_indices(self):

# Log results.
for path, future in zip(paths_needed, futures):
total_og_docs, total_instances = future.result()
_, total_instances = future.result()
log.info(
f"Created {total_instances:,d} instances of sequence length up to "
f"{self.sequence_length} from '{path}'"
Expand Down Expand Up @@ -926,6 +941,7 @@ def __init__(
*paths: PathOrStr,
pad_token_id: int,
eos_token_id: int,
vocab_size: int,
max_sequence_length: int,
min_sequence_length: int = 256,
curriculum: Optional[VSLCurriculum] = None,
Expand Down Expand Up @@ -955,7 +971,13 @@ def __init__(
else:
metadata = [metadata or {}] * len(paths)

super().__init__(*paths, pad_token_id=pad_token_id, eos_token_id=eos_token_id, dtype=dtype)
super().__init__(
*paths,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
vocab_size=vocab_size,
dtype=dtype,
)
self._metadata = metadata
self._include_instance_metadata = include_instance_metadata
self._max_sequence_length = max_sequence_length
Expand Down Expand Up @@ -1539,6 +1561,7 @@ def build(self) -> NumpyDatasetBase:
max_target_sequence_length=self.max_target_sequence_length,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
vocab_size=self.tokenizer.vocab_size,
dtype=self.get_dtype(),
metadata=metadata,
include_instance_metadata=self.include_instance_metadata,
Expand Down Expand Up @@ -1578,6 +1601,7 @@ def build(self) -> NumpyDatasetBase:
sequence_length=self.sequence_length,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
vocab_size=self.tokenizer.vocab_size,
dtype=self.get_dtype(),
metadata=metadata,
include_instance_metadata=self.include_instance_metadata,
Expand All @@ -1602,6 +1626,7 @@ def build(self) -> NumpyDatasetBase:
curriculum=None if self.vsl_curriculum is None else self.vsl_curriculum.build(),
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
vocab_size=self.tokenizer.vocab_size,
dtype=self.get_dtype(),
metadata=metadata,
include_instance_metadata=self.include_instance_metadata,
Expand Down
49 changes: 37 additions & 12 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ def fit(self):
og_sigterm_handler = signal.signal(signal.SIGTERM, self._handle_os_signal)
og_sigint_handler = signal.signal(signal.SIGINT, self._handle_os_signal)

# Do a dry-run for compiling and catch OOMs.
self._dry_run_batch()

try:
while not self.training_complete:
self._fit_epoch()
Expand Down Expand Up @@ -1025,9 +1028,9 @@ def _train_microbatch_context(
stack.enter_context(self.model.no_sync())
yield

def _train_batch(self, batch: Dict[str, Any]):
def _train_batch(self, batch: Dict[str, Any], dry_run: bool = False):
# Record how many instances are going to be skipped (masked out).
if (instance_mask := batch.get("instance_mask")) is not None:
if (instance_mask := batch.get("instance_mask")) is not None and not dry_run:
self.record_metric("train/masked instances", (~instance_mask).sum(), ReduceType.sum)

# Zero-gradients.
Expand Down Expand Up @@ -1078,6 +1081,11 @@ def _train_batch(self, batch: Dict[str, Any]):
# Run backward pass.
loss.backward()

if dry_run:
# Zero-gradients again.
self.optim.zero_grad(set_to_none=True)
return

self.record_metric(TRAIN_CE_LOSS_METRIC, ce_batch_loss, ReduceType.mean)
if z_batch_loss is not None:
self.record_metric(TRAIN_Z_LOSS_METRIC, z_batch_loss, ReduceType.mean)
Expand Down Expand Up @@ -1107,6 +1115,32 @@ def _iter_batches(self) -> Generator[Dict[str, Any], None, None]:
except StopIteration:
break

def _validate_batch(self, batch: Dict[str, Any]) -> int:
"""
Validate the data in a batch and return the global total number of tokens in the batch.
"""
# NOTE: To track the global number of tokens seen per batch we make the
# assumption that all ranks see the same number batch size in tokens per step,
# which should always be the case for training efficiency at least.
# Alternatively we'd have to use a distributed collective which isn't worth it.
if batch["input_ids"].numel() != self.rank_batch_size:
raise RuntimeError(
f"Expected batch size of {self.rank_batch_size:,d} tokens on rank {get_rank()}, "
f"got input IDs with shape {tuple(batch['input_ids'].shape)} = {batch['input_ids'].numel():,d} tokens"
)
return self.global_batch_size

def _dry_run_batch(self):
try:
batch = self.data_loader.get_mock_batch()
except NotImplementedError:
return # for backwards compatibility

log.info("Starting forward/backward dry-run batch...")
self._validate_batch(batch)
self._train_batch(batch, dry_run=True)
log.info("Dry-run complete")

def _fit_epoch(self):
self.data_loader.reshuffle(self.epoch)

Expand All @@ -1118,17 +1152,8 @@ def _fit_epoch(self):
first_batch = True
for batch in self._iter_batches():
# Bookkeeping.
# NOTE: To track the global number of tokens seen per batch we make the
# assumption that all ranks see the same number batch size in tokens per step,
# which should always be the case for training efficiency at least.
# Alternatively we'd have to use a distributed collective which isn't worth it.
if batch["input_ids"].numel() != self.rank_batch_size:
raise RuntimeError(
f"Expected batch size of {self.rank_batch_size:,d} tokens on rank {get_rank()}, "
f"got input IDs with shape {tuple(batch['input_ids'].shape)} = {batch['input_ids'].numel():,d} tokens"
)
self.global_step += 1
self.global_train_tokens_seen += self.global_batch_size
self.global_train_tokens_seen += self._validate_batch(batch)

self.record_metric(SEQ_LEN_METRIC, float(batch["input_ids"].shape[1]))

Expand Down
5 changes: 5 additions & 0 deletions src/test/data/custom_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def reshuffle(self, epoch: Optional[int] = None, **kwargs):
)
]

def get_mock_batch(self) -> Dict[str, Any]:
num_instances = self.rank_batch_size // self.sequence_length
input_ids = torch.randint(0, self.vocab_size, (num_instances, self.sequence_length))
return {"input_ids": input_ids}

def _iter_batches(self) -> Iterable[Dict[str, Any]]:
assert self._dataset is not None, "did you forget to call 'reshuffle()'?"

Expand Down
4 changes: 4 additions & 0 deletions src/test/data/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_all_batches() -> List[List[int]]:
sequence_length=sequence_length,
pad_token_id=-1,
eos_token_id=-1,
vocab_size=32_000,
)
for rank in range(world_size):
data_loader = NumpyFSLDataLoader(
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_fsl_data_loader_multiple_epochs(
sequence_length=sequence_length,
pad_token_id=-1,
eos_token_id=-1,
vocab_size=32_000,
)
data_loader = NumpyFSLDataLoader(
dataset,
Expand Down Expand Up @@ -204,6 +206,7 @@ def get_all_tokens(seq_len: int, tokens_processed: int = 0) -> List[int]:
sequence_length=seq_len,
pad_token_id=-1,
eos_token_id=-1,
vocab_size=32_000,
max_target_sequence_length=max_target_sequence_length,
)
data_loader = NumpyFSLDataLoader(
Expand Down Expand Up @@ -261,6 +264,7 @@ def test_vsl_data_loader(
tmp_path / "tokens2.npy",
pad_token_id=0,
eos_token_id=0,
vocab_size=32_000,
min_sequence_length=2,
max_sequence_length=4,
dtype=np.uint16,
Expand Down
3 changes: 3 additions & 0 deletions src/test/data/numpy_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_numpy_fsl_dataset(tmp_path: Path):
sequence_length=4,
pad_token_id=-1,
eos_token_id=-1,
vocab_size=32_000,
)
assert ds[0]["input_ids"].tolist() == [0, 1, 2, 3]
assert ds[1]["input_ids"].tolist() == [4, 5, 6, 7]
Expand All @@ -52,6 +53,7 @@ def test_numpy_padded_fsl_dataset(tmp_path: Path):
sequence_length=8,
pad_token_id=0,
eos_token_id=0,
vocab_size=32_000,
)
ds.prepare()
assert ds[0]["input_ids"].tolist() == [1, 2, 3, 4, 5, 6, 7, 0]
Expand Down Expand Up @@ -92,6 +94,7 @@ def test_numpy_vsl_dataset(tmp_path: Path):
data2_path,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
vocab_size=32_000,
max_sequence_length=8,
min_sequence_length=2,
dtype=dtype,
Expand Down

0 comments on commit 0c75ef6

Please sign in to comment.