Skip to content

Commit

Permalink
Use at::cuda::getCurrentCUDAStream()
Browse files Browse the repository at this point in the history
  • Loading branch information
jamt9000 committed Jan 23, 2021
1 parent e485656 commit f82feb3
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {

#else

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <nvjpeg.h>

static nvjpegHandle_t nvjpeg_handle = nullptr;

void init_nvjpegImage(nvjpegImage_t& img) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
img.channel[c] = NULL;
img.channel[c] = nullptr;
img.pitch[c] = 0;
}
}
Expand Down Expand Up @@ -132,15 +134,17 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {
}

// TODO torch cuda stream support
// TODO output besides RGB

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
nvjpeg_state,
datap,
data.numel(),
outputFormat,
&outImage,
/*stream=*/0);
stream);

// Destroy the state
nvjpegJpegStateDestroy(nvjpeg_state);
Expand Down

0 comments on commit f82feb3

Please sign in to comment.