Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax on Metal w/ GPU support #3380

Merged
merged 2 commits into from
Oct 2, 2023
Merged

Conversation

BradyPlanden
Copy link
Member

Description

This adds:

  • FP32 support for Jax on Metal with GPU compilation (M-series GPU's don't support FP64)
  • Logic for parallelised GPU implementation on Metal

This fixes #3274; however, there are errors in the Jax GPU tests when ran on apple-silicon. I'll open a seperate issue for those though.

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

@codecov
Copy link

codecov bot commented Sep 29, 2023

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (c5aa3ce) 99.56% compared to head (577f9a1) 99.58%.
Report is 47 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3380      +/-   ##
===========================================
+ Coverage    99.56%   99.58%   +0.01%     
===========================================
  Files          253      254       +1     
  Lines        19561    19821     +260     
===========================================
+ Hits         19476    19738     +262     
+ Misses          85       83       -2     
Files Coverage Δ
pybamm/__init__.py 100.00% <100.00%> (ø)
...bamm/expression_tree/operations/evaluate_python.py 99.30% <100.00%> (+<0.01%) ⬆️
pybamm/solvers/base_solver.py 100.00% <100.00%> (ø)
pybamm/solvers/idaklu_solver.py 100.00% <100.00%> (+0.90%) ⬆️
pybamm/solvers/jax_bdf_solver.py 98.90% <100.00%> (+<0.01%) ⬆️
pybamm/solvers/processed_variable.py 100.00% <100.00%> (ø)
pybamm/solvers/processed_variable_computed.py 100.00% <100.00%> (ø)
pybamm/solvers/solution.py 100.00% <100.00%> (ø)
pybamm/solvers/jax_solver.py 90.69% <0.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@jsbrittain jsbrittain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BradyPlanden these code changes look good. I also don't see how they would have caused that regression in the casadi solver, so would recommend rerunning the benchmarks in-case this was due to an overloaded/slow runner.

@BradyPlanden BradyPlanden merged commit f81de94 into pybamm-team:develop Oct 2, 2023
31 of 32 checks passed
@BradyPlanden BradyPlanden deleted the jax-metal branch October 2, 2023 07:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

testing runner with GPU support
2 participants