Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] nvJPEG support #2786

Closed
wants to merge 17 commits into from
Closed

[WIP] nvJPEG support #2786

wants to merge 17 commits into from

Conversation

jamt9000
Copy link
Contributor

@jamt9000 jamt9000 commented Oct 10, 2020

This is a minimal proof-of-concept supporting nvJPEG (#2742) and adds a torch.ops.image.decode_jpeg_cuda op.

For the nvjpeg.h it will need cudatoolkit-dev

It currently succeeds in loading RGB jpeg images into a gpu tensor, although I wouldn't currently expect it to be very fast because it re-initialises the library each time it is called (not sure if there is a global place to do this?). The loaded images differ from PIL although look visually similar.

It is also currently under cpu/ even though it needs cuda because I wasn't sure how to integrate with the image op code otherwise.

@jamt9000
Copy link
Contributor Author

jamt9000 commented Oct 11, 2020

Here are some timings (on a Tesla P100), of the ops themselves and simulating doing something useful with the tensor by computing .sum() on the gpu. Times are as given by time.time() so may not account for the async nature of cuda.

Timing code
import torch
import torchvision
import time
import numpy as np

img_path = './test/assets/grace_hopper_517x606.jpg'

from torchvision.io.image import (
    decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
    encode_png, write_png, write_file)


data = read_file(img_path)

# warmup nvjpeg
torch.ops.image.decode_jpeg_cuda(data)

times_nv = []
times_lj = []

for i in range(30):
    tic = time.time()
    img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
    time_nv = time.time() - tic

    print('time_nv', time_nv)
    times_nv.append(time_nv)

    tic = time.time()
    img_jpeg = torch.ops.image.decode_jpeg(data)
    time_ljpg = time.time() - tic

    print('time_ljpg', time_ljpg)
    times_lj.append(time_ljpg)

times_sum_nv = []
times_sum_lj = []

for i in range(30):
    tic = time.time()
    img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
    img_nvjpeg_sum = img_nvjpeg.float().sum()
    print(img_nvjpeg_sum)
    time_nv_sum = time.time() - tic
    print('time_nv', time_nv_sum)
    times_sum_nv.append(time_nv_sum)

    tic = time.time()
    img_jpeg = torch.ops.image.decode_jpeg(data)
    img_ljpg_sum = img_jpeg.cuda().float().sum()
    print(img_ljpg_sum)
    time_ljpg_sum = time.time() - tic
    print('time_ljpg', time_ljpg_sum)
    times_sum_lj.append(time_ljpg_sum)

print('mean nvjpeg', np.mean(times_nv), 'sd =', np.std(times_nv))
print('mean libjpeg', np.mean(times_lj), 'sd =', np.std(times_lj))

print('mean nvjpeg and .sum() on gpu', np.mean(times_sum_nv), 'sd =', np.std(times_sum_nv))
print('mean libjpeg and .sum() on gpu', np.mean(times_sum_lj), 'sd =', np.std(times_sum_lj))
mean nvjpeg 0.002721214294433594 sd = 0.00019912007008205212
mean libjpeg 0.004462607701619466 sd = 0.00019312457023127426
mean nvjpeg and .sum() on gpu 0.003558675448099772 sd = 0.00027405448467918706
mean libjpeg and .sum() on gpu 0.005720305442810059 sd = 0.00033805148933798226

Also nvjpeg is even faster with data.pin_memory():

mean nvjpeg 0.0016569773356119792 sd = 0.0001282730913417541
mean nvjpeg and .sum() on gpu 0.0022617975870768228 sd = 0.00026759146078698454

@fmassa
Copy link
Member

fmassa commented Oct 12, 2020

Hi @jamt9000

Thanks a lot for the PR!
I'll have a closer look at it after Wednesday, but from a first glance it looks very good!
We don't currently have a global state in torchvision C++ yet. We do have it in PyTorch, but I think having this in torchvision is fine (but I'll check afterwards)

For the timing code, can you add a torch.cuda.synchronize() call before / after the invocations of the CUDA function, so that we can get a more accurate timing? Sometimes benchmarking CUDA code can be misleading because it only benchmarks the kernel launch time, and not the execution, due to the asynchronous nature of CUDA.

Also, about the API, I wonder if we should have a separate function like decode_jpeg_cuda or if we should expose it via a device argument in decode_jpeg, something like decode_jpeg(data, device)?

For the code living in cpu/ and not in cuda/ it's fine for now and we can fix that before merging the PR.

I have a question wrt cudatoolkit-dev : do you know if this is something that we could have in the cudatoolkit provided by the pytorch channel? This way we wouldn't need to change installation instructions and it would work more seamlessly with the rest of the commands. @seemethere do you have an idea?

@jamt9000
Copy link
Contributor Author

Timings with torch.cuda.synchronize():

Without pin_memory

mean nvjpeg 0.0028519550959269207 sd = 0.00017520657178623838
mean libjpeg 0.0045461893081665036 sd = 0.00025594093510109754
mean nvjpeg and .sum() on gpu 0.0036966085433959963 sd = 0.0005660666376655812
mean libjpeg and .sum() on gpu 0.005520820617675781 sd = 0.00020280927568013896

With pin_memory

mean nvjpeg 0.001844644546508789 sd = 0.00015511983771863112
mean libjpeg 0.004541428883870443 sd = 0.00016329190528891328
mean nvjpeg and .sum() on gpu 0.0026334524154663086 sd = 0.0004767436442469833
mean libjpeg and .sum() on gpu 0.005513159434000651 sd = 0.00019829479987284863
Code
import torch
import torchvision
import time
import numpy as np

img_path = './test/assets/grace_hopper_517x606.jpg'

from torchvision.io.image import (
    decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
    encode_png, write_png, write_file)


data = read_file(img_path)#.pin_memory()

# warmup nvjpeg
torch.ops.image.decode_jpeg_cuda(data)

times_nv = []
times_lj = []

for i in range(30):
    torch.cuda.synchronize()

    tic = time.time()
    img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
    torch.cuda.synchronize()
    time_nv = time.time() - tic

    print('time_nv', time_nv)
    times_nv.append(time_nv)

    tic = time.time()
    img_jpeg = torch.ops.image.decode_jpeg(data)
    time_ljpg = time.time() - tic

    print('time_ljpg', time_ljpg)
    times_lj.append(time_ljpg)

times_sum_nv = []
times_sum_lj = []

for i in range(30):
    torch.cuda.synchronize()

    tic = time.time()
    img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data)
    img_nvjpeg_sum = img_nvjpeg.float().sum()
    torch.cuda.synchronize()
    print(img_nvjpeg_sum)
    time_nv_sum = time.time() - tic

    print('time_nv', time_nv_sum)
    times_sum_nv.append(time_nv_sum)

    tic = time.time()
    img_jpeg = torch.ops.image.decode_jpeg(data)
    img_ljpg_sum = img_jpeg.cuda().float().sum()
    torch.cuda.synchronize()
    print(img_ljpg_sum)
    time_ljpg_sum = time.time() - tic
    print('time_ljpg', time_ljpg_sum)
    times_sum_lj.append(time_ljpg_sum)

print('mean nvjpeg', np.mean(times_nv), 'sd =', np.std(times_nv))
print('mean libjpeg', np.mean(times_lj), 'sd =', np.std(times_lj))

print('mean nvjpeg and .sum() on gpu', np.mean(times_sum_nv), 'sd =', np.std(times_sum_nv))
print('mean libjpeg and .sum() on gpu', np.mean(times_sum_lj), 'sd =', np.std(times_sum_lj))

@fmassa
Copy link
Member

fmassa commented Oct 17, 2020

Those results seems very interesting!

I'll have a closer look at this next week. I want to compare against libjpeg-turbo as well, and check on how we can distribute the binaries for this.

Another thing is that I would like to think is if an API for decoding a list of images at once makes sense, and how we could make this as user-friendly as possible. This will probably be left for a separate PR though, as it's orthogonal to adding CUDA support for decode_jpeg. I wonder if you have any thoughts on this?

Also, are your results using libjpeg or libjpeg-turbo for the CPU counterpart? Could you do

ldd path/to/torchvision/image.so

and see if torchvision is linked against libjpeg or libjpeg-turbo?

@jamt9000
Copy link
Contributor Author

Also, about the API, I wonder if we should have a separate function like decode_jpeg_cuda or if we should expose it via a device argument in decode_jpeg, something like decode_jpeg(data, device)?

For the python api certainly a single function would make sense. In fact I suppose read_image would be the higher level function people will call, so the device flag would make sense there - although it wouldn't have gpu png decoding.

In terms of C++ the implementations will necessarily be separate, although the op under torch.ops.image.* could still be unified (I guess this will be relevant if the data pipeline gets compiled to torchscript - and I don't know how selecting a gpu device would work there if the input buffers are all cpu, it would be bad to have a gpu ordinal baked into the torchscript).

@jamt9000
Copy link
Contributor Author

Also, are your results using libjpeg or libjpeg-turbo for the CPU counterpart?

It is non-turbo libjpeg

libjpeg.so.9 => /opt/conda/envs/torchvisiondev/lib/libjpeg.so.9

Another thing is that I would like to think is if an API for decoding a list of images at once makes sense, and how we could make this as user-friendly as possible.

This reminds me of vl_imreadjpeg from MatConvNet, which as I recall worked pretty well at speedily turning a list of filenames into a gpu array of images. It helped that the dataset getBatch functions operated on a list of indices, whereas Pytorch has __getitem__ takes a single index. It should still be a good addition for speedy inference on many files though (which is probably also the main usecase of nvjpeg decoding, since cpu decoding can be done "for free" while the gpu is busy during training, and afaik multi-process dataloading doesn't work with gpu tensors).

@fmassa
Copy link
Member

fmassa commented Oct 19, 2020

and I don't know how selecting a gpu device would work there if the input buffers are all cpu

Just request the user to provide a torch.device object, there are a few workarounds that we can do with torchscript so that it's supported.

It is non-turbo libjpeg

Ok, thanks for confirming. I'll try on my end to see the speedup that we obtain by using libjpeg-turbo

whereas Pytorch has getitem takes a single index. It should still be a good addition for speedy inference on many files though

My thinking is that some datasets (like IterableDataset) allows the dataset to return a batch at once. This means that the dataset implementation could forward a list of image paths to read_image to decode, while the transforms are also performed in parallel. But you do have a point that GPUs don't play very well with the DataLoader and multiple workers, but this is something that will be improved overtime.

@facebook-github-bot
Copy link

Hi @jamt9000!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@fmassa fmassa mentioned this pull request Jan 20, 2021
@jamt9000
Copy link
Contributor Author

Is the image stuff under torchvision/csrc/io/image/ now?

@andfoy
Copy link
Contributor

andfoy commented Jan 22, 2021

Is the image stuff under torchvision/csrc/io/image/ now?

Yes, now CPU implementation lies in cpu, which would mean that this addition should belong in a separate cuda folder

@jamt9000
Copy link
Contributor Author

Now working on CI for linux gpu! https://app.circleci.com/pipelines/github/pytorch/vision/5916/workflows/d618a212-368c-495a-b8e5-5d10f9857ba0/jobs/381128

It looks like building as a CUDAExtension with a recent CUDA version is sufficient for the existence of nvjpeg.h/libnvjpeg.so since it seems to exist as standard in cuda toolkit installs (ie the full system one with nvcc, not conda's cudatoolkit which is just the runtime libs)

/usr/local/cuda-11.1/targets/x86_64-linux/include/nvjpeg.h
/usr/local/cuda-11.0/targets/x86_64-linux/include/nvjpeg.h
/usr/local/cuda-10.2/targets/x86_64-linux/include/nvjpeg.h
/usr/local/cuda-10.1/targets/x86_64-linux/include/nvjpeg.h
/usr/local/cuda-10.0/include/nvjpeg.h

(it might pick up the .so from conda's cudatoolkit depending on the library path order, but it ought to be the same version as otherwise there'd be bigger problems with mismatched pytorch cuda)

@andfoy
Copy link
Contributor

andfoy commented Jan 26, 2021

@jamt9000, thanks for updating the PR, we're not still sure when CUDA 9.2 support is going to be dropped, so I don't know if we should ship nvjpeg on 9.2.

cc @fmassa

@andfoy
Copy link
Contributor

andfoy commented Jan 26, 2021

I don't know if we should implement a dispatcher or an option to pick nvjpeg on read/decode_image?

#include <ATen/cuda/CUDAContext.h>
#include <nvjpeg.h>

static nvjpegHandle_t nvjpeg_handle = nullptr;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there device-specific state associated with the nvjpegHandle? i.e. is it safe/optimal to create a nvjpeg_handle on one device, then switch CUDA devices and use the previously created nvjpeg_handle?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs say:

The library handle is used in any consecutive nvJPEG library calls, and should be initialized first. The library handle is thread safe, and can be used by multiple threads simultaneously.

With the links above where the struct is just defined as a global variable, and the fact that they call it a "library handle", I'd say it's fair to assume that it's not device-specific.

From a quick test with 2 GPUs, this passes fine:

                img_nvjpeg = decode_jpeg(data, mode=mode, device='cuda:0')
                img_nvjpeg2 = decode_jpeg(data, mode=mode, device='cuda:1')
                self.assertTrue((img_nvjpeg.cpu().float() - img_nvjpeg2.cpu().float()).abs().mean() < 1e-10)

@fmassa
Copy link
Member

fmassa commented Jan 28, 2021

so I don't know if we should ship nvjpeg on 9.2.

I believe we will soon be deprecating CUDA 9.2 in PyTorch, so I propose that we don't ship this feature with torchvision for CUDA 9.2 binaries.

I don't know if we should implement a dispatcher or an option to pick nvjpeg on read/decode_image?

I think given the current API of nvjpeg, we might need to pass a device argument to decode_jpeg, which will internally dispatch to either the CPU or CUDA implementations.

Additionally, with the device approach, we might be able to properly initialize whatever gpu-specific states are needed, as @ajtulloch mentioned.

@fmassa
Copy link
Member

fmassa commented Feb 2, 2021

Hi @jamt9000 ,

Do you think you would have time sometime soon to address @ajtulloch and my comments so that we can move forward merging this PR?

If you are busy it's fine, I could ask someone from the team to build on top of this PR so that we can get it merged soon.

What do you think?

@ajtulloch
Copy link

ajtulloch commented Feb 2, 2021

Could you also report some benchmark results in megapixels/second of JPEG decode throughput (i.e. images-per-second * image-height * image-width) which will be useful for comparing to other implementations (e.g. libjpeg-turbo, nvJPEG batched on A100, etc).

You might find building libjpeg-turbo (ensuring you have nasm/yasm installed so you have simd decode built), and running

./tjbench <my-jpeg>  -fastdct -fastupsample -nowrite

and reporting those results useful for comparing against the "best" CPU JPEG decode impl.

@jamt9000
Copy link
Contributor Author

jamt9000 commented Feb 3, 2021

Hello @fmassa ! I should have some time this weekend to look at it.

Things I'm not sure about are a) how the device should be given to the python wrapper and passed down and b) where to initialise and store global state

@@ -149,7 +149,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output)


def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: torch.device = 'cpu') -> torch.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should annotation be Union[int, str, torch.device]? eg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just make it a torch.device, because I think torchscript doesn't yet support Union

@@ -166,7 +167,11 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_jpeg(input, mode.value)
device = torch.device(device)
if device.type == 'cuda':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the dispatching be done here or in C++? How will it interact with torchscript?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about making it dispatch on C++ directly, but this should also work with torchscript I believe

@fmassa
Copy link
Member

fmassa commented Apr 8, 2021

Hi @jamt9000 ,

Very sorry for the delay in replying. I didn't have the bandwidth to spend time with the PR.

@NicolasHug will be stepping up and making sure we can get this PR to completion for the next release, and will be looking into the global state, handling CUDA versions < 10, dispatching to CUDA / CPU on C++ / benchmarking / etc.

Nicolas will be reviewing the PR and will probably have questions, let us know how you would prefer you both to collaborate on this PR

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

@NicolasHug
Copy link
Member

I ran some quick benchmark on a Tesla V100 (CPU decoding is on libjpeg, not libjpeg-turbo):

CPU   mean: 4.728 ms, median: 4.728 ms, Throughput = 66.265 Megapixel / sec, 211.504 fps
CUDA  mean: 2.196 ms, median: 2.196 ms, Throughput = 142.637 Megapixel / sec, 455.271 fps
Code
import torch
from torch.utils.benchmark import Timer
from torchvision.io.image import decode_jpeg, read_file, ImageReadMode

img_path = 'test/assets/encode_jpeg/grace_hopper_517x606.jpg'
data = read_file(img_path)
img = decode_jpeg(data)
num_pixels = img.numel() / 3

num_runs = 1000

stmt = "decode_jpeg(data, device='{}')"
setup = 'from torchvision.io.image import decode_jpeg'
globals = {'data': data}

for device in ('cpu', 'cuda'):
    timer = Timer(stmt=stmt.format(device), setup=setup, globals=globals)
    measurement = timer.timeit(num_runs)
    print(
        f"{device.upper():<5} mean: {measurement.mean * 1000:.3f} ms, median: {measurement.median * 1000:.3f} ms, "
        f"Throughput = {num_pixels / 1_000_000 / measurement.median:.3f} Megapixel / sec, "
        f"{1 / measurement.median:.3f} fps"
    )

Using libjpeg turbo tjbench as suggested by @ajtulloch I get:


Using fastest DCT/IDCT algorithm

Using fast upsampling code

>>>>>  JPEG 4:2:0 --> BGR (Top-down)  <<<<<

Image size: 517 x 606
Decompress    --> Frame rate:         577.290138 fps
                  Throughput:         180.866155 Megapixels/sec

So the GPU decoding is faster than on CPU but slower than with libjpeg turbo.

nvjpeg supports batch decoding so I will try to implement that and see if we can get a gain there.

@jamt9000
Copy link
Contributor Author

Nicolas will be reviewing the PR and will probably have questions, let us know how you would prefer you both to collaborate on this PR

Hello! I'm quite busy at the moment but happy to discuss here or on slack if @NicolasHug wants to make improvements

So the GPU decoding is faster than on CPU but slower than with libjpeg turbo

It may also be worth checking the overhead for transferring to gpu that will be incurred with libjpeg-turbo but not with nvjpeg

@NicolasHug
Copy link
Member

NicolasHug commented May 4, 2021

I implemented a quick-and-dirty version of batch-decoding (https://github.com/NicolasHug/vision/blob/380d5b5b563a55d26a69e66bdb1c9076b521031c/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp#L178), and ran a few more benchmarks on a V100. Everything is single-threaded. I didn't observe any significant change when allowing more than one CPU thread for batch-decoding.

Looking at the different throughputs:

  • batch decoding doesn't seem to be a win, although it's slightly faster than manual for-loop for small-ish images
  • nvjpeg is about 2-3X faster than basic CPU decoding (without libjpeg-turbo), which is somewhat consistent with the results reported in https://developer.nvidia.com/nvjpeg
  • nvjpeg is never faster than libjpeg-turbo
  • The gap between nvjpeg and libjpeg-turbo reduces as the image size grows. They're about the same for the big 2000x2000 image.

CC @fmassa

Benchmark code for ref
import torch
from torch.utils.benchmark import Timer
from torchvision.io.image import decode_jpeg, read_file, ImageReadMode, write_jpeg, encode_jpeg
from torchvision import transforms as T
import sys

img_path = sys.argv[1]
data = read_file(img_path)
img = decode_jpeg(data)

def sumup(name, mean, median, throughput, fps):
    print(
        f"{name:<20} - mean: {mean:<7.2f} ms, median: {median:<7.2f} ms, "
        f"Throughput = {throughput:<7.1f} Megapixel / sec, "
        f"{fps:<7.1f} fps"
    )

print(f"Using {img_path}")
print(f"{img.shape = }, {data.shape = }")
height, width = img.shape[-2:]

num_pixels = height * width
num_runs = 30


for batch_size in (1, 4, 16, 32, 64):
    print(f"{batch_size = }")

    # non-batch implem
    for device in ('cpu', 'cuda'):
        if batch_size >= 32 and height >= 1000:
            print(f"skipping for-loop for {batch_size = } and {device = }")
            continue
        stmt = f"for _ in range(batch_size): decode_jpeg(data, device='{device}')"
        setup = 'from torchvision.io.image import decode_jpeg'
        globals = {'data': data, 'batch_size': batch_size}

        t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs)
        sumup(f"for-loop {device}", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median)

    # Batch implem
    stmt = "torch.ops.image.decode_jpeg_batch_cuda(batch_data, mode, device, batch_size, height, width)"
    setup = 'import torch'
    batch_data = torch.cat([data] * batch_size, dim=0)
    globals = {
        'batch_data': batch_data, 'mode': ImageReadMode.UNCHANGED.value, 'device': torch.device('cuda'), 'batch_size': batch_size,
        'height': height, 'width': width
    }
    t = Timer(stmt=stmt, setup=setup, globals=globals).timeit(num_runs)

    sumup(f"BATCH cuda", t.mean * 1000, t.median * 1000, num_pixels * batch_size / 1e6 / t.median, batch_size / t.median)

300x300

(pt) ➜  vision git:(nvjpeg_bis) ✗ python nvjpeg_bench.py grace_300x300.jpg
Using grace_300x300.jpg
img.shape = torch.Size([3, 300, 300]), data.shape = torch.Size([18048])
batch_size = 1
for-loop cpu         - mean: 1.32    ms, median: 1.32    ms, Throughput = 68.4    Megapixel / sec, 759.4   fps
for-loop cuda        - mean: 0.65    ms, median: 0.65    ms, Throughput = 137.6   Megapixel / sec, 1528.5  fps
BATCH cuda           - mean: 0.66    ms, median: 0.66    ms, Throughput = 136.3   Megapixel / sec, 1514.5  fps
batch_size = 4
for-loop cpu         - mean: 5.25    ms, median: 5.25    ms, Throughput = 68.6    Megapixel / sec, 762.6   fps
for-loop cuda        - mean: 2.70    ms, median: 2.70    ms, Throughput = 133.2   Megapixel / sec, 1479.7  fps
BATCH cuda           - mean: 2.26    ms, median: 2.26    ms, Throughput = 159.1   Megapixel / sec, 1767.5  fps
batch_size = 16
for-loop cpu         - mean: 21.01   ms, median: 21.01   ms, Throughput = 68.5    Megapixel / sec, 761.6   fps
for-loop cuda        - mean: 10.68   ms, median: 10.68   ms, Throughput = 134.8   Megapixel / sec, 1497.9  fps
BATCH cuda           - mean: 9.18    ms, median: 9.18    ms, Throughput = 156.8   Megapixel / sec, 1742.6  fps
batch_size = 32
for-loop cpu         - mean: 41.92   ms, median: 41.92   ms, Throughput = 68.7    Megapixel / sec, 763.4   fps
for-loop cuda        - mean: 21.48   ms, median: 21.48   ms, Throughput = 134.1   Megapixel / sec, 1489.7  fps
BATCH cuda           - mean: 17.15   ms, median: 17.15   ms, Throughput = 167.9   Megapixel / sec, 1865.4  fps
batch_size = 64
for-loop cpu         - mean: 83.91   ms, median: 83.91   ms, Throughput = 68.6    Megapixel / sec, 762.8   fps
for-loop cuda        - mean: 43.52   ms, median: 43.52   ms, Throughput = 132.4   Megapixel / sec, 1470.6  fps
BATCH cuda           - mean: 34.86   ms, median: 34.86   ms, Throughput = 165.2   Megapixel / sec, 1835.8  fps
(pt) ➜  vision git:(nvjpeg_bis) ✗ ../libjpeg-turbo/build/tjbench grace_300x300.jpg -fastdct -fastupsample -nowrite

Using fastest DCT/IDCT algorithm

Using fast upsampling code

>>>>>  JPEG 4:2:0 --> BGR (Top-down)  <<<<<

Image size: 300 x 300
Decompress    --> Frame rate:         2232.743573 fps
                  Throughput:         200.946922 Megapixels/sec

517x606

(pt) ➜  vision git:(nvjpeg_bis) ✗ python nvjpeg_bench.py grace_517x606.jpg
Using grace_517x606.jpg
img.shape = torch.Size([3, 606, 517]), data.shape = torch.Size([73746])
batch_size = 1
for-loop cpu         - mean: 4.73    ms, median: 4.73    ms, Throughput = 66.2    Megapixel / sec, 211.4   fps
for-loop cuda        - mean: 2.16    ms, median: 2.16    ms, Throughput = 145.1   Megapixel / sec, 463.2   fps
BATCH cuda           - mean: 2.11    ms, median: 2.11    ms, Throughput = 148.6   Megapixel / sec, 474.4   fps
batch_size = 4
for-loop cpu         - mean: 19.68   ms, median: 19.68   ms, Throughput = 63.7    Megapixel / sec, 203.3   fps
for-loop cuda        - mean: 8.66    ms, median: 8.66    ms, Throughput = 144.7   Megapixel / sec, 461.8   fps
BATCH cuda           - mean: 10.28   ms, median: 10.28   ms, Throughput = 121.9   Megapixel / sec, 389.0   fps
batch_size = 16
for-loop cpu         - mean: 76.17   ms, median: 76.17   ms, Throughput = 65.8    Megapixel / sec, 210.0   fps
for-loop cuda        - mean: 35.32   ms, median: 35.32   ms, Throughput = 141.9   Megapixel / sec, 453.0   fps
BATCH cuda           - mean: 32.17   ms, median: 32.17   ms, Throughput = 155.8   Megapixel / sec, 497.4   fps
batch_size = 32
for-loop cpu         - mean: 153.15  ms, median: 153.15  ms, Throughput = 65.5    Megapixel / sec, 208.9   fps
for-loop cuda        - mean: 71.05   ms, median: 71.05   ms, Throughput = 141.1   Megapixel / sec, 450.4   fps
BATCH cuda           - mean: 67.50   ms, median: 67.50   ms, Throughput = 148.5   Megapixel / sec, 474.1   fps
batch_size = 64
for-loop cpu         - mean: 302.21  ms, median: 302.21  ms, Throughput = 66.3    Megapixel / sec, 211.8   fps
for-loop cuda        - mean: 146.51  ms, median: 146.51  ms, Throughput = 136.9   Megapixel / sec, 436.8   fps
BATCH cuda           - mean: 137.36  ms, median: 137.36  ms, Throughput = 146.0   Megapixel / sec, 465.9   fps
(pt) ➜  vision git:(nvjpeg_bis) ✗ ../libjpeg-turbo/build/tjbench grace_517x606.jpg -fastdct -fastupsample -nowrite

Using fastest DCT/IDCT algorithm

Using fast upsampling code

>>>>>  JPEG 4:2:0 --> BGR (Top-down)  <<<<<

Image size: 517 x 606
Decompress    --> Frame rate:         575.745642 fps
                  Throughput:         180.382261 Megapixels/sec

1920x1080

(pt) ➜  vision git:(nvjpeg_bis) ✗ python nvjpeg_bench.py grace_1920x1080.jpg
Using grace_1920x1080.jpg
img.shape = torch.Size([3, 1920, 1080]), data.shape = torch.Size([201324])
batch_size = 1
for-loop cpu         - mean: 25.13   ms, median: 25.13   ms, Throughput = 82.5    Megapixel / sec, 39.8    fps
for-loop cuda        - mean: 8.01    ms, median: 8.01    ms, Throughput = 258.7   Megapixel / sec, 124.8   fps
BATCH cuda           - mean: 13.59   ms, median: 13.59   ms, Throughput = 152.6   Megapixel / sec, 73.6    fps
batch_size = 4
for-loop cpu         - mean: 98.52   ms, median: 98.52   ms, Throughput = 84.2    Megapixel / sec, 40.6    fps
for-loop cuda        - mean: 30.40   ms, median: 30.40   ms, Throughput = 272.9   Megapixel / sec, 131.6   fps
BATCH cuda           - mean: 40.49   ms, median: 40.49   ms, Throughput = 204.9   Megapixel / sec, 98.8    fps
batch_size = 16
for-loop cpu         - mean: 393.87  ms, median: 393.87  ms, Throughput = 84.2    Megapixel / sec, 40.6    fps
for-loop cuda        - mean: 123.02  ms, median: 123.02  ms, Throughput = 269.7   Megapixel / sec, 130.1   fps
BATCH cuda           - mean: 206.93  ms, median: 206.93  ms, Throughput = 160.3   Megapixel / sec, 77.3    fps
batch_size = 32
skipping for-loop for batch_size = 32 and device = 'cpu'
skipping for-loop for batch_size = 32 and device = 'cuda'
BATCH cuda           - mean: 454.86  ms, median: 454.86  ms, Throughput = 145.9   Megapixel / sec, 70.4    fps
batch_size = 64
skipping for-loop for batch_size = 64 and device = 'cpu'
skipping for-loop for batch_size = 64 and device = 'cuda'
BATCH cuda           - mean: 871.31  ms, median: 871.31  ms, Throughput = 152.3   Megapixel / sec, 73.5    fps
(pt) ➜  vision git:(nvjpeg_bis) ✗ ../libjpeg-turbo/build/tjbench grace_1920x1080.jpg -fastdct -fastupsample -nowrite

Using fastest DCT/IDCT algorithm

Using fast upsampling code

>>>>>  JPEG 4:2:0 --> BGR (Top-down)  <<<<<

Image size: 1080 x 1920
Decompress    --> Frame rate:         139.281709 fps
                  Throughput:         288.814551 Megapixels/sec

2000x2000

(pt) ➜  vision git:(nvjpeg_bis) ✗ python nvjpeg_bench.py grace_2000x2000.jpg
Using grace_2000x2000.jpg
img.shape = torch.Size([3, 2000, 2000]), data.shape = torch.Size([316299])
batch_size = 1
for-loop cpu         - mean: 46.60   ms, median: 46.60   ms, Throughput = 85.8    Megapixel / sec, 21.5    fps
for-loop cuda        - mean: 13.33   ms, median: 13.33   ms, Throughput = 300.0   Megapixel / sec, 75.0    fps
BATCH cuda           - mean: 18.65   ms, median: 18.65   ms, Throughput = 214.5   Megapixel / sec, 53.6    fps
batch_size = 4
for-loop cpu         - mean: 181.49  ms, median: 181.49  ms, Throughput = 88.2    Megapixel / sec, 22.0    fps
for-loop cuda        - mean: 52.01   ms, median: 52.01   ms, Throughput = 307.6   Megapixel / sec, 76.9    fps
BATCH cuda           - mean: 76.38   ms, median: 76.38   ms, Throughput = 209.5   Megapixel / sec, 52.4    fps
batch_size = 16
for-loop cpu         - mean: 717.28  ms, median: 717.28  ms, Throughput = 89.2    Megapixel / sec, 22.3    fps
for-loop cuda        - mean: 205.97  ms, median: 205.97  ms, Throughput = 310.7   Megapixel / sec, 77.7    fps
BATCH cuda           - mean: 387.17  ms, median: 387.17  ms, Throughput = 165.3   Megapixel / sec, 41.3    fps
batch_size = 32
skipping for-loop for batch_size = 32 and device = 'cpu'
skipping for-loop for batch_size = 32 and device = 'cuda'
BATCH cuda           - mean: 785.07  ms, median: 785.07  ms, Throughput = 163.0   Megapixel / sec, 40.8    fps
batch_size = 64
skipping for-loop for batch_size = 64 and device = 'cpu'
skipping for-loop for batch_size = 64 and device = 'cuda'
BATCH cuda           - mean: 1564.96 ms, median: 1564.96 ms, Throughput = 163.6   Megapixel / sec, 40.9    fps
(pt) ➜  vision git:(nvjpeg_bis) ✗ ../libjpeg-turbo/build/tjbench grace_2000x2000.jpg -fastdct -fastupsample -nowrite

Using fastest DCT/IDCT algorithm

Using fast upsampling code

>>>>>  JPEG 4:2:0 --> BGR (Top-down)  <<<<<

Image size: 2000 x 2000
Decompress    --> Frame rate:         77.544314 fps
                  Throughput:         310.177257 Megapixels/sec

@cceyda
Copy link

cceyda commented May 4, 2021

@NicolasHug Is it possible to for you to benchmark nvjpeg with an A100, it is supposed to be much faster than V100 according to: https://developer.nvidia.com/nvjpeg

image

@NicolasHug
Copy link
Member

Sorry @cceyda I don't have access to an A100 but this might be in the roadmap in the future

@tp-nan
Copy link

tp-nan commented Oct 17, 2021

single-threaded

I implemented a quick-and-dirty version of batch-decoding (https://github.com/NicolasHug/vision/blob/380d5b5b563a55d26a69e66bdb1c9076b521031c/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp#L178), and ran a few more benchmarks on a V100. Everything is single-threaded. I didn't observe any significant change when allowing more than one CPU thread for batch-decoding.

Looking at the different throughputs:

  • batch decoding doesn't seem to be a win, although it's slightly faster than manual for-loop for small-ish images
  • nvjpeg is about 2-3X faster than basic CPU decoding (without libjpeg-turbo), which is somewhat consistent with the results reported in https://developer.nvidia.com/nvjpeg
  • nvjpeg is never faster than libjpeg-turbo
  • The gap between nvjpeg and libjpeg-turbo reduces as the image size grows. They're about the same for the big 2000x2000 image.
  1. Is it possible to perform some multi-instances and multi-thread test? multi decoder instances and multi cuda streams could be used. Batched decoding on gpu use less average resource and could produce larger throughout.
  2. usually there are at least four gpus on one machine. the maximum throughout could be test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants