Skip to content

Commit

Permalink
Enable passing number of channels when inferring data format (hugging…
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts authored and EduardoPach committed Aug 9, 2023
1 parent 018061a commit 6222207
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,34 @@ def to_numpy_array(img) -> np.ndarray:
return to_numpy(img)


def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
def infer_channel_dimension_format(
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
) -> ChannelDimension:
"""
Infers the channel dimension format of `image`.
Args:
image (`np.ndarray`):
The image to infer the channel dimension of.
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
The number of channels of the image.
Returns:
The channel dimension of the image.
"""
num_channels = num_channels if num_channels is not None else (1, 3)
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels

if image.ndim == 3:
first_dim, last_dim = 0, 2
elif image.ndim == 4:
first_dim, last_dim = 1, 3
else:
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")

if image.shape[first_dim] in (1, 3):
if image.shape[first_dim] in num_channels:
return ChannelDimension.FIRST
elif image.shape[last_dim] in (1, 3):
elif image.shape[last_dim] in num_channels:
return ChannelDimension.LAST
raise ValueError("Unable to infer channel dimension format")

Expand Down
4 changes: 4 additions & 0 deletions tests/utils/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,10 @@ def test_infer_channel_dimension(self):
with pytest.raises(ValueError):
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))

# But if we explicitly set one of the number of channels to 50 it works
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
self.assertEqual(inferred_dim, ChannelDimension.LAST)

# Test we correctly identify the channel dimension
image = np.random.randint(0, 256, (3, 4, 5))
inferred_dim = infer_channel_dimension_format(image)
Expand Down

0 comments on commit 6222207

Please sign in to comment.