Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed changes to #541 #543

Merged
merged 53 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a324841
Add type annotation
bwohlberg Jul 18, 2024
1ebf4e0
Remove jax distributed data generation option
bwohlberg Jul 18, 2024
a75ab62
Remove jax distributed data generation option
bwohlberg Jul 18, 2024
997f52a
Clean up
bwohlberg Jul 18, 2024
5f4552f
Clean up
bwohlberg Jul 18, 2024
a0b72ae
Extend docs
bwohlberg Jul 18, 2024
f77e212
Add additional test for exception state
bwohlberg Jul 18, 2024
5ab1d05
Tracer conversion error fix from Cristina
bwohlberg Jul 18, 2024
ecbfca5
Omitted import
bwohlberg Jul 18, 2024
28c828c
Clean up
bwohlberg Jul 18, 2024
dcd358c
Consistent phrasing
bwohlberg Jul 18, 2024
fbb4564
Merge branch 'cristina/issue535' into brendt/issue535_extended
bwohlberg Jul 19, 2024
8f286d0
Clean up some f-strings
bwohlberg Jul 22, 2024
5e85f8c
Add missing ray init
bwohlberg Jul 22, 2024
d69fdd2
Set dtype
bwohlberg Jul 22, 2024
3ef66a6
Merge branch 'cristina/issue535' into brendt/issue535_extended
bwohlberg Jul 22, 2024
0d97b3f
Fix indentation error
bwohlberg Jul 22, 2024
eec3242
Update module docstring
bwohlberg Jul 23, 2024
a7fa89f
Experimental solution to ray/jax failure
bwohlberg Jul 23, 2024
85ded0f
Bug fix
bwohlberg Jul 23, 2024
e7461f0
Improve docstring
bwohlberg Jul 23, 2024
5dac79f
Implement hack to resolve jax/ray conflict
bwohlberg Jul 23, 2024
25f318e
Debug attempt
bwohlberg Jul 23, 2024
9218e4d
Debug attempt
bwohlberg Jul 23, 2024
e73ae7d
Debug attempt
bwohlberg Jul 23, 2024
c9714e4
Debug attempt
bwohlberg Jul 23, 2024
e24ccdd
Debug attempt
bwohlberg Jul 23, 2024
9bbad64
Debug attempt
bwohlberg Jul 23, 2024
325fb9b
New solution attempt
bwohlberg Jul 23, 2024
d521aa3
Debug attempt
bwohlberg Jul 23, 2024
8eef347
Debug attempt
bwohlberg Jul 23, 2024
fa08d8c
Debug attempt
bwohlberg Jul 23, 2024
47b8067
Debug attempt
bwohlberg Jul 23, 2024
931c763
Debug attempt
bwohlberg Jul 23, 2024
a9cafff
Debug attempt
bwohlberg Jul 23, 2024
5f7001e
Debug attempt
bwohlberg Jul 23, 2024
644c189
Debug attempt
bwohlberg Jul 23, 2024
89b4772
Debug attempt
bwohlberg Jul 23, 2024
fdb8520
Debug attempt
bwohlberg Jul 23, 2024
978759e
Return to earlier approach
bwohlberg Jul 23, 2024
fc2315a
Extend comment
bwohlberg Jul 23, 2024
039a970
Clean up and improve function logic
bwohlberg Jul 23, 2024
9dca046
Address some problems
bwohlberg Jul 23, 2024
1fcd82d
Clean up
bwohlberg Jul 23, 2024
6cdf217
Rename function for consistency with related functions
bwohlberg Jul 23, 2024
f2acaf2
Merge branch 'main' into brendt/issue535_extended
bwohlberg Jul 23, 2024
32e7b01
Merge branch 'brendt/issue535_extended' into brendt/issue535_extended…
bwohlberg Jul 23, 2024
aa1467a
Bug fix
bwohlberg Jul 23, 2024
cc678fd
Clean up
bwohlberg Jul 23, 2024
d650f3a
Bug fix
bwohlberg Jul 23, 2024
0fe46a4
Address pylint complaint
bwohlberg Jul 23, 2024
41fa25e
Revert unworkable structure
bwohlberg Jul 24, 2024
2c96637
Merge branch 'cristina/issue535' into brendt/issue535_extended
bwohlberg Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Version 0.0.6 (unreleased)
``scico.flax.save_variables`` and ``scico.flax.load_variables``
respectively.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.30.
• Support ``flax`` versions between 0.8.0 and 0.8.3 (inclusive).
• Support ``flax`` versions 0.8.0 to 0.8.3.



Expand Down
150 changes: 54 additions & 96 deletions scico/flax/examples/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import os
import warnings
from functools import partial
from time import time
from typing import Callable, List, Tuple, Union

Expand Down Expand Up @@ -47,8 +48,6 @@
have_astra = False
else:
have_astra = True

if have_astra:
from scico.linop.xray.astra import XRayTransform2D


Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(
attn2: Mass attenuation parameter for material 2.
Default: 10.
"""
super(Foam2, self).__init__(radius=0.5, material=SimpleMaterial(attn1))
super().__init__(radius=0.5, material=SimpleMaterial(attn1))
if porosity < 0 or porosity > 1:
raise ValueError("Porosity must be in the range [0,1).")
self.sprinkle(
Expand All @@ -98,11 +97,11 @@ def __init__(
)


def generate_foam2_images(seed: float, size: int, ndata: int) -> Array:
"""Generate batch of foam2 structures.
def generate_foam1_images(seed: float, size: int, ndata: int) -> np.ndarray:
"""Generate batch of xdesign foam-like structures.

Generate batch of images with :class:`Foam2` structure
(foam-like material with two different attenuations).
Generate batch of images with `xdesign` foam-like structure, which
uses one attenuation.

Args:
seed: Seed for data generation.
Expand All @@ -115,22 +114,20 @@ def generate_foam2_images(seed: float, size: int, ndata: int) -> Array:
if not have_xdesign:
raise RuntimeError("Package xdesign is required for use of this function.")

# np.random.seed(seed)
saux = jnp.zeros((ndata, size, size, 1))
np.random.seed(seed)
saux = np.zeros((ndata, size, size, 1))
for i in range(ndata):
foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size))
# normalize
saux = saux / jnp.max(saux, axis=(1, 2), keepdims=True)
foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux[i, ..., 0] = discrete_phantom(foam, size=size)

return saux


def generate_foam1_images(seed: float, size: int, ndata: int) -> Array:
"""Generate batch of xdesign foam-like structures.
def generate_foam2_images(seed: float, size: int, ndata: int) -> np.ndarray:
"""Generate batch of foam2 structures.

Generate batch of images with `xdesign` foam-like structure, which
uses one attenuation.
Generate batch of images with :class:`Foam2` structure
(foam-like material with two different attenuations).

Args:
seed: Seed for data generation.
Expand All @@ -143,11 +140,13 @@ def generate_foam1_images(seed: float, size: int, ndata: int) -> Array:
if not have_xdesign:
raise RuntimeError("Package xdesign is required for use of this function.")

# np.random.seed(seed)
saux = jnp.zeros((ndata, size, size, 1))
np.random.seed(seed)
saux = np.zeros((ndata, size, size, 1))
for i in range(ndata):
foam = Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux = saux.at[i, ..., 0].set(discrete_phantom(foam, size=size))
foam = Foam2(size_range=[0.075, 0.0025], gap=1e-3, porosity=1)
saux[i, ..., 0] = discrete_phantom(foam, size=size)
# normalize
saux /= np.max(saux, axis=(1, 2), keepdims=True)

return saux

Expand Down Expand Up @@ -180,7 +179,11 @@ def batched_f(f_: Callable, vr: Array) -> Array:
evaluation preserves the batch axis.
"""
nproc = jax.device_count()
res = jax.pmap(lambda i: vector_f(f_, vr[i]))(jnp.arange(nproc))
if vr.shape[0] != nproc:
vrr = vr.reshape((nproc, -1, *vr.shape[:1]))
else:
vrr = vr
res = jax.pmap(partial(vector_f, f_))(vrr)
return res


Expand All @@ -191,8 +194,7 @@ def generate_ct_data(
imgfunc: Callable = generate_foam2_images,
seed: int = 1234,
verbose: bool = False,
prefer_ray: bool = True,
) -> Tuple[Array, ...]:
) -> Tuple[Array, Array, Array]:
"""Generate batch of computed tomography (CT) data.

Generate batch of CT data for training of machine learning network
Expand All @@ -205,9 +207,6 @@ def generate_ct_data(
imgfunc: Function for generating input images (e.g. foams).
seed: Seed for data generation.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.

Returns:
tuple: A tuple (img, sino, fbp) containing:
Expand All @@ -220,29 +219,24 @@ def generate_ct_data(
raise RuntimeError("Package astra is required for use of this function.")

# Generate input data.
if have_ray and prefer_ray:
start_time = time()
img = ray_distributed_data_generation(imgfunc, size, nimg, seed)
time_dtgen = time() - start_time
else:
start_time = time()
img = distributed_data_generation(imgfunc, size, nimg, False)
time_dtgen = time() - start_time
# Clip to [0,1] range.
start_time = time()
img = distributed_data_generation(imgfunc, size, nimg, seed)
time_dtgen = time() - start_time
# clip to [0,1] range
img = jnp.clip(img, 0, 1)

nproc = jax.device_count()

# Configure a CT projection operator to generate synthetic measurements.
angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles
gt_sh = (size, size)
detector_spacing = 1
detector_spacing = 1.0
A = XRayTransform2D(gt_sh, size, detector_spacing, angles) # X-ray transform operator

# Compute sinograms in parallel.
start_time = time()
if nproc > 1:
# Shard array
# shard array
imgshd = img.reshape((nproc, -1, size, size, 1))
sinoshd = batched_f(A, imgshd)
sino = sinoshd.reshape((-1, nproj, size, 1))
Expand Down Expand Up @@ -284,8 +278,7 @@ def generate_blur_data(
imgfunc: Callable,
seed: int = 4321,
verbose: bool = False,
prefer_ray: bool = True,
) -> Tuple[Array, ...]:
) -> Tuple[Array, Array]:
"""Generate batch of blurred data.

Generate batch of blurred data for training of machine learning
Expand All @@ -299,24 +292,16 @@ def generate_blur_data(
imgfunc: Function to generate foams.
seed: Seed for data generation.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.

Returns:
tuple: A tuple (img, blurn) containing:

- **img** : Generated foam images.
- **blurn** : Corresponding blurred and noisy images.
"""
if have_ray and prefer_ray:
start_time = time()
img = ray_distributed_data_generation(imgfunc, size, nimg, seed)
time_dtgen = time() - start_time
else:
start_time = time()
img = distributed_data_generation(imgfunc, size, nimg, False)
time_dtgen = time() - start_time
start_time = time()
img = distributed_data_generation(imgfunc, size, nimg, seed)
time_dtgen = time() - start_time

# Clip to [0,1] range.
img = jnp.clip(img, 0, 1)
Expand Down Expand Up @@ -356,44 +341,16 @@ def generate_blur_data(


def distributed_data_generation(
imgenf: Callable, size: int, nimg: int, sharded: bool = True
) -> Array:
"""Data generation distributed among processes using jax.

Args:
imagenf: Function for batch-data generation.
size: Size of image to generate.
ndata: Number of images to generate.
sharded: Flag to indicate if data is to be returned as the
chunks generated by each process or consolidated.
Default: ``True``.

Returns:
Array of generated data.
"""
nproc = jax.device_count()
seeds = jnp.arange(nproc)
if nproc > 1 and nimg % nproc > 0:
raise ValueError("Number of images to generate must be divisible by the number of devices")

ndata_per_proc = int(nimg // nproc)

idx = np.arange(nproc)
imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc)

# imgs = jax.pmap(imgenf, static_broadcasted_argnums=(1, 2))(seeds, size, ndata_per_proc)

if not sharded:
imgs = imgs.reshape((-1, size, size, 1))

return imgs


def ray_distributed_data_generation(
imgenf: Callable, size: int, nimg: int, seedg: float = 123
) -> Array:
) -> np.ndarray:
"""Data generation distributed among processes using ray.

*Warning:* callable `imgenf` should not make use of any jax functions
to avoid the risk of errors when running with GPU devices, in which
case jax is initialized to expect the availability of GPUs, which are
then not available within the `ray.remote` function due to the absence
of any declared GPUs as a `num_gpus` parameter of `@ray.remote`.

Args:
imagenf: Function for batch-data generation.
size: Size of image to generate.
Expand All @@ -405,27 +362,28 @@ def ray_distributed_data_generation(
"""
if not have_ray:
raise RuntimeError("Package ray is required for use of this function.")
if not ray.is_initialized():
raise RuntimeError("Ray must be initialized via ray.init() before using this function.")

@ray.remote
def data_gen(seed, size, ndata, imgf):
return imgf(seed, size, ndata)

# Use half of available CPU resources.
# Use half of available CPU resources
ar = ray.available_resources()
if "CPU" not in ar:
warnings.warn("No CPU key in ray.available_resources() output")
nproc = max(int(ar.get("CPU", "1")) // 2, 1)
# nproc = max(int(ar["CPU"]) // 2, 1)
warnings.warn("No CPU key in ray.available_resources() output.")
nproc = max(int(ar.get("CPU", 1)) // 2, 1)
if nproc > nimg:
nproc = nimg
if nproc > 1 and nimg % nproc > 0:
raise ValueError(
f"Number of images to generate ({nimg}) "
f"must be divisible by the number of available devices ({nproc})"
f"Number of images to generate ({nimg}) must be divisible by "
f"the number of available devices ({nproc})."
)

ndata_per_proc = int(nimg // nproc)

@ray.remote
def data_gen(seed, size, ndata, imgf):
return imgf(seed, size, ndata)

ray_return = ray.get(
[data_gen.remote(seed + seedg, size, ndata_per_proc, imgenf) for seed in range(nproc)]
)
Expand Down
8 changes: 0 additions & 8 deletions scico/flax/examples/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def load_ct_data(
nproj: int,
cache_path: Optional[str] = None,
verbose: bool = False,
prefer_ray: bool = True,
) -> Tuple[CTDataSetDict, ...]: # pragma: no cover
"""
Load or generate CT data.
Expand Down Expand Up @@ -77,8 +76,6 @@ def load_ct_data(
Default: ``None``.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.

Returns:
tuple: A tuple (trdt, ttdt) containing:
Expand Down Expand Up @@ -146,7 +143,6 @@ def load_ct_data(
size,
nproj,
verbose=verbose,
prefer_ray=prefer_ray,
)
# Separate training and testing partitions.
trdt = {"img": img[:train_nimg], "sino": sino[:train_nimg], "fbp": fbp[:train_nimg]}
Expand Down Expand Up @@ -186,7 +182,6 @@ def load_foam1_blur_data(
noise_sigma: float,
cache_path: Optional[str] = None,
verbose: bool = False,
prefer_ray: bool = True,
) -> Tuple[DataSetDict, ...]: # pragma: no cover
"""Load or generate blurred data based on xdesign foam structures.

Expand Down Expand Up @@ -214,8 +209,6 @@ def load_foam1_blur_data(
Default: ``None``.
verbose: Flag indicating whether to print status messages.
Default: ``False``.
prefer_ray: Use ray for distributed processing if available.
Default: ``True``.

Returns:
tuple: A tuple (train_ds, test_ds) containing:
Expand Down Expand Up @@ -297,7 +290,6 @@ def load_foam1_blur_data(
noise_sigma,
imgfunc=generate_foam1_images,
verbose=verbose,
prefer_ray=prefer_ray,
)
# Separate training and testing partitions.
train_ds = {"image": blrn[:train_nimg], "label": img[:train_nimg]}
Expand Down
2 changes: 1 addition & 1 deletion scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
Pdx = np.stack((dx[0] * jnp.cos(angles), dx[1] * jnp.sin(angles)))
Pdiag1 = np.abs(Pdx[0] + Pdx[1])
Pdiag2 = np.abs(Pdx[0] - Pdx[1])
max_width = np.max(np.maximum(Pdiag1, Pdiag2))
max_width: float = np.max(np.maximum(Pdiag1, Pdiag2))

if max_width > 1:
warn(
Expand Down
Loading
Loading