diff --git a/test/test_image.py b/test/test_image.py index b3ab0b2364a..cdeadf7a0a0 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -2,14 +2,12 @@ import io import glob import unittest -import sys import torch -import torchvision from PIL import Image from torchvision.io.image import ( decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, - encode_png, write_png, write_file) + encode_png, write_png, write_file, ImageReadMode) import numpy as np from common_utils import get_tmp_dir @@ -49,9 +47,9 @@ def normalize_dimensions(img_pil): class ImageTester(unittest.TestCase): def test_decode_jpeg(self): - conversion = [(None, 0), ("L", 1), ("RGB", 3)] + conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)] for img_path in get_images(IMAGE_ROOT, ".jpg"): - for pil_mode, channels in conversion: + for pil_mode, mode in conversion: with Image.open(img_path) as img: is_cmyk = img.mode == "CMYK" if pil_mode is not None: @@ -66,7 +64,7 @@ def test_decode_jpeg(self): img_pil = normalize_dimensions(img_pil) data = read_file(img_path) - img_ljpeg = decode_image(data, channels=channels) + img_ljpeg = decode_image(data, mode=mode) # Permit a small variation on pixel values to account for implementation # differences between Pillow and LibJPEG. @@ -165,9 +163,10 @@ def test_write_jpeg(self): self.assertEqual(torch_bytes, pil_bytes) def test_decode_png(self): - conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)] + conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA), + ("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)] for img_path in get_images(FAKEDATA_DIR, ".png"): - for pil_mode, channels in conversion: + for pil_mode, mode in conversion: with Image.open(img_path) as img: if pil_mode is not None: img = img.convert(pil_mode) @@ -175,7 +174,7 @@ def test_decode_png(self): img_pil = normalize_dimensions(img_pil) data = read_file(img_path) - img_lpng = decode_image(data, channels=channels) + img_lpng = decode_image(data, mode=mode) tol = 0 if conversion is None else 1 self.assertTrue(img_lpng.allclose(img_pil, atol=tol)) diff --git a/torchvision/csrc/cpu/image/image_read_mode.h b/torchvision/csrc/cpu/image/image_read_mode.h new file mode 100644 index 00000000000..00ff4f6b581 --- /dev/null +++ b/torchvision/csrc/cpu/image/image_read_mode.h @@ -0,0 +1,9 @@ +#pragma once + +/* Should be kept in-sync with Python ImageReadMode enum */ +using ImageReadMode = int64_t; +#define IMAGE_READ_MODE_UNCHANGED 0 +#define IMAGE_READ_MODE_GRAY 1 +#define IMAGE_READ_MODE_GRAY_ALPHA 2 +#define IMAGE_READ_MODE_RGB 3 +#define IMAGE_READ_MODE_RGB_ALPHA 4 \ No newline at end of file diff --git a/torchvision/csrc/cpu/image/read_image_cpu.cpp b/torchvision/csrc/cpu/image/read_image_cpu.cpp index 5839017d3d7..6039b870f31 100644 --- a/torchvision/csrc/cpu/image/read_image_cpu.cpp +++ b/torchvision/csrc/cpu/image/read_image_cpu.cpp @@ -1,15 +1,14 @@ #include "read_image_cpu.h" -#include +#include "readjpeg_cpu.h" +#include "readpng_cpu.h" -torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) { +torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional TORCH_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); - TORCH_CHECK( - channels >= 0 && channels <= 4, "Number of channels not supported"); auto datap = data.data_ptr(); @@ -17,9 +16,9 @@ torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) { const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" if (memcmp(jpeg_signature, datap, 3) == 0) { - return decodeJPEG(data, channels); + return decodeJPEG(data, mode); } else if (memcmp(png_signature, datap, 4) == 0) { - return decodePNG(data, channels); + return decodePNG(data, mode); } else { TORCH_CHECK( false, diff --git a/torchvision/csrc/cpu/image/read_image_cpu.h b/torchvision/csrc/cpu/image/read_image_cpu.h index e926a8474da..6186d0d0d98 100644 --- a/torchvision/csrc/cpu/image/read_image_cpu.h +++ b/torchvision/csrc/cpu/image/read_image_cpu.h @@ -1,8 +1,8 @@ #pragma once -#include "readjpeg_cpu.h" -#include "readpng_cpu.h" +#include +#include "image_read_mode.h" C10_EXPORT torch::Tensor decode_image( const torch::Tensor& data, - int64_t channels = 0); + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp index d093dca0963..58584612697 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cpu.cpp +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.cpp @@ -1,10 +1,9 @@ #include "readjpeg_cpu.h" #include -#include #if !JPEG_FOUND -torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { +torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK( false, "decodeJPEG: torchvision not compiled with libjpeg support"); } @@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr( src->pub.next_input_byte = src->data; } -torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { +torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional TORCH_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); - TORCH_CHECK( - channels == 0 || channels == 1 || channels == 3, - "Number of channels not supported"); struct jpeg_decompress_struct cinfo; struct torch_jpeg_error_mgr jerr; @@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { // read info from header. jpeg_read_header(&cinfo, TRUE); - int current_channels = cinfo.num_components; + int channels = cinfo.num_components; - if (channels > 0 && channels != current_channels) { - switch (channels) { - case 1: // Gray - cinfo.out_color_space = JCS_GRAYSCALE; + if (mode != IMAGE_READ_MODE_UNCHANGED) { + switch (mode) { + case IMAGE_READ_MODE_GRAY: + if (cinfo.jpeg_color_space != JCS_GRAYSCALE) { + cinfo.out_color_space = JCS_GRAYSCALE; + channels = 1; + } break; - case 3: // RGB - cinfo.out_color_space = JCS_RGB; + case IMAGE_READ_MODE_RGB: + if (cinfo.jpeg_color_space != JCS_RGB) { + cinfo.out_color_space = JCS_RGB; + channels = 3; + } break; /* * Libjpeg does not support converting from CMYK to grayscale etc. There * is a way to do this but it involves converting it manually to RGB: * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 - * */ default: jpeg_destroy_decompress(&cinfo); - TORCH_CHECK(false, "Invalid number of output channels."); + TORCH_CHECK(false, "Provided mode not supported"); } jpeg_calc_output_dimensions(&cinfo); - } else { - channels = current_channels; } jpeg_start_decompress(&cinfo); diff --git a/torchvision/csrc/cpu/image/readjpeg_cpu.h b/torchvision/csrc/cpu/image/readjpeg_cpu.h index 0e7bb137d12..f05d05a9064 100644 --- a/torchvision/csrc/cpu/image/readjpeg_cpu.h +++ b/torchvision/csrc/cpu/image/readjpeg_cpu.h @@ -1,7 +1,8 @@ #pragma once #include +#include "image_read_mode.h" C10_EXPORT torch::Tensor decodeJPEG( const torch::Tensor& data, - int64_t channels = 0); + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); diff --git a/torchvision/csrc/cpu/image/readpng_cpu.cpp b/torchvision/csrc/cpu/image/readpng_cpu.cpp index 94f5060dab6..7adc125b2e8 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.cpp +++ b/torchvision/csrc/cpu/image/readpng_cpu.cpp @@ -1,26 +1,22 @@ #include "readpng_cpu.h" -// Comment #include -#include #if !PNG_FOUND -torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { +torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support"); } #else #include #include -torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { +torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional TORCH_CHECK( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); - TORCH_CHECK( - channels >= 0 && channels <= 4, "Number of channels not supported"); auto png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); @@ -74,75 +70,85 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { TORCH_CHECK(retval == 1, "Could read image metadata from content.") } - int current_channels = png_get_channels(png_ptr, info_ptr); + int channels = png_get_channels(png_ptr, info_ptr); - if (channels > 0) { + if (mode != IMAGE_READ_MODE_UNCHANGED) { // TODO: consider supporting PNG_INFO_tRNS bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; - switch (channels) { - case 1: // Gray - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } - - if (has_alpha) { - png_set_strip_alpha(png_ptr); - } - - if (has_color) { - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + switch (mode) { + case IMAGE_READ_MODE_GRAY: + if (color_type != PNG_COLOR_TYPE_GRAY) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + channels = 1; } break; - case 2: // Gray + Alpha - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } - - if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); - } - - if (has_color) { - png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + case IMAGE_READ_MODE_GRAY_ALPHA: + if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + + if (has_color) { + png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); + } + channels = 2; } break; - case 3: - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } else if (!has_color) { - png_set_gray_to_rgb(png_ptr); - } - - if (has_alpha) { - png_set_strip_alpha(png_ptr); + case IMAGE_READ_MODE_RGB: + if (color_type != PNG_COLOR_TYPE_RGB) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (has_alpha) { + png_set_strip_alpha(png_ptr); + } + channels = 3; } break; - case 4: - if (is_palette) { - png_set_palette_to_rgb(png_ptr); - has_alpha = true; - } else if (!has_color) { - png_set_gray_to_rgb(png_ptr); - } - - if (!has_alpha) { - png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + case IMAGE_READ_MODE_RGB_ALPHA: + if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { + if (is_palette) { + png_set_palette_to_rgb(png_ptr); + has_alpha = true; + } else if (!has_color) { + png_set_gray_to_rgb(png_ptr); + } + + if (!has_alpha) { + png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); + } + channels = 4; } break; default: png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "Invalid number of output channels."); + TORCH_CHECK(false, "Provided mode not supported"); } png_read_update_info(png_ptr, info_ptr); - } else { - channels = current_channels; } auto tensor = diff --git a/torchvision/csrc/cpu/image/readpng_cpu.h b/torchvision/csrc/cpu/image/readpng_cpu.h index a36032ddb25..9c74cb2c678 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.h +++ b/torchvision/csrc/cpu/image/readpng_cpu.h @@ -1,9 +1,8 @@ #pragma once -// Comment #include -#include +#include "image_read_mode.h" C10_EXPORT torch::Tensor decodePNG( const torch::Tensor& data, - int64_t channels = 0); + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 01e1e1e5ca0..f11a7ccb634 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -4,6 +4,8 @@ import os.path as osp import importlib.machinery +from enum import Enum + _HAS_IMAGE_OPT = False try: @@ -47,6 +49,14 @@ pass +class ImageReadMode(Enum): + UNCHANGED = 0 + GRAY = 1 + GRAY_ALPHA = 2 + RGB = 3 + RGB_ALPHA = 4 + + def read_file(path: str) -> torch.Tensor: """ Reads and outputs the bytes contents of a file as a uint8 Tensor @@ -74,24 +84,26 @@ def write_file(filename: str, data: torch.Tensor) -> None: torch.ops.image.write_file(filename, data) -def decode_png(input: torch.Tensor, channels: int = 0) -> torch.Tensor: +def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB Tensor. - Optionally converts the image to the desired number of color channels. + Optionally converts the image to the desired format. The values of the output tensor are uint8 between 0 and 255. Arguments: input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the PNG image. - channels (int): the number of output channels for the decoded - image. 0 keeps the original number of channels, 1 converts to Grayscale - 2 converts to Grayscale with Alpha, 3 converts to RGB and 4 coverts to - RGB with Alpha. Default: 0 + mode (ImageReadMode): the read mode used for optionally + converting the image. Use `ImageReadMode.UNCHANGED` for loading + the image as-is, `ImageReadMode.GRAY` for converting to grayscale, + `ImageReadMode.GRAY_ALPHA` for grayscale with transparency, + `ImageReadMode.RGB` for RGB and `ImageReadMode.RGB_ALPHA` for + RGB with transparency. Default: `ImageReadMode.UNCHANGED` Returns: output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_png(input, channels) + output = torch.ops.image.decode_png(input, mode.value) return output @@ -137,23 +149,24 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor, channels: int = 0) -> torch.Tensor: +def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. - Optionally converts the image to the desired number of color channels. + Optionally converts the image to the desired format. The values of the output tensor are uint8 between 0 and 255. Arguments: input (Tensor[1]): a one dimensional uint8 tensor containing the raw bytes of the JPEG image. - channels (int): the number of output channels for the decoded - image. 0 keeps the original number of channels, 1 converts to Grayscale - and 3 converts to RGB. Default: 0 + mode (ImageReadMode): the read mode used for optionally + converting the image. Use `ImageReadMode.UNCHANGED` for loading + the image as-is, `ImageReadMode.GRAY` for converting to grayscale + and `ImageReadMode.RGB` for RGB. Default: `ImageReadMode.UNCHANGED` Returns: output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_jpeg(input, channels) + output = torch.ops.image.decode_jpeg(input, mode.value) return output @@ -202,12 +215,12 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): write_file(filename, output) -def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor: +def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: """ Detects whether an image is a JPEG or PNG and performs the appropriate operation to decode the image into a 3 dimensional RGB Tensor. - Optionally converts the image to the desired number of color channels. + Optionally converts the image to the desired format. The values of the output tensor are uint8 between 0 and 255. Parameters @@ -215,39 +228,41 @@ def decode_image(input: torch.Tensor, channels: int = 0) -> torch.Tensor: input: Tensor a one dimensional uint8 tensor containing the raw bytes of the PNG or JPEG image. - channels: int - the number of output channels of the decoded image. JPEG and PNG images - have different permitted values. The default value is 0 and it keeps - the original number of channels. See `decode_jpeg()` and `decode_png()` - for more information. Default: 0 + mode: ImageReadMode + the read mode used for optionally converting the image. JPEG + and PNG images have different permitted values. The default + value is `ImageReadMode.UNCHANGED` and it keeps the image as-is. + See `decode_jpeg()` and `decode_png()` for more information. + Default: `ImageReadMode.UNCHANGED` Returns ------- output: Tensor[image_channels, image_height, image_width] """ - output = torch.ops.image.decode_image(input, channels) + output = torch.ops.image.decode_image(input, mode.value) return output -def read_image(path: str, channels: int = 0) -> torch.Tensor: +def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: """ Reads a JPEG or PNG image into a 3 dimensional RGB Tensor. - Optionally converts the image to the desired number of color channels. + Optionally converts the image to the desired format. The values of the output tensor are uint8 between 0 and 255. Parameters ---------- path: str path of the JPEG or PNG image. - channels: int - the number of output channels of the decoded image. JPEG and PNG images - have different permitted values. The default value is 0 and it keeps - the original number of channels. See `decode_jpeg()` and `decode_png()` - for more information. Default: 0 + mode: ImageReadMode + the read mode used for optionally converting the image. JPEG + and PNG images have different permitted values. The default + value is `ImageReadMode.UNCHANGED` and it keeps the image as-is. + See `decode_jpeg()` and `decode_png()` for more information. + Default: `ImageReadMode.UNCHANGED` Returns ------- output: Tensor[image_channels, image_height, image_width] """ data = read_file(path) - return decode_image(data, channels) + return decode_image(data, mode)