From 64ad5871aeccbb26b56dbac1a67f63c96bdca731 Mon Sep 17 00:00:00 2001 From: Siddharth Ancha Date: Sat, 28 Jan 2023 09:53:29 -0500 Subject: [PATCH] [Fix] Fix ignore class id from -1 to 255 in `master` (#2515) ## Motivation This fixes #2493. When the `label_map` is created, the index for ignored classes was being set to -1, whereas the index that is actually ignored is 255. This worked indirectly since -1 was underflowed to 255 when converting to uint8. The same fix was made in the 1.x by #2332 but this fix was never made to `master`. ## Modification The only small modification is setting the index of ignored classes to 255 instead of -1. ## Checklist - [x] Pre-commit or other linting tools are used to fix the potential lint issues. - _I've fixed all linting/pre-commit errors._ - [x] The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. - _No unit tests need to be added. Unit tests that are affected were modified. - [x] If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. - _I don't think this change affects MMDet or MMDet3D._ - [x] The documentation has been modified accordingly, like docstring or example tutorials. - _This change fixes an existing bug and doesn't require modifying any documentation/docstring._ --- mmseg/datasets/custom.py | 4 ++-- tests/test_data/test_loading.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 4615d4114e..0ee15b1aad 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -349,7 +349,7 @@ def get_classes_and_palette(self, classes=None, palette=None): self.label_map = {} for i, c in enumerate(self.CLASSES): if c not in class_names: - self.label_map[i] = -1 + self.label_map[i] = 255 else: self.label_map[i] = class_names.index(c) @@ -364,7 +364,7 @@ def get_palette_for_custom_classes(self, class_names, palette=None): palette = [] for old_id, new_id in sorted( self.label_map.items(), key=lambda x: x[1]): - if new_id != -1: + if new_id != 255: palette.append(self.PALETTE[old_id]) palette = type(self.PALETTE)(palette) diff --git a/tests/test_data/test_loading.py b/tests/test_data/test_loading.py index d41d460231..19f495accd 100644 --- a/tests/test_data/test_loading.py +++ b/tests/test_data/test_loading.py @@ -187,7 +187,7 @@ def test_load_seg_custom_classes(self): # classes=["A", "C", "D"] which removes class "B". label_map={ 0: 0, - 1: -1, # simulate removing class 1 + 1: 255, # simulate removing class 1 2: 1, 3: 2 }, @@ -204,7 +204,7 @@ def test_load_seg_custom_classes(self): true_mask = np.ones_like(gt_array) * 255 # all zeros get mapped to 255 true_mask[2:4, 2:4] = 0 # 1s are reduced to class 0 mapped to class 0 - true_mask[2:4, 6:8] = -1 # 2s are reduced to class 1 which is removed + true_mask[2:4, 6:8] = 255 # 2s are reduced to class 1 which is removed true_mask[6:8, 2:4] = 1 # 3s are reduced to class 2 mapped to class 1 true_mask[6:8, 6:8] = 2 # 4s are reduced to class 3 mapped to class 2