Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of mode and remove channels #3024

Merged
merged 2 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import io
import glob
import unittest
import sys

import torch
import torchvision
from PIL import Image
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file)
encode_png, write_png, write_file, ImageReadMode)
import numpy as np

from common_utils import get_tmp_dir
Expand Down Expand Up @@ -49,9 +47,9 @@ def normalize_dimensions(img_pil):

class ImageTester(unittest.TestCase):
def test_decode_jpeg(self):
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
for img_path in get_images(IMAGE_ROOT, ".jpg"):
for pil_mode, channels in conversion:
for pil_mode, mode in conversion:
with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
if pil_mode is not None:
Expand All @@ -66,7 +64,7 @@ def test_decode_jpeg(self):

img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_ljpeg = decode_image(data, channels=channels)
img_ljpeg = decode_image(data, mode=mode)

# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
Expand Down Expand Up @@ -165,17 +163,18 @@ def test_write_jpeg(self):
self.assertEqual(torch_bytes, pil_bytes)

def test_decode_png(self):
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
for img_path in get_images(FAKEDATA_DIR, ".png"):
for pil_mode, channels in conversion:
for pil_mode, mode in conversion:
with Image.open(img_path) as img:
if pil_mode is not None:
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))

img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_lpng = decode_image(data, channels=channels)
img_lpng = decode_image(data, mode=mode)

tol = 0 if conversion is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
Expand Down
9 changes: 9 additions & 0 deletions torchvision/csrc/cpu/image/image_read_mode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

/* Should be kept in-sync with Python ImageReadMode enum */
datumbox marked this conversation as resolved.
Show resolved Hide resolved
using ImageReadMode = int64_t;
#define IMAGE_READ_MODE_UNCHANGED 0
#define IMAGE_READ_MODE_GRAY 1
#define IMAGE_READ_MODE_GRAY_ALPHA 2
#define IMAGE_READ_MODE_RGB 3
#define IMAGE_READ_MODE_RGB_ALPHA 4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: new line

I'll add it along with the other proposed corrections to minimize CI runs.

Comment on lines +5 to +9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is fine as is, but I wonder if a more modern pattern now is to use constexpr ImageReadMode kModeUnchanged = 0; or something like that, maybe within a namespace

11 changes: 5 additions & 6 deletions torchvision/csrc/cpu/image/read_image_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
#include "read_image_cpu.h"
#include <cstring>
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"

torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");

auto datap = data.data_ptr<uint8_t>();

const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data, channels);
return decodeJPEG(data, mode);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data, channels);
return decodePNG(data, mode);
} else {
TORCH_CHECK(
false,
Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/cpu/image/read_image_cpu.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include <torch/torch.h>
#include "image_read_mode.h"

C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
33 changes: 16 additions & 17 deletions torchvision/csrc/cpu/image/readjpeg_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#include "readjpeg_cpu.h"

#include <ATen/ATen.h>
#include <string>

#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}
Expand Down Expand Up @@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}

torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels == 0 || channels == 1 || channels == 3,
"Number of channels not supported");

struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
Expand All @@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
// read info from header.
jpeg_read_header(&cinfo, TRUE);

int current_channels = cinfo.num_components;
int channels = cinfo.num_components;

if (channels > 0 && channels != current_channels) {
switch (channels) {
case 1: // Gray
cinfo.out_color_space = JCS_GRAYSCALE;
if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
break;
case 3: // RGB
cinfo.out_color_space = JCS_RGB;
case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
cinfo.out_color_space = JCS_RGB;
channels = 3;
}
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
* is a way to do this but it involves converting it manually to RGB:
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
*
*/
default:
jpeg_destroy_decompress(&cinfo);
TORCH_CHECK(false, "Invalid number of output channels.");
TORCH_CHECK(false, "Provided mode not supported");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that if an unsupported conversion operation is requested, for instance CMYK to RGB, this check is not going to trigger. Instead we will get a RuntimeError: Unsupported color conversion request exception on the call of jpeg_start_decompress() below. I think it's better to leave libjpeg throw its own exception for its limitations than having extra code on our side to check for the same thing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be good to be explicit here and mention that it's not supported because the input is jpeg? Otherwise it could be confusing to the user, wdyt?

}

jpeg_calc_output_dimensions(&cinfo);
} else {
channels = current_channels;
}

jpeg_start_decompress(&cinfo);
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/cpu/image/readjpeg_cpu.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include <torch/torch.h>
#include "image_read_mode.h"

C10_EXPORT torch::Tensor decodeJPEG(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
118 changes: 62 additions & 56 deletions torchvision/csrc/cpu/image/readpng_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
#include "readpng_cpu.h"

// Comment
#include <ATen/ATen.h>
#include <string>

#if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
}
#else
#include <png.h>
#include <setjmp.h>

torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");

auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
Expand Down Expand Up @@ -74,75 +70,85 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}

int current_channels = png_get_channels(png_ptr, info_ptr);
int channels = png_get_channels(png_ptr, info_ptr);

if (channels > 0) {
if (mode != IMAGE_READ_MODE_UNCHANGED) {
// TODO: consider supporting PNG_INFO_tRNS
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;

switch (channels) {
case 1: // Gray
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (color_type != PNG_COLOR_TYPE_GRAY) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
channels = 1;
}
break;
case 2: // Gray + Alpha
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
case IMAGE_READ_MODE_GRAY_ALPHA:
if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
channels = 2;
}
break;
case 3:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
case IMAGE_READ_MODE_RGB:
if (color_type != PNG_COLOR_TYPE_RGB) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
}
channels = 3;
}
break;
case 4:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
case IMAGE_READ_MODE_RGB_ALPHA:
if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}
channels = 4;
}
break;
default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Invalid number of output channels.");
TORCH_CHECK(false, "Provided mode not supported");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about the error message, it would be better to specify that this was not supported for PNG maybe?

}

png_read_update_info(png_ptr, info_ptr);
} else {
channels = current_channels;
}

auto tensor =
Expand Down
5 changes: 2 additions & 3 deletions torchvision/csrc/cpu/image/readpng_cpu.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#pragma once

// Comment
#include <torch/torch.h>
#include <string>
#include "image_read_mode.h"

C10_EXPORT torch::Tensor decodePNG(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
Loading