From 8494fcceb51196223edcff9a743c6626f144aa75 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 13:33:08 +0900 Subject: [PATCH] add unit test to make sure the `mask` is binary --- tests/unit/data/image/test_mvtec_loco.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/data/image/test_mvtec_loco.py b/tests/unit/data/image/test_mvtec_loco.py index f58355c8e4..a322fd6672 100644 --- a/tests/unit/data/image/test_mvtec_loco.py +++ b/tests/unit/data/image/test_mvtec_loco.py @@ -30,3 +30,10 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> MVTecLoco: _datamodule.setup() return _datamodule + + def test_mask_is_binary(self, datamodule: MVTecLoco) -> None: + """Test if the mask tensor is binary.""" + if datamodule.test_data.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + mask_tensor = datamodule.test_data[0]["mask"] + is_binary = (mask_tensor.eq(0) | mask_tensor.eq(1)).all() + assert is_binary.item() is True