Skip to content

Commit

Permalink
Add GIF decoder (#8406)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored May 8, 2024
1 parent 1644fff commit e4d2d1a
Show file tree
Hide file tree
Showing 22 changed files with 2,644 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/unittest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ eval "$($(which conda) shell.bash hook)" && conda deactivate && conda activate c

echo '::group::Install testing utilities'
# TODO: remove the <8 constraint on pytest when https://github.com/pytorch/vision/issues/8238 is closed
pip install --progress-bar=off "pytest<8" pytest-mock pytest-cov expecttest!=0.2.0
pip install --progress-bar=off "pytest<8" pytest-mock pytest-cov expecttest!=0.2.0 requests
echo '::endgroup::'

python test/smoke_test.py
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
echo '::group::Lint C source'
set +e
./.github/scripts/run-clang-format.py -r torchvision/csrc --clang-format-executable ./clang-format
./.github/scripts/run-clang-format.py -r torchvision/csrc --clang-format-executable ./clang-format --exclude "torchvision/csrc/io/image/cpu/giflib/*"
if [ $? -ne 0 ]; then
git --no-pager diff
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ include(GNUInstallDirs)
include(CMakePackageConfigHelpers)

set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/io/image/cpu/giflib ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda)
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ We don't officially support building from source using `pip`, but _if_ you do, y
#### Other development dependencies (some of these are needed to run tests):

```
pip install expecttest flake8 typing mypy pytest pytest-mock scipy
pip install expecttest flake8 typing mypy pytest pytest-mock scipy requests
```

## Development Process
Expand Down
1 change: 1 addition & 0 deletions docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Images
encode_jpeg
decode_jpeg
write_jpeg
decode_gif
encode_png
decode_png
write_png
Expand Down
27 changes: 15 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,11 @@ def get_extensions():
image_macros += [("NVJPEG_FOUND", str(int(use_nvjpeg)))]

image_path = os.path.join(extensions_dir, "io", "image")
image_src = glob.glob(os.path.join(image_path, "*.cpp")) + glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
image_src = (
glob.glob(os.path.join(image_path, "*.cpp"))
+ glob.glob(os.path.join(image_path, "cpu", "*.cpp"))
+ glob.glob(os.path.join(image_path, "cpu", "giflib", "*.c"))
)

if is_rocm_pytorch:
image_src += glob.glob(os.path.join(image_path, "hip", "*.cpp"))
Expand All @@ -341,18 +345,17 @@ def get_extensions():
else:
image_src += glob.glob(os.path.join(image_path, "cuda", "*.cpp"))

if use_png or use_jpeg:
ext_modules.append(
extension(
"torchvision.image",
image_src,
include_dirs=image_include + include_dirs + [image_path],
library_dirs=image_library + library_dirs,
define_macros=image_macros,
libraries=image_link_flags,
extra_compile_args=extra_compile_args,
)
ext_modules.append(
extension(
"torchvision.image",
image_src,
include_dirs=image_include + include_dirs + [image_path],
library_dirs=image_library + library_dirs,
define_macros=image_macros,
libraries=image_link_flags,
extra_compile_args=extra_compile_args,
)
)

# Locating ffmpeg
ffmpeg_exe = shutil.which("ffmpeg")
Expand Down
46 changes: 45 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import glob
import io
import os
import re
import sys
from pathlib import Path

import numpy as np
import pytest
import requests
import torch
import torchvision.transforms.functional as F
from common_utils import assert_equal, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_read_png_16,
decode_gif,
decode_image,
decode_jpeg,
decode_png,
Expand Down Expand Up @@ -548,5 +551,46 @@ def test_pathlib_support(tmpdir):
write_png(img, write_path)


@pytest.mark.parametrize("name", ("gifgrid", "fire", "porsche", "treescap", "treescap-interlaced", "solid2", "x-trans"))
def test_decode_gif(tmpdir, name):
# Using test images from GIFLIB
# https://sourceforge.net/p/giflib/code/ci/master/tree/pic/, we assert PIL
# and torchvision decoded outputs are equal.
# We're not testing against "welcome2" because PIL and GIFLIB disagee on what
# the background color should be (likely a difference in the way they handle
# transparency?)

path = tmpdir / f"{name}.gif"
url = f"https://sourceforge.net/p/giflib/code/ci/master/tree/pic/{name}.gif?format=raw"
with open(path, "wb") as f:
f.write(requests.get(url).content)

tv_out = read_image(path)
if tv_out.ndim == 3:
tv_out = tv_out[None]

assert tv_out.is_contiguous(memory_format=torch.channels_last)

# For some reason, not using Image.open() as a CM causes "ResourceWarning: unclosed file"
with Image.open(path) as pil_img:
pil_seq = ImageSequence.Iterator(pil_img)

for pil_frame, tv_frame in zip(pil_seq, tv_out):
pil_frame = F.pil_to_tensor(pil_frame.convert("RGB"))
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)


def test_decode_gif_errors():
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
decode_gif(encoded_data[None])
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
decode_gif(encoded_data.float())
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
decode_gif(encoded_data[::2])
with pytest.raises(RuntimeError, match=re.escape("DGifOpenFileName() failed - 103")):
decode_gif(encoded_data)


if __name__ == "__main__":
pytest.main([__file__])
157 changes: 157 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_gif.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#include "decode_gif.h"
#include <cstring>
#include "giflib/gif_lib.h"

namespace vision {
namespace image {

typedef struct reader_helper_t {
uint8_t const* encoded_data; // input tensor data pointer
size_t encoded_data_size; // size of input tensor in bytes
size_t num_bytes_read; // number of bytes read so far in the tensor
} reader_helper_t;

// That function is used by GIFLIB routines to read the encoded bytes.
// This reads `len` bytes and writes them into `buf`. The data is read from the
// input tensor passed to decode_gif() starting at the `num_bytes_read`
// position.
int read_from_tensor(GifFileType* gifFile, GifByteType* buf, int len) {
// the UserData field was set in DGifOpen()
reader_helper_t* reader_helper =
static_cast<reader_helper_t*>(gifFile->UserData);

size_t num_bytes_to_read = std::min(
(size_t)len,
reader_helper->encoded_data_size - reader_helper->num_bytes_read);
std::memcpy(
buf, reader_helper->encoded_data + reader_helper->num_bytes_read, len);
reader_helper->num_bytes_read += num_bytes_to_read;
return num_bytes_to_read;
}

torch::Tensor decode_gif(const torch::Tensor& encoded_data) {
// LibGif docs: https://giflib.sourceforge.net/intro.html
// Refer over there for more details on the libgif API, API ref, and a
// detailed description of the GIF format.

TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");

int error = D_GIF_SUCCEEDED;

// We're using DGidOpen. The other entrypoints of libgif are
// DGifOpenFileName and DGifOpenFileHandle but we don't want to use those,
// since we need to read the encoded bytes from a tensor of encoded bytes, not
// from a file (for consistency with existing jpeg and png decoders). Using
// DGifOpen is the only way to read from a custom source.
// For that we need to provide a reader function `read_from_tensor` that
// reads from the tensor, and we have to keep track of the number of bytes
// read so far: this is why we need the reader_helper struct.

// TODO: We are potentially doing an unnecessary copy of the encoded bytes:
// - 1 copy in from file to tensor (in read_file())
// - 1 copy from tensor to GIFLIB buffers (in read_from_tensor())
// Since we're vendoring GIFLIB we can potentially modify the calls to
// InternalRead() and just set the `buf` pointer to the tensor data directly.
// That might even save allocation of those buffers.
// If we do that, we'd have to make sure the buffers are never written to by
// GIFLIB, otherwise we'd be overridding the tensor data.
reader_helper_t reader_helper;
reader_helper.encoded_data = encoded_data.data_ptr<uint8_t>();
reader_helper.encoded_data_size = encoded_data.numel();
reader_helper.num_bytes_read = 0;
GifFileType* gifFile =
DGifOpen(static_cast<void*>(&reader_helper), read_from_tensor, &error);

TORCH_CHECK(
(gifFile != nullptr) && (error == D_GIF_SUCCEEDED),
"DGifOpenFileName() failed - ",
error);

if (DGifSlurp(gifFile) == GIF_ERROR) {
auto gifFileError = gifFile->Error;
DGifCloseFile(gifFile, &error);
TORCH_CHECK(false, "DGifSlurp() failed - ", gifFileError);
}
auto num_images = gifFile->ImageCount;

// This check should already done within DGifSlurp(), just to be safe
TORCH_CHECK(num_images > 0, "GIF file should contain at least one image!");

// Note:
// The GIF format has this notion of "canvas" and "canvas size", where each
// image could be displayed on the canvas at different offsets, forming a
// mosaic/picture wall like so:
//
// <--- canvas W --->
// ------------------------ ^
// | | | |
// | img1 | img3 | |
// | |------------| canvas H
// |---------- | |
// | img2 | img4 | |
// | | | |
// ------------------------ v
// The GifLib docs indicate that this is mostly vestigial
// (https://giflib.sourceforge.net/whatsinagif/bits_and_bytes.html), and
// modern viewers ignore the canvas size as well as image offsets. Hence,
// we're ignoring that too:
// - We're ignoring the canvas width and height and assume that the shape of
// the canvas and of all images is the shape of the first image.
// - We're enforcing that all images have the same shape.
// - Left and Top offsets of each image are ignored as well and assumed to be
// 0.

auto out_h = gifFile->SavedImages[0].ImageDesc.Height;
auto out_w = gifFile->SavedImages[0].ImageDesc.Width;

// We output a channels-last tensor for consistency with other image decoders.
// Torchvision's resize tends to be is faster on uint8 channels-last tensors.
auto options = torch::TensorOptions()
.dtype(torch::kU8)
.memory_format(torch::MemoryFormat::ChannelsLast);
auto out = torch::empty(
{int64_t(num_images), 3, int64_t(out_h), int64_t(out_w)}, options);
auto out_a = out.accessor<uint8_t, 4>();

for (int i = 0; i < num_images; i++) {
const SavedImage& img = gifFile->SavedImages[i];
const GifImageDesc& desc = img.ImageDesc;
TORCH_CHECK(
desc.Width == out_w && desc.Height == out_h,
"All images in the gif should have the same dimensions.");

const ColorMapObject* cmap =
desc.ColorMap ? desc.ColorMap : gifFile->SColorMap;
TORCH_CHECK(
cmap != nullptr,
"Global and local color maps are missing. This should never happen!");

for (int h = 0; h < desc.Height; h++) {
for (int w = 0; w < desc.Width; w++) {
auto c = img.RasterBits[h * desc.Width + w];
GifColorType rgb = cmap->Colors[c];
out_a[i][0][h][w] = rgb.Red;
out_a[i][1][h][w] = rgb.Green;
out_a[i][2][h][w] = rgb.Blue;
}
}
}
out = out.squeeze(0); // remove batch dim if there's only one image

DGifCloseFile(gifFile, &error);
TORCH_CHECK(error == D_GIF_SUCCEEDED, "DGifCloseFile() failed - ", error);

return out;
}

} // namespace image
} // namespace vision
12 changes: 12 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_gif.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <torch/types.h>

namespace vision {
namespace image {

// encoded_data tensor must be 1D uint8 and contiguous
C10_EXPORT torch::Tensor decode_gif(const torch::Tensor& encoded_data);

} // namespace image
} // namespace vision
11 changes: 10 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decode_image.h"

#include "decode_gif.h"
#include "decode_jpeg.h"
#include "decode_png.h"

Expand All @@ -23,16 +24,24 @@ torch::Tensor decode_image(

const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
const uint8_t gif_signature_1[6] = {
0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a"
const uint8_t gif_signature_2[6] = {
0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a"

if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(
data, mode, /*allow_16_bits=*/false, apply_exif_orientation);
} else if (
memcmp(gif_signature_1, datap, 6) == 0 ||
memcmp(gif_signature_2, datap, 6) == 0) {
return decode_gif(data);
} else {
TORCH_CHECK(
false,
"Unsupported image file. Only jpeg and png ",
"Unsupported image file. Only jpeg, png and gif ",
"are currently supported.");
}
}
Expand Down
28 changes: 28 additions & 0 deletions torchvision/csrc/io/image/cpu/giflib/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
These files come from the GIFLIB project (https://giflib.sourceforge.net/) and
are licensed under the MIT license.

Some modifications have been made to the original files:
- Remove use of "register" keyword in gifalloc.c for C++17 compatibility.
- Declare loop variable i in DGifGetImageHeader as int instead of unsigned int.

Below is the original license text from the COPYING file of the GIFLIB project:

= MIT LICENSE

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
Loading

0 comments on commit e4d2d1a

Please sign in to comment.