Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train) Add support for torch.compile (EXPERIMENTAL) #2931

Open
wants to merge 30 commits into
base: main
Choose a base branch
from

Conversation

canergen
Copy link
Member

@canergen canergen commented Aug 7, 2024

No description provided.

@canergen canergen added the cuda tests Run test suite on CUDA label Aug 7, 2024
@canergen
Copy link
Member Author

canergen commented Aug 7, 2024

@ori-kron-wis Can you add tests for all pytorch models (not pyro and not jax) for compile. Can you check speed improvements on your end? You should execute it with: model.train(accelerator='cuda', plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True})

trainingplans with future imports and test_compute_kl revert
@ori-kron-wis ori-kron-wis self-assigned this Aug 26, 2024
@ori-kron-wis ori-kron-wis self-requested a review August 26, 2024 14:49
@canergen
Copy link
Member Author

Needs tests like: model2.train(accelerator='cuda', batch_size=5000, max_epochs=100, train_size=0.9, plan_kwargs={'n_epochs_kl_warmup': 100, 'compile': True}, datasplitter_kwargs={'drop_last': True}) and then get_elbo, get_reconstruction_loss, get_latent.

@ori-kron-wis
Copy link
Collaborator

ori-kron-wis commented Sep 17, 2024

I added torch compile tests for most models (of course not working with the github action due to that error) - on new servers, it worked fine and was faster, although the compile part will add some overhead.

Currently pyro test not working on a multi GPU machine. Need to see why (only test_pyro_bayesian_regression). once we remove it everything works (it should be passed here)

Copy link

codecov bot commented Sep 18, 2024

Codecov Report

Attention: Patch coverage is 71.42857% with 4 lines in your changes missing coverage. Please review.

Project coverage is 84.51%. Comparing base (c2e3714) to head (4400ac2).

Files with missing lines Patch % Lines
src/scvi/train/_trainingplans.py 71.42% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2931      +/-   ##
==========================================
- Coverage   84.53%   84.51%   -0.02%     
==========================================
  Files         178      178              
  Lines       15062    15071       +9     
==========================================
+ Hits        12732    12737       +5     
- Misses       2330     2334       +4     
Files with missing lines Coverage Δ
src/scvi/module/_vae.py 94.48% <ø> (ø)
src/scvi/train/_trainingplans.py 93.84% <71.42%> (-0.75%) ⬇️
---- 🚨 Try these New Features:

@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.3 milestone Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda tests Run test suite on CUDA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants