Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed May 16, 2024
1 parent 76be1e8 commit db626a6
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 41 deletions.
10 changes: 5 additions & 5 deletions src/cultionet/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def tanimoto_distance(
scale = 1.0 / self.depth

if mask is not None:
mask = einops.rearrange(mask, 'b h w -> b 1 h w')
y = y * mask
yhat = yhat * mask

Expand All @@ -252,16 +251,18 @@ def tanimoto_distance(
tpl = tpl * weights
sq_sum = sq_sum * weights

numerator = tpl + self.smooth
denominator = 0.0
for d in range(0, self.depth):
a = 2.0**d
b = -(2.0 * a - 1.0)
denominator = denominator + torch.reciprocal(
(a * sq_sum) + (b * tpl) + self.smooth
(a * sq_sum) + (b * tpl)
)
denominator = torch.nan_to_num(
denominator, nan=0.0, posinf=0.0, neginf=0.0
)

return ((numerator * denominator) * scale).sum(dim=1)
return ((tpl * denominator) * scale).sum(dim=1)

def forward(
self,
Expand Down Expand Up @@ -327,7 +328,6 @@ def tanimoto_dist(

# Apply a mask to zero-out gradients where mask == 0
if mask is not None:
mask = einops.rearrange(mask, 'b h w -> b 1 h w')
ytrue = ytrue * mask
ypred = ypred * mask

Expand Down
2 changes: 2 additions & 0 deletions src/cultionet/models/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from pathlib import Path

import einops
import pandas as pd
import torch
import torch.nn as nn
Expand Down Expand Up @@ -527,6 +528,7 @@ def get_true_labels(
mask = torch.where(batch.y == -1, 0, 1).to(
dtype=torch.uint8, device=batch.y.device
)
mask = einops.rearrange(mask, 'b h w -> b 1 h w')

return {
"true_edge": true_edge,
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def create_batch(
width: int = 20,
rng: Optional[np.random.Generator] = None,
) -> Data:
x = torch.randn(1, num_channels, num_time, height, width)
y = torch.randint(low=0, high=3, size=(1, height, width))
x = torch.rand(1, num_channels, num_time, height, width)
y = torch.randint(low=-1, high=3, size=(1, height, width))
bdist = torch.rand(1, height, width)

if rng is None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def test_augmenter_loading():
width=50,
)

assert batch.x.min() >= 0
assert batch.x.max() <= 1
assert batch.y.min() == -1

batch.segments = np.uint8(nd_label(batch.y.squeeze().numpy() == 1)[0])
batch.props = regionprops(batch.segments)
aug_batch = method(batch.copy(), aug_args=aug.aug_args)
Expand All @@ -93,6 +97,10 @@ def test_augmenter_loading():
width=50,
)

assert batch.x.min() >= 0
assert batch.x.max() <= 1
assert batch.y.min() == -1

aug_batch = method(batch.copy(), aug_args=aug.aug_args)

if method.name_ == 'rotate-90':
Expand Down
61 changes: 27 additions & 34 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
DIST_TARGETS = torch.from_numpy(
rng.random((BATCH_SIZE, HEIGHT, WIDTH))
).float()
MASK = torch.from_numpy(
rng.integers(low=0, high=2, size=(BATCH_SIZE, 1, HEIGHT, WIDTH))
).long()


def test_loss_preprocessing():
Expand All @@ -53,18 +56,14 @@ def test_loss_preprocessing():
)
inputs, targets = preprocessor(INPUTS_CROP_LOGIT, DISCRETE_TARGETS)

assert inputs.shape == (BATCH_SIZE * HEIGHT * WIDTH, 2)
assert targets.shape == (BATCH_SIZE * HEIGHT * WIDTH, 2)
assert torch.allclose(targets.max(dim=0).values, torch.ones(2))
assert inputs.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH)
assert targets.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH)
assert torch.allclose(
inputs.sum(dim=1), torch.ones(BATCH_SIZE * HEIGHT * WIDTH), rtol=0.1
inputs.sum(dim=1), torch.ones(BATCH_SIZE, HEIGHT, WIDTH), rtol=0.1
)
assert torch.allclose(
inputs,
rearrange(
F.softmax(INPUTS_CROP_LOGIT, dim=1, dtype=INPUTS_CROP_LOGIT.dtype),
'b c h w -> (b h w) c',
),
F.softmax(INPUTS_CROP_LOGIT, dim=1, dtype=INPUTS_CROP_LOGIT.dtype),
)

# Input probabilities
Expand All @@ -73,31 +72,26 @@ def test_loss_preprocessing():
)
inputs, targets = preprocessor(INPUTS_CROP_PROB, DISCRETE_TARGETS)

assert inputs.shape == (BATCH_SIZE * HEIGHT * WIDTH, 2)
assert targets.shape == (BATCH_SIZE * HEIGHT * WIDTH, 2)
assert torch.allclose(targets.max(dim=0).values, torch.ones(2))
assert inputs.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH)
assert targets.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH)
assert torch.allclose(
inputs.sum(dim=1), torch.ones(BATCH_SIZE * HEIGHT * WIDTH), rtol=0.1
inputs.sum(dim=1), torch.ones(BATCH_SIZE, HEIGHT, WIDTH), rtol=0.1
)
assert torch.allclose(
inputs,
rearrange(INPUTS_CROP_PROB, 'b c h w -> (b h w) c'),
INPUTS_CROP_PROB,
)

preprocessor = LossPreprocessing(
transform_logits=False, one_hot_targets=True
)
# This should fail because there are more class targets than the input dimensions
with pytest.raises(ValueError):
inputs, targets = preprocessor(INPUTS_EDGE_PROB, DISCRETE_TARGETS)
inputs, targets = preprocessor(INPUTS_EDGE_PROB, DISCRETE_EDGE_TARGETS)

assert inputs.shape == (BATCH_SIZE * HEIGHT * WIDTH, 1)
assert targets.shape == (BATCH_SIZE * HEIGHT * WIDTH, 1)
assert torch.allclose(targets.max(dim=0).values, torch.ones(1))
assert inputs.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH)
assert targets.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH)
assert torch.allclose(
inputs,
rearrange(INPUTS_EDGE_PROB, 'b c h w -> (b h w) c'),
INPUTS_EDGE_PROB,
)

# Regression
Expand All @@ -107,32 +101,31 @@ def test_loss_preprocessing():
inputs, targets = preprocessor(INPUTS_DIST, DIST_TARGETS)

# Preprocessing should not change the inputs other than the shape
assert torch.allclose(
inputs, rearrange(INPUTS_DIST, 'b c h w -> (b h w) c')
)
assert torch.allclose(
targets, rearrange(DIST_TARGETS, 'b h w -> (b h w) 1')
)
assert torch.allclose(inputs, INPUTS_DIST)
assert torch.allclose(targets, rearrange(DIST_TARGETS, 'b h w -> b 1 h w'))


def test_tanimoto_classification_loss():
loss_func = TanimotoDistLoss(
scale_pos_weight=False,
transform_logits=False,
one_hot_targets=True,
)
loss_func = TanimotoDistLoss()

loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS)
assert round(float(loss.item()), 3) == 0.611
assert round(float(loss.item()), 3) == 0.61

loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK)
assert round(float(loss.item()), 3) == 0.608

loss_func = TanimotoComplementLoss()
loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS)
assert round(float(loss.item()), 3) == 0.824
assert round(float(loss.item()), 3) == 0.649

loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK)
assert round(float(loss.item()), 3) == 0.647


def test_tanimoto_regression_loss():
loss_func = TanimotoDistLoss(one_hot_targets=False)
loss = loss_func(INPUTS_DIST, DIST_TARGETS)
assert round(float(loss.item()), 4) == 0.4174
assert round(float(loss.item()), 3) == 0.417

loss_func = TanimotoComplementLoss(one_hot_targets=False)
loss = loss_func(INPUTS_DIST, DIST_TARGETS)
Expand Down

0 comments on commit db626a6

Please sign in to comment.