Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use im2col and col2im together with GEMM to archive a conv2d operation? #19

Open
tangjiasheng opened this issue Mar 30, 2018 · 2 comments

Comments

@tangjiasheng
Copy link

Can you provide some sample codes? No idea about how to manipulate 5 dims (or 6 dims).

@akhauriyash
Copy link

akhauriyash commented Aug 29, 2018

Check out the link here:
https://leonardoaraujosantos.gitbooks.io/artificial-inteligence/content/making_faster.html

Now, for better understanding:
open im2col.py and add the code below at the bottom. Then run python im2col.py and observe the output:

init = Variable(torch.randn((1, 3, 4, 4)).cuda())
z  = 0
for i in range(init.size(1)):
    for j in range(init.size(2)):
        for k in range(init.size(3)):
            init[0, i, j, k] = z + 1
            z+=1

unrolled = im2col_batch(init, 2, 1, 0)
print(init.size())
print(unrolled.size())
unr_reshape = torch.reshape(unrolled, (1, unrolled.size(1)*unrolled.size(2)*unrolled.size(3), unrolled.size(4)*unrolled.size(5)))
print(unr_reshape)

You should be able to replicate the blog image with this code. I hope this helps in understanding how to manipulate the dimensions!

Cheers!

@akhauriyash
Copy link

akhauriyash commented Aug 29, 2018

Alternatively, just change this function to:

def im2col_batch(input, kernel_size, stride, padding):
    if input.dim() == 3:
        return _im2col(input, kernel_size, stride, padding)
    elif input.dim() == 4:
        shape = (input.size(0),) + im2col_shape(input.size()[1:], kernel_size, stride, padding)
        out = input.new(*shape)
        for x, o in zip(input, out):
             _im2col(x, kernel_size, stride, padding, out=o)
        out = torch.reshape(out, (out.size(0), out.size(1)*out.size(2)*out.size(3), out.size(4)*out.size(5)))
        return out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants