diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index 16af9a429d3..54e1258202a 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -11,13 +11,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { #else +#include +#include #include static nvjpegHandle_t nvjpeg_handle = nullptr; void init_nvjpegImage(nvjpegImage_t& img) { for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { - img.channel[c] = NULL; + img.channel[c] = nullptr; img.pitch[c] = 0; } } @@ -132,7 +134,9 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { } // TODO torch cuda stream support - // TODO output besides RGB + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + nvjpegStatus_t decode_status = nvjpegDecode( nvjpeg_handle, nvjpeg_state, @@ -140,7 +144,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { data.numel(), outputFormat, &outImage, - /*stream=*/0); + stream); // Destroy the state nvjpegJpegStateDestroy(nvjpeg_state);