2022-02-23 01:12:02.683505: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 4.00GiB (4294967296B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 4.00GiB constant allocation: 0B maybe_live_out allocation: 4.00GiB preallocated temp allocation: 0B total allocation: 8.00GiB total fragmentation: 0B (0.00%) allocation 0: 0x55cd9cf8bed0, size 4294967296, output shape is |f32[16384,65536]|, maybe-live-out: value: <1 exponential.3 @0> (size=4294967296,offset=0): f32[16384,65536]{1,0} contains:<1 exponential.3 @0> positions: exponential.3 uses: from instruction:%exponential.3 = f32[16384,65536]{1,0} exponential(f32[16384,65536]{1,0} %parameter.1), metadata={op_type="exp" op_name="jit(exp)/exp" source_file="/vol/storage2/MK/ott/ott/geometry/ops.py" source_line=27} allocation 1: 0x55cd9cf8bf80, size 4294967296, parameter 0, shape |f32[16384,65536]| at ShapeIndex {}: value: <0 parameter.1 @0> (size=4294967296,offset=0): f32[16384,65536]{1,0} contains:<0 parameter.1 @0> positions: parameter.1 uses: exponential.3, operand 0 from instruction:%parameter.1 = f32[16384,65536]{1,0} parameter(0) (65536, 30) (16384, 30) Traceback (most recent call last): File "run.py", line 15, in r = ott.core.sinkhorn.sinkhorn(pc) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 959, in sinkhorn return sink(ot_prob, (init_dual_a, init_dual_b)) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 384, in __call__ return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 524, in run out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 201, in set_cost return self.set(reg_ot_cost=ent_reg_cost(f, g, ot_prob, lse_mode)) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 176, in ent_reg_cost total_sum = jnp.sum(ot_prob.geom.marginal_from_potentials(f, g)) File "/vol/storage2/MK/ott/ott/geometry/geometry.py", line 258, in marginal_from_potentials z = self.apply_lse_kernel(f, g, self.epsilon, axis=axis)[0] File "/vol/storage2/MK/ott/ott/geometry/pointcloud.py", line 136, in apply_lse_kernel h_res, h_sgn = app(self.x, self.y, self._norm_x, self._norm_y, f, g, eps, File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/api.py", line 1452, in batched_fun out_flat = batching.batch( File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/vol/storage2/MK/ott/ott/geometry/pointcloud.py", line 315, in _apply_lse_kernel_xy return ops.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 218, in __call__ out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 281, in bind outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/interpreters/batching.py", line 248, in process_custom_jvp_call out_vals = prim.bind(fun, jvp, *in_vals) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 281, in bind outs = top_trace.process_custom_jvp_call(self, fun, jvp, tracers) # type: ignore File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/core.py", line 633, in process_custom_jvp_call return fun.call_wrapped(*tracers) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/vol/storage2/MK/ott/ott/geometry/ops.py", line 27, in logsumexp return jax.scipy.special.logsumexp( File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/scipy/special.py", line 120, in logsumexp out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 207, in exp return exp_p.bind(x) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/core.py", line 272, in bind out = top_trace.process_primitive(self, tracers, params) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/interpreters/batching.py", line 163, in process_primitive val_out, dim_out = batched_primitive(vals_in, dims_in, **params) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/interpreters/batching.py", line 377, in vectorized_batcher return prim.bind(*batched_args, **params), batch_dims[0] File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/core.py", line 272, in bind out = top_trace.process_primitive(self, tracers, params) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/core.py", line 624, in process_primitive return primitive.impl(*tracers, **params) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/interpreters/xla.py", line 418, in apply_primitive return compiled_fun(*args) File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/interpreters/xla.py", line 442, in return lambda *args, **kw: compiled(*args, **kw)[0] File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1100, in _execute_compiled out_bufs = compiled.execute(input_bufs) jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 4.00GiB (4294967296B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 4.00GiB constant allocation: 0B maybe_live_out allocation: 4.00GiB preallocated temp allocation: 0B total allocation: 8.00GiB total fragmentation: 0B (0.00%) allocation 0: 0x55cd9cf8bed0, size 4294967296, output shape is |f32[16384,65536]|, maybe-live-out: value: <1 exponential.3 @0> (size=4294967296,offset=0): f32[16384,65536]{1,0} contains:<1 exponential.3 @0> positions: exponential.3 uses: from instruction:%exponential.3 = f32[16384,65536]{1,0} exponential(f32[16384,65536]{1,0} %parameter.1), metadata={op_type="exp" op_name="jit(exp)/exp" source_file="/vol/storage2/MK/ott/ott/geometry/ops.py" source_line=27} allocation 1: 0x55cd9cf8bf80, size 4294967296, parameter 0, shape |f32[16384,65536]| at ShapeIndex {}: value: <0 parameter.1 @0> (size=4294967296,offset=0): f32[16384,65536]{1,0} contains:<0 parameter.1 @0> positions: parameter.1 uses: exponential.3, operand 0 from instruction:%parameter.1 = f32[16384,65536]{1,0} parameter(0) The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "run.py", line 15, in r = ott.core.sinkhorn.sinkhorn(pc) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 959, in sinkhorn return sink(ot_prob, (init_dual_a, init_dual_b)) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 384, in __call__ return run_fn(ot_prob, self, (init_dual_a, init_dual_b)) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 524, in run out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 201, in set_cost return self.set(reg_ot_cost=ent_reg_cost(f, g, ot_prob, lse_mode)) File "/vol/storage2/MK/ott/ott/core/sinkhorn.py", line 176, in ent_reg_cost total_sum = jnp.sum(ot_prob.geom.marginal_from_potentials(f, g)) File "/vol/storage2/MK/ott/ott/geometry/geometry.py", line 258, in marginal_from_potentials z = self.apply_lse_kernel(f, g, self.epsilon, axis=axis)[0] File "/vol/storage2/MK/ott/ott/geometry/pointcloud.py", line 136, in apply_lse_kernel h_res, h_sgn = app(self.x, self.y, self._norm_x, self._norm_y, f, g, eps, File "/vol/storage2/MK/ott/ott/geometry/pointcloud.py", line 315, in _apply_lse_kernel_xy return ops.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1) File "/vol/storage2/MK/ott/ott/geometry/ops.py", line 27, in logsumexp return jax.scipy.special.logsumexp( File "/vol/storage/miniconda3/envs/cellrank/lib/python3.8/site-packages/jax/_src/scipy/special.py", line 120, in logsumexp out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), RuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 4.00GiB (4294967296B) on device ordinal 0 BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 4.00GiB constant allocation: 0B maybe_live_out allocation: 4.00GiB preallocated temp allocation: 0B total allocation: 8.00GiB total fragmentation: 0B (0.00%) allocation 0: 0x55cd9cf8bed0, size 4294967296, output shape is |f32[16384,65536]|, maybe-live-out: value: <1 exponential.3 @0> (size=4294967296,offset=0): f32[16384,65536]{1,0} contains:<1 exponential.3 @0> positions: exponential.3 uses: from instruction:%exponential.3 = f32[16384,65536]{1,0} exponential(f32[16384,65536]{1,0} %parameter.1), metadata={op_type="exp" op_name="jit(exp)/exp" source_file="/vol/storage2/MK/ott/ott/geometry/ops.py" source_line=27} allocation 1: 0x55cd9cf8bf80, size 4294967296, parameter 0, shape |f32[16384,65536]| at ShapeIndex {}: value: <0 parameter.1 @0> (size=4294967296,offset=0): f32[16384,65536]{1,0} contains:<0 parameter.1 @0> positions: parameter.1 uses: exponential.3, operand 0 from instruction:%parameter.1 = f32[16384,65536]{1,0} parameter(0)