Skip to content

Commit

Permalink
Added caching to predict method (#35)
Browse files Browse the repository at this point in the history
* Added benchmark library asv

* Revert "Delete CONTRIBUTING.rst"

This reverts commit 01d8fe1.

* Revert "Added benchmark library asv"

This reverts commit d067f04.

* Deleted CONTRIBUTING.rst

* Set score values to object attributes

* Added caching for resizing of flatfield/darkfield
Nicholas-Schaub authored Jan 20, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent abb7124 commit cfad35a
Showing 2 changed files with 41 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/pybasic/pybasic.py
Original file line number Diff line number Diff line change
@@ -125,6 +125,8 @@ class BaSiC(BaseModel):
# Private attributes for internal processing
_score: float = PrivateAttr(None)
_reweight_score: float = PrivateAttr(None)
_flatfield: np.ndarray = PrivateAttr(None)
_darkfield: np.ndarray = PrivateAttr(None)
_alm_settings = {
"lambda_darkfield",
"lambda_flatfield",
@@ -146,6 +148,10 @@ def __init__(self, **kwargs) -> None:
self.darkfield = np.zeros((self.working_size,) * 2, dtype=np.float64)
self.flatfield = np.zeros((self.working_size,) * 2, dtype=np.float64)

# Initialize the internal cache
self._darkfield = np.zeros((self.working_size,) * 2, dtype=np.float64)
self._flatfield = np.zeros((self.working_size,) * 2, dtype=np.float64)

if self.device is not Device.cpu:
# TODO: sanity checks on device selection
pass
@@ -264,6 +270,9 @@ def fit(self, images: np.ndarray) -> None:
if self.get_darkfield:
self.darkfield = X_A_offset

self._darkfield = self.darkfield
self._flatfield = self.flatfield

def predict(
self, images: np.ndarray, timelapse: bool = False
) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]:
@@ -288,8 +297,14 @@ def predict(
... imsave(f"image_{i}.tif")
"""

# Convert to the correct format
im_float = images.astype(np.float64)

# Check the image size
if not all(i == d for i, d in zip(self._flatfield.shape, images.shape)):
self._flatfield = resize(self.flatfield, images.shape[:2])
self._darkfield = resize(self.darkfield, images.shape[:2])

# Initialize the output
output = np.zeros(images.shape, dtype=images.dtype)

@@ -303,13 +318,13 @@ def unshade(ins, outs, i, dark, flat):
# If one or fewer workers, don't user ThreadPool. Useful for debugging.
if self.max_workers <= 1:
for i in range(images.shape[-1]):
unshade(im_float, output, i, self.darkfield, self.flatfield)
unshade(im_float, output, i, self._darkfield, self._flatfield)

else:
with ThreadPoolExecutor(self.max_workers) as executor:
threads = executor.map(
lambda x: unshade(
im_float, output, x, self.darkfield, self.flatfield
im_float, output, x, self._darkfield, self._flatfield
),
range(images.shape[-1]),
)
24 changes: 24 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pybasic import BaSiC
import numpy as np
import pytest
from skimage.transform import resize


@pytest.fixture
@@ -69,15 +70,38 @@ def test_basic_predict(capsys, test_data):
"""Apply the shading model to the images"""
# flatfield only
basic.flatfield = gradient
basic._flatfield = gradient
corrected = basic.predict(images)
corrected_error = corrected.mean()
assert corrected_error < 0.5

# with darkfield correction
basic.darkfield = np.full(basic.flatfield.shape, 8)
basic._darkfield = np.full(basic.flatfield.shape, 8)
corrected = basic.predict(images)
assert corrected.mean() < corrected_error

"""Test shortcut"""
corrected = basic(images)
assert corrected.mean() < corrected_error


def test_basic_predict_resize(capsys, test_data):

basic = BaSiC(get_darkfield=False)
gradient, images, truth = test_data

images = resize(images, tuple(d * 2 for d in images.shape[:2]))
truth = resize(truth, tuple(d * 2 for d in truth.shape[:2]))

"""Apply the shading model to the images"""
# flatfield only
basic.flatfield = gradient
corrected = basic.predict(images)
corrected_error = corrected.mean()
assert corrected_error < 0.5

# with darkfield correction
basic.darkfield = np.full(basic.flatfield.shape, 8)
corrected = basic.predict(images)
assert corrected.mean() == corrected_error

0 comments on commit cfad35a

Please sign in to comment.