Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(OTF) Normalization and element references #715

Merged
merged 94 commits into from
Aug 5, 2024
Merged

(OTF) Normalization and element references #715

merged 94 commits into from
Aug 5, 2024

Conversation

lbluque
Copy link
Collaborator

@lbluque lbluque commented May 24, 2024

This PR enables (on the fly) fitting and estimation of normalization values and element references

  • Normalizers and LinearReference modules are trainer attributes.
  • This also cleans up the use of linear references previously inside datasets - they are now saved as part of the checkpoint (no need to insert them into checkpoints after training for testing/inference)
  • Snuck in a fix when reading ASE Datasets from a list of paths
  • normalization values and/or linear references can be estimated at runtime before training. The config also allows to hard set a value for mean or rmsd (root mean square difference). ie a config to enable this, in which forces mean is set to zero and so the estimated rmsd will correspond to rms force:
dataset:
  train:
    tranforms:
      normalizer:
          fit:
              targets:
                 energy: {}
                 forces: { mean: 0.0 }
          batch_size: 32
          num_batches: 1000
      element_references:
        fit:
          targets:
            - energy
          batch_size: 32
          num_batches: 1000
  • added scripts to fit linear references and/or normalizers using the train dataset in a standard config (with fitting directive as specified above), i.e.
python src/fairchem/core/scripts/fit_references.py --config path/to/config.yml
python src/fairchem/core/scripts/fit_normalizers.py --config path/to/config.yml --linref-path path/energy_linref.pt
  • linear references can also be passed as a file in the dataset/transforms block (for example if fit with above script, or legacy npz files):
      element_references:
        energy:
          file: /path/to/file.pt/or/npz
  • normalization values can also be passed from a file for many targets (the script above generates a dict with targets and normalizers):
      normalizer:
        file: norms.pt
  • or they can be passed by individual files (an npz or state_dict.pt with "mean" and "std")
      normalizer:
        energy:
          file: energy_norms.pt  # or .npz
  • using lin_ref for linear references inside datasets is still enabled for backwards compatibility.

TODO:

  • Make sure that otf_fit does not refit on resubmission
  • Write unit-tests
  • Add option to run fit normalizers/element references and save

@lbluque lbluque marked this pull request as draft May 24, 2024 22:08
@lbluque lbluque requested review from wood-b, mshuaibii and misko July 20, 2024 00:19
@lbluque
Copy link
Collaborator Author

lbluque commented Aug 2, 2024

@mshuaibii @misko @wood-b finally here are some validation training runs,
https://fairwandb.org/fairchem/norms-refs-val

we should be set to go now!

misko
misko previously approved these changes Aug 2, 2024
Copy link
Collaborator

@wood-b wood-b left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Great job pushing this through and thanks for the validation!

@lbluque lbluque added this pull request to the merge queue Aug 5, 2024
Merged via the queue into main with commit 029d4d3 Aug 5, 2024
7 checks passed
@lbluque lbluque deleted the norms-and-refs branch August 5, 2024 04:14
lbluque added a commit that referenced this pull request Aug 6, 2024
lbluque added a commit that referenced this pull request Aug 6, 2024
@zulissimeta zulissimeta added enhancement New feature or request minor Minor version release labels Aug 13, 2024
misko pushed a commit that referenced this pull request Jan 17, 2025
* denorm targets in _forward only

* linear reference class

* atomref in normalizer

* raise input error

* clean up normalizer interface

* add element refs

* add element refs correctly

* ruff

* fix save_checkpoint

* reference and dereference

* 2xnorm linref trainer add

* clean-up

* otf linear reference fit

* fix tensor device

* otf element references and normalizers

* use only present elements when fitting

* lint

* _forward norm and derefd values

* fix list of paths in src

* total mean and std

* fitted flag to avoid refitting normalizers/references on rerun

* allow passing lstsq driver

* element ref unit tests

* remove superfluous type

* lint fix

* allow setting batch_size explicitly

* test applying element refs

* normalizer tests

* increase distributed timeout

* save normalizers and linear refs in otf_fit

* remove debug code

* fix removing refs

* swap otf_fit for fit, and save all normalizers in one file

* log loading and saving normalizers

* fit references and normalizer scripts

* lint fixes

* allow absent optim key in config

* lin-ref description

* read files based on extension

* pass seed

* rename dataset fixture

* check if file is none

* pass generator correctly

* separate method for norms and refs

* add normalizer code back

* fix Generator construction

* import order

* log warnings if multiple inputs are passed

* raise Error if duplicate references or norms are set

* use len batch

* assert element reference targets are scalar

* fix name and rename method

* load and save norms and refs using same logic

* fix creating normalizer

* remove print statements

* adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764)

* adding new notebook for using fairchem models with NEBs

* adding md tutorials

* blocking code cells that arent needed or take too long

* warn instead of error when duplicate norm/ref target names

* allow timeout to be read from config

* test seed noseed ref fits

* lotsa refactoring

* lotsa fixing

* more fixing...

* num_workers zero to prevent mp issues

* add otf norms smoke test and fixes

* allow overriding normalization fit values

* update tests

* fix normalizer loading

* use rmsd instead of only stdev

* fix tests

* correct rmsd calc and fix loading

* clean up norm loading and log values

* logg linear reference metrics

* load element references state dict

* fix loading and tests

* fix imports in scripts

* fix test?

* fix test

* use numpy as default to fit references

* minor fixes

* rm torch_tempdir fixture

---------

Co-authored-by: Brook Wander <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
Former-commit-id: 4ad6633733df9c76620ee779b6851a119e920f0b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request minor Minor version release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants