Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian Mixture Model #137

Merged
merged 14 commits into from
Jul 18, 2023
Merged

Gaussian Mixture Model #137

merged 14 commits into from
Jul 18, 2023

Conversation

krstopro
Copy link
Member

Added Gaussian Mixture Model implementation, similar to the one in scikit-learn Python module. Currently it supports only full covariances and uses k-means for the initialisation. Other covariance types (diagonal, tied, spherical) and initialisation methods (random, k-means++) can easily be added.

Copy link
Contributor

@msluszniak msluszniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! The PR LGTM :)

Copy link
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of really minor comments. Overall this is looking great!

lib/scholar/cluster/gmm.ex Outdated Show resolved Hide resolved
lib/scholar/cluster/gmm.ex Outdated Show resolved Hide resolved
expectation of the Gaussian assignment for each data point x and the M-step which updates the
parameters to maximize the expectations found in E-step. While every iteration of the algorithm
is guaranteed to improve the log-likelihood, the final result depends on the initial values of
the parameters and the entire procedure should be repeated several times.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From reading without looking at the code beforehand:

Does this mean that we need to run the algorithm multiple times, or does the algorithm itself repeat this procedure?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The algorithm itself repeats the procedure num_runs amount of times, each time with different initial parameters. Similarly as it was already done in k-means implementation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if I should change anything here. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, the description is clear. @polvalente do you have any suggestions?

lib/scholar/cluster/gmm.ex Outdated Show resolved Hide resolved
lib/scholar/cluster/gmm.ex Outdated Show resolved Hide resolved
lib/scholar/cluster/gmm.ex Show resolved Hide resolved
lib/scholar/cluster/gmm.ex Show resolved Hide resolved
lib/scholar/cluster/gmm.ex Show resolved Hide resolved
lib/scholar/cluster/gmm.ex Show resolved Hide resolved
lib/scholar/cluster/gmm.ex Outdated Show resolved Hide resolved
krstopro and others added 4 commits July 14, 2023 22:14
@krstopro
Copy link
Member Author

krstopro commented Jul 15, 2023

@msluszniak @polvalente Thanks for the support and given feedback! I've addressed most of the suggestions provided; see the comments above (except for vectorization that I would leave for now). There are few more remarks that I would like to make; see some of the comments I've written.

@msluszniak
Copy link
Contributor

Are there any blockers for this PR?

@krstopro
Copy link
Member Author

@msluszniak Don't think so. I just need to address few suggestions that you made. Could you have a look at the two reviews I opened? The logsumexp one might be important.

@msluszniak
Copy link
Contributor

@msluszniak Don't think so. I just need to address few suggestions that you made. Could you have a look at the two reviews I opened? The logsumexp one might be important.

Sure, but tbh I cannot see them 😅

@krstopro
Copy link
Member Author

krstopro commented Jul 17, 2023

Sure, but tbh I cannot see them 😅

@msluszniak They are probably collapsed among many others :)
Tagged you in the comments.

@msluszniak
Copy link
Contributor

Are you sure that your comments are not in pending mode or in the batch and not pushed or sth? Because I really cannot find anything

@krstopro
Copy link
Member Author

Are you sure that your comments are not in pending mode or in the batch and not pushed or sth? Because I really cannot find anything

Sorry, they were in pending mode and I don't think I am able to submit them. Hence, I'll write them here.
Line 357-366 includes implementation to logsumexp which basically calculates the logarithm of the sum of the exponentials of the tensor in a numerically stable way. This is a well known trick that is implemented in numerical libraries, e.g. see SciPy
https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html.
However, I am not aware that such a function exists in Nx. Is there one or perhaps I should raise an issue about implementing it?

Line 292: doing {num_gaussians, num_features, num_features} = Nx.shape(covariances) inside defn won't work, even though it works outside. I suppose this is intended?

@msluszniak
Copy link
Contributor

I'm almost sure that there is no such function implemented in Nx so feel free to send a PR with it. I'm not sure if we eventually want to have this type of functions in Nx or in Scholar, but it will be helpful.

This line won't work in defn because you're doing pattern matching. If you want to assert that the shape is as follows you need to check for it separately.

@polvalente
Copy link
Contributor

You can use case for checking the validity of the shape inside defn

@krstopro
Copy link
Member Author

You can use case for checking the validity of the shape inside defn

I see, thanks. Here it is not an issue, just wanted to check that this is intended.
Anyway, if everything else is fine can you approve the PR?

},
k < num_gaussians do
diff = x - means[k]
covariance = Nx.dot(responsibilities[[.., k]] * Nx.transpose(diff), diff) / nk[k]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
covariance = Nx.dot(responsibilities[[.., k]] * Nx.transpose(diff), diff) / nk[k]
covariance = Nx.dot(responsibilities[[.., k]] * diff, [-2], diff / nk[k], [-2])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work as Nx.transpose is only applied to diff, not the entire product. Adding another axis to responsibilities[[.., k]] solves the problem.
covariance = Nx.dot(diff * Nx.new_axis(responsibilities[[.., k]], 1), [0], diff / nk[k], [0])

@josevalim
Copy link
Contributor

@krstopro please verify @polvalente suggestions above and we can ship it :)

@msluszniak msluszniak added enhancement New feature or request good first issue Good for newcomers labels Jul 17, 2023
@josevalim josevalim merged commit 5c51869 into elixir-nx:main Jul 18, 2023
@josevalim
Copy link
Contributor

💚 💙 💜 💛 ❤️

@krstopro krstopro changed the title Adding Gaussian Mixture Model Gaussian Mixture Model Apr 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants