Skip to content

Commit

Permalink
Cleanup/refactor of decoders and related tests (#8617)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 30, 2024
1 parent c33b00a commit db60022
Show file tree
Hide file tree
Showing 18 changed files with 141 additions and 133 deletions.
100 changes: 61 additions & 39 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@
IS_WINDOWS = sys.platform in ("win32", "cygwin")
IS_MACOS = sys.platform == "darwin"
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")

# Hacky way of figuring out whether we compiled with libavif/libheif (those are
# currenlty disabled by default)
try:
_decode_avif(torch.arange(10, dtype=torch.uint8))
except Exception as e:
DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e)

try:
_decode_heic(torch.arange(10, dtype=torch.uint8))
except Exception as e:
DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e)


def _get_safe_image_name(name):
Expand Down Expand Up @@ -149,17 +162,6 @@ def test_invalid_exif(tmpdir, size):
torch.testing.assert_close(expected, output)


def test_decode_jpeg_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100,), dtype=torch.float16))

with pytest.raises(RuntimeError, match="Not a JPEG file"):
decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
# sanity check: make sure we can decode the bad Huffman encoding
bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
Expand Down Expand Up @@ -235,10 +237,6 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):


def test_decode_png_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_png(torch.empty((), dtype=torch.uint8))
with pytest.raises(RuntimeError, match="Content is not png"):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
with pytest.raises(RuntimeError, match="Content is too small for png"):
Expand Down Expand Up @@ -864,20 +862,28 @@ def test_decode_gif(tmpdir, name, scripted):
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)


@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp))
def test_decode_gif_webp_errors(decode_fun):
decode_fun_and_match = [
(decode_png, "Content is not png"),
(decode_jpeg, "Not a JPEG file"),
(decode_gif, re.escape("DGifOpenFileName() failed - 103")),
(decode_webp, "WebPGetFeatures failed."),
]
if DECODE_AVIF_ENABLED:
decode_fun_and_match.append((_decode_avif, "BMFF parsing failed"))
if DECODE_HEIC_ENABLED:
decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box"))


@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match)
def test_decode_bad_encoded_data(decode_fun, match):
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
decode_fun(encoded_data[None])
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
decode_fun(encoded_data.float())
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
decode_fun(encoded_data[::2])
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
elif decode_fun is decode_webp:
expected_match = "WebPGetFeatures failed."
with pytest.raises(RuntimeError, match=expected_match):
with pytest.raises(RuntimeError, match=match):
decode_fun(encoded_data)


Expand All @@ -890,21 +896,27 @@ def test_decode_webp(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


# This test is skipped because it requires webp images that we're not including
# within the repo. The test images were downloaded from the different pages of
# https://developers.google.com/speed/webp/gallery
# Note that converting an RGBA image to RGB leads to bad results because the
# transparent pixels aren't necessarily set to "black" or "white", they can be
# random stuff. This is consistent with PIL results.
@pytest.mark.skip(reason="Need to download test images first")
# This test is skipped by default because it requires webp images that we're not
# including within the repo. The test images were downloaded manually from the
# different pages of https://developers.google.com/speed/webp/gallery
@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set")
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None))
"mode, pil_mode",
(
# Note that converting an RGBA image to RGB leads to bad results because the
# transparent pixels aren't necessarily set to "black" or "white", they can be
# random stuff. This is consistent with PIL results.
(ImageReadMode.RGB, "RGB"),
(ImageReadMode.RGB_ALPHA, "RGBA"),
(ImageReadMode.UNCHANGED, None),
),
)
@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp"))
@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name)
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
encoded_bytes = read_file(filename)
if scripted:
Expand All @@ -915,9 +927,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename)
pil_img = Image.open(filename).convert(pil_mode)
from_pil = F.pil_to_tensor(pil_img)
assert_equal(img, from_pil)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_avif(decode_fun, scripted):
Expand All @@ -927,12 +940,20 @@ def test_decode_avif(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.")
# Note: decode_image fails because some of these files have a (valid) signature
# we don't recognize. We should probably use libmagic....
@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic))
decode_funs = []
if DECODE_AVIF_ENABLED:
decode_funs.append(_decode_avif)
if DECODE_HEIC_ENABLED:
decode_funs.append(_decode_heic)


@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.")
@pytest.mark.parametrize("decode_fun", decode_funs)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode",
Expand All @@ -945,7 +966,7 @@ def test_decode_avif(decode_fun, scripted):
@pytest.mark.parametrize(
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
)
def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename):
def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename):
if "reversed_dimg_order" in str(filename):
# Pillow properly decodes this one, but we don't (order of parts of the
# image is wrong). This is due to a bug that was recently fixed in
Expand Down Expand Up @@ -996,21 +1017,21 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
g = make_grid([img, from_pil])
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))

is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
if mode == ImageReadMode.RGB and not is__decode_heic:
is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
if mode == ImageReadMode.RGB and not is_decode_heic:
# We don't compare torchvision's AVIF against PIL for RGB because
# results look pretty different on RGBA images (other images are fine).
# The result on torchvision basically just plainly ignores the alpha
# channel, resuting in transparent pixels looking dark. PIL seems to be
# using a sort of k-nn thing (Take a look at the resuting images)
return
if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic:
if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic:
return

torch.testing.assert_close(img, from_pil, rtol=0, atol=3)


@pytest.mark.xfail(reason="HEIC support not enabled yet.")
@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_heic(decode_fun, scripted):
Expand All @@ -1020,6 +1041,7 @@ def test_decode_heic(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


if __name__ == "__main__":
Expand Down
43 changes: 43 additions & 0 deletions torchvision/csrc/io/image/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

#include "common.h"
#include <torch/torch.h>

namespace vision {
namespace image {

void validate_encoded_data(const torch::Tensor& encoded_data) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1 && encoded_data.numel() > 0,
"Input tensor must be 1-dimensional and non-empty, got ",
encoded_data.dim(),
" dims and ",
encoded_data.numel(),
" numels.");
}

bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
ImageReadMode mode,
bool has_alpha) {
// Return true if the calling decoding function should return a 3D RGB tensor,
// and false if it should return a 4D RGBA tensor.
// This function ignores the requested "grayscale" modes and treats it as
// "unchanged", so it should only used on decoders who don't support grayscale
// outputs.

if (mode == IMAGE_READ_MODE_RGB) {
return true;
}
if (mode == IMAGE_READ_MODE_RGB_ALPHA) {
return false;
}
// From here we assume mode is "unchanged", even for grayscale ones.
return !has_alpha;
}

} // namespace image
} // namespace vision
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <stdint.h>
#include <torch/torch.h>

namespace vision {
namespace image {
Expand All @@ -13,5 +14,11 @@ const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
const ImageReadMode IMAGE_READ_MODE_RGB = 3;
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;

void validate_encoded_data(const torch::Tensor& encoded_data);

bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
ImageReadMode mode,
bool has_alpha);

} // namespace image
} // namespace vision
26 changes: 5 additions & 21 deletions torchvision/csrc/io/image/cpu/decode_avif.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "decode_avif.h"
#include "../common.h"

#if AVIF_FOUND
#include "avif/avif.h"
Expand Down Expand Up @@ -33,16 +34,7 @@ torch::Tensor decode_avif(
// Refer there for more detail about what each function does, and which
// structure/data is available after which call.

TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");
validate_encoded_data(encoded_data);

DecoderPtr decoder(avifDecoderCreate());
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder.");
Expand All @@ -60,6 +52,7 @@ torch::Tensor decode_avif(
result == AVIF_RESULT_OK,
"avifDecoderParse failed: ",
avifResultToString(result));
printf("avif num images = %d\n", decoder->imageCount);
TORCH_CHECK(
decoder->imageCount == 1, "Avif file contains more than one image");

Expand All @@ -78,18 +71,9 @@ torch::Tensor decode_avif(
auto use_uint8 = (decoder->image->depth <= 8);
rgb.depth = use_uint8 ? 8 : 16;

if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
mode != IMAGE_READ_MODE_RGB_ALPHA) {
// Other modes aren't supported, but we don't error or even warn because we
// have generic entry points like decode_image which may support all modes,
// it just depends on the underlying decoder.
mode = IMAGE_READ_MODE_UNCHANGED;
}

// If return_rgb is false it means we return rgba - nothing else.
auto return_rgb =
(mode == IMAGE_READ_MODE_RGB ||
(mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent));
should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
mode, decoder->alphaPresent);

auto num_channels = return_rgb ? 3 : 4;
rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA;
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/cpu/decode_avif.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"
#include "../common.h"

namespace vision {
namespace image {
Expand Down
12 changes: 2 additions & 10 deletions torchvision/csrc/io/image/cpu/decode_gif.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decode_gif.h"
#include <cstring>
#include "../common.h"
#include "giflib/gif_lib.h"

namespace vision {
Expand Down Expand Up @@ -34,16 +35,7 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) {
// Refer over there for more details on the libgif API, API ref, and a
// detailed description of the GIF format.

TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");
validate_encoded_data(encoded_data);

int error = D_GIF_SUCCEEDED;

Expand Down
25 changes: 4 additions & 21 deletions torchvision/csrc/io/image/cpu/decode_heic.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "decode_heic.h"
#include "../common.h"

#if HEIC_FOUND
#include "libheif/heif_cxx.h"
Expand All @@ -19,26 +20,8 @@ torch::Tensor decode_heic(
torch::Tensor decode_heic(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");

if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
mode != IMAGE_READ_MODE_RGB_ALPHA) {
// Other modes aren't supported, but we don't error or even warn because we
// have generic entry points like decode_image which may support all modes,
// it just depends on the underlying decoder.
mode = IMAGE_READ_MODE_UNCHANGED;
}
validate_encoded_data(encoded_data);

// If return_rgb is false it means we return rgba - nothing else.
auto return_rgb = true;

int height = 0;
Expand Down Expand Up @@ -82,8 +65,8 @@ torch::Tensor decode_heic(
bit_depth = handle.get_luma_bits_per_pixel();

return_rgb =
(mode == IMAGE_READ_MODE_RGB ||
(mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel()));
should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
mode, handle.has_alpha_channel());

height = handle.get_height();
width = handle.get_width();
Expand Down
Loading

0 comments on commit db60022

Please sign in to comment.