Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Microbatch first last batch serial #11072

Merged
merged 21 commits into from
Dec 7, 2024
Merged

Conversation

MichelleArk
Copy link
Contributor

@MichelleArk MichelleArk commented Nov 28, 2024

Resolves #11094
Resolves #11104

Problem

Model pre & post hooks were running on each batch for a microbatch model

Solution

  • First and last batch are always run in serial
  • First batch runs pre-hook
  • Last batch runs post-hook
  • Account for when first batch = last batch
  • If first batch fails, skip remaining batches

Examples

When all batches succeed
Screenshot 2024-12-06 at 19 29 17 (2)

When the first batch fails
Screenshot 2024-12-06 at 19 29 33 (2)

When a middle batch fails
Screenshot 2024-12-06 at 19 30 49 (2)

Checklist

  • I have read the contributing guide and understand what's expected of me.
  • I have run this code in development, and it appears to resolve the stated issue.
  • This PR includes tests, or tests are not required or relevant for this PR.
  • This PR has no interface changes (e.g., macros, CLI, logs, JSON artifacts, config files, adapter interface, etc.) or this PR has already received feedback and approval from Product or DX.
  • This PR includes type annotations for new and modified functions.

@cla-bot cla-bot bot added the cla:yes label Nov 28, 2024
Copy link
Contributor

Thank you for your pull request! We could not find a changelog entry for this change. For details on how to document a change, see the contributing guide.

Copy link

codecov bot commented Nov 28, 2024

Codecov Report

Attention: Patch coverage is 98.38710% with 1 line in your changes missing coverage. Please review.

Project coverage is 88.90%. Comparing base (1b7d9b5) to head (0198741).
Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main   #11072      +/-   ##
==========================================
- Coverage   89.18%   88.90%   -0.28%     
==========================================
  Files         183      183              
  Lines       23783    23933     +150     
==========================================
+ Hits        21211    21278      +67     
- Misses       2572     2655      +83     
Flag Coverage Δ
integration 86.20% <96.77%> (-0.37%) ⬇️
unit 61.96% <35.48%> (-0.21%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Components Coverage Δ
Unit Tests 61.96% <35.48%> (-0.21%) ⬇️
Integration Tests 86.20% <96.77%> (-0.37%) ⬇️

Copy link
Contributor

@QMalcolm QMalcolm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woooo! Thank you for doing this work ❤️ I think this will make it so hooks work how people expect, which is incredibly important. I do think there are some changes we can make though to improve the mental model of what happens where for maintainability though.

Comment on lines 721 to 725
# Run first batch runs in serial
relation_exists = self._submit_batch(
node, relation_exists, batches, batch_idx, batch_results, pool, parallel=False
)
batch_idx += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's actually preferable to the first batch separated out. It makes it mirror how the last batch is separated out. We could also have an optional arg force_sequential (defaulted to False) which would skip the should_run_in_parallel check in _submit_batch

Comment on lines 611 to 613
elif self.batch_idx == 0 or self.batch_idx == len(self.batches) - 1:
# First and last batch don't run in parallel
run_in_parallel = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check could also be skipped if we're instead handling force_sequential to determine if we should even check should_run_in_parallel in _submit_batch. It'd be nice for this function to be less dependent on "where" it is, and I think this check breaks that.

Comment on lines 719 to 721
while batch_idx < len(runner.batches) - 1:
relation_exists = self._submit_batch(
node, relation_exists, batches, batch_idx, batch_results, pool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another reason for splitting out the first batch:
We don't want to do any of the other batches if the first batch fails. By lumping it in the while loop which all other batches (except for the last batch), we lose that. We could add that logic to the loop, however I think that'd crowd the loop logic as it will only ever be logic needed for the first batch result.

should_run_in_parallel shouldn't, and no longer needs to, take into
consideration where in batch exists in a larger context. The first and
last batch for a microbatch model are now forced to run sequentially
by `handle_microbatch_model`
…hen batches are skipped

This was necessary specifically because the default on skip set the `X of Y` part
of the skipped log using the `node_index` and the `num_nodes`. If there was 2
nodes and we are on the 4th batch of the second node, we'd get a message like
`SKIPPED 4 of 2...` which didn't make much sense. We're likely in a future commit
going to add a custom event for logging the start, result, and skipping of batches
for better readability of the logs.
Previously `MicrobatchModelRunner.on_skip` only handled when a _batch_ of
the model was being skipped. However, that method is also used when the
entire microbatch model is being skipped due to an upstream node error. Because
we previously _weren't_ handling this second case, it'd cause an unhandled
runtime exception. Thus we now need to check whether we're running a batch or not,
and there is no batch, then use the super's on_skip method.
Previously we were doign an if+elif for setting pre and post hooks
for batches, where in the `if` matched if the batch wasn't the first
batch, and the `elif` matched if the batch wasn't the last batch. The
issue with this is that if the `if` was hit the `elif` _wouldn't_ be hit.
This caused the first batch to appropriately not run the `post-hook` but
then every hook after would run the `post-hook`.
@QMalcolm QMalcolm force-pushed the microbatch-first-last-batch-serial branch from 5a548c1 to 74a76c4 Compare December 6, 2024 22:57
assert "post-hook" in event.data.msg # type: ignore


class TestWhenOnlyOneBatchRunBothPostAndPreHooks(BaseMicrobatchTest):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for adding this one 🙏

We started using self.get_runner in a pervious commit. Unfornuately
there were unintended consequences in doing so. Namely, incorrectly
incrementing the number of nodes being run. We do _eventually_ want
to move to using self.get_runner, but it is out of scope for this
segement of work (mostly due to time constraints).
Copy link
Contributor Author

@MichelleArk MichelleArk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Can't approve my own PR, but thank you so much for taking this over the finish line 🏁

The rest of the functional changes + tests look great 🚀

@QMalcolm QMalcolm added the proto update update proto definitions in CI label Dec 7, 2024
@QMalcolm QMalcolm force-pushed the microbatch-first-last-batch-serial branch 2 times, most recently from 2f9e68e to 670d907 Compare December 7, 2024 00:41
@QMalcolm QMalcolm removed the proto update update proto definitions in CI label Dec 7, 2024
@QMalcolm QMalcolm force-pushed the microbatch-first-last-batch-serial branch from 670d907 to 3fa7bbc Compare December 7, 2024 01:19
@QMalcolm QMalcolm marked this pull request as ready for review December 7, 2024 01:47
@QMalcolm QMalcolm requested a review from a team as a code owner December 7, 2024 01:47
Copy link
Contributor

@QMalcolm QMalcolm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self approval 🙈

@QMalcolm
Copy link
Contributor

QMalcolm commented Dec 7, 2024

Tests look good, proceeding with merging

@QMalcolm QMalcolm merged commit 03fdb4c into main Dec 7, 2024
53 of 56 checks passed
@QMalcolm QMalcolm deleted the microbatch-first-last-batch-serial branch December 7, 2024 18:43
github-actions bot pushed a commit that referenced this pull request Dec 7, 2024
* microbatch: split out first and last batch to run in serial

* only run pre_hook on first batch, post_hook on last batch

* refactor: internalize parallel to RunTask._submit_batch

* Add optional `force_sequential` to `_submit_batch` to allow for skipping parallelism check

* Force last batch to run sequentially

* Force first batch to run sequentially

* Remove batch_idx check in `should_run_in_parallel`

`should_run_in_parallel` shouldn't, and no longer needs to, take into
consideration where in batch exists in a larger context. The first and
last batch for a microbatch model are now forced to run sequentially
by `handle_microbatch_model`

* Begin skipping batches if first batch fails

* Write custom `on_skip` for `MicrobatchModelRunner` to better handle when batches are skipped

This was necessary specifically because the default on skip set the `X of Y` part
of the skipped log using the `node_index` and the `num_nodes`. If there was 2
nodes and we are on the 4th batch of the second node, we'd get a message like
`SKIPPED 4 of 2...` which didn't make much sense. We're likely in a future commit
going to add a custom event for logging the start, result, and skipping of batches
for better readability of the logs.

* Add microbatch pre-hook, post-hook, and sequential first/last batch tests

* Fix/Add tests around first batch failure vs latter batch failure

* Correct MicrobatchModelRunner.on_skip to handle skipping the entire node

Previously `MicrobatchModelRunner.on_skip` only handled when a _batch_ of
the model was being skipped. However, that method is also used when the
entire microbatch model is being skipped due to an upstream node error. Because
we previously _weren't_ handling this second case, it'd cause an unhandled
runtime exception. Thus, we now need to check whether we're running a batch or not,
and there is no batch, then use the super's on_skip method.

* Correct conditional logic for setting pre- and post-hooks for batches

Previously we were doing an if+elif for setting pre- and post-hooks
for batches, where in the `if` matched if the batch wasn't the first
batch, and the `elif` matched if the batch wasn't the last batch. The
issue with this is that if the `if` was hit, the `elif` _wouldn't_ be hit.
This caused the first batch to appropriately not run the `post-hook` but
then every hook after would run the `post-hook`.

* Add two new event types `LogStartBatch` and `LogBatchResult`

* Update MicrobatchModelRunner to use new batch specific log events

* Fix event testing

* Update microbatch integration tests to catch batch specific event types

---------

Co-authored-by: Quigley Malcolm <[email protected]>
(cherry picked from commit 03fdb4c)
QMalcolm pushed a commit that referenced this pull request Dec 9, 2024
* microbatch: split out first and last batch to run in serial

* only run pre_hook on first batch, post_hook on last batch

* refactor: internalize parallel to RunTask._submit_batch

* Add optional `force_sequential` to `_submit_batch` to allow for skipping parallelism check

* Force last batch to run sequentially

* Force first batch to run sequentially

* Remove batch_idx check in `should_run_in_parallel`

`should_run_in_parallel` shouldn't, and no longer needs to, take into
consideration where in batch exists in a larger context. The first and
last batch for a microbatch model are now forced to run sequentially
by `handle_microbatch_model`

* Begin skipping batches if first batch fails

* Write custom `on_skip` for `MicrobatchModelRunner` to better handle when batches are skipped

This was necessary specifically because the default on skip set the `X of Y` part
of the skipped log using the `node_index` and the `num_nodes`. If there was 2
nodes and we are on the 4th batch of the second node, we'd get a message like
`SKIPPED 4 of 2...` which didn't make much sense. We're likely in a future commit
going to add a custom event for logging the start, result, and skipping of batches
for better readability of the logs.

* Add microbatch pre-hook, post-hook, and sequential first/last batch tests

* Fix/Add tests around first batch failure vs latter batch failure

* Correct MicrobatchModelRunner.on_skip to handle skipping the entire node

Previously `MicrobatchModelRunner.on_skip` only handled when a _batch_ of
the model was being skipped. However, that method is also used when the
entire microbatch model is being skipped due to an upstream node error. Because
we previously _weren't_ handling this second case, it'd cause an unhandled
runtime exception. Thus, we now need to check whether we're running a batch or not,
and there is no batch, then use the super's on_skip method.

* Correct conditional logic for setting pre- and post-hooks for batches

Previously we were doing an if+elif for setting pre- and post-hooks
for batches, where in the `if` matched if the batch wasn't the first
batch, and the `elif` matched if the batch wasn't the last batch. The
issue with this is that if the `if` was hit, the `elif` _wouldn't_ be hit.
This caused the first batch to appropriately not run the `post-hook` but
then every hook after would run the `post-hook`.

* Add two new event types `LogStartBatch` and `LogBatchResult`

* Update MicrobatchModelRunner to use new batch specific log events

* Fix event testing

* Update microbatch integration tests to catch batch specific event types

---------

Co-authored-by: Quigley Malcolm <[email protected]>
(cherry picked from commit 03fdb4c)

Co-authored-by: Michelle Ark <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
2 participants