Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposes views on the HDBSCAN cluster exemplars #229

Merged
merged 4 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
Expand All @@ -29,6 +30,7 @@
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;

import java.io.IOException;
Expand Down Expand Up @@ -107,6 +109,43 @@ public List<Double> getOutlierScores() {
return outlierScores;
}

/**
* Returns a deep copy of the cluster exemplars.
* @return The cluster exemplars.
*/
public List<HdbscanTrainer.ClusterExemplar> getClusterExemplars() {
List<HdbscanTrainer.ClusterExemplar> list = new ArrayList<>(clusterExemplars.size());
for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) {
list.add(e.copy());
}
return list;
}

/**
* Returns the features in each cluster exemplar.
* <p>
* In many cases this should be used in preference to {@link #getClusterExemplars()}
* as it performs the mapping from Tribuo's internal feature ids to
* the externally visible feature names.
* @return The cluster exemplars.
*/
public List<Pair<Integer,List<Feature>>> getClusters() {
List<Pair<Integer,List<Feature>>> list = new ArrayList<>(clusterExemplars.size());

for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) {
List<Feature> features = new ArrayList<>(e.getFeatures().numActiveElements());

for (VectorTuple v : e.getFeatures()) {
Feature f = new Feature(featureIDMap.get(v.index).getName(),v.value);
features.add(f);
}

list.add(new Pair<>(e.getLabel(),features));
}

return list;
}

@Override
public Prediction<ClusterID> predict(Example<ClusterID> example) {
SGDVector vector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.TreeMap;
import java.util.TreeSet;
Expand Down Expand Up @@ -805,7 +806,7 @@ public TrainerProvenance getProvenance() {
/**
* A cluster exemplar, with attributes for the point's label, outlier score and its features.
*/
final static class ClusterExemplar implements Serializable {
public final static class ClusterExemplar implements Serializable {
private static final long serialVersionUID = 1L;

private final Integer label;
Expand All @@ -820,26 +821,72 @@ final static class ClusterExemplar implements Serializable {
this.maxDistToEdge = maxDistToEdge;
}

Integer getLabel() {
/**
* Get the label in this exemplar.
* @return The label.
*/
public Integer getLabel() {
return label;
}

Double getOutlierScore() {
/**
* Get the outlier score in this exemplar.
* @return The outlier score.
*/
public Double getOutlierScore() {
return outlierScore;
}

SGDVector getFeatures() {
/**
* Get the feature vector in this exemplar.
* @return The feature vector.
*/
public SGDVector getFeatures() {
return features;
}

Double getMaxDistToEdge() {
/**
* Get the maximum distance from this exemplar to the edge of the cluster.
* <p>
* For models trained in 4.2 this will return {@link Double#NEGATIVE_INFINITY} as that information is
* not produced by 4.2 models.
* @return The distance to the edge of the cluster.
*/
public Double getMaxDistToEdge() {
if (maxDistToEdge != null) {
return maxDistToEdge;
}
else {
return Double.NEGATIVE_INFINITY;
}
}

/**
* Copies this cluster exemplar.
* @return A deep copy of this cluster exemplar.
*/
public ClusterExemplar copy() {
return new ClusterExemplar(label,outlierScore,features.copy(),maxDistToEdge);
}

@Override
public String toString() {
double dist = maxDistToEdge == null ? Double.NEGATIVE_INFINITY : maxDistToEdge;
return "ClusterExemplar(label="+label+",outlierScore="+outlierScore+",vector="+features+",maxDistToEdge="+dist+")";
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClusterExemplar that = (ClusterExemplar) o;
return label.equals(that.label) && outlierScore.equals(that.outlierScore) && features.equals(that.features) && Objects.equals(maxDistToEdge, that.maxDistToEdge);
}

@Override
public int hashCode() {
return Objects.hash(label, outlierScore, features, maxDistToEdge);
}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
import org.junit.jupiter.api.Test;
import org.tribuo.DataSource;
import org.tribuo.Dataset;
import org.tribuo.Feature;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Prediction;
Expand All @@ -37,7 +38,10 @@
import org.tribuo.data.columnar.processors.response.EmptyResponseProcessor;
import org.tribuo.data.csv.CSVDataSource;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.impl.ArrayExample;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.test.Helpers;

import java.io.FileInputStream;
Expand All @@ -47,8 +51,10 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
Expand All @@ -57,6 +63,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

/**
Expand Down Expand Up @@ -205,8 +212,35 @@ public static void runBasicTrainPredict(HdbscanTrainer trainer) {
Dataset<ClusterID> testData = new MutableDataset<>(splitter.getTest());

HdbscanModel model = trainer.train(trainData);

for (HdbscanTrainer.ClusterExemplar e : model.getClusterExemplars()) {
assertTrue(e.getMaxDistToEdge() > 0.0);
}

List<Integer> clusterLabels = model.getClusterLabels();
List<Double> outlierScores = model.getOutlierScores();
List<Pair<Integer,List<Feature>>> exemplarLists = model.getClusters();
List<HdbscanTrainer.ClusterExemplar> exemplars = model.getClusterExemplars();

assertEquals(exemplars.size(), exemplarLists.size());

// Check there's at least one exemplar per label
Set<Integer> exemplarLabels = exemplarLists.stream().map(Pair::getA).collect(Collectors.toSet());
Set<Integer> clusterLabelSet = new HashSet<>(clusterLabels);
// Remove the noise label
clusterLabelSet.remove(Integer.valueOf(0));
assertEquals(exemplarLabels,clusterLabelSet);

for (int i = 0; i < exemplars.size(); i++) {
HdbscanTrainer.ClusterExemplar e = exemplars.get(i);
Pair<Integer, List<Feature>> p = exemplarLists.get(i);
assertEquals(model.getFeatureIDMap().size(), e.getFeatures().size());
assertEquals(p.getB().size(), e.getFeatures().size());
SGDVector otherFeatures = DenseVector.createDenseVector(
new ArrayExample<>(trainData.getOutputFactory().getUnknownOutput(), p.getB()),
model.getFeatureIDMap(), false);
assertEquals(otherFeatures, e.getFeatures());
}

int [] expectedIntClusterLabels = {4,3,4,5,3,5,3,4,3,4,5,5,3,4,4,0,3,4,0,5,5,3,3,4,4,4,4,4,4,4,4,4,4,0,4,5,3,5,3,4,3,4,4,3,0,5,0,4,4,4,4,4,5,4,3,4,4,4,4,4,5,3,4,3,5,3,4,5,3,4,0,5,4,4,4,4,4,5,4,4,4,4,4,5,3,4,4,3,4,3,5,5,0,5,4,4,3,5,5,4,5,5,3,5,4,4,3,5,4,5,5,5,4,4,5,5,3,5,4,4,3,5,5,3,5,4,4,5,5,5,3,5,4,5,3,4,3,5,4,4,3,3,5,4,4,5,5,4,3,4,5,4,5,4,3,3,3,4,5,4,5,5,3,4,3,3,4,5,3,5,5,5,5,5,4,4,3,4,5,5,4,4,3,4,3,4,5,4,4,5,4,3,3,0,3,5,5,3,3,3,4,3,3,5,5,5,5,3,5,5,3,5,3,4,5,3,3,3,4,4,3,3,3,5,3,4,5,3,5,5,5,3,5,3,5,4,5,4,4,5,5,5,3,5,4,5,5,4,4,4,5,4,5,4,3,3,4,5,4,4,3,3,3,4,5,4,4,4,4,5,4,4,4,5,3,5,4,5,3,5,3,5,4,4,0,4,4,5,3,4,5,5,0,5,4,5,3,4,3,5,5,4,5,5,5,5,5,5,3,5,4,3,3,5,3,4,5,4,3,5,4,3,3,3,5,4,5,4,5,5,4,3,5,4,5,4,5,4,3,4,5,4,4,5,5,5,3,4,5,4,0,3,5,3,4,3,3,5,5,5,4,4,3,3,4,3,5,3,3,4,3,5,3,4,5,4,4,3,4,4,3,3,5,4,4,5,3,5,3,3,4,5,3,4,5,5,4,4,4,5,5,5,5,3,3,4,4,4,4,4,3,5,4,3,4,4,5,3,5,3,4,5,4,4,5,3,4,4,4,5,5,4,5,0,4,5,3,4,5,4,4,4,5,4,4,4,0,3,4,5,5,4,4,3,3,4,3,3,4,5,5,4,3,5,4,4,4,4,5,4,4,3,4,5,5,4,3,4,5,4,3,5,5,5,3,4,4,4,4,4,4,5,3,3,3,5,5,4,5,3,5,3,5,4,5,3,4,5,4,3,5,4,4,5,5,0,3,3,5,5,3,0,5,5,5,5,3,4,5,4,3,3,4,5,4,4,0,5,3,4,4,4,4,5,5,5,3,5,4,3,3,5,3,4,3,5,3,3,4,3,5,4,3,4,3,0,4,5,5,5,3,4,3,5,5,4,5,4,4,4,5,4,3,4,3,4,5,3,5,4,5,3,0,4,0,4,3,3,4,3,0,3,3,3,3,4,4,5,3,3,5,4,4,4,5,5,5,3,3,4,4,3,4,5,3,4,4,5,3,4,4,4,3,4,4,4,5,4,4,5,5,5,4,4,4,5,5,5,5,4,3,4,3,3,3,4,4,5,4,5,4,4,4,4,4,5,4,5,5,5,4,3,5,3,5,4,5,4,4,5,0,5,3,4,5,4,4,5,3,4,4,3,5,4,4,4,5,3,3,4,4,5,5,5,3,4,3,4,5,5,4,4,3,3,4,4,5,5,5,3,4,3,4,4,4,5,5,5,0,4,5,5,3,3,4,5,4,3,3,4,3,4,5,4,3,4,5,5,3,3,4,4,3,3,5,4,5,3,4,5,4,3,3,4,5,5,5,3,3,4,4,5,5,5,4,5,5,5,4,4,4,5,4,5,5,3,3,4,4,3,5,5,3,3,4,4,5,3,3,3};
List<Integer> expectedClusterLabels = Arrays.stream(expectedIntClusterLabels).boxed().collect(Collectors.toList());
Expand Down Expand Up @@ -258,6 +292,11 @@ public void deserializeHdbscanModelV42Test() {
fail("There is a problem deserializing the model file " + serializedModelPath);
}

// In v4.2 models this value is unset and defaults to negative infinity.
for (HdbscanTrainer.ClusterExemplar e : model.getClusterExemplars()) {
assertEquals(Double.NEGATIVE_INFINITY, e.getMaxDistToEdge());
}

ClusteringFactory clusteringFactory = new ClusteringFactory();
ResponseProcessor<ClusterID> emptyResponseProcessor = new EmptyResponseProcessor<>(clusteringFactory);
Map<String, FieldProcessor> regexMappingProcessors = new HashMap<>();
Expand Down