From 420418edf0d266dfeda9f8ff514a94638421000d Mon Sep 17 00:00:00 2001 From: gileshd Date: Fri, 19 Jul 2024 22:04:48 +0100 Subject: [PATCH 1/3] Add utility function for sklearn kmeans --- dynamax/utils/cluster.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 dynamax/utils/cluster.py diff --git a/dynamax/utils/cluster.py b/dynamax/utils/cluster.py new file mode 100644 index 00000000..a31c9e05 --- /dev/null +++ b/dynamax/utils/cluster.py @@ -0,0 +1,28 @@ +from typing import Tuple + +from jax import numpy as jnp +from jax import random as jr + +from jaxtyping import Array, Float + + +def kmeans_sklearn( + k: int, X: Float[Array, "num_samples state_dim"], key: Array +) -> Tuple[Float[Array, "num_states state_dim"], Float[Array, "num_samples"]]: + """ + Compute the cluster centers and assignments using the sklearn K-means algorithm. + + Args: + k (int): The number of clusters. + X (Array(N, D)): The input data array. N samples of dimension D. + key (Array): The random seed array. + + Returns: + Array(k, D), Array(N,): The cluster centers and labels + """ + from sklearn.cluster import KMeans + + key, subkey = jr.split(key) # Create a random seed for SKLearn. + sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. + km = KMeans(k, random_state=int(sklearn_key)).fit(X) + return jnp.array(km.cluster_centers_), jnp.array(km.labels_) From 828f1d144e3631cbc7d9e790e11e9c83863e4242 Mon Sep 17 00:00:00 2001 From: gileshd Date: Fri, 19 Jul 2024 22:00:51 +0100 Subject: [PATCH 2/3] Update SSMs to use kmeans utility function --- dynamax/hidden_markov_model/models/arhmm.py | 7 ++-- .../hidden_markov_model/models/gamma_hmm.py | 9 ++---- .../models/gaussian_hmm.py | 32 ++++--------------- dynamax/hidden_markov_model/models/gmm_hmm.py | 15 +++------ .../hidden_markov_model/models/linreg_hmm.py | 7 ++-- .../hidden_markov_model/models/logreg_hmm.py | 16 ++++++---- 6 files changed, 28 insertions(+), 58 deletions(-) diff --git a/dynamax/hidden_markov_model/models/arhmm.py b/dynamax/hidden_markov_model/models/arhmm.py index 2eff832b..07c1ee75 100644 --- a/dynamax/hidden_markov_model/models/arhmm.py +++ b/dynamax/hidden_markov_model/models/arhmm.py @@ -10,6 +10,7 @@ from dynamax.parameters import ParameterProperties from dynamax.types import Scalar from dynamax.utils.bijectors import RealToPSDBijector +from dynamax.utils.cluster import kmeans_sklearn from tensorflow_probability.substrates import jax as tfp from typing import NamedTuple, Optional, Tuple, Union @@ -42,12 +43,8 @@ def initialize(self, emissions=None): if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) _emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.emission_dim * self.num_lags)) - _emission_biases = jnp.array(km.cluster_centers_) + _emission_biases, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1)) elif method.lower() == "prior": diff --git a/dynamax/hidden_markov_model/models/gamma_hmm.py b/dynamax/hidden_markov_model/models/gamma_hmm.py index 2efcdd86..f44538ea 100644 --- a/dynamax/hidden_markov_model/models/gamma_hmm.py +++ b/dynamax/hidden_markov_model/models/gamma_hmm.py @@ -8,6 +8,7 @@ from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions from dynamax.types import Scalar +from dynamax.utils.cluster import kmeans_sklearn import optax from typing import NamedTuple, Optional, Tuple, Union @@ -38,13 +39,9 @@ def initialize(self, if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, 1)) - + cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, 1), key) _emission_concentrations = jnp.ones((self.num_states,)) - _emission_rates = jnp.ravel(1.0 / km.cluster_centers_) + _emission_rates = jnp.ravel(1.0 / cluster_centers) elif method.lower() == "prior": _emission_concentrations = jnp.ones((self.num_states,)) diff --git a/dynamax/hidden_markov_model/models/gaussian_hmm.py b/dynamax/hidden_markov_model/models/gaussian_hmm.py index c1904878..fe1e8dc3 100644 --- a/dynamax/hidden_markov_model/models/gaussian_hmm.py +++ b/dynamax/hidden_markov_model/models/gaussian_hmm.py @@ -17,6 +17,7 @@ from dynamax.utils.distributions import niw_posterior_update from dynamax.utils.bijectors import RealToPSDBijector from dynamax.utils.utils import pytree_sum +from dynamax.utils.cluster import kmeans_sklearn from typing import NamedTuple, Optional, Tuple, Union @@ -70,12 +71,7 @@ def initialize(self, key=jr.PRNGKey(0), emissions=None): if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) - - _emission_means = jnp.array(km.cluster_centers_) + _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1)) elif method.lower() == "prior": @@ -168,11 +164,7 @@ def initialize(self, key=jr.PRNGKey(0), if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) - _emission_means = jnp.array(km.cluster_centers_) + _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_scale_diags = jnp.ones((self.num_states, self.emission_dim)) elif method.lower() == "prior": @@ -289,11 +281,7 @@ def initialize(self, key=jr.PRNGKey(0), """ if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) - _emission_means = jnp.array(km.cluster_centers_) + _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_scales = jnp.ones((self.num_states,)) elif method.lower() == "prior": @@ -391,11 +379,7 @@ def initialize(self, key=jr.PRNGKey(0), """ if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) - _emission_means = jnp.array(km.cluster_centers_) + _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_cov = jnp.eye(self.emission_dim) elif method.lower() == "prior": @@ -513,11 +497,7 @@ def initialize(self, key=jr.PRNGKey(0), """ if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) - _emission_means = jnp.array(km.cluster_centers_) + _emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_cov_diag_factors = jnp.ones((self.num_states, self.emission_dim)) _emission_cov_low_rank_factors = jnp.zeros((self.num_states, self.emission_dim, self.emission_rank)) diff --git a/dynamax/hidden_markov_model/models/gmm_hmm.py b/dynamax/hidden_markov_model/models/gmm_hmm.py index 8b7e778c..e6f55d84 100644 --- a/dynamax/hidden_markov_model/models/gmm_hmm.py +++ b/dynamax/hidden_markov_model/models/gmm_hmm.py @@ -15,6 +15,7 @@ from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions from dynamax.utils.bijectors import RealToPSDBijector from dynamax.utils.utils import pytree_sum +from dynamax.utils.cluster import kmeans_sklearn from dynamax.types import Scalar from typing import NamedTuple, Optional, Tuple, Union @@ -77,12 +78,9 @@ def initialize(self, key=jr.PRNGKey(0), emissions=None): if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) + cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components - _emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1)) + _emission_means = jnp.tile(jnp.array(cluster_centers)[:, None, :], (1, self.num_components, 1)) _emission_covs = jnp.tile(jnp.eye(self.emission_dim), (self.num_states, self.num_components, 1, 1)) elif method.lower() == "prior": @@ -299,12 +297,9 @@ def initialize(self, key=jr.PRNGKey(0), emissions=None): if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) + cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components - _emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1)) + _emission_means = jnp.tile(jnp.array(cluster_centers)[:, None, :], (1, self.num_components, 1)) _emission_scale_diags = jnp.ones((self.num_states, self.num_components, self.emission_dim)) elif method.lower() == "prior": diff --git a/dynamax/hidden_markov_model/models/linreg_hmm.py b/dynamax/hidden_markov_model/models/linreg_hmm.py index df63c1bd..26947f61 100644 --- a/dynamax/hidden_markov_model/models/linreg_hmm.py +++ b/dynamax/hidden_markov_model/models/linreg_hmm.py @@ -9,6 +9,7 @@ from dynamax.types import Scalar from dynamax.utils.utils import pytree_sum from dynamax.utils.bijectors import RealToPSDBijector +from dynamax.utils.cluster import kmeans_sklearn from tensorflow_probability.substrates import jax as tfp from typing import NamedTuple, Optional, Tuple, Union @@ -58,12 +59,8 @@ def initialize(self, emissions=None): if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" - from sklearn.cluster import KMeans - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) _emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.input_dim)) - _emission_biases = jnp.array(km.cluster_centers_) + _emission_biases, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key) _emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1)) elif method.lower() == "prior": diff --git a/dynamax/hidden_markov_model/models/logreg_hmm.py b/dynamax/hidden_markov_model/models/logreg_hmm.py index 2da4dd84..b9957a09 100644 --- a/dynamax/hidden_markov_model/models/logreg_hmm.py +++ b/dynamax/hidden_markov_model/models/logreg_hmm.py @@ -8,6 +8,7 @@ from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions from dynamax.types import Scalar +from dynamax.utils.cluster import kmeans_sklearn import optax from typing import NamedTuple, Optional, Tuple, Union @@ -48,16 +49,19 @@ def initialize(self, if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" assert inputs is not None, "Need inputs to initialize the model with K-Means!" - from sklearn.cluster import KMeans flat_emissions = emissions.reshape(-1,) flat_inputs = inputs.reshape(-1, self.input_dim) - key, subkey = jr.split(key) # Create a random seed for SKLearn. - sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. - km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(flat_inputs) + + _, km_labels = kmeans_sklearn(self.num_states, flat_inputs, key) _emission_weights = jnp.zeros((self.num_states, self.input_dim)) - _emission_biases = jnp.array([tfb.Sigmoid().inverse(flat_emissions[km.labels_ == k].mean()) - for k in range(self.num_states)]) + cluster_emissions_means = jnp.array( + [jnp.mean(flat_emissions, where=km_labels == k) for k in range(self.num_states)] + ) + cluster_emissions_means = jnp.where( + jnp.isnan(cluster_emissions_means), flat_emissions.mean(), cluster_emissions_means + ) + _emission_biases = tfb.Sigmoid().inverse(cluster_emissions_means) elif method.lower() == "prior": # TODO: Use an MNIW prior From 0f8646b45e67e89212f974ba5d2c9901fb887610 Mon Sep 17 00:00:00 2001 From: gileshd Date: Fri, 19 Jul 2024 22:14:02 +0100 Subject: [PATCH 3/3] Add jax implementation of kmeans --- dynamax/utils/cluster.py | 82 +++++++++++++++++++++++++++++++++-- dynamax/utils/cluster_test.py | 50 +++++++++++++++++++++ 2 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 dynamax/utils/cluster_test.py diff --git a/dynamax/utils/cluster.py b/dynamax/utils/cluster.py index a31c9e05..cda9fb82 100644 --- a/dynamax/utils/cluster.py +++ b/dynamax/utils/cluster.py @@ -1,9 +1,9 @@ -from typing import Tuple - +from functools import partial +from jax import lax, jit from jax import numpy as jnp from jax import random as jr - -from jaxtyping import Array, Float +from jaxtyping import Array, Int, Float +from typing import NamedTuple, Tuple def kmeans_sklearn( @@ -26,3 +26,77 @@ def kmeans_sklearn( sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. km = KMeans(k, random_state=int(sklearn_key)).fit(X) return jnp.array(km.cluster_centers_), jnp.array(km.labels_) + + +class KMeansState(NamedTuple): + centroids: Float[Array, "num_states state_dim"] + assignments: Int[Array, "num_samples"] + prev_centroids: Float[Array, "num_states state_dim"] + itr: int + + +@partial(jit, static_argnums=(1, 3)) +def kmeans_jax( + X: Float[Array, "num_samples state_dim"], + k: int, + key: Array = jr.PRNGKey(0), + max_iters: int = 1000, +) -> KMeansState: + """ + Perform k-means clustering using JAX. + + K-means++ initialization is used to initialize the centroids. + + Args: + X (Array): The input data array of shape (n_samples, n_features). + k (int): The number of clusters. + max_iters (int, optional): The maximum number of iterations. Defaults to 1000. + key (PRNGKey, optional): The random key for initialization. Defaults to jr.PRNGKey(0). + + Returns: + KMeansState: A named tuple containing the final centroids array of shape (k, n_features), + the assignments array of shape (n_samples,) indicating the cluster index for each sample, + the previous centroids array of shape (k, n_features), and the number of iterations. + """ + + def _update_centroids(X: Array, assignments: Array): + new_centroids = jnp.array([jnp.mean(X, axis=0, where=(assignments == i)[:, None]) for i in range(k)]) + return new_centroids + + def _update_assignments(X, centroids): + return jnp.argmin(jnp.linalg.norm(X[:, None] - centroids, axis=2), axis=1) + + def body(carry: KMeansState): + centroids, assignments, *_ = carry + new_centroids = _update_centroids(X, assignments) + new_assignments = _update_assignments(X, new_centroids) + return KMeansState(new_centroids, new_assignments, centroids, carry.itr + 1) + + def cond(carry: KMeansState): + return jnp.any(carry.centroids != carry.prev_centroids) & (carry.itr < max_iters) + + def init(key): + """kmeans++ initialization of centroids + + Iteratively sample new centroids with probability proportional to the squared distance + from the closest centroid. This initialization method is more stable than random + initialization and leads to faster convergence. + Ref: Arthur, D., & Vassilvitskii, S. (2006). + """ + centroids = jnp.zeros((k, X.shape[1])) + centroids = centroids.at[0, :].set(jr.choice(key, X)) + for i in range(1, k): + squared_diffs = jnp.sum((X[:, None, :] - centroids[None, :i, :]) ** 2, axis=2) + min_squared_dists = jnp.min(squared_diffs, axis=1) + probs = min_squared_dists / jnp.sum(min_squared_dists) + centroids = centroids.at[i, :].set(jr.choice(key, X, p=probs)) + assignments = _update_assignments(X, centroids) + # Perform one iteration to update centroids + updated_centroids = _update_centroids(X, assignments) + updated_assignments = _update_assignments(X, updated_centroids) + return KMeansState(updated_centroids, updated_assignments, centroids, 1) + + init_state = init(key) + state = lax.while_loop(cond, body, init_state) + + return state diff --git a/dynamax/utils/cluster_test.py b/dynamax/utils/cluster_test.py new file mode 100644 index 00000000..414120b3 --- /dev/null +++ b/dynamax/utils/cluster_test.py @@ -0,0 +1,50 @@ +from jax import numpy as jnp +from jax import random as jr +from jax import vmap + +from dynamax.utils.cluster import kmeans_jax + + +def test_kmeans_jax_toy(): + """Checks that kmeans works against toy example. + + Ref: scikit-learn tests + """ + + key = jr.PRNGKey(101) + x = jnp.array([[0, 0], [0.5, 0], [0.5, 1], [1, 1]]) + + centroids, assignments, *_ = kmeans_jax(x, 2, key) + + # There are two possible solutions for the centroids and assignments + try: + expected_labels = jnp.array([0, 0, 1, 1]) + expected_centers = jnp.array([[0.25, 0], [0.75, 1]]) + assert jnp.all(assignments == expected_labels) + assert jnp.allclose(centroids, expected_centers) + except AssertionError: + expected_labels = jnp.array([1, 1, 0, 0]) + expected_centers = jnp.array([[0.75, 1.0], [0.25, 0.0]]) + assert jnp.all(assignments == expected_labels) + assert jnp.allclose(centroids, expected_centers) + + +def test_kmeans_jax_vmap(): + """Test that kmeans_jax works with vmap.""" + + def _gen_data(key): + """Generate 3 clusters of 10 samples each.""" + subkeys = jr.split(key, 3) + means = jnp.array([-2., 0., 2.]) + _2D_normal = lambda key, mean: jr.normal(key, (10, 2))*0.2 + mean + return vmap(_2D_normal)(subkeys, means).reshape(-1, 2) + + key = jr.PRNGKey(5) + key, *data_subkeys = jr.split(key,3) + # Generate 2 samples of the 3-cluster data + x = vmap(_gen_data)(jnp.array(data_subkeys)) + + alg_subkeys = jr.split(key, 2) + _, assignments, *_ = vmap(kmeans_jax, (0, None, 0))(x, 3, alg_subkeys) + # Check that the assignments are the same for both samples (clusters are very distinct) + assert jnp.all(assignments[0] == assignments[1])