-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inferrability of cat
#149
Inferrability of cat
#149
Conversation
cat_channels(xy...) = cat(xy...; dims = 3) | ||
cat_channels(xy...) = inferredcat(xy...; dims = 3) | ||
|
||
function inferredcat(xs::T...; dims = :)::T where T <: AbstractArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that this limits it to arrays of the same type
As luck would have it, https://github.com/JuliaDiff/ChainRules.jl/blob/8c34f19d3a8a8a224c9fbe20524d2a08c8b9bf81/src/rulesets/Base/array.jl#L345 actually is type stable if you invoke Given the relative ease of this approach, it's also worth exploring changing |
The julia lang pr should resolve the inferrability of the Zygote adjoint. Fwiw, I had tried threading the existing |
It's unclear at this point whether the base PR will be backported. In contrast, using the ChainRule works back to at least 1.6 and is something we should be doing anyways (deleting old redundant adjoints in Zygote to reduce our maintenance burden). |
Happy to backport if that is the blocker. |
My point was that we can get a type stable cat_channels(xy...) = cat(xy...; dims = Val(3)) The Base PR is nice but not at all necessary for this approach. |
Now that the aforementioned 2 PRs have landed and #170 has changed |
cat
fails to infer for non constant arguments (possible improvement in JuliaLang/julia#45028). This causes a lot of churn in the models that make use ofcat
. I defined anrrule
but that didn't infer properly (I think I know why, but we can protect ourselves from accidental regressions).rrule
:adjoint
:I also ran it against FluxBench and pulled out some tests from there
Note that there is a massive regression in the compile time of certain models (like DenseNet)