Skip to content

Commit

Permalink
Adding option to specify epsilon parameter in low-rank sinkhorn
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Feb 3, 2022
1 parent 9bef50b commit 53a0697
Show file tree
Hide file tree
Showing 92 changed files with 188 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ jobs:
python -VV
python -c "import jax; print('jax', jax.__version__)"
python -c "import jaxlib; print('jaxlib', jaxlib.__version__)"
pytest tests
pytest tests -n 8
shell: bash
Binary file added ott/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added ott/__pycache__/version.cpython-39.pyc
Binary file not shown.
Binary file added ott/core/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added ott/core/__pycache__/anderson.cpython-39.pyc
Binary file not shown.
Binary file added ott/core/__pycache__/dataclasses.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added ott/core/__pycache__/icnn.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file added ott/core/__pycache__/momentum.cpython-39.pyc
Binary file not shown.
Binary file added ott/core/__pycache__/problems.cpython-39.pyc
Binary file not shown.
Binary file added ott/core/__pycache__/quad_problems.cpython-39.pyc
Binary file not shown.
Binary file added ott/core/__pycache__/sinkhorn.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ott.core import sinkhorn_lr
from ott.geometry import epsilon_scheduler
from ott.geometry import geometry
from ott.geometry import geometry_lr
from ott.geometry import low_rank
from ott.geometry import pointcloud


Expand Down Expand Up @@ -194,7 +194,7 @@ def __call__(self, prob: quad_problems.QuadraticProblem) -> GWOutput:
# Consider converting
if convert:
if not prob.is_fused or isinstance(prob.geom_xy,
geometry_lr.LRCGeometry):
low_rank.LRCGeometry):
prob.geom_xx = prob.geom_xx.to_LRCGeometry()
prob.geom_yy = prob.geom_yy.to_LRCGeometry()
else:
Expand Down
16 changes: 8 additions & 8 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ott.core import sinkhorn_lr
from ott.geometry import epsilon_scheduler
from ott.geometry import geometry
from ott.geometry import geometry_lr
from ott.geometry import low_rank
from ott.geometry import pointcloud
# Because Protocol is not available in Python < 3.8
from typing_extensions import Protocol
Expand Down Expand Up @@ -150,10 +150,10 @@ def is_fused(self):
@property
def is_all_geoms_lr(self):
lr_geoms = (
isinstance(self.geom_xx, geometry_lr.LRCGeometry) and
isinstance(self.geom_yy, geometry_lr.LRCGeometry))
isinstance(self.geom_xx, low_rank.LRCGeometry) and
isinstance(self.geom_yy, low_rank.LRCGeometry))
lr_geoms = lr_geoms and (
isinstance(self.geom_xy, geometry_lr.LRCGeometry)
isinstance(self.geom_xy, low_rank.LRCGeometry)
or
not self.is_fused
)
Expand Down Expand Up @@ -228,7 +228,7 @@ def marginal_dependent_cost(self, marginal_1, marginal_2):
fn=self.linear_loss[1])
x_term = jnp.concatenate((tmp1, jnp.ones_like(tmp1)), axis=1)
y_term = jnp.concatenate((jnp.ones_like(tmp2), tmp2), axis=1)
return geometry_lr.LRCGeometry(cost_1=x_term, cost_2=y_term)
return low_rank.LRCGeometry(cost_1=x_term, cost_2=y_term)

def cost_unbalanced_correction(self, transport_matrix, marginal_1, marginal_2,
epsilon, rescale_factor, delta=1e-9) -> float:
Expand Down Expand Up @@ -394,10 +394,10 @@ def update_lr_geom(self, lr_sink):
if self.is_all_geoms_lr:
tmp1r = self.geom_xx.apply_cost_2(q)
tmp2l = jnp.transpose(self.geom_yy.apply_cost_1(r, 1))
geom = geometry_lr.LRCGeometry(cost_1=tmp1r, cost_2=-tmp2l)
geom = geometry_lr.add_lrc_geom(geom, marginal_cost)
geom = low_rank.LRCGeometry(cost_1=tmp1r, cost_2=-tmp2l)
geom = low_rank.add_lrc_geom(geom, marginal_cost)
if self.is_fused:
geom = geometry_lr.add_lrc_geom(
geom = low_rank.add_lrc_geom(
geom, self.geom_xy)

else:
Expand Down
10 changes: 7 additions & 3 deletions ott/core/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class LRSinkhorn(sinkhorn.Sinkhorn):
Attributes:
rank: the rank constraint on the coupling to minimize the linear OT problem
gamma: the (inverse of) gradient stepsize used by mirror descent.
epsilon: entropic regularization added on top of low-rank problem.
lse_mode: whether to run computations in lse or kernel mode. At this moment,
only ``lse_mode=True`` is implemented.
threshold: convergence threshold, used to quantify whether two successive
Expand All @@ -200,6 +201,7 @@ class LRSinkhorn(sinkhorn.Sinkhorn):
def __init__(self,
rank: int = 10,
gamma: float = 1.0,
epsilon: float = 1e-4,
lse_mode: bool = True,
threshold: float = 1e-3,
norm_error: int = 1,
Expand All @@ -213,6 +215,7 @@ def __init__(self,
kwargs_dys: Any = None):
self.rank = rank
self.gamma = gamma
self.epsilon = epsilon
self.lse_mode = lse_mode
self.threshold = threshold
self.inner_iterations = inner_iterations
Expand Down Expand Up @@ -260,12 +263,13 @@ def not_converged(self, state, iteration):

def lr_costs(self, ot_prob, state, iteration):
c_q = ot_prob.geom.apply_cost(state.r, axis=1) / state.g[None, :]
c_q -= jnp.log(state.q) / self.gamma
c_q += (self.epsilon - 1 / self.gamma) * jnp.log(state.q)
c_r = ot_prob.geom.apply_cost(state.q) / state.g[None, :]
c_r -= jnp.log(state.r) / self.gamma
c_r += (self.epsilon - 1 / self.gamma) * jnp.log(state.r)
diag_qcr = jnp.sum(state.q * ot_prob.geom.apply_cost(state.r, axis=1),
axis=0)
h = diag_qcr / state.g ** 2 + jnp.log(state.g) / self.gamma
h = diag_qcr / state.g ** 2 - (
self.epsilon - 1 / self.gamma) * jnp.log(state.g)
return c_q, c_r, h

def dysktra_update(self, c_q, c_r, h, ot_prob, state, iteration,
Expand Down
Binary file added ott/geometry/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added ott/geometry/__pycache__/costs.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file added ott/geometry/__pycache__/geometry.cpython-39.pyc
Binary file not shown.
Binary file added ott/geometry/__pycache__/grid.cpython-39.pyc
Binary file not shown.
Binary file added ott/geometry/__pycache__/low_rank.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file added ott/geometry/__pycache__/ops.cpython-39.pyc
Binary file not shown.
Binary file not shown.
File renamed without changes.
4 changes: 2 additions & 2 deletions ott/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jnp
from ott.geometry import costs
from ott.geometry import geometry
from ott.geometry import geometry_lr
from ott.geometry import low_rank
from ott.geometry import ops


Expand Down Expand Up @@ -294,7 +294,7 @@ def to_LRCGeometry(self, scale=1.0):
cost_1 *= jnp.sqrt(scale)
cost_2 *= jnp.sqrt(scale)

return geometry_lr.LRCGeometry(
return low_rank.LRCGeometry(
cost_1=cost_1,
cost_2=cost_2,
epsilon=self._epsilon_init,
Expand Down
2 changes: 1 addition & 1 deletion ott/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""OTT tools: A set of tools to use OT in differentiable ML pipelines."""

from . import plot
#from . import plot
from . import sinkhorn_divergence
from . import soft_sort
from . import transport
Binary file added ott/tools/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added ott/tools/__pycache__/plot.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file added ott/tools/__pycache__/soft_sort.cpython-39.pyc
Binary file not shown.
Binary file added ott/tools/__pycache__/transport.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
104 changes: 104 additions & 0 deletions ott_jax.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
Metadata-Version: 2.1
Name: ott-jax
Version: 0.2.2
Summary: OTT: Optimal Transport Tools in Jax.
Home-page: https://github.com/google-research/ott
Author: Google LLC
Author-email: [email protected]
License: UNKNOWN
Keywords: optimal transport,sinkhorn,wasserstein,jax
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.7
Description-Content-Type: text/markdown
License-File: LICENSE

![Tests](https://github.com/google-research/ott/actions/workflows/tests.yml/badge.svg)

<div align="center">
<img src="https://github.com/google-research/ott/raw/master/docs/logoOTT.png" alt="logo" width="150"></img>
</div>

# Optimal Transport Tools (OTT), A toolbox for all things Wasserstein.

**See [full documentation](https://ott-jax.readthedocs.io/en/latest/) for detailed info on the toolbox.**
=======
The goal of OTT is to provide sturdy, versatile and efficient optimal transport solvers, taking advantage of JAX features, such as [JIT](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Using-jit-to-speed-up-functions), [auto-vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap) and [implicit differentiation](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html).

A typical OT problem has two ingredients: a pair of weight vectors `a` and `b` (one for each measure), with a ground cost matrix that is either directly given, or derived as the pairwise evaluation of a cost function on pairs of points taken from two measures. The main design choice in OTT comes from encapsulating the cost in a `Geometry` object, and bundle it with a few useful operations (notably kernel applications). The most common geometry is that of two clouds of vectors compared with the squared Euclidean distance, as illustrated in the example below:

## Example

```py
import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings via Sinkhorn algorithm.
ot = transport.Transport(x, y, a=a, b=b)
P = ot.matrix
```

The call to `sinkhorn` above works out the optimal transport solution by storing its output. The transport matrix can be instantiated using those optimal solutions and the `Geometry` again. That transoprt matrix links each point from the first point cloud to one or more points from the second, as illustrated below.

![obtained coupling](./images/couplings.png)

To be more precise, the `sinkhorn` algorithm operates on the `Geometry`,
taking into account weights `a` and `b`, to solve the OT problem, produce a named tuple that contains two optimal dual potentials `f` and `g` (vectors of the same size as `a` and `b`), the objective `reg_ot_cost` and a log of the `errors` of the algorithm as it converges, and a `converged` flag.

## Overall description of source code

Currently implements the following classes and functions:

- In the [geometry](ott/geometry) folder,

- The `CostFn` class in [costs.py](ott/geometry/costs.py) and its descendants define cost functions between points. Two simple costs are currently provided, `Euclidean` between vectors, and `Bures`, between a pair of mean vector and covariance (p.d.) matrix.

- The `Geometry` class in [geometry.py](ott/geometry/geometry.py) and its descendants describe a cost structure between two measures. That cost structure is accessed through various member functions, either used when running the Sinkhorn algorithm (typically kernel multiplications, or log-sum-exp row/column-wise application) or after (to apply the OT matrix to a vector).

- In its generic `Geometry` implementation, as in [geometry.py](ott/geometry/geometry.py), an object can be initialized with either a `cost_matrix` along with an `epsilon` regularization parameter (or scheduler), or with a `kernel_matrix`.

- If one wishes to compute OT between two weighted point clouds
<img src="https://render.githubusercontent.com/render/math?math=%24x%3D(x_1%2C%20%5Cdots%2C%20x_n)%24"> and <img src="https://render.githubusercontent.com/render/math?math=%24y%3D(y_1%2C%20%5Cdots%2C%20y_m)%24"> endowed with a
given cost function (e.g. Euclidean) <img src="https://render.githubusercontent.com/render/math?math=%24c%24">, the `PointCloud`
class in [pointcloud.py](ott/geometry/grid.py) can be used to define the corresponding kernel
<img src="https://render.githubusercontent.com/render/math?math=%24K_%7Bij%7D%3D%5Cexp(-c(x_i%2Cy_j)%2F%5Cepsilon)%24">. When the number of these points grows very large, this geometry can be instantiated with an `online=True` parameter, to avoid storing the kernel matrix and choose instead to recompute the matrix on the fly at each application.

- Simlarly, if all measures to be considered are supported on a
separable grid (e.g. <img src="https://render.githubusercontent.com/render/math?math=%24%5C%7B1%2C...%2Cn%5C%7D%5Ed%24">), and the cost is separable
along all axis, i.e. the cost between two points on that
grid is equal to the sum of (possibly <img src="https://render.githubusercontent.com/render/math?math=%24d%24"> different) cost
functions evaluated on each of the <img src="https://render.githubusercontent.com/render/math?math=%24d%24"> pairs of coordinates, then
the application of the kernel is much simplified, both in log space
or on the histograms themselves. This particular case is exploited in the `Grid` geometry in [grid.py](ott/geometry/grid.py) which can be instantiated as a hypercube using a `grid_size` parameter, or directly through grid locations in `x`.

- `LRCGeometry`, low-rank cost geometries, of which a `PointCloud` endowed with a squared-Euclidean distance is a particular example, can efficiently carry apply their cost to another matrix. This is leveraged in particular in the low-rank Sinkhorn (and Gromov-Wasserstein) solvers.


- In the [core](ott/core) folder,
- The `sinkhorn` function in [sinkhorn.py](ott/core/sinkhorn.py) is a wrapper around the `Sinkhorn` solver class, running the Sinkhorn algorithm, with the aim of solving approximately one or various optimal transport problems in parallel. An OT problem is defined by a `Geometry` object, and a pair <img src="https://render.githubusercontent.com/render/math?math=%24(a%2C%20b)%24"> (or batch thereof) of histograms. The function's outputs are stored in a `SinkhornOutput` named t-uple, containing potentials, regularized OT cost, sequence of errors and a convergence flag. Such outputs (with the exception of errors and convergence flag) can be differentiated w.r.t. any of the three inputs `(Geometry, a, b)` either through backprop or implicit differentiation of the optimality conditions of the optimal potentials `f` and `g`.
- A later addition in [sinkhorn_lr.py](ott/core/sinkhorn.py) is focused on the `LRSinkhorn` solver class, which is able to solve OT problems at larger scales using an explicit factorization of couplings as being low-rank.

- In [discrete_barycenter.py](ott/tools/discrete_barycenter.py): implementation of discrete Wasserstein barycenters : given <img src="https://render.githubusercontent.com/render/math?math=%24N%24"> histograms all supported on the same `Geometry`, compute a barycenter of theses measures, using an algorithm by [Janati et al. (2020)](https://arxiv.org/abs/2006.02575)

- In [gromov_wasserstein.py](ott/tools/gromov_wasserstein.py): implementation of two Gromov-Wasserstein solvers (both entropy-regularized and low-rank) to compare two measured-metric spaces, here encoded as a pair of `Geometry` objects, `geom_xx`, `geom_xy` along with weights `a` and `b`. Additional options include using a fused term by specifying `geom_xy`.

- In the [tools](ott/tools) folder,

- In [soft_sort.py](ott/tools/soft_sort.py): implementation of
[soft-sorting](https://papers.nips.cc/paper/2019/hash/d8c24ca8f23c562a5600876ca2a550ce-Abstract.html) operators, notably [soft-quantile transforms](http://proceedings.mlr.press/v119/cuturi20a.html)

- The `sinkhorn_divergence` function in [sinkhorn_divergence.py](ott/tools/sinkhorn_divergence.py), implements the [unbalanced](https://arxiv.org/abs/1910.12958) formulation of the [Sinkhorn divergence](http://proceedings.mlr.press/v84/genevay18a.html), a variant of the Wasserstein distance that uses regularization and is computed by centering the output of `sinkhorn` when comparing two measures.

- The `Transport` class in [sinkhorn_divergence.py](ott/tools/transport.py), provides a simple wrapper to the `sinkhorn` function defined above when the user is primarily interested in computing and storing an OT matrix.

- The [gaussian_mixture](ott/tools/gaussian_mixture) folder provides novel tools to compare and estimate GMMs with an OT perspective.


40 changes: 40 additions & 0 deletions ott_jax.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
LICENSE
README.md
pyproject.toml
setup.cfg
setup.py
ott/__init__.py
ott/version.py
ott/core/__init__.py
ott/core/anderson.py
ott/core/dataclasses.py
ott/core/discrete_barycenter.py
ott/core/fixed_point_loop.py
ott/core/gromov_wasserstein.py
ott/core/icnn.py
ott/core/implicit_differentiation.py
ott/core/momentum.py
ott/core/problems.py
ott/core/quad_problems.py
ott/core/sinkhorn.py
ott/core/sinkhorn_lr.py
ott/core/unbalanced_functions.py
ott/geometry/__init__.py
ott/geometry/costs.py
ott/geometry/epsilon_scheduler.py
ott/geometry/geometry.py
ott/geometry/grid.py
ott/geometry/low_rank.py
ott/geometry/matrix_square_root.py
ott/geometry/ops.py
ott/geometry/pointcloud.py
ott/tools/__init__.py
ott/tools/plot.py
ott/tools/sinkhorn_divergence.py
ott/tools/soft_sort.py
ott/tools/transport.py
ott_jax.egg-info/PKG-INFO
ott_jax.egg-info/SOURCES.txt
ott_jax.egg-info/dependency_links.txt
ott_jax.egg-info/requires.txt
ott_jax.egg-info/top_level.txt
1 change: 1 addition & 0 deletions ott_jax.egg-info/dependency_links.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

7 changes: 7 additions & 0 deletions ott_jax.egg-info/requires.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
absl-py>=0.7.0
jax>=0.1.67
jaxlib>=0.1.47
numpy>=1.18.4
matplotlib>=2.0.1
flax>=0.3.6
optax>=0.0.9
1 change: 1 addition & 0 deletions ott_jax.egg-info/top_level.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ott
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/core/sinkhorn_grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_apply_cost(self):

vec = jax.random.uniform(self.rng, grid_size).ravel()
self.assertAllClose(geom_mat.apply_cost(vec),
geom_grid.apply_cost(vec))
geom_grid.apply_cost(vec), rtol=1e-4, atol=1e-4)

self.assertAllClose(
geom_grid.apply_cost(vec)[:, 0], np.dot(geom_mat.cost_matrix.T, vec))
Expand Down
8 changes: 7 additions & 1 deletion tests/core/sinkhorn_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_euclidean_point_cloud(self, use_lrcgeom):
self.assertTrue(jnp.isclose(costs[-2], costs[-1], rtol=threshold))
cost_1 = costs[costs > -1][-1]

solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, rank=20)
solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, rank=20, epsilon=0.0)
out = solver(ot_prob)
costs = out.costs
cost_2 = costs[costs > -1][-1]
Expand All @@ -69,6 +69,12 @@ def test_euclidean_point_cloud(self, use_lrcgeom):
cost_other = out.cost_at_geom(other_geom)
self.assertGreater(cost_other, 0.0)

solver = sinkhorn_lr.LRSinkhorn(threshold=threshold, rank=20, epsilon=1e-2)
out = solver(ot_prob)
costs = out.costs
cost_3 = costs[costs > -1][-1]
self.assertGreater(cost_3, cost_2)


if __name__ == '__main__':
absltest.main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
12 changes: 6 additions & 6 deletions tests/geometry/geometry_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.numpy as jnp
import jax.test_util
from ott.geometry import geometry
from ott.geometry import geometry_lr
from ott.geometry import low_rank
from ott.geometry import pointcloud


Expand All @@ -40,7 +40,7 @@ def test_apply(self):
c = jnp.matmul(c1, c2.T)
bias = 0.27
geom = geometry.Geometry(c + bias)
geom_lr = geometry_lr.LRCGeometry(c1, c2, bias=bias)
geom_lr = low_rank.LRCGeometry(c1, c2, bias=bias)
for dim, axis in ((m, 1), (n, 0)):
for mat_shape in ((dim, 2), (dim,)):
mat = jax.random.normal(keys[2], mat_shape)
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_apply_squared(self):
c = jnp.matmul(c1, c2.T)
geom = geometry.Geometry(c)
geom2 = geometry.Geometry(c ** 2)
geom_lr = geometry_lr.LRCGeometry(c1, c2)
geom_lr = low_rank.LRCGeometry(c1, c2)
for dim, axis in ((m, 1), (n, 0)):
for mat_shape in ((dim, 2), (dim,)):
mat = jax.random.normal(keys[2], mat_shape)
Expand All @@ -103,9 +103,9 @@ def test_add_lr_geoms(self):
d = jnp.matmul(d1, d2.T)
geom = geometry.Geometry(c + d)

geom_lr_c = geometry_lr.LRCGeometry(c1, c2)
geom_lr_d = geometry_lr.LRCGeometry(d1, d2)
geom_lr = geometry_lr.add_lrc_geom(geom_lr_c, geom_lr_d)
geom_lr_c = low_rank.LRCGeometry(c1, c2)
geom_lr_d = low_rank.LRCGeometry(d1, d2)
geom_lr = low_rank.add_lrc_geom(geom_lr_c, geom_lr_d)

for dim, axis in ((m, 1), (n, 0)):
mat = jax.random.normal(keys[1], (dim, 2))
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 53a0697

Please sign in to comment.