Skip to content

Commit

Permalink
Fix PRO metric calculation on GPU (#1317)
Browse files Browse the repository at this point in the history
* Fix shapes and types in pro metric

* Adjust pro metric tests to fit real scenario

* Fix formatting in test

* Refactor target reshaping

Co-authored-by: Samet Akcay <[email protected]>

---------

Co-authored-by: Samet Akcay <[email protected]>
  • Loading branch information
blaz-r and samet-akcay authored Sep 5, 2023
1 parent 9ec7162 commit a2fbc11
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
7 changes: 5 additions & 2 deletions src/anomalib/utils/metrics/pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def compute(self) -> Tensor:
target = dim_zero_cat(self.target)
preds = dim_zero_cat(self.preds)

target = target.unsqueeze(1).type(torch.float) # kornia expects N1HW and FloatTensor format
if target.is_cuda:
comps = connected_components_gpu(target.unsqueeze(1))
comps = connected_components_gpu(target)
else:
comps = connected_components_cpu(target.unsqueeze(1))
comps = connected_components_cpu(target)
pro = pro_score(preds, comps, threshold=self.threshold)
return pro

Expand All @@ -63,6 +64,8 @@ def pro_score(predictions: Tensor, comps: Tensor, threshold: float = 0.5) -> Ten
n_comps = len(comps.unique())

preds = comps.clone()
# match the shapes in case one of the tensors is N1HW
preds = preds.reshape(predictions.shape)
preds[~predictions] = 0
if n_comps == 1: # only background
return torch.Tensor([1.0])
Expand Down
9 changes: 7 additions & 2 deletions tests/pre_merge/utils/metrics/test_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ def test_pro():
]
]
)
# ground truth mask is int type
labels = labels.type(torch.int32)

preds = (torch.arange(10) / 10) + 0.05
preds = preds.unsqueeze(1).repeat(1, 5).view(1, 1, 10, 5)
# metrics receive squeezed predictions (N, H, W)
preds = preds.unsqueeze(1).repeat(1, 5).view(1, 10, 5)

thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
targets = [1.0, 0.8, 0.6, 0.4, 0.2, 0.0]
Expand All @@ -49,8 +52,10 @@ def test_device_consistency():
batch = torch.zeros((32, 256, 256))
for i in range(batch.shape[0]):
batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5
# ground truth mask is int type
batch = batch.type(torch.int32)

preds = transform(batch).unsqueeze(1)
preds = transform(batch)

pro_cpu = PRO()
pro_gpu = PRO()
Expand Down

0 comments on commit a2fbc11

Please sign in to comment.