From cdb67ed92da5593e044a99b2962e93e4a7c2d754 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 29 Sep 2022 22:37:02 -0400 Subject: [PATCH 1/2] Replacing DistanceType with a Distance interface. --- .../clustering/hdbscan/HdbscanModel.java | 24 +- .../clustering/hdbscan/HdbscanOptions.java | 2 +- .../clustering/hdbscan/HdbscanTrainer.java | 49 +- .../hdbscan/protos/HdbscanModelProto.java | 246 +++--- .../protos/HdbscanModelProtoOrBuilder.java | 17 +- .../protos/TribuoClusteringHdbscan.java | 16 +- .../protos/tribuo-clustering-hdbscan.proto | 2 +- .../clustering/hdbscan/TestHdbscan.java | 8 +- .../tribuo/clustering/kmeans/KMeansModel.java | 22 +- .../clustering/kmeans/KMeansOptions.java | 2 +- .../clustering/kmeans/KMeansTrainer.java | 38 +- .../kmeans/protos/KMeansModelProto.java | 246 +++--- .../protos/KMeansModelProtoOrBuilder.java | 17 +- .../kmeans/protos/TribuoClusteringKmeans.java | 10 +- .../protos/tribuo-clustering-kmeans.proto | 2 +- .../tribuo/clustering/kmeans/TestKMeans.java | 10 +- .../common/nearest/KNNClassifierOptions.java | 2 +- .../org/tribuo/common/nearest/KNNModel.java | 36 +- .../org/tribuo/common/nearest/KNNTrainer.java | 28 +- .../common/nearest/protos/KNNModelProto.java | 246 +++--- .../protos/KNNModelProtoOrBuilder.java | 17 +- .../nearest/protos/TribuoCommonKnn.java | 18 +- .../resources/protos/tribuo-common-knn.proto | 2 +- .../org/tribuo/common/nearest/TestKNN.java | 10 +- .../tribuo/math/distance/CosineDistance.java | 90 ++ .../org/tribuo/math/distance/Distance.java | 43 + .../tribuo/math/distance/DistanceType.java | 41 +- .../org/tribuo/math/distance/L1Distance.java | 90 ++ .../org/tribuo/math/distance/L2Distance.java | 90 ++ .../tribuo/math/distance/package-info.java | 3 +- .../neighbour/NeighboursQueryFactory.java | 6 +- .../neighbour/NeighboursQueryFactoryType.java | 10 +- .../bruteforce/NeighboursBruteForce.java | 14 +- .../NeighboursBruteForceFactory.java | 21 +- .../math/neighbour/kdtree/DimensionNode.java | 12 +- .../tribuo/math/neighbour/kdtree/KDTree.java | 20 +- .../math/neighbour/kdtree/KDTreeFactory.java | 23 +- .../math/protos/BruteForceFactoryProto.java | 246 +++--- .../BruteForceFactoryProtoOrBuilder.java | 17 +- .../org/tribuo/math/protos/DistanceProto.java | 817 ++++++++++++++++++ .../math/protos/DistanceProtoOrBuilder.java | 42 + .../math/protos/KDTreeFactoryProto.java | 246 +++--- .../protos/KDTreeFactoryProtoOrBuilder.java | 17 +- .../org/tribuo/math/protos/TribuoMath.java | 56 +- .../tribuo/math/protos/TribuoMathImpl.java | 15 +- .../resources/protos/tribuo-math-impl.proto | 4 +- .../main/resources/protos/tribuo-math.proto | 9 + .../org/tribuo/math/neighbour/TestKDTree.java | 22 +- .../neighbour/TestNeighborsBruteForce.java | 22 +- 49 files changed, 2252 insertions(+), 794 deletions(-) create mode 100644 Math/src/main/java/org/tribuo/math/distance/CosineDistance.java create mode 100644 Math/src/main/java/org/tribuo/math/distance/Distance.java create mode 100644 Math/src/main/java/org/tribuo/math/distance/L1Distance.java create mode 100644 Math/src/main/java/org/tribuo/math/distance/L2Distance.java create mode 100644 Math/src/main/java/org/tribuo/math/protos/DistanceProto.java create mode 100644 Math/src/main/java/org/tribuo/math/protos/DistanceProtoOrBuilder.java diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java index d3f9c90a9..c8f0965f6 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanModel.java @@ -31,12 +31,12 @@ import org.tribuo.clustering.hdbscan.protos.ClusterExemplarProto; import org.tribuo.clustering.hdbscan.protos.HdbscanModelProto; import org.tribuo.impl.ModelDataCarrier; -import org.tribuo.math.distance.DistanceType; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.SparseVector; import org.tribuo.math.la.Tensor; import org.tribuo.math.la.VectorTuple; +import org.tribuo.protos.ProtoUtil; import org.tribuo.protos.core.ModelProto; import org.tribuo.provenance.ModelProvenance; @@ -78,7 +78,7 @@ public final class HdbscanModel extends Model { // This is not final to support deserialization of older models. It will be final in a future version which doesn't // maintain serialization compatibility with 4.X. - private DistanceType distType; + private org.tribuo.math.distance.Distance dist; private final List clusterExemplars; @@ -86,12 +86,12 @@ public final class HdbscanModel extends Model { HdbscanModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo outputIDInfo, List clusterLabels, DenseVector outlierScoresVector, - List clusterExemplars, DistanceType distType, double noisePointsOutlierScore) { + List clusterExemplars, org.tribuo.math.distance.Distance dist, double noisePointsOutlierScore) { super(name,description,featureIDMap,outputIDInfo,false); this.clusterLabels = Collections.unmodifiableList(clusterLabels); this.outlierScoresVector = outlierScoresVector; this.clusterExemplars = Collections.unmodifiableList(clusterExemplars); - this.distType = distType; + this.dist = dist; this.noisePointsOutlierScore = noisePointsOutlierScore; } @@ -135,10 +135,10 @@ public static HdbscanModel deserializeFromProto(int version, String className, A exemplars.add(HdbscanTrainer.ClusterExemplar.deserialize(p)); } - DistanceType distType = DistanceType.valueOf(proto.getDistType()); + org.tribuo.math.distance.Distance dist = ProtoUtil.deserialize(proto.getDistance()); return new HdbscanModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), - outputDomain, clusterLabels, outlierScoresVector, exemplars, distType, proto.getNoisePointsOutlierScore()); + outputDomain, clusterLabels, outlierScoresVector, exemplars, dist, proto.getNoisePointsOutlierScore()); } /** @@ -222,7 +222,7 @@ public Prediction predict(Example example) { if (Double.compare(noisePointsOutlierScore, 0) > 0) { // This will be true from models > 4.2 boolean isNoisePoint = true; for (HdbscanTrainer.ClusterExemplar clusterExemplar : clusterExemplars) { - double distance = DistanceType.getDistance(clusterExemplar.getFeatures(), vector, distType); + double distance = dist.computeDistance(clusterExemplar.getFeatures(), vector); if (isNoisePoint && distance <= clusterExemplar.getMaxDistToEdge()) { isNoisePoint = false; } @@ -239,7 +239,7 @@ public Prediction predict(Example example) { } else { for (HdbscanTrainer.ClusterExemplar clusterExemplar : clusterExemplars) { - double distance = DistanceType.getDistance(clusterExemplar.getFeatures(), vector, distType); + double distance = dist.computeDistance(clusterExemplar.getFeatures(), vector); if (distance < minDistance) { minDistance = distance; clusterLabel = clusterExemplar.getLabel(); @@ -268,7 +268,7 @@ public ModelProto serialize() { modelBuilder.setMetadata(carrier.serialize()); modelBuilder.addAllClusterLabels(clusterLabels); modelBuilder.setOutlierScoresVector(outlierScoresVector.serialize()); - modelBuilder.setDistType(distType.name()); + modelBuilder.setDistance(dist.serialize()); for (HdbscanTrainer.ClusterExemplar e : clusterExemplars) { modelBuilder.addClusterExemplars(e.serialize()); } @@ -288,13 +288,13 @@ protected HdbscanModel copy(String newName, ModelProvenance newProvenance) { List copyClusterLabels = new ArrayList<>(clusterLabels); List copyExemplars = new ArrayList<>(clusterExemplars); return new HdbscanModel(newName, newProvenance, featureIDMap, outputIDInfo, copyClusterLabels, - copyOutlierScoresVector, copyExemplars, distType, noisePointsOutlierScore); + copyOutlierScoresVector, copyExemplars, dist, noisePointsOutlierScore); } private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject(); - if (distType == null) { - distType = distanceType.getDistanceType(); + if (dist == null) { + dist = distanceType.getDistanceType().getDistance(); } } } diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java index 604165a4f..bac527ff1 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanOptions.java @@ -71,6 +71,6 @@ public String getOptionsDescription() { */ public HdbscanTrainer getTrainer() { logger.info("Configuring Hdbscan Trainer"); - return new HdbscanTrainer(minClusterSize, distType, k, numThreads, nqFactoryType); + return new HdbscanTrainer(minClusterSize, distType.getDistance(), k, numThreads, nqFactoryType); } } diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java index 87f408352..4fa39efa0 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/HdbscanTrainer.java @@ -135,7 +135,7 @@ public DistanceType getDistanceType() { private Distance distanceType; @Config(description = "The distance function to use.") - private DistanceType distType; + private org.tribuo.math.distance.Distance dist; @Config(mandatory = true, description = "The number of nearest-neighbors to use in the initial density approximation. " + "This includes the point itself.") @@ -154,19 +154,18 @@ public DistanceType getDistanceType() { /** * for olcut. */ - private HdbscanTrainer() { - } + private HdbscanTrainer() {} /** * Constructs an HDBSCAN* trainer with only the minClusterSize parameter. * * @param minClusterSize The minimum number of points required to form a cluster. - * {@link #distType} defaults to {@link DistanceType#L2}, {@link #k} defaults to {@link #minClusterSize}, + * {@link #dist} defaults to {@link DistanceType#L2}, {@link #k} defaults to {@link #minClusterSize}, * {@link #numThreads} defaults to 1 and {@link #neighboursQueryFactory} defaults to * {@link NeighboursBruteForceFactory}. */ public HdbscanTrainer(int minClusterSize) { - this(minClusterSize, DistanceType.L2, minClusterSize, 1, NeighboursQueryFactoryType.BRUTE_FORCE); + this(minClusterSize, DistanceType.L2.getDistance(), minClusterSize, 1, NeighboursQueryFactoryType.BRUTE_FORCE); } /** @@ -182,24 +181,24 @@ public HdbscanTrainer(int minClusterSize) { */ @Deprecated public HdbscanTrainer(int minClusterSize, Distance distanceType, int k, int numThreads) { - this(minClusterSize, distanceType.getDistanceType(), k, numThreads, NeighboursQueryFactoryType.BRUTE_FORCE); + this(minClusterSize, distanceType.getDistanceType().getDistance(), k, numThreads, NeighboursQueryFactoryType.BRUTE_FORCE); } /** * Constructs an HDBSCAN* trainer using the supplied parameters. * * @param minClusterSize The minimum number of points required to form a cluster. - * @param distType The distance function. + * @param dist The distance function. * @param k The number of nearest-neighbors to use in the initial density approximation. * @param numThreads The number of threads. * @param nqFactoryType The nearest neighbour query implementation factory to use. */ - public HdbscanTrainer(int minClusterSize, DistanceType distType, int k, int numThreads, NeighboursQueryFactoryType nqFactoryType) { + public HdbscanTrainer(int minClusterSize, org.tribuo.math.distance.Distance dist, int k, int numThreads, NeighboursQueryFactoryType nqFactoryType) { this.minClusterSize = minClusterSize; - this.distType = distType; + this.dist = dist; this.k = k; this.numThreads = numThreads; - this.neighboursQueryFactory = NeighboursQueryFactoryType.getNeighboursQueryFactory(nqFactoryType, distType, numThreads); + this.neighboursQueryFactory = NeighboursQueryFactoryType.getNeighboursQueryFactory(nqFactoryType, dist, numThreads); } /** @@ -211,7 +210,7 @@ public HdbscanTrainer(int minClusterSize, DistanceType distType, int k, int numT */ public HdbscanTrainer(int minClusterSize, int k, NeighboursQueryFactory neighboursQueryFactory) { this.minClusterSize = minClusterSize; - this.distType = neighboursQueryFactory.getDistanceType(); + this.dist = neighboursQueryFactory.getDistance(); this.k = k; this.neighboursQueryFactory = neighboursQueryFactory; } @@ -222,19 +221,19 @@ public HdbscanTrainer(int minClusterSize, int k, NeighboursQueryFactory neighbou @Override public synchronized void postConfig() { if (this.distanceType != null) { - if (this.distType != null) { + if (this.dist != null) { throw new PropertyException("distType", "Both distType and distanceType must not both be set."); } else { - this.distType = this.distanceType.getDistanceType(); + this.dist = this.distanceType.getDistanceType().getDistance(); this.distanceType = null; } } if (neighboursQueryFactory == null) { int numberThreads = (this.numThreads <= 0) ? 1 : this.numThreads; - this.neighboursQueryFactory = new NeighboursBruteForceFactory(distType, numberThreads); + this.neighboursQueryFactory = new NeighboursBruteForceFactory(dist, numberThreads); } else { - if (!this.distType.equals(neighboursQueryFactory.getDistanceType())) { + if (!this.dist.equals(neighboursQueryFactory.getDistance())) { throw new PropertyException("neighboursQueryFactory", "distType and its field on the " + "NeighboursQueryFactory must be equal."); } @@ -264,7 +263,7 @@ public HdbscanModel train(Dataset examples, Map r } DenseVector coreDistances = calculateCoreDistances(data, k, neighboursQueryFactory); - ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, distType); + ExtendedMinimumSpanningTree emst = constructEMST(data, coreDistances, dist); double[] pointNoiseLevels = new double[data.length]; // The levels at which each point becomes noise int[] pointLastClusters = new int[data.length]; // The last label of each point before becoming noise @@ -284,7 +283,7 @@ public HdbscanModel train(Dataset examples, Map r ImmutableOutputInfo outputMap = new ImmutableClusteringInfo(counts); // Compute the cluster exemplars. - List clusterExemplars = computeExemplars(data, clusterAssignments, distType); + List clusterExemplars = computeExemplars(data, clusterAssignments, dist); // Get the outlier score value for points that are predicted as noise points. double noisePointsOutlierScore = getNoisePointsOutlierScore(clusterAssignments); @@ -295,7 +294,7 @@ public HdbscanModel train(Dataset examples, Map r examples.getProvenance(), trainerProvenance, runProvenance); return new HdbscanModel("hdbscan-model", provenance, featureMap, outputMap, clusterLabels, outlierScoresVector, - clusterExemplars, distType, noisePointsOutlierScore); + clusterExemplars, dist, noisePointsOutlierScore); } @Override @@ -347,12 +346,12 @@ private static DenseVector calculateCoreDistances(SGDVector[] data, int k, Neigh * core distances for each point. * @param data An array of {@link DenseVector} containing the data. * @param coreDistances A {@link DenseVector} containing the core distances for every point. - * @param distType The distance metric to employ. + * @param dist The distance metric to employ. * @return An {@link ExtendedMinimumSpanningTree} representation of the data using the mutual reachability distances, * and the graph is sorted by edge weight in ascending order. */ private static ExtendedMinimumSpanningTree constructEMST(SGDVector[] data, DenseVector coreDistances, - DistanceType distType) { + org.tribuo.math.distance.Distance dist) { // One bit is set (true) for each attached point, and unset (false) for unattached points: BitSet attachedPoints = new BitSet(data.length); @@ -380,7 +379,7 @@ private static ExtendedMinimumSpanningTree constructEMST(SGDVector[] data, Dense continue; } - double mutualReachabilityDistance = DistanceType.getDistance(data[currentPoint], data[neighbor], distType); + double mutualReachabilityDistance = dist.computeDistance(data[currentPoint], data[neighbor]); if (coreDistances.get(currentPoint) > mutualReachabilityDistance) { mutualReachabilityDistance = coreDistances.get(currentPoint); } @@ -754,11 +753,11 @@ private static Map>> generateClusterAssignme * * @param data An array of {@link DenseVector} containing the data. * @param clusterAssignments A map of the cluster labels, and the points assigned to them. - * @param distType The distance metric to employ. + * @param dist The distance metric to employ. * @return A list of {@link ClusterExemplar}s which are used for predictions. */ private static List computeExemplars(SGDVector[] data, Map>> clusterAssignments, - DistanceType distType) { + org.tribuo.math.distance.Distance dist) { List clusterExemplars = new ArrayList<>(); // The formula to calculate the exemplar number. This calculates the number of exemplars to be used for this // configuration. The appropriate number of exemplars is important for prediction. At the time, this @@ -797,7 +796,7 @@ else if (numExemplarsThisCluster > outlierScoreIndexTree.size()) { SGDVector features = data[partialClusterExemplar.getValue()]; double maxInnerDist = Double.NEGATIVE_INFINITY; for (Entry entry : outlierScoreIndexTree.entrySet()) { - double distance = DistanceType.getDistance(features, data[entry.getValue()], distType); + double distance = dist.computeDistance(features, data[entry.getValue()]); if (distance > maxInnerDist){ maxInnerDist = distance; } @@ -834,7 +833,7 @@ private static double getNoisePointsOutlierScore(Map buil } private HdbscanModelProto() { clusterLabels_ = emptyIntList(); - distType_ = ""; clusterExemplars_ = java.util.Collections.emptyList(); } @@ -104,9 +103,16 @@ private HdbscanModelProto( break; } case 34: { - java.lang.String s = input.readStringRequireUtf8(); + org.tribuo.math.protos.DistanceProto.Builder subBuilder = null; + if (distance_ != null) { + subBuilder = distance_.toBuilder(); + } + distance_ = input.readMessage(org.tribuo.math.protos.DistanceProto.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(distance_); + distance_ = subBuilder.buildPartial(); + } - distType_ = s; break; } case 42: { @@ -241,42 +247,30 @@ public org.tribuo.math.protos.TensorProtoOrBuilder getOutlierScoresVectorOrBuild return getOutlierScoresVector(); } - public static final int DIST_TYPE_FIELD_NUMBER = 4; - private volatile java.lang.Object distType_; + public static final int DISTANCE_FIELD_NUMBER = 4; + private org.tribuo.math.protos.DistanceProto distance_; /** - * string dist_type = 4; - * @return The distType. + * .tribuo.math.DistanceProto distance = 4; + * @return Whether the distance field is set. */ @java.lang.Override - public java.lang.String getDistType() { - java.lang.Object ref = distType_; - if (ref instanceof java.lang.String) { - return (java.lang.String) ref; - } else { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - distType_ = s; - return s; - } + public boolean hasDistance() { + return distance_ != null; } /** - * string dist_type = 4; - * @return The bytes for distType. + * .tribuo.math.DistanceProto distance = 4; + * @return The distance. */ @java.lang.Override - public com.google.protobuf.ByteString - getDistTypeBytes() { - java.lang.Object ref = distType_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - distType_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } + public org.tribuo.math.protos.DistanceProto getDistance() { + return distance_ == null ? org.tribuo.math.protos.DistanceProto.getDefaultInstance() : distance_; + } + /** + * .tribuo.math.DistanceProto distance = 4; + */ + @java.lang.Override + public org.tribuo.math.protos.DistanceProtoOrBuilder getDistanceOrBuilder() { + return getDistance(); } public static final int CLUSTER_EXEMPLARS_FIELD_NUMBER = 5; @@ -358,8 +352,8 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (outlierScoresVector_ != null) { output.writeMessage(3, getOutlierScoresVector()); } - if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(distType_)) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 4, distType_); + if (distance_ != null) { + output.writeMessage(4, getDistance()); } for (int i = 0; i < clusterExemplars_.size(); i++) { output.writeMessage(5, clusterExemplars_.get(i)); @@ -398,8 +392,9 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(3, getOutlierScoresVector()); } - if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(distType_)) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(4, distType_); + if (distance_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(4, getDistance()); } for (int i = 0; i < clusterExemplars_.size(); i++) { size += com.google.protobuf.CodedOutputStream @@ -436,8 +431,11 @@ public boolean equals(final java.lang.Object obj) { if (!getOutlierScoresVector() .equals(other.getOutlierScoresVector())) return false; } - if (!getDistType() - .equals(other.getDistType())) return false; + if (hasDistance() != other.hasDistance()) return false; + if (hasDistance()) { + if (!getDistance() + .equals(other.getDistance())) return false; + } if (!getClusterExemplarsList() .equals(other.getClusterExemplarsList())) return false; if (java.lang.Double.doubleToLongBits(getNoisePointsOutlierScore()) @@ -466,8 +464,10 @@ public int hashCode() { hash = (37 * hash) + OUTLIER_SCORES_VECTOR_FIELD_NUMBER; hash = (53 * hash) + getOutlierScoresVector().hashCode(); } - hash = (37 * hash) + DIST_TYPE_FIELD_NUMBER; - hash = (53 * hash) + getDistType().hashCode(); + if (hasDistance()) { + hash = (37 * hash) + DISTANCE_FIELD_NUMBER; + hash = (53 * hash) + getDistance().hashCode(); + } if (getClusterExemplarsCount() > 0) { hash = (37 * hash) + CLUSTER_EXEMPLARS_FIELD_NUMBER; hash = (53 * hash) + getClusterExemplarsList().hashCode(); @@ -627,8 +627,12 @@ public Builder clear() { outlierScoresVector_ = null; outlierScoresVectorBuilder_ = null; } - distType_ = ""; - + if (distanceBuilder_ == null) { + distance_ = null; + } else { + distance_ = null; + distanceBuilder_ = null; + } if (clusterExemplarsBuilder_ == null) { clusterExemplars_ = java.util.Collections.emptyList(); bitField0_ = (bitField0_ & ~0x00000002); @@ -679,7 +683,11 @@ public org.tribuo.clustering.hdbscan.protos.HdbscanModelProto buildPartial() { } else { result.outlierScoresVector_ = outlierScoresVectorBuilder_.build(); } - result.distType_ = distType_; + if (distanceBuilder_ == null) { + result.distance_ = distance_; + } else { + result.distance_ = distanceBuilder_.build(); + } if (clusterExemplarsBuilder_ == null) { if (((bitField0_ & 0x00000002) != 0)) { clusterExemplars_ = java.util.Collections.unmodifiableList(clusterExemplars_); @@ -754,9 +762,8 @@ public Builder mergeFrom(org.tribuo.clustering.hdbscan.protos.HdbscanModelProto if (other.hasOutlierScoresVector()) { mergeOutlierScoresVector(other.getOutlierScoresVector()); } - if (!other.getDistType().isEmpty()) { - distType_ = other.distType_; - onChanged(); + if (other.hasDistance()) { + mergeDistance(other.getDistance()); } if (clusterExemplarsBuilder_ == null) { if (!other.clusterExemplars_.isEmpty()) { @@ -1134,80 +1141,123 @@ public org.tribuo.math.protos.TensorProtoOrBuilder getOutlierScoresVectorOrBuild return outlierScoresVectorBuilder_; } - private java.lang.Object distType_ = ""; + private org.tribuo.math.protos.DistanceProto distance_; + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.DistanceProto, org.tribuo.math.protos.DistanceProto.Builder, org.tribuo.math.protos.DistanceProtoOrBuilder> distanceBuilder_; + /** + * .tribuo.math.DistanceProto distance = 4; + * @return Whether the distance field is set. + */ + public boolean hasDistance() { + return distanceBuilder_ != null || distance_ != null; + } /** - * string dist_type = 4; - * @return The distType. + * .tribuo.math.DistanceProto distance = 4; + * @return The distance. */ - public java.lang.String getDistType() { - java.lang.Object ref = distType_; - if (!(ref instanceof java.lang.String)) { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - distType_ = s; - return s; + public org.tribuo.math.protos.DistanceProto getDistance() { + if (distanceBuilder_ == null) { + return distance_ == null ? org.tribuo.math.protos.DistanceProto.getDefaultInstance() : distance_; } else { - return (java.lang.String) ref; + return distanceBuilder_.getMessage(); } } /** - * string dist_type = 4; - * @return The bytes for distType. + * .tribuo.math.DistanceProto distance = 4; */ - public com.google.protobuf.ByteString - getDistTypeBytes() { - java.lang.Object ref = distType_; - if (ref instanceof String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - distType_ = b; - return b; + public Builder setDistance(org.tribuo.math.protos.DistanceProto value) { + if (distanceBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + distance_ = value; + onChanged(); } else { - return (com.google.protobuf.ByteString) ref; + distanceBuilder_.setMessage(value); } + + return this; } /** - * string dist_type = 4; - * @param value The distType to set. - * @return This builder for chaining. + * .tribuo.math.DistanceProto distance = 4; */ - public Builder setDistType( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - - distType_ = value; - onChanged(); + public Builder setDistance( + org.tribuo.math.protos.DistanceProto.Builder builderForValue) { + if (distanceBuilder_ == null) { + distance_ = builderForValue.build(); + onChanged(); + } else { + distanceBuilder_.setMessage(builderForValue.build()); + } + return this; } /** - * string dist_type = 4; - * @return This builder for chaining. + * .tribuo.math.DistanceProto distance = 4; */ - public Builder clearDistType() { - - distType_ = getDefaultInstance().getDistType(); - onChanged(); + public Builder mergeDistance(org.tribuo.math.protos.DistanceProto value) { + if (distanceBuilder_ == null) { + if (distance_ != null) { + distance_ = + org.tribuo.math.protos.DistanceProto.newBuilder(distance_).mergeFrom(value).buildPartial(); + } else { + distance_ = value; + } + onChanged(); + } else { + distanceBuilder_.mergeFrom(value); + } + return this; } /** - * string dist_type = 4; - * @param value The bytes for distType to set. - * @return This builder for chaining. + * .tribuo.math.DistanceProto distance = 4; */ - public Builder setDistTypeBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - checkByteStringIsUtf8(value); + public Builder clearDistance() { + if (distanceBuilder_ == null) { + distance_ = null; + onChanged(); + } else { + distance_ = null; + distanceBuilder_ = null; + } + + return this; + } + /** + * .tribuo.math.DistanceProto distance = 4; + */ + public org.tribuo.math.protos.DistanceProto.Builder getDistanceBuilder() { - distType_ = value; onChanged(); - return this; + return getDistanceFieldBuilder().getBuilder(); + } + /** + * .tribuo.math.DistanceProto distance = 4; + */ + public org.tribuo.math.protos.DistanceProtoOrBuilder getDistanceOrBuilder() { + if (distanceBuilder_ != null) { + return distanceBuilder_.getMessageOrBuilder(); + } else { + return distance_ == null ? + org.tribuo.math.protos.DistanceProto.getDefaultInstance() : distance_; + } + } + /** + * .tribuo.math.DistanceProto distance = 4; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.DistanceProto, org.tribuo.math.protos.DistanceProto.Builder, org.tribuo.math.protos.DistanceProtoOrBuilder> + getDistanceFieldBuilder() { + if (distanceBuilder_ == null) { + distanceBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.DistanceProto, org.tribuo.math.protos.DistanceProto.Builder, org.tribuo.math.protos.DistanceProtoOrBuilder>( + getDistance(), + getParentForChildren(), + isClean()); + distance_ = null; + } + return distanceBuilder_; } private java.util.List clusterExemplars_ = diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/HdbscanModelProtoOrBuilder.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/HdbscanModelProtoOrBuilder.java index 5eda24812..00896b8e3 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/HdbscanModelProtoOrBuilder.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/HdbscanModelProtoOrBuilder.java @@ -55,16 +55,19 @@ public interface HdbscanModelProtoOrBuilder extends org.tribuo.math.protos.TensorProtoOrBuilder getOutlierScoresVectorOrBuilder(); /** - * string dist_type = 4; - * @return The distType. + * .tribuo.math.DistanceProto distance = 4; + * @return Whether the distance field is set. */ - java.lang.String getDistType(); + boolean hasDistance(); /** - * string dist_type = 4; - * @return The bytes for distType. + * .tribuo.math.DistanceProto distance = 4; + * @return The distance. */ - com.google.protobuf.ByteString - getDistTypeBytes(); + org.tribuo.math.protos.DistanceProto getDistance(); + /** + * .tribuo.math.DistanceProto distance = 4; + */ + org.tribuo.math.protos.DistanceProtoOrBuilder getDistanceOrBuilder(); /** * repeated .tribuo.clustering.hdbscan.ClusterExemplarProto cluster_exemplars = 5; diff --git a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/TribuoClusteringHdbscan.java b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/TribuoClusteringHdbscan.java index abc3a467e..a0e99de0b 100644 --- a/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/TribuoClusteringHdbscan.java +++ b/Clustering/Hdbscan/src/main/java/org/tribuo/clustering/hdbscan/protos/TribuoClusteringHdbscan.java @@ -38,16 +38,16 @@ public static void registerAllExtensions( "\021tribuo-math.proto\"\202\001\n\024ClusterExemplarPr" + "oto\022\r\n\005label\030\001 \001(\005\022\025\n\routlier_score\030\002 \001(" + "\001\022*\n\010features\030\003 \001(\0132\030.tribuo.math.Tensor" + - "Proto\022\030\n\020max_dist_to_edge\030\004 \001(\001\"\226\002\n\021Hdbs" + + "Proto\022\030\n\020max_dist_to_edge\030\004 \001(\001\"\261\002\n\021Hdbs" + "canModelProto\022-\n\010metadata\030\001 \001(\0132\033.tribuo" + ".core.ModelDataProto\022\026\n\016cluster_labels\030\002" + " \003(\005\0227\n\025outlier_scores_vector\030\003 \001(\0132\030.tr" + - "ibuo.math.TensorProto\022\021\n\tdist_type\030\004 \001(\t" + - "\022J\n\021cluster_exemplars\030\005 \003(\0132/.tribuo.clu" + - "stering.hdbscan.ClusterExemplarProto\022\"\n\032" + - "noise_points_outlier_score\030\006 \001(\001B(\n$org." + - "tribuo.clustering.hdbscan.protosP\001b\006prot" + - "o3" + "ibuo.math.TensorProto\022,\n\010distance\030\004 \001(\0132" + + "\032.tribuo.math.DistanceProto\022J\n\021cluster_e" + + "xemplars\030\005 \003(\0132/.tribuo.clustering.hdbsc" + + "an.ClusterExemplarProto\022\"\n\032noise_points_" + + "outlier_score\030\006 \001(\001B(\n$org.tribuo.cluste" + + "ring.hdbscan.protosP\001b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -66,7 +66,7 @@ public static void registerAllExtensions( internal_static_tribuo_clustering_hdbscan_HdbscanModelProto_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_tribuo_clustering_hdbscan_HdbscanModelProto_descriptor, - new java.lang.String[] { "Metadata", "ClusterLabels", "OutlierScoresVector", "DistType", "ClusterExemplars", "NoisePointsOutlierScore", }); + new java.lang.String[] { "Metadata", "ClusterLabels", "OutlierScoresVector", "Distance", "ClusterExemplars", "NoisePointsOutlierScore", }); org.tribuo.protos.core.TribuoCore.getDescriptor(); org.tribuo.math.protos.TribuoMath.getDescriptor(); } diff --git a/Clustering/Hdbscan/src/main/resources/protos/tribuo-clustering-hdbscan.proto b/Clustering/Hdbscan/src/main/resources/protos/tribuo-clustering-hdbscan.proto index 43481fd2c..306ea12fe 100644 --- a/Clustering/Hdbscan/src/main/resources/protos/tribuo-clustering-hdbscan.proto +++ b/Clustering/Hdbscan/src/main/resources/protos/tribuo-clustering-hdbscan.proto @@ -47,7 +47,7 @@ message HdbscanModelProto { tribuo.core.ModelDataProto metadata = 1; repeated int32 cluster_labels = 2; tribuo.math.TensorProto outlier_scores_vector = 3; - string dist_type = 4; + tribuo.math.DistanceProto distance = 4; repeated ClusterExemplarProto cluster_exemplars = 5; double noise_points_outlier_score = 6; } diff --git a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java index dd8fd2a0c..713b7a64b 100644 --- a/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java +++ b/Clustering/Hdbscan/src/test/java/org/tribuo/clustering/hdbscan/TestHdbscan.java @@ -74,7 +74,7 @@ */ public class TestHdbscan { - private static final HdbscanTrainer t = new HdbscanTrainer(5, DistanceType.L2, 5,2, NeighboursQueryFactoryType.KD_TREE); + private static final HdbscanTrainer t = new HdbscanTrainer(5, DistanceType.L2.getDistance(), 5,2, NeighboursQueryFactoryType.KD_TREE); @BeforeAll public static void setup() { @@ -96,7 +96,7 @@ public void testInvocationCounter() { CSVDataSource csvSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians.csv"),rowProcessor,false); Dataset dataset = new MutableDataset<>(csvSource); - HdbscanTrainer trainer = new HdbscanTrainer(7, DistanceType.L2, 7,4, NeighboursQueryFactoryType.BRUTE_FORCE); + HdbscanTrainer trainer = new HdbscanTrainer(7, DistanceType.L2.getDistance(), 7,4, NeighboursQueryFactoryType.BRUTE_FORCE); for (int i = 0; i < 5; i++) { HdbscanModel model = trainer.train(dataset); } @@ -124,7 +124,7 @@ public void testEndToEndTrainWithCSVData() { CSVDataSource csvSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians.csv"),rowProcessor,false); Dataset dataset = new MutableDataset<>(csvSource); - NeighboursQueryFactory kdTreeFactory = new KDTreeFactory(DistanceType.L2, 4); + NeighboursQueryFactory kdTreeFactory = new KDTreeFactory(DistanceType.L2.getDistance(), 4); HdbscanTrainer trainer = new HdbscanTrainer(7,7,kdTreeFactory); HdbscanModel model = trainer.train(dataset); @@ -156,7 +156,7 @@ public void testEndToEndPredictWithCSVData() { CSVDataSource csvTestSource = new CSVDataSource<>(Paths.get("src/test/resources/basic-gaussians-predict.csv"),rowProcessor,false); Dataset testSet = new MutableDataset<>(csvTestSource); - HdbscanTrainer trainer = new HdbscanTrainer(7, DistanceType.L2, 7,1, NeighboursQueryFactoryType.BRUTE_FORCE); + HdbscanTrainer trainer = new HdbscanTrainer(7, DistanceType.L2.getDistance(), 7,1, NeighboursQueryFactoryType.BRUTE_FORCE); HdbscanModel model = trainer.train(dataset); List clusterLabels = model.getClusterLabels(); diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansModel.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansModel.java index 69423abbe..e14902f22 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansModel.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansModel.java @@ -30,13 +30,13 @@ import org.tribuo.clustering.kmeans.KMeansTrainer.Distance; import org.tribuo.clustering.kmeans.protos.KMeansModelProto; import org.tribuo.impl.ModelDataCarrier; -import org.tribuo.math.distance.DistanceType; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.SparseVector; import org.tribuo.math.la.Tensor; import org.tribuo.math.la.VectorTuple; import org.tribuo.math.protos.TensorProto; +import org.tribuo.protos.ProtoUtil; import org.tribuo.protos.core.ModelProto; import org.tribuo.provenance.ModelProvenance; @@ -77,13 +77,13 @@ public class KMeansModel extends Model { // This is not final to support deserialization of older models. It will be final in a future version which doesn't // maintain serialization compatibility with 4.X. - private DistanceType distType; + private org.tribuo.math.distance.Distance dist; KMeansModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, - ImmutableOutputInfo outputIDInfo, DenseVector[] centroidVectors, DistanceType distType) { + ImmutableOutputInfo outputIDInfo, DenseVector[] centroidVectors, org.tribuo.math.distance.Distance dist) { super(name,description,featureIDMap,outputIDInfo,false); this.centroidVectors = centroidVectors; - this.distType = distType; + this.dist = dist; } /** @@ -125,9 +125,9 @@ public static KMeansModel deserializeFromProto(int version, String className, An } } - DistanceType distType = DistanceType.valueOf(proto.getDistType()); + org.tribuo.math.distance.Distance dist = ProtoUtil.deserialize(proto.getDistance()); - return new KMeansModel(carrier.name(), carrier.provenance(), featureDomain, outputDomain, centroids, distType); + return new KMeansModel(carrier.name(), carrier.provenance(), featureDomain, outputDomain, centroids, dist); } /** @@ -191,7 +191,7 @@ public Prediction predict(Example example) { double minDistance = Double.POSITIVE_INFINITY; int id = -1; for (int i = 0; i < centroidVectors.length; i++) { - double distance = DistanceType.getDistance(centroidVectors[i], vector, distType); + double distance = dist.computeDistance(centroidVectors[i], vector); if (distance < minDistance) { minDistance = distance; @@ -217,7 +217,7 @@ public ModelProto serialize() { KMeansModelProto.Builder modelBuilder = KMeansModelProto.newBuilder(); modelBuilder.setMetadata(carrier.serialize()); - modelBuilder.setDistType(distType.name()); + modelBuilder.setDistance(dist.serialize()); for (DenseVector e : centroidVectors) { modelBuilder.addCentroidVectors(e.serialize()); } @@ -236,13 +236,13 @@ protected KMeansModel copy(String newName, ModelProvenance newProvenance) { for (int i = 0; i < centroidVectors.length; i++) { newCentroids[i] = centroidVectors[i].copy(); } - return new KMeansModel(newName,newProvenance,featureIDMap,outputIDInfo,newCentroids,distType); + return new KMeansModel(newName,newProvenance,featureIDMap,outputIDInfo,newCentroids,dist); } private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject(); - if (distType == null) { - distType = distanceType.getDistanceType(); + if (dist == null) { + dist = distanceType.getDistanceType().getDistance(); } } } diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java index a6aa0d51b..792e4a7df 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java @@ -64,6 +64,6 @@ public class KMeansOptions implements Options { */ public KMeansTrainer getTrainer() { logger.info("Configuring K-Means Trainer"); - return new KMeansTrainer(centroids, iterations, distType, initialisation, numThreads, seed); + return new KMeansTrainer(centroids, iterations, distType.getDistance(), initialisation, numThreads, seed); } } diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index ec40312ab..1ea08c7a0 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -158,7 +158,7 @@ public enum Initialisation { private Distance distanceType; @Config(description = "The distance function to use.") - private DistanceType distType; + private org.tribuo.math.distance.Distance dist; @Config(description = "The centroid initialisation method to use.") private Initialisation initialisationType = Initialisation.RANDOM; @@ -199,12 +199,12 @@ public KMeansTrainer(int centroids, int iterations, Distance distanceType, int n * * @param centroids The number of centroids to use. * @param iterations The maximum number of iterations. - * @param distType The distance function. + * @param dist The distance function. * @param numThreads The number of threads. * @param seed The random seed. */ - public KMeansTrainer(int centroids, int iterations, DistanceType distType, int numThreads, long seed) { - this(centroids,iterations,distType,Initialisation.RANDOM,numThreads,seed); + public KMeansTrainer(int centroids, int iterations, org.tribuo.math.distance.Distance dist, int numThreads, long seed) { + this(centroids,iterations,dist,Initialisation.RANDOM,numThreads,seed); } /** @@ -221,7 +221,7 @@ public KMeansTrainer(int centroids, int iterations, DistanceType distType, int n */ @Deprecated public KMeansTrainer(int centroids, int iterations, Distance distanceType, Initialisation initialisationType, int numThreads, long seed) { - this(centroids, iterations, distanceType.getDistanceType(), initialisationType, numThreads, seed); + this(centroids, iterations, distanceType.getDistanceType().getDistance(), initialisationType, numThreads, seed); } /** @@ -229,15 +229,15 @@ public KMeansTrainer(int centroids, int iterations, Distance distanceType, Initi * * @param centroids The number of centroids to use. * @param iterations The maximum number of iterations. - * @param distType The distance function. + * @param dist The distance function. * @param initialisationType The centroid initialization method. * @param numThreads The number of threads. * @param seed The random seed. */ - public KMeansTrainer(int centroids, int iterations, DistanceType distType, Initialisation initialisationType, int numThreads, long seed) { + public KMeansTrainer(int centroids, int iterations, org.tribuo.math.distance.Distance dist, Initialisation initialisationType, int numThreads, long seed) { this.centroids = centroids; this.iterations = iterations; - this.distType = distType; + this.dist = dist; this.initialisationType = initialisationType; this.numThreads = numThreads; this.seed = seed; @@ -252,10 +252,10 @@ public synchronized void postConfig() { this.rng = new SplittableRandom(seed); if (this.distanceType != null) { - if (this.distType != null) { - throw new PropertyException("distType", "Both distType and distanceType must not both be set."); + if (this.dist != null) { + throw new PropertyException("dist", "Both dist and distanceType must not both be set."); } else { - this.distType = this.distanceType.getDistanceType(); + this.dist = this.distanceType.getDistanceType().getDistance(); this.distanceType = null; } } @@ -285,8 +285,6 @@ public KMeansModel train(Dataset examples, Map ru } ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); - - int[] oldCentre = new int[examples.size()]; SGDVector[] data = new SGDVector[examples.size()]; double[] weights = new double[examples.size()]; @@ -308,7 +306,7 @@ public KMeansModel train(Dataset examples, Map ru centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG); break; case PLUSPLUS: - centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, distType); + centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, dist); break; default: throw new IllegalStateException("Unknown initialisation" + initialisationType); @@ -328,7 +326,7 @@ public KMeansModel train(Dataset examples, Map ru SGDVector vector = e.vector; for (int j = 0; j < centroids; j++) { DenseVector cluster = centroidVectors[j]; - double distance = DistanceType.getDistance(cluster, vector, distType); + double distance = dist.computeDistance(cluster, vector); if (distance < minDist) { minDist = distance; clusterID = j; @@ -402,7 +400,7 @@ public KMeansModel train(Dataset examples, Map ru ModelProvenance provenance = new ModelProvenance(KMeansModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); - return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, distType); + return new KMeansModel("k-means-model", provenance, featureMap, outputMap, centroidVectors, dist); } @Override @@ -460,11 +458,11 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF * @param centroids The number of centroids to create. * @param data The dataset of {@link SGDVector} to use. * @param rng The RNG to use. - * @param distType The distance function. + * @param dist The distance function. * @return A {@link DenseVector} array of centroids. */ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVector[] data, SplittableRandom rng, - DistanceType distType) { + org.tribuo.math.distance.Distance dist) { if (centroids > data.length) { throw new IllegalArgumentException("The number of centroids may not exceed the number of samples."); } @@ -486,7 +484,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVecto // go through every vector and see if the min distance to the // newest centroid is smaller than previous min distance for vec for (int j = 0; j < data.length; j++) { - double tempDistance = DistanceType.getDistance(prevCentroid, data[j], distType); + double tempDistance = dist.computeDistance(prevCentroid, data[j]); minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance); } @@ -565,7 +563,7 @@ protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map build } private KMeansModelProto() { centroidVectors_ = java.util.Collections.emptyList(); - distType_ = ""; } @java.lang.Override @@ -78,9 +77,16 @@ private KMeansModelProto( break; } case 26: { - java.lang.String s = input.readStringRequireUtf8(); + org.tribuo.math.protos.DistanceProto.Builder subBuilder = null; + if (distance_ != null) { + subBuilder = distance_.toBuilder(); + } + distance_ = input.readMessage(org.tribuo.math.protos.DistanceProto.parser(), extensionRegistry); + if (subBuilder != null) { + subBuilder.mergeFrom(distance_); + distance_ = subBuilder.buildPartial(); + } - distType_ = s; break; } default: { @@ -184,42 +190,30 @@ public org.tribuo.math.protos.TensorProtoOrBuilder getCentroidVectorsOrBuilder( return centroidVectors_.get(index); } - public static final int DIST_TYPE_FIELD_NUMBER = 3; - private volatile java.lang.Object distType_; + public static final int DISTANCE_FIELD_NUMBER = 3; + private org.tribuo.math.protos.DistanceProto distance_; /** - * string dist_type = 3; - * @return The distType. + * .tribuo.math.DistanceProto distance = 3; + * @return Whether the distance field is set. */ @java.lang.Override - public java.lang.String getDistType() { - java.lang.Object ref = distType_; - if (ref instanceof java.lang.String) { - return (java.lang.String) ref; - } else { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - distType_ = s; - return s; - } + public boolean hasDistance() { + return distance_ != null; } /** - * string dist_type = 3; - * @return The bytes for distType. + * .tribuo.math.DistanceProto distance = 3; + * @return The distance. */ @java.lang.Override - public com.google.protobuf.ByteString - getDistTypeBytes() { - java.lang.Object ref = distType_; - if (ref instanceof java.lang.String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - distType_ = b; - return b; - } else { - return (com.google.protobuf.ByteString) ref; - } + public org.tribuo.math.protos.DistanceProto getDistance() { + return distance_ == null ? org.tribuo.math.protos.DistanceProto.getDefaultInstance() : distance_; + } + /** + * .tribuo.math.DistanceProto distance = 3; + */ + @java.lang.Override + public org.tribuo.math.protos.DistanceProtoOrBuilder getDistanceOrBuilder() { + return getDistance(); } private byte memoizedIsInitialized = -1; @@ -242,8 +236,8 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) for (int i = 0; i < centroidVectors_.size(); i++) { output.writeMessage(2, centroidVectors_.get(i)); } - if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(distType_)) { - com.google.protobuf.GeneratedMessageV3.writeString(output, 3, distType_); + if (distance_ != null) { + output.writeMessage(3, getDistance()); } unknownFields.writeTo(output); } @@ -262,8 +256,9 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(2, centroidVectors_.get(i)); } - if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(distType_)) { - size += com.google.protobuf.GeneratedMessageV3.computeStringSize(3, distType_); + if (distance_ != null) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, getDistance()); } size += unknownFields.getSerializedSize(); memoizedSize = size; @@ -287,8 +282,11 @@ public boolean equals(final java.lang.Object obj) { } if (!getCentroidVectorsList() .equals(other.getCentroidVectorsList())) return false; - if (!getDistType() - .equals(other.getDistType())) return false; + if (hasDistance() != other.hasDistance()) return false; + if (hasDistance()) { + if (!getDistance() + .equals(other.getDistance())) return false; + } if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -308,8 +306,10 @@ public int hashCode() { hash = (37 * hash) + CENTROID_VECTORS_FIELD_NUMBER; hash = (53 * hash) + getCentroidVectorsList().hashCode(); } - hash = (37 * hash) + DIST_TYPE_FIELD_NUMBER; - hash = (53 * hash) + getDistType().hashCode(); + if (hasDistance()) { + hash = (37 * hash) + DISTANCE_FIELD_NUMBER; + hash = (53 * hash) + getDistance().hashCode(); + } hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -460,8 +460,12 @@ public Builder clear() { } else { centroidVectorsBuilder_.clear(); } - distType_ = ""; - + if (distanceBuilder_ == null) { + distance_ = null; + } else { + distance_ = null; + distanceBuilder_ = null; + } return this; } @@ -503,7 +507,11 @@ public org.tribuo.clustering.kmeans.protos.KMeansModelProto buildPartial() { } else { result.centroidVectors_ = centroidVectorsBuilder_.build(); } - result.distType_ = distType_; + if (distanceBuilder_ == null) { + result.distance_ = distance_; + } else { + result.distance_ = distanceBuilder_.build(); + } onBuilt(); return result; } @@ -581,9 +589,8 @@ public Builder mergeFrom(org.tribuo.clustering.kmeans.protos.KMeansModelProto ot } } } - if (!other.getDistType().isEmpty()) { - distType_ = other.distType_; - onChanged(); + if (other.hasDistance()) { + mergeDistance(other.getDistance()); } this.mergeUnknownFields(other.unknownFields); onChanged(); @@ -974,80 +981,123 @@ public org.tribuo.math.protos.TensorProto.Builder addCentroidVectorsBuilder( return centroidVectorsBuilder_; } - private java.lang.Object distType_ = ""; + private org.tribuo.math.protos.DistanceProto distance_; + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.DistanceProto, org.tribuo.math.protos.DistanceProto.Builder, org.tribuo.math.protos.DistanceProtoOrBuilder> distanceBuilder_; + /** + * .tribuo.math.DistanceProto distance = 3; + * @return Whether the distance field is set. + */ + public boolean hasDistance() { + return distanceBuilder_ != null || distance_ != null; + } /** - * string dist_type = 3; - * @return The distType. + * .tribuo.math.DistanceProto distance = 3; + * @return The distance. */ - public java.lang.String getDistType() { - java.lang.Object ref = distType_; - if (!(ref instanceof java.lang.String)) { - com.google.protobuf.ByteString bs = - (com.google.protobuf.ByteString) ref; - java.lang.String s = bs.toStringUtf8(); - distType_ = s; - return s; + public org.tribuo.math.protos.DistanceProto getDistance() { + if (distanceBuilder_ == null) { + return distance_ == null ? org.tribuo.math.protos.DistanceProto.getDefaultInstance() : distance_; } else { - return (java.lang.String) ref; + return distanceBuilder_.getMessage(); } } /** - * string dist_type = 3; - * @return The bytes for distType. + * .tribuo.math.DistanceProto distance = 3; */ - public com.google.protobuf.ByteString - getDistTypeBytes() { - java.lang.Object ref = distType_; - if (ref instanceof String) { - com.google.protobuf.ByteString b = - com.google.protobuf.ByteString.copyFromUtf8( - (java.lang.String) ref); - distType_ = b; - return b; + public Builder setDistance(org.tribuo.math.protos.DistanceProto value) { + if (distanceBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + distance_ = value; + onChanged(); } else { - return (com.google.protobuf.ByteString) ref; + distanceBuilder_.setMessage(value); } + + return this; } /** - * string dist_type = 3; - * @param value The distType to set. - * @return This builder for chaining. + * .tribuo.math.DistanceProto distance = 3; */ - public Builder setDistType( - java.lang.String value) { - if (value == null) { - throw new NullPointerException(); - } - - distType_ = value; - onChanged(); + public Builder setDistance( + org.tribuo.math.protos.DistanceProto.Builder builderForValue) { + if (distanceBuilder_ == null) { + distance_ = builderForValue.build(); + onChanged(); + } else { + distanceBuilder_.setMessage(builderForValue.build()); + } + return this; } /** - * string dist_type = 3; - * @return This builder for chaining. + * .tribuo.math.DistanceProto distance = 3; */ - public Builder clearDistType() { - - distType_ = getDefaultInstance().getDistType(); - onChanged(); + public Builder mergeDistance(org.tribuo.math.protos.DistanceProto value) { + if (distanceBuilder_ == null) { + if (distance_ != null) { + distance_ = + org.tribuo.math.protos.DistanceProto.newBuilder(distance_).mergeFrom(value).buildPartial(); + } else { + distance_ = value; + } + onChanged(); + } else { + distanceBuilder_.mergeFrom(value); + } + return this; } /** - * string dist_type = 3; - * @param value The bytes for distType to set. - * @return This builder for chaining. + * .tribuo.math.DistanceProto distance = 3; */ - public Builder setDistTypeBytes( - com.google.protobuf.ByteString value) { - if (value == null) { - throw new NullPointerException(); - } - checkByteStringIsUtf8(value); + public Builder clearDistance() { + if (distanceBuilder_ == null) { + distance_ = null; + onChanged(); + } else { + distance_ = null; + distanceBuilder_ = null; + } + + return this; + } + /** + * .tribuo.math.DistanceProto distance = 3; + */ + public org.tribuo.math.protos.DistanceProto.Builder getDistanceBuilder() { - distType_ = value; onChanged(); - return this; + return getDistanceFieldBuilder().getBuilder(); + } + /** + * .tribuo.math.DistanceProto distance = 3; + */ + public org.tribuo.math.protos.DistanceProtoOrBuilder getDistanceOrBuilder() { + if (distanceBuilder_ != null) { + return distanceBuilder_.getMessageOrBuilder(); + } else { + return distance_ == null ? + org.tribuo.math.protos.DistanceProto.getDefaultInstance() : distance_; + } + } + /** + * .tribuo.math.DistanceProto distance = 3; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.DistanceProto, org.tribuo.math.protos.DistanceProto.Builder, org.tribuo.math.protos.DistanceProtoOrBuilder> + getDistanceFieldBuilder() { + if (distanceBuilder_ == null) { + distanceBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.DistanceProto, org.tribuo.math.protos.DistanceProto.Builder, org.tribuo.math.protos.DistanceProtoOrBuilder>( + getDistance(), + getParentForChildren(), + isClean()); + distance_ = null; + } + return distanceBuilder_; } @java.lang.Override public final Builder setUnknownFields( diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/KMeansModelProtoOrBuilder.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/KMeansModelProtoOrBuilder.java index 2110ea8af..15fe0bec8 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/KMeansModelProtoOrBuilder.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/KMeansModelProtoOrBuilder.java @@ -47,14 +47,17 @@ org.tribuo.math.protos.TensorProtoOrBuilder getCentroidVectorsOrBuilder( int index); /** - * string dist_type = 3; - * @return The distType. + * .tribuo.math.DistanceProto distance = 3; + * @return Whether the distance field is set. */ - java.lang.String getDistType(); + boolean hasDistance(); /** - * string dist_type = 3; - * @return The bytes for distType. + * .tribuo.math.DistanceProto distance = 3; + * @return The distance. */ - com.google.protobuf.ByteString - getDistTypeBytes(); + org.tribuo.math.protos.DistanceProto getDistance(); + /** + * .tribuo.math.DistanceProto distance = 3; + */ + org.tribuo.math.protos.DistanceProtoOrBuilder getDistanceOrBuilder(); } diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/TribuoClusteringKmeans.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/TribuoClusteringKmeans.java index f0ae4453c..1c29cbe47 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/TribuoClusteringKmeans.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/protos/TribuoClusteringKmeans.java @@ -30,12 +30,12 @@ public static void registerAllExtensions( java.lang.String[] descriptorData = { "\n\036tribuo-clustering-kmeans.proto\022\030tribuo" + ".clustering.kmeans\032\021tribuo-core.proto\032\021t" + - "ribuo-math.proto\"\210\001\n\020KMeansModelProto\022-\n" + + "ribuo-math.proto\"\243\001\n\020KMeansModelProto\022-\n" + "\010metadata\030\001 \001(\0132\033.tribuo.core.ModelDataP" + "roto\0222\n\020centroid_vectors\030\002 \003(\0132\030.tribuo." + - "math.TensorProto\022\021\n\tdist_type\030\003 \001(\tB\'\n#o" + - "rg.tribuo.clustering.kmeans.protosP\001b\006pr" + - "oto3" + "math.TensorProto\022,\n\010distance\030\003 \001(\0132\032.tri" + + "buo.math.DistanceProtoB\'\n#org.tribuo.clu" + + "stering.kmeans.protosP\001b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -48,7 +48,7 @@ public static void registerAllExtensions( internal_static_tribuo_clustering_kmeans_KMeansModelProto_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_tribuo_clustering_kmeans_KMeansModelProto_descriptor, - new java.lang.String[] { "Metadata", "CentroidVectors", "DistType", }); + new java.lang.String[] { "Metadata", "CentroidVectors", "Distance", }); org.tribuo.protos.core.TribuoCore.getDescriptor(); org.tribuo.math.protos.TribuoMath.getDescriptor(); } diff --git a/Clustering/KMeans/src/main/resources/protos/tribuo-clustering-kmeans.proto b/Clustering/KMeans/src/main/resources/protos/tribuo-clustering-kmeans.proto index b22279c2b..e0a90ad39 100644 --- a/Clustering/KMeans/src/main/resources/protos/tribuo-clustering-kmeans.proto +++ b/Clustering/KMeans/src/main/resources/protos/tribuo-clustering-kmeans.proto @@ -36,5 +36,5 @@ KMeansModel proto message KMeansModelProto { tribuo.core.ModelDataProto metadata = 1; repeated tribuo.math.TensorProto centroid_vectors = 2; - string dist_type = 3; + tribuo.math.DistanceProto distance = 3; } diff --git a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java index fe7e51bad..83addce8a 100644 --- a/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java +++ b/Clustering/KMeans/src/test/java/org/tribuo/clustering/kmeans/TestKMeans.java @@ -43,10 +43,10 @@ */ public class TestKMeans { - private static final KMeansTrainer t = new KMeansTrainer(4,10, DistanceType.L2, + private static final KMeansTrainer t = new KMeansTrainer(4,10, DistanceType.L2.getDistance(), KMeansTrainer.Initialisation.RANDOM, 1,1); - private static final KMeansTrainer plusPlus = new KMeansTrainer(4,10, DistanceType.L2, + private static final KMeansTrainer plusPlus = new KMeansTrainer(4,10, DistanceType.L2.getDistance(), KMeansTrainer.Initialisation.PLUSPLUS, 1,1); @BeforeAll @@ -172,7 +172,7 @@ public void testPlusPlusTooManyCentroids() { @Test public void testSetInvocationCount() { // Create new trainer and dataset so as not to mess with the other tests - KMeansTrainer originalTrainer = new KMeansTrainer(4,10, DistanceType.L2, + KMeansTrainer originalTrainer = new KMeansTrainer(4,10, DistanceType.L2.getDistance(), KMeansTrainer.Initialisation.RANDOM, 1,1); Pair,Dataset> p = ClusteringDataGenerator.denseTrainTest(); @@ -190,7 +190,7 @@ public void testSetInvocationCount() { // Create a new model with same configuration, but set the invocation count to numOfInvocations // Assert that this succeeded, this means RNG will be at state where originalTrainer was before // it performed its last train. - KMeansTrainer newTrainer = new KMeansTrainer(4,10, DistanceType.L2, + KMeansTrainer newTrainer = new KMeansTrainer(4,10, DistanceType.L2.getDistance(), KMeansTrainer.Initialisation.RANDOM, 1,1); newTrainer.setInvocationCount(numOfInvocations); assertEquals(numOfInvocations,newTrainer.getInvocationCount()); @@ -211,7 +211,7 @@ public void testSetInvocationCount() { @Test public void testNegativeInvocationCount(){ assertThrows(IllegalArgumentException.class, () -> { - KMeansTrainer t = new KMeansTrainer(4,10, DistanceType.L2, + KMeansTrainer t = new KMeansTrainer(4,10, DistanceType.L2.getDistance(), KMeansTrainer.Initialisation.RANDOM, 1,1); t.setInvocationCount(-1); }); diff --git a/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNClassifierOptions.java b/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNClassifierOptions.java index afeba8c61..20caa06d7 100644 --- a/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNClassifierOptions.java +++ b/Common/NearestNeighbour/src/main/java/org/tribuo/common/nearest/KNNClassifierOptions.java @@ -97,6 +97,6 @@ private EnsembleCombiner