Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 435593406
  • Loading branch information
LaetitiaPapaxanthos committed Mar 18, 2022
1 parent c395639 commit 79a7f2c
Show file tree
Hide file tree
Showing 98 changed files with 2,307 additions and 3 deletions.
14 changes: 14 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
Expand Down
7 changes: 7 additions & 0 deletions docs/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ Neural Potentials

icnn.ICNN

Neural Potentials
-----------------
.. autosummary::
:toctree: _autosummary

neuraldual.NeuralDualSolver
neuraldual.NeuralDual

References
----------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin
notebooks/soft_sort.ipynb
notebooks/application_biology.ipynb
notebooks/fairness.ipynb
notebooks/neural_dual.ipynb


.. toctree::
Expand Down
502 changes: 502 additions & 0 deletions docs/notebooks/neural_dual.ipynb

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""OTT library."""

from . import core
Expand Down
15 changes: 15 additions & 0 deletions ott/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""OTT core libraries: the engines behind most computations happening in OTT."""

# pytype: disable=import-error # kwargs-checking
Expand All @@ -11,6 +25,7 @@
from . import problems
from . import sinkhorn
from . import sinkhorn_lr
from . import neuraldual
from .implicit_differentiation import ImplicitDiff
from .problems import LinearProblem
from .sinkhorn import Sinkhorn
Expand Down
14 changes: 14 additions & 0 deletions ott/core/anderson.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tools for Anderson acceleration."""
from typing import Any
import jax
Expand Down
14 changes: 14 additions & 0 deletions ott/core/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""pytree_nodes Dataclasses."""

import dataclasses
Expand Down
14 changes: 14 additions & 0 deletions ott/core/discrete_barycenter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Implementation of Janati+(2020) Wasserstein barycenter algorithm."""

Expand Down
14 changes: 14 additions & 0 deletions ott/core/fixed_point_loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""jheek@ backprop-friendly implementation of fixed point loop."""
from typing import Callable, Any
Expand Down
14 changes: 14 additions & 0 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""A Jax version of the regularised GW Solver (Peyre et al. 2016)."""
import functools
Expand Down
24 changes: 22 additions & 2 deletions ott/core/icnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Implementation of Amos+(2017) input convex neural networks (ICNN)."""

Expand Down Expand Up @@ -77,17 +91,23 @@ class ICNN(nn.Module):
init_std: float = 0.1
init_fn: Callable = jax.nn.initializers.normal
act_fn: Callable = nn.leaky_relu
pos_weights: bool = True

def setup(self):
num_hidden = len(self.dim_hidden)

w_zs = list()

if self.pos_weights:
Dense = PositiveDense
else:
Dense = nn.Dense

for i in range(1, num_hidden):
w_zs.append(PositiveDense(
w_zs.append(Dense(
self.dim_hidden[i], kernel_init=self.init_fn(self.init_std),
use_bias=False))
w_zs.append(PositiveDense(
w_zs.append(Dense(
1, kernel_init=self.init_fn(self.init_std), use_bias=False))
self.w_zs = w_zs

Expand Down
14 changes: 14 additions & 0 deletions ott/core/implicit_differentiation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions entering the implicit differentiation of Sinkhorn."""

from typing import Callable, Optional, Tuple
Expand Down
14 changes: 14 additions & 0 deletions ott/core/momentum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
# coding=utf-8
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions related to momemtum."""

import jax.numpy as jnp
Expand Down
Loading

0 comments on commit 79a7f2c

Please sign in to comment.