Skip to content

Commit

Permalink
support missing count in confusion matrix in model eval panel
Browse files Browse the repository at this point in the history
  • Loading branch information
imanjra committed Dec 19, 2024
1 parent 65ffd8b commit 76a1492
Showing 1 changed file with 33 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import { formatValue, getNumericDifference, useTriggerEvent } from "./utils";
const KEY_COLOR = "#ff6d04";
const COMPARE_KEY_COLOR = "#03a9f4";
const DEFAULT_BAR_CONFIG = { sortBy: "default" };
const NONE_CLASS = "(none)";

export default function Evaluation(props: EvaluationProps) {
const {
Expand All @@ -71,14 +72,22 @@ export default function Evaluation(props: EvaluationProps) {
const [expanded, setExpanded] = React.useState("summary");
const [mode, setMode] = useState("chart");
const [editNoteState, setEditNoteState] = useState({ open: false, note: "" });
const [classPerformanceConfig, setClassPerformanceConfig] =
useState<PLOT_CONFIG_TYPE>({});
const [classPerformanceDialogConfig, setClassPerformanceDialogConfig] =
useState<PLOT_CONFIG_DIALOG_TYPE>(DEFAULT_BAR_CONFIG);
const [confusionMatrixConfig, setConfusionMatrixConfig] =
useState<PLOT_CONFIG_TYPE>({ log: true });
const [confusionMatrixDialogConfig, setConfusionMatrixDialogConfig] =
useState<PLOT_CONFIG_DIALOG_TYPE>(DEFAULT_BAR_CONFIG);
const [
classPerformanceConfig,
setClassPerformanceConfig,
] = useState<PLOT_CONFIG_TYPE>({});
const [
classPerformanceDialogConfig,
setClassPerformanceDialogConfig,
] = useState<PLOT_CONFIG_DIALOG_TYPE>(DEFAULT_BAR_CONFIG);
const [
confusionMatrixConfig,
setConfusionMatrixConfig,
] = useState<PLOT_CONFIG_TYPE>({ log: true });
const [
confusionMatrixDialogConfig,
setConfusionMatrixDialogConfig,
] = useState<PLOT_CONFIG_DIALOG_TYPE>(DEFAULT_BAR_CONFIG);
const [metricMode, setMetricMode] = useState("chart");
const [classMode, setClassMode] = useState("chart");
const [performanceClass, setPerformanceClass] = useState("precision");
Expand Down Expand Up @@ -1656,12 +1665,23 @@ function getMatrix(matrices, config, maskTargets, compareMaskTargets?) {
if (!matrices) return;
const { sortBy = "az", limit } = config;
const parsedLimit = typeof limit === "number" ? limit : undefined;
const classes = matrices[`${sortBy}_classes`].slice(0, parsedLimit);
const matrix = matrices[`${sortBy}_matrix`].slice(0, parsedLimit);
const originalClasses = matrices[`${sortBy}_classes`];
const originalMatrix = matrices[`${sortBy}_matrix`];
const classes = originalClasses.slice(0, parsedLimit);
const matrix = originalMatrix.slice(0, parsedLimit);
const colorscale = matrices[`${sortBy}_colorscale`];
const labels = classes.map((c) => {
return compareMaskTargets?.[c] || maskTargets?.[c] || c;
});
const noneIndex = originalClasses.indexOf(NONE_CLASS);
if (parsedLimit < originalClasses.length) {
classes.push(
compareMaskTargets?.[NONE_CLASS] ||
maskTargets?.[NONE_CLASS] ||
NONE_CLASS
);
matrix.push(originalMatrix[noneIndex]);
}
return { labels, matrix, colorscale };
}

Expand All @@ -1672,9 +1692,8 @@ function getConfigLabel({ config, type, dashed }) {
type === "classPerformance"
? CLASS_PERFORMANCE_SORT_OPTIONS
: CONFUSION_MATRIX_SORT_OPTIONS;
const sortByLabel = sortByLabels.find(
(option) => option.value === sortBy
)?.label;
const sortByLabel = sortByLabels.find((option) => option.value === sortBy)
?.label;
return dashed ? ` - ${sortByLabel}` : sortByLabel;
}

Expand All @@ -1687,7 +1706,7 @@ function useActiveFilter(evaluation, compareEvaluation) {
const { _cls, kwargs } = stage;
if (_cls.endsWith("FilterLabels")) {
const [_, filter] = kwargs;
const filterEq = filter[1].$eq;
const filterEq = filter[1].$eq || [];
const [filterEqLeft, filterEqRight] = filterEq;
if (filterEqLeft === "$$this.label") {
return { type: "label", value: filterEqRight };
Expand Down

0 comments on commit 76a1492

Please sign in to comment.