From 9991e1573c92bb03f29dea5bb535a1851a2f5d53 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:29:31 +0100 Subject: [PATCH] Mark Sinkhorn online as CPU --- tests/solvers/linear/sinkhorn_test.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index a25064026..3e2545024 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -451,11 +451,14 @@ def test_restart(self, lse_mode: bool): # check only one iteration suffices when restarting with same data. assert num_iter_restarted == 1 + @pytest.mark.cpu @pytest.mark.limit_memory("110 MB") - @pytest.mark.fast.with_args("batch_size", [500, 1000], only_fast=0) - def test_sinkhorn_online_memory(self, batch_size: int): + @pytest.mark.fast.with_args( + "batch_size,jit", [(500, True), (1000, False)], only_fast=0 + ) + def test_sinkhorn_online_memory(self, batch_size: int, jit: bool): # offline: Total memory allocated: 240.1MiB - # online (500): Total memory allocated: 33.4MiB + # online (500): Total memory allocated: 33.4MiB; GPU: 203.4MiB # online (1000): Total memory allocated: 45.6MiB rngs = jax.random.split(jax.random.PRNGKey(0), 4) n, m = 5000, 4000 @@ -463,7 +466,9 @@ def test_sinkhorn_online_memory(self, batch_size: int): y = jax.random.uniform(rngs[1], (m, 2)) geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1) problem = linear_problem.LinearProblem(geom) - solver = sinkhorn.Sinkhorn() + solver = sinkhorn.Sinkhorn(jit=False) + if jit: + solver = jax.jit(solver) out = solver(problem) assert out.converged