Skip to content

Commit

Permalink
Add support of mode and remove channels (pytorch#3024)
Browse files Browse the repository at this point in the history
* Add support of mode and remove channels.

* Replacing integer mode with define constants.
  • Loading branch information
datumbox authored and bryant1410 committed Nov 22, 2020
1 parent 8fe6709 commit 7098884
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 124 deletions.
17 changes: 8 additions & 9 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import io
import glob
import unittest
import sys

import torch
import torchvision
from PIL import Image
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file)
encode_png, write_png, write_file, ImageReadMode)
import numpy as np

from common_utils import get_tmp_dir
Expand Down Expand Up @@ -49,9 +47,9 @@ def normalize_dimensions(img_pil):

class ImageTester(unittest.TestCase):
def test_decode_jpeg(self):
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("RGB", ImageReadMode.RGB)]
for img_path in get_images(IMAGE_ROOT, ".jpg"):
for pil_mode, channels in conversion:
for pil_mode, mode in conversion:
with Image.open(img_path) as img:
is_cmyk = img.mode == "CMYK"
if pil_mode is not None:
Expand All @@ -66,7 +64,7 @@ def test_decode_jpeg(self):

img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_ljpeg = decode_image(data, channels=channels)
img_ljpeg = decode_image(data, mode=mode)

# Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG.
Expand Down Expand Up @@ -165,17 +163,18 @@ def test_write_jpeg(self):
self.assertEqual(torch_bytes, pil_bytes)

def test_decode_png(self):
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
conversion = [(None, ImageReadMode.UNCHANGED), ("L", ImageReadMode.GRAY), ("LA", ImageReadMode.GRAY_ALPHA),
("RGB", ImageReadMode.RGB), ("RGBA", ImageReadMode.RGB_ALPHA)]
for img_path in get_images(FAKEDATA_DIR, ".png"):
for pil_mode, channels in conversion:
for pil_mode, mode in conversion:
with Image.open(img_path) as img:
if pil_mode is not None:
img = img.convert(pil_mode)
img_pil = torch.from_numpy(np.array(img))

img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_lpng = decode_image(data, channels=channels)
img_lpng = decode_image(data, mode=mode)

tol = 0 if conversion is None else 1
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
Expand Down
9 changes: 9 additions & 0 deletions torchvision/csrc/cpu/image/image_read_mode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

/* Should be kept in-sync with Python ImageReadMode enum */
using ImageReadMode = int64_t;
#define IMAGE_READ_MODE_UNCHANGED 0
#define IMAGE_READ_MODE_GRAY 1
#define IMAGE_READ_MODE_GRAY_ALPHA 2
#define IMAGE_READ_MODE_RGB 3
#define IMAGE_READ_MODE_RGB_ALPHA 4
11 changes: 5 additions & 6 deletions torchvision/csrc/cpu/image/read_image_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
#include "read_image_cpu.h"
#include <cstring>
#include "readjpeg_cpu.h"
#include "readpng_cpu.h"

torch::Tensor decode_image(const torch::Tensor& data, int64_t channels) {
torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");

auto datap = data.data_ptr<uint8_t>();

const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decodeJPEG(data, channels);
return decodeJPEG(data, mode);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decodePNG(data, channels);
return decodePNG(data, mode);
} else {
TORCH_CHECK(
false,
Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/cpu/image/read_image_cpu.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include "readjpeg_cpu.h"
#include "readpng_cpu.h"
#include <torch/torch.h>
#include "image_read_mode.h"

C10_EXPORT torch::Tensor decode_image(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
33 changes: 16 additions & 17 deletions torchvision/csrc/cpu/image/readjpeg_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#include "readjpeg_cpu.h"

#include <ATen/ATen.h>
#include <string>

#if !JPEG_FOUND
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(
false, "decodeJPEG: torchvision not compiled with libjpeg support");
}
Expand Down Expand Up @@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr(
src->pub.next_input_byte = src->data;
}

torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels == 0 || channels == 1 || channels == 3,
"Number of channels not supported");

struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
Expand All @@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) {
// read info from header.
jpeg_read_header(&cinfo, TRUE);

int current_channels = cinfo.num_components;
int channels = cinfo.num_components;

if (channels > 0 && channels != current_channels) {
switch (channels) {
case 1: // Gray
cinfo.out_color_space = JCS_GRAYSCALE;
if (mode != IMAGE_READ_MODE_UNCHANGED) {
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) {
cinfo.out_color_space = JCS_GRAYSCALE;
channels = 1;
}
break;
case 3: // RGB
cinfo.out_color_space = JCS_RGB;
case IMAGE_READ_MODE_RGB:
if (cinfo.jpeg_color_space != JCS_RGB) {
cinfo.out_color_space = JCS_RGB;
channels = 3;
}
break;
/*
* Libjpeg does not support converting from CMYK to grayscale etc. There
* is a way to do this but it involves converting it manually to RGB:
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
*
*/
default:
jpeg_destroy_decompress(&cinfo);
TORCH_CHECK(false, "Invalid number of output channels.");
TORCH_CHECK(false, "Provided mode not supported");
}

jpeg_calc_output_dimensions(&cinfo);
} else {
channels = current_channels;
}

jpeg_start_decompress(&cinfo);
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/cpu/image/readjpeg_cpu.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include <torch/torch.h>
#include "image_read_mode.h"

C10_EXPORT torch::Tensor decodeJPEG(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
118 changes: 62 additions & 56 deletions torchvision/csrc/cpu/image/readpng_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
#include "readpng_cpu.h"

// Comment
#include <ATen/ATen.h>
#include <string>

#if !PNG_FOUND
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support");
}
#else
#include <png.h>
#include <setjmp.h>

torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(
channels >= 0 && channels <= 4, "Number of channels not supported");

auto png_ptr =
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
Expand Down Expand Up @@ -74,75 +70,85 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}

int current_channels = png_get_channels(png_ptr, info_ptr);
int channels = png_get_channels(png_ptr, info_ptr);

if (channels > 0) {
if (mode != IMAGE_READ_MODE_UNCHANGED) {
// TODO: consider supporting PNG_INFO_tRNS
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;

switch (channels) {
case 1: // Gray
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
switch (mode) {
case IMAGE_READ_MODE_GRAY:
if (color_type != PNG_COLOR_TYPE_GRAY) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
channels = 1;
}
break;
case 2: // Gray + Alpha
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
case IMAGE_READ_MODE_GRAY_ALPHA:
if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}

if (has_color) {
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
}
channels = 2;
}
break;
case 3:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
case IMAGE_READ_MODE_RGB:
if (color_type != PNG_COLOR_TYPE_RGB) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (has_alpha) {
png_set_strip_alpha(png_ptr);
}
channels = 3;
}
break;
case 4:
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
case IMAGE_READ_MODE_RGB_ALPHA:
if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
if (is_palette) {
png_set_palette_to_rgb(png_ptr);
has_alpha = true;
} else if (!has_color) {
png_set_gray_to_rgb(png_ptr);
}

if (!has_alpha) {
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
}
channels = 4;
}
break;
default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Invalid number of output channels.");
TORCH_CHECK(false, "Provided mode not supported");
}

png_read_update_info(png_ptr, info_ptr);
} else {
channels = current_channels;
}

auto tensor =
Expand Down
5 changes: 2 additions & 3 deletions torchvision/csrc/cpu/image/readpng_cpu.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#pragma once

// Comment
#include <torch/torch.h>
#include <string>
#include "image_read_mode.h"

C10_EXPORT torch::Tensor decodePNG(
const torch::Tensor& data,
int64_t channels = 0);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
Loading

0 comments on commit 7098884

Please sign in to comment.