-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate power
in PointCloud
, introduce TICost
and use it to compute Entropic (Brenier) maps.
#167
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #167 +/- ##
==========================================
+ Coverage 89.83% 89.91% +0.07%
==========================================
Files 51 51
Lines 5068 5118 +50
Branches 519 521 +2
==========================================
+ Hits 4553 4602 +49
- Misses 391 397 +6
+ Partials 124 119 -5
|
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.scipy as jsp | ||
import jax.tree_util as jtu | ||
from typing_extensions import Literal | ||
|
||
from ott.geometry import pointcloud | ||
from ott.geometry import costs, pointcloud | ||
|
||
__all__ = ["DualPotentials", "EntropicPotentials"] | ||
Potential_t = Callable[[jnp.ndarray], float] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the same vein as cost_fn
, I'm wondering if Potential_t
could be renamed to PotentialFn_t
or just PotentialFn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same reflex when Michal used it for the first time, but I think it makes sense :) Here it turns out this is just a type (_t
) and can be either a vector of a function.
ott/geometry/costs.py
Outdated
|
||
c(x,y) = h(z), where z := x-y. | ||
|
||
where h is a function strictly convex (or concave) function mapping vectors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small repetition here: I think you meant "where h is a strictly convex (or concave) function ...". It's minor, I know ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great catch! thanks.
power
in point cloud API, introduce RBF costs.
power
in point cloud API, introduce RBF costs.power
in PointCloud
, introduce RBFCost
and use it to compute Entropic (Brenier) maps.
power
in PointCloud
, introduce RBFCost
and use it to compute Entropic (Brenier) maps.power
in PointCloud
, introduce TIFCost
and use it to compute Entropic (Brenier) maps.
power
in PointCloud
, introduce TIFCost
and use it to compute Entropic (Brenier) maps.power
in PointCloud
, introduce TICost
and use it to compute Entropic (Brenier) maps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just minor fixes for Potentials
, then it can be merged.
It seems I need your explicit approval? maybe related to changes I did in "branch protection" |
…ompute Entropic (Brenier) maps. (#167) * deperecate `power`, introduce h maps in potentials * Deprecate power and introduce h function in costs. * linter * linter * revert abstractmethod. * linter * linter * PNorm -> SqPNorm * PNorm -> SqPNorm in tests. * another fix for abstract method. * fix abc.abstractmethod * linter * nb fix * linter * nb bug fix * modify ipynb * abc.abstractmethod for RBF * fixes and additions. * fix `cor` in neuraldual * fix in neuraldual * p-norm ** p implemented, fixes. * various fixes. Change to `TICost` * various fixes * fix nb * last fixes.
cost(x,y) := h(x-y)
whereh
is from vectors to reals. Specify (when known) the legendre transform ofh
. This plays a role in the application of the Brenier theorem.SqPNorms
, i.e. squared p-norms, cost functions. Squared p-norms have closed forms for their legendre transform (corresponding squared q-norm, where1/p + 1/q = 1
)power
option inPointCloud
geometries. This was confusing and redundant with the ability to write down directly a customCostFn
for that purpose. Raise an error ifpower
is passed to help debugging.