Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change log level to warning #111

Merged
merged 9 commits into from
Dec 1, 2022
2 changes: 1 addition & 1 deletion src/basicpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from basicpy.basicpy import BaSiC

# Set logger level from environment variable
logging_level = os.getenv("BASIC_LOG_LEVEL", default="INFO").upper()
logging_level = os.getenv("BASIC_LOG_LEVEL", default="WARNING").upper()
logger = logging.getLogger(__name__)
logger.setLevel(logging_level)

Expand Down
102 changes: 50 additions & 52 deletions src/basicpy/basicpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Main BaSiC class.
"""
"""Main BaSiC class."""

# Core modules
from __future__ import annotations
Expand All @@ -11,7 +10,7 @@
from enum import Enum
from multiprocessing import cpu_count
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp

Expand All @@ -20,10 +19,7 @@
from jax import device_put
from jax.image import ResizeMethod
from jax.image import resize as jax_resize

# FIXME change this to jax.xla.XlaRuntimeError
# when https://github.com/google/jax/pull/10676 gets merged
from pydantic import BaseModel, Field, PrivateAttr
from pydantic import BaseModel, Field, PrivateAttr, root_validator
from skimage.transform import resize as skimage_resize

from basicpy._jax_routines import ApproximateFit, LadmapFit
Expand All @@ -34,9 +30,6 @@

newax = jnp.newaxis

# from basicpy.tools.dct2d_tools import dct2d, idct2d
# from basicpy.tools.inexact_alm import inexact_alm_rspca_l1

# Get number of available threads to limit CPU thrashing
# From preadator: https://pypi.org/project/preadator/
if hasattr(os, "sched_getaffinity"):
Expand All @@ -54,32 +47,36 @@


class Device(Enum):
"""Device selection enum."""

cpu: str = "cpu"
gpu: str = "gpu"
tpu: str = "tpu"


class FittingMode(str, Enum):
"""Fit method enum."""

ladmap: str = "ladmap"
approximate: str = "approximate"


class ResizeMode(str, Enum):
"""Resize method enum."""

jax: str = "jax"
skimage: str = "skimage"
skimage_dask: str = "skimage_dask"


class TimelapseTransformMode(str, Enum):
"""Timelapse transformation enum."""

additive: str = "additive"
multiplicative: str = "multiplicative"


# multiple channels should be handled by creating a `basic` object for each chan
# multiple channels should be handled by creating a `basic` object for each channel
class BaSiC(BaseModel):
"""A class for fitting and applying BaSiC illumination correction profiles."""

Expand All @@ -101,7 +98,6 @@ class BaSiC(BaseModel):
fitting_mode: FittingMode = Field(
FittingMode.ladmap, description="Must be one of ['ladmap', 'approximate']"
)

epsilon: float = Field(
0.1,
description="Weight regularization term.",
Expand Down Expand Up @@ -194,29 +190,23 @@ class BaSiC(BaseModel):
_profiles_fname = "profiles.npy"

class Config:
"""Pydantic class configuration."""

arbitrary_types_allowed = True
extra = "forbid"

def __init__(self, **kwargs) -> None:
"""Initialize BaSiC with the provided settings."""

log_str = f"Initializing BaSiC {id(self)} with parameters: \n"
for k, v in kwargs.items():
log_str += f"{k}: {v}\n"
logger.info(log_str)

super().__init__(**kwargs)

if self.device is not Device.cpu:
# TODO: sanity checks on device selection
pass
@root_validator(pre=True)
def debug_log_values(cls, values: Dict[str, Any]):
"""Use a validator to echo input values."""
logger.debug("Initializing BaSiC with parameters:")
for k, v in values.items():
logger.debug(f"{k}: {v}")
return values

def __call__(
self, images: np.ndarray, timelapse: bool = False
) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
"""Shortcut for BaSiC.transform"""

"""Shortcut for `BaSiC.transform`."""
return self.transform(images, timelapse)

def _resize(self, Im, target_shape):
Expand All @@ -225,11 +215,13 @@ def _resize(self, Im, target_shape):
resize_params.update(self.resize_params)
Im = device_put(Im).astype(jnp.float32)
return jax_resize(Im, target_shape, **resize_params)

elif self.resize_mode == ResizeMode.skimage:
Im = skimage_resize(
Im, target_shape, preserve_range=True, **self.resize_params
)
return device_put(Im).astype(jnp.float32)

elif self.resize_mode == ResizeMode.skimage_dask:
assert np.array_equal(target_shape[:-2], Im.shape[:-2])
import dask.array as da
Expand All @@ -252,9 +244,7 @@ def _resize(self, Im, target_shape):
return device_put(Im).astype(jnp.float32)

def _resize_to_working_size(self, Im):
"""
Resize the images to the working size.
"""
"""Resize the images to the working size."""
if self.working_size is not None:
if np.isscalar(self.working_size):
working_shape = [self.working_size] * (Im.ndim - 2)
Expand All @@ -273,8 +263,7 @@ def _resize_to_working_size(self, Im):
def fit(
self, images: np.ndarray, fitting_weight: Optional[np.ndarray] = None
) -> None:
"""
Generate illumination correction profiles from images.
"""Generate illumination correction profiles from images.

Args:
images: Input images to fit shading model.
Expand All @@ -293,7 +282,6 @@ def fit(
>>> basic.fit(images)

"""

ndim = images.ndim
if images.ndim == 3:
images = images[:, np.newaxis, ...]
Expand Down Expand Up @@ -352,9 +340,9 @@ def fit(
self._smoothness_darkfield = self.smoothness_darkfield
self._sparse_cost_darkfield = self.sparse_cost_darkfield

logger.info(f"_smoothness_flatfield set to {self._smoothness_flatfield}")
logger.info(f"_smoothness_darkfield set to {self._smoothness_darkfield}")
logger.info(f"_sparse_cost_darkfield set to {self._sparse_cost_darkfield}")
logger.debug(f"_smoothness_flatfield set to {self._smoothness_flatfield}")
logger.debug(f"_smoothness_darkfield set to {self._smoothness_darkfield}")
logger.debug(f"_sparse_cost_darkfield set to {self._sparse_cost_darkfield}")

# spectral_norm = jnp.linalg.norm(Im.reshape((Im.shape[0], -1)), ord=2)
_temp = jnp.linalg.svd(Im2.reshape((Im2.shape[0], -1)), full_matrices=False)
Expand Down Expand Up @@ -393,7 +381,7 @@ def fit(
fitting_step = ApproximateFit(**fit_params)

for i in range(self.max_reweight_iterations):
logger.info(f"reweighting iteration {i}")
logger.debug(f"reweighting iteration {i}")
if self.fitting_mode == FittingMode.approximate:
S = jnp.zeros(Im2.shape[1:], dtype=jnp.float32)
else:
Expand All @@ -415,15 +403,19 @@ def fit(
B,
I_R,
)
logger.info(f"single-step optimization score: {norm_ratio}.")
logger.info(f"mean of S: {float(jnp.mean(S))}.")
logger.debug(f"single-step optimization score: {norm_ratio}.")
logger.debug(f"mean of S: {float(jnp.mean(S))}.")
self._score = norm_ratio
if not converged:
logger.warning("single-step optimization did not converge.")
logger.debug("single-step optimization did not converge.")
if S.max() == 0:
logger.error("S is zero. Please try to decrease smoothness_darkfield.")
logger.error(
"Estimated flatfield is zero. "
+ "Please try to decrease smoothness_darkfield."
)
raise RuntimeError(
"S is zero. Please try to decrease smoothness_darkfield."
"Estimated flatfield is zero. "
+ "Please try to decrease smoothness_darkfield."
)
self._S = S
self._D_R = D_R
Expand All @@ -441,7 +433,7 @@ def fit(
self._weight_dark = W_D
self._residual = I_R

logger.info(f"Iteration {i} finished.")
logger.debug(f"Iteration {i} finished.")
if last_S is not None:
mad_flatfield = jnp.sum(jnp.abs(S - last_S)) / jnp.sum(np.abs(last_S))
if self.get_darkfield:
Expand All @@ -451,8 +443,11 @@ def fit(
self._reweight_score = max(mad_flatfield, mad_darkfield)
else:
self._reweight_score = mad_flatfield
logger.info(f"reweighting score: {self._reweight_score}")
logger.info(f"elapsed time: {time.monotonic() - start_time} seconds")
logger.debug(f"reweighting score: {self._reweight_score}")
logger.info(
f"Iteration {i} elapsed time: "
+ f"{time.monotonic() - start_time} seconds"
)

if self._reweight_score <= self.reweighting_tol:
logger.info("Reweighting converged.")
Expand All @@ -462,6 +457,9 @@ def fit(
last_S = S
last_D = D

if not converged:
logger.warning("Single-step optimization did not converge at the last reweighting step.")

assert S is not None
assert D is not None
assert B is not None
Expand All @@ -472,7 +470,7 @@ def fit(
if self.fitting_mode == FittingMode.approximate:
B = jnp.mean(Im, axis=(1, 2, 3))
I_R = jnp.zeros(Im.shape, dtype=jnp.float32)
logger.info(f"reweighting iteration for baseline {i}")
logger.debug(f"reweighting iteration for baseline {i}")
I_R, B, norm_ratio, converged = fitting_step.fit_baseline(
Im,
W,
Expand All @@ -486,7 +484,7 @@ def fit(
W = fitting_step.calc_weights_baseline(I_B, I_R) * Ws
self._weight = W
self._residual = I_R
logger.info(f"Iteration {i} finished.")
logger.debug(f"Iteration {i} finished.")

self.flatfield = skimage_resize(S, images.shape[1:])
self.darkfield = skimage_resize(D, images.shape[1:])
Expand All @@ -507,7 +505,7 @@ def transform(
images: input images to correct. See `fit`.
timelapse: If `True`, corrects the timelapse/photobleaching offsets,
assuming that the residual is the product of flatfield and
the object fluorescence. Also accepts "multplicative"
the object fluorescence. Also accepts "multiplicative"
(the same as `True`) or "additive" (residual is the object
fluorescence).

Expand All @@ -518,7 +516,6 @@ def transform(
>>> basic.fit(images)
>>> corrected = basic.transform(images)
"""

if self.baseline is None:
raise RuntimeError("BaSiC object is not initialized")

Expand Down Expand Up @@ -598,12 +595,12 @@ def fit_transform(

@property
def score(self):
"""The BaSiC fit final score"""
"""The BaSiC fit final score."""
return self._score

@property
def reweight_score(self):
"""The BaSiC fit final reweighting score"""
"""The BaSiC fit final reweighting score."""
return self._reweight_score

@property
Expand All @@ -622,7 +619,8 @@ def save_model(self, model_dir: PathLike, overwrite: bool = False) -> None:
model_dir: path to model directory

Raises:
FileExistsError: if model directory already exists"""
FileExistsError: if model directory already exists
"""
path = Path(model_dir)

try:
Expand Down
1 change: 1 addition & 0 deletions src/basicpy/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Datasets used for testing."""
import glob
from os import path

Expand Down