Skip to content

Commit

Permalink
Timelapse option in transform (#80)
Browse files Browse the repository at this point in the history
* udpated transform method

* fixed tests

* fixed typo

* solving problems with tests

* fixed transform bug

* removed tests for rescaling-fitting

* removed resized transform example

* added dask test

* set baseline type to optinal

* updated testdata

* deleted notebook outputs

* Update timelapse

* re-organized transform part

* tests running...

* made the resize method changable

* resize method tested

* tested resize method

* added comments for the transform function

* minor fix

* fixed fitting_weight shape bug

Co-authored-by: Tim Morello <[email protected]>
  • Loading branch information
yfukai and tdmorello authored Jul 29, 2022
1 parent 046d762 commit 4571862
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 509 deletions.
415 changes: 8 additions & 407 deletions misc_notebooks/organize_test_data.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def tests(session: Session) -> None:
)
session.install(".")
session.install(
"dask",
"pytest",
"pytest-benchmark",
"pytest-datafiles",
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ bench =
jax
jaxlib>=0.3.10 # to import jaxlib.xla_extension.XlaRuntimeError
scipy
dask =
dask
dev =
black
bump2version
darglint
dask
flake8
flake8-alphabetize
flake8-black
Expand Down
193 changes: 122 additions & 71 deletions src/basicpy/basicpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Main BaSiC class.
Todo:
Keep examples up to date with changing API.
"""

# Core modules
Expand All @@ -11,7 +8,6 @@
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from multiprocessing import cpu_count
from pathlib import Path
Expand All @@ -22,13 +18,14 @@
# 3rd party modules
import numpy as np
from jax import device_put
from jax.image import ResizeMethod, resize
from jax.image import ResizeMethod
from jax.image import resize as jax_resize

# FIXME change this to jax.xla.XlaRuntimeError
# when https://github.com/google/jax/pull/10676 gets merged
from jaxlib.xla_extension import XlaRuntimeError
from pydantic import BaseModel, Field, PrivateAttr
from skimage.transform import resize as _resize
from skimage.transform import resize as skimage_resize

from basicpy._jax_routines import ApproximateFit, LadmapFit
from basicpy.tools.dct_tools import JaxDCT
Expand Down Expand Up @@ -70,6 +67,19 @@ class FittingMode(str, Enum):
approximate: str = "approximate"


class ResizeMode(str, Enum):

jax: str = "jax"
skimage: str = "skimage"
skimage_dask: str = "skimage_dask"


class TimelapseTransformMode(str, Enum):

additive: str = "additive"
multiplicative: str = "multiplicative"


# multiple channels should be handled by creating a `basic` object for each chan
class BaSiC(BaseModel):
"""A class for fitting and applying BaSiC illumination correction profiles."""
Expand Down Expand Up @@ -151,9 +161,14 @@ class BaSiC(BaseModel):
1e-3,
description="Optimization tolerance for update diff.",
)
resize_method: ResizeMethod = Field(
ResizeMethod.LINEAR,
description="Resize method to use when downsampling images.",
resize_mode: ResizeMode = Field(
ResizeMode.jax,
description="Resize mode to use when downsampling images. "
+ "Must be one of 'jax', 'skimage', and 'skimage_dask'",
)
resize_params: Dict = Field(
{},
description="Parameters for the resize function when downsampling images.",
)
reweighting_tol: float = Field(
1e-2,
Expand All @@ -171,8 +186,6 @@ class BaSiC(BaseModel):
# Private attributes for internal processing
_score: float = PrivateAttr(None)
_reweight_score: float = PrivateAttr(None)
_flatfield: np.ndarray = PrivateAttr(None)
_darkfield: np.ndarray = PrivateAttr(None)
_weight: float = PrivateAttr(None)
_residual: float = PrivateAttr(None)
_S: float = PrivateAttr(None)
Expand Down Expand Up @@ -211,7 +224,36 @@ def __call__(

return self.transform(images, timelapse)

def _resize(self, Im):
def _resize(self, Im, target_shape):
if self.resize_mode == ResizeMode.jax:
resize_params = dict(method=ResizeMethod.LINEAR)
resize_params.update(self.resize_params)
return jax_resize(Im, target_shape, **resize_params)
elif self.resize_mode == ResizeMode.skimage:
return skimage_resize(
Im, target_shape, preserve_range=True, **self.resize_params
)
elif self.resize_mode == ResizeMode.skimage_dask:
assert np.array_equal(target_shape[:-2], Im.shape[:-2])
import dask.array as da

return (
da.from_array(
[
skimage_resize(
Im[tuple(inds)],
target_shape[-2:],
preserve_range=True,
**self.resize_params,
)
for inds in np.ndindex(Im.shape[:-2])
]
)
.reshape((*Im.shape[:-2], *target_shape[-2:]))
.compute()
)

def _resize_to_working_size(self, Im):
"""
Resize the images to the working size.
"""
Expand All @@ -225,7 +267,9 @@ def _resize(self, Im):
)
else:
working_shape = self.working_size
Im = resize(Im, [*Im.shape[:2], *working_shape], self.resize_method)
target_shape = [*Im.shape[:2], *working_shape]
Im = self._resize(Im, target_shape)

return Im

def fit(
Expand All @@ -236,15 +280,17 @@ def fit(
Args:
images: Input images to fit shading model.
Must be 3-dimensional array with dimension of (T,Y,X).
Must be 3-dimensional or 4-dimensional array
with dimension of (T,Y,X) or (T,Z,Y,X).
T can be either of time or mosaic position.
fitting_weight: relative fitting weight for each pixel.
Higher value means more contribution to fitting.
Must has the same shape as images.
Example:
>>> from basicpy import BaSiC
>>> from basicpy.tools import load_images
>>> images = load_images('./images')
>>> from basicpy import data as bdata
>>> images = bdata.wsi_brain()
>>> basic = BaSiC() # use default settings
>>> basic.fit(images)
Expand All @@ -253,6 +299,8 @@ def fit(
ndim = images.ndim
if images.ndim == 3:
images = images[:, np.newaxis, ...]
if fitting_weight is not None:
fitting_weight = fitting_weight[:, np.newaxis, ...]
elif images.ndim == 4:
if self.fitting_mode == FittingMode.approximate:
raise ValueError(
Expand All @@ -268,11 +316,11 @@ def fit(
start_time = time.monotonic()

Im = device_put(images).astype(jnp.float32)
Im = self._resize(Im)
Im = self._resize_to_working_size(Im)

if fitting_weight is not None:
Ws = device_put(fitting_weight).astype(jnp.float32)
Ws = self._resize(Ws)
Ws = self._resize_to_working_size(Ws)
# normalize relative weight to 0 to 1
Ws_min = jnp.min(Ws)
Ws_max = jnp.max(Ws)
Expand Down Expand Up @@ -433,88 +481,92 @@ def fit(
self._residual = I_R
logger.info(f"Iteration {i} finished.")

self.flatfield = skimage_resize(S, images.shape[1:])
self.darkfield = skimage_resize(D, images.shape[1:])
if ndim == 3:
self.flatfield = S[0]
self.darkfield = D[0]
else:
self.flatfield = S
self.darkfield = D
self.flatfield = self.flatfield[0]
self.darkfield = self.darkfield[0]
self.baseline = B
logger.info(
f"=== BaSiC fit finished in {time.monotonic()-start_time} seconds ==="
)

def transform(
self, images: np.ndarray, timelapse: bool = False
self, images: np.ndarray, timelapse: Union[bool, TimelapseTransformMode] = False
) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
"""Apply profile to images.
Todo:
Add in baseline/timelapse correction.
Args:
images: input images to correct
timelapse: calculate timelapse/photobleaching offsets. Currently
does nothing.
images: input images to correct. See `fit`.
timelapse: If `True`, corrects the timelapse/photobleaching offsets,
assuming that the residual is the product of flatfield and
the object fluorescence. Also accepts "multplicative"
(the same as `True`) or "additive" (residual is the object
fluorescence).
Returns:
An array of the same size as images. If timelapse is True, returns
a flat array of baseline corrections used in the calculations.
corrected images
Example:
>>> basic.fit(images)
>>> corrected = basic.transform(images)
>>> for i, im in enumerate(corrected):
... imsave(f"image_{i}.tif")
"""

if self.baseline is None:
raise RuntimeError("BaSiC object is not initialized")

logger.info("=== BaSiC transform started ===")
start_time = time.monotonic()

# Convert to the correct format
im_float = images.astype(np.float32)

# Rescale the flatfield and darkfield
if not np.array_equal(self.flatfield.shape, im_float.shape[1:]):
self._flatfield = _resize(self.flatfield, images.shape[1:])
self._darkfield = _resize(self.darkfield, images.shape[1:])
else:
self._flatfield = self.flatfield
self._darkfield = self.darkfield
# Image = B_n x S_l + D_l + I_R_nl

# Initialize the output
output = np.empty(images.shape, dtype=images.dtype)
# in timelapse cases ...
# "Multiplicative" mode
# Real Image x S_l = I_R_nl
# Image = (B_n + Real Image) x S_l + D_l
# Real Image = (Image - D_l) / S_l - B_n

if timelapse:
# calculate timelapse from input series
...

def unshade(ins, outs, i, dark, flat):
outs[i] = (ins[i] - dark) / flat
# "Additive" mode
# Real Image = I_R_nl
# Image = B_n x S_l + D_l + Real Image
# Real Image = Image - D_l - (S_l x B_n)

logger.info(f"unshading in {self.max_workers} threads")
# If one or fewer workers, don't user ThreadPool. Useful for debugging.
if self.max_workers <= 1:
for i in range(images.shape[0]):
unshade(im_float, output, i, self._darkfield, self._flatfield)
# in non-timelapse cases ...
# we assume B_n is the mean of Real Image
# and then always assume Multiplicative mode
# the image model is
# Image = Real Image x S_l + D_l
# Real Image = (Image - D_l) / S_l

else:
with ThreadPoolExecutor(self.max_workers) as executor:
threads = executor.map(
lambda x: unshade(
im_float, output, x, self._darkfield, self._flatfield
),
range(images.shape[0]),
if timelapse:
if timelapse is True:
timelapse = TimelapseTransformMode.multiplicative

baseline_inds = tuple([slice(None)] + ([np.newaxis] * (im_float.ndim - 1)))
if timelapse == TimelapseTransformMode.multiplicative:
output = (im_float - self.darkfield[np.newaxis]) / self.flatfield[
np.newaxis
] - self.baseline[baseline_inds]
elif timelapse == TimelapseTransformMode.additive:
baseline_flatfield = (
self.flatfield[np.newaxis] * self.baseline[baseline_inds]
)

# Get the result of each thread, this should catch thread errors
for thread in threads:
assert thread is None

output = im_float - self.darkfield[np.newaxis] - baseline_flatfield
else:
raise ValueError(
"timelapse value must be bool, 'multiplicative' or 'additive'"
)
else:
output = (im_float - self.darkfield[np.newaxis]) / self.flatfield[
np.newaxis
]
logger.info(
f"=== BaSiC transform finished in {time.monotonic()-start_time} seconds ==="
)
return output.astype(images.dtype)
return output

# REFACTOR large datasets will probably prefer saving corrected images to
# files directly, a generator may be handy
Expand All @@ -524,18 +576,17 @@ def fit_transform(
"""Fit and transform on data.
Args:
images: input images to fit and correct
images: input images to fit and correct. See `fit`.
Returns:
profiles and corrected images
corrected images
Example:
>>> profiles, corrected = basic.fit_transform(images)
>>> corrected = basic.fit_transform(images)
"""
self.fit(images)
corrected = self.transform(images, timelapse)

# NOTE or only return corrected images and user can get profiles separately
return corrected

@property
Expand Down
Loading

0 comments on commit 4571862

Please sign in to comment.