diff --git a/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java b/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java index 029a75cc6..dc8d3c007 100644 --- a/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java +++ b/Classification/Core/src/main/java/org/tribuo/classification/evaluation/ConfusionMetrics.java @@ -16,12 +16,12 @@ package org.tribuo.classification.evaluation; +import java.util.logging.Logger; + import org.tribuo.classification.Classifiable; import org.tribuo.evaluation.metrics.EvaluationMetric.Average; import org.tribuo.evaluation.metrics.MetricTarget; -import java.util.logging.Logger; - /** * Static functions for computing classification metrics based on a {@link ConfusionMatrix}. */ @@ -60,7 +60,7 @@ public static > double accuracy(T label, ConfusionMatr double support = cm.support(label); // handle div-by-zero if (support == 0d) { - logger.warning("No predictions: accuracy ill-defined"); + logger.warning("No predictions for " + label + ": accuracy ill-defined"); return Double.NaN; } return cm.tp(label) / cm.support(label); diff --git a/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java b/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java index 25b5572cd..38cef80c0 100644 --- a/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java +++ b/Classification/Core/src/test/java/org/tribuo/classification/evaluation/LabelConfusionMatrixTest.java @@ -16,18 +16,18 @@ package org.tribuo.classification.evaluation; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Prediction; import org.tribuo.classification.Label; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.tribuo.classification.Utils.label; import static org.tribuo.classification.Utils.mkDomain; import static org.tribuo.classification.Utils.mkPrediction; -import static org.junit.jupiter.api.Assertions.assertEquals; public class LabelConfusionMatrixTest { @@ -38,7 +38,8 @@ public void testMulticlass() { mkPrediction("a", "a"), mkPrediction("c", "b"), mkPrediction("b", "b"), - mkPrediction("b", "c") + mkPrediction("b", "c"), + mkPrediction("a", "b") ); ImmutableOutputInfo