Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes required to support jax version 0.4.33 #555

Merged
merged 14 commits into from
Oct 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/pytest_ubuntu.yml
Original file line number Diff line number Diff line change
@@ -60,7 +60,8 @@ jobs:
pip install pytest-split
pip install -r requirements.txt
pip install -r dev_requirements.txt
conda install -c conda-forge svmbir>=0.3.3
# svmbir install temporarily disabled due to import errors
#conda install -c conda-forge svmbir>=0.3.3
conda install -c conda-forge astra-toolbox
conda install -c conda-forge pyyaml
pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version
9 changes: 8 additions & 1 deletion docs/source/install.rst
Original file line number Diff line number Diff line change
@@ -87,7 +87,7 @@ can be installed from PyPI
From GitHub
-----------

SCICO can be downloaded from the `GitHub repo
The development version of SCICO can be downloaded from the `GitHub repo
<https://github.com/lanl/scico>`_. Note that, since the SCICO repo has
a submodule, it should be cloned via the command
::
@@ -102,6 +102,13 @@ Install using the commands
pip install -e .


If a clone of the SCICO repository is not needed, it is simpler to
install directly using ``pip``
::

pip install git+https://github.com/lanl/scico



GPU Support
-----------
7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
@@ -54,6 +54,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
@@ -67,7 +72,7 @@
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
@@ -61,6 +61,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
@@ -70,7 +75,7 @@
from scico.linop.xray.astra import XRayTransform2D


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable
@@ -36,7 +41,7 @@
from scico.flax.examples import load_ct_data


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


7 changes: 6 additions & 1 deletion examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
@@ -58,6 +58,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
@@ -67,7 +72,7 @@
from scico.linop import CircularConvolve


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


7 changes: 6 additions & 1 deletion examples/scripts/deconv_odp_train_foam1.py
Original file line number Diff line number Diff line change
@@ -66,6 +66,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
@@ -75,7 +80,7 @@
from scico.linop import CircularConvolve


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


7 changes: 6 additions & 1 deletion examples/scripts/denoise_dncnn_train_bsds.py
Original file line number Diff line number Diff line change
@@ -24,14 +24,19 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_image_data


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.31
jax>=0.4.3,<=0.4.31
orbax-checkpoint<=0.5.7
flax>=0.8.0,<=0.8.3
jaxlib>=0.4.3,<=0.4.33
jax>=0.4.3,<=0.4.33
orbax-checkpoint>=0.5.0
flax>=0.8.0,<=0.9.0
pyabel>=0.9.0
9 changes: 7 additions & 2 deletions scico/flax/examples/data_generation.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,11 @@ class UnitCircle:
import jax
import jax.numpy as jnp

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from scico.linop import CircularConvolve
from scico.numpy import Array

@@ -260,7 +265,7 @@ def generate_ct_data(
fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min())

if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
@@ -333,7 +338,7 @@ def generate_blur_data(
blurn = jnp.clip(blurn, 0, 1)

if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
51 changes: 33 additions & 18 deletions scico/flax/train/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
@@ -13,9 +13,7 @@

import jax

import orbax.checkpoint

from flax.training import orbax_utils
import orbax.checkpoint as ocp

from .state import TrainState
from .typed_dict import ConfigDict
@@ -48,13 +46,20 @@ def checkpoint_restore(
if isinstance(workdir_, str):
workdir_ = Path(workdir_)
if workdir_.exists():
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpoint_manager = orbax.checkpoint.CheckpointManager(workdir_, orbax_checkpointer)
step = checkpoint_manager.latest_step()
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
workdir_,
item_names=("state", "config"),
options=options,
)
step = mngr.latest_step()
if step is not None:
target = {"state": state, "config": {}}
ckpt = checkpoint_manager.restore(step, items=target)
state = ckpt["state"]
restored = mngr.restore(
step, args=ocp.args.Composite(state=ocp.args.StandardRestore(state))
)
mngr.wait_until_finished()
mngr.close()
state = restored.state
elif not ok_no_ckpt:
raise FileNotFoundError("Could not read from checkpoint: " + str(workdir))

@@ -74,13 +79,23 @@ def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, P
workdir: Path in which to store checkpoint files.
"""
if jax.process_index() == 0:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
# Bundle config and model parameters together
ckpt = {"state": state, "config": config}
save_args = orbax_utils.save_args_from_target(ckpt)
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=3, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
workdir, orbax_checkpointer, options
options = ocp.CheckpointManagerOptions(max_to_keep=3, create=True)
mngr = ocp.CheckpointManager(
workdir,
item_names=("state", "config"),
options=options,
)
step = int(state.step)
checkpoint_manager.save(step, ckpt, save_kwargs={"save_args": save_args})
# Remove non-serializable partial functools in post_lst if it exists
config_ = config.copy()
if "post_lst" in config_:
config_.pop("post_lst", None) # type: ignore
mngr.save(
step,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
config=ocp.args.JsonSave(config_),
),
)
mngr.wait_until_finished()
mngr.close()
3 changes: 2 additions & 1 deletion scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
@@ -128,7 +128,8 @@ def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore
r"""Apply DnCNN denoiser.

*Warning*: The `lam` parameter is ignored, and has no effect on
the output.
the output for :class:`.DnCNN` objects initialized with
:code:`variant` parameter values other than `6N` and `17N`.

Args:
x: Input array.
13 changes: 9 additions & 4 deletions scico/numpy/_blockarray.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
"""Block array class."""

import inspect
from functools import wraps
from functools import WRAPPER_ASSIGNMENTS, wraps
from typing import Callable

import jax
@@ -174,10 +174,15 @@ def prop_ba(self):
def _da_method_wrapper(method_name):
method = getattr(Array, method_name)

if method.__name__ is None:
return method
# Don't try to set attributes that are None. Not clear why some
# functions/methods (e.g. block_until_ready) have None values
# for these attributes.
wrapper_assignments = WRAPPER_ASSIGNMENTS
for attr in ("__name__", "__qualname__"):
if getattr(method, attr) is None:
wrapper_assignments = tuple(x for x in wrapper_assignments if x != attr)

@wraps(method)
@wraps(method, assigned=wrapper_assignments)
def method_ba(self, *args, **kwargs):
result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self)

4 changes: 2 additions & 2 deletions scico/numpy/_wrapped_function_lists.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
@@ -84,7 +84,7 @@
"arccosh",
"arctanh",
"around",
"round_",
"round",
"rint",
"fix",
"floor",