-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Use Tullio for pairwise distances #386
base: master
Are you sure you want to change the base?
Conversation
This is cool but I would definitely want to see more extensive benchmarking before definitely adopting this -- I wonder whether Tullio's performance drops off for quite large problems? Either way, I feel like we need some graphs -- maybe a good thing to add to the benchmarking day list? |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…Functions.jl into tgf/tullio
Tests are failing because of the AD function colwise(d::BinaryOp, x::AbstractVector, y::AbstractVector)
return @tullio out[i] := d(x[i], y[i])
end the differentiation given ┌ Warning: symbolic gradient failed
│ err = "no diffrule found for function d(_, _)."
└ @ Tullio ~/.julia/packages/Tullio/qPZkO/src/macro.jl:1264 We could use the |
You can disable the symbolic differentiation with The macro also tries to provide a gradient for use with Tracker or (via ChainRules) for Zygote, Yota, etc. (Disable with grad=false, or nograd=A.) This is done in one of two ways: I think |
Codecov Report
@@ Coverage Diff @@
## master #386 +/- ##
===========================================
- Coverage 93.09% 31.89% -61.20%
===========================================
Files 52 53 +1
Lines 1202 1182 -20
===========================================
- Hits 1119 377 -742
- Misses 83 805 +722
Continue to review full report at Codecov.
|
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…Functions.jl into tgf/tullio
As far as I can tell, there is no concrete solution with Tullio for this problem.
rrule(::typeof(pairwise), metric, x, y)
return val, pull = pullback( , x, y) do x, y
metric.(x, permutedims(y))
end
end
|
I think this could be fixed by replacing https://github.com/mcabbott/Tullio.jl/blob/93278c6bf0441382fde9c52fedac8dc41e3e4648/src/eval.jl#L49-L58 with function ChainRulesCore.rrule(ev::Eval, args...)
Z = ev.fwd(args...)
function tullio_back(Δ)
dxs = map(ev.rev(Δ, Z, args...)) do dx
dx === nothing ? ChainRulesCore.ZeroTangent() : dx
end
return (ChainRulesCore.ZeroTangent(), dxs...)
end
return Z, tullio_back
end
# without gradient definition
ChainRulesCore.@opt_out rrule(ev::Eval{<:Any,Nothing}, args...) |
Even if it does not fix our specific use case here, I think it deserves a PR. I'll make one later today. |
Hmm this works but it does not really help. The current error is gone but since Zygote can't differentiate mutating functions it can't differentiate through it. It's helpful though for AD systems that support mutation 🤷 |
Benchmark resultJudge resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
Benchmark resultJudge resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
Benchmark resultJudge resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
Benchmark resultJudge resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
Benchmark resultJudge resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
Benchmark resultJudge resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
This is an extended discussion on the status of this PR and using Tullio for computing pairwise distances. This is the follow-up of the discussion during out last meeting. Distances vs Tullio
The current issues with Tullio (and its solutions)
|
One slightly different option, is to implement things as a mix of Tullio and standard functionality, in a way that doesn't require implementing new For example, suppose that function pairwise(d, x::RowVecs, y::RowVecs)
@tullio D[i, j] = (x[i, k] - y[j, k])^2
return map(f, D)
end Yes, you've had to implement a specialise method of Alternatively, presumably there's a way to teach Tullio about any particular |
Thanks for the write-up! It makes me wonder though what exactly the benefits from switching to Tullio would be. Unfortunately, IIRC my PR wouldn't help with the issues that you faced here. And if we have to implement ChainRules derivatives then one of the main selling points of Tullio (AD) would not be helpful in our case and we could just implrment and optimize rules for Distances. I'm also curious in what cases Tullio reduces allocations. I assume you talk about Speaking about So it seems the main argument for Tullio would be GPU compatibility. IIRC people tried to address this already with packages such as DistancesCUDA (not sure about the name...). In defense of Distances, my impression is not that it is completely outdated. It's actively maintained and there were many improvements and refactorings in the last months. In my possibly biased opinion its source code is also much cleaner and simpler and easier to read than Tullio (probably not too surprising - to some extent - given the different goals of both packages). |
Regarding GPU support, we or some separate package could just restrict our custom pairwise etc implementations such that we only forward specific array types to Distances and handle the rest in a non-mutable way (or the other way around). |
@willtebbutt The problem with this, is that it does not solve the problem of the generic fallback. Of course one could just use broadcasting
I was wrong about this (I confused with the other
That sounds more like a temporary solution right? And also a lot more work! I think Tullio would at least allow us to have a general solution (and if my AD proposal works very little work from our side) |
I think that's what ideally you want to do. Based on specific traits of arrays such as the ones in ArrayInterface you use different implementations (e.g. in-place or out-of-place). That's what SciML does all the time, e.g. all ODE methods are implemented twice, once for in-place and once for out-of-place functions. So if such traits (would) become available in base, then this would be what Distances should and probably would do. Unfortunately, ArrayInterface is a quite heavy dependency so currently such traits are not available in Distances. |
Benchmark resultJudge resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/KernelFunctions.jl/KernelFunctions.jlJob Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
I have been trying to get this minimally working, but I am observing that it is terribly slow on the GPU. Am I doing something completely wrong? julia> using Tullio, BenchmarkTools, CUDA
julia> x = rand(5, 5);
julia> test(x) = @tullio D[i, j] := (x[k, i] - x[k, j])^2 grad=false;
julia> test2(x) = @tullio D[i, j] := (x[i, k] - x[j, k])^2 grad=false;
julia> @btime test($x);
104.748 ns (1 allocation: 256 bytes)
julia> @btime test($(x |> cu));
2.989 ms (834 allocations: 129.17 KiB)
julia> @btime test2($x);
118.719 ns (1 allocation: 256 bytes)
julia> @btime test2($(x |> cu));
2.973 ms (834 allocations: 129.17 KiB) |
Summary
We have a long-time problem for binary operations like
DotProduct
not satisfying the requirements of theDistances.jl
framework (not a proper metric). Additionally,Distances.jl
is very incompatible with GPU operations (see JuliaStats/Distances.jl#143 and JuliaStats/Distances.jl#137).Using
Tullio.jl
should solve both these problems. Some quick benchmarks shows that Tullio is both faster and more GPU-able than Distances.jlThere is a longer discussion about this PR in #380
This should also close #98 and replace #194
Proposed changes
pairwise
andcolwise
Distances.jl
are their types (we stop usingDistances.pairwise
).pairwise
forColVecs
andRowVecs
when possible to improve speed (and GPU compatibility)AbstractBinaryOp
abstract type for objects likeDotProduct
andDelta
and combine them with Distances usingBinaryOp = Union{AbstractBinaryOp,Distances.PreMetric}
.What alternatives have you considered?
Dropping Distances.jl operations anyway but without Tullio but Tullio shows it's faster.
Breaking changes