Skip to content

Commit

Permalink
Rename proj_fn to projector, allow passing array
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Sep 12, 2024
1 parent 3cfeee2 commit 4eaaedc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
19 changes: 12 additions & 7 deletions src/ott/tools/sliced.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -54,7 +54,7 @@ def sliced_wasserstein(
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
proj_fn: Optional[Projector] = None,
projector: Optional[Union[jnp.ndarray, Projector]] = None,
return_transport: bool = False,
return_dual_variables: bool = False,
**kwargs: Any,
Expand All @@ -74,9 +74,9 @@ def sliced_wasserstein(
cost_fn: Cost function. Must be a submodular function of two real arguments,
i.e. such that :math:`\partial c(x,y)/\partial x \partial y <0`. If
:obj:`None`, use :class:`~ott.geometry.costs.SqEuclidean`.
proj_fn: Projection function, mapping any ``[b, dim]`` matrix of coordinates
to ``[b, n_proj]`` matrix of features, on which 1D transports (for
``n_proj`` directions) are subsequently computed independently.
projector: Array of shape ``[n_proj, dim]``, or a function mapping any
``[b, dim]`` array of coordinates to ``[b, n_proj]`` array of features,
on which ``n_proj``-1D transports are computed independently.
By default, use :func:`~ott.tools.sliced.random_proj_sphere`.
return_transport: Whether to store ``n_proj`` transport plans in the output.
return_dual_variables: Whether to store ``n_proj`` pairs of dual vectors
Expand All @@ -88,10 +88,15 @@ def sliced_wasserstein(
Returns:
The sliced Wasserstein distance with the corresponding output object.
"""
if proj_fn is None:
if projector is None:
proj_fn = random_proj_sphere
elif callable(projector):
proj_fn = projector
else:
proj_fn = lambda arr, **_: arr @ projector.T

x_proj, y_proj = proj_fn(x, **kwargs), proj_fn(y, **kwargs),
x_proj = proj_fn(x, **kwargs)
y_proj = proj_fn(y, **kwargs)
geom = pointcloud.PointCloud(x_proj, y_proj, cost_fn=cost_fn)

out = linear.solve_univariate(
Expand Down
21 changes: 14 additions & 7 deletions tests/tools/sliced_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,26 @@ def gen_data(

class TestSliced:

@pytest.mark.parametrize("proj_fn", [None, custom_proj])
@pytest.mark.parametrize("projector", [None, custom_proj])
@pytest.mark.parametrize("cost_fn", [costs.PNormP(1.3), None])
def test_random_projs(
self, rng: jax.Array, cost_fn: Optional[costs.CostFn],
proj_fn: Optional[Projector]
projector: Optional[Projector]
):
n, m, dim, n_proj = 12, 17, 5, 13
rng1, rng2 = jax.random.split(rng, 2)
a, x, b, y = gen_data(rng1, n, m, dim)

# Test non-negative and returns output as needed.
cost, out = sliced.sliced_wasserstein(
x, y, a, b, cost_fn=cost_fn, proj_fn=proj_fn, n_proj=n_proj, rng=rng2
x,
y,
a=a,
b=b,
cost_fn=cost_fn,
projector=projector,
n_proj=n_proj,
rng=rng2
)
assert cost > 0.0
np.testing.assert_array_equal(cost, jnp.sum(out.ot_costs))
Expand All @@ -78,14 +85,14 @@ def test_consistency_with_id(

# Test matches standard implementation when using identity.
cost, _ = sliced.sliced_wasserstein(
x, y, proj_fn=lambda x: x, cost_fn=cost_fn
x, y, projector=lambda x, **_: x, cost_fn=cost_fn
)
geom = pointcloud.PointCloud(x=x, y=y, cost_fn=cost_fn)
geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn)
out_lin = jnp.sum(linear.solve_univariate(geom).ot_costs)
np.testing.assert_allclose(out_lin, cost, rtol=1e-6, atol=1e-6)

@pytest.mark.parametrize("proj_fn", [None, custom_proj])
def test_diff(self, rng: jax.Array, proj_fn: Optional[Projector]):
@pytest.mark.parametrize("projector", [None, custom_proj])
def test_diff(self, rng: jax.Array, projector: Optional[Projector]):
eps = 1e-4
n, m, dim = 13, 16, 7
a, x, b, y = gen_data(rng, n, m, dim)
Expand Down

0 comments on commit 4eaaedc

Please sign in to comment.