Skip to content

Commit

Permalink
rm threadpool arg
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Nov 28, 2018
1 parent 25dcd3c commit 8b91c9b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/nnpack/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ function softmax!(x::A) where A<:AbstractVecOrMat{Float64}
end

softmax!(x::A) where A<:AbstractVecOrMat{Float32} =
nnp_softmax_output(x, x, threadpool = shared_threadpool[])
nnp_softmax_output(x, x)

softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float64} = softmax!(Float32.(y), Float32.(x))

softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float32} =
nnp_softmax_output(x, y, threadpool = shared_threadpool[])
nnp_softmax_output(x, y)

softmax(x::A) where A<:AbstractVecOrMat{Float64} = softmax(Float32.(x))

softmax(x::A) where A<:AbstractVecOrMat{Float32} =
nnp_softmax_output(x, similar(x), threadpool = shared_threadpool[])
nnp_softmax_output(x, similar(x))

maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
maxpool(Float32.(x), k, pad = pad, stride = stride)
Expand All @@ -40,7 +40,7 @@ maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Fl
maxpool!(Float32.(y), Float32.(x), k, pad = pad, stride = stride)

maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4} =
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride), threadpool = shared_threadpool[])
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride))

conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
conv(Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
Expand Down Expand Up @@ -88,7 +88,7 @@ function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, al
if flipkernel == 0
w = reverse(reverse(w, dims=1), dims=2)
end
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride)
end

crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
Expand All @@ -104,7 +104,7 @@ function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
if fallback
conv2d_grad_x!(zeros(Float32, size(x)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
else
else
∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
end
end
Expand All @@ -114,7 +114,7 @@ end

function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool[])
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo)
end

∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
Expand All @@ -124,7 +124,7 @@ function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, al
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
if fallback
conv2d_grad_w!(zeros(Float32, size(w)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
else
else
∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
end
end
Expand All @@ -134,6 +134,6 @@ end

function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
dw .= nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool[])
dw .= nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)
flipkernel == 0 ? reverse(reverse(dw, dims=1), dims=2) : dw
end

0 comments on commit 8b91c9b

Please sign in to comment.