-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix conv with groups when falling in direct backend #468
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ToucheSir @mcabbott not the most elegant solution, but this is the idea that I had for merging both backends with the type declaration differences. Open for suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good and tests pass. Will merge tomorrow if there are no objections.
This was a nice enhancement, thanks @gabrielpreviato ! |
Fixing #369.
When the convolution filter type is not the same as the input array, conv uses the direct backend
Probably this backed was not updated when grouped convolution was added to conv, which led to some false warnings about dimension mismatches given that the check_dimension function changed as well.
This PR adds the same logic used in im2col backend to the direct backed, so grouped convolutions also works for the direct backend.
What has changed
Now for both conv and depthwiseconv (and their forwarding definitions), we have the functions definitions of both im2col and direct backend in a for-loop.
The idea is that future features or corrections will always be implemented for both backends avoiding the problem that happened in #369.
For this, the loop needs a new element, where we had the frontend and backend, we now have the frontend, backend, and signature. This signature is used for the parametric types of the AbstractArray that are used for the functions since there are some type restrictions for the im2col backend.
PR Checklist