Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increased GPU memory usage when using a cost_fn different from costs.SqEuclidean() #504

Closed
felix0097 opened this issue Mar 19, 2024 · 10 comments · Fixed by #588
Closed

Increased GPU memory usage when using a cost_fn different from costs.SqEuclidean() #504

felix0097 opened this issue Mar 19, 2024 · 10 comments · Fixed by #588

Comments

@felix0097
Copy link

Hi,

I have a question regarding the memory usage when using a cost_fn different from the default costs.SqEuclidean().

I have a large dataset (~240.000 datapoints in x and ~400.000 datapoints in y). If I use pointcloud.PointCloud with the default cost_fn and batch_size=512 everything works fine. However, if I use a different cost_fn e.g. the Cosine cost function, I run out of GPU memory. Even if I further reduce the batch_size (somehow the batch_size argument does not really seem to have an effect anymore).

This works:

geom = pointcloud.PointCloud(
    x, y, batch_size=512
)

ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn()
ot = solver(ot_prob)

This runs out of GPU memory:

geom = pointcloud.PointCloud(
    x, y, 
    cost_fn=Cosine(),
    batch_size=512
)

ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn()
ot = solver(ot_prob)

I'm using ott-jax==0.4.5

Thanks for your help!
Felix

@michalk8
Copy link
Collaborator

Strange, cosine should not have any more memory requirements when using batch_size? In the code above, did you try jitting the solver as solver = jax.jit(sinkhorn.Sinkhorn)? If not, this could lead to some possible memory optimizations.

@michalk8
Copy link
Collaborator

As an alternative, there's a private method called _cosine_to_sqeucl which converts the cosine cost (defined as 1 - cosine_sim(x, y) to SqEucl cost.

@marcocuturi
Copy link
Contributor

this is a detail, but if the jitting above suggested by Michal still fails, do you see a different behavior by inverting the arg positions of x and y?

The current parallelization we have implemented is only line-wise, the size of lines being 400k with your definition. So in practice, the matrices stored at each iteration are 400k x 512 ~ 40k x 5k ~ 20k x 10k which can start being a bit heavy depending on your GPU and on whether you are using float64.

@felix0097
Copy link
Author

I tried jitting the solver. But this didn't work - my notebook kernel just dies then. I think the issue might be that the code ignores the batch_size argument all together. I get the following error message:

2024-03-21 06:49:02.013401: W external[/tsl/tsl/framework/bfc_allocator.cc:291](http://localhost:8888/tsl/tsl/framework/bfc_allocator.cc#line=290)] Allocator (GPU_0_bfc) ran out of memory trying to allocate 377.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[16], line 3
      1 ot_prob = linear_problem.LinearProblem(geom)
      2 solver = sinkhorn.Sinkhorn()
----> 3 ot = solver(ot_prob)

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:864](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=863), in Sinkhorn.__call__(self, ot_prob, init, rng)
    860 initializer = self.create_initializer()
    861 init_dual_a, init_dual_b = initializer(
    862     ot_prob, *init, lse_mode=self.lse_mode, rng=rng
    863 )
--> 864 return run(ot_prob, self, (init_dual_a, init_dual_b))

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:1141](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=1140), in run(ot_prob, solver, init)
   1139 """Run loop of the solver, outputting a state upgraded to an output."""
   1140 iter_fun = _iterations_implicit if solver.implicit_diff else iterations
-> 1141 out = iter_fun(ot_prob, solver, init)
   1142 # Be careful here, the geom and the cost are injected at the end, where it
   1143 # does not interfere with the implicit differentiation.
   1144 out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)

    [... skipping hidden 5 frame]

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py:1178](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/solvers/linear/sinkhorn.py#line=1177), in iterations(ot_prob, solver, init)
   1176 const = ot_prob, solver
   1177 state = solver.init_state(ot_prob, init)
-> 1178 state = fix_point(
   1179     cond_fn, body_fn, solver.min_iterations, solver.max_iterations,
   1180     solver.inner_iterations, const, state
   1181 )
   1182 return solver.output_from_state(ot_prob, state)

File [/vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/math/fixed_point_loop.py:92](http://localhost:8888/lab/tree/git/dataset-similarity/notebooks/miniconda3/envs/similarity/lib/python3.10/site-packages/ott/math/fixed_point_loop.py#line=91), in fixpoint_iter(cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, constants, state)
     86   (_, state), _ = jax.lax.scan(
     87       lambda carry, x: unrolled_body_fn(carry), (0, state),
     88       None,
     89       length=max_iterations // inner_iterations
     90   )
     91 else:
---> 92   _, state = jax.lax.while_loop(max_cond_fn, unrolled_body_fn, (0, state))
     93 return state

    [... skipping hidden 21 frame]

File /vol/data/miniconda3/envs/similarity/lib/python3.10/site-packages/jax/_src/compiler.py:237, in backend_compile(backend, module, options, host_callbacks)
    232   return backend.compile(built_c, compile_options=options,
    233                          host_callbacks=host_callbacks)
    234 # Some backends don't have `host_callbacks` option yet
    235 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    236 # to take in `host_callbacks`
--> 237 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 404834557696 bytes.

The code tries to allocate 377.03GiB of GPU memory. Which corresponds pretty much perfectly to the size of the full cost matrix using float32s: x.shape[0] * y.shape[0] * 4 / 1024**3 = 239696 * 422220 * 4 / 1024**3 = 377.02GiB.

Also, the 377.03GiB are independent of the batch_size I use. I can reduce the batch size and the code still tries to allocate the 377.03GiB of memory.

@michalk8
Copy link
Collaborator

I looked at the test_sinkhorn_online_memory_jit and modified it with cost_fn=costs.Cosine(), it didn't increase much (7.6MiB -> 8.7MiB) on CPU, not sure exactly what's going on above on your system.
@felix0097 What's our JAX/OTT-JAX version?

@michalk8
Copy link
Collaborator

Also, could you please check the generated XLA code using jax.make_jaxpr to see whether it's indeed being materialized there?

@felix0097
Copy link
Author

I'm using jax==0.4.25 and ott-jax==0.4.5 @michalk8.

I've attached the out put of the jax.make_jaxpr function below. I'm not really familiar on how to interpret the results, but the pattern f32[422220,239696] shows up quite a few times: e.g line 188 onwards and 657 onwards.

make_jaxpr.txt

@felix0097
Copy link
Author

I tried the code above on a different system as well and have the same problem. I used the Jax container (version 23.10-py3) from NGC with jax==0.4.17.dev20231020 and jax-ott==0.4.4.

@michalk8
Copy link
Collaborator

Ok, will take a closer look at this. For now, I recommend converting the cosine cost to sqeucl as mentioned above.

@michalk8
Copy link
Collaborator

michalk8 commented Apr 3, 2024

Seems like this is an issue with epsilon=None and it being computed:

Traceback (most recent call last):
  File "/mnt/task_runtime/test.py", line 13, in <module>
    out = solve_fn(geom, min_iterations=1, max_iterations=1)
  File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/_solve.py", line 60, in solve
    return solver(prob)
  File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 864, in __call__
    return run(ot_prob, self, (init_dual_a, init_dual_b))
  File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 1144, in run
    out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)
  File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 342, in set_cost
    return self.set(reg_ot_cost=compute_kl_reg_cost(f, g, ot_prob, lse_mode))
  File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 256, in compute_kl_reg_cost
    fa = ot_prob.geom.potential_from_scaling(ot_prob.a)
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 519, in potential_from_scaling
    return self.epsilon * jnp.log(scaling)
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 162, in epsilon
    return self._epsilon.target
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 150, in _epsilon
    scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 130, in mean_cost_matrix
    tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 376, in apply_cost
    return self._apply_cost(arr, axis, fn=fn)
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 393, in _apply_cost
    return app(
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 784, in _apply_cost_xy
    c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 760, in _cost
    cost = norm_x + norm_y + one_line_pairwise(x, y)
  File "/usr/local/lib/python3.10/dist-packages/ott/geometry/costs.py", line 317, in pairwise
    cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 404834557696 bytes. 

Passing the epsilon=<some float> fixes this, but this is not a good solution overall, we should be pre-computing the statistics before. However, this will require a lot of work and break the API in some places, though in general I support this and will think of a few ways how to do it most efficiently.

@michalk8 michalk8 mentioned this issue Oct 16, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants