-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Add (Low-Rank) FUGW barycenters + brain tutorial #526
base: main
Are you sure you want to change the base?
Conversation
…/ott into feat/fugw-barycenters
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #526 +/- ##
==========================================
- Coverage 90.93% 90.74% -0.20%
==========================================
Files 68 68
Lines 7063 7088 +25
Branches 998 1004 +6
==========================================
+ Hits 6423 6432 +9
- Misses 486 502 +16
Partials 154 154
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot! will let @michalk8 have a look too!
scale_cost: Union[int, float, Literal["mean", "max_cost"]] = 1.0, | ||
**kwargs: Any, | ||
): | ||
assert y is None or costs is None, "Cannot specify both `y` and `costs`." | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would remove None
as this is an __init__
function
# y as costs | ||
if problem._y_as_costs: | ||
cost = problem._y[0, :, :] | ||
# if not : initialized with euclidian metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Euclidean
# if not : initialized with euclidian metrics | ||
else: | ||
coords = problem._y[0, :, :] | ||
pairwise_sq_dists = jnp.sum((coords[:, None] - coords[None]) ** 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be more consistent to use a pointcloud.PointCloud
geometry to compute this automatically, with, e.g., a cost_fn=Euclidean
function. This may also avoid numerical issues with jnp.sqrt
Completed with @S-bazaz for @marcocuturi 's 2024 Computational OT course at ENSAE.
Implements: