Skip to content

Commit

Permalink
add unit test to make sure the mask is binary
Browse files Browse the repository at this point in the history
  • Loading branch information
willyfh committed Feb 11, 2024
1 parent d2cbcf3 commit 8494fcc
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tests/unit/data/image/test_mvtec_loco.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> MVTecLoco:
_datamodule.setup()

return _datamodule

def test_mask_is_binary(self, datamodule: MVTecLoco) -> None:
"""Test if the mask tensor is binary."""
if datamodule.test_data.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
mask_tensor = datamodule.test_data[0]["mask"]
is_binary = (mask_tensor.eq(0) | mask_tensor.eq(1)).all()
assert is_binary.item() is True

0 comments on commit 8494fcc

Please sign in to comment.