Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes required to support jax version 0.4.33 #555

Merged
merged 14 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/pytest_ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ jobs:
pip install pytest-split
pip install -r requirements.txt
pip install -r dev_requirements.txt
conda install -c conda-forge svmbir>=0.3.3
# svmbir install temporarily disabled due to import errors
#conda install -c conda-forge svmbir>=0.3.3
conda install -c conda-forge astra-toolbox
conda install -c conda-forge pyyaml
pip install --upgrade --force-reinstall scipy>=1.6.0 # Temporary fix for GLIBCXX_3.4.30 not found in conda forge version
Expand Down
9 changes: 8 additions & 1 deletion docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ can be installed from PyPI
From GitHub
-----------

SCICO can be downloaded from the `GitHub repo
The development version of SCICO can be downloaded from the `GitHub repo
<https://github.com/lanl/scico>`_. Note that, since the SCICO repo has
a submodule, it should be cloned via the command
::
Expand All @@ -102,6 +102,13 @@ Install using the commands
pip install -e .


If a clone of the SCICO repository is not needed, it is simpler to
install directly using ``pip``
::

pip install git+https://github.com/lanl/scico



GPU Support
-----------
Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_modl_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -67,7 +72,7 @@
applies if GPU is not available).
"""
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_odp_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -70,7 +75,7 @@
from scico.linop.xray.astra import XRayTransform2D


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/ct_astra_unet_train_foam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand All @@ -36,7 +41,7 @@
from scico.flax.examples import load_ct_data


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/deconv_modl_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -67,7 +72,7 @@
from scico.linop import CircularConvolve


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/deconv_odp_train_foam1.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
Expand All @@ -75,7 +80,7 @@
from scico.linop import CircularConvolve


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
7 changes: 6 additions & 1 deletion examples/scripts/denoise_dncnn_train_bsds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@

import jax

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_image_data


platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print("Platform: ", platform)


Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scipy>=1.6.0
imageio>=2.17
tifffile
matplotlib
jaxlib>=0.4.3,<=0.4.31
jax>=0.4.3,<=0.4.31
orbax-checkpoint<=0.5.7
flax>=0.8.0,<=0.8.3
jaxlib>=0.4.3,<=0.4.33
jax>=0.4.3,<=0.4.33
orbax-checkpoint>=0.5.0
flax>=0.8.0,<=0.9.0
pyabel>=0.9.0
9 changes: 7 additions & 2 deletions scico/flax/examples/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class UnitCircle:
import jax
import jax.numpy as jnp

try:
from jax.extend.backend import get_backend # introduced in jax 0.4.33
except ImportError:
from jax.lib.xla_bridge import get_backend

from scico.linop import CircularConvolve
from scico.numpy import Array

Expand Down Expand Up @@ -260,7 +265,7 @@ def generate_ct_data(
fbp = (fbp - fbp.min()) / (fbp.max() - fbp.min())

if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
Expand Down Expand Up @@ -333,7 +338,7 @@ def generate_blur_data(
blurn = jnp.clip(blurn, 0, 1)

if verbose: # pragma: no cover
platform = jax.lib.xla_bridge.get_backend().platform
platform = get_backend().platform
print(f"{'Platform':26s}{':':4s}{platform}")
print(f"{'Device count':26s}{':':4s}{jax.device_count()}")
print(f"{'Data generation':19s}{'time[s]:':10s}{time_dtgen:>7.2f}")
Expand Down
51 changes: 33 additions & 18 deletions scico/flax/train/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,9 +13,7 @@

import jax

import orbax.checkpoint

from flax.training import orbax_utils
import orbax.checkpoint as ocp

from .state import TrainState
from .typed_dict import ConfigDict
Expand Down Expand Up @@ -48,13 +46,20 @@ def checkpoint_restore(
if isinstance(workdir_, str):
workdir_ = Path(workdir_)
if workdir_.exists():
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpoint_manager = orbax.checkpoint.CheckpointManager(workdir_, orbax_checkpointer)
step = checkpoint_manager.latest_step()
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
workdir_,
item_names=("state", "config"),
options=options,
)
step = mngr.latest_step()
if step is not None:
target = {"state": state, "config": {}}
ckpt = checkpoint_manager.restore(step, items=target)
state = ckpt["state"]
restored = mngr.restore(
step, args=ocp.args.Composite(state=ocp.args.StandardRestore(state))
)
mngr.wait_until_finished()
mngr.close()
state = restored.state
elif not ok_no_ckpt:
raise FileNotFoundError("Could not read from checkpoint: " + str(workdir))

Expand All @@ -74,13 +79,23 @@ def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, P
workdir: Path in which to store checkpoint files.
"""
if jax.process_index() == 0:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
# Bundle config and model parameters together
ckpt = {"state": state, "config": config}
save_args = orbax_utils.save_args_from_target(ckpt)
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=3, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
workdir, orbax_checkpointer, options
options = ocp.CheckpointManagerOptions(max_to_keep=3, create=True)
mngr = ocp.CheckpointManager(
workdir,
item_names=("state", "config"),
options=options,
)
step = int(state.step)
checkpoint_manager.save(step, ckpt, save_kwargs={"save_args": save_args})
# Remove non-serializable partial functools in post_lst if it exists
config_ = config.copy()
if "post_lst" in config_:
config_.pop("post_lst", None) # type: ignore
mngr.save(
step,
args=ocp.args.Composite(
state=ocp.args.StandardSave(state),
config=ocp.args.JsonSave(config_),
),
)
mngr.wait_until_finished()
mngr.close()
3 changes: 2 additions & 1 deletion scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore
r"""Apply DnCNN denoiser.

*Warning*: The `lam` parameter is ignored, and has no effect on
the output.
the output for :class:`.DnCNN` objects initialized with
:code:`variant` parameter values other than `6N` and `17N`.

Args:
x: Input array.
Expand Down
13 changes: 9 additions & 4 deletions scico/numpy/_blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""Block array class."""

import inspect
from functools import wraps
from functools import WRAPPER_ASSIGNMENTS, wraps
from typing import Callable

import jax
Expand Down Expand Up @@ -174,10 +174,15 @@ def prop_ba(self):
def _da_method_wrapper(method_name):
method = getattr(Array, method_name)

if method.__name__ is None:
return method
# Don't try to set attributes that are None. Not clear why some
# functions/methods (e.g. block_until_ready) have None values
# for these attributes.
wrapper_assignments = WRAPPER_ASSIGNMENTS
for attr in ("__name__", "__qualname__"):
if getattr(method, attr) is None:
wrapper_assignments = tuple(x for x in wrapper_assignments if x != attr)

@wraps(method)
@wraps(method, assigned=wrapper_assignments)
def method_ba(self, *args, **kwargs):
result = tuple(getattr(x, method_name)(*args, **kwargs) for x in self)

Expand Down
4 changes: 2 additions & 2 deletions scico/numpy/_wrapped_function_lists.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# Copyright (C) 2022-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
Expand Down Expand Up @@ -84,7 +84,7 @@
"arccosh",
"arctanh",
"around",
"round_",
"round",
"rint",
"fix",
"floor",
Expand Down
Loading