Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make jax & odes optional #3163

Merged
merged 26 commits into from
Jul 28, 2023
Merged

Conversation

arjxn-py
Copy link
Member

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes #3146

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all
  • The documentation builds: $ python run-tests.py --doctest

You can run unit and doctests together at once, using $ python run-tests.py --quick.

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

@arjxn-py arjxn-py force-pushed the make-jax-optional branch from 6498835 to fe5f37d Compare July 19, 2023 10:36
@codecov
Copy link

codecov bot commented Jul 19, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (f588f20) 99.71% compared to head (fe704bf) 99.71%.
Report is 56 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3163   +/-   ##
========================================
  Coverage    99.71%   99.71%           
========================================
  Files          248      248           
  Lines        18761    18764    +3     
========================================
+ Hits         18707    18710    +3     
  Misses          54       54           
Files Changed Coverage Δ
pybamm/util.py 100.00% <100.00%> (ø)

... and 2 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Saransh-cpp Saransh-cpp self-requested a review July 19, 2023 14:03
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @arjxn-py! This would also require updating the docs and getting rid of -

PyBaMM/pybamm/util.py

Lines 295 to 341 in 02c56a2

def install_jax(arguments=None): # pragma: no cover
"""
Install compatible versions of jax, jaxlib.
Command Line Interface::
$ pybamm_install_jax
| optional arguments:
| -h, --help show help message
| -f, --force force install compatible versions of jax and jaxlib
"""
parser = argparse.ArgumentParser(description="Install jax and jaxlib")
parser.add_argument(
"-f",
"--force",
action="store_true",
help="force install compatible versions of"
f" jax ({JAX_VERSION}) and jaxlib ({JAXLIB_VERSION})",
)
args = parser.parse_args(arguments)
if system() == "Windows":
raise NotImplementedError("Jax is not available on Windows")
# Raise an error if jax and jaxlib are already installed, but incompatible
# and --force is not set
elif importlib.util.find_spec("jax") is not None:
if not args.force and not is_jax_compatible():
raise ValueError(
"Jax is already installed but the installed version of jax or jaxlib is"
" not supported by PyBaMM. \nYou can force install compatible versions"
f" of jax ({JAX_VERSION}) and jaxlib ({JAXLIB_VERSION}) using the"
" following command: \npybamm_install_jax --force"
)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
f"jax=={JAX_VERSION}",
f"jaxlib=={JAXLIB_VERSION}",
]
)

from both, code and docs.

setup.py Outdated Show resolved Hide resolved
@arjxn-py arjxn-py marked this pull request as draft July 19, 2023 15:06
@arjxn-py arjxn-py marked this pull request as ready for review July 22, 2023 10:34
@arjxn-py arjxn-py requested a review from Saransh-cpp July 22, 2023 10:35
@arjxn-py
Copy link
Member Author

arjxn-py commented Jul 22, 2023

This would also require updating the docs and getting rid from both - code and docs.

@Saransh-cpp do we also want to deprecate pybamm_install_odes?

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thanks @arjxn-py! Some comments below.

docs/source/user_guide/installation/GNU-linux.rst Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
pybamm/util.py Outdated Show resolved Hide resolved
docs/source/api/util.rst Outdated Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
setup.py Show resolved Hide resolved
@arjxn-py arjxn-py force-pushed the make-jax-optional branch from dd10a60 to cb49480 Compare July 25, 2023 10:46
@arjxn-py
Copy link
Member Author

Deprecation warning looks like this :
WhatsApp Image 2023-07-25 at 16 06 26
Adding note for odes before re-requesting review.

pybamm/install_odes.py Outdated Show resolved Hide resolved
pybamm/install_odes.py Outdated Show resolved Hide resolved
pybamm/util.py Outdated Show resolved Hide resolved
pybamm/util.py Outdated Show resolved Hide resolved
@arjxn-py arjxn-py requested a review from Saransh-cpp July 26, 2023 10:34
pybamm/install_odes.py Outdated Show resolved Hide resolved
docs/source/user_guide/installation/index.rst Outdated Show resolved Hide resolved
docs/source/user_guide/installation/index.rst Outdated Show resolved Hide resolved
@arjxn-py arjxn-py force-pushed the make-jax-optional branch from d029faa to 31f81d4 Compare July 27, 2023 14:28
@arjxn-py arjxn-py requested a review from Saransh-cpp July 27, 2023 14:29
@arjxn-py arjxn-py requested a review from Saransh-cpp July 27, 2023 15:12
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpicking 🙂

CHANGELOG.md Outdated Show resolved Hide resolved
docs/source/user_guide/installation/GNU-linux.rst Outdated Show resolved Hide resolved
docs/source/user_guide/installation/GNU-linux.rst Outdated Show resolved Hide resolved
docs/source/user_guide/installation/index.rst Outdated Show resolved Hide resolved
docs/source/user_guide/installation/index.rst Outdated Show resolved Hide resolved
docs/source/user_guide/installation/index.rst Outdated Show resolved Hide resolved
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @arjxn-py!

@Saransh-cpp Saransh-cpp merged commit 695917e into pybamm-team:develop Jul 28, 2023
@arjxn-py arjxn-py deleted the make-jax-optional branch July 28, 2023 03:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make jax, jaxlib, and scikits.odes "extra requires" in setup.py
2 participants