Skip to content

Commit

Permalink
Mark Sinkhorn online as CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Feb 7, 2023
1 parent c60cf50 commit 9991e15
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,19 +451,24 @@ def test_restart(self, lse_mode: bool):
# check only one iteration suffices when restarting with same data.
assert num_iter_restarted == 1

@pytest.mark.cpu
@pytest.mark.limit_memory("110 MB")
@pytest.mark.fast.with_args("batch_size", [500, 1000], only_fast=0)
def test_sinkhorn_online_memory(self, batch_size: int):
@pytest.mark.fast.with_args(
"batch_size,jit", [(500, True), (1000, False)], only_fast=0
)
def test_sinkhorn_online_memory(self, batch_size: int, jit: bool):
# offline: Total memory allocated: 240.1MiB
# online (500): Total memory allocated: 33.4MiB
# online (500): Total memory allocated: 33.4MiB; GPU: 203.4MiB
# online (1000): Total memory allocated: 45.6MiB
rngs = jax.random.split(jax.random.PRNGKey(0), 4)
n, m = 5000, 4000
x = jax.random.uniform(rngs[0], (n, 2))
y = jax.random.uniform(rngs[1], (m, 2))
geom = pointcloud.PointCloud(x, y, batch_size=batch_size, epsilon=1)
problem = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn()
solver = sinkhorn.Sinkhorn(jit=False)
if jit:
solver = jax.jit(solver)

out = solver(problem)
assert out.converged
Expand Down

0 comments on commit 9991e15

Please sign in to comment.