-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
2D Transpose Convolutions #54
Conversation
This is great, but I think we can simplify a bit -- we don't actually need the conv_transpose alias. How about if the ConvTranspose layer just calls the gradient function directly, and we also define the derivative of that function. Aside from not needing the NNlib patch, this has the big bonus that nested AD will then also work through convolutions. |
That sounds cool. I was working on writing the gradient hook up for it. I felt that we should refactor the conv interface. When this setup of
because the EDIT: Ahh nvm, fixed it :) |
This should fix the failing gradtest |
b579b91
to
97b095b
Compare
This has been now resolved for v1.0. |
It'd be good if we could take the opportunity to make the interface a bit more consistent. If I understand correctly, |
I realized that for |
src/conv.jl
Outdated
@@ -36,8 +53,14 @@ function crosscor(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:Abstra | |||
x, w, pad = pad_, stride = stride_, dilation = dilation) | |||
end | |||
|
|||
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel = 0) where A<:AbstractArray = | |||
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel) | |||
function ∇conv_data(dy::A, w::A, x_dims=nothing; pad = 0, stride = 1, dilation = 1, flipkernel = 0) where A<:AbstractArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple small things
- Can you make this a keyword argument
size
- Can you add a simple three-arg wrapper with a deprecation warning
- You should compare
x_dims === nothing
so that type inference can remove the check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When size
method is being used in the function, it gives MethodError
if this argument is named size
. How about using dims
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dims
is a bit weird because it usually refers to the dimensions you act on. How about just calling Base.size
inside the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok great, this is all but there now.
Once this is merged we're going to have to upper-bound CuArrays and Flux when we tag NNlib again. We'll have to update both of those for the new API as well (I know you have a PR for flux already)
if size === nothing | ||
size = cdims(Base.size(x), dilation_dims(w, dilation), pad_, stride_) | ||
end | ||
conv!(similar(x, size), x, w, pad = pad_, stride = stride_, dilation = dilation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the size argument actually make sense for conv
? I don't know if there's a similar ambiguity in the sizes as compared with the transpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conv
acts as the gradient function of input for conv transpose during the backward pass. Hence, just like conv_data
there exists an ambiguity here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok that's fine, just checking
src/conv.jl
Outdated
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray = | ||
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel) | ||
∇conv_filter(dy::A, x::A, size::Tuple; pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray = | ||
∇conv_filter!(zeros(eltype(dy),size), dy, x; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should use similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, you mean to use zero(similar(dy, size))
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just similar should be fine, if you're about to overwrite it anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we have similar
in use in this branch, which is failing at nan
test for ∇conv_filter
. However, if zero
is used instead of similar
then the tests pass.
src/conv.jl
Outdated
|
||
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray = | ||
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel) | ||
∇conv_filter(dy::A, x::A, size::Tuple; pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you intentionally doing data and filter in different ways? Why not use a size kwarg for both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is intentional. Because for conv_data
it serves a dual purpose. When size=nothing
it performs a conv_transpose
, else it is a conv_grad
.
conv_filter
has only one purpose, that is to find the gradient of the filter. Hence, we require the value of size
to be passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But they are not actually different operations, right? conv transpose is just the gradient with a particular inferred size. (Lmk if my understanding is off here.)
So we could in principle just write a size-inference for conv_filter; but if you don't want to do that for now it'd be fine to just do size as a kwarg without a value (which will error if it's not provided).
|
bump! this would be great to have |
I didn't see any big missing pieces, but I am worried that we might not have sufficient test coverage. I am working on getting codecov or something hooked up so that we can be sure that what we're merging covers as many corner cases as possible. |
@tejank10 can you rebase this on top of the latest |
56fbbb0
to
c4ac366
Compare
Codecov Report
@@ Coverage Diff @@
## master #54 +/- ##
==========================================
+ Coverage 70.46% 72.15% +1.69%
==========================================
Files 9 9
Lines 579 607 +28
==========================================
+ Hits 408 438 +30
+ Misses 171 169 -2
Continue to review full report at Codecov.
|
bump 🙂 is any help needed to get this out the door? |
@tejank10 Great, thanks. Can you synthesize a few more tests to exercise the codepaths that are missing (as evidenced by the code coverage). I'm particularly interested that we hit the first couple of branches in |
Awesome. I'm calling this good, and will be testing it out with an autoencoder experiment in the near future! |
* print/convert batchedadjtrans over cuarray * Update test/batchedadjtrans.jl Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Brian Chen <[email protected]>
* print/convert batchedadjtrans over cuarray * Update test/batchedadjtrans.jl Co-authored-by: Brian Chen <[email protected]> Co-authored-by: Brian Chen <[email protected]>
Added 2D transpose convolutions and tests