-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Increased GPU memory usage when using a cost_fn different from costs.SqEuclidean() #504
Comments
Strange, cosine should not have any more memory requirements when using |
As an alternative, there's a private method called |
this is a detail, but if the jitting above suggested by Michal still fails, do you see a different behavior by inverting the arg positions of The current parallelization we have implemented is only line-wise, the size of lines being 400k with your definition. So in practice, the matrices stored at each iteration are 400k x 512 ~ 40k x 5k ~ 20k x 10k which can start being a bit heavy depending on your GPU and on whether you are using float64. |
I tried
The code tries to allocate Also, the |
I looked at the |
Also, could you please check the generated XLA code using jax.make_jaxpr to see whether it's indeed being materialized there? |
I'm using I've attached the out put of the jax.make_jaxpr function below. I'm not really familiar on how to interpret the results, but the pattern |
I tried the code above on a different system as well and have the same problem. I used the Jax container (version |
Ok, will take a closer look at this. For now, I recommend converting the cosine cost to sqeucl as mentioned above. |
Seems like this is an issue with Traceback (most recent call last):
File "/mnt/task_runtime/test.py", line 13, in <module>
out = solve_fn(geom, min_iterations=1, max_iterations=1)
File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/_solve.py", line 60, in solve
return solver(prob)
File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 864, in __call__
return run(ot_prob, self, (init_dual_a, init_dual_b))
File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 1144, in run
out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)
File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 342, in set_cost
return self.set(reg_ot_cost=compute_kl_reg_cost(f, g, ot_prob, lse_mode))
File "/usr/local/lib/python3.10/dist-packages/ott/solvers/linear/sinkhorn.py", line 256, in compute_kl_reg_cost
fa = ot_prob.geom.potential_from_scaling(ot_prob.a)
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 519, in potential_from_scaling
return self.epsilon * jnp.log(scaling)
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 162, in epsilon
return self._epsilon.target
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 150, in _epsilon
scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/geometry.py", line 130, in mean_cost_matrix
tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 376, in apply_cost
return self._apply_cost(arr, axis, fn=fn)
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 393, in _apply_cost
return app(
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 784, in _apply_cost_xy
c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/pointcloud.py", line 760, in _cost
cost = norm_x + norm_y + one_line_pairwise(x, y)
File "/usr/local/lib/python3.10/dist-packages/ott/geometry/costs.py", line 317, in pairwise
cosine_similarity = jnp.vdot(x, y) / (x_norm * y_norm + self._ridge)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 404834557696 bytes. Passing the |
Hi,
I have a question regarding the memory usage when using a cost_fn different from the default
costs.SqEuclidean()
.I have a large dataset (~
240.000
datapoints inx
and ~400.000
datapoints iny
). If I usepointcloud.PointCloud
with the defaultcost_fn
andbatch_size=512
everything works fine. However, if I use a differentcost_fn
e.g. theCosine
cost function, I run out of GPU memory. Even if I further reduce the batch_size (somehow the batch_size argument does not really seem to have an effect anymore).This works:
This runs out of GPU memory:
I'm using
ott-jax==0.4.5
Thanks for your help!
Felix
The text was updated successfully, but these errors were encountered: