Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BalancedBatchSampler QoL Updates #566

Closed
wants to merge 6 commits into from

Conversation

nimashoghi
Copy link
Collaborator

@nimashoghi nimashoghi commented Aug 24, 2023

Right now, BalancedBatchSampler has some rough edges:

  • It requires a very specific npz format for the metadata storage which needs to be followed for any dataset that wants to be balanced. Currently, the new ASE dataset doesn't support this format.
  • The force_balancing and throw_on_error parameters are confusing.

For this PR, I updated BalancedBatchSampler to rely on a protocol which expects datasets to implement a data_sizes method, which returns the "size" of each dataset sample. I have updated LmdbDataset and OC22LmdbDataset to implement this in a backward-compatible manner w/ the old implementation.

I have also completely removed the neighbors balancing support, as our graph generation happens on GPU anyway, and the number of neighbors changes depending on max_neighbors/cutoff. In many cases, the values stored in the metadata would end up not being accurate.

As a TL;DR, here's essentially the main change (as far as the datasets are concerned).
Previously:

class MyDataset(Dataset):
    def __init__(self, ...):
        ...
        self.metadata_path = ...

Now:

class MyDataset(Dataset):
    def data_sizes(self, batch_idx: List[int]) -> np.ndarray:
        # Use the loaded metadata to load the natoms for samples in batch_idx
        return self.metadata["natoms"][batch_idx]

    def __init__(self, ...):
        ...
        self.metadata = np.load(...) # Load all metadata in the init method

Tasks:

  • Update BalancedBatchSampler to use datasets' data_sizes method
  • Replace BalancedBatchSampler's force_balancing and throw_on_error parameters with on_error

nimashoghi and others added 5 commits August 24, 2023 18:27
@codecov
Copy link

codecov bot commented Aug 30, 2023

Codecov Report

Attention: Patch coverage is 55.63910% with 59 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ocpmodels/common/balanced_batch_sampler.py 58.82% 42 Missing ⚠️
ocpmodels/trainers/base_trainer.py 22.22% 7 Missing ⚠️
ocpmodels/datasets/lmdb_dataset.py 50.00% 5 Missing ⚠️
ocpmodels/datasets/oc22_lmdb_dataset.py 54.54% 5 Missing ⚠️
Files with missing lines Coverage Δ
ocpmodels/common/data_parallel.py 21.05% <100.00%> (-27.38%) ⬇️
ocpmodels/datasets/lmdb_dataset.py 38.58% <50.00%> (+0.44%) ⬆️
ocpmodels/datasets/oc22_lmdb_dataset.py 16.31% <54.54%> (+3.23%) ⬆️
ocpmodels/trainers/base_trainer.py 16.93% <22.22%> (+0.08%) ⬆️
ocpmodels/common/balanced_batch_sampler.py 58.82% <58.82%> (ø)

@github-actions
Copy link

This PR has been marked as stale because it has been open for 30 days with no activity.

@github-actions github-actions bot added the stale label Sep 30, 2023
@abhshkdz abhshkdz added dont-close and removed stale labels Oct 2, 2023
@abhshkdz abhshkdz self-assigned this Oct 2, 2023
@mshuaibii mshuaibii marked this pull request as draft April 8, 2024 19:53
@mshuaibii mshuaibii mentioned this pull request Apr 8, 2024
@mshuaibii mshuaibii requested a review from wood-b April 8, 2024 21:12
@mshuaibii mshuaibii added the enhancement New feature or request label Apr 9, 2024
@wood-b wood-b removed the dont-close label Oct 24, 2024
@wood-b
Copy link
Collaborator

wood-b commented Oct 24, 2024

Closing this PR as it was incorporated in this PR #753

@wood-b wood-b closed this Oct 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants