Skip to content

Commit

Permalink
Add tokens to iterations (#3374)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Jun 13, 2024
1 parent aedb229 commit 13f2a4f
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 23 deletions.
28 changes: 15 additions & 13 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,19 +273,21 @@ def batch_end(self, state: State, logger: Logger) -> None:
The following :attr:`.State.timestamp` member variables are
incremented immediately before the :attr:`.Event.BATCH_END` event.
+------------------------------------+
| :attr:`.Timestamp.batch` |
+------------------------------------+
| :attr:`.Timestamp.batch_in_epoch` |
+------------------------------------+
| :attr:`.Timestamp.sample` |
+------------------------------------+
| :attr:`.Timestamp.sample_in_epoch` |
+------------------------------------+
| :attr:`.Timestamp.token` |
+------------------------------------+
| :attr:`.Timestamp.token_in_epoch` |
+------------------------------------+
+--------------------------------------+
| :attr:`.Timestamp.batch` |
+--------------------------------------+
| :attr:`.Timestamp.batch_in_epoch` |
+--------------------------------------+
| :attr:`.Timestamp.sample` |
+--------------------------------------+
| :attr:`.Timestamp.sample_in_epoch` |
+--------------------------------------+
| :attr:`.Timestamp.token` |
+--------------------------------------+
| :attr:`.Timestamp.token_in_epoch` |
+--------------------------------------+
| :attr:`.Timestamp.token_in_iteration`|
+--------------------------------------+
Args:
state (State): The training state.
Expand Down
2 changes: 1 addition & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def _iteration_length(self, iteration_length: Optional[Union[str, Time[int]]]):
return
if isinstance(iteration_length, str):
iteration_length = ensure_time(iteration_length, TimeUnit.EPOCH)
if iteration_length.unit != TimeUnit.EPOCH:
if iteration_length.unit != TimeUnit.EPOCH and iteration_length.unit != TimeUnit.TOKEN:
raise NotImplementedError(f'{iteration_length.unit} is not allowed as a unit for iteration_length.')
self.__iteration_length = iteration_length

Expand Down
32 changes: 32 additions & 0 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ class Timestamp(Serializable):
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
token_in_iteration (int | Time[int], optional): The token in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
Expand All @@ -490,6 +491,7 @@ def __init__(
sample: Union[int, Time[int]] = 0,
token: Union[int, Time[int]] = 0,
epoch_in_iteration: Union[int, Time[int]] = 0,
token_in_iteration: Union[int, Time[int]] = 0,
batch_in_epoch: Union[int, Time[int]] = 0,
sample_in_epoch: Union[int, Time[int]] = 0,
token_in_epoch: Union[int, Time[int]] = 0,
Expand Down Expand Up @@ -531,6 +533,14 @@ def __init__(
))
self._epoch_in_iteration = epoch_in_iteration

token_in_iteration = Time.from_input(token_in_iteration, TimeUnit.TOKEN)
if token_in_iteration.unit != TimeUnit.TOKEN:
raise ValueError((
f'The `token_in_iteration` argument has units of {token_in_iteration.unit}; '
f'not {TimeUnit.TOKEN}.'
))
self._token_in_iteration = token_in_iteration

batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH)
if batch_in_epoch.unit != TimeUnit.BATCH:
raise ValueError(
Expand Down Expand Up @@ -579,6 +589,7 @@ def state_dict(self) -> dict[str, Any]:
'sample': self.sample.value,
'token': self.token.value,
'epoch_in_iteration': self.epoch_in_iteration.value,
'token_in_iteration': self.token_in_iteration.value,
'batch_in_epoch': self.batch_in_epoch.value,
'sample_in_epoch': self.sample_in_epoch.value,
'token_in_epoch': self.token_in_epoch.value,
Expand Down Expand Up @@ -609,6 +620,8 @@ def load_state_dict(self, state: dict[str, Any]) -> None:
self._iteration = Time(state['iteration'], TimeUnit.ITERATION)
if 'epoch_in_iteration' in state:
self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
if 'token_in_iteration' in state:
self._token_in_iteration = Time(state['token_in_iteration'], TimeUnit.TOKEN)
if 'iteration_wct' in state:
self._iteration_wct = state['iteration_wct']

Expand Down Expand Up @@ -642,6 +655,11 @@ def epoch_in_iteration(self) -> Time[int]:
"""The epoch count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._epoch_in_iteration

@property
def token_in_iteration(self) -> Time[int]:
"""The token count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._token_in_iteration

@property
def batch_in_epoch(self) -> Time[int]:
"""The batch count in the current epoch (resets at 0 at the beginning of every epoch)."""
Expand Down Expand Up @@ -814,6 +832,7 @@ def to_next_batch(
sample_in_epoch=self.sample_in_epoch + samples,
token=self.token + tokens,
token_in_epoch=self.token_in_epoch + tokens,
token_in_iteration=self.token_in_iteration + tokens,
total_wct=self.total_wct + duration,
iteration_wct=self.iteration_wct + duration,
epoch_wct=self.epoch_wct + duration,
Expand All @@ -822,6 +841,7 @@ def to_next_batch(

def to_next_epoch(
self,
tokens: Union[int, Time] = 0,
duration: Optional[datetime.timedelta] = None,
):
"""Create a new :class:`.Timestamp`, advanced to the next epoch.
Expand All @@ -841,6 +861,7 @@ def to_next_epoch(
>>> timestamp.copy(
... epoch=timestamp.epoch + 1,
... epoch_in_iteration=timestamp.epoch_in_iteration + 1,
... token_in_iteration=timestamp.token_in_iteration + tokens,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
Expand All @@ -851,12 +872,17 @@ def to_next_epoch(
... )
Timestamp(...)
Args:
tokens (int | Time, optional): The number of tokens trained in the batch. Defaults to 0.
duration (datetime.timedelta, optional): The duration to train the batch.
"""
if duration is None:
duration = datetime.timedelta(seconds=0)
return self.copy(
epoch=self.epoch + 1,
epoch_in_iteration=self.epoch_in_iteration + 1,
token_in_iteration=self.token_in_iteration + tokens,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
Expand Down Expand Up @@ -886,6 +912,7 @@ def to_next_iteration(
>>> timestamp.copy(
... iteration=timestamp.iteration + 1,
... epoch_in_iteration=0,
... token_in_iteration=0,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
Expand All @@ -902,6 +929,7 @@ def to_next_iteration(
return self.copy(
iteration=self.iteration + 1,
epoch_in_iteration=0,
token_in_iteration=0,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
Expand All @@ -919,6 +947,7 @@ def copy(
sample: Optional[Union[int, Time[int]]] = None,
token: Optional[Union[int, Time[int]]] = None,
epoch_in_iteration: Optional[Union[int, Time[int]]] = None,
token_in_iteration: Optional[Union[int, Time[int]]] = None,
batch_in_epoch: Optional[Union[int, Time[int]]] = None,
sample_in_epoch: Optional[Union[int, Time[int]]] = None,
token_in_epoch: Optional[Union[int, Time[int]]] = None,
Expand All @@ -938,6 +967,7 @@ def copy(
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
token_in_iteration (int | Time[int], optional): The token in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
Expand All @@ -957,6 +987,7 @@ def copy(
sample=sample if sample is not None else self.sample,
token=token if token is not None else self.token,
epoch_in_iteration=epoch_in_iteration if epoch_in_iteration is not None else self.epoch_in_iteration,
token_in_iteration=token_in_iteration if token_in_iteration is not None else self.token_in_iteration,
batch_in_epoch=batch_in_epoch if batch_in_epoch is not None else self.batch_in_epoch,
sample_in_epoch=sample_in_epoch if sample_in_epoch is not None else self.sample_in_epoch,
token_in_epoch=token_in_epoch if token_in_epoch is not None else self.token_in_epoch,
Expand All @@ -975,6 +1006,7 @@ def __repr__(self) -> str:
f'sample={int(self.sample)}, '
f'token={int(self.token)}, '
f'epoch_in_iteration={int(self.epoch_in_iteration)}, '
f'token_in_iteration={int(self.token_in_iteration)}, '
f'batch_in_epoch={int(self.batch_in_epoch)}, '
f'sample_in_epoch={int(self.sample_in_epoch)}, '
f'token_in_epoch={int(self.token_in_epoch)}, '
Expand Down
32 changes: 25 additions & 7 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,10 +2610,24 @@ def _train_loop(self) -> None:

self.engine.run_event(Event.BATCH_CHECKPOINT)

if self.state.timestamp >= self.state.max_duration:
if (
self.state.timestamp >= self.state.max_duration or (
self.state._iteration_length is not None and
self.state.timestamp.token_in_iteration.unit == self.state._iteration_length.unit and
self.state.timestamp.token_in_iteration >= self.state._iteration_length
)
):
# If max_duration is specified in batches, samples, or tokens, and
# and the max_duration is reached mid-epoch, then break out of the dataloader
# to finish the epoch early and finish training.

# Increment iteration
if (
self.state._iteration_length is not None and
self.state.timestamp.token_in_iteration.unit == self.state._iteration_length.unit and
self.state.timestamp.token_in_iteration >= self.state._iteration_length
):
self._increment_iteration()
finished_epoch_early = True
break

Expand Down Expand Up @@ -2649,12 +2663,10 @@ def _train_loop(self) -> None:
# Increment iteration
if (
self.state._iteration_length is not None and
self.state.timestamp.epoch_in_iteration == self.state._iteration_length
self.state.timestamp.epoch_in_iteration.unit == self.state._iteration_length.unit and
self.state.timestamp.epoch_in_iteration >= self.state._iteration_length
):
self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_iteration()
self.engine.run_event(Event.ITERATION_END)
self.engine.run_event(Event.ITERATION_CHECKPOINT)
self._increment_iteration()

# Log final time values
self.logger.log_metrics({
Expand Down Expand Up @@ -3039,6 +3051,12 @@ def _train_microbatch(

return microbatch_loss_dict

def _increment_iteration(self):
self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_iteration()
self.engine.run_event(Event.ITERATION_END)
self.engine.run_event(Event.ITERATION_CHECKPOINT)

def predict(
self,
dataloader: Union[DataLoader, DataSpec],
Expand Down Expand Up @@ -3506,7 +3524,7 @@ def _eval_loop(
outputs.append(v)
else:
outputs = self.state.outputs.cpu()
batch = DeviceCPU().batch_to_device(self.state.batch,)
batch = DeviceCPU().batch_to_device(self.state.batch)
else:
outputs = self.state.outputs
batch = self.state.batch
Expand Down
1 change: 1 addition & 0 deletions tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ def test_get_resumption_state_dict():
'sample': 0,
'token': 0,
'epoch_in_iteration': 0,
'token_in_iteration': 0,
'batch_in_epoch': 0,
'sample_in_epoch': 0,
'token_in_epoch': 0,
Expand Down
10 changes: 8 additions & 2 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_timestamp_to_next_batch_epoch_iteration():
# Step batch 0 in epoch 0
timestamp = timestamp.to_next_batch(10, 20, datetime.timedelta(seconds=5))
assert timestamp.batch == 1
assert timestamp.batch_in_epoch == 1
assert timestamp.token_in_iteration == 20
assert timestamp.batch_in_epoch == 1
assert timestamp.sample == 10
assert timestamp.sample_in_epoch == 10
Expand All @@ -163,9 +163,10 @@ def test_timestamp_to_next_batch_epoch_iteration():
assert timestamp.batch_wct == datetime.timedelta(seconds=5)

# Finish epoch 0
timestamp = timestamp.to_next_epoch(datetime.timedelta(seconds=5))
timestamp = timestamp.to_next_epoch(duration=datetime.timedelta(seconds=5))
assert timestamp.epoch == 1
assert timestamp.batch == 1
assert timestamp.token_in_iteration == 20
assert timestamp.batch_in_epoch == 0
assert timestamp.sample == 10
assert timestamp.sample_in_epoch == 0
Expand All @@ -181,6 +182,7 @@ def test_timestamp_to_next_batch_epoch_iteration():
assert timestamp.epoch == 1
assert timestamp.batch == 2
assert timestamp.epoch_in_iteration == 1
assert timestamp.token_in_iteration == 20
assert timestamp.batch_in_epoch == 1
assert timestamp.sample == 15
assert timestamp.sample_in_epoch == 5
Expand All @@ -195,6 +197,7 @@ def test_timestamp_to_next_batch_epoch_iteration():
timestamp = timestamp.to_next_batch(5, 1, datetime.timedelta(seconds=10))
assert timestamp.epoch == 1
assert timestamp.batch == 3
assert timestamp.token_in_iteration == 21
assert timestamp.batch_in_epoch == 2
assert timestamp.sample == 20
assert timestamp.sample_in_epoch == 10
Expand All @@ -210,6 +213,7 @@ def test_timestamp_to_next_batch_epoch_iteration():
assert timestamp.epoch == 2
assert timestamp.batch == 3
assert timestamp.epoch_in_iteration == 2
assert timestamp.token_in_iteration == 21
assert timestamp.batch_in_epoch == 0
assert timestamp.sample == 20
assert timestamp.sample_in_epoch == 0
Expand All @@ -224,6 +228,7 @@ def test_timestamp_to_next_batch_epoch_iteration():
assert timestamp.epoch == 2
assert timestamp.batch == 4
assert timestamp.epoch_in_iteration == 2
assert timestamp.token_in_iteration == 22
assert timestamp.batch_in_epoch == 1
assert timestamp.sample == 25
assert timestamp.sample_in_epoch == 5
Expand All @@ -240,6 +245,7 @@ def test_timestamp_to_next_batch_epoch_iteration():
assert timestamp.epoch == 2
assert timestamp.batch == 4
assert timestamp.epoch_in_iteration == 0
assert timestamp.token_in_iteration == 0
assert timestamp.batch_in_epoch == 0
assert timestamp.sample == 25
assert timestamp.sample_in_epoch == 0
Expand Down

0 comments on commit 13f2a4f

Please sign in to comment.