diff --git a/webui/react/src/components/CompareHyperparameters.tsx b/webui/react/src/components/CompareHyperparameters.tsx index 9f5b20529254..0ff91ab0cde4 100644 --- a/webui/react/src/components/CompareHyperparameters.tsx +++ b/webui/react/src/components/CompareHyperparameters.tsx @@ -12,7 +12,7 @@ import ExperimentVisualizationFilters, { VisualizationFilters, } from 'pages/ExperimentDetails/ExperimentVisualization/ExperimentVisualizationFilters'; import { TrialMetricData } from 'pages/TrialDetails/useTrialMetrics'; -import { ExperimentWithTrial, TrialItem } from 'types'; +import { ExperimentWithTrial, FlatRun, TrialItem, XOR } from 'types'; import CompareHeatMaps from './CompareHeatMaps'; import { @@ -23,17 +23,22 @@ import CompareParallelCoordinates from './CompareParallelCoordinates'; import CompareScatterPlots from './CompareScatterPlots'; import css from './HpParallelCoordinates.module.scss'; -interface Props { +interface BaseProps { projectId: number; - selectedExperiments: ExperimentWithTrial[]; - trials: TrialItem[]; metricData: TrialMetricData; } +type Props = XOR< + { selectedExperiments: ExperimentWithTrial[]; trials: TrialItem[] }, + { selectedRuns: FlatRun[] } +> & + BaseProps; + export const NO_DATA_MESSAGE = 'No data available.'; const CompareHyperparameters: React.FC = ({ selectedExperiments, + selectedRuns, trials, projectId, metricData, @@ -42,9 +47,14 @@ const CompareHyperparameters: React.FC = ({ const fullHParams: string[] = useMemo(() => { const hpParams = new Set(); - trials.forEach((trial) => Object.keys(trial.hyperparameters).forEach((hp) => hpParams.add(hp))); + trials?.forEach((trial) => + Object.keys(trial.hyperparameters).forEach((hp) => hpParams.add(hp)), + ); + selectedRuns?.forEach((run) => + Object.keys(run.hyperparameters ?? {}).forEach((hp) => hpParams.add(hp)), + ); return Array.from(hpParams); - }, [trials]); + }, [selectedRuns, trials]); const settingsConfig = useMemo( () => settingsConfigForCompareHyperparameters(fullHParams, projectId), @@ -120,11 +130,11 @@ const CompareHyperparameters: React.FC = ({ return ; } - if (trials.length === 0) { + if ((trials ?? selectedRuns).length === 0) { return ; } - if (selectedExperiments.length !== 0 && metrics.length === 0) { + if ((selectedExperiments ?? selectedRuns).length !== 0 && metrics.length === 0) { return (
= ({
- {selectedExperiments.length > 0 && ( + {(selectedExperiments ?? selectedRuns).length > 0 && ( <> Parallel Coordinates - + {selectedRuns ? ( + + ) : ( + + )} Scatter Plots - + {selectedRuns ? ( + + ) : ( + + )} Heat Maps - + {selectedRuns ? ( + + ) : ( + + )} )}
diff --git a/webui/react/src/components/CompareParallelCoordinates.tsx b/webui/react/src/components/CompareParallelCoordinates.tsx index 16569835b3dc..8fe4ac4ae453 100644 --- a/webui/react/src/components/CompareParallelCoordinates.tsx +++ b/webui/react/src/components/CompareParallelCoordinates.tsx @@ -9,16 +9,17 @@ import { useGlasbey } from 'hooks/useGlasbey'; import { TrialMetricData } from 'pages/TrialDetails/useTrialMetrics'; import { ExperimentWithTrial, + FlatRun, HpTrialData, Hyperparameter, HyperparameterType, Primitive, - Range, Scale, TrialItem, XAxisDomain, + XOR, } from 'types'; -import { defaultNumericRange, getNumericRange, updateRange } from 'utils/chart'; +import { getNumericRange } from 'utils/chart'; import { flattenObject, isPrimitive } from 'utils/data'; import { metricToKey, metricToStr } from 'utils/metric'; import { numericSorter } from 'utils/sort'; @@ -28,28 +29,37 @@ import css from './HpParallelCoordinates.module.scss'; export const COMPARE_PARALLEL_COORDINATES = 'compare-parallel-coordinates'; -interface Props { +interface BaseProps { projectId: number; - selectedExperiments: ExperimentWithTrial[]; - trials: TrialItem[]; metricData: TrialMetricData; settings: CompareHyperparametersSettings; fullHParams: string[]; } +type Props = XOR< + { selectedExperiments: ExperimentWithTrial[]; trials: TrialItem[] }, + { selectedRuns: FlatRun[] } +> & + BaseProps; + const CompareParallelCoordinates: React.FC = ({ selectedExperiments, trials, settings, metricData, fullHParams, + selectedRuns, }: Props) => { const [chartData, setChartData] = useState(); const [hermesCreatedFilters, setHermesCreatedFilters] = useState({}); const { metrics, data, isLoaded, setScale } = metricData; - const colorMap = useGlasbey(selectedExperiments.map((e) => e.experiment.id)); + const colorMap = useGlasbey( + selectedExperiments + ? selectedExperiments.map((e) => e.experiment.id) + : selectedRuns.map((r) => r.id), + ); const selectedScale = settings.scale; useEffect(() => { @@ -61,7 +71,7 @@ const CompareParallelCoordinates: React.FC = ({ const experimentHyperparameters = useMemo(() => { const hpMap: Record = {}; - selectedExperiments.forEach((exp) => { + selectedExperiments?.forEach((exp) => { const hps = Object.keys(exp.experiment.hyperparameters); hps.forEach((hp) => (hpMap[hp] = exp.experiment.hyperparameters[hp])); }); @@ -140,67 +150,99 @@ const CompareParallelCoordinates: React.FC = ({ useEffect(() => { if (!selectedMetric) return; - const trialMetricsMap: Record = {}; - const trialHpMap: Record> = {}; - - const trialHpdata: Record = {}; - let trialMetricRange: Range = defaultNumericRange(true); - - trials?.forEach((trial) => { - const expId = trial.experimentId; - const key = metricToKey(selectedMetric); - - // Choose the final metric value for each trial - const metricValue = data?.[trial.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1]; - - if (!metricValue) return; - trialMetricsMap[expId] = metricValue; - - trialMetricRange = updateRange(trialMetricRange, metricValue); - const flatHParams = { - ...trial?.hyperparameters, - ...flattenObject(trial?.hyperparameters || {}), - }; - - Object.keys(flatHParams).forEach((hpKey) => { - const hpValue = flatHParams[hpKey]; - trialHpMap[hpKey] = trialHpMap[hpKey] ?? {}; - trialHpMap[hpKey][expId] = isPrimitive(hpValue) - ? (hpValue as Primitive) - : JSON.stringify(hpValue); + const metricsMap: Record = {}; + const hpMap: Record> = {}; + + const hpData: Record = {}; + + if (trials) { + trials.forEach((trial) => { + const expId = trial.experimentId; + const key = metricToKey(selectedMetric); + + // Choose the final metric value for each trial + const metricValue = data?.[trial.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1]; + + if (!metricValue) return; + metricsMap[expId] = metricValue; + + const flatHParams = { + ...trial?.hyperparameters, + ...flattenObject(trial?.hyperparameters || {}), + }; + + Object.keys(flatHParams).forEach((hpKey) => { + const hpValue = flatHParams[hpKey]; + hpMap[hpKey] = hpMap[hpKey] ?? {}; + hpMap[hpKey][expId] = isPrimitive(hpValue) + ? (hpValue as Primitive) + : JSON.stringify(hpValue); + }); }); - }); + } else if (selectedRuns) { + selectedRuns.forEach((run) => { + const key = metricToKey(selectedMetric); + + // Choose the final metric value for each trial + const metricValue = data?.[run.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1]; - const trialIds = Object.keys(trialMetricsMap) + if (!metricValue) return; + metricsMap[run.id] = metricValue; + + const flatHParams = { + ...run.hyperparameters, + ...flattenObject(run.hyperparameters || {}), + }; + + Object.keys(flatHParams).forEach((hpKey) => { + const hpValue = flatHParams[hpKey]; + hpMap[hpKey] = hpMap[hpKey] ?? {}; + hpMap[hpKey][run.id] = isPrimitive(hpValue) + ? (hpValue as Primitive) + : JSON.stringify(hpValue); + }); + }); + } + + const trialIds = Object.keys(metricsMap) .map((id) => parseInt(id)) .sort(numericSorter); - Object.keys(trialHpMap).forEach((hpKey) => { - trialHpdata[hpKey] = trialIds.map((trialId) => trialHpMap[hpKey][trialId]); + Object.keys(hpMap).forEach((hpKey) => { + hpData[hpKey] = trialIds.map((trialId) => hpMap[hpKey][trialId]); }); const metricKey = metricToStr(selectedMetric); - const metricValues = trialIds.map((id) => trialMetricsMap[id]); - trialHpdata[metricKey] = metricValues; + const metricValues = trialIds.map((id) => metricsMap[id]); + hpData[metricKey] = metricValues; const metricRange = getNumericRange(metricValues); setChartData({ - data: trialHpdata, + data: hpData, metricRange, metricValues, trialIds, }); - }, [selectedExperiments, selectedMetric, fullHParams, metricData, selectedScale, trials, data]); + }, [ + selectedExperiments, + selectedMetric, + fullHParams, + metricData, + selectedScale, + trials, + data, + selectedRuns, + ]); if (!isLoaded) { return ; } - if (trials.length === 0) { + if ((trials ?? selectedRuns).length === 0) { return ; } - if (!chartData || (selectedExperiments.length !== 0 && metrics.length === 0)) { + if (!chartData || ((selectedExperiments ?? selectedRuns).length !== 0 && metrics.length === 0)) { return (
= ({ return (
- {selectedExperiments.length > 0 && ( + {(selectedExperiments ?? selectedRuns).length > 0 && (
& + BaseProps; + interface HpMetricData { hpLabels: Record; hpLogScales: Record; hpValues: Record; metricValues: Record; - trialIds: number[]; + recordIds: number[]; } const CompareScatterPlots: React.FC = ({ @@ -47,6 +54,7 @@ const CompareScatterPlots: React.FC = ({ settings, metricData, selectedExperiments, + selectedRuns, }: Props) => { const baseRef = useRef(null); const [chartData, setChartData] = useState(); @@ -92,7 +100,7 @@ const CompareScatterPlots: React.FC = ({ null, null, null, - chartData?.trialIds || [], + chartData?.recordIds || [], ], ], options: { @@ -107,7 +115,7 @@ const CompareScatterPlots: React.FC = ({ cursor: { drag: { setScale: false, x: false, y: false } }, title, }, - tooltipLabels: [xLabel, yLabel, null, null, null, 'trial ID'], + tooltipLabels: [xLabel, yLabel, null, null, null, 'record ID'], }; return acc; }, {}); @@ -139,7 +147,7 @@ const CompareScatterPlots: React.FC = ({ const experimentHyperparameters = useMemo(() => { const hpMap: Record = {}; - selectedExperiments.forEach((exp) => { + selectedExperiments?.forEach((exp) => { const hps = Object.keys(exp.experiment.hyperparameters); hps.forEach((hp) => (hpMap[hp] = exp.experiment.hyperparameters[hp])); }); @@ -149,7 +157,7 @@ const CompareScatterPlots: React.FC = ({ useEffect(() => { if (!selectedMetric) return; - const trialIds: number[] = []; + const recordIds: number[] = []; const hpTrialMap: Record< string, Record @@ -160,22 +168,25 @@ const CompareScatterPlots: React.FC = ({ const hpLabelMap: Record = {}; const hpLogScaleMap: Record = {}; - trials.forEach((trial) => { - const trialId = trial.id; - trialIds.push(trialId); + const recordHyperparameters: [number, TrialHyperparameters][] = selectedRuns + ? selectedRuns.flatMap((run) => (run.hyperparameters ? [[run.id, run.hyperparameters]] : [])) + : trials.map((trial) => [trial.id, trial.hyperparameters]); + + recordHyperparameters?.forEach(([recordId, recordHp]) => { + recordIds.push(recordId); - const flatHParams = flattenObject(trial.hyperparameters); + const flatHParams = flattenObject(recordHp); fullHParams.forEach((hParam: string) => { /** * TODO: filtering NaN, +/- Infinity for now, but handle it later with * dynamic min/max ranges via uPlot.Scales. */ const key = metricToKey(selectedMetric); - const trialMetric = data?.[trial.id]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1]; + const trialMetric = data?.[recordId]?.[key]?.data?.[XAxisDomain.Batches]?.at(-1)?.[1]; hpTrialMap[hParam] = hpTrialMap[hParam] || {}; - hpTrialMap[hParam][trialId] = hpTrialMap[hParam][trialId] || {}; - hpTrialMap[hParam][trialId] = { + hpTrialMap[hParam][recordId] = hpTrialMap[hParam][recordId] || {}; + hpTrialMap[hParam][recordId] = { hp: flatHParams[hParam], metric: trialMetric, }; @@ -185,8 +196,8 @@ const CompareScatterPlots: React.FC = ({ hpMetricMap[hParam] = []; hpValueMap[hParam] = []; hpLabelMap[hParam] = []; - trialIds.forEach((trialId) => { - const map = hpTrialMap[hParam]?.[trialId] || {}; + recordIds.forEach((recordId) => { + const map = hpTrialMap[hParam]?.[recordId] || {}; const hpValue = isBoolean(map.hp) ? map.hp.toString() : map.hp; if (isString(hpValue)) { @@ -211,9 +222,9 @@ const CompareScatterPlots: React.FC = ({ hpLogScales: hpLogScaleMap, hpValues: hpValueMap, metricValues: hpMetricMap, - trialIds, + recordIds, }); - }, [fullHParams, experimentHyperparameters, selectedMetric, trials, data]); + }, [fullHParams, experimentHyperparameters, selectedMetric, trials, data, selectedRuns]); if (!metricsLoaded || !chartData) { return ; diff --git a/webui/react/src/components/RunComparisonView.tsx b/webui/react/src/components/RunComparisonView.tsx index 110f44644438..0133e11cc3ae 100644 --- a/webui/react/src/components/RunComparisonView.tsx +++ b/webui/react/src/components/RunComparisonView.tsx @@ -5,10 +5,10 @@ import Pivot, { PivotProps } from 'hew/Pivot'; import SplitPane, { Pane } from 'hew/SplitPane'; import React, { useMemo } from 'react'; -//import CompareHyperparameters from 'components/CompareHyperparameters'; +import CompareHyperparameters from 'components/CompareHyperparameters'; import useMobile from 'hooks/useMobile'; import useScrollbarWidth from 'hooks/useScrollbarWidth'; -//import { TrialsComparisonTable } from 'pages/ExperimentDetails/TrialsComparisonModal'; +import { TrialsComparisonTable } from 'pages/ExperimentDetails/TrialsComparisonModal'; import { useRunMetrics } from 'pages/FlatRuns/useRunMetrics'; import { FlatRun } from 'types'; @@ -30,7 +30,7 @@ const RunComparisonView: React.FC = ({ initialWidth, onWidthChange, fixedColumnsCount, - //projectId, + projectId, selectedRuns, }) => { const scrollbarWidth = useScrollbarWidth(); @@ -50,25 +50,24 @@ const RunComparisonView: React.FC = ({ key: 'metrics', label: 'Metrics', }, - // { - // children: ( - // - // ), - // key: 'hyperparameters', - // label: 'Hyperparameters', - // }, - // { - // children: , - // key: 'details', - // label: 'Details', - // }, + { + children: ( + + ), + key: 'hyperparameters', + label: 'Hyperparameters', + }, + { + children: , + key: 'details', + label: 'Details', + }, ]; - }, [metricData, selectedRuns]); + }, [metricData, projectId, selectedRuns]); const leftPane = open && !hasPinnedColumns ? (