Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a multi-label linear sgd classifier #106

Merged
merged 26 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0fe1894
Initial draft of multi-label logistic regression. Needs thresholding …
Craigacp Nov 5, 2020
e91fb8c
Finishes the multi-label linear model implementation.
Craigacp Nov 6, 2020
bc14b69
Tidying up the names and the docs
Craigacp Nov 6, 2020
83e9bc9
Reducing memory usage of single labels in a MultiLabel.
Craigacp Nov 13, 2020
414ed6e
Add an option to quiesce ONNX tests.
Craigacp Nov 26, 2020
061cd23
Optimizing DenseVector and DenseMatrix to use more efficient operatio…
Craigacp Nov 29, 2020
85e7423
Relaxing LinearParameters so predict accepts a DenseVector.
Craigacp Nov 29, 2020
af7316e
Fixing a bug in DenseVector.createDenseVector.
Craigacp Nov 29, 2020
8a117b0
Adding a method to the VectorNormalizer interface that normalizes thi…
Craigacp Nov 29, 2020
daf61e1
Converts DenseMatrix.normalizeRows over to use the in place normaliza…
Craigacp Nov 29, 2020
c6d1420
Adding a new common project for SGD.
Craigacp Nov 30, 2020
a6ea4ca
Refactoring the various SGD objective functions to share an interface…
Craigacp Nov 30, 2020
734bf8d
Refactoring the different LinearSGDModels so they share a common base…
Craigacp Nov 30, 2020
9c0d9b0
Refactoring the different LinearSGDTrainers so they share a common ba…
Craigacp Nov 30, 2020
236e750
Tidying up the argument names in LinearSGDModel.
Craigacp Nov 30, 2020
1dae605
Removing unnecessary MultiLabel SGD Util class.
Craigacp Nov 30, 2020
f21b4a0
Adding a package-info.java to Common/SGD.
Craigacp Nov 30, 2020
818d5bc
Fix licenses in Common/SGD and MultiLabel/SGD
Craigacp Nov 30, 2020
001c5b4
Fixing the license files again.
Craigacp Dec 9, 2020
4c3fdf0
Restoring backwards compatibility for classification & regression Lin…
Craigacp Dec 9, 2020
3421bf7
Removing new deprecated code.
Craigacp Dec 9, 2020
237c48b
Adding deprecated annotations to the old weights in the linear sgd su…
Craigacp Dec 10, 2020
3860e0e
Javadoc updates.
Craigacp Dec 10, 2020
67c04c6
Adding a note to the roadmap about the multi-label linear sgd.
Craigacp Dec 10, 2020
de37d03
Fixing the review comments.
Craigacp Dec 17, 2020
4e3b724
Removing the unused LossAndGradient class.
Craigacp Dec 17, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ public String toReadableString() {
return builder.toString();
}

@Override
public String toString() {
return toReadableString();
}

@Override
public Iterator<Pair<Integer, Label>> iterator() {
return new ImmutableInfoIterator(idLabelMap);
Expand Down
5 changes: 5 additions & 0 deletions Classification/SGD/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
<artifactId>tribuo-math</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-common-sgd</artifactId>
<version>${project.version}</version>
</dependency>
<!-- test time dependencies -->
<dependency>
<groupId>${project.groupId}</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

package org.tribuo.classification.sgd;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.common.sgd.SGDObjective;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.util.VectorNormalizer;

Expand All @@ -29,15 +27,23 @@
* An objective knows if it generates a probabilistic model or not,
* and what kind of normalization needs to be applied to produce probability values.
*/
public interface LabelObjective extends Configurable, Provenancable<ConfiguredObjectProvenance> {
public interface LabelObjective extends SGDObjective<Integer> {

/**
* Scores a prediction, returning the loss and a vector of per label gradients.
* @param truth The true label id.
*
* @deprecated In 4.1, to migrate to the new name {@link #lossAndGradient}.
* @param truth The true label id.
* @param prediction The prediction for each label id.
* @return The score and per label gradient.
*/
public Pair<Double,SGDVector> valueAndGradient(int truth, SGDVector prediction);
@Deprecated
Pair<Double, SGDVector> valueAndGradient(int truth, SGDVector prediction);

@Override
default Pair<Double, SGDVector> lossAndGradient(Integer truth, SGDVector prediction) {
return valueAndGradient(truth, prediction);
}

/**
* Generates a new {@link VectorNormalizer} which normalizes the predictions into [0,1].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,21 @@

package org.tribuo.classification.sgd.linear;

import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.provenance.ModelProvenance;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;

/**
* The inference time version of a linear model trained using SGD.
Expand All @@ -52,122 +42,89 @@
* Proceedings of COMPSTAT, 2010.
* </pre>
*/
public class LinearSGDModel extends Model<Label> {
public class LinearSGDModel extends AbstractLinearSGDModel<Label> {
private static final long serialVersionUID = 2L;

private final DenseMatrix weights;
private final VectorNormalizer normalizer;

LinearSGDModel(String name, ModelProvenance description,
ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap,
// Unused as the weights now live in AbstractLinearSGDModel
// It remains for serialization compatibility with Tribuo 4.0
@Deprecated
private DenseMatrix weights = null;

/**
* Constructs a linear classification model trained via SGD.
* @param name The model name.
* @param provenance The model provenance.
* @param featureIDMap The feature domain.
* @param outputIDInfo The output domain.
* @param parameters The model parameters (i.e., the weight matrix).
* @param normalizer The normalization function.
* @param generatesProbabilities Does this model generate probabilities?
*/
LinearSGDModel(String name, ModelProvenance provenance,
ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo,
LinearParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities) {
super(name, description, featureIDMap, labelIDMap, generatesProbabilities);
this.weights = parameters.getWeightMatrix();
super(name, provenance, featureIDMap, outputIDInfo, parameters.getWeightMatrix(), generatesProbabilities);
this.normalizer = normalizer;
}

private LinearSGDModel(String name, ModelProvenance description,
ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap,
/**
* Constructs a linear classification model trained via SGD.
* @param name The model name.
* @param provenance The model provenance.
* @param featureIDMap The feature domain.
* @param outputIDInfo The output domain.
* @param weights The model parameters (i.e., the weight matrix).
* @param normalizer The normalization function.
* @param generatesProbabilities Does this model generate probabilities?
*/
private LinearSGDModel(String name, ModelProvenance provenance,
ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo,
DenseMatrix weights, VectorNormalizer normalizer, boolean generatesProbabilities) {
super(name, description, featureIDMap, labelIDMap, generatesProbabilities);
this.weights = weights;
super(name, provenance, featureIDMap, outputIDInfo, weights, generatesProbabilities);
this.normalizer = normalizer;
}

@Override
public Prediction<Label> predict(Example<Label> example) {
SparseVector features = SparseVector.createSparseVector(example,featureIDMap,true);
// Due to bias feature
if (features.numActiveElements() == 1) {
throw new IllegalArgumentException("No features found in Example " + example.toString());
}
DenseVector prediction = weights.leftMultiply(features);
PredAndActive predTuple = predictSingle(example);
DenseVector prediction = predTuple.prediction;
prediction.normalize(normalizer);

double maxScore = Double.NEGATIVE_INFINITY;
Label maxLabel = null;
Map<String,Label> predMap = new LinkedHashMap<>();
for (int i = 0; i < prediction.size(); i++) {
String labelName = outputIDInfo.getOutput(i).getLabel();
Label label = new Label(labelName, prediction.get(i));
double score = prediction.get(i);
Label label = new Label(labelName, score);
predMap.put(labelName,label);
if (label.getScore() > maxScore) {
maxScore = label.getScore();
if (score > maxScore) {
maxScore = score;
maxLabel = label;
}
}
return new Prediction<>(maxLabel, predMap, features.numActiveElements()-1, example, generatesProbabilities);
return new Prediction<>(maxLabel, predMap, predTuple.numActiveFeatures-1, example, generatesProbabilities);
}

@Override
public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
int maxFeatures = n < 0 ? featureIDMap.size() + 1 : n;

Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));

//
// Use a priority queue to find the top N features.
int numClasses = weights.getDimension1Size();
int numFeatures = weights.getDimension2Size()-1; //Removing the bias feature.
Map<String, List<Pair<String,Double>>> map = new HashMap<>();
for (int i = 0; i < numClasses; i++) {
PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator);

for (int j = 0; j < numFeatures; j++) {
Pair<String,Double> curr = new Pair<>(featureIDMap.get(j).getName(), weights.get(i,j));

if (q.size() < maxFeatures) {
q.offer(curr);
} else if (comparator.compare(curr, q.peek()) > 0) {
q.poll();
q.offer(curr);
}
}
Pair<String,Double> curr = new Pair<>(BIAS_FEATURE, weights.get(i,numFeatures));

if (q.size() < maxFeatures) {
q.offer(curr);
} else if (comparator.compare(curr, q.peek()) > 0) {
q.poll();
q.offer(curr);
}
ArrayList<Pair<String,Double>> b = new ArrayList<>();
while (q.size() > 0) {
b.add(q.poll());
}

Collections.reverse(b);
map.put(outputIDInfo.getOutput(i).getLabel(), b);
}
return map;
protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) {
return new LinearSGDModel(newName,newProvenance,featureIDMap,outputIDInfo,new DenseMatrix(baseWeights),normalizer,generatesProbabilities);
}

@Override
public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
Prediction<Label> prediction = predict(example);
Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
int numClasses = weights.getDimension1Size();
int numFeatures = weights.getDimension2Size()-1;

for (int i = 0; i < numClasses; i++) {
List<Pair<String, Double>> classScores = new ArrayList<>();
for (Feature f : example) {
int id = featureIDMap.getID(f.getName());
if (id > -1) {
double score = weights.get(i,id) * f.getValue();
classScores.add(new Pair<>(f.getName(), score));
}
}
classScores.add(new Pair<>(Model.BIAS_FEATURE,weights.get(i,numFeatures)));
classScores.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
weightMap.put(outputIDInfo.getOutput(i).getLabel(), classScores);
}

return Optional.of(new Excuse<>(example, prediction, weightMap));
protected String getDimensionName(int index) {
return outputIDInfo.getOutput(index).getLabel();
}

@Override
protected LinearSGDModel copy(String newName, ModelProvenance newProvenance) {
return new LinearSGDModel(newName,newProvenance,featureIDMap,outputIDInfo,new DenseMatrix(weights),normalizer,generatesProbabilities);
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();

// Bounce old 4.0 style models into the new 4.1 style models
if (weights != null && baseWeights == null) {
baseWeights = weights;
weights = null;
}
}
}
Loading