diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index f06fdb0e6df..72ed45492ba 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray: return to_numpy(img) -def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension: +def infer_channel_dimension_format( + image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None +) -> ChannelDimension: """ Infers the channel dimension format of `image`. Args: image (`np.ndarray`): The image to infer the channel dimension of. + num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`): + The number of channels of the image. Returns: The channel dimension of the image. """ + num_channels = num_channels if num_channels is not None else (1, 3) + num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels + if image.ndim == 3: first_dim, last_dim = 0, 2 elif image.ndim == 4: @@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension: else: raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") - if image.shape[first_dim] in (1, 3): + if image.shape[first_dim] in num_channels: return ChannelDimension.FIRST - elif image.shape[last_dim] in (1, 3): + elif image.shape[last_dim] in num_channels: return ChannelDimension.LAST raise ValueError("Unable to infer channel dimension format") diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index 6d6c0f5d9b0..f62f647c42d 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -578,6 +578,10 @@ def test_infer_channel_dimension(self): with pytest.raises(ValueError): infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50))) + # But if we explicitly set one of the number of channels to 50 it works + inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50) + self.assertEqual(inferred_dim, ChannelDimension.LAST) + # Test we correctly identify the channel dimension image = np.random.randint(0, 256, (3, 4, 5)) inferred_dim = infer_channel_dimension_format(image)