Skip to content

Commit

Permalink
slightly changed transform tests to get them working
Browse files Browse the repository at this point in the history
  • Loading branch information
yfukai committed May 24, 2022
1 parent af46839 commit a3d35ff
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
21 changes: 12 additions & 9 deletions src/basicpy/basicpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,27 +398,30 @@ def transform(
start_time = time.monotonic()

# Convert to the correct format
im_float = images.astype(np.float64)
im_float = images.astype(np.float32)

# Check the image size
if not all(i == d for i, d in zip(self._flatfield.shape, images.shape)):
self._flatfield = _resize(self.flatfield, images.shape[:2])
self._darkfield = _resize(self.darkfield, images.shape[:2])
# Rescale the flatfield and darkfield
if not np.array_equal(self.flatfield.shape, im_float.shape[1:]):
self._flatfield = _resize(self.flatfield, images.shape[1:])
self._darkfield = _resize(self.darkfield, images.shape[1:])
else:
self._flatfield = self.flatfield
self._darkfield = self.darkfield

# Initialize the output
output = np.zeros(images.shape, dtype=images.dtype)
output = np.empty(images.shape, dtype=images.dtype)

if timelapse:
# calculate timelapse from input series
...

def unshade(ins, outs, i, dark, flat):
outs[..., i] = (ins[..., i] - dark) / flat
outs[i] = (ins[i] - dark) / flat

logger.info(f"unshading in {self.max_workers} threads")
# If one or fewer workers, don't user ThreadPool. Useful for debugging.
if self.max_workers <= 1:
for i in range(images.shape[-1]):
for i in range(images.shape[0]):
unshade(im_float, output, i, self._darkfield, self._flatfield)

else:
Expand All @@ -427,7 +430,7 @@ def unshade(ins, outs, i, dark, flat):
lambda x: unshade(
im_float, output, x, self._darkfield, self._flatfield
),
range(images.shape[-1]),
range(images.shape[0]),
)

# Get the result of each thread, this should catch thread errors
Expand Down
20 changes: 10 additions & 10 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def test_data():
grid = np.meshgrid(*(2 * (np.linspace(-size // 2 + 1, size // 2, size),)))

# Create the gradient (flatfield) with and offset (darkfield)
gradient = sum(d ** 2 for d in grid) ** (1 / 2) + 8
gradient = sum(d**2 for d in grid) ** (1 / 2) + 8
gradient_int = gradient.astype(np.uint8)

# Ground truth, for correctness checking
truth = gradient / gradient.mean()

# Create an image stack and add poisson noise
images = np.random.poisson(lam=gradient_int.flatten(), size=(n_images, size ** 2))
images = np.random.poisson(lam=gradient_int.flatten(), size=(n_images, size**2))
images = images.transpose().reshape((size, size, n_images))

return gradient, images, truth
Expand Down Expand Up @@ -72,19 +72,19 @@ def test_basic_transform(capsys, test_data):
# flatfield only
basic.flatfield = gradient
basic._flatfield = gradient
corrected = basic.transform(images)
corrected = basic.transform(np.moveaxis(images, -1, 0))
corrected_error = corrected.mean()
assert corrected_error < 0.5

# with darkfield correction
basic.darkfield = np.full(basic.flatfield.shape, 8)
basic._darkfield = np.full(basic.flatfield.shape, 8)
corrected = basic.transform(images)
assert corrected.mean() < corrected_error
corrected = basic.transform(np.moveaxis(images, -1, 0))
assert corrected.mean() <= corrected_error

"""Test shortcut"""
corrected = basic(images)
assert corrected.mean() < corrected_error
corrected = basic(np.moveaxis(images, -1, 0))
assert corrected.mean() <= corrected_error


def test_basic_transform_resize(capsys, test_data):
Expand All @@ -98,14 +98,14 @@ def test_basic_transform_resize(capsys, test_data):
"""Apply the shading model to the images"""
# flatfield only
basic.flatfield = gradient
corrected = basic.transform(images)
corrected = basic.transform(np.moveaxis(images, -1, 0))
corrected_error = corrected.mean()
assert corrected_error < 0.5

# with darkfield correction
basic.darkfield = np.full(basic.flatfield.shape, 8)
corrected = basic.transform(images)
assert corrected.mean() == corrected_error
corrected = basic.transform(np.moveaxis(images, -1, 0))
assert corrected.mean() <= corrected_error


def test_basic_save_model(tmp_path: Path):
Expand Down

0 comments on commit a3d35ff

Please sign in to comment.