Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds clustering & nearest neighbour protobuf serialization #276

Merged
merged 3 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ public ConfiguredObjectProvenance getProvenance() {
return new ConfiguredObjectProvenanceImpl(this, "EnsembleCombiner");
}

@Override
public Class<Label> getTypeWitness() {
return Label.class;
}

/**
* Exports this voting combiner to ONNX.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ public ConfiguredObjectProvenance getProvenance() {
return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner");
}

@Override
public Class<Label> getTypeWitness() {
return Label.class;
}

/**
* Exports this voting combiner to ONNX.
* <p>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package org.tribuo.clustering.hdbscan;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -26,11 +28,16 @@
import org.tribuo.Prediction;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.hdbscan.HdbscanTrainer.Distance;
import org.tribuo.clustering.hdbscan.protos.ClusterExemplarProto;
import org.tribuo.clustering.hdbscan.protos.HdbscanModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

import java.io.IOException;
Expand All @@ -57,6 +64,11 @@
public final class HdbscanModel extends Model<ClusterID> {
private static final long serialVersionUID = 1L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final List<Integer> clusterLabels;

private final DenseVector outlierScoresVector;
Expand All @@ -76,13 +88,59 @@ public final class HdbscanModel extends Model<ClusterID> {
ImmutableOutputInfo<ClusterID> outputIDInfo, List<Integer> clusterLabels, DenseVector outlierScoresVector,
List<HdbscanTrainer.ClusterExemplar> clusterExemplars, DistanceType distType, double noisePointsOutlierScore) {
super(name,description,featureIDMap,outputIDInfo,false);
this.clusterLabels = clusterLabels;
this.clusterLabels = Collections.unmodifiableList(clusterLabels);
this.outlierScoresVector = outlierScoresVector;
this.clusterExemplars = clusterExemplars;
this.clusterExemplars = Collections.unmodifiableList(clusterExemplars);
this.distType = distType;
this.noisePointsOutlierScore = noisePointsOutlierScore;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static HdbscanModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
HdbscanModelProto proto = message.unpack(HdbscanModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<ClusterID> outputDomain = (ImmutableOutputInfo<ClusterID>) carrier.outputDomain();

Tensor outlierScoresTensor = Tensor.deserialize(proto.getOutlierScoresVector());
if (!(outlierScoresTensor instanceof DenseVector)) {
throw new IllegalStateException("Invalid protobuf, outlier scores must be a dense vector, found " + outlierScoresTensor.getClass());
}
DenseVector outlierScoresVector = (DenseVector) outlierScoresTensor;

List<Integer> clusterLabels = new ArrayList<>(proto.getClusterLabelsList());
for (Integer i : clusterLabels) {
if (outputDomain.getOutput(i) == null && i != -1) {
throw new IllegalStateException("Invalid protobuf, found cluster id " + i + " which is not present in the domain " + outputDomain);
}
}
if (clusterLabels.size() != outlierScoresVector.size()) {
throw new IllegalStateException("Invalid protobuf, expected the same number of outlier scores as cluster labels, found " +outlierScoresVector.size() + " scores and " + clusterLabels.size() + " labels");
}

List<HdbscanTrainer.ClusterExemplar> exemplars = new ArrayList<>();
for (ClusterExemplarProto p : proto.getClusterExemplarsList()) {
exemplars.add(HdbscanTrainer.ClusterExemplar.deserialize(p));
}

DistanceType distType = DistanceType.valueOf(proto.getDistType());

return new HdbscanModel(carrier.name(), carrier.provenance(), carrier.featureDomain(),
outputDomain, clusterLabels, outlierScoresVector, exemplars, distType, proto.getNoisePointsOutlierScore());
}

/**
* Returns the cluster labels for the training data.
* <p>
Expand Down Expand Up @@ -202,10 +260,32 @@ public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
return Optional.empty();
}

@Override
public ModelProto serialize() {
ModelDataCarrier<ClusterID> carrier = createDataCarrier();

HdbscanModelProto.Builder modelBuilder = HdbscanModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.addAllClusterLabels(clusterLabels);
modelBuilder.setOutlierScoresVector(outlierScoresVector.serialize());
modelBuilder.setDistType(distType.name());
for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) {
modelBuilder.addClusterExemplars(e.serialize());
}
modelBuilder.setNoisePointsOutlierScore(noisePointsOutlierScore);

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(HdbscanModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

@Override
protected HdbscanModel copy(String newName, ModelProvenance newProvenance) {
DenseVector copyOutlierScoresVector = outlierScoresVector.copy();
List<Integer> copyClusterLabels = Collections.unmodifiableList(clusterLabels);
List<Integer> copyClusterLabels = new ArrayList<>(clusterLabels);
List<HdbscanTrainer.ClusterExemplar> copyExemplars = new ArrayList<>(clusterExemplars);
return new HdbscanModel(newName, newProvenance, featureIDMap, outputIDInfo, copyClusterLabels,
copyOutlierScoresVector, copyExemplars, distType, noisePointsOutlierScore);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import org.tribuo.Trainer;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ImmutableClusteringInfo;
import org.tribuo.clustering.hdbscan.protos.ClusterExemplarProto;
import org.tribuo.math.distance.DistanceType;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.neighbour.NeighboursQuery;
import org.tribuo.math.neighbour.NeighboursQueryFactory;
import org.tribuo.math.neighbour.NeighboursQueryFactoryType;
Expand Down Expand Up @@ -924,6 +926,26 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(label, outlierScore, features, maxDistToEdge);
}

ClusterExemplarProto serialize() {
ClusterExemplarProto.Builder builder = ClusterExemplarProto.newBuilder();

builder.setLabel(label);
builder.setOutlierScore(outlierScore);
builder.setFeatures(features.serialize());
builder.setMaxDistToEdge(maxDistToEdge);

return builder.build();
}

static ClusterExemplar deserialize(ClusterExemplarProto proto) {
Tensor tensor = Tensor.deserialize(proto.getFeatures());
if (!(tensor instanceof SGDVector)) {
throw new IllegalStateException("Invalid protobuf, features must be an SGDVector, found " + tensor.getClass());
}
SGDVector vector = (SGDVector) tensor;
return new ClusterExemplar(proto.getLabel(),proto.getOutlierScore(),vector,proto.getMaxDistToEdge());
}
}

}
Loading