-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] nvJPEG support #2786
[WIP] nvJPEG support #2786
Conversation
Here are some timings (on a Tesla P100), of the ops themselves and simulating doing something useful with the tensor by computing .sum() on the gpu. Times are as given by time.time() so may not account for the async nature of cuda. Timing codeimport torch
import torchvision
import time
import numpy as np
img_path = './test/assets/grace_hopper_517x606.jpg'
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file)
data = read_file(img_path)
# warmup nvjpeg
torch.ops.image.decode_jpeg_cuda(data)
times_nv = []
times_lj = []
for i in range(30):
tic = time.time()
img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
time_nv = time.time() - tic
print('time_nv', time_nv)
times_nv.append(time_nv)
tic = time.time()
img_jpeg = torch.ops.image.decode_jpeg(data)
time_ljpg = time.time() - tic
print('time_ljpg', time_ljpg)
times_lj.append(time_ljpg)
times_sum_nv = []
times_sum_lj = []
for i in range(30):
tic = time.time()
img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
img_nvjpeg_sum = img_nvjpeg.float().sum()
print(img_nvjpeg_sum)
time_nv_sum = time.time() - tic
print('time_nv', time_nv_sum)
times_sum_nv.append(time_nv_sum)
tic = time.time()
img_jpeg = torch.ops.image.decode_jpeg(data)
img_ljpg_sum = img_jpeg.cuda().float().sum()
print(img_ljpg_sum)
time_ljpg_sum = time.time() - tic
print('time_ljpg', time_ljpg_sum)
times_sum_lj.append(time_ljpg_sum)
print('mean nvjpeg', np.mean(times_nv), 'sd =', np.std(times_nv))
print('mean libjpeg', np.mean(times_lj), 'sd =', np.std(times_lj))
print('mean nvjpeg and .sum() on gpu', np.mean(times_sum_nv), 'sd =', np.std(times_sum_nv))
print('mean libjpeg and .sum() on gpu', np.mean(times_sum_lj), 'sd =', np.std(times_sum_lj))
Also nvjpeg is even faster with data.pin_memory():
|
Hi @jamt9000 Thanks a lot for the PR! For the timing code, can you add a Also, about the API, I wonder if we should have a separate function like For the code living in I have a question wrt |
Timings with torch.cuda.synchronize(): Without pin_memory
With pin_memory
Codeimport torch
import torchvision
import time
import numpy as np
img_path = './test/assets/grace_hopper_517x606.jpg'
from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
encode_png, write_png, write_file)
data = read_file(img_path)#.pin_memory()
# warmup nvjpeg
torch.ops.image.decode_jpeg_cuda(data)
times_nv = []
times_lj = []
for i in range(30):
torch.cuda.synchronize()
tic = time.time()
img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
torch.cuda.synchronize()
time_nv = time.time() - tic
print('time_nv', time_nv)
times_nv.append(time_nv)
tic = time.time()
img_jpeg = torch.ops.image.decode_jpeg(data)
time_ljpg = time.time() - tic
print('time_ljpg', time_ljpg)
times_lj.append(time_ljpg)
times_sum_nv = []
times_sum_lj = []
for i in range(30):
torch.cuda.synchronize()
tic = time.time()
img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
img_nvjpeg_sum = img_nvjpeg.float().sum()
torch.cuda.synchronize()
print(img_nvjpeg_sum)
time_nv_sum = time.time() - tic
print('time_nv', time_nv_sum)
times_sum_nv.append(time_nv_sum)
tic = time.time()
img_jpeg = torch.ops.image.decode_jpeg(data)
img_ljpg_sum = img_jpeg.cuda().float().sum()
torch.cuda.synchronize()
print(img_ljpg_sum)
time_ljpg_sum = time.time() - tic
print('time_ljpg', time_ljpg_sum)
times_sum_lj.append(time_ljpg_sum)
print('mean nvjpeg', np.mean(times_nv), 'sd =', np.std(times_nv))
print('mean libjpeg', np.mean(times_lj), 'sd =', np.std(times_lj))
print('mean nvjpeg and .sum() on gpu', np.mean(times_sum_nv), 'sd =', np.std(times_sum_nv))
print('mean libjpeg and .sum() on gpu', np.mean(times_sum_lj), 'sd =', np.std(times_sum_lj)) |
Those results seems very interesting! I'll have a closer look at this next week. I want to compare against libjpeg-turbo as well, and check on how we can distribute the binaries for this. Another thing is that I would like to think is if an API for decoding a list of images at once makes sense, and how we could make this as user-friendly as possible. This will probably be left for a separate PR though, as it's orthogonal to adding CUDA support for Also, are your results using
and see if torchvision is linked against |
For the python api certainly a single function would make sense. In fact I suppose read_image would be the higher level function people will call, so the device flag would make sense there - although it wouldn't have gpu png decoding. In terms of C++ the implementations will necessarily be separate, although the op under torch.ops.image.* could still be unified (I guess this will be relevant if the data pipeline gets compiled to torchscript - and I don't know how selecting a gpu device would work there if the input buffers are all cpu, it would be bad to have a gpu ordinal baked into the torchscript). |
It is non-turbo libjpeg
This reminds me of vl_imreadjpeg from MatConvNet, which as I recall worked pretty well at speedily turning a list of filenames into a gpu array of images. It helped that the dataset getBatch functions operated on a list of indices, whereas Pytorch has |
Just request the user to provide a
Ok, thanks for confirming. I'll try on my end to see the speedup that we obtain by using libjpeg-turbo
My thinking is that some datasets (like IterableDataset) allows the dataset to return a batch at once. This means that the dataset implementation could forward a list of image paths to |
Hi @jamt9000! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but we do not have a signature on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Is the image stuff under torchvision/csrc/io/image/ now? |
Yes, now CPU implementation lies in |
It will be included in the cuda sdk lib and include paths set up by CUDAExtension
Now working on CI for linux gpu! https://app.circleci.com/pipelines/github/pytorch/vision/5916/workflows/d618a212-368c-495a-b8e5-5d10f9857ba0/jobs/381128 It looks like building as a CUDAExtension with a recent CUDA version is sufficient for the existence of nvjpeg.h/libnvjpeg.so since it seems to exist as standard in cuda toolkit installs (ie the full system one with nvcc, not conda's cudatoolkit which is just the runtime libs)
(it might pick up the .so from conda's cudatoolkit depending on the library path order, but it ought to be the same version as otherwise there'd be bigger problems with mismatched pytorch cuda) |
I don't know if we should implement a dispatcher or an option to pick nvjpeg on |
#include <ATen/cuda/CUDAContext.h> | ||
#include <nvjpeg.h> | ||
|
||
static nvjpegHandle_t nvjpeg_handle = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there device-specific state associated with the nvjpegHandle? i.e. is it safe/optimal to create a nvjpeg_handle on one device, then switch CUDA devices and use the previously created nvjpeg_handle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't completely obvious...although these parts of DALI give some clues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs say:
The library handle is used in any consecutive nvJPEG library calls, and should be initialized first. The library handle is thread safe, and can be used by multiple threads simultaneously.
With the links above where the struct is just defined as a global variable, and the fact that they call it a "library handle", I'd say it's fair to assume that it's not device-specific.
From a quick test with 2 GPUs, this passes fine:
img_nvjpeg = decode_jpeg(data, mode=mode, device='cuda:0')
img_nvjpeg2 = decode_jpeg(data, mode=mode, device='cuda:1')
self.assertTrue((img_nvjpeg.cpu().float() - img_nvjpeg2.cpu().float()).abs().mean() < 1e-10)
I believe we will soon be deprecating CUDA 9.2 in PyTorch, so I propose that we don't ship this feature with torchvision for CUDA 9.2 binaries.
I think given the current API of Additionally, with the |
Hi @jamt9000 , Do you think you would have time sometime soon to address @ajtulloch and my comments so that we can move forward merging this PR? If you are busy it's fine, I could ask someone from the team to build on top of this PR so that we can get it merged soon. What do you think? |
Could you also report some benchmark results in megapixels/second of JPEG decode throughput (i.e. images-per-second * image-height * image-width) which will be useful for comparing to other implementations (e.g. libjpeg-turbo, nvJPEG batched on A100, etc). You might find building libjpeg-turbo (ensuring you have nasm/yasm installed so you have simd decode built), and running
and reporting those results useful for comparing against the "best" CPU JPEG decode impl. |
Hello @fmassa ! I should have some time this weekend to look at it. Things I'm not sure about are a) how the device should be given to the python wrapper and passed down and b) where to initialise and store global state |
@@ -149,7 +149,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): | |||
write_file(filename, output) | |||
|
|||
|
|||
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: | |||
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, | |||
device: torch.device = 'cpu') -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should annotation be Union[int, str, torch.device]
? eg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just make it a torch.device
, because I think torchscript doesn't yet support Union
@@ -166,7 +167,11 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG | |||
Returns: | |||
output (Tensor[image_channels, image_height, image_width]) | |||
""" | |||
output = torch.ops.image.decode_jpeg(input, mode.value) | |||
device = torch.device(device) | |||
if device.type == 'cuda': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the dispatching be done here or in C++? How will it interact with torchscript?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about making it dispatch on C++ directly, but this should also work with torchscript I believe
Hi @jamt9000 , Very sorry for the delay in replying. I didn't have the bandwidth to spend time with the PR. @NicolasHug will be stepping up and making sure we can get this PR to completion for the next release, and will be looking into the global state, handling CUDA versions < 10, dispatching to CUDA / CPU on C++ / benchmarking / etc. Nicolas will be reviewing the PR and will probably have questions, let us know how you would prefer you both to collaborate on this PR |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks! |
I ran some quick benchmark on a Tesla V100 (CPU decoding is on libjpeg, not libjpeg-turbo):
Codeimport torch
from torch.utils.benchmark import Timer
from torchvision.io.image import decode_jpeg, read_file, ImageReadMode
img_path = 'test/assets/encode_jpeg/grace_hopper_517x606.jpg'
data = read_file(img_path)
img = decode_jpeg(data)
num_pixels = img.numel() / 3
num_runs = 1000
stmt = "decode_jpeg(data, device='{}')"
setup = 'from torchvision.io.image import decode_jpeg'
globals = {'data': data}
for device in ('cpu', 'cuda'):
timer = Timer(stmt=stmt.format(device), setup=setup, globals=globals)
measurement = timer.timeit(num_runs)
print(
f"{device.upper():<5} mean: {measurement.mean * 1000:.3f} ms, median: {measurement.median * 1000:.3f} ms, "
f"Throughput = {num_pixels / 1_000_000 / measurement.median:.3f} Megapixel / sec, "
f"{1 / measurement.median:.3f} fps"
) Using libjpeg turbo
So the GPU decoding is faster than on CPU but slower than with libjpeg turbo. nvjpeg supports batch decoding so I will try to implement that and see if we can get a gain there. |
Hello! I'm quite busy at the moment but happy to discuss here or on slack if @NicolasHug wants to make improvements
It may also be worth checking the overhead for transferring to gpu that will be incurred with libjpeg-turbo but not with nvjpeg |
I implemented a quick-and-dirty version of batch-decoding (https://github.com/NicolasHug/vision/blob/380d5b5b563a55d26a69e66bdb1c9076b521031c/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp#L178), and ran a few more benchmarks on a V100. Everything is single-threaded. I didn't observe any significant change when allowing more than one CPU thread for batch-decoding. Looking at the different throughputs:
CC @fmassa Benchmark code for refimport torch
from torch.utils.benchmark import Timer
from torchvision.io.image import decode_jpeg, read_file, ImageReadMode, write_jpeg, encode_jpeg
from torchvision import transforms as T
import sys
img_path = sys.argv[1]
data = read_file(img_path)
img = decode_jpeg(data)
def sumup(name, mean, median, throughput, fps):
print(
f"{name:<20} - mean: {mean:<7.2f} ms, median: {median:<7.2f} ms, "
f"Throughput = {throughput:<7.1f} Megapixel / sec, "
f"{fps:<7.1f} fps"
)
print(f"Using {img_path}")
print(f"{img.shape = }, {data.shape = }")
height, width = img.shape[-2:]
num_pixels = height * width
num_runs = 30
for batch_size in (1, 4, 16, 32, 64):
print(f"{batch_size = }")
# non-batch implem
for device in ('cpu', 'cuda'):
if batch_size >= 32 and height >= 1000:
print(f"skipping for-loop for {batch_size = } and {device = }")
continue
stmt = f"for _ in range(batch_size): decode_jpeg(data, device='{device}')"
setup = 'from torchvision.io.image import decode_jpeg'
globals = {'data': data, 'batch_size': batch_size}
t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs)
sumup(f"for-loop {device}", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median)
# Batch implem
stmt = "torch.ops.image.decode_jpeg_batch_cuda(batch_data, mode, device, batch_size, height, width)"
setup = 'import torch'
batch_data = torch.cat([data] * batch_size, dim=0)
globals = {
'batch_data': batch_data, 'mode': ImageReadMode.UNCHANGED.value, 'device': torch.device('cuda'), 'batch_size': batch_size,
'height': height, 'width': width
}
t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs)
sumup(f"BATCH cuda", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median)
300x300
517x606
1920x1080
2000x2000
|
@NicolasHug Is it possible to for you to benchmark nvjpeg with an A100, it is supposed to be much faster than V100 according to: https://developer.nvidia.com/nvjpeg |
Sorry @cceyda I don't have access to an A100 but this might be in the roadmap in the future |
|
This is a minimal proof-of-concept supporting nvJPEG (#2742) and adds a
torch.ops.image.decode_jpeg_cuda
op.For the nvjpeg.h it will need cudatoolkit-devIt currently succeeds in loading RGB jpeg images into a gpu tensor,
although I wouldn't currently expect it to be very fast because it re-initialises the library each time it is called(not sure if there is a global place to do this?). The loaded images differ from PIL although look visually similar.It is also currently under cpu/ even though it needs cuda because I wasn't sure how to integrate with the image op code otherwise.