Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP prettyToString(ConfusionMatrix<MultiLabel>) and labelConfusionMat… #128

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package org.tribuo.classification.evaluation;

import java.util.logging.Logger;

import org.tribuo.classification.Classifiable;
import org.tribuo.evaluation.metrics.EvaluationMetric.Average;
import org.tribuo.evaluation.metrics.MetricTarget;

import java.util.logging.Logger;

/**
* Static functions for computing classification metrics based on a {@link ConfusionMatrix}.
*/
Expand Down Expand Up @@ -60,7 +60,7 @@ public static <T extends Classifiable<T>> double accuracy(T label, ConfusionMatr
double support = cm.support(label);
// handle div-by-zero
if (support == 0d) {
logger.warning("No predictions: accuracy ill-defined");
logger.warning("No predictions for " + label + ": accuracy ill-defined");
return Double.NaN;
}
return cm.tp(label) / cm.support(label);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@

package org.tribuo.classification.evaluation;

import java.util.Arrays;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.classification.Utils.label;
import static org.tribuo.classification.Utils.mkDomain;
import static org.tribuo.classification.Utils.mkPrediction;
import static org.junit.jupiter.api.Assertions.assertEquals;


public class LabelConfusionMatrixTest {
Expand All @@ -38,7 +38,8 @@ public void testMulticlass() {
mkPrediction("a", "a"),
mkPrediction("c", "b"),
mkPrediction("b", "b"),
mkPrediction("b", "c")
mkPrediction("b", "c"),
mkPrediction("a", "b")
);
ImmutableOutputInfo<Label> domain = mkDomain(predictions);
LabelConfusionMatrix cm = new LabelConfusionMatrix(domain, predictions);
Expand All @@ -54,25 +55,25 @@ public void testMulticlass() {
assertEquals(1, cm.tp(a));
assertEquals(0, cm.fp(a));
assertEquals(3, cm.tn(a));
assertEquals(0, cm.fn(a));
assertEquals(1, cm.support(a));
assertEquals(1, cm.fn(a));
assertEquals(2, cm.support(a));

assertEquals(1, cm.tp(b));
assertEquals(1, cm.fp(b));
assertEquals(2, cm.fp(b));
assertEquals(1, cm.tn(b));
assertEquals(1, cm.fn(b));
assertEquals(2, cm.support(b));

assertEquals(0, cm.tp(c));
assertEquals(1, cm.fp(c));
assertEquals(2, cm.tn(c));
assertEquals(3, cm.tn(c));
assertEquals(1, cm.fn(c));
assertEquals(1, cm.support(c));

assertEquals(4, cm.support());
assertEquals(5, cm.support());
String cmToString = cm.toString();
assertEquals(" a b c\n" +
"a 1 0 0\n" +
"a 1 1 0\n" +
"b 0 1 1\n" +
"c 0 1 0\n", cmToString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

package org.tribuo.multilabel.evaluation;

import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
Expand All @@ -25,10 +30,6 @@
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.MultiLabelFactory;

import java.util.List;
import java.util.Set;
import java.util.function.Function;

/**
* A {@link ConfusionMatrix} which accepts {@link MultiLabel}s.
*
Expand Down Expand Up @@ -158,15 +159,18 @@ public double confusion(MultiLabel predicted, MultiLabel truth) {

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < mcm.length; i++) {
DenseMatrix cm = mcm[i];
sb.append(cm.toString());
sb.append("\n");
}
sb.append("]");
return sb.toString();
return getDomain().getDomain().stream()
.map(multiLabel -> {
final int tp = (int) tp(multiLabel);
final int fn = (int) fn(multiLabel);
final int fp = (int) fp(multiLabel);
final int tn = (int) tn(multiLabel);
return String.join("\n",
multiLabel.toString(),
String.format(" [tn: %,d fn: %,d]", tn, fn),
String.format(" [fp: %,d tp: %,d]", fp, tp));
}
).collect(Collectors.joining("\n"));
}

static ConfusionMatrixTuple tabulate(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,38 @@

package org.tribuo.multilabel;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.MutableLabelInfo;
import org.tribuo.classification.evaluation.ClassifierEvaluation;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.classification.evaluation.LabelConfusionMatrix;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
import org.tribuo.impl.ListExample;
import org.tribuo.multilabel.baseline.IndependentMultiLabelTrainer;
import org.tribuo.multilabel.evaluation.MultiLabelEvaluator;
import org.tribuo.multilabel.example.MultiLabelDataGenerator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.test.Helpers;

import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.oracle.labs.mlrg.olcut.util.Pair;

import static org.junit.jupiter.api.Assertions.assertEquals;

Expand Down Expand Up @@ -67,4 +82,88 @@ public void testIndependentBinaryPredictions() {
Helpers.testModelSerialization(model,MultiLabel.class);
}

@Test
public void testMultiLabelConfusionMatrixToStrings() {
Dataset<MultiLabel> train = MultiLabelDataGenerator.generateTrainData();
Dataset<MultiLabel> test = MultiLabelDataGenerator.generateTestData();

IndependentMultiLabelTrainer trainer = new IndependentMultiLabelTrainer(
new LogisticRegressionTrainer());
Model<MultiLabel> model = trainer.train(train);

ClassifierEvaluation<MultiLabel> evaluation = new MultiLabelEvaluator()
.evaluate(model, test);

System.out.println(evaluation);

// MultiLabelConfusionMatrix toString() hard to interpret
final ConfusionMatrix<MultiLabel> mcm = evaluation.getConfusionMatrix();

System.out.println("new toString()");
System.out.println(mcm);

System.out.println("\npredictions");
evaluation.getPredictions().forEach(System.out::println);

final List<Prediction<MultiLabel>> predictions = evaluation.getPredictions();
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}

public static LabelConfusionMatrix singleLabelConfusionMatrix(final List<Prediction<MultiLabel>> predictions) {
final List<Prediction<Label>> singleLabelPredictions = mkSingleLabelPredictions(predictions);
ImmutableOutputInfo<Label> domain = mkDomain(singleLabelPredictions);
LabelConfusionMatrix cm = new LabelConfusionMatrix(domain, singleLabelPredictions);
return cm;
}

public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<MultiLabel>> predictions) {
return predictions.stream()
.flatMap(p -> {
final Set<Label> trueLabels = p.getExample().getOutput().getLabelSet();
final Set<Label> predicted = p.getOutput().getLabelSet();
// intersection(trueLabels, predicted) = true positives
// predicted - trueLabels = false positives
// trueLabels - predicted = false negatives
return predicted.stream().map(pred -> {
if (trueLabels.contains(pred)) {
return mkPrediction(pred.getLabel(), pred.getLabel());
} else if (trueLabels.size() == 1) {
return mkPrediction(trueLabels.iterator().next().getLabel(), pred.getLabel());
} else {
// arbitrarily pick first trueLabel
return mkPrediction(trueLabels.iterator().next().getLabel(), pred.getLabel());
}
});
}).collect(Collectors.toList());
}

// FIXME HACK copied from Classification/Core/src/test/java/org/tribuo/classification/Utils.java

public static Prediction<Label> mkPrediction(String trueVal, String predVal) {
LabelFactory factory = new LabelFactory();
Example<Label> example = new ListExample<>(factory.generateOutput(trueVal));
example.add(new Feature("noop", 1d));
Prediction<Label> prediction = new Prediction<>(factory.generateOutput(predVal), 0, example);
return prediction;
}

public static ImmutableOutputInfo<Label> mkDomain(List<Prediction<Label>> predictions) {
// MutableLabelInfo info = new MutableLabelInfo();
// FIXME hack call package private ctor of MutableLabelInfo
// TODO just make that public
nezda marked this conversation as resolved.
Show resolved Hide resolved
final MutableLabelInfo info;
try {
Constructor<MutableLabelInfo> ctor = MutableLabelInfo.class.getDeclaredConstructor();
ctor.setAccessible(true);
info = ctor.newInstance();
} catch (NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) {
throw new RuntimeException(e);
}
for (Prediction<Label> p : predictions) {
info.observe(p.getExample().getOutput());
info.observe(p.getOutput()); // TODO? LN added
}
return info.generateImmutableOutputInfo();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@

package org.tribuo.multilabel.evaluation;

import java.util.Arrays;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.multilabel.MultiLabel;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.tribuo.multilabel.IndependentMultiLabelTest.singleLabelConfusionMatrix;
import static org.tribuo.multilabel.Utils.getUnknown;
import static org.tribuo.multilabel.Utils.label;
import static org.tribuo.multilabel.Utils.mkDomain;
import static org.tribuo.multilabel.Utils.mkPrediction;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class MultiLabelConfusionMatrixTest {

Expand Down Expand Up @@ -158,6 +159,11 @@ public void testSingleLabel() {
assertEquals(1, cm.support(c));

assertEquals(4, cm.support());

System.out.println("new toString()");
System.out.println(cm);
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}

@Test
Expand Down Expand Up @@ -231,6 +237,11 @@ public void testMultiLabel() {
assertEquals(1, cm.support(c));

assertEquals(5, cm.support());

System.out.println("new toString()");
System.out.println(cm);
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}


Expand Down