Skip to content

Commit

Permalink
Merge pull request #53 from maxfreu/patch-1
Browse files Browse the repository at this point in the history
save allocs during algorithm search
  • Loading branch information
ToucheSir authored Jun 18, 2022
2 parents 24cd95d + a5e4d55 commit d29ab6e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions ext/NNlibCUDA/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
CUDA = "3.3.1"
NNlib = "0.8.6"
CUDA = "3.11"
NNlib = "0.8.7"
julia = "1.6"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions ext/NNlibCUDA/src/cudnn/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)
xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx)
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx, beta!=0)
with_workspace(p.memory) do workspace
cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)
end
Expand All @@ -115,7 +115,7 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw);
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw, beta!=0);
with_workspace(p.memory) do workspace
cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);
end
Expand Down

0 comments on commit d29ab6e

Please sign in to comment.