Skip to content

Commit

Permalink
Make GMM documentation more visible, fix bug in M step of EM algorithm (
Browse files Browse the repository at this point in the history
#144)

* Make GMM documentation more visible, fix bug in M step of EM algorithm

* Remove unused imports of adhoc_import and importlib
  • Loading branch information
geoff-davis authored Oct 10, 2022
1 parent baedbc5 commit 468c2d5
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 284 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ There are currently three packages, ``geometry``, ``core`` and ``tools``, playin
notebooks/neural_dual.ipynb
notebooks/icnn_inits.ipynb
notebooks/wasserstein_barycenters_gmms.ipynb
notebooks/gmm_pair_demo.ipynb

.. toctree::
:maxdepth: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@
},
"outputs": [],
"source": [
"from colabtools import adhoc_import\n",
"import importlib\n",
"\n",
"import ott\n",
"from ott.tools.gaussian_mixture import gaussian_mixture\n",
"from ott.tools.gaussian_mixture import gaussian_mixture_pair\n",
Expand Down
11 changes: 5 additions & 6 deletions ott/tools/gaussian_mixture/fit_gmm_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,16 @@ def _m_step_fn(
Returns:
A GaussianMixturePair with updated parameters.
"""
params = (pair,)
state = opt_init(params)
state = opt_init((pair,))

for _ in range(steps):
grad_objective = grad_objective_fn(pair, obs0, obs1)
updates, state = opt_update(grad_objective, state, params)
params = optax.apply_updates(params, updates)
for j, gmm in enumerate((params[0].gmm0, params[0].gmm1)):
updates, state = opt_update(grad_objective, state, (pair,))
(pair,) = optax.apply_updates((pair,), updates)
for j, gmm in enumerate((pair.gmm0, pair.gmm1)):
if gmm.has_nans():
raise ValueError(f'NaN in gmm{j}')
return params[0]
return pair

return _m_step_fn

Expand Down
275 changes: 0 additions & 275 deletions ott/tools/gaussian_mixture/gmm_demo.ipynb

This file was deleted.

0 comments on commit 468c2d5

Please sign in to comment.