diff --git a/.github/scripts/export_IS_M1_CONDA_BUILD_JOB.sh b/.github/scripts/export_IS_M1_CONDA_BUILD_JOB.sh new file mode 100755 index 00000000000..1cca56ddc56 --- /dev/null +++ b/.github/scripts/export_IS_M1_CONDA_BUILD_JOB.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +export IS_M1_CONDA_BUILD_JOB=1 diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index cf059f97a44..adb1256303f 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -30,6 +30,7 @@ conda create \ python="${PYTHON_VERSION}" pip \ ninja cmake \ libpng \ + libwebp \ 'ffmpeg<4.3' conda activate ci conda install --quiet --yes libjpeg-turbo -c pytorch diff --git a/.github/workflows/build-conda-m1.yml b/.github/workflows/build-conda-m1.yml index bd10e4e4634..e8f6546a678 100644 --- a/.github/workflows/build-conda-m1.yml +++ b/.github/workflows/build-conda-m1.yml @@ -42,6 +42,7 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + env-var-script: ./.github/scripts/export_IS_M1_CONDA_BUILD_JOB.sh pre-script: ${{ matrix.pre-script }} post-script: ${{ matrix.post-script }} package-name: ${{ matrix.package-name }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 2db9c1e274a..cf305c4ec17 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,10 @@ option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_MPS "Enable MPS support" OFF) option(WITH_PNG "Enable features requiring LibPNG." ON) option(WITH_JPEG "Enable features requiring LibJPEG." ON) +# Libwebp is disabled by default, which means enabling it from cmake is largely +# untested. Since building from cmake is very low pri anyway, this is OK. If +# you're a user and you need this, please open an issue (and a PR!). +option(WITH_WEBP "Enable features requiring LibWEBP." OFF) if(WITH_CUDA) enable_language(CUDA) @@ -32,6 +36,11 @@ if (WITH_JPEG) find_package(JPEG REQUIRED) endif() +if (WITH_WEBP) + add_definitions(-DWEBP_FOUND) + find_package(WEBP REQUIRED) +endif() + function(CUDA_CONVERT_FLAGS EXISTING_TARGET) get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS) if(NOT "${old_flags}" STREQUAL "") @@ -104,6 +113,10 @@ if (WITH_JPEG) target_link_libraries(${PROJECT_NAME} PRIVATE ${JPEG_LIBRARIES}) endif() +if (WITH_WEBP) + target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES}) +endif() + set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib) @@ -118,6 +131,10 @@ if (WITH_JPEG) include_directories(${JPEG_INCLUDE_DIRS}) endif() +if (WITH_WEBP) + include_directories(${WEBP_INCLUDE_DIRS}) +endif() + set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake") configure_package_config_file(cmake/TorchVisionConfig.cmake.in diff --git a/docs/source/io.rst b/docs/source/io.rst index f8258713163..638f310bf69 100644 --- a/docs/source/io.rst +++ b/docs/source/io.rst @@ -10,6 +10,11 @@ videos. Images ------ +Torchvision currently supports decoding JPEG, PNG, WEBP and GIF images. JPEG +decoding can also be done on CUDA GPUs. + +For encoding, JPEG (cpu and CUDA) and PNG are supported. + .. autosummary:: :toctree: generated/ :template: function.rst @@ -20,6 +25,7 @@ Images decode_jpeg write_jpeg decode_gif + decode_webp encode_png decode_png write_png diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 9cd1518d525..9b1f93b5abe 100644 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -1,4 +1,5 @@ #!/bin/bash + if [[ "$(uname)" == Darwin ]]; then # Uninstall Conflicting jpeg brew formulae jpeg_packages=$(brew list | grep jpeg) @@ -12,8 +13,10 @@ if [[ "$(uname)" == Darwin ]]; then fi if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then - # Install libpng from Anaconda (defaults) - conda install libpng -yq + conda install libpng libwebp -yq + # Installing webp also installs a non-turbo jpeg, so we uninstall jpeg stuff + # before re-installing them + conda uninstall libjpeg-turbo libjpeg -y conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch # Copy binaries to be included in the wheel distribution @@ -29,7 +32,7 @@ else conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch-nightly fi - yum install -y libjpeg-turbo-devel freetype gnutls + yum install -y libjpeg-turbo-devel libwebp-devel freetype gnutls pip install auditwheel fi diff --git a/packaging/torchvision/meta.yaml b/packaging/torchvision/meta.yaml index 78ac930f8e5..a847328a77e 100644 --- a/packaging/torchvision/meta.yaml +++ b/packaging/torchvision/meta.yaml @@ -11,6 +11,7 @@ requirements: - {{ compiler('c') }} # [win] - libpng - libjpeg-turbo + - libwebp - ffmpeg >=4.2.2, <5.0.0 # [linux] host: @@ -28,6 +29,7 @@ requirements: - libpng - ffmpeg >=4.2.2, <5.0.0 # [linux] - libjpeg-turbo + - libwebp - pillow >=5.3.0, !=8.3.* - pytorch-mutex 1.0 {{ build_variant }} # [not osx ] {{ environ.get('CONDA_PYTORCH_CONSTRAINT', 'pytorch') }} diff --git a/setup.py b/setup.py index 1dc863b040a..fb3b503e6e6 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ DEBUG = os.getenv("DEBUG", "0") == "1" USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1" USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" +USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) USE_FFMPEG = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1" @@ -41,6 +42,7 @@ print(f"{DEBUG = }") print(f"{USE_PNG = }") print(f"{USE_JPEG = }") +print(f"{USE_WEBP = }") print(f"{USE_NVJPEG = }") print(f"{NVCC_FLAGS = }") print(f"{USE_FFMPEG = }") @@ -308,6 +310,22 @@ def make_image_extension(): else: warnings.warn("Building torchvision without JPEG support") + if USE_WEBP: + webp_found, webp_include_dir, webp_library_dir = find_library(header="webp/decode.h") + if webp_found: + print("Building torchvision with WEBP support") + print(f"{webp_include_dir = }") + print(f"{webp_library_dir = }") + if webp_include_dir is not None and webp_library_dir is not None: + # if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add. + include_dirs.append(webp_include_dir) + library_dirs.append(webp_library_dir) + webp_library = "libwebp" if sys.platform == "win32" else "webp" + libraries.append(webp_library) + define_macros += [("WEBP_FOUND", 1)] + else: + warnings.warn("Building torchvision without WEBP support") + if USE_NVJPEG and torch.cuda.is_available(): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() diff --git a/test/assets/fakedata/logos/rgb_pytorch.webp b/test/assets/fakedata/logos/rgb_pytorch.webp new file mode 100644 index 00000000000..e594584d76d Binary files /dev/null and b/test/assets/fakedata/logos/rgb_pytorch.webp differ diff --git a/test/smoke_test.py b/test/smoke_test.py index 464fabee935..f98d019bea5 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -1,13 +1,15 @@ """Run smoke tests""" +import os import sys from pathlib import Path import torch import torchvision -from torchvision.io import decode_jpeg, read_file, read_image +from torchvision.io import decode_jpeg, decode_webp, read_file, read_image from torchvision.models import resnet50, ResNet50_Weights + SCRIPT_DIR = Path(__file__).parent @@ -25,6 +27,9 @@ def smoke_test_torchvision_read_decode() -> None: img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png")) if img_png.shape != (4, 471, 354): raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") + img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp")) + if img_webp.shape != (3, 100, 100): + raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}") def smoke_test_torchvision_decode_jpeg(device: str = "cpu"): @@ -77,11 +82,16 @@ def main() -> None: print(f"torchvision: {torchvision.__version__}") print(f"torch.cuda.is_available: {torch.cuda.is_available()}") - # Turn 1.11.0aHASH into 1.11 (major.minor only) - version = ".".join(torchvision.__version__.split(".")[:2]) - if version >= "0.16": - print(f"{torch.ops.image._jpeg_version() = }") - assert torch.ops.image._is_compiled_against_turbo() + print(f"{torch.ops.image._jpeg_version() = }") + if not torch.ops.image._is_compiled_against_turbo(): + msg = "Torchvision wasn't compiled against libjpeg-turbo" + if os.getenv("IS_M1_CONDA_BUILD_JOB") == "1": + # When building the conda package on M1, it's difficult to enforce + # that we build against turbo due to interactions with the libwebp + # package. So we just accept it, instead of raising an error. + print(msg) + else: + raise ValueError(msg) smoke_test_torchvision() smoke_test_torchvision_read_decode() diff --git a/test/test_image.py b/test/test_image.py index f083e53b87b..cce7d6e0ff7 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -18,6 +18,7 @@ decode_image, decode_jpeg, decode_png, + decode_webp, encode_jpeg, encode_png, ImageReadMode, @@ -861,16 +862,32 @@ def test_decode_gif(tmpdir, name, scripted): torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0) -def test_decode_gif_errors(): +@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp)) +def test_decode_gif_webp_errors(decode_fun): encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8) with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"): - decode_gif(encoded_data[None]) + decode_fun(encoded_data[None]) with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"): - decode_gif(encoded_data.float()) + decode_fun(encoded_data.float()) with pytest.raises(RuntimeError, match="Input tensor must be contiguous"): - decode_gif(encoded_data[::2]) - with pytest.raises(RuntimeError, match=re.escape("DGifOpenFileName() failed - 103")): - decode_gif(encoded_data) + decode_fun(encoded_data[::2]) + if decode_fun is decode_gif: + expected_match = re.escape("DGifOpenFileName() failed - 103") + else: + expected_match = "WebPDecodeRGB failed." + with pytest.raises(RuntimeError, match=expected_match): + decode_fun(encoded_data) + + +@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image)) +@pytest.mark.parametrize("scripted", (False, True)) +def test_decode_webp(decode_fun, scripted): + encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".webp"))) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img = decode_fun(encoded_bytes) + assert img.shape == (3, 100, 100) + assert img[None].is_contiguous(memory_format=torch.channels_last) if __name__ == "__main__": diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 3a18406042e..ed527a44b31 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -3,6 +3,7 @@ #include "decode_gif.h" #include "decode_jpeg.h" #include "decode_png.h" +#include "decode_webp.h" namespace vision { namespace image { @@ -20,29 +21,43 @@ torch::Tensor decode_image( data.dim() == 1 && data.numel() > 0, "Expected a non empty 1-dimensional tensor"); + auto err_msg = + "Unsupported image file. Only jpeg, png and gif are currently supported."; + auto datap = data.data_ptr(); const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF" + TORCH_CHECK(data.numel() >= 3, err_msg); + if (memcmp(jpeg_signature, datap, 3) == 0) { + return decode_jpeg(data, mode, apply_exif_orientation); + } + const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG" + TORCH_CHECK(data.numel() >= 4, err_msg); + if (memcmp(png_signature, datap, 4) == 0) { + return decode_png(data, mode, apply_exif_orientation); + } + const uint8_t gif_signature_1[6] = { 0x47, 0x49, 0x46, 0x38, 0x39, 0x61}; // == "GIF89a" const uint8_t gif_signature_2[6] = { 0x47, 0x49, 0x46, 0x38, 0x37, 0x61}; // == "GIF87a" - - if (memcmp(jpeg_signature, datap, 3) == 0) { - return decode_jpeg(data, mode, apply_exif_orientation); - } else if (memcmp(png_signature, datap, 4) == 0) { - return decode_png(data, mode, apply_exif_orientation); - } else if ( - memcmp(gif_signature_1, datap, 6) == 0 || + TORCH_CHECK(data.numel() >= 6, err_msg); + if (memcmp(gif_signature_1, datap, 6) == 0 || memcmp(gif_signature_2, datap, 6) == 0) { return decode_gif(data); - } else { - TORCH_CHECK( - false, - "Unsupported image file. Only jpeg, png and gif ", - "are currently supported."); } + + const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" + const uint8_t webp_signature_end[7] = { + 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" + TORCH_CHECK(data.numel() >= 15, err_msg); + if ((memcmp(webp_signature_begin, datap, 4) == 0) && + (memcmp(webp_signature_end, datap + 8, 7) == 0)) { + return decode_webp(data); + } + + TORCH_CHECK(false, err_msg); } } // namespace image diff --git a/torchvision/csrc/io/image/cpu/decode_webp.cpp b/torchvision/csrc/io/image/cpu/decode_webp.cpp new file mode 100644 index 00000000000..844ce61a3e3 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_webp.cpp @@ -0,0 +1,40 @@ +#include "decode_webp.h" + +#if WEBP_FOUND +#include "webp/decode.h" +#endif // WEBP_FOUND + +namespace vision { +namespace image { + +#if !WEBP_FOUND +torch::Tensor decode_webp(const torch::Tensor& data) { + TORCH_CHECK( + false, "decode_webp: torchvision not compiled with libwebp support"); +} +#else + +torch::Tensor decode_webp(const torch::Tensor& encoded_data) { + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1, + "Input tensor must be 1-dimensional, got ", + encoded_data.dim(), + " dims."); + + int width = 0; + int height = 0; + auto decoded_data = WebPDecodeRGB( + encoded_data.data_ptr(), encoded_data.numel(), &width, &height); + TORCH_CHECK(decoded_data != nullptr, "WebPDecodeRGB failed."); + auto out = torch::from_blob(decoded_data, {height, width, 3}, torch::kUInt8); + return out.permute({2, 0, 1}); // return CHW, channels-last +} +#endif // WEBP_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_webp.h b/torchvision/csrc/io/image/cpu/decode_webp.h new file mode 100644 index 00000000000..00a0c3362f7 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_webp.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_webp(const torch::Tensor& data); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 9f7563eebf8..8ca2f814996 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -21,6 +21,7 @@ static auto registry = .op("image::encode_png", &encode_png) .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_jpeg) + .op("image::decode_webp", &decode_webp) .op("image::encode_jpeg", &encode_jpeg) .op("image::read_file", &read_file) .op("image::write_file", &write_file) diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index f7e9b63801c..3f47fdec65c 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -4,6 +4,7 @@ #include "cpu/decode_image.h" #include "cpu/decode_jpeg.h" #include "cpu/decode_png.h" +#include "cpu/decode_webp.h" #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index f38ce687a0d..780b6ab333e 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -25,6 +25,7 @@ decode_image, decode_jpeg, decode_png, + decode_webp, encode_jpeg, encode_png, ImageReadMode, diff --git a/torchvision/io/image.py b/torchvision/io/image.py index eec073ce55e..3414e280e68 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -78,7 +78,7 @@ def decode_png( The values of the output tensor are in uint8 in [0, 255] for most cases. If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] - (supported from torchvision ``0.21``. Since uint16 support is limited in + (supported from torchvision ``0.21``). Since uint16 support is limited in pytorch, we recommend calling :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` after this function to convert the decoded image into a uint8 or float @@ -277,12 +277,13 @@ def decode_image( apply_exif_orientation: bool = False, ) -> torch.Tensor: """ - Detect whether an image is a JPEG, PNG or GIF and performs the appropriate - operation to decode the image into a 3 dimensional RGB or grayscale Tensor. + Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the + appropriate operation to decode the image into a Tensor. - The values of the output tensor are in uint8 in [0, 255] for most cases. If - the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] - (supported from torchvision ``0.21``. Since uint16 support is limited in + The values of the output tensor are in uint8 in [0, 255] for most cases. + + If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] + (supported from torchvision ``0.21``). Since uint16 support is limited in pytorch, we recommend calling :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` after this function to convert the decoded image into a uint8 or float @@ -290,13 +291,13 @@ def decode_image( Args: input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the - PNG or JPEG image. + image. mode (ImageReadMode): the read mode used for optionally converting the image. Default: ``ImageReadMode.UNCHANGED``. See ``ImageReadMode`` class for more information on various - available modes. Ignored for GIFs. + available modes. Only applies to JPEG and PNG images. apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. - Ignored for GIFs. Default: False. + Only applies to JPEG and PNG images. Default: False. Returns: output (Tensor[image_channels, image_height, image_width]) @@ -313,10 +314,11 @@ def read_image( apply_exif_orientation: bool = False, ) -> torch.Tensor: """ - Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor. + Reads a JPEG, PNG, WEBP, or GIF image into a Tensor. - The values of the output tensor are in uint8 in [0, 255] for most cases. If - the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] + The values of the output tensor are in uint8 in [0, 255] for most cases. + + If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] (supported from torchvision ``0.21``. Since uint16 support is limited in pytorch, we recommend calling :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` @@ -324,13 +326,13 @@ def read_image( tensor. Args: - path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image. + path (str or ``pathlib.Path``): path of the image. mode (ImageReadMode): the read mode used for optionally converting the image. Default: ``ImageReadMode.UNCHANGED``. See ``ImageReadMode`` class for more information on various - available modes. Ignored for GIFs. + available modes. Only applies to JPEG and PNG images. apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. - Ignored for GIFs. Default: False. + Only applies to JPEG and PNG images. Default: False. Returns: output (Tensor[image_channels, image_height, image_width]) @@ -359,3 +361,24 @@ def decode_gif(input: torch.Tensor) -> torch.Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_gif) return torch.ops.image.decode_gif(input) + + +def decode_webp( + input: torch.Tensor, +) -> torch.Tensor: + """ + Decode a WEBP image into a 3 dimensional RGB Tensor. + + The values of the output tensor are uint8 between 0 and 255. If the input + image is RGBA, the transparency is ignored. + + Args: + input (Tensor[1]): a one dimensional contiguous uint8 tensor containing + the raw bytes of the WEBP image. + + Returns: + Decoded image (Tensor[image_channels, image_height, image_width]) + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(decode_webp) + return torch.ops.image.decode_webp(input)