diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py index baf09b4ed..82f6c05b6 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/transform.py @@ -227,7 +227,8 @@ def semantic_segmentation_transformer( y = np.array(y) out = apply_transform(transform, image=x, mask=y) x, y = out['image'], out['mask'] - y = y.astype(int) + if y is not None: + y = y.astype(int) return x, y diff --git a/tests/pytorch_learner/dataset/test_transform.py b/tests/pytorch_learner/dataset/test_transform.py index 653b9ed42..73c3eb9c8 100644 --- a/tests/pytorch_learner/dataset/test_transform.py +++ b/tests/pytorch_learner/dataset/test_transform.py @@ -4,7 +4,8 @@ import albumentations as A from rastervision.pytorch_learner.dataset.transform import ( - yxyx_to_albu, albu_to_yxyx, xywh_to_albu, apply_transform) + yxyx_to_albu, albu_to_yxyx, xywh_to_albu, apply_transform, + semantic_segmentation_transformer) class TestTransforms(unittest.TestCase): @@ -65,6 +66,29 @@ def test_box_format_conversions_xywh(self): boxes_albu = xywh_to_albu(boxes, (10, 10)) np.testing.assert_allclose(boxes_albu, boxes_albu_gt) + def test_semantic_segmentation_transformer(self): + # w/ y, w/o transform + x_in, y_in = np.zeros((10, 10, 3), dtype=np.uint8), np.zeros((10, 10)) + x_out, y_out = semantic_segmentation_transformer((x_in, y_in), None) + np.issubdtype(y_out.dtype, int) + + # w/ y, w/ transform + x_out, y_out = semantic_segmentation_transformer((x_in, y_in), + A.Resize(20, 20)) + self.assertEqual(x_out.shape, (20, 20, 3)) + self.assertEqual(y_out.shape, (20, 20)) + np.issubdtype(y_out.dtype, int) + + # w/o y, w/o transform + x_out, y_out = semantic_segmentation_transformer((x_in, None), None) + self.assertIsNone(y_out) + + # w/o y, w/ transform + x_out, y_out = semantic_segmentation_transformer((x_in, None), + A.Resize(20, 20)) + self.assertEqual(x_out.shape, (20, 20, 3)) + self.assertIsNone(y_out) + if __name__ == '__main__': unittest.main()