diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index aaa9e4eadc6a..03f13ae82057 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -202,7 +202,12 @@ def infer_channel_dimension_format( else: raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") - if image.shape[first_dim] in num_channels: + if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels: + logger.warning( + f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension." + ) + return ChannelDimension.FIRST + elif image.shape[first_dim] in num_channels: return ChannelDimension.FIRST elif image.shape[last_dim] in num_channels: return ChannelDimension.LAST