Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inferrability of cat #149

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,44 @@ Concatenate `x` and `y` (and any `z`s) along the channel dimension (third dimens
Equivalent to `cat(x, y, zs...; dims=3)`.
Convenient reduction operator for use with `Parallel`.
"""
cat_channels(xy...) = cat(xy...; dims = 3)
cat_channels(xy...) = inferredcat(xy...; dims = 3)

function inferredcat(xs::T...; dims = :)::T where T <: AbstractArray
Copy link
Member Author

@DhairyaLGandhi DhairyaLGandhi Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this limits it to arrays of the same type

cat(xs...; dims = dims)
end

# `rrule` doesn't infer through `_project` neatly
# function Zygote.ChainRules.rrule(::typeof(inferredcat), xs::T...; dims = :)::T where T <: AbstractArray
# sz = size.(xs)
# function inferredcat_pullback(Δ)
# (Zygote.ChainRules.NoTangent(), makesub(Δ, size(Δ)[dims], sz, dims = dims)...,)
# end
# inferredcat(xs...; dims = dims), inferredcat_pullback
# end

Zygote.@adjoint function inferredcat(xs::T...; dims = :) where T <: AbstractArray
sz = size.(xs)
lz = length.(xs)
inferredcat(xs..., dims = dims), Δ -> (partition_grad(Δ, size(Δ)[dims], sz, dims = dims)...,)
end

function partition_grad(d::AbstractArray{T,N}, x, sz; dims = :) where {T,N}
sizeatdim = map(x -> x[dims], sz)
x_start = 1
m = map(enumerate(sizeatdim)) do (i, ix)
x_stop = x_start + ix - 1
p::Base.UnitRange{Int64} = x_start:x_stop
x_start = x_start + ix
p
end
function gen_indices(m, sz)
map(m, sz) do m, sz
ntuple(x -> x == dims ? m : 1:sz[x], N)
end
end
ix = gen_indices(m, sz)
map(ix_ -> @view(d[ix_...]), ix)
end

"""
swapdims(perm)
Expand Down
4 changes: 4 additions & 0 deletions test/infer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@testset "Inferrability" begin
r = rand(Float32, 56, 56, 64, 1)
@inferred gradient((x,y) -> sum(Metalhead.cat_channels(x, y)), r, r)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ GC.gc()
# Other tests
@testset verbose = true "Other" begin
include("other.jl")
include("infer.jl")
end

GC.gc()
Expand Down