Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorization memory issues with PointCloud.apply_lse_kernel with online=True #20

Closed
michalk8 opened this issue Feb 26, 2022 · 3 comments · Fixed by #23
Closed

Vectorization memory issues with PointCloud.apply_lse_kernel with online=True #20

michalk8 opened this issue Feb 26, 2022 · 3 comments · Fixed by #23

Comments

@michalk8
Copy link
Collaborator

We've noticed that when online=True, the memory performance got worse when online=False,
After checking the traceback (attached below), we've found that the 2 netsted vmaps in apply_lse_kernel vectorize the code such that n x m matrix is fully materialized (n/m being the number of points the in PointCloud).

I have a small fix in here which uses online: Optional[int] = None as a batch size. Then in apply_lse_kernel, I use lax.scan and within the loop body, the fully vectorized computation for that batch is used - this reduces the memory complexity from O(n * m) -> O(max(n, m) * batch_size).
So far, can only report that I can run sinkhorn on a PointCloud of shape (65536, 32768) using 8516MiB memory (ohline=16384), which previously raised OOM on 16GiB GPU (Tesla T4).
It took 1157s to run sinkhorn with epsilon=1e-2 ; will try to do more comprehensive benchmarks
later.
As for tests, they all pass, except 4:

FAILED tests/core/sinkhorn_bures_test.py::SinkhornTest::test_bures_point_cloud_ker-batch
FAILED tests/geometry/geometry_pointcloud_apply_test.py::ApplyTest::test_apply_cost_and_kernel
FAILED tests/core/sinkhorn_test.py::SinkhornTest::test_apply_transport_geometry_from_potentials
FAILED tests/core/sinkhorn_test.py::SinkhornTest::test_apply_transport_geometry_from_scalings

Think there might be a more efficient approach.

For benchmarking, I've used the same code as in #9
jax_online_err.txt

@michalk8
Copy link
Collaborator Author

We did some benchmarks, the result seems to match online=True/False, as well a before the changes mentioned above.
Am including a notebook base on this tutorial - ott_on_off_check.ipynb.txt.

Regarding other benchmarks, I've used this code:

#!/usr/bin/env python3

from time import perf_counter
import argparse

import numpy as np
import jax.numpy as jnp
import ott
import jax.profiler

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--online", action='store_true')
    parser.add_argument("--lse", action='store_true')
    parser.add_argument("-n", type=int, default=15)
    parser.parse_args()

    args = parser.parse_args()
    online = args.online
    lse = args.lse
    n_obs = args.n

    mod = "online" if online else "offline"
    mod += "_lse" if lse else "_no_lse"
    mod += f"_{n_obs}"

    src = jnp.asarray(np.random.normal(size=(2 ** n_obs, 30)))
    tgt = jnp.asarray(np.random.normal(size=(2 ** (n_obs - 1), 30)))
    print(src.shape, tgt.shape, mod)

    pc = ott.geometry.PointCloud(src, tgt, epsilon=1e-2, online=2 ** (n_obs - 2))
    t = perf_counter()
    r = ott.core.sinkhorn.sinkhorn(pc, lse_mode=lse)
    print(r.transport_mass())
    print("Time:", perf_counter() - t)

and ran XLA_PYTHON_CLIENT_ALLOCATOR=platform nohup python run.py -n 17 --online --lse &
on Nvidia A100. It consumed about 33GiB of memory and the output was:

(131072, 30) (65536, 30) online_lse_17
0.9999974
Time: 986.0139930024743

@michalk8
Copy link
Collaborator Author

michalk8 commented Mar 1, 2022

Regarding the errors, they have been fixed, the problem was that during vmap, dummy objects are called to get shapes of the pytrees and in PointCloud.__init__ and I access the shape to get the batch size.

@LaetitiaPapaxanthos
Copy link
Contributor

LaetitiaPapaxanthos commented Mar 1, 2022

Thanks a lot Michal for spotting the memory issue and proposing a solution. When you have a solution you would like to propose, don't hesitate to do a PR and we will review the code.

michalk8 pushed a commit that referenced this issue Jun 27, 2024
Solving Issue #18 when no epsilon is passed to a geometry defined with a kernel matrix
michalk8 pushed a commit that referenced this issue Jun 27, 2024
PiperOrigin-RevId: 411071648
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants