From cbb356cf85443af1276655621cebf219e7a8fdb1 Mon Sep 17 00:00:00 2001 From: Luke Nezda Date: Tue, 6 Apr 2021 06:41:02 -0500 Subject: [PATCH 1/8] WIP prettyToString(ConfusionMatrix) and labelConfusionMatrixToString(ConfusionMatrix) - seems correct except testMultiLabelConfusionMatrixToStrings() labelConfusionMatrixToString has off-diagonal values? --- .../multilabel/IndependentMultiLabelTest.java | 103 ++++++++++++++++-- .../MultiLabelConfusionMatrixTest.java | 25 ++++- 2 files changed, 114 insertions(+), 14 deletions(-) diff --git a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java index b90105698..728859eea 100644 --- a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java +++ b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java @@ -16,23 +16,29 @@ package org.tribuo.multilabel; -import com.oracle.labs.mlrg.olcut.util.Pair; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.tribuo.Dataset; +import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; import org.tribuo.Prediction; +import org.tribuo.classification.evaluation.ClassifierEvaluation; +import org.tribuo.classification.evaluation.ConfusionMatrix; import org.tribuo.classification.sgd.linear.LinearSGDTrainer; import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer; import org.tribuo.multilabel.baseline.IndependentMultiLabelTrainer; +import org.tribuo.multilabel.evaluation.MultiLabelEvaluator; import org.tribuo.multilabel.example.MultiLabelDataGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; import org.tribuo.test.Helpers; - -import java.util.List; -import java.util.Map; -import java.util.logging.Level; -import java.util.logging.Logger; +import com.oracle.labs.mlrg.olcut.util.Pair; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -67,4 +73,83 @@ public void testIndependentBinaryPredictions() { Helpers.testModelSerialization(model,MultiLabel.class); } + @Test + public void testMultiLabelConfusionMatrixToStrings() { + Dataset train = MultiLabelDataGenerator.generateTrainData(); + Dataset test = MultiLabelDataGenerator.generateTestData(); + + IndependentMultiLabelTrainer trainer = new IndependentMultiLabelTrainer( + new LogisticRegressionTrainer()); + Model model = trainer.train(train); + + ClassifierEvaluation evaluation = new MultiLabelEvaluator() + .evaluate(model, test); + + System.out.println(evaluation); + + // MultiLabelConfusionMatrix toString() hard to interpret + final ConfusionMatrix mcm = evaluation.getConfusionMatrix(); + + System.out.println("original"); + System.out.println(mcm); + + System.out.println("\npretty"); + System.out.println(prettyToString(mcm)); + + System.out.println("\nlabelConfusionMatrixToString"); + System.out.println(labelConfusionMatrixToString(mcm)); + } + + public static String prettyToString(ConfusionMatrix mcmObject) { + return mcmObject.getDomain().getDomain().stream() + .map(multiLabel -> { + final int tp = (int) mcmObject.tp(multiLabel); + final int fn = (int) mcmObject.fn(multiLabel); + final int fp = (int) mcmObject.fp(multiLabel); + final int tn = (int) mcmObject.tn(multiLabel); + return multiLabel + "\n" + + String.format("[tn: %,d fn: %,d]\n", tn, fn) + + String.format("[fp: %,d tp: %,d]", fp, tp); + } + ).collect(Collectors.joining("\n")); + } + + public static String labelConfusionMatrixToString(ConfusionMatrix mcmObject) { + ImmutableOutputInfo domain = mcmObject.getDomain(); + List labelOrder = new ArrayList<>(domain.getDomain()); + + StringBuilder sb = new StringBuilder(); + + int maxLen = Integer.MIN_VALUE; + for (MultiLabel multiLabel : labelOrder) { + maxLen = Math.max(multiLabel.getLabelString().length(), maxLen); + maxLen = Math.max(String.format(" %,d", (int) mcmObject.support(multiLabel)).length(), maxLen); + } + + String trueLabelFormat = String.format("%%-%ds", maxLen + 2); + String predictedLabelFormat = String.format("%%%ds", maxLen + 2); + String countFormat = String.format("%%,%dd", maxLen + 2); + + // + // Empty spot in first row for labels on subsequent rows. + sb.append(String.format(trueLabelFormat, "")); + + // + // Labels across the top for predicted. + for (MultiLabel multiLabel : labelOrder) { + sb.append(String.format(predictedLabelFormat, multiLabel.getLabelString())); + } + sb.append('\n'); + + for (MultiLabel trueLabel : labelOrder) { + sb.append(String.format(trueLabelFormat, trueLabel.getLabelString())); + for (MultiLabel predictedLabel : labelOrder) { + int confusion = (int) mcmObject.confusion(predictedLabel, trueLabel); + sb.append(String.format(countFormat, confusion)); + } + sb.append('\n'); + } + + return sb.toString(); + } } diff --git a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java index 4a766a09c..7c80a0672 100644 --- a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java +++ b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java @@ -16,22 +16,23 @@ package org.tribuo.multilabel.evaluation; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Prediction; import org.tribuo.math.la.DenseMatrix; +import org.tribuo.multilabel.IndependentMultiLabelTest; import org.tribuo.multilabel.MultiLabel; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.tribuo.multilabel.Utils.getUnknown; import static org.tribuo.multilabel.Utils.label; import static org.tribuo.multilabel.Utils.mkDomain; import static org.tribuo.multilabel.Utils.mkPrediction; -import static org.junit.jupiter.api.Assertions.assertEquals; public class MultiLabelConfusionMatrixTest { @@ -158,6 +159,13 @@ public void testSingleLabel() { assertEquals(1, cm.support(c)); assertEquals(4, cm.support()); + + System.out.println("original"); + System.out.println(cm); + System.out.println("\npretty"); + System.out.println(IndependentMultiLabelTest.prettyToString(cm)); + System.out.println("\nlabelConfusionMatrixToString"); + System.out.println(IndependentMultiLabelTest.labelConfusionMatrixToString(cm)); } @Test @@ -231,6 +239,13 @@ public void testMultiLabel() { assertEquals(1, cm.support(c)); assertEquals(5, cm.support()); + + System.out.println("original"); + System.out.println(cm); + System.out.println("\npretty"); + System.out.println(IndependentMultiLabelTest.prettyToString(cm)); + System.out.println("\nlabelConfusionMatrixToString"); + System.out.println(IndependentMultiLabelTest.labelConfusionMatrixToString(cm)); } From b44933462296ce6e4a7aa41779315b7aee13a155 Mon Sep 17 00:00:00 2001 From: Luke Nezda Date: Wed, 14 Apr 2021 06:06:20 -0500 Subject: [PATCH 2/8] make MultiLabelConfusionMatrix.toString() easier for humans to read and interpret --- .../evaluation/MultiLabelConfusionMatrix.java | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java index c46b11b3e..5285afbcc 100644 --- a/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java +++ b/MultiLabel/Core/src/main/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrix.java @@ -16,6 +16,11 @@ package org.tribuo.multilabel.evaluation; +import java.util.List; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.tribuo.ImmutableOutputInfo; import org.tribuo.Model; import org.tribuo.Prediction; @@ -25,10 +30,6 @@ import org.tribuo.multilabel.MultiLabel; import org.tribuo.multilabel.MultiLabelFactory; -import java.util.List; -import java.util.Set; -import java.util.function.Function; - /** * A {@link ConfusionMatrix} which accepts {@link MultiLabel}s. * @@ -158,15 +159,18 @@ public double confusion(MultiLabel predicted, MultiLabel truth) { @Override public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("["); - for (int i = 0; i < mcm.length; i++) { - DenseMatrix cm = mcm[i]; - sb.append(cm.toString()); - sb.append("\n"); - } - sb.append("]"); - return sb.toString(); + return getDomain().getDomain().stream() + .map(multiLabel -> { + final int tp = (int) tp(multiLabel); + final int fn = (int) fn(multiLabel); + final int fp = (int) fp(multiLabel); + final int tn = (int) tn(multiLabel); + return String.join("\n", + multiLabel.toString(), + String.format(" [tn: %,d fn: %,d]", tn, fn), + String.format(" [fp: %,d tp: %,d]", fp, tp)); + } + ).collect(Collectors.joining("\n")); } static ConfusionMatrixTuple tabulate(ImmutableOutputInfo domain, List> predictions) { From 670ebb2a8a83d35b887c6e62268632726fd52cfc Mon Sep 17 00:00:00 2001 From: Luke Nezda Date: Wed, 14 Apr 2021 06:08:00 -0500 Subject: [PATCH 3/8] remove sketch version of IndependentMultiLabelTest.prettyToString and usages --- .../multilabel/IndependentMultiLabelTest.java | 20 +------------------ .../MultiLabelConfusionMatrixTest.java | 8 ++------ 2 files changed, 3 insertions(+), 25 deletions(-) diff --git a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java index 728859eea..f678659b0 100644 --- a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java +++ b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java @@ -21,7 +21,6 @@ import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; -import java.util.stream.Collectors; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; @@ -90,30 +89,13 @@ public void testMultiLabelConfusionMatrixToStrings() { // MultiLabelConfusionMatrix toString() hard to interpret final ConfusionMatrix mcm = evaluation.getConfusionMatrix(); - System.out.println("original"); + System.out.println("new toString()"); System.out.println(mcm); - System.out.println("\npretty"); - System.out.println(prettyToString(mcm)); - System.out.println("\nlabelConfusionMatrixToString"); System.out.println(labelConfusionMatrixToString(mcm)); } - public static String prettyToString(ConfusionMatrix mcmObject) { - return mcmObject.getDomain().getDomain().stream() - .map(multiLabel -> { - final int tp = (int) mcmObject.tp(multiLabel); - final int fn = (int) mcmObject.fn(multiLabel); - final int fp = (int) mcmObject.fp(multiLabel); - final int tn = (int) mcmObject.tn(multiLabel); - return multiLabel + "\n" - + String.format("[tn: %,d fn: %,d]\n", tn, fn) - + String.format("[fp: %,d tp: %,d]", fp, tp); - } - ).collect(Collectors.joining("\n")); - } - public static String labelConfusionMatrixToString(ConfusionMatrix mcmObject) { ImmutableOutputInfo domain = mcmObject.getDomain(); List labelOrder = new ArrayList<>(domain.getDomain()); diff --git a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java index 7c80a0672..e8ca9c406 100644 --- a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java +++ b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/evaluation/MultiLabelConfusionMatrixTest.java @@ -160,10 +160,8 @@ public void testSingleLabel() { assertEquals(4, cm.support()); - System.out.println("original"); + System.out.println("new toString()"); System.out.println(cm); - System.out.println("\npretty"); - System.out.println(IndependentMultiLabelTest.prettyToString(cm)); System.out.println("\nlabelConfusionMatrixToString"); System.out.println(IndependentMultiLabelTest.labelConfusionMatrixToString(cm)); } @@ -240,10 +238,8 @@ public void testMultiLabel() { assertEquals(5, cm.support()); - System.out.println("original"); + System.out.println("new toString()"); System.out.println(cm); - System.out.println("\npretty"); - System.out.println(IndependentMultiLabelTest.prettyToString(cm)); System.out.println("\nlabelConfusionMatrixToString"); System.out.println(IndependentMultiLabelTest.labelConfusionMatrixToString(cm)); } From a44689dd531ce34734626ba8c1a7acdc5f7adf67 Mon Sep 17 00:00:00 2001 From: Luke Nezda Date: Fri, 16 Apr 2021 08:21:18 -0500 Subject: [PATCH 4/8] attempt to fix labelConfusionMatrixToString --- .../multilabel/IndependentMultiLabelTest.java | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java index f678659b0..8be2d56d2 100644 --- a/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java +++ b/MultiLabel/Core/src/test/java/org/tribuo/multilabel/IndependentMultiLabelTest.java @@ -94,6 +94,9 @@ public void testMultiLabelConfusionMatrixToStrings() { System.out.println("\nlabelConfusionMatrixToString"); System.out.println(labelConfusionMatrixToString(mcm)); + + System.out.println("\npredictions"); + evaluation.getPredictions().forEach(System.out::println); } public static String labelConfusionMatrixToString(ConfusionMatrix mcmObject) { @@ -123,10 +126,16 @@ public static String labelConfusionMatrixToString(ConfusionMatrix mc } sb.append('\n'); - for (MultiLabel trueLabel : labelOrder) { - sb.append(String.format(trueLabelFormat, trueLabel.getLabelString())); - for (MultiLabel predictedLabel : labelOrder) { + for (MultiLabel predictedLabel : labelOrder) { + sb.append(String.format(trueLabelFormat, predictedLabel.getLabelString())); + for (MultiLabel trueLabel : labelOrder) { int confusion = (int) mcmObject.confusion(predictedLabel, trueLabel); + int fp = (int) mcmObject.fp(trueLabel); + if (confusion > 0 && !trueLabel.equals(predictedLabel) && fp == 0) { + // not actual confusion - fp == 0 + // FIXME likely incomplete, wrong for more involved example + confusion = 0; + } sb.append(String.format(countFormat, confusion)); } sb.append('\n'); From 6a5d55b2e6c36ac61d960de618205c44df1b4fba Mon Sep 17 00:00:00 2001 From: Luke Nezda Date: Fri, 16 Apr 2021 08:23:39 -0500 Subject: [PATCH 5/8] include label there are no predictions for in warning log message --- .../tribuo/classification/evaluation/ConfusionMetrics.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java b/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java index 029a75cc6..dc8d3c007 100644 --- a/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java +++ b/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java @@ -16,12 +16,12 @@ package org.tribuo.classification.evaluation; +import java.util.logging.Logger; + import org.tribuo.classification.Classifiable; import org.tribuo.evaluation.metrics.EvaluationMetric.Average; import org.tribuo.evaluation.metrics.MetricTarget; -import java.util.logging.Logger; - /** * Static functions for computing classification metrics based on a {@link ConfusionMatrix}. */ @@ -60,7 +60,7 @@ public static > double accuracy(T label, ConfusionMatr double support = cm.support(label); // handle div-by-zero if (support == 0d) { - logger.warning("No predictions: accuracy ill-defined"); + logger.warning("No predictions for " + label + ": accuracy ill-defined"); return Double.NaN; } return cm.tp(label) / cm.support(label); From 35735aa50a6cc9c60b36a39e9a9f45194839ea71 Mon Sep 17 00:00:00 2001 From: Luke Nezda Date: Thu, 22 Apr 2021 10:29:24 -0500 Subject: [PATCH 6/8] WIP utility singleLabelConfusionMatrix(List> predictions) --- .../evaluation/LabelConfusionMatrixTest.java | 25 ++-- .../multilabel/IndependentMultiLabelTest.java | 117 +++++++++++------- .../MultiLabelConfusionMatrixTest.java | 10 +- 3 files changed, 88 insertions(+), 64 deletions(-) diff --git a/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java b/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java index 25b5572cd..38cef80c0 100644 --- a/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java +++ b/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java @@ -16,18 +16,18 @@ package org.tribuo.classification.evaluation; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Prediction; import org.tribuo.classification.Label; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.tribuo.classification.Utils.label; import static org.tribuo.classification.Utils.mkDomain; import static org.tribuo.classification.Utils.mkPrediction; -import static org.junit.jupiter.api.Assertions.assertEquals; public class LabelConfusionMatrixTest { @@ -38,7 +38,8 @@ public void testMulticlass() { mkPrediction("a", "a"), mkPrediction("c", "b"), mkPrediction("b", "b"), - mkPrediction("b", "c") + mkPrediction("b", "c"), + mkPrediction("a", "b") ); ImmutableOutputInfo