Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added caching to predict method #35

Merged
merged 7 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/pybasic/pybasic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class BaSiC(BaseModel):
# Private attributes for internal processing
_score: float = PrivateAttr(None)
_reweight_score: float = PrivateAttr(None)
_flatfield: np.ndarray = PrivateAttr(None)
_darkfield: np.ndarray = PrivateAttr(None)
_alm_settings = {
"lambda_darkfield",
"lambda_flatfield",
Expand All @@ -146,6 +148,10 @@ def __init__(self, **kwargs) -> None:
self.darkfield = np.zeros((self.working_size,) * 2, dtype=np.float64)
self.flatfield = np.zeros((self.working_size,) * 2, dtype=np.float64)

# Initialize the internal cache
self._darkfield = np.zeros((self.working_size,) * 2, dtype=np.float64)
self._flatfield = np.zeros((self.working_size,) * 2, dtype=np.float64)

if self.device is not Device.cpu:
# TODO: sanity checks on device selection
pass
Expand Down Expand Up @@ -218,7 +224,7 @@ def fit(self, images: np.ndarray) -> None:
# # TODO: implement inexact_alm_rspca_l1_intflat?
# raise IOError("Initial flatfield option not implemented yet!")
# else:
X_k_A, X_k_E, X_k_A_offset = inexact_alm_rspca_l1(
X_k_A, X_k_E, X_k_A_offset, self._score = inexact_alm_rspca_l1(
D, weight=weight, **self.dict(include=self._alm_settings)
)

Expand Down Expand Up @@ -251,6 +257,7 @@ def fit(self, images: np.ndarray) -> None:
)
flatfield_last = flatfield_current
darkfield_last = darkfield_current
self._reweight_score = np.maximum(mad_flatfield, mad_darkfield)
if (
np.maximum(mad_flatfield, mad_darkfield) <= self.reweighting_tol
or reweighting_iter >= self.max_reweight_iterations
Expand All @@ -263,6 +270,9 @@ def fit(self, images: np.ndarray) -> None:
if self.get_darkfield:
self.darkfield = X_A_offset

self._darkfield = self.darkfield
self._flatfield = self.flatfield

def predict(
self, images: np.ndarray, timelapse: bool = False
) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
Expand All @@ -287,8 +297,14 @@ def predict(
... imsave(f"image_{i}.tif")
"""

# Convert to the correct format
im_float = images.astype(np.float64)

# Check the image size
if not all(i == d for i, d in zip(self._flatfield.shape, images.shape)):
self._flatfield = resize(self.flatfield, images.shape[:2])
self._darkfield = resize(self.darkfield, images.shape[:2])

# Initialize the output
output = np.zeros(images.shape, dtype=images.dtype)

Expand All @@ -302,13 +318,13 @@ def unshade(ins, outs, i, dark, flat):
# If one or fewer workers, don't user ThreadPool. Useful for debugging.
if self.max_workers <= 1:
for i in range(images.shape[-1]):
unshade(im_float, output, i, self.darkfield, self.flatfield)
unshade(im_float, output, i, self._darkfield, self._flatfield)

else:
with ThreadPoolExecutor(self.max_workers) as executor:
threads = executor.map(
lambda x: unshade(
im_float, output, x, self.darkfield, self.flatfield
im_float, output, x, self._darkfield, self._flatfield
),
range(images.shape[-1]),
)
Expand Down
2 changes: 1 addition & 1 deletion src/pybasic/tools/inexact_alm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,4 @@ def inexact_alm_rspca_l1(
A_offset = np.squeeze(A_offset)
A_offset = A_offset + B1_offset * np.reshape(W_idct_hat, -1, order="F")

return A1_hat, E1_hat, A_offset
return A1_hat, E1_hat, A_offset, stopCriterion
24 changes: 24 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pybasic import BaSiC
import numpy as np
import pytest
from skimage.transform import resize


@pytest.fixture
Expand Down Expand Up @@ -69,15 +70,38 @@ def test_basic_predict(capsys, test_data):
"""Apply the shading model to the images"""
# flatfield only
basic.flatfield = gradient
basic._flatfield = gradient
corrected = basic.predict(images)
corrected_error = corrected.mean()
assert corrected_error < 0.5

# with darkfield correction
basic.darkfield = np.full(basic.flatfield.shape, 8)
basic._darkfield = np.full(basic.flatfield.shape, 8)
corrected = basic.predict(images)
assert corrected.mean() < corrected_error

"""Test shortcut"""
corrected = basic(images)
assert corrected.mean() < corrected_error


def test_basic_predict_resize(capsys, test_data):

basic = BaSiC(get_darkfield=False)
gradient, images, truth = test_data

images = resize(images, tuple(d * 2 for d in images.shape[:2]))
truth = resize(truth, tuple(d * 2 for d in truth.shape[:2]))

"""Apply the shading model to the images"""
# flatfield only
basic.flatfield = gradient
corrected = basic.predict(images)
corrected_error = corrected.mean()
assert corrected_error < 0.5

# with darkfield correction
basic.darkfield = np.full(basic.flatfield.shape, 8)
corrected = basic.predict(images)
assert corrected.mean() == corrected_error