diff --git a/src/pybasic/pybasic.py b/src/pybasic/pybasic.py index b485ac4b..031659eb 100644 --- a/src/pybasic/pybasic.py +++ b/src/pybasic/pybasic.py @@ -125,6 +125,8 @@ class BaSiC(BaseModel): # Private attributes for internal processing _score: float = PrivateAttr(None) _reweight_score: float = PrivateAttr(None) + _flatfield: np.ndarray = PrivateAttr(None) + _darkfield: np.ndarray = PrivateAttr(None) _alm_settings = { "lambda_darkfield", "lambda_flatfield", @@ -146,6 +148,10 @@ def __init__(self, **kwargs) -> None: self.darkfield = np.zeros((self.working_size,) * 2, dtype=np.float64) self.flatfield = np.zeros((self.working_size,) * 2, dtype=np.float64) + # Initialize the internal cache + self._darkfield = np.zeros((self.working_size,) * 2, dtype=np.float64) + self._flatfield = np.zeros((self.working_size,) * 2, dtype=np.float64) + if self.device is not Device.cpu: # TODO: sanity checks on device selection pass @@ -218,7 +224,7 @@ def fit(self, images: np.ndarray) -> None: # # TODO: implement inexact_alm_rspca_l1_intflat? # raise IOError("Initial flatfield option not implemented yet!") # else: - X_k_A, X_k_E, X_k_A_offset = inexact_alm_rspca_l1( + X_k_A, X_k_E, X_k_A_offset, self._score = inexact_alm_rspca_l1( D, weight=weight, **self.dict(include=self._alm_settings) ) @@ -251,6 +257,7 @@ def fit(self, images: np.ndarray) -> None: ) flatfield_last = flatfield_current darkfield_last = darkfield_current + self._reweight_score = np.maximum(mad_flatfield, mad_darkfield) if ( np.maximum(mad_flatfield, mad_darkfield) <= self.reweighting_tol or reweighting_iter >= self.max_reweight_iterations @@ -263,6 +270,9 @@ def fit(self, images: np.ndarray) -> None: if self.get_darkfield: self.darkfield = X_A_offset + self._darkfield = self.darkfield + self._flatfield = self.flatfield + def predict( self, images: np.ndarray, timelapse: bool = False ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: @@ -287,8 +297,14 @@ def predict( ... imsave(f"image_{i}.tif") """ + # Convert to the correct format im_float = images.astype(np.float64) + # Check the image size + if not all(i == d for i, d in zip(self._flatfield.shape, images.shape)): + self._flatfield = resize(self.flatfield, images.shape[:2]) + self._darkfield = resize(self.darkfield, images.shape[:2]) + # Initialize the output output = np.zeros(images.shape, dtype=images.dtype) @@ -302,13 +318,13 @@ def unshade(ins, outs, i, dark, flat): # If one or fewer workers, don't user ThreadPool. Useful for debugging. if self.max_workers <= 1: for i in range(images.shape[-1]): - unshade(im_float, output, i, self.darkfield, self.flatfield) + unshade(im_float, output, i, self._darkfield, self._flatfield) else: with ThreadPoolExecutor(self.max_workers) as executor: threads = executor.map( lambda x: unshade( - im_float, output, x, self.darkfield, self.flatfield + im_float, output, x, self._darkfield, self._flatfield ), range(images.shape[-1]), ) diff --git a/src/pybasic/tools/inexact_alm.py b/src/pybasic/tools/inexact_alm.py index d1c4d3b5..9004f98c 100644 --- a/src/pybasic/tools/inexact_alm.py +++ b/src/pybasic/tools/inexact_alm.py @@ -189,4 +189,4 @@ def inexact_alm_rspca_l1( A_offset = np.squeeze(A_offset) A_offset = A_offset + B1_offset * np.reshape(W_idct_hat, -1, order="F") - return A1_hat, E1_hat, A_offset + return A1_hat, E1_hat, A_offset, stopCriterion diff --git a/tests/test_basic.py b/tests/test_basic.py index 1419e1b0..2cacb989 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,6 +1,7 @@ from pybasic import BaSiC import numpy as np import pytest +from skimage.transform import resize @pytest.fixture @@ -69,15 +70,38 @@ def test_basic_predict(capsys, test_data): """Apply the shading model to the images""" # flatfield only basic.flatfield = gradient + basic._flatfield = gradient corrected = basic.predict(images) corrected_error = corrected.mean() assert corrected_error < 0.5 # with darkfield correction basic.darkfield = np.full(basic.flatfield.shape, 8) + basic._darkfield = np.full(basic.flatfield.shape, 8) corrected = basic.predict(images) assert corrected.mean() < corrected_error """Test shortcut""" corrected = basic(images) assert corrected.mean() < corrected_error + + +def test_basic_predict_resize(capsys, test_data): + + basic = BaSiC(get_darkfield=False) + gradient, images, truth = test_data + + images = resize(images, tuple(d * 2 for d in images.shape[:2])) + truth = resize(truth, tuple(d * 2 for d in truth.shape[:2])) + + """Apply the shading model to the images""" + # flatfield only + basic.flatfield = gradient + corrected = basic.predict(images) + corrected_error = corrected.mean() + assert corrected_error < 0.5 + + # with darkfield correction + basic.darkfield = np.full(basic.flatfield.shape, 8) + corrected = basic.predict(images) + assert corrected.mean() == corrected_error