Skip to content

Commit

Permalink
Use NNlib.conv_bias_act for Conv
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Feb 7, 2022
1 parent 7b56813 commit 6573f65
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,13 @@ function (c::Conv)(x::AbstractArray)
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ = NNlib.fast_act(c.σ, x)
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
σ.(conv(x, c.weight, cdims) .+ b)
_conv_bias_act(x, c.weight, cdims, b, σ)
end

_conv_bias_act(x, w, cdims, b, σ) = NNlib.conv_bias_act(x, w, cdims, b, σ)
_conv_bias_act(x::CuArray, w::CuArray, cdims, b::Zeros, σ) =
_conv_bias_act(x, w, cdims, CUDA.zeros(size(b)...), σ)

_channels_in(l ::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
_channels_out(l::Conv) = size(l.weight, ndims(l.weight))

Expand Down

0 comments on commit 6573f65

Please sign in to comment.