-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorization memory issues with PointCloud.apply_lse_kernel
with online=True
#20
Comments
We did some benchmarks, the result seems to match Regarding other benchmarks, I've used this code: #!/usr/bin/env python3
from time import perf_counter
import argparse
import numpy as np
import jax.numpy as jnp
import ott
import jax.profiler
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--online", action='store_true')
parser.add_argument("--lse", action='store_true')
parser.add_argument("-n", type=int, default=15)
parser.parse_args()
args = parser.parse_args()
online = args.online
lse = args.lse
n_obs = args.n
mod = "online" if online else "offline"
mod += "_lse" if lse else "_no_lse"
mod += f"_{n_obs}"
src = jnp.asarray(np.random.normal(size=(2 ** n_obs, 30)))
tgt = jnp.asarray(np.random.normal(size=(2 ** (n_obs - 1), 30)))
print(src.shape, tgt.shape, mod)
pc = ott.geometry.PointCloud(src, tgt, epsilon=1e-2, online=2 ** (n_obs - 2))
t = perf_counter()
r = ott.core.sinkhorn.sinkhorn(pc, lse_mode=lse)
print(r.transport_mass())
print("Time:", perf_counter() - t) and ran (131072, 30) (65536, 30) online_lse_17
0.9999974
Time: 986.0139930024743 |
Regarding the errors, they have been fixed, the problem was that during |
Thanks a lot Michal for spotting the memory issue and proposing a solution. When you have a solution you would like to propose, don't hesitate to do a PR and we will review the code. |
Solving Issue #18 when no epsilon is passed to a geometry defined with a kernel matrix
We've noticed that when
online=True
, the memory performance got worse whenonline=False
,After checking the traceback (attached below), we've found that the 2 netsted
vmaps
inapply_lse_kernel
vectorize the code such thatn x m
matrix is fully materialized (n/m
being the number of points the inPointCloud
).I have a small fix in here which uses
online: Optional[int] = None
as a batch size. Then inapply_lse_kernel
, I uselax.scan
and within the loop body, the fully vectorized computation for that batch is used - this reduces the memory complexity fromO(n * m) -> O(max(n, m) * batch_size)
.So far, can only report that I can run
sinkhorn
on aPointCloud
of shape(65536, 32768)
using8516MiB
memory (ohline=16384
), which previously raised OOM on16GiB
GPU (Tesla T4).It took
1157s
to runsinkhorn
withepsilon=1e-2
; will try to do more comprehensive benchmarkslater.
As for tests, they all pass, except 4:
Think there might be a more efficient approach.
For benchmarking, I've used the same code as in #9
jax_online_err.txt
The text was updated successfully, but these errors were encountered: