From f82feb39fe809da43132ec87bc8540b3d5c848f5 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 23 Jan 2021 15:05:52 +0000 Subject: [PATCH] Use at::cuda::getCurrentCUDAStream() --- torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp index 16af9a429d3..54e1258202a 100644 --- a/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp @@ -11,13 +11,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { #else +#include +#include #include static nvjpegHandle_t nvjpeg_handle = nullptr; void init_nvjpegImage(nvjpegImage_t& img) { for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { - img.channel[c] = NULL; + img.channel[c] = nullptr; img.pitch[c] = 0; } } @@ -132,7 +134,9 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { } // TODO torch cuda stream support - // TODO output besides RGB + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + nvjpegStatus_t decode_status = nvjpegDecode( nvjpeg_handle, nvjpeg_state, @@ -140,7 +144,7 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) { data.numel(), outputFormat, &outImage, - /*stream=*/0); + stream); // Destroy the state nvjpegJpegStateDestroy(nvjpeg_state);