From 9f40db7b5f22266e81dc47f45092fe7ac84147de Mon Sep 17 00:00:00 2001 From: Samantha Campo Date: Tue, 25 Aug 2020 16:07:16 -0400 Subject: [PATCH 1/7] Initial attempt --- .../clustering/kmeans/KMeansOptions.java | 7 +- .../clustering/kmeans/KMeansTrainer.java | 140 ++++++++++++++---- .../tribuo/clustering/kmeans/TrainTest.java | 7 +- .../tribuo/clustering/kmeans/TestKMeans.java | 6 +- 4 files changed, 131 insertions(+), 29 deletions(-) diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java index 3646165f1..cfe38066f 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java @@ -20,6 +20,7 @@ import com.oracle.labs.mlrg.olcut.config.Options; import org.tribuo.Trainer; import org.tribuo.clustering.kmeans.KMeansTrainer.Distance; +import org.tribuo.clustering.kmeans.KMeansTrainer.Initialisation; import java.util.logging.Logger; @@ -35,6 +36,9 @@ public class KMeansOptions implements Options { public int centroids = 10; @Option(longName="kmeans-distance",usage="Distance function in K-Means. Defaults to EUCLIDEAN.") public Distance distance = Distance.EUCLIDEAN; + @Option(longName="kmeans-initialisation",usage="Initialisation function " + + "in K-Means. Defaults to UNIFORM.") + public Initialisation initialisation = Initialisation.UNIFORM; @Option(longName="kmeans-num-threads",usage="Number of computation threads in K-Means. Defaults to 4.") public int numThreads = 4; @Option(longName="kmeans-seed", usage = "Sets the random seed for K-Means.") @@ -43,6 +47,7 @@ public class KMeansOptions implements Options { public KMeansTrainer getTrainer() { logger.info("Configuring K-Means Trainer"); //public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, int seed) { - return new KMeansTrainer(centroids,iterations,distance,numThreads,seed); + return new KMeansTrainer(centroids,iterations,distance,initialisation + ,numThreads,seed); } } diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index a59793bce..5f1bce68a 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -27,20 +27,15 @@ import org.tribuo.clustering.ClusterID; import org.tribuo.clustering.ImmutableClusteringInfo; import org.tribuo.math.la.DenseVector; +import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.SparseVector; import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.provenance.impl.TrainerProvenanceImpl; import java.time.OffsetDateTime; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; -import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.atomic.AtomicInteger; @@ -89,6 +84,21 @@ public enum Distance { L1 } + /** + * Possible initialization functions. + */ + public enum Initialisation { + /** + * Initialize centroids by choosing uniformly at random from the data + * points. + */ + UNIFORM, + /** + * KMeans++ initialisation. + */ + PLUSPLUS + } + @Config(mandatory = true,description="Number of centroids (i.e. the \"k\" in k-means).") private int centroids; @@ -98,6 +108,9 @@ public enum Distance { @Config(mandatory = true,description="The distance function to use.") private Distance distanceType; + @Config(mandatory = true,description="The initialisation to use.") + private Initialisation initialisationType; + @Config(description="The number of threads to use for training.") private int numThreads = 1; @@ -121,10 +134,12 @@ private KMeansTrainer() {} * @param numThreads The number of threads. * @param seed The random seed. */ - public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, long seed) { + public KMeansTrainer(int centroids, int iterations, Distance distanceType + , Initialisation initialisationType, int numThreads, long seed) { this.centroids = centroids; this.iterations = iterations; this.distanceType = distanceType; + this.initialisationType = initialisationType; this.numThreads = numThreads; this.seed = seed; postConfig(); @@ -146,7 +161,6 @@ public KMeansModel train(Dataset examples, Map ru trainInvocationCounter++; } ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); - DenseVector[] centroidVectors = initialiseCentroids(centroids,examples,featureMap,localRNG); ForkJoinPool fjp = new ForkJoinPool(numThreads); @@ -161,6 +175,19 @@ public KMeansModel train(Dataset examples, Map ru n++; } + DenseVector[] centroidVectors; + switch (initialisationType) { + case UNIFORM: + centroidVectors = initialiseCentroids(centroids,examples,featureMap,localRNG); + break; + case PLUSPLUS: + centroidVectors = initialisePlusPlusCentroids(centroids, + data, featureMap,localRNG); + break; + default: + throw new IllegalStateException("Unknown initialisation" + initialisationType); + } + Map> clusterAssignments = new HashMap<>(); for (int i = 0; i < centroids; i++) { clusterAssignments.put(i,Collections.synchronizedList(new ArrayList<>())); @@ -193,20 +220,7 @@ public KMeansModel train(Dataset examples, Map ru SparseVector vector = e.vector; for (int j = 0; j < centroids; j++) { DenseVector cluster = centroidVectors[j]; - double distance; - switch (distanceType) { - case EUCLIDEAN: - distance = cluster.euclideanDistance(vector); - break; - case COSINE: - distance = cluster.cosineDistance(vector); - break; - case L1: - distance = cluster.l1Distance(vector); - break; - default: - throw new IllegalStateException("Unknown distance " + distanceType); - } + double distance = getDistance(cluster, vector); if (distance < minDist) { minDist = distance; clusterID = j; @@ -261,8 +275,6 @@ public int getInvocationCount() { /** * Initialisation method called at the start of each train call. * - * Used to allow overriding for kmeans++, kmedoids etc. - * * @param centroids The number of centroids to create. * @param examples The dataset to use. * @param featureMap The feature map to use for centroid sampling. @@ -284,6 +296,84 @@ protected static DenseVector[] initialiseCentroids(int centroids, Dataset maxDist) { + maxDist = curDist; + idxOfMax = i; + } + } + return idxOfMax; + } + protected double minDistancePerVector(SparseVector curVec, + int numInitializedCentroids, + DenseVector[] centroidVectors) { + double minDistance = Double.POSITIVE_INFINITY; + double tempDistance; + + // iterate through all previously initialized centroid + for (int i = 0; i < numInitializedCentroids; i++) { + DenseVector curCentroid = centroidVectors[i]; + tempDistance = getDistance(curCentroid, curVec); + minDistance = Math.max(minDistance, tempDistance); + } + return minDistance; + } + protected double getDistance(DenseVector cluster, SGDVector vector) { + double distance; + switch (distanceType) { + case EUCLIDEAN: + distance = cluster.euclideanDistance(vector); + break; + case COSINE: + distance = cluster.cosineDistance(vector); + break; + case L1: + distance = cluster.l1Distance(vector); + break; + default: + throw new IllegalStateException("Unknown distance " + distanceType); + } + return distance; + } + protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map> clusterAssignments, SparseVector[] data, double[] weights) { // M step Stream>> mStream; diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java index 257bc0dbd..928f0fea5 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java @@ -28,6 +28,7 @@ import org.tribuo.clustering.ClusteringFactory; import org.tribuo.clustering.evaluation.ClusteringEvaluation; import org.tribuo.clustering.kmeans.KMeansTrainer.Distance; +import org.tribuo.clustering.kmeans.KMeansTrainer.Initialisation; import org.tribuo.data.DataOptions; import java.io.IOException; @@ -54,6 +55,9 @@ public String getOptionsDescription() { public int iterations = 10; @Option(charName='d',longName="distance",usage="Distance function to use in the e step. Defaults to EUCLIDEAN.") public Distance distance = Distance.EUCLIDEAN; + @Option(charName='i',longName="initialisation",usage="Type of initialisation " + + "to use for centroids. Defaults to UNIFORM.") + public Initialisation initialisation = Initialisation.UNIFORM; @Option(charName='t',longName="num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).") public int numThreads = 4; } @@ -87,7 +91,8 @@ public static void main(String[] args) throws IOException { Dataset train = data.getA(); //public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, int seed) - KMeansTrainer trainer = new KMeansTrainer(o.centroids,o.iterations,o.distance,o.numThreads,o.general.seed); + KMeansTrainer trainer = new KMeansTrainer(o.centroids,o.iterations, + o.distance,o.initialisation,o.numThreads,o.general.seed); Model model = trainer.train(train); logger.info("Finished training model"); ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model,train); diff --git a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java index a3c715828..ed89b5ce6 100644 --- a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java +++ b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java @@ -38,7 +38,8 @@ */ public class TestKMeans { - private static final KMeansTrainer t = new KMeansTrainer(4,10,Distance.EUCLIDEAN,1,1); + private static final KMeansTrainer t = new KMeansTrainer(4,10, + Distance.EUCLIDEAN, KMeansTrainer.Initialisation.UNIFORM, 1,1); @BeforeAll public static void setup() { @@ -52,7 +53,8 @@ public void testEvaluation() { Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); ClusteringEvaluator eval = new ClusteringEvaluator(); - KMeansTrainer trainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN,1,1); + KMeansTrainer trainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN, + KMeansTrainer.Initialisation.UNIFORM,1,1); KMeansModel model = trainer.train(data); From 7be66288d4de971516a8d6937f354b5ab00ea1b6 Mon Sep 17 00:00:00 2001 From: Samantha Campo Date: Tue, 25 Aug 2020 23:42:41 -0400 Subject: [PATCH 2/7] Possibly complete but needs more testing --- .../clustering/kmeans/KMeansTrainer.java | 86 ++++++++++--------- .../tribuo/clustering/kmeans/TestKMeans.java | 17 +++- 2 files changed, 58 insertions(+), 45 deletions(-) diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index 5f1bce68a..da1447b41 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -32,7 +32,9 @@ import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.provenance.impl.TrainerProvenanceImpl; +import org.tribuo.util.Util; +import javax.xml.stream.XMLInputFactory; import java.time.OffsetDateTime; import java.util.*; import java.util.Map.Entry; @@ -182,7 +184,7 @@ public KMeansModel train(Dataset examples, Map ru break; case PLUSPLUS: centroidVectors = initialisePlusPlusCentroids(centroids, - data, featureMap,localRNG); + data,featureMap,localRNG); break; default: throw new IllegalStateException("Unknown initialisation" + initialisationType); @@ -300,62 +302,62 @@ protected DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVector[] data, ImmutableFeatureMap featureMap, SplittableRandom rng) { + int numFeatures = featureMap.size(); + double[] minDistancePerVector = new double[data.length]; + Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY); + double[] squared_min_distance = new double[data.length]; + double[] probabilities = new double[data.length]; DenseVector[] centroidVectors = new DenseVector[centroids]; - double[] newCentroid = getRandomCentroidFromData(data); - centroidVectors[0] = DenseVector.createDenseVector(newCentroid); - // Set each currently uninitialised centroid + // set first centroid randomly from the data + centroidVectors[0] = getRandomCentroidFromData(data, numFeatures); + + // Set each uninitialised centroid remaining for (int i = 1; i < centroids; i++) { - double[] distancePerVec = new double[data.length]; + DenseVector prevCentroid = centroidVectors[i-1]; - // go through every vector + // go through every vector and see if the min distance to the + // current vector is smaller than previous min distance for vec + double tempDistance; for (int j = 0; j < data.length; j++) { SparseVector curVec = data[j]; - distancePerVec[j] = minDistancePerVector(curVec, i, centroidVectors); + tempDistance = getDistance(prevCentroid, curVec); + minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance); } - int idxOfMax = argmax(distancePerVec); - newCentroid = data[idxOfMax].toDenseArray(); - centroidVectors[i] = DenseVector.createDenseVector(newCentroid); + + // square the distances and get total for normalization + double total = 0.0; + for (int j = 0; j < data.length; j++) { + squared_min_distance[j] = Math.pow(minDistancePerVector[j], 2); + total += squared_min_distance[j]; + } + + // compute probabilites as p[i] = D(xi)^2 / sum(D(x)^2) + for (int j = 0; j < probabilities.length; j++) { + probabilities[j] = squared_min_distance[j] / total; + } + + // sample from probabilites to get the new centroid from data + double[] cdf = Util.generateCDF(probabilities); + int idx = Util.sampleFromCDF(cdf, rng); + centroidVectors[i] = sparseToDense(data[idx], numFeatures); } return centroidVectors; } - protected double[] getRandomCentroidFromData(SparseVector[] data) { - Random rand = new Random(); - int rand_idx = rand.nextInt(data.length); - double[] newCentroid = data[rand_idx].toDenseArray(); - return newCentroid; + protected DenseVector getRandomCentroidFromData(SparseVector[] data, + int numFeatures) { + int rand_idx = rng.nextInt(data.length); + return sparseToDense(data[rand_idx], numFeatures); } - protected int argmax(double[] distancePerVec) { - double maxDist = Double.NEGATIVE_INFINITY; - int idxOfMax = -1; - - double curDist; - for (int i = 0; i < distancePerVec.length; i++) { - curDist = distancePerVec[i]; - if (curDist > maxDist) { - maxDist = curDist; - idxOfMax = i; - } - } - return idxOfMax; - } - protected double minDistancePerVector(SparseVector curVec, - int numInitializedCentroids, - DenseVector[] centroidVectors) { - double minDistance = Double.POSITIVE_INFINITY; - double tempDistance; - - // iterate through all previously initialized centroid - for (int i = 0; i < numInitializedCentroids; i++) { - DenseVector curCentroid = centroidVectors[i]; - tempDistance = getDistance(curCentroid, curVec); - minDistance = Math.max(minDistance, tempDistance); - } - return minDistance; + protected DenseVector sparseToDense(SparseVector vec, int numFeatures) { + DenseVector dense = new DenseVector(numFeatures); + dense.intersectAndAddInPlace(vec); + return dense; } + protected double getDistance(DenseVector cluster, SGDVector vector) { double distance; switch (distanceType) { diff --git a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java index ed89b5ce6..32c13cd35 100644 --- a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java +++ b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java @@ -41,6 +41,9 @@ public class TestKMeans { private static final KMeansTrainer t = new KMeansTrainer(4,10, Distance.EUCLIDEAN, KMeansTrainer.Initialisation.UNIFORM, 1,1); + private static final KMeansTrainer plusPlus = new KMeansTrainer(4,10, + Distance.EUCLIDEAN, KMeansTrainer.Initialisation.PLUSPLUS, 1,1); + @BeforeAll public static void setup() { Logger logger = Logger.getLogger(KMeansTrainer.class.getName()); @@ -53,10 +56,10 @@ public void testEvaluation() { Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); ClusteringEvaluator eval = new ClusteringEvaluator(); - KMeansTrainer trainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN, - KMeansTrainer.Initialisation.UNIFORM,1,1); +// KMeansTrainer trainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN, +// KMeansTrainer.Initialisation.UNIFORM,1,1); - KMeansModel model = trainer.train(data); + KMeansModel model = plusPlus.train(data); ClusteringEvaluation trainEvaluation = eval.evaluate(model,data); assertFalse(Double.isNaN(trainEvaluation.adjustedMI())); @@ -73,6 +76,14 @@ public void testKMeans(Pair,Dataset> p) { e.evaluate(m,p.getB()); } + public void testKMeansPlusPlus(Pair, + Dataset> p) { + Model m = plusPlus.train(p.getA()); + ClusteringEvaluator e = new ClusteringEvaluator(); + e.evaluate(m,p.getB()); + } + + @Test public void testDenseData() { Pair,Dataset> p = ClusteringDataGenerator.denseTrainTest(); From 47099dd05080f8936d6b2db984e7a438ada9c021 Mon Sep 17 00:00:00 2001 From: Samantha Campo Date: Tue, 25 Aug 2020 23:55:42 -0400 Subject: [PATCH 3/7] start testing --- .../tribuo/clustering/kmeans/TestKMeans.java | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java index 32c13cd35..1ee53d04b 100644 --- a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java +++ b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java @@ -52,24 +52,44 @@ public static void setup() { @Test public void testEvaluation() { - Dataset data = ClusteringDataGenerator.gaussianClusters(500, 1L); - Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); - ClusteringEvaluator eval = new ClusteringEvaluator(); + Pair, Dataset> datasets = + getEvaluationData(); // KMeansTrainer trainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN, // KMeansTrainer.Initialisation.UNIFORM,1,1); - KMeansModel model = plusPlus.train(data); + KMeansModel model = t.train(datasets.getA()); + evaluationHelper(model, datasets); + } - ClusteringEvaluation trainEvaluation = eval.evaluate(model,data); + @Test + public void testPlusPlusEvaluation() { + Pair, Dataset> datasets = + getEvaluationData(); + + KMeansModel model = plusPlus.train(datasets.getA()); + evaluationHelper(model, datasets); + } + + public void evaluationHelper(KMeansModel model, Pair, + Dataset> datasets) { + ClusteringEvaluator eval = new ClusteringEvaluator(); + ClusteringEvaluation trainEvaluation = eval.evaluate(model,datasets.getA()); assertFalse(Double.isNaN(trainEvaluation.adjustedMI())); assertFalse(Double.isNaN(trainEvaluation.normalizedMI())); - ClusteringEvaluation testEvaluation = eval.evaluate(model,test); + ClusteringEvaluation testEvaluation = eval.evaluate(model, + datasets.getB()); assertFalse(Double.isNaN(testEvaluation.adjustedMI())); assertFalse(Double.isNaN(testEvaluation.normalizedMI())); } + public Pair, Dataset> getEvaluationData() { + Dataset data = ClusteringDataGenerator.gaussianClusters(500, 1L); + Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); + return new Pair<>(data, test); + } + public void testKMeans(Pair,Dataset> p) { Model m = t.train(p.getA()); ClusteringEvaluator e = new ClusteringEvaluator(); From 740ca327d914bb1e5bd3b6c5083fb7ca779f3dfc Mon Sep 17 00:00:00 2001 From: Samantha Campo Date: Wed, 26 Aug 2020 14:55:16 -0400 Subject: [PATCH 4/7] testing complete --- .../clustering/kmeans/KMeansTrainer.java | 9 +- .../tribuo/clustering/kmeans/TestKMeans.java | 98 +++++++++++-------- 2 files changed, 60 insertions(+), 47 deletions(-) diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index da1447b41..38e767a5d 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -251,7 +251,6 @@ public KMeansModel train(Dataset examples, Map ru } } - Map counts = new HashMap<>(); for (Entry> e : clusterAssignments.entrySet()) { counts.put(e.getKey(),new MutableLong(e.getValue().size())); @@ -306,7 +305,7 @@ protected DenseVector[] initialisePlusPlusCentroids(int centroids, double[] minDistancePerVector = new double[data.length]; Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY); - double[] squared_min_distance = new double[data.length]; + double[] squaredMinDistance = new double[data.length]; double[] probabilities = new double[data.length]; DenseVector[] centroidVectors = new DenseVector[centroids]; @@ -329,13 +328,13 @@ protected DenseVector[] initialisePlusPlusCentroids(int centroids, // square the distances and get total for normalization double total = 0.0; for (int j = 0; j < data.length; j++) { - squared_min_distance[j] = Math.pow(minDistancePerVector[j], 2); - total += squared_min_distance[j]; + squaredMinDistance[j] = Math.pow(minDistancePerVector[j], 2); + total += squaredMinDistance[j]; } // compute probabilites as p[i] = D(xi)^2 / sum(D(x)^2) for (int j = 0; j < probabilities.length; j++) { - probabilities[j] = squared_min_distance[j] / total; + probabilities[j] = squaredMinDistance[j] / total; } // sample from probabilites to get the new centroid from data diff --git a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java index 1ee53d04b..7be62e01a 100644 --- a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java +++ b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java @@ -52,86 +52,100 @@ public static void setup() { @Test public void testEvaluation() { - Pair, Dataset> datasets = - getEvaluationData(); - -// KMeansTrainer trainer = new KMeansTrainer(5,10,Distance.EUCLIDEAN, -// KMeansTrainer.Initialisation.UNIFORM,1,1); - - KMeansModel model = t.train(datasets.getA()); - evaluationHelper(model, datasets); + runEvaluation(t); } @Test public void testPlusPlusEvaluation() { - Pair, Dataset> datasets = - getEvaluationData(); - - KMeansModel model = plusPlus.train(datasets.getA()); - evaluationHelper(model, datasets); + runEvaluation(plusPlus); } - public void evaluationHelper(KMeansModel model, Pair, - Dataset> datasets) { + public void runEvaluation(KMeansTrainer trainer) { + Dataset data = ClusteringDataGenerator.gaussianClusters(500, 1L); + Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); ClusteringEvaluator eval = new ClusteringEvaluator(); - ClusteringEvaluation trainEvaluation = eval.evaluate(model,datasets.getA()); + + KMeansModel model = trainer.train(data); + + ClusteringEvaluation trainEvaluation = eval.evaluate(model,data); assertFalse(Double.isNaN(trainEvaluation.adjustedMI())); assertFalse(Double.isNaN(trainEvaluation.normalizedMI())); - ClusteringEvaluation testEvaluation = eval.evaluate(model, - datasets.getB()); + ClusteringEvaluation testEvaluation = eval.evaluate(model,test); assertFalse(Double.isNaN(testEvaluation.adjustedMI())); - assertFalse(Double.isNaN(testEvaluation.normalizedMI())); - } + assertFalse(Double.isNaN(testEvaluation.normalizedMI())); } - public Pair, Dataset> getEvaluationData() { - Dataset data = ClusteringDataGenerator.gaussianClusters(500, 1L); - Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); - return new Pair<>(data, test); - } - - public void testKMeans(Pair,Dataset> p) { - Model m = t.train(p.getA()); + public void testTrainer(Pair, + Dataset> p, KMeansTrainer trainer) { + Model m = trainer.train(p.getA()); ClusteringEvaluator e = new ClusteringEvaluator(); e.evaluate(m,p.getB()); } - public void testKMeansPlusPlus(Pair, - Dataset> p) { - Model m = plusPlus.train(p.getA()); - ClusteringEvaluator e = new ClusteringEvaluator(); - e.evaluate(m,p.getB()); + public void runDenseData(KMeansTrainer trainer) { + Pair,Dataset> p = ClusteringDataGenerator.denseTrainTest(); + testTrainer(p, trainer); } - @Test public void testDenseData() { - Pair,Dataset> p = ClusteringDataGenerator.denseTrainTest(); - testKMeans(p); + runDenseData(t); } @Test - public void testSparseData() { + public void testPlusPlusDenseData() { + runDenseData(plusPlus); + } + + public void runSparseData(KMeansTrainer trainer) { Pair,Dataset> p = ClusteringDataGenerator.sparseTrainTest(); - testKMeans(p); + testTrainer(p, trainer); } @Test - public void testInvalidExample() { + public void testSparseData() { + runSparseData(t); + } + + @Test + public void testPlusPlusSparseData() { + runSparseData(plusPlus); + } + + public void runInvalidExample(KMeansTrainer trainer) { assertThrows(IllegalArgumentException.class, () -> { Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); - Model m = t.train(p.getA()); + Model m = trainer.train(p.getA()); m.predict(ClusteringDataGenerator.invalidSparseExample()); }); } @Test - public void testEmptyExample() { + public void testInvalidExample() { + runInvalidExample(t); + } + + @Test + public void testPlusPlusInvalidExample() { + runInvalidExample(plusPlus); + } + + + public void runEmptyExample(KMeansTrainer trainer) { assertThrows(IllegalArgumentException.class, () -> { Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); - Model m = t.train(p.getA()); + Model m = trainer.train(p.getA()); m.predict(ClusteringDataGenerator.emptyExample()); }); } + @Test + public void testEmptyExample() { + runEmptyExample(t); + } + + @Test + public void testPlusPlusEmptyExample() { + runEmptyExample(plusPlus); + } } From 5d6291450f3ac9d5c2df0b960b222b0e83ffc432 Mon Sep 17 00:00:00 2001 From: Samantha Campo Date: Mon, 31 Aug 2020 14:23:00 -0400 Subject: [PATCH 5/7] Finish documentation --- .../clustering/kmeans/KMeansTrainer.java | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index 38e767a5d..b1b55b326 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -110,7 +110,8 @@ public enum Initialisation { @Config(mandatory = true,description="The distance function to use.") private Distance distanceType; - @Config(mandatory = true,description="The initialisation to use.") + @Config(mandatory = true,description="The centroid initialisation " + + "method to use.") private Initialisation initialisationType; @Config(description="The number of threads to use for training.") @@ -133,6 +134,7 @@ private KMeansTrainer() {} * @param centroids The number of centroids to use. * @param iterations The maximum number of iterations. * @param distanceType The distance function. + * @param initialisationType The centroid initialization method. * @param numThreads The number of threads. * @param seed The random seed. */ @@ -274,7 +276,8 @@ public int getInvocationCount() { } /** - * Initialisation method called at the start of each train call. + * Initialisation method called at the start of each train call when + * using uniform centroid initialisation. * * @param centroids The number of centroids to create. * @param examples The dataset to use. @@ -297,6 +300,16 @@ protected static DenseVector[] initialiseCentroids(int centroids, Dataset Date: Tue, 1 Sep 2020 16:08:01 -0400 Subject: [PATCH 6/7] Add check and test for more centroids then data. Fix wrapping. Documentation. --- .../clustering/kmeans/KMeansOptions.java | 8 +- .../clustering/kmeans/KMeansTrainer.java | 129 ++++++++++-------- .../tribuo/clustering/kmeans/TrainTest.java | 6 +- .../tribuo/clustering/kmeans/TestKMeans.java | 22 ++- 4 files changed, 93 insertions(+), 72 deletions(-) diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java index cfe38066f..6807bf7ba 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java @@ -36,9 +36,8 @@ public class KMeansOptions implements Options { public int centroids = 10; @Option(longName="kmeans-distance",usage="Distance function in K-Means. Defaults to EUCLIDEAN.") public Distance distance = Distance.EUCLIDEAN; - @Option(longName="kmeans-initialisation",usage="Initialisation function " + - "in K-Means. Defaults to UNIFORM.") - public Initialisation initialisation = Initialisation.UNIFORM; + @Option(longName="kmeans-initialisation",usage="Initialisation function in K-Means. Defaults to RANDOM.") + public Initialisation initialisation = Initialisation.RANDOM; @Option(longName="kmeans-num-threads",usage="Number of computation threads in K-Means. Defaults to 4.") public int numThreads = 4; @Option(longName="kmeans-seed", usage = "Sets the random seed for K-Means.") @@ -47,7 +46,6 @@ public class KMeansOptions implements Options { public KMeansTrainer getTrainer() { logger.info("Configuring K-Means Trainer"); //public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, int seed) { - return new KMeansTrainer(centroids,iterations,distance,initialisation - ,numThreads,seed); + return new KMeansTrainer(centroids,iterations,distance,initialisation,numThreads,seed); } } diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index b1b55b326..f5330178c 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -34,10 +34,15 @@ import org.tribuo.provenance.impl.TrainerProvenanceImpl; import org.tribuo.util.Util; -import javax.xml.stream.XMLInputFactory; import java.time.OffsetDateTime; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; +import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.atomic.AtomicInteger; @@ -64,6 +69,11 @@ * "The Elements of Statistical Learning" * Springer 2001. PDF * + *
+ * D. Arthur, S. Vassilvitskii.
+ * "K-Means++: The Advantages of Careful Seeding"
+ * PDF
+ * 
*/ public class KMeansTrainer implements Trainer { private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName()); @@ -94,30 +104,29 @@ public enum Initialisation { * Initialize centroids by choosing uniformly at random from the data * points. */ - UNIFORM, + RANDOM, /** * KMeans++ initialisation. */ PLUSPLUS } - @Config(mandatory = true,description="Number of centroids (i.e. the \"k\" in k-means).") + @Config(mandatory = true, description = "Number of centroids (i.e. the \"k\" in k-means).") private int centroids; - @Config(mandatory = true,description="The number of iterations to run.") + @Config(mandatory = true, description = "The number of iterations to run.") private int iterations; - @Config(mandatory = true,description="The distance function to use.") + @Config(mandatory = true, description = "The distance function to use.") private Distance distanceType; - @Config(mandatory = true,description="The centroid initialisation " + - "method to use.") + @Config(mandatory = true, description = "The centroid initialisation method to use.") private Initialisation initialisationType; - @Config(description="The number of threads to use for training.") + @Config(description = "The number of threads to use for training.") private int numThreads = 1; - @Config(mandatory = true,description="The seed to use for the RNG.") + @Config(mandatory = true, description = "The seed to use for the RNG.") private long seed; private SplittableRandom rng; @@ -127,10 +136,12 @@ public enum Initialisation { /** * for olcut. */ - private KMeansTrainer() {} + private KMeansTrainer() { + } /** * Constructs a K-Means trainer using the supplied parameters. + * * @param centroids The number of centroids to use. * @param iterations The maximum number of iterations. * @param distanceType The distance function. @@ -138,8 +149,7 @@ private KMeansTrainer() {} * @param numThreads The number of threads. * @param seed The random seed. */ - public KMeansTrainer(int centroids, int iterations, Distance distanceType - , Initialisation initialisationType, int numThreads, long seed) { + public KMeansTrainer(int centroids, int iterations, Distance distanceType, Initialisation initialisationType, int numThreads, long seed) { this.centroids = centroids; this.iterations = iterations; this.distanceType = distanceType; @@ -159,7 +169,7 @@ public KMeansModel train(Dataset examples, Map ru // Creates a new local RNG and adds one to the invocation count. TrainerProvenance trainerProvenance; SplittableRandom localRNG; - synchronized(this) { + synchronized (this) { localRNG = rng.split(); trainerProvenance = getProvenance(); trainInvocationCounter++; @@ -174,27 +184,26 @@ public KMeansModel train(Dataset examples, Map ru int n = 0; for (Example example : examples) { weights[n] = example.getWeight(); - data[n] = SparseVector.createSparseVector(example,featureMap,false); + data[n] = SparseVector.createSparseVector(example, featureMap, false); oldCentre[n] = -1; n++; } DenseVector[] centroidVectors; switch (initialisationType) { - case UNIFORM: - centroidVectors = initialiseCentroids(centroids,examples,featureMap,localRNG); + case RANDOM: + centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG); break; case PLUSPLUS: - centroidVectors = initialisePlusPlusCentroids(centroids, - data,featureMap,localRNG); + centroidVectors = initialisePlusPlusCentroids(centroids, data, featureMap, localRNG, distanceType); break; default: throw new IllegalStateException("Unknown initialisation" + initialisationType); } - Map> clusterAssignments = new HashMap<>(); + Map> clusterAssignments = new HashMap<>(); for (int i = 0; i < centroids; i++) { - clusterAssignments.put(i,Collections.synchronizedList(new ArrayList<>())); + clusterAssignments.put(i, Collections.synchronizedList(new ArrayList<>())); } boolean converged = false; @@ -203,18 +212,18 @@ public KMeansModel train(Dataset examples, Map ru //logger.log(Level.INFO,"Beginning iteration " + i); AtomicInteger changeCounter = new AtomicInteger(0); - for (Entry> e : clusterAssignments.entrySet()) { + for (Entry> e : clusterAssignments.entrySet()) { e.getValue().clear(); } // E step Stream vecStream = Arrays.stream(data); - Stream intStream = IntStream.range(0,data.length).boxed(); + Stream intStream = IntStream.range(0, data.length).boxed(); Stream eStream; if (numThreads > 1) { - eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream,vecStream,IntAndVector::new).parallel()); + eStream = StreamUtil.boundParallelism(StreamUtil.zip(intStream, vecStream, IntAndVector::new).parallel()); } else { - eStream = StreamUtil.zip(intStream,vecStream,IntAndVector::new); + eStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); } try { fjp.submit(() -> eStream.forEach((IntAndVector e) -> { @@ -224,7 +233,7 @@ public KMeansModel train(Dataset examples, Map ru SparseVector vector = e.vector; for (int j = 0; j < centroids; j++) { DenseVector cluster = centroidVectors[j]; - double distance = getDistance(cluster, vector); + double distance = getDistance(cluster, vector, distanceType); if (distance < minDist) { minDist = distance; clusterID = j; @@ -239,11 +248,11 @@ public KMeansModel train(Dataset examples, Map ru } })).get(); } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException("Parallel execution failed",e); + throw new RuntimeException("Parallel execution failed", e); } //logger.log(Level.INFO, "E step completed. " + changeCounter.get() + " words updated."); - mStep(fjp,centroidVectors,clusterAssignments,data,weights); + mStep(fjp, centroidVectors, clusterAssignments, data, weights); logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated."); @@ -253,21 +262,22 @@ public KMeansModel train(Dataset examples, Map ru } } - Map counts = new HashMap<>(); - for (Entry> e : clusterAssignments.entrySet()) { - counts.put(e.getKey(),new MutableLong(e.getValue().size())); + Map counts = new HashMap<>(); + for (Entry> e : clusterAssignments.entrySet()) { + counts.put(e.getKey(), new MutableLong(e.getValue().size())); } ImmutableOutputInfo outputMap = new ImmutableClusteringInfo(counts); - ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); + ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), + examples.getProvenance(), trainerProvenance, runProvenance); - return new KMeansModel("",provenance,featureMap,outputMap,centroidVectors, distanceType); + return new KMeansModel("", provenance, featureMap, outputMap, centroidVectors, distanceType); } @Override public KMeansModel train(Dataset dataset) { - return train(dataset,Collections.emptyMap()); + return train(dataset, Collections.emptyMap()); } @Override @@ -279,13 +289,13 @@ public int getInvocationCount() { * Initialisation method called at the start of each train call when * using uniform centroid initialisation. * - * @param centroids The number of centroids to create. - * @param examples The dataset to use. + * @param centroids The number of centroids to create. * @param featureMap The feature map to use for centroid sampling. * @param rng The RNG to use. * @return A {@link DenseVector} array of centroids. */ - protected static DenseVector[] initialiseCentroids(int centroids, Dataset examples, ImmutableFeatureMap featureMap, SplittableRandom rng) { + private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableFeatureMap featureMap, + SplittableRandom rng) { DenseVector[] centroidVectors = new DenseVector[centroids]; int numFeatures = featureMap.size(); for (int i = 0; i < centroids; i++) { @@ -310,10 +320,13 @@ protected static DenseVector[] initialiseCentroids(int centroids, Dataset data.length) { + throw new IllegalArgumentException("The number of centroids may not exceed the number of samples."); + } + int numFeatures = featureMap.size(); double[] minDistancePerVector = new double[data.length]; Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY); @@ -323,25 +336,25 @@ protected DenseVector[] initialisePlusPlusCentroids(int centroids, DenseVector[] centroidVectors = new DenseVector[centroids]; // set first centroid randomly from the data - centroidVectors[0] = getRandomCentroidFromData(data, numFeatures); + centroidVectors[0] = getRandomCentroidFromData(data, numFeatures, rng); // Set each uninitialised centroid remaining for (int i = 1; i < centroids; i++) { - DenseVector prevCentroid = centroidVectors[i-1]; + DenseVector prevCentroid = centroidVectors[i - 1]; // go through every vector and see if the min distance to the // newest centroid is smaller than previous min distance for vec double tempDistance; for (int j = 0; j < data.length; j++) { SparseVector curVec = data[j]; - tempDistance = getDistance(prevCentroid, curVec); + tempDistance = getDistance(prevCentroid, curVec, distanceType); minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance); } // square the distances and get total for normalization double total = 0.0; - for (int j = 0; j < data.length; j++) { - squaredMinDistance[j] = Math.pow(minDistancePerVector[j], 2); + for (int j = 0; j < data.length; j++) { + squaredMinDistance[j] = minDistancePerVector[j] * minDistancePerVector[j]; total += squaredMinDistance[j]; } @@ -363,10 +376,11 @@ protected DenseVector[] initialisePlusPlusCentroids(int centroids, * * @param data The dataset of {@link SparseVector} to use. * @param numFeatures The number of features. - * @return A {@Link DenseVector} representing a centroid. + * @param rng The RNG to use. + * @return A {@link DenseVector} representing a centroid. */ - protected DenseVector getRandomCentroidFromData(SparseVector[] data, - int numFeatures) { + private static DenseVector getRandomCentroidFromData(SparseVector[] data, + int numFeatures, SplittableRandom rng) { int rand_idx = rng.nextInt(data.length); return sparseToDense(data[rand_idx], numFeatures); } @@ -379,13 +393,14 @@ protected DenseVector getRandomCentroidFromData(SparseVector[] data, * @param numFeatures The number of features. * @return A {@link DenseVector} containing the information from vec. */ - protected DenseVector sparseToDense(SparseVector vec, int numFeatures) { + private static DenseVector sparseToDense(SparseVector vec, int numFeatures) { DenseVector dense = new DenseVector(numFeatures); dense.intersectAndAddInPlace(vec); return dense; } - protected double getDistance(DenseVector cluster, SGDVector vector) { + private static double getDistance(DenseVector cluster, SGDVector vector, + Distance distanceType) { double distance; switch (distanceType) { case EUCLIDEAN: @@ -403,9 +418,9 @@ protected double getDistance(DenseVector cluster, SGDVector vector) { return distance; } - protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map> clusterAssignments, SparseVector[] data, double[] weights) { + protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map> clusterAssignments, SparseVector[] data, double[] weights) { // M step - Stream>> mStream; + Stream>> mStream; if (numThreads > 1) { mStream = StreamUtil.boundParallelism(clusterAssignments.entrySet().stream().parallel()); } else { @@ -418,21 +433,21 @@ protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map f * weights[idx]); + newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]); counter++; } if (counter > 0) { - newCentroid.scaleInPlace(1.0/counter); + newCentroid.scaleInPlace(1.0 / counter); } })).get(); } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException("Parallel execution failed",e); + throw new RuntimeException("Parallel execution failed", e); } } @Override public String toString() { - return "KMeansTrainer(centroids="+centroids+",distanceType="+ distanceType +",seed="+seed+",numThreads="+numThreads+")"; + return "KMeansTrainer(centroids=" + centroids + ",distanceType=" + distanceType + ",seed=" + seed + ",numThreads=" + numThreads + ")"; } @Override diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java index 928f0fea5..7b972b737 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java @@ -55,9 +55,9 @@ public String getOptionsDescription() { public int iterations = 10; @Option(charName='d',longName="distance",usage="Distance function to use in the e step. Defaults to EUCLIDEAN.") public Distance distance = Distance.EUCLIDEAN; - @Option(charName='i',longName="initialisation",usage="Type of initialisation " + - "to use for centroids. Defaults to UNIFORM.") - public Initialisation initialisation = Initialisation.UNIFORM; + @Option(charName='i',longName="initialisation",usage="Type of initialisation to use for centroids. Defaults " + + "to RANDOM.") + public Initialisation initialisation = Initialisation.RANDOM; @Option(charName='t',longName="num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).") public int numThreads = 4; } diff --git a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java index 7be62e01a..f9d054ca8 100644 --- a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java +++ b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java @@ -38,11 +38,11 @@ */ public class TestKMeans { - private static final KMeansTrainer t = new KMeansTrainer(4,10, - Distance.EUCLIDEAN, KMeansTrainer.Initialisation.UNIFORM, 1,1); + private static final KMeansTrainer t = new KMeansTrainer(4,10, Distance.EUCLIDEAN, + KMeansTrainer.Initialisation.RANDOM, 1,1); - private static final KMeansTrainer plusPlus = new KMeansTrainer(4,10, - Distance.EUCLIDEAN, KMeansTrainer.Initialisation.PLUSPLUS, 1,1); + private static final KMeansTrainer plusPlus = new KMeansTrainer(4,10, Distance.EUCLIDEAN, + KMeansTrainer.Initialisation.PLUSPLUS, 1,1); @BeforeAll public static void setup() { @@ -73,10 +73,10 @@ public void runEvaluation(KMeansTrainer trainer) { ClusteringEvaluation testEvaluation = eval.evaluate(model,test); assertFalse(Double.isNaN(testEvaluation.adjustedMI())); - assertFalse(Double.isNaN(testEvaluation.normalizedMI())); } + assertFalse(Double.isNaN(testEvaluation.normalizedMI())); + } - public void testTrainer(Pair, - Dataset> p, KMeansTrainer trainer) { + public void testTrainer(Pair, Dataset> p, KMeansTrainer trainer) { Model m = trainer.train(p.getA()); ClusteringEvaluator e = new ClusteringEvaluator(); e.evaluate(m,p.getB()); @@ -148,4 +148,12 @@ public void testEmptyExample() { public void testPlusPlusEmptyExample() { runEmptyExample(plusPlus); } + + @Test + public void testPlusPlusTooManyCentroids() { + assertThrows(IllegalArgumentException.class, () -> { + Dataset data = ClusteringDataGenerator.gaussianClusters(3, 1L); + plusPlus.train(data); + }); + } } From 6dd32bc2dcaccd3dd7ea1a4252e1a9e3f0dac314 Mon Sep 17 00:00:00 2001 From: Samantha Campo Date: Wed, 2 Sep 2020 21:49:01 -0400 Subject: [PATCH 7/7] Fix PR issues. --- .../clustering/kmeans/KMeansTrainer.java | 19 +++++++++++++------ .../tribuo/clustering/kmeans/TrainTest.java | 11 +++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index f5330178c..95bb2ecf1 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -69,6 +69,8 @@ * "The Elements of Statistical Learning" * Springer 2001. PDF * + *

+ * For more on optional kmeans++ initialisation, see: *

  * D. Arthur, S. Vassilvitskii.
  * "K-Means++: The Advantages of Careful Seeding"
@@ -286,8 +288,8 @@ public int getInvocationCount() {
     }
 
     /**
-     * Initialisation method called at the start of each train call when
-     * using uniform centroid initialisation.
+     * Initialisation method called at the start of each train call when using the default centroid initialisation.
+     * Centroids are initialised using a uniform random sample from the feature domain.
      *
      * @param centroids  The number of centroids to create.
      * @param featureMap The feature map to use for centroid sampling.
@@ -311,8 +313,7 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF
     }
 
     /**
-     * Initialisation method called at the start of each train call when
-     * using kmeans++ centroid initialisation.
+     * Initialisation method called at the start of each train call when using kmeans++ centroid initialisation.
      *
      * @param centroids The number of centroids to create.
      * @param data The dataset of {@link SparseVector} to use.
@@ -344,10 +345,9 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SparseVe
 
             // go through every vector and see if the min distance to the
             // newest centroid is smaller than previous min distance for vec
-            double tempDistance;
             for (int j = 0; j < data.length; j++) {
                 SparseVector curVec = data[j];
-                tempDistance = getDistance(prevCentroid, curVec, distanceType);
+                double tempDistance = getDistance(prevCentroid, curVec, distanceType);
                 minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance);
             }
 
@@ -399,6 +399,13 @@ private static DenseVector sparseToDense(SparseVector vec, int numFeatures) {
         return dense;
     }
 
+    /**
+     *
+     * @param cluster A {@link DenseVector} representing a centroid.
+     * @param vector A {@link SGDVector} representing an example.
+     * @param distanceType The distance metric to employ.
+     * @return A double representing the distance from vector to centroid.
+     */
     private static double getDistance(DenseVector cluster, SGDVector vector,
                                       Distance distanceType) {
         double distance;
diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java
index 7b972b737..9842f7326 100644
--- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java
+++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/TrainTest.java
@@ -49,16 +49,15 @@ public String getOptionsDescription() {
         }
         public DataOptions general;
 
-        @Option(charName='n',longName="num-clusters",usage="Number of clusters to infer. Defaults to 5.")
+        @Option(charName='n',longName="num-clusters",usage="Number of clusters to infer.")
         public int centroids = 5;
-        @Option(charName='i',longName="iterations",usage="Maximum number of iterations. Defaults to 10.")
+        @Option(charName='i',longName="iterations",usage="Maximum number of iterations.")
         public int iterations = 10;
-        @Option(charName='d',longName="distance",usage="Distance function to use in the e step. Defaults to EUCLIDEAN.")
+        @Option(charName='d',longName="distance",usage="Distance function to use in the e step.")
         public Distance distance = Distance.EUCLIDEAN;
-        @Option(charName='i',longName="initialisation",usage="Type of initialisation to use for centroids. Defaults " +
-                "to RANDOM.")
+        @Option(charName='s',longName="initialisation",usage="Type of initialisation to use for centroids.")
         public Initialisation initialisation = Initialisation.RANDOM;
-        @Option(charName='t',longName="num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).")
+        @Option(charName='t',longName="num-threads",usage="Number of threads to use (range (1, num hw threads)).")
         public int numThreads = 4;
     }