Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Blocked Convolution Proof of Concept #97

Closed
wants to merge 10 commits into from
6 changes: 6 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ version = "0.5.2"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[SIMD]]
deps = ["InteractiveUtils", "Test"]
git-tree-sha1 = "a81f30058aa91fb53c794169436b402c3102a960"
uuid = "fdea26ae-647d-5447-a871-4b548cad5224"
version = "2.6.0"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

Expand Down
2 changes: 1 addition & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ include("impl/depthwiseconv_direct.jl")
# im2col implementations of convolutional and depthwise-convolutional algorithms
include("impl/conv_im2col.jl")
include("impl/depthwiseconv_im2col.jl")

include("impl/blocked_conv.jl")
# Direct implementations of pooling
include("impl/pooling_direct.jl")

Expand Down
165 changes: 165 additions & 0 deletions src/impl/blocked_conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
export blocked_conv, block, deblock
using SIMD

@inline function remove_singleton_spatial_dimension(x::AbstractArray)
return reshape(x, size(x)[1:end-3]..., size(x)[end-1:end]...)
end

@inline function remove_singleton_spatial_dimension(x, reps::Int)
for r in 1:reps
x = remove_singleton_spatial_dimension(x)
end
return x
end

# unoptimized blocking
function block(x, block_axis = 3, block_len = 8)
@assert size(x)[block_axis]%block_len == 0
shape = [i for i in size(x)]
shape[block_axis] /= block_len
insert!(shape, block_axis, block_len)
permute = vcat([block_axis], [i for i=1:length(shape) if i != block_axis])
permutedims(reshape(x,Tuple(shape)), permute)
end

function deblock(x, block_axis = 3)
permute = [i for i = 2:length(size(x))]
insert!(permute, block_axis, 1)
shape = [size(x)[i] for i = 2:length(size(x))]
shape[block_axis] *= size(x)[1]
reshape(permutedims(x, permute), Tuple(shape))
end

##Iteration indicies, outer to inner:
# batch - n
# out_channels - j
# out depth - dₒ
# in channels - i
# filter depth - dₖ
# filter height - hₖ
# filter width - wₖ
# out height - hₒ
# out width - wₒ
# in blocked channels - ii
# out blocked channels (simd'd), jj
function blocked_conv_inner_loop!(Out::Array{T,6},
X::Array{T,6},
W::Array{T,7},
ol::Int64,
::Type{Vec{B,T}},
cdims::DenseConvDims) where {B,T}
cₒ, cᵢ, Wₖ, Hₖ, Dₖ, Cᵢ, Cₒ = size(W)
cₒ, Wₒ, Hₒ, Dₒ, Cₒ, N = size(Out)
cᵢ, Wᵢ, Hᵢ, Dᵢ, Cᵢ, N = size(X)
p = padding(cdims)
s = stride(cdims)
d = dilation(cdims)
padded_regions, central_region = calc_padding_regions(cdims)
# get fused loop indexes
ool = ol - 1
n = div(ool, Cₒ)
ool -= (n) * Cₒ
j = ool

n += 1
j += 1

#calculate the central region without conditionals
w_region, h_region, d_region = central_region
@inbounds for i = 1:Cᵢ, dₒ = d_region, hₒ = h_region, dₖ = 1:Dₖ, hₖ = 1:Hₖ, wₖ = 1:Wₖ, wₒ = w_region
# pre-calculate indexes for the inner loop
dᵢ = 1 + s[3] * (dₒ - 1) + d[3] * (dₖ - 1) - p[5]
hᵢ = 1 + s[2] * (hₒ - 1) + d[2] * (hₖ - 1) - p[3]
wᵢ = 1 + s[1] * (wₒ - 1) + d[1] * (wₖ - 1) - p[1]

F_w = Wₖ - (wₖ - 1)
F_h = Hₖ - (hₖ - 1)
F_d = Dₖ - (dₖ - 1)
@inbounds for ii = 1:B
tmpI = Vec{8, T}(X[ii, wᵢ, hᵢ, dᵢ, i, n])
tmpO = vload(Vec{B, T}, view(Out, :, wₒ, hₒ, dₒ, j, n), 1)
tmpW = vload(Vec{B, T}, view(W, :, ii, F_w, F_h, F_d, i, j), 1)
tmpOut = fma(tmpI, tmpW, tmpO)
vstore(tmpOut, view(Out, :, wₒ, hₒ, dₒ, j, n), 1)
end
end

#calculate the regions with conditionals
@inbounds for (w_region, h_region, d_region) in padded_regions
@inbounds for i =1:Cᵢ, dₒ = d_region, hₒ = h_region, dₖ = 1:Dₖ, hₖ = 1:Hₖ, wₖ = 1:Wₖ, wₒ = w_region
# pre-calculate indexes for the inner loop
dᵢ = 1 + s[3] * (dₒ - 1) + d[3] * (dₖ - 1) - p[5]
hᵢ = 1 + s[2] * (hₒ - 1) + d[2] * (hₖ - 1) - p[3]
wᵢ = 1 + s[1] * (wₒ - 1) + d[1] * (wₖ - 1) - p[1]
# Check for over-input
if (hᵢ < 1 || wᵢ < 1 || dᵢ < 1 || hᵢ > Hᵢ || wᵢ > Wᵢ || dᵢ > Dᵢ)
continue
end
F_w = Wₖ - (wₖ - 1)
F_h = Hₖ - (hₖ - 1)
F_d = Dₖ - (dₖ - 1)
@inbounds for ii = 1:B
tmpI = Vec{8, T}(X[ii, wᵢ, hᵢ, dᵢ, i, n])
tmpO = vload(Vec{B, T}, view(Out, :, wₒ, hₒ, dₒ, j, n), 1)
tmpW = vload(Vec{B, T}, view(W, :, ii, F_w, F_h, F_d, i, j), 1)
tmpOut = fma(tmpI, tmpW, tmpO)
vstore(tmpOut, view(Out, :, wₒ, hₒ, dₒ, j, n), 1)
end
end
end
end

function blocked_conv!(Out::Array{T,6},
X::Array{T,6},
W::Array{T,7},
cdims::DenseConvDims) where T<:Number
@assert size(Out)[1] == size(W)[1]
@assert size(X)[1] == size(W)[2]
## Fuse a few outer loops to make sure we have enough jobs for the threads
## Most important if it's a low batch size kernel
out_loop_size = size(Out)[6] * size(Out)[5]
@inbounds Threads.@threads for ol = 1:out_loop_size
blocked_conv_inner_loop!(Out, X, W, ol, Vec{size(X)[1],T}, cdims)
end
end


function blocked_conv(X::Array{T,6}, W::Array{T,7}, cdims::DenseConvDims) where T<:Number
Out = zeros(T, size(W,1), output_size(cdims)...,
div(channels_out(cdims),size(W, 1)), size(X, 6))
blocked_conv!(Out, X, W, cdims)
Out
end

for N in (3, 4)
@eval begin
function $(Symbol("blocked_conv!"))(
y::AbstractArray{yT,$(N+1)}, x::AbstractArray{xT,$(N+1)},
w::AbstractArray{wT,$(N+2)}, cdims::ConvDims) where {yT, xT, wT}
$(Symbol("blocked_conv!"))(
insert_singleton_spatial_dimension(y, $(5 - N)),
insert_singleton_spatial_dimension(x, $(5 - N)),
insert_singleton_spatial_dimension(w, $(5 - N)),
insert_singleton_spatial_dimension(cdims, $(5 - N))
)

# We explicitly return `y` here, because the backend call
# itself may return a reshaped view, which we don't want.
return y
end
end
@eval begin
function $(Symbol("blocked_conv"))(
x::AbstractArray{xT,$(N+1)},
w::AbstractArray{wT,$(N+2)}, cdims::ConvDims) where {yT, xT, wT}
remove_singleton_spatial_dimension(
$(Symbol("blocked_conv"))(
insert_singleton_spatial_dimension(x, $(5 - N)),
insert_singleton_spatial_dimension(w, $(5 - N)),
insert_singleton_spatial_dimension(cdims, $(5 - N))
),
$(5 - N)
)
end
end
end
60 changes: 60 additions & 0 deletions test/blocked_conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using BenchmarkTools
using Test
using NNlib
using LinearAlgebra

BLAS.set_num_threads(4)

function test_blocked_conv(im_size,
k_size,
rank,
pad,
stride,
dilation;
benchmark = false)
X_shape = vcat([im_size for i in 1:rank], [32, 128])
W_shape = vcat([k_size for i in 1:rank], [32, 16])

X = rand(Float32, X_shape...)
W = rand(Float32, W_shape...)

bX = block(X, rank + 1)
bW = block(block(W, rank + 1), rank + 3)


if benchmark
println("Data Shape: $(size(X))")
println("Weight Shape: $(size(W))")
println("pad=$pad, stride=$stride, dilation=$dilation")
# print("block_data: ")
# @btime block($X, $(rank + 1))
# print("block_weights: ")
# @btime block(block($W, $(rank + 1)), $(rank + 3))


c = DenseConvDims(X, W; stride = stride, dilation = dilation, padding = pad)
print("blocked_conv2d: ")
@btime Out1 = blocked_conv($bX, $bW, $c)
print("NNlib.conv: ")
@btime Out2 = conv($X, $W, $c)
println()
end
c = DenseConvDims(X, W; stride = stride, dilation = dilation, padding = pad)
Out1 = blocked_conv(bX, bW, c)
Out2 = conv(X, W, c)
@test isapprox(deblock(Out1, rank + 1), Out2)
end

do_benchmarking = false

for im_size = [32, 64, 128, 192]
for k_size = [5]
for pad = [3], stride = [2], dilation = [2]
test_blocked_conv(im_size, k_size, 1, pad, stride, dilation, benchmark = do_benchmarking)
test_blocked_conv(im_size, k_size, 2, pad, stride, dilation, benchmark = do_benchmarking)
if im_size <= 32
test_blocked_conv(im_size, k_size, 3, pad, stride, dilation, benchmark = do_benchmarking)
end
end
end
end