Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add setter for epoch in iteration #3407

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,13 +525,8 @@ def __init__(
raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.')
self._token = token

epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH)
if epoch_in_iteration.unit != TimeUnit.EPOCH:
raise ValueError((
f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; '
f'not {TimeUnit.EPOCH}.'
))
self._epoch_in_iteration = epoch_in_iteration
self._epoch_in_iteration = Time(0, TimeUnit.EPOCH)
self.epoch_in_iteration = epoch_in_iteration
b-chu marked this conversation as resolved.
Show resolved Hide resolved
b-chu marked this conversation as resolved.
Show resolved Hide resolved

token_in_iteration = Time.from_input(token_in_iteration, TimeUnit.TOKEN)
if token_in_iteration.unit != TimeUnit.TOKEN:
Expand Down Expand Up @@ -619,7 +614,7 @@ def load_state_dict(self, state: dict[str, Any]) -> None:
if 'iteration' in state:
self._iteration = Time(state['iteration'], TimeUnit.ITERATION)
if 'epoch_in_iteration' in state:
self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
self.epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
if 'token_in_iteration' in state:
self._token_in_iteration = Time(state['token_in_iteration'], TimeUnit.TOKEN)
if 'iteration_wct' in state:
Expand Down Expand Up @@ -655,6 +650,20 @@ def epoch_in_iteration(self) -> Time[int]:
"""The epoch count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._epoch_in_iteration

@epoch_in_iteration.setter
def epoch_in_iteration(
self,
epoch_in_iteration: Union[int, Time[int]], # pyright: ignore[reportPropertyTypeMismatch]
):
"""Sets epoch count in the current iteration."""
epoch_in_iteration = Time.from_input(epoch_in_iteration, TimeUnit.EPOCH)
if epoch_in_iteration.unit != TimeUnit.EPOCH:
raise ValueError((
f'The `epoch_in_iteration` argument has units of {epoch_in_iteration.unit}; '
f'not {TimeUnit.EPOCH}.'
))
self._epoch_in_iteration = epoch_in_iteration
b-chu marked this conversation as resolved.
Show resolved Hide resolved

@property
def token_in_iteration(self) -> Time[int]:
"""The token count in the current iteration (resets at 0 at the beginning of every iteration)."""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def test_timestamp_update():
assert timestamp is not timestamp_2


def test_set_timestamp():
timestamp = Timestamp(epoch_in_iteration=1)
assert timestamp.epoch_in_iteration == 1
timestamp.epoch_in_iteration = 2
assert timestamp.epoch_in_iteration == 2
b-chu marked this conversation as resolved.
Show resolved Hide resolved


def test_timestamp_to_next_batch_epoch_iteration():
timestamp = Timestamp()
# Step batch 0 in epoch 0
Expand Down
Loading