Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for decoding jpegs on GPU with nvjpeg #3792

Merged
merged 37 commits into from
May 11, 2021

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented May 7, 2021

Closes #2786
Closes #2742

This is based on @jamt9000's great initial work in #2786. I mostly made some minor clean ups and added some tests.

In terms of usage, the current supported API is to add a device parameter to io.image.decode_jpeg.

For benchmarks, see #2786 (comment). Overall this seems to offer a 2-3X speedup over CPU decoding (without libjpeg-turbo).

Note: We use nvjpegCreateSimple() which is only available in cuda >= 10.1. So even though nvjpeg exists for 10.0, this won't compile. I assume this is OK since our CI only tests 10.1 and upwards anyway.

While we can move forward with this simple basic version, there seem to be room for improvement. In particular:

@NicolasHug NicolasHug marked this pull request as ready for review May 7, 2021 16:25
@NicolasHug NicolasHug changed the title WIP Support for decoding jpegs on GPU with nvjpeg Support for decoding jpegs on GPU with nvjpeg May 7, 2021
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I'll have a more thorough look next week. I only had a quick check and left a few comments for discussion. Let me know your thoughts.

Edit: I had a second look. Overall the PR looks great. I flagged two more things to discuss prior merging. With the exception of the potential memory leak that we need to investigate, all other comments are questions around the API so there is no need to modify your code.


#else

static nvjpegHandle_t nvjpeg_handle = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding this on an anonymous namespace as it looks an internal detail of the implementation. Also just checking whether this should be released at any point to avoid memory leaks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just checking whether this should be released at any point to avoid memory leaks

Yeah this is a good point. Creating / allocating it at each call has some severe overhead so it makes sense to declare this as a global variable (related discussion: #2786 (comment)). But this means we never really know when to release it, and the memory will only be freed when the process is killed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference: this is thread-safe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good that it's thread-safe but it's still unclear to me whether we have to find a way to release it or if we can leave it be. We don't have such an idiom at TorchVision but I wonder if there are examples of resources on PyTorch core that are never released.

@ezyang how do you handle situations like this on core?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We leak a bunch of global statically initialized objects, as it is the easiest way to avoid destructor ordering problems. If nvjpeg is very simple library you might be able to write an RAII class for this object and have it destruct properly on shutdown.

Precedent in PyTorch is the cudnn convolution cache, look at that for some inspiration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang Thanks for the insights.

@NicolasHug Given Ed's input, I think we can mark this as resolved if you want.

@@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
.op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the dispatcher would make sense here. Since this is the first IO method we add for GPU, it might be worth checking the naming conventions (_cuda) as this will be reproduced on the near future in other places. Thoughts @fmassa ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the dispatcher would be good, but I'm not sure how it handles constructor functions (like torch.empty / torch.rand).

Indeed, this function always takes CPU tensors, and it's up to a device argument to decide if we should dispatch to the CPU or the CUDA version.

@ezyang do you know if we can use the dispatcher to dispatch taking a torch.device into account, knowing that all tensors live in the CPU?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa How about reading the data on CPU since that's needed and then calling to() to move it on the right device. This can happen in the python side of things and remain hidden. Then after the binary data living in GPU, the dispatcher can be used as normal to decide if the decoding should happen on the GPU or CPU. Thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, nvjpeg requires the input data to live on CPU, so we would need to move it back to CPU again within the function, which would be inefficient. I would have preferred if we could pass the tensor directly as a CUDA tensor as well, but I'm not sure this is possible without further overheads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarifications concerning nvjpeg. I think that we can investigate on future PRs how we could do this more elegantly. No need to block this PR.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me, thanks!

I've made a few comments, I none of which are merge-blocking I think.

img_nvjpeg = f(data, mode=mode, device='cuda')

# Some difference expected between jpeg implementations
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to consider the mean different or the max difference here? What would be the minimum value so that max tests pass here?

Copy link
Member Author

@NicolasHug NicolasHug May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max error can be quite high unfortunately, the minimum threshold for all tests to pass seems to be 52, after which some tests start failing.
In test_decode_jpeg, we also test for the MAE (with the same threshold=2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum, this looks suspicious that the decoding gives such large differences. Something to keep an eye on

nvjpegImage_t out_image;

for (int c = 0; c < num_channels_output; c++) {
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is fine for now, but this adds extra overhead as we need to create a full Tensor just to extract the data pointer (and the Tensor construction is heavy). Given that we generally only have 3 channels that shouldn't be much of an issue, but still good to keep in mind.

Some alternatives would be to directly use the raw data_ptr adding the correct offsets, like

uint8_t * out_tensor_ptr = out_tensor.data_ptr<uint8_t>();
...
out_image.channel[c] = out_tensor_ptr + c * height * width;

Also, interesting that nvjpeg accept decoding images in both CHW and HWC formats -- I wonder if there is any performance implications by decoding it in CHW?


#else

static nvjpegHandle_t nvjpeg_handle = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference: this is thread-safe

Comment on lines 52 to 55
TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should clear this if the creation fails, otherwise we might not be able to run this function anymore as the handle will be invalid?

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@NicolasHug
Copy link
Member Author

CI is green(ish) so I'll merge.

Thanks everyone for the reviews and especially @jamt9000 for the initial work!
I'll follow up and open an issue with potential future improvements

if (create_status != NVJPEG_STATUS_SUCCESS) {
// Reset handle so that one can still call the function again in the
// same process if there was a failure
free(nvjpeg_handle);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's an opaque handle I think the use of free() may not be correct unless it's documented as being supported.

(I would hope that it simply leaves the handle as null if initialisation fails, although I don't see that in the docs - here it just reinits the handle without any freeing when hw backend fails though)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm that's a good point. What would you recommend instead of free?

Before pushing this I did a quick test by inserting

       nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);
       create_status = NVJPEG_STATUS_NOT_INITIALIZED;   // <- this

and all the tests were failing gracefully with E RuntimeError: nvjpegCreateSimple failed: 1. Since I was running the tests with pytest test/test_image.py -k cuda they were all in the same process and pytest was just catching the RuntimeErrors, so I assumed it was OK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't use it because I was wondering whether nvjpegDestroy would properly work with a bad handle?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if you assume anything can happen if initialisation fails then it might end up being an arbitrary value like 0xDEADBEEF and all you can do is reset it to null.

@NicolasHug NicolasHug merged commit f87ce88 into pytorch:master May 11, 2021
@jamt9000 jamt9000 mentioned this pull request May 11, 2021
13 tasks
facebook-github-bot pushed a commit that referenced this pull request May 19, 2021
Summary: Co-authored-by: James Thewlis <[email protected]>

Reviewed By: datumbox

Differential Revision: D28473331

fbshipit-source-id: d82d415e81876b660e599997860c737848d9afc0
NicolasHug added a commit to NicolasHug/vision that referenced this pull request May 19, 2021
Summary: Co-authored-by: James Thewlis <[email protected]>

Reviewed By: datumbox

Differential Revision: D28473331

fbshipit-source-id: d82d415e81876b660e599997860c737848d9afc0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement JPEG decoding via nvjpeg
6 participants