Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc/project structure #176

Merged
merged 35 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
425fc0e
Reorganize repo structure
Nov 11, 2022
7f2fa02
Update PC docs
michalk8 Nov 15, 2022
a41ade8
Update imports, fix some types
michalk8 Nov 16, 2022
0d30656
Fix more types, pet-peeves
michalk8 Nov 16, 2022
da2c97a
Fix tests
michalk8 Nov 16, 2022
3718cf7
Merge branch 'main' into misc/project-structure
michalk8 Nov 16, 2022
eb384eb
Update cost funcs and potentials
michalk8 Nov 16, 2022
5280021
Fix LR initializer
michalk8 Nov 16, 2022
a5999b5
Fix k-means initializer
michalk8 Nov 16, 2022
35ae68a
Move `utils`
michalk8 Nov 16, 2022
c1dfdb4
Update imports in notebooks
michalk8 Nov 16, 2022
f213453
Update geometry docs
michalk8 Nov 18, 2022
4b560d2
Update initializers
michalk8 Nov 18, 2022
75602d2
Update math docs
michalk8 Nov 18, 2022
7dcbea6
Update problem docstrings
michalk8 Nov 18, 2022
3365f24
Update `solvers` docstrings
michalk8 Nov 18, 2022
861bdfa
Update `tools` docstrings
michalk8 Nov 18, 2022
4faeca5
Remove remaining `core` mentions from docstrings
michalk8 Nov 18, 2022
ba29e10
Start updating documentation
michalk8 Nov 18, 2022
2bd6d8b
Fix typing
michalk8 Nov 21, 2022
d54074a
Update solvers docs
michalk8 Nov 21, 2022
6a15213
Add initializers
michalk8 Nov 21, 2022
987ad25
Update docs
michalk8 Nov 21, 2022
8e0fcb7
Fix MetaOT links
michalk8 Nov 21, 2022
f7b9295
Fix bibliography links
michalk8 Nov 21, 2022
8875b2f
Fix more links in the notebooks
michalk8 Nov 21, 2022
791bfed
Follow line length in README.md
michalk8 Nov 21, 2022
41a7ca7
Update `tests` structure
michalk8 Nov 22, 2022
0f1ad3b
Update badges
michalk8 Nov 22, 2022
75e358e
Add TODOs, fix citation in `index.rst`, move `implicit_diff`
michalk8 Nov 22, 2022
8c1ed2a
Fix implicit_diff, TODOs in costs
michalk8 Nov 22, 2022
9186c61
Use `jax.lax.cond` in `UnbalancedBures`
michalk8 Nov 22, 2022
f49351c
Fix `UnbalancedBures`
michalk8 Nov 22, 2022
687e594
Update CI versions
michalk8 Nov 22, 2022
aea8912
Fix UnbalancedBures's norm
michalk8 Nov 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion ott/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""OTT library."""
from . import core, geometry, tools
from . import geometry, initializers, math, problems, solvers, tools, utils
from ._version import __version__
44 changes: 0 additions & 44 deletions ott/core/__init__.py

This file was deleted.

33 changes: 0 additions & 33 deletions ott/core/_math_utils.py

This file was deleted.

71 changes: 0 additions & 71 deletions ott/core/momentum.py

This file was deleted.

93 changes: 0 additions & 93 deletions ott/core/problems.py

This file was deleted.

7 changes: 1 addition & 6 deletions ott/geometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""OTT ground geometries: Classes and cost functions to instantiate them."""
from . import costs, low_rank, ops
from .epsilon_scheduler import Epsilon
from .geometry import Geometry
from .graph import Graph
from .grid import Grid
from .pointcloud import PointCloud
from . import costs, epsilon_scheduler, geometry, graph, grid, pointcloud, segment
39 changes: 25 additions & 14 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
import abc
import functools
import math
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Tuple, Union

import jax
import jax.numpy as jnp

from ott.core import fixed_point_loop
from ott.geometry import matrix_square_root
from ott.math import fixed_point_loop, matrix_square_root

__all__ = [
"PNorm", "SqPNorm", "Euclidean", "SqEuclidean", "Cosine", "Bures",
"UnbalancedBures"
]


@jax.tree_util.register_pytree_node_class
Expand All @@ -32,9 +36,11 @@ class CostFn(abc.ABC):

Cost functions evaluate a function on a pair of inputs. For convenience,
that function is split into two norms -- evaluated on each input separately --
followed by a pairwise cost that involves both inputs, as in
followed by a pairwise cost that involves both inputs, as in:

.. math::

c(x,y) = norm(x) + norm(y) + pairwise(x,y)
c(x,y) = norm(x) + norm(y) + pairwise(x,y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this won't look good with maths

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it renders ok-ish, maybe we can just replace norm(x) with ||x||


If the norm function is not implemented, that value is handled as a 0.
"""
Expand Down Expand Up @@ -108,14 +114,14 @@ class TICost(CostFn):

@abc.abstractmethod
def h(self, z: jnp.ndarray) -> float:
"""RBF function acting on difference of `x-y` to ouput cost."""
"""TI function acting on difference of :math:`x-y` to output cost."""

def h_legendre(self, z: jnp.ndarray) -> float:
"""Legendre transform of RBF function `h` (when latter is convex)."""
"""Legendre transform of TI function :func:`h` (when latter is convex)."""
raise NotImplementedError("`h_legendre` not implemented.")

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute cost as evaluation of :func:`h` on `x-y`."""
"""Compute cost as evaluation of :func:`h` on :math:`x-y`."""
return self.h(x - y)


Expand All @@ -128,9 +134,10 @@ class SqPNorm(TICost):
"""

def __init__(self, p: float):
super().__init__()
assert p >= 1.0, "p parameter in sq. p-norm should be >= 1.0"
self.p = p
self.q = 1. / (1 - 1 / self.p) if p > 1.0 else 'inf'
self.q = 1. / (1. - 1. / self.p) if p > 1.0 else "inf"

def h(self, z: jnp.ndarray) -> float:
return 0.5 * jnp.linalg.norm(z, self.p) ** 2
Expand All @@ -152,9 +159,10 @@ class PNorm(TICost):
"""p-norm (to the power p) of the difference of two vectors."""

def __init__(self, p: float):
super().__init__()
assert p >= 1.0, "p parameter in p-norm should be >= 1.0"
self.p = p
self.q = 1. / (1 - 1 / self.p)
self.q = 1. / (1. - 1. / self.p) if p > 1. else "inf"

def h(self, z: jnp.ndarray) -> float:
return jnp.linalg.norm(z, self.p) ** self.p / self.p
Expand All @@ -175,8 +183,9 @@ def tree_unflatten(cls, aux_data, children):
class Euclidean(CostFn):
"""Euclidean distance.

Note that the Euclidean distance is not cast as a `TICost`, because this
would correspond to `h = jnp.linalg.norm`, whose gradient is not invertible,
Note that the Euclidean distance is not cast as a
:class:`~ott.geometry.costs.TICost`, since this would correspond to :math:`h`
being :func:`jax.numpy.linalg.norm`, whose gradient is not invertible,
because the function is not strictly convex (it is linear on rays).
"""

Expand Down Expand Up @@ -429,6 +438,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
log_m_pi += -0.5 * ldet_c_ab

# If all logdet signs are 1, output value, nan otherwise.
# TODO(michalk8): use lax.cond
return jnp.where(
sldet_c == 1 and sldet_c_ab == 1 and sldet_ab == 1 and sldet_t_ab == 1,
2 * sig2 * mass_x * mass_y - 2 * (sig2 + gam) * jnp.exp(log_m_pi),
Expand All @@ -444,7 +454,8 @@ def tree_unflatten(cls, aux_data, children):
return cls(aux_data[0], aux_data[1], aux_data[2], **aux_data[3])


def x_to_means_and_covs(x: jnp.ndarray, dimension: jnp.ndarray) -> jnp.ndarray:
def x_to_means_and_covs(x: jnp.ndarray,
dimension: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Extract means and covariance matrices of Gaussians from raveled vector.

Args:
Expand All @@ -456,7 +467,7 @@ def x_to_means_and_covs(x: jnp.ndarray, dimension: jnp.ndarray) -> jnp.ndarray:
covariances: [num_gaussians, dimension] array that holds the covariances.
"""
x = jnp.atleast_2d(x)
means = x[:, 0:dimension]
means = x[:, :dimension]
covariances = jnp.reshape(
x[:, dimension:dimension + dimension ** 2], (-1, dimension, dimension)
)
Expand Down
2 changes: 2 additions & 0 deletions ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import jax
import jax.numpy as jnp

__all__ = ["Epsilon"]


@jax.tree_util.register_pytree_node_class
class Epsilon:
Expand Down
Loading