Skip to content

Commit

Permalink
Support for decoding jpegs on GPU with nvjpeg (#3792)
Browse files Browse the repository at this point in the history
Co-authored-by: James Thewlis <[email protected]>
  • Loading branch information
NicolasHug and jamt9000 authored May 11, 2021
1 parent 4500208 commit f87ce88
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ include(CMakePackageConfigHelpers)

set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda)
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
Expand Down
17 changes: 16 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,23 @@ def get_extensions():
image_library += [jpeg_lib]
image_include += [jpeg_include]

# Locating nvjpeg
# Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
nvjpeg_found = (
extension is CUDAExtension and
CUDA_HOME is not None and
os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))
)

print('NVJPEG found: {0}'.format(nvjpeg_found))
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
if nvjpeg_found:
print('Building torchvision with NVJPEG image support')
image_link_flags.append('nvjpeg')

image_path = os.path.join(extensions_dir, 'io', 'image')
image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))

if png_found or jpeg_found:
ext_modules.append(extension(
Expand Down
22 changes: 18 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true'
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'


@contextlib.contextmanager
Expand Down Expand Up @@ -407,11 +410,8 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names):

def cpu_and_gpu():
import pytest # noqa
# ignore CPU tests in RE as they're already covered by another contbuild
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'

# ignore CPU tests in RE as they're already covered by another contbuild
devices = [] if IN_RE_WORKER else ['cpu']

if torch.cuda.is_available():
Expand All @@ -427,3 +427,17 @@ def cpu_and_gpu():
devices.append(pytest.param('cuda', marks=cuda_marks))

return devices


def needs_cuda(test_func):
import pytest # noqa

if IN_FBCODE and not IN_RE_WORKER:
# We don't want to skip in fbcode, so we just don't collect
# TODO: slightly more robust way would be to detect if we're in a sandcastle instance
# so that the test will still be collected (and skipped) in the devvms.
return pytest.mark.dont_collect(test_func)
elif torch.cuda.is_available():
return test_func
else:
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)
41 changes: 40 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
import unittest

import pytest
import numpy as np
import torch
from PIL import Image
from common_utils import get_tmp_dir
from common_utils import get_tmp_dir, needs_cuda

from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
Expand Down Expand Up @@ -278,5 +279,43 @@ def test_write_file_non_ascii(self):
os.unlink(fpath)


@needs_cuda
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize('img_path', get_images(IMAGE_ROOT, ".jpg"))
@pytest.mark.parametrize('scripted', (False, True))
def test_decode_jpeg_cuda(mode, img_path, scripted):
if 'cmyk' in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported")
tester = ImageTester()
data = read_file(img_path)
img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device='cuda')

# Some difference expected between jpeg implementations
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2)


@needs_cuda
@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda')))
def test_decode_jpeg_cuda_device_param(cuda_device):
"""Make sure we can pass a string or a torch.device as device param"""
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
decode_jpeg(data, device=cuda_device)


@needs_cuda
def test_decode_jpeg_cuda_errors():
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(data.reshape(-1, 1), device='cuda')
with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
decode_jpeg(data.to('cuda'), device='cuda')
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(data.to(torch.float), device='cuda')
with pytest.raises(RuntimeError, match="Expected a cuda device"):
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')


if __name__ == '__main__':
unittest.main()
185 changes: 185 additions & 0 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#include "decode_jpeg_cuda.h"

#include <ATen/ATen.h>

#if NVJPEG_FOUND
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <nvjpeg.h>
#endif

#include <string>

namespace vision {
namespace image {

#if !NVJPEG_FOUND

torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support");
}

#else

namespace {
static nvjpegHandle_t nvjpeg_handle = nullptr;
}

torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");

TORCH_CHECK(
!data.is_cuda(),
"The input tensor must be on CPU when decoding with nvjpeg")

TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");

TORCH_CHECK(device.is_cuda(), "Expected a cuda device")

at::cuda::CUDAGuard device_guard(device);

// Create global nvJPEG handle
std::once_flag nvjpeg_handle_creation_flag;
std::call_once(nvjpeg_handle_creation_flag, []() {
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);

if (create_status != NVJPEG_STATUS_SUCCESS) {
// Reset handle so that one can still call the function again in the
// same process if there was a failure
free(nvjpeg_handle);
nvjpeg_handle = nullptr;
}
TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}
});

// Create the jpeg state
nvjpegJpegState_t jpeg_state;
nvjpegStatus_t state_status =
nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state);

TORCH_CHECK(
state_status == NVJPEG_STATUS_SUCCESS,
"nvjpegJpegStateCreate failed: ",
state_status);

auto datap = data.data_ptr<uint8_t>();

// Get the image information
int num_channels;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
nvjpegStatus_t info_status = nvjpegGetImageInfo(
nvjpeg_handle,
datap,
data.numel(),
&num_channels,
&subsampling,
widths,
heights);

if (info_status != NVJPEG_STATUS_SUCCESS) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
}

if (subsampling == NVJPEG_CSS_UNKNOWN) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
}

int width = widths[0];
int height = heights[0];

nvjpegOutputFormat_t ouput_format;
int num_channels_output;

switch (mode) {
case IMAGE_READ_MODE_UNCHANGED:
num_channels_output = num_channels;
// For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will
// not properly decode RGB images (it's fine for grayscale), so we set
// output_format manually here
if (num_channels == 1) {
ouput_format = NVJPEG_OUTPUT_Y;
} else if (num_channels == 3) {
ouput_format = NVJPEG_OUTPUT_RGB;
} else {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false,
"When mode is UNCHANGED, only 1 or 3 input channels are allowed.");
}
break;
case IMAGE_READ_MODE_GRAY:
ouput_format = NVJPEG_OUTPUT_Y;
num_channels_output = 1;
break;
case IMAGE_READ_MODE_RGB:
ouput_format = NVJPEG_OUTPUT_RGB;
num_channels_output = 3;
break;
default:
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG decoding on GPU");
}

auto out_tensor = torch::empty(
{int64_t(num_channels_output), int64_t(height), int64_t(width)},
torch::dtype(torch::kU8).device(device));

// nvjpegImage_t is a struct with
// - an array of pointers to each channel
// - the pitch for each channel
// which must be filled in manually
nvjpegImage_t out_image;

for (int c = 0; c < num_channels_output; c++) {
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>();
out_image.pitch[c] = width;
}
for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) {
out_image.channel[c] = nullptr;
out_image.pitch[c] = 0;
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index());

nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
jpeg_state,
datap,
data.numel(),
ouput_format,
&out_image,
stream);

nvjpegJpegStateDestroy(jpeg_state);

TORCH_CHECK(
decode_status == NVJPEG_STATUS_SUCCESS,
"nvjpegDecode failed: ",
decode_status);

return out_tensor;
}

#endif // NVJPEG_FOUND

} // namespace image
} // namespace vision
15 changes: 15 additions & 0 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device);

} // namespace image
} // namespace vision
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
.op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda);

} // namespace image
} // namespace vision
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
#include "cpu/encode_jpeg.h"
#include "cpu/encode_png.h"
#include "cpu/read_write_file.h"
#include "cuda/decode_jpeg_cuda.h"
16 changes: 13 additions & 3 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,34 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output)


def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: str = 'cpu') -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image.
the raw bytes of the JPEG image. This tensor must be on CPU,
regardless of the ``device`` parameter.
mode (ImageReadMode): the read mode used for optionally
converting the image. Default: `ImageReadMode.UNCHANGED`.
See `ImageReadMode` class for more information on various
available modes.
device (str or torch.device): The device on which the decoded image will
be stored. If a cuda device is specified, the image will be decoded
with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
supported for CUDA version >= 10.1
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_jpeg(input, mode.value)
device = torch.device(device)
if device.type == 'cuda':
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
return output


Expand Down

0 comments on commit f87ce88

Please sign in to comment.