-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a gather_for_metrics capability #540
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clever! It's really great, just make sure to update the docs and the examples :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice doc!
is this functionality in 0.11.0? |
It is not, we suggest not using it for now and doing the check manually as shown in the metric example script as some bugs were discovered: #575 |
Cool, thank you, excited for this change when it happens |
Introduce a
gather_for_metrics
functionWhat does this add?
This PR adds a new function to
Accelerator
calledgather_for_metrics
, which assists with calculating the right metric in distributed setupsWho is it for?
Users of accelerate that want to ensure that their reported metrics are fully accurate
Why is it needed?
To assist with making sure all the batches have the right batch size, Accelerate will pad the length on the last batch to be duplicates of the last sample. These need to be dropped when calculating the final metrics on the last batch, and currently it looks something like:
This PR adds a new utility called
accelerate.gather_for_metrics
which will handle this check for us, entirely thanks to theGradientState
capability.What parts of the API does this impact?
User-facing:
Accelerator.gather_for_metrics
function was addedInternal structure:
total_dataset_length
attributeGradientState
now keeps track of the number of samples seenBasic Usage Example(s):
When calculating metrics, users can now do the following to properly calculate their metrics:
When would I use it, and when wouldn't I?
Since this works on distributed and non-distributed systems, always if the evaluation dataset has been prepared by Accelerator. Users should just add this to any script that calculates metrics.
TODO: