Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a gather_for_metrics capability #540

Merged
merged 14 commits into from
Jul 21, 2022
Merged

Add a gather_for_metrics capability #540

merged 14 commits into from
Jul 21, 2022

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Jul 20, 2022

Introduce a gather_for_metrics function

What does this add?

This PR adds a new function to Accelerator called gather_for_metrics, which assists with calculating the right metric in distributed setups

Who is it for?

Users of accelerate that want to ensure that their reported metrics are fully accurate

Why is it needed?

To assist with making sure all the batches have the right batch size, Accelerate will pad the length on the last batch to be duplicates of the last sample. These need to be dropped when calculating the final metrics on the last batch, and currently it looks something like:

if accelerator.use_distributed:
    # Then see if we're on the last batch of our eval dataloader
    if step == len(eval_dataloader) - 1:
        # Last batch needs to be truncated on distributed systems as it contains additional samples
        predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
        references = references[: len(eval_dataloader.dataset) - samples_seen]
    else:
        # Otherwise we add the number of samples seen
        samples_seen += references.shape[0]

This PR adds a new utility called accelerate.gather_for_metrics which will handle this check for us, entirely thanks to the GradientState capability.

Note: this method currently doesn't work for TPU's, as it needs to be eval_dataloader._loader.dataset. This PR fixes this as well

What parts of the API does this impact?

User-facing:

  • A new Accelerator.gather_for_metrics function was added

Internal structure:

  • Preprocessed dataloaders now have a new total_dataset_length attribute
  • GradientState now keeps track of the number of samples seen

Basic Usage Example(s):

When calculating metrics, users can now do the following to properly calculate their metrics:

input, target = next(iter(dataloader))
with torch.no_grad():
    logits = ddp_model(ddp_input)
    logits, target = accelerator.gather_for_metrics((logits, ddp_target), dataloader)
    accuracy_multi = accuracy(logits.argmax(dim=-1), target)

When would I use it, and when wouldn't I?

Since this works on distributed and non-distributed systems, always if the evaluation dataset has been prepared by Accelerator. Users should just add this to any script that calculates metrics.

TODO:

  • Update the other examples to use this new API. multiprocess_metrics will serve as a lower-level example

@muellerzr muellerzr added the enhancement New feature or request label Jul 20, 2022
@muellerzr muellerzr requested a review from sgugger July 20, 2022 15:15
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 20, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clever! It's really great, just make sure to update the docs and the examples :-)

@muellerzr muellerzr requested a review from sgugger July 20, 2022 16:23
@muellerzr muellerzr changed the title Add a gather_metrics capability Add a gather_for_metrics capability Jul 20, 2022
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice doc!

docs/source/quicktour.mdx Show resolved Hide resolved
@muellerzr muellerzr merged commit 164943c into main Jul 21, 2022
@muellerzr muellerzr deleted the dset_len branch July 21, 2022 11:40
@plamb-viso
Copy link

is this functionality in 0.11.0?

@muellerzr
Copy link
Collaborator Author

It is not, we suggest not using it for now and doing the check manually as shown in the metric example script as some bugs were discovered: #575

@plamb-viso
Copy link

Cool, thank you, excited for this change when it happens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants