Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timelapse option in transform #80

Merged
merged 29 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c23952f
udpated transform method
yfukai Jun 27, 2022
945e117
fixed tests
yfukai Jun 27, 2022
55ac57b
fixed typo
yfukai Jun 27, 2022
820cfe2
Merge branch 'timelapse_option_in_transform' of https://github.com/yf…
tdmorello Jun 30, 2022
5c2a0d5
Merge branch 'dev' into timelapse_option_in_transform
tdmorello Jun 30, 2022
d1257fb
Merge branch 'dev' of https://github.com/peng-lab/BaSiCPy into timela…
yfukai Jul 1, 2022
92ea4ba
Merge branch 'timelapse_option_in_transform' of ssh://github.com/yfuk…
yfukai Jul 1, 2022
2bae103
solving problems with tests
yfukai Jul 1, 2022
5b04955
fixed transform bug
yfukai Jul 1, 2022
fbaa5f6
removed tests for rescaling-fitting
yfukai Jul 1, 2022
66185bf
removed resized transform example
yfukai Jul 1, 2022
6df8ef6
Merge branch 'flatfield-and-darkfield-shape' into timelapse_option_in…
yfukai Jul 2, 2022
8f409e7
added dask test
yfukai Jul 3, 2022
b6250eb
Merge branch 'dev' of https://github.com/peng-lab/BaSiCPy into timela…
yfukai Jul 4, 2022
446ee39
set baseline type to optinal
yfukai Jul 4, 2022
26ed361
updated testdata
yfukai Jul 4, 2022
8bc58bd
deleted notebook outputs
yfukai Jul 4, 2022
466179d
Merge branch 'timelapse_option_in_transform' of https://github.com/yf…
tdmorello Jul 4, 2022
01d457d
Update timelapse
tdmorello Jul 5, 2022
382a50a
re-organized transform part
yfukai Jul 9, 2022
989e47f
tests running...
yfukai Jul 9, 2022
9bef4d3
made the resize method changable
yfukai Jul 9, 2022
67f7e86
Merge branch 'timelapse_option_in_transform' into timelapse-update
yfukai Jul 10, 2022
1dee296
Merge pull request #3 from tdmorello/timelapse-update
yfukai Jul 10, 2022
e2cf1e0
resize method tested
yfukai Jul 12, 2022
455a965
tested resize method
yfukai Jul 12, 2022
69de43d
added comments for the transform function
yfukai Jul 15, 2022
e30c3d5
minor fix
yfukai Jul 15, 2022
d01bb9c
fixed fitting_weight shape bug
yfukai Jul 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
415 changes: 8 additions & 407 deletions misc_notebooks/organize_test_data.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def tests(session: Session) -> None:
)
session.install(".")
session.install(
"dask",
"pytest",
"pytest-benchmark",
"pytest-datafiles",
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ bench =
jax
jaxlib>=0.3.10 # to import jaxlib.xla_extension.XlaRuntimeError
scipy
dask =
dask
dev =
black
bump2version
darglint
dask
flake8
flake8-alphabetize
flake8-black
Expand Down
193 changes: 122 additions & 71 deletions src/basicpy/basicpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Main BaSiC class.

Todo:
Keep examples up to date with changing API.
"""

# Core modules
Expand All @@ -11,7 +8,6 @@
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from multiprocessing import cpu_count
from pathlib import Path
Expand All @@ -22,13 +18,14 @@
# 3rd party modules
import numpy as np
from jax import device_put
from jax.image import ResizeMethod, resize
from jax.image import ResizeMethod
from jax.image import resize as jax_resize

# FIXME change this to jax.xla.XlaRuntimeError
# when https://github.com/google/jax/pull/10676 gets merged
from jaxlib.xla_extension import XlaRuntimeError
from pydantic import BaseModel, Field, PrivateAttr
from skimage.transform import resize as _resize
from skimage.transform import resize as skimage_resize

from basicpy._jax_routines import ApproximateFit, LadmapFit
from basicpy.tools.dct_tools import JaxDCT
Expand Down Expand Up @@ -70,6 +67,19 @@ class FittingMode(str, Enum):
approximate: str = "approximate"


class ResizeMode(str, Enum):

jax: str = "jax"
skimage: str = "skimage"
skimage_dask: str = "skimage_dask"


class TimelapseTransformMode(str, Enum):

additive: str = "additive"
multiplicative: str = "multiplicative"


# multiple channels should be handled by creating a `basic` object for each chan
class BaSiC(BaseModel):
"""A class for fitting and applying BaSiC illumination correction profiles."""
Expand Down Expand Up @@ -151,9 +161,14 @@ class BaSiC(BaseModel):
1e-3,
description="Optimization tolerance for update diff.",
)
resize_method: ResizeMethod = Field(
ResizeMethod.LINEAR,
description="Resize method to use when downsampling images.",
resize_mode: ResizeMode = Field(
ResizeMode.jax,
description="Resize mode to use when downsampling images. "
+ "Must be one of 'jax', 'skimage', and 'skimage_dask'",
)
resize_params: Dict = Field(
{},
description="Parameters for the resize function when downsampling images.",
)
reweighting_tol: float = Field(
1e-2,
Expand All @@ -171,8 +186,6 @@ class BaSiC(BaseModel):
# Private attributes for internal processing
_score: float = PrivateAttr(None)
_reweight_score: float = PrivateAttr(None)
_flatfield: np.ndarray = PrivateAttr(None)
_darkfield: np.ndarray = PrivateAttr(None)
_weight: float = PrivateAttr(None)
_residual: float = PrivateAttr(None)
_S: float = PrivateAttr(None)
Expand Down Expand Up @@ -211,7 +224,36 @@ def __call__(

return self.transform(images, timelapse)

def _resize(self, Im):
def _resize(self, Im, target_shape):
if self.resize_mode == ResizeMode.jax:
resize_params = dict(method=ResizeMethod.LINEAR)
resize_params.update(self.resize_params)
return jax_resize(Im, target_shape, **resize_params)
elif self.resize_mode == ResizeMode.skimage:
return skimage_resize(
Im, target_shape, preserve_range=True, **self.resize_params
)
elif self.resize_mode == ResizeMode.skimage_dask:
assert np.array_equal(target_shape[:-2], Im.shape[:-2])
import dask.array as da

return (
da.from_array(
[
skimage_resize(
Im[tuple(inds)],
target_shape[-2:],
preserve_range=True,
**self.resize_params,
)
for inds in np.ndindex(Im.shape[:-2])
]
)
.reshape((*Im.shape[:-2], *target_shape[-2:]))
.compute()
)

def _resize_to_working_size(self, Im):
"""
Resize the images to the working size.
"""
Expand All @@ -225,7 +267,9 @@ def _resize(self, Im):
)
else:
working_shape = self.working_size
Im = resize(Im, [*Im.shape[:2], *working_shape], self.resize_method)
target_shape = [*Im.shape[:2], *working_shape]
Im = self._resize(Im, target_shape)

return Im

def fit(
Expand All @@ -236,15 +280,17 @@ def fit(

Args:
images: Input images to fit shading model.
Must be 3-dimensional array with dimension of (T,Y,X).
Must be 3-dimensional or 4-dimensional array
with dimension of (T,Y,X) or (T,Z,Y,X).
T can be either of time or mosaic position.
fitting_weight: relative fitting weight for each pixel.
Higher value means more contribution to fitting.
Must has the same shape as images.

Example:
>>> from basicpy import BaSiC
>>> from basicpy.tools import load_images
>>> images = load_images('./images')
>>> from basicpy import data as bdata
>>> images = bdata.wsi_brain()
>>> basic = BaSiC() # use default settings
>>> basic.fit(images)

Expand All @@ -253,6 +299,8 @@ def fit(
ndim = images.ndim
if images.ndim == 3:
images = images[:, np.newaxis, ...]
if fitting_weight is not None:
fitting_weight = fitting_weight[:, np.newaxis, ...]
elif images.ndim == 4:
if self.fitting_mode == FittingMode.approximate:
raise ValueError(
Expand All @@ -268,11 +316,11 @@ def fit(
start_time = time.monotonic()

Im = device_put(images).astype(jnp.float32)
Im = self._resize(Im)
Im = self._resize_to_working_size(Im)

if fitting_weight is not None:
Ws = device_put(fitting_weight).astype(jnp.float32)
Ws = self._resize(Ws)
Ws = self._resize_to_working_size(Ws)
# normalize relative weight to 0 to 1
Ws_min = jnp.min(Ws)
Ws_max = jnp.max(Ws)
Expand Down Expand Up @@ -433,88 +481,92 @@ def fit(
self._residual = I_R
logger.info(f"Iteration {i} finished.")

self.flatfield = skimage_resize(S, images.shape[1:])
self.darkfield = skimage_resize(D, images.shape[1:])
if ndim == 3:
self.flatfield = S[0]
self.darkfield = D[0]
else:
self.flatfield = S
self.darkfield = D
self.flatfield = self.flatfield[0]
self.darkfield = self.darkfield[0]
self.baseline = B
logger.info(
f"=== BaSiC fit finished in {time.monotonic()-start_time} seconds ==="
)

def transform(
self, images: np.ndarray, timelapse: bool = False
self, images: np.ndarray, timelapse: Union[bool, TimelapseTransformMode] = False
) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
"""Apply profile to images.

Todo:
Add in baseline/timelapse correction.

Args:
images: input images to correct
timelapse: calculate timelapse/photobleaching offsets. Currently
does nothing.
images: input images to correct. See `fit`.
timelapse: If `True`, corrects the timelapse/photobleaching offsets,
assuming that the residual is the product of flatfield and
the object fluorescence. Also accepts "multplicative"
(the same as `True`) or "additive" (residual is the object
fluorescence).

Returns:
An array of the same size as images. If timelapse is True, returns
a flat array of baseline corrections used in the calculations.
corrected images

Example:
>>> basic.fit(images)
>>> corrected = basic.transform(images)
>>> for i, im in enumerate(corrected):
... imsave(f"image_{i}.tif")
"""

if self.baseline is None:
raise RuntimeError("BaSiC object is not initialized")

logger.info("=== BaSiC transform started ===")
start_time = time.monotonic()

# Convert to the correct format
im_float = images.astype(np.float32)

# Rescale the flatfield and darkfield
if not np.array_equal(self.flatfield.shape, im_float.shape[1:]):
self._flatfield = _resize(self.flatfield, images.shape[1:])
self._darkfield = _resize(self.darkfield, images.shape[1:])
else:
self._flatfield = self.flatfield
self._darkfield = self.darkfield
# Image = B_n x S_l + D_l + I_R_nl

# Initialize the output
output = np.empty(images.shape, dtype=images.dtype)
# in timelapse cases ...
# "Multiplicative" mode
# Real Image x S_l = I_R_nl
# Image = (B_n + Real Image) x S_l + D_l
# Real Image = (Image - D_l) / S_l - B_n

if timelapse:
# calculate timelapse from input series
...

def unshade(ins, outs, i, dark, flat):
outs[i] = (ins[i] - dark) / flat
# "Additive" mode
# Real Image = I_R_nl
# Image = B_n x S_l + D_l + Real Image
# Real Image = Image - D_l - (S_l x B_n)

logger.info(f"unshading in {self.max_workers} threads")
# If one or fewer workers, don't user ThreadPool. Useful for debugging.
if self.max_workers <= 1:
for i in range(images.shape[0]):
unshade(im_float, output, i, self._darkfield, self._flatfield)
# in non-timelapse cases ...
# we assume B_n is the mean of Real Image
# and then always assume Multiplicative mode
# the image model is
# Image = Real Image x S_l + D_l
# Real Image = (Image - D_l) / S_l

else:
with ThreadPoolExecutor(self.max_workers) as executor:
threads = executor.map(
lambda x: unshade(
im_float, output, x, self._darkfield, self._flatfield
),
range(images.shape[0]),
if timelapse:
if timelapse is True:
timelapse = TimelapseTransformMode.multiplicative

baseline_inds = tuple([slice(None)] + ([np.newaxis] * (im_float.ndim - 1)))
if timelapse == TimelapseTransformMode.multiplicative:
output = (im_float - self.darkfield[np.newaxis]) / self.flatfield[
np.newaxis
] - self.baseline[baseline_inds]
elif timelapse == TimelapseTransformMode.additive:
baseline_flatfield = (
self.flatfield[np.newaxis] * self.baseline[baseline_inds]
)

# Get the result of each thread, this should catch thread errors
for thread in threads:
assert thread is None

output = im_float - self.darkfield[np.newaxis] - baseline_flatfield
else:
raise ValueError(
"timelapse value must be bool, 'multiplicative' or 'additive'"
)
else:
output = (im_float - self.darkfield[np.newaxis]) / self.flatfield[
np.newaxis
]
logger.info(
f"=== BaSiC transform finished in {time.monotonic()-start_time} seconds ==="
)
return output.astype(images.dtype)
return output

# REFACTOR large datasets will probably prefer saving corrected images to
# files directly, a generator may be handy
Expand All @@ -524,18 +576,17 @@ def fit_transform(
"""Fit and transform on data.

Args:
images: input images to fit and correct
images: input images to fit and correct. See `fit`.

Returns:
profiles and corrected images
corrected images

Example:
>>> profiles, corrected = basic.fit_transform(images)
>>> corrected = basic.fit_transform(images)
"""
self.fit(images)
corrected = self.transform(images, timelapse)

# NOTE or only return corrected images and user can get profiles separately
return corrected

@property
Expand Down
Loading