Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue with using other metric choices with emd2_1d #669

Closed
mrunalimanj opened this issue Aug 13, 2024 · 2 comments
Closed

issue with using other metric choices with emd2_1d #669

mrunalimanj opened this issue Aug 13, 2024 · 2 comments

Comments

@mrunalimanj
Copy link

Describe the bug

emd2_1d errors when not using the sped-up distribution metrics, e.g. cosine, yule,

To Reproduce

Steps to reproduce the behavior:
Simple test case adapted from the 1d example code:


import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss

##############################################################################
# Generate data
# -------------


#%% parameters

n = 100  # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a = gauss(n, m=20, s=5)  # m= mean, s= std
b = gauss(n, m=60, s=10)

# use fast 1D solver
G0 = ot.emd_1d(x, x, a, b, metric="cosine")
 54 G0 = ot.emd_1d(x, x, a, b, metric="cosine")
     55 
     56 # Equivalent to

~/miniconda3/envs/ms-gen/lib/python3.8/site-packages/ot/lp/solver_1d.py in emd_1d(x_a, x_b, a, b, metric, p, dense, log, check_marginals)
    257     perm_b = nx.argsort(x_b_1d)
    258 
--> 259     G_sorted, indices, cost = emd_1d_sorted(
    260         nx.to_numpy(a[perm_a]).astype(np.float64),
    261         nx.to_numpy(b[perm_b]).astype(np.float64),

ot/lp/emd_wrap.pyx in ot.lp.emd_wrap.emd_1d_sorted()

AttributeError: 'float' object has no attribute 'reshape'

Expected behavior

Should return a value, but instead errors (can't tell if math is yet correct)

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.8.18
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source): pip install POT

Output of the following code snippet:

import platform; print(platform.platform()) 
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__): 
import ot; print("POT", ot.__version__)

Linux-5.15.0-117-generic-x86_64-with-glibc2.10
Python 3.8.18 | packaged by conda-forge | (default, Oct 10 2023, 15:44:36)
[GCC 12.3.0]
NumPy 1.24.4
SciPy 1.10.1
POT 0.9.4

Additional context

@rtavenar
Copy link
Contributor

Hi @mrunalimanj ,

The emd_1d and emd2_1d functions rely on the result presented (for example) in [1, Remark 2.28]. As such, it is valid only for metrics of the form $d(x, y) = |x - y|^p$ (at least, to the best of my knowledge).

I hence suggest that we remove the opportunity to select metrics that are not of this form, and I have opened a PR for that.

@rflamary
Copy link
Collaborator

The PR was merged I'm closing the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants