-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Zygote compatibility #42
Comments
I'd very much like to see Zygote support too. As far as I understand Zygote, the first choice should generally be not mutating at all and writing in a more functional style which would be quite possible here: we would just have to explicitly write non-mutating methods for all manifolds instead of relying on mutating variants and |
I don't know if that's the case long term. I think in the short term My thought is to start by adding a macro for defining new functions that have mutating and non-mutating versions, since we have a lot of duplication there. That would use the above pattern with For custom adjoints, in the short term we can define them using |
For one thing,
Making such a macro may be quite complicated. But it could work.
Sounds reasonable, the more AD backends we can support the better. |
I haven't mixed Zygote with StaticArrays extensively, but in my simple tests, it was slower than using a standard If passed a static array, julia> using Zygote, StaticArrays, BenchmarkTools
julia> function foo(x)
y = similar(x)
copyto!(y, x^2)
return y
end
foo (generic function with 1 method)
julia> function bar(x)
y = Zygote.Buffer(x)
copyto!(y, x^2)
return copy(y)
end
bar (generic function with 1 method)
julia> x = @SMatrix randn(3,3);
julia> @benchmark foo(x)
BenchmarkTools.Trial:
memory estimate: 80 bytes
allocs estimate: 1
--------------
minimum time: 41.659 ns (0.00% GC)
median time: 43.811 ns (0.00% GC)
mean time: 58.513 ns (12.49% GC)
maximum time: 7.474 μs (98.40% GC)
--------------
samples: 10000
evals/sample: 990
julia> @benchmark bar(x)
BenchmarkTools.Trial:
memory estimate: 80 bytes
allocs estimate: 1
--------------
minimum time: 42.524 ns (0.00% GC)
median time: 44.118 ns (0.00% GC)
mean time: 59.761 ns (12.47% GC)
maximum time: 8.164 μs (98.92% GC)
--------------
samples: 10000
evals/sample: 990
julia> x2 = collect(x);
julia> @benchmark foo(x2)
BenchmarkTools.Trial:
memory estimate: 320 bytes
allocs estimate: 2
--------------
minimum time: 127.207 ns (0.00% GC)
median time: 145.000 ns (0.00% GC)
mean time: 184.259 ns (7.84% GC)
maximum time: 5.387 μs (94.93% GC)
--------------
samples: 10000
evals/sample: 900
I looked at |
I've taken a look at Do you know why |
I have a local branch where I've begun work on this, and I'll open a breaking WIP PR.
Not specifically. It seems to be its main point though. I think the idea is to sidestep |
Great! I have some experience with broadcasting in
I don't fully understand that but that apparently doesn't work with broadcasting at all? So you have to do quite a lot in that PR 🙂 . julia> a = Zygote.Buffer([1.0 2.0; 3. 4.])
Zygote.Buffer{Float64,Array{Float64,2}}([1.390671161567e-309 0.0; 6.9043377050637e-310 0.0], false)
julia> a .*= 2
ERROR: MethodError: no method matching iterate(::Zygote.Buffer{Float64,Array{Float64,2}})
Closest candidates are:
iterate(::Core.SimpleVector) at essentials.jl:604
iterate(::Core.SimpleVector, ::Any) at essentials.jl:604
iterate(::ExponentialBackOff) at error.jl:214
...
Stacktrace:
[1] copyto!(::Array{Float64,1}, ::Zygote.Buffer{Float64,Array{Float64,2}}) at ./abstractarray.jl:722
[2] _collect(::UnitRange{Int64}, ::Zygote.Buffer{Float64,Array{Float64,2}}, ::Base.HasEltype, ::Base.HasLength) at ./array.jl:550
[3] collect(::Zygote.Buffer{Float64,Array{Float64,2}}) at ./array.jl:544
[4] broadcastable(::Zygote.Buffer{Float64,Array{Float64,2}}) at ./broadcast.jl:659
[5] broadcasted(::Function, ::Zygote.Buffer{Float64,Array{Float64,2}}, ::Int64) at ./broadcast.jl:1213
[6] top-level scope at REPL[16]:1 |
Sure, though I think you'll get the notification anyways. I'm going to start here, and if it looks like changes are needed to StaticArrays, we can go that route.
😧 well this will be fun.
The tricky thing is making sure we've considered the edge cases with our special array-wrapping types like |
After putting some time into this, I now think that the functional approach will be cleaner and easier to support going forward. A few downsides: 1) Some loss in efficiency, but how much? 2) |
One possible way forward is to have both mutating and non-mutating versions of all functions. Having only non-mutating versions is not acceptable for me as this would very significantly slow down my computations. I've originally written my FunManifolds.jl code in a non-mutating style and then I was slowly replacing it with mutating code which resulted in massive speed-ups. |
Ugh, that will just be so frustrating to keep synchronized. It's too bad we don't have better tools for writing general interfaces that support mutating and functional styles in base Julia. I've heard rumors that some such features might be coming for 2.0. In the meantime, maybe getting |
Generic tests can take care of checking that everything works right. Having both variants might be less frustrating than dealing with
I didn't know that 2.0 could have a better support for that. I wonder how that would work. Anyway, what we have looks more like a Zygote problem than a Julia problem.
Right, and mutation might not even be the biggest problem on the road to Zygote compatibility. We'll see. |
It would be nice to add compatibility for Zygote. I know I personally need Zygote for its complex number support, and I'd also like to use
Manifolds
in the same code.For the most part, we can probably expect Zygote to just work, with one major block: Zygote doesn't support mutation. I don't think this will be a big problem though. Zygote offers
Zygote.Buffer
, which behaves just likesimilar
and allows mutation. All we should need to do is add something like thisEvery non-mutating function that uses
similar_result
then returns the output offinalize_result
, and Zygote should just work.Of course, this just brings Zygote support to the same level as ReverseDiff. It doesn't handle the issues raised in #17. Although it should be easier to define custom behavior for embedded manifolds with Zygote than with ReverseDiff.
The text was updated successfully, but these errors were encountered: