Skip to content

Commit

Permalink
Adds a better method for accessing the centroids of a KMeansModel (#98)
Browse files Browse the repository at this point in the history
* Adding back a constructor to KMeansTrainer to preserve backwards compatibility with 4.0 releases.

* Adding a new method to get the centroids out of KMeansModel, and updating the tutorial to use it and K-Means++.

* Updating the javadoc in KMeansModel.getCentroidVectors to point users at the new method.

* KMeansTrainer shouldn't have a mandatory initialisation parameter as it breaks backwards compatibility with provenance from 4.0.
  • Loading branch information
Craigacp authored Nov 30, 2020
1 parent f262f2a commit 931a57d
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
Expand All @@ -27,8 +28,10 @@
import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -64,6 +67,13 @@ public class KMeansModel extends Model<ClusterID> {

/**
* Returns a copy of the centroids.
* <p>
* In most cases you should prefer {@link #getCentroids} as
* it performs the mapping from Tribuo's internal feature ids
* to the externally visible feature names for you.
* This method provides direct access to the centroid vectors
* for use in downstream processing if the ids are not relevant
* (or are known to match).
* @return The centroids.
*/
public DenseVector[] getCentroidVectors() {
Expand All @@ -76,6 +86,32 @@ public DenseVector[] getCentroidVectors() {
return copies;
}

/**
* Returns a list of features, one per centroid.
* <p>
* This should be used in preference to {@link #getCentroidVectors()}
* as it performs the mapping from Tribuo's internal feature ids to
* the externally visible feature names.
* </p>
* @return A list containing all the centroids.
*/
public List<List<Feature>> getCentroids() {
List<List<Feature>> output = new ArrayList<>(centroidVectors.length);

for (int i = 0; i < centroidVectors.length; i++) {
List<Feature> features = new ArrayList<>(featureIDMap.size());

for (VectorTuple v : centroidVectors[i]) {
Feature f = new Feature(featureIDMap.get(v.index).getName(),v.value);
features.add(f);
}

output.add(features);
}

return output;
}

@Override
public Prediction<ClusterID> predict(Example<ClusterID> example) {
SparseVector vector = SparseVector.createSparseVector(example,featureIDMap,false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ public enum Initialisation {
@Config(mandatory = true, description = "The distance function to use.")
private Distance distanceType;

@Config(mandatory = true, description = "The centroid initialisation method to use.")
private Initialisation initialisationType;
@Config(description = "The centroid initialisation method to use.")
private Initialisation initialisationType = Initialisation.RANDOM;

@Config(description = "The number of threads to use for training.")
private int numThreads = 1;
Expand All @@ -141,6 +141,19 @@ public enum Initialisation {
private KMeansTrainer() {
}

/**
* Constructs a K-Means trainer using the supplied parameters and the default random initialisation.
*
* @param centroids The number of centroids to use.
* @param iterations The maximum number of iterations.
* @param distanceType The distance function.
* @param numThreads The number of threads.
* @param seed The random seed.
*/
public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) {
this(centroids,iterations,distanceType,Initialisation.RANDOM,numThreads,seed);
}

/**
* Constructs a K-Means trainer using the supplied parameters.
*
Expand Down
160 changes: 131 additions & 29 deletions tutorials/clustering-tribuo-v4.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"import org.tribuo.clustering.evaluation.*;\n",
"import org.tribuo.clustering.example.ClusteringDataGenerator;\n",
"import org.tribuo.clustering.kmeans.*;\n",
"import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;"
"import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;\n",
"import org.tribuo.clustering.kmeans.KMeansTrainer.Initialisation;"
]
},
{
Expand Down Expand Up @@ -97,7 +98,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training with 5 clusters took (00:00:00:062)\n"
"Training with 5 clusters took (00:00:00:049)\n"
]
}
],
Expand Down Expand Up @@ -125,16 +126,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"DenseVector(size=2,values=[-1.7294066290817505,-0.019280856227650595])\n",
"DenseVector(size=2,values=[2.740410056407627,2.8737688541143247])\n",
"DenseVector(size=2,values=[0.05102068424764918,0.0757660102333321])\n",
"DenseVector(size=2,values=[5.174977643580621,5.088149544081452])\n",
"DenseVector(size=2,values=[9.938804461039872,-0.020702060844743055])\n"
"[(A, -1.729407), (B, -0.019281)]\n",
"[(A, 2.740410), (B, 2.873769)]\n",
"[(A, 0.051021), (B, 0.075766)]\n",
"[(A, 5.174978), (B, 5.088150)]\n",
"[(A, 9.938804), (B, -0.020702)]\n"
]
}
],
"source": [
"var centroids = model.getCentroidVectors();\n",
"var centroids = model.getCentroids();\n",
"for (var centroid : centroids) {\n",
" System.out.println(centroid);\n",
"}"
Expand All @@ -154,7 +155,76 @@
"|4|2|\n",
"|5|4|\n",
"\n",
"Though the first one is a bit far out as it's x_1 should be -1.0 not -1.7, and there is a little wobble in the rest. Still it's pretty good considering K-Means assumes spherical gaussians and our data generator has a covariance matrix per gaussian."
"Though the first one is a bit far out as it's \"A\" feature should be -1.0 not -1.7, and there is a little wobble in the rest. Still it's pretty good considering K-Means assumes spherical gaussians and our data generator has a covariance matrix per gaussian."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## K-Means++\n",
"Tribuo also includes the K-Means++ initialisation algorithm, which we can run on our toy problem as follows:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training with 5 clusters took (00:00:00:042)\n"
]
}
],
"source": [
"var plusplusTrainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN,Initialisation.PLUSPLUS,1,1);\n",
"var startTime = System.currentTimeMillis();\n",
"var plusplusModel = plusplusTrainer.train(data);\n",
"var endTime = System.currentTimeMillis();\n",
"System.out.println(\"Training with 5 clusters took \" + Util.formatDuration(startTime,endTime));"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The training time isn't much different in this case, but the K-Means++ initialisation does take longer than the default on larger datasets. However the resulting clusters are usually better.\n",
"\n",
"We can check the centroids from this model using the same method as before."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(A, -1.567863), (B, -0.029534)]\n",
"[(A, 9.938804), (B, -0.020702)]\n",
"[(A, 3.876203), (B, 3.930657)]\n",
"[(A, 0.399868), (B, 0.330537)]\n",
"[(A, 5.520480), (B, 5.390406)]\n"
]
}
],
"source": [
"var ppCentroids = plusplusModel.getCentroids();\n",
"for (var centroid : ppCentroids) {\n",
" System.out.println(centroid);\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see in this case that the K-Means++ initialisation has warped the centroids slightly, so the fit isn't quite as nice as the default initialisation, but that's why we have evaluation data and measure model fit."
]
},
{
Expand All @@ -170,7 +240,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -181,7 +251,7 @@
"Adjusted MI = 0.8113314999600718"
]
},
"execution_count": 7,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -200,18 +270,18 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Clustering Evaluation\n",
"Normalized MI = 0.8154291916732408\n",
"Adjusted MI = 0.8139169342020222"
"Normalized MI = 0.8154291916732409\n",
"Adjusted MI = 0.8139169342020223"
]
},
"execution_count": 8,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -225,7 +295,39 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that as expected it's a pretty good correlation to the ground truth labels. K-Means (of the kind implemented in Tribuo) is similar to a gaussian mixture using spherical gaussians, and our data generator uses gaussians with full rank covariances, so it won't be perfect."
"We see that as expected it's a pretty good correlation to the ground truth labels. K-Means (of the kind implemented in Tribuo) is similar to a gaussian mixture using spherical gaussians, and our data generator uses gaussians with full rank covariances, so it won't be perfect.\n",
"\n",
"We can also check the K-Means++ model in the same way:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Clustering Evaluation\n",
"Normalized MI = 0.7881995472105396\n",
"Adjusted MI = 0.7864797287891366"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"var testPlusPlusEvaluation = eval.evaluate(plusplusModel,test);\n",
"testPlusPlusEvaluation.toString();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As expected with the slightly poorer quality centroids this initialisation gives then it's not got quite as good a fit. However we emphasise that K-Means++ usually improves the quality of the clustering, and so it's worth testing out if you're clustering data with Tribuo."
]
},
{
Expand All @@ -238,14 +340,14 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training with 5 clusters on 4 threads took (00:00:00:073)\n"
"Training with 5 clusters on 4 threads took (00:00:00:038)\n"
]
}
],
Expand All @@ -267,14 +369,14 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training with 20 clusters on 4 threads took (00:00:00:059)\n"
"Training with 20 clusters on 4 threads took (00:00:00:038)\n"
]
}
],
Expand All @@ -295,18 +397,18 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Clustering Evaluation\n",
"Normalized MI = 0.8104463467727057\n",
"Adjusted MI = 0.8088941747451207"
"Normalized MI = 0.8104463467727059\n",
"Adjusted MI = 0.8088941747451209"
]
},
"execution_count": 11,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -325,7 +427,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand All @@ -336,7 +438,7 @@
"Adjusted MI = 0.860327445295668"
]
},
"execution_count": 12,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -358,7 +460,7 @@
"metadata": {},
"source": [
"## Conclusion\n",
"We looked at clustering using Tribuo's K-Means implementation, comparing both the single-threaded and multi-threaded versions, then looked at the performance metrics available when there are ground truth clusterings.\n",
"We looked at clustering using Tribuo's K-Means implementation, experimented with different initialisations, and compared both the single-threaded and multi-threaded versions. Then we looked at the performance metrics available when there are ground truth clusterings.\n",
"\n",
"We plan to further expand Tribuo's clustering functionality to incorporate other algorithms in the future. If you want to help, or have specific algorithmic requirements, file an issue on our [github page](https://github.com/oracle/tribuo)."
]
Expand All @@ -376,9 +478,9 @@
"mimetype": "text/x-java-source",
"name": "Java",
"pygments_lexer": "java",
"version": "14+36-1461"
"version": "16+10-20201111"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}

0 comments on commit 931a57d

Please sign in to comment.