Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the prediction technique #222

Merged
merged 5 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ public final class HdbscanModel extends Model<ClusterID> {

private final List<HdbscanTrainer.ClusterExemplar> clusterExemplars;

private final double noisePointsOutlierScore;

HdbscanModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap,
ImmutableOutputInfo<ClusterID> outputIDInfo, List<Integer> clusterLabels, DenseVector outlierScoresVector,
List<HdbscanTrainer.ClusterExemplar> clusterExemplars, DistanceType distType) {
List<HdbscanTrainer.ClusterExemplar> clusterExemplars, DistanceType distType, double noisePointsOutlierScore) {
super(name,description,featureIDMap,outputIDInfo,false);
this.clusterLabels = clusterLabels;
this.outlierScoresVector = outlierScoresVector;
this.clusterExemplars = clusterExemplars;
this.distType = distType;
this.noisePointsOutlierScore = noisePointsOutlierScore;
}

/**
Expand Down Expand Up @@ -115,18 +118,39 @@ public Prediction<ClusterID> predict(Example<ClusterID> example) {
if (vector.numActiveElements() == 0) {
throw new IllegalArgumentException("No features found in Example " + example);
}

double minDistance = Double.POSITIVE_INFINITY;
int clusterLabel = -1;
double clusterOutlierScore = 0.0;
for (HdbscanTrainer.ClusterExemplar clusterExemplar : clusterExemplars) {
double distance = DistanceType.getDistance(clusterExemplar.getFeatures(), vector, distType);
if (distance < minDistance) {
minDistance = distance;
clusterLabel = clusterExemplar.getLabel();
clusterOutlierScore = clusterExemplar.getOutlierScore();
int clusterLabel = HdbscanTrainer.OUTLIER_NOISE_CLUSTER_LABEL;
double outlierScore = 0.0;
if (Double.compare(noisePointsOutlierScore, 0) > 0) { // This will be true from models > 4.2
boolean isNoisePoint = true;
for (HdbscanTrainer.ClusterExemplar clusterExemplar : clusterExemplars) {
double distance = DistanceType.getDistance(clusterExemplar.getFeatures(), vector, distType);
if (isNoisePoint && distance <= clusterExemplar.getMaxDistToEdge()) {
isNoisePoint = false;
}
if (distance < minDistance) {
minDistance = distance;
clusterLabel = clusterExemplar.getLabel();
outlierScore = clusterExemplar.getOutlierScore();
}
}
if (isNoisePoint) {
clusterLabel = HdbscanTrainer.OUTLIER_NOISE_CLUSTER_LABEL;
outlierScore = noisePointsOutlierScore;
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
}
}
else {
for (HdbscanTrainer.ClusterExemplar clusterExemplar : clusterExemplars) {
double distance = DistanceType.getDistance(clusterExemplar.getFeatures(), vector, distType);
if (distance < minDistance) {
minDistance = distance;
clusterLabel = clusterExemplar.getLabel();
outlierScore = clusterExemplar.getOutlierScore();
}
}
}
return new Prediction<>(new ClusterID(clusterLabel, clusterOutlierScore),vector.size(),example);
return new Prediction<>(new ClusterID(clusterLabel, outlierScore),vector.size(),example);
}

@Override
Expand All @@ -145,7 +169,7 @@ protected HdbscanModel copy(String newName, ModelProvenance newProvenance) {
List<Integer> copyClusterLabels = Collections.unmodifiableList(clusterLabels);
List<HdbscanTrainer.ClusterExemplar> copyExemplars = new ArrayList<>(clusterExemplars);
return new HdbscanModel(newName, newProvenance, featureIDMap, outputIDInfo, copyClusterLabels,
copyOutlierScoresVector, copyExemplars, distType);
copyOutlierScoresVector, copyExemplars, distType, noisePointsOutlierScore);
}

private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
* An HDBSCAN* trainer which generates a hierarchical, density-based clustering representation
Expand Down Expand Up @@ -82,6 +84,8 @@ public final class HdbscanTrainer implements Trainer<ClusterID> {

static final int OUTLIER_NOISE_CLUSTER_LABEL = 0;

private static final double MAX_OUTLIER_SCORE = 0.9999;

/**
* Available distance functions.
* @deprecated
Expand Down Expand Up @@ -241,15 +245,18 @@ public HdbscanModel train(Dataset<ClusterID> examples, Map<String, Provenance> r
ImmutableOutputInfo<ClusterID> outputMap = new ImmutableClusteringInfo(counts);

// Compute the cluster exemplars.
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments);
List<ClusterExemplar> clusterExemplars = computeExemplars(data, clusterAssignments, distType);

// Get the outlier score value for points that are predicted as noise points.
double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments);

logger.log(Level.INFO, "Hdbscan is done.");

ModelProvenance provenance = new ModelProvenance(HdbscanModel.class.getName(), OffsetDateTime.now(),
examples.getProvenance(), trainerProvenance, runProvenance);

return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector,
clusterExemplars, distType);
clusterExemplars, distType, noisePointsOutlierScore);
}

@Override
Expand Down Expand Up @@ -705,14 +712,16 @@ private static Map<Integer, List<Pair<Double, Integer>>> generateClusterAssignme
}

/**
* Compute the exemplars. These are representative points which are subsets of their clusters and noise points, and
* Compute the exemplars. These are representative points which are subsets of their clusters, and
* will be used for prediction on unseen data points.
*
* @param data An array of {@link DenseVector} containing the data.
* @param clusterAssignments A map of the cluster labels, and the points assigned to them.
* @param distType The distance metric to employ.
* @return A list of {@link ClusterExemplar}s which are used for predictions.
*/
private static List<ClusterExemplar> computeExemplars(SGDVector[] data, Map<Integer, List<Pair<Double, Integer>>> clusterAssignments) {
private static List<ClusterExemplar> computeExemplars(SGDVector[] data, Map<Integer, List<Pair<Double, Integer>>> clusterAssignments,
DistanceType distType) {
geoffreydstewart marked this conversation as resolved.
Show resolved Hide resolved
List<ClusterExemplar> clusterExemplars = new ArrayList<>();
// The formula to calculate the exemplar number. This calculates the number of exemplars to be used for this
// configuration. The appropriate number of exemplars is important for prediction. At the time, this
Expand All @@ -721,37 +730,68 @@ private static List<ClusterExemplar> computeExemplars(SGDVector[] data, Map<Inte

for (Entry<Integer, List<Pair<Double, Integer>>> e : clusterAssignments.entrySet()) {
int clusterLabel = e.getKey();
List<Pair<Double, Integer>> outlierScoreIndexList = clusterAssignments.get(clusterLabel);

// Put the items into a TreeMap. This achieves the required sorting and removes duplicate outlier scores to
// provide the best samples
TreeMap<Double, Integer> outlierScoreIndexTree = new TreeMap<>();
outlierScoreIndexList.forEach(p -> outlierScoreIndexTree.put(p.getA(), p.getB()));
int numExemplarsThisCluster = e.getValue().size() * numExemplars / data.length;
if (numExemplarsThisCluster > outlierScoreIndexTree.size()) {
numExemplarsThisCluster = outlierScoreIndexTree.size();
}

if (clusterLabel != OUTLIER_NOISE_CLUSTER_LABEL) {
for (int i = 0; i < numExemplarsThisCluster; i++) {
// Note that for non-outliers, the first node is polled from the tree, which has the lowest outlier
// score out of the remaining points assigned this cluster.
Entry<Double, Integer> entry = outlierScoreIndexTree.pollFirstEntry();
clusterExemplars.add(new ClusterExemplar(clusterLabel, entry.getKey(), data[entry.getValue()]));
List<Pair<Double, Integer>> outlierScoreIndexList = clusterAssignments.get(clusterLabel);

// Put the items into a TreeMap. This achieves the required sorting and removes duplicate outlier scores to
// provide the best samples
TreeMap<Double, Integer> outlierScoreIndexTree = new TreeMap<>();
outlierScoreIndexList.forEach(p -> outlierScoreIndexTree.put(p.getA(), p.getB()));
int numExemplarsThisCluster = e.getValue().size() * numExemplars / data.length;
if (numExemplarsThisCluster > outlierScoreIndexTree.size()) {
numExemplarsThisCluster = outlierScoreIndexTree.size();
}
}
else {
for (int i = 0; i < numExemplarsThisCluster; i++) {
// Note that for outliers the last node is polled from the tree, which has the highest outlier score
// out of the remaining points assigned this cluster.
Entry<Double, Integer> entry = outlierScoreIndexTree.pollLastEntry();
clusterExemplars.add(new ClusterExemplar(clusterLabel, entry.getKey(), data[entry.getValue()]));

// First, get the entries that will be used for cluster exemplars.
// Note that for non-outliers, the first node is polled from the tree, which has the lowest outlier
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
// score out of the remaining points assigned this cluster.
List<Entry<Double, Integer>> partialClusterExemplars = new ArrayList<>();
Stream<Integer> intStream = IntStream.range(0, numExemplarsThisCluster).boxed();
intStream.forEach((i) -> partialClusterExemplars.add(outlierScoreIndexTree.pollFirstEntry()));

// For each of the partial exemplars in this cluster, iterate the remaining nodes in the tree to find
// the maximum distance between the exemplar and the members of the cluster. The other exemplars don't
// need to be checked here since they won't be on the fringe of the cluster.
for (Entry<Double, Integer> partialClusterExemplar : partialClusterExemplars) {
SGDVector features = data[partialClusterExemplar.getValue()];
double maxInnerDist = Double.NEGATIVE_INFINITY;
for (Entry<Double, Integer> entry : outlierScoreIndexTree.entrySet()) {
double distance = DistanceType.getDistance(features, data[entry.getValue()], distType);
if (distance > maxInnerDist){
maxInnerDist = distance;
}
}
clusterExemplars.add(new ClusterExemplar(clusterLabel, partialClusterExemplar.getKey(), features,
maxInnerDist));
}
}
}
return clusterExemplars;
}

/**
* Determine the outlier score value for points that are predicted as noise points.
*
* @param clusterAssignments A map of the cluster labels, and the points assigned to them.
* @return An outlier score value for points predicted as noise points.
*/
private static double getNoisePointsOutlierScore(Map<Integer, List<Pair<Double, Integer>>> clusterAssignments) {

List<Pair<Double, Integer>> outlierScoreIndexList = clusterAssignments.get(OUTLIER_NOISE_CLUSTER_LABEL);
if ((outlierScoreIndexList == null) || outlierScoreIndexList.isEmpty()) {
return MAX_OUTLIER_SCORE;
}

double upperOutlierScoreBound = Double.NEGATIVE_INFINITY;
for (Pair<Double, Integer> outlierScoreIndex : outlierScoreIndexList) {
if (outlierScoreIndex.getA() > upperOutlierScoreBound) {
upperOutlierScoreBound = outlierScoreIndex.getA();
}
}
return upperOutlierScoreBound;
}

@Override
public String toString() {
return "HdbscanTrainer(minClusterSize=" + minClusterSize + ",distanceType=" + distType + ",k=" + k + ",numThreads=" + numThreads + ")";
Expand All @@ -771,11 +811,13 @@ final static class ClusterExemplar implements Serializable {
private final Integer label;
private final Double outlierScore;
private final SGDVector features;
private final Double maxDistToEdge;

ClusterExemplar(Integer label, Double outlierScore, SGDVector features) {
ClusterExemplar(Integer label, Double outlierScore, SGDVector features, Double maxDistToEdge) {
this.label = label;
this.outlierScore = outlierScore;
this.features = features;
this.maxDistToEdge = maxDistToEdge;
}

Integer getLabel() {
Expand All @@ -789,6 +831,15 @@ Double getOutlierScore() {
SGDVector getFeatures() {
return features;
}

Double getMaxDistToEdge() {
if (maxDistToEdge != null) {
return maxDistToEdge;
}
else {
return Double.NEGATIVE_INFINITY;
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import org.tribuo.math.distance.DistanceType;
import org.tribuo.test.Helpers;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -54,6 +57,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;

/**
* Unit tests with small datasets for Hdbscan
Expand Down Expand Up @@ -172,6 +176,26 @@ public void testEndToEndPredictWithCSVData() {

assertArrayEquals(expectedLabelPredictions, actualLabelPredictions);
assertArrayEquals(expectedOutlierScorePredictions, actualOutlierScorePredictions);

CSVDataSource<ClusterID> nextCsvTestSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians-predict-with-outliers.csv"),rowProcessor,false);
Dataset<ClusterID> nextTestSet = new MutableDataset<>(nextCsvTestSource);

predictions = model.predict(nextTestSet);

i = 0;
actualLabelPredictions = new int[nextTestSet.size()];
actualOutlierScorePredictions = new double[nextTestSet.size()];
for (Prediction<ClusterID> pred : predictions) {
actualLabelPredictions[i] = pred.getOutput().getID();
actualOutlierScorePredictions[i] = pred.getOutput().getScore();
i++;
}

int[] nextExpectedLabelPredictions = {5,0,3,0,4,0};
double[] nextExpectedOutlierScorePredictions = {0.04384108680937504,0.837375806784261,0.04922915472735656,0.837375806784261,0.02915273635987492,0.837375806784261};

assertArrayEquals(nextExpectedLabelPredictions, actualLabelPredictions);
assertArrayEquals(nextExpectedOutlierScorePredictions, actualOutlierScorePredictions);
}

public static void runBasicTrainPredict(HdbscanTrainer trainer) {
Expand Down Expand Up @@ -216,6 +240,52 @@ public void testBasicTrainPredict() {
runBasicTrainPredict(t);
}

@Test
public void deserializeHdbscanModelV42Test() {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
String serializedModelFilename = "Hdbscan_minClSize7_L2_k7_nt1_v4.2.model";
String serializedModelPath = this.getClass().getClassLoader().getResource(serializedModelFilename).getPath();

HdbscanModel model = null;
try (ObjectInputStream oin = new ObjectInputStream(new FileInputStream(serializedModelPath))) {
Object data = oin.readObject();
model = (HdbscanModel) data;
if (!model.validate(ClusterID.class)) {
fail("This is not a Clustering model.");
}
} catch (IOException e) {
fail("There is a problem accessing the serialized model file " + serializedModelPath);
} catch (ClassNotFoundException e) {
fail("There is a problem deserializing the model file " + serializedModelPath);
}

ClusteringFactory clusteringFactory = new ClusteringFactory();
ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
regexMappingProcessors.put("Feature1", new DoubleFieldProcessor("Feature1"));
regexMappingProcessors.put("Feature2", new DoubleFieldProcessor("Feature2"));
regexMappingProcessors.put("Feature3", new DoubleFieldProcessor("Feature3"));
RowProcessor<ClusterID> rowProcessor = new RowProcessor<>(emptyResponseProcessor,regexMappingProcessors);
CSVDataSource<ClusterID> csvTestSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians-predict.csv"),rowProcessor,false);
Dataset<ClusterID> testSet = new MutableDataset<>(csvTestSource);

List<Prediction<ClusterID>> predictions = model.predict(testSet);

int i = 0;
int[] actualLabelPredictions = new int[testSet.size()];
double[] actualOutlierScorePredictions = new double[testSet.size()];
for (Prediction<ClusterID> pred : predictions) {
actualLabelPredictions[i] = pred.getOutput().getID();
actualOutlierScorePredictions[i] = pred.getOutput().getScore();
i++;
}

int[] expectedLabelPredictions = {5,3,5,5,3,5,4,4,5,3,3,3,3,4,4,5,4,5,5,4};
double[] expectedOutlierScorePredictions = {0.04384108680937504,0.04922915472735656,4.6591582469379667E-4,0.025225544503289288,0.04922915472735656,0.0,0.044397942146806146,0.044397942146806146,0.025225544503289288,0.0,0.04922915472735656,0.0,0.0,0.044397942146806146,0.02395925569434121,0.003121298369468062,0.02915273635987492,0.03422951971100352,0.0,0.02915273635987492};

assertArrayEquals(expectedLabelPredictions, actualLabelPredictions);
assertArrayEquals(expectedOutlierScorePredictions, actualOutlierScorePredictions);
}

public static void runEvaluation(HdbscanTrainer trainer) {
DataSource<ClusterID> gaussianSource = new GaussianClusterDataSource(1000, 1L);
TrainTestSplitter<ClusterID> splitter = new TrainTestSplitter<>(gaussianSource, 0.7f, 2L);
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Feature1,Feature2,Feature3
-2.3302356259487063,3.9431416146381046,1.0315528543744679
12.5,15.0,17.1
0.41679363204429154,8.247732287302664,9.810651956897404
-16.0,-13.3,14.4
1.2947698963877157,-1.0272570581099394,1.6991984313559259
-14.9,-13.9,-15.5