From 3c9a1885768138e12e5020f355315c6e819a278d Mon Sep 17 00:00:00 2001 From: "Ashton G." Date: Wed, 7 Aug 2024 15:40:17 -0400 Subject: [PATCH] fix: prevent multiple calls to time-series on compare view select (#9805) (cherry picked from commit f7e18fc7474977db7a7922195d1d3a9c0760b5ee) --- .../CompareHyperparameters.test.mock.tsx | 14 -- webui/react/src/components/CompareMetrics.tsx | 9 +- .../components/ComparisonView.test.mock.tsx | 14 -- webui/react/src/components/ComparisonView.tsx | 159 ++++++++++-------- webui/react/src/hooks/useMetricNames.ts | 25 +-- webui/react/src/hooks/useMetrics.ts | 100 ++++++----- 6 files changed, 162 insertions(+), 159 deletions(-) diff --git a/webui/react/src/components/CompareHyperparameters.test.mock.tsx b/webui/react/src/components/CompareHyperparameters.test.mock.tsx index fcca16cc70e..75847a6542d 100644 --- a/webui/react/src/components/CompareHyperparameters.test.mock.tsx +++ b/webui/react/src/components/CompareHyperparameters.test.mock.tsx @@ -108,20 +108,6 @@ export const METRIC_DATA: RunMetricData = { }, ], scale: 'linear', - selectedMetrics: [ - { - group: 'training', - name: 'loss', - }, - { - group: 'validation', - name: 'accuracy', - }, - { - group: 'validation', - name: 'validation_loss', - }, - ], setScale: (): Scale => { return Scale.Linear; }, diff --git a/webui/react/src/components/CompareMetrics.tsx b/webui/react/src/components/CompareMetrics.tsx index 946ddf1a65d..c9fc7dc7787 100644 --- a/webui/react/src/components/CompareMetrics.tsx +++ b/webui/react/src/components/CompareMetrics.tsx @@ -1,6 +1,5 @@ import { ChartGrid, ChartsProps } from 'hew/LineChart'; import { Loadable, Loaded, NotLoaded } from 'hew/utils/loadable'; -import _ from 'lodash'; import React, { useCallback, useMemo, useState } from 'react'; import MetricBadgeTag from 'components/MetricBadgeTag'; @@ -116,7 +115,7 @@ const CompareMetrics: React.FC = ({ ); const chartsProps: Loadable = useMemo(() => { - const { metricHasData, metrics, isLoaded, selectedMetrics } = metricData; + const { metricHasData, metrics, isLoaded } = metricData; const { chartProps, chartedMetrics } = selectedRuns ? calculateRunsChartProps(metricData, selectedRuns, xAxis, colorMap) : calculateExperimentChartProps(metricData, selectedExperiments, trials, xAxis, colorMap); @@ -134,11 +133,7 @@ const CompareMetrics: React.FC = ({ if (!isLoaded) { // When trial metrics hasn't loaded metric names or individual trial metrics. return NotLoaded; - } else if (!chartDataIsLoaded || !_.isEqual(selectedMetrics, metrics)) { - // In some cases the selectedMetrics returned may not be up to date - // with the metrics selected by the user. In this case we want to - // show a loading state until the metrics match. - + } else if (!chartDataIsLoaded) { // returns the chartProps with a NotLoaded series which enables // the ChartGrid to show a spinner for the loading charts. return Loaded(chartProps.map((chartProps) => ({ ...chartProps, series: NotLoaded }))); diff --git a/webui/react/src/components/ComparisonView.test.mock.tsx b/webui/react/src/components/ComparisonView.test.mock.tsx index e4b1ebf8c0d..908b9454530 100644 --- a/webui/react/src/components/ComparisonView.test.mock.tsx +++ b/webui/react/src/components/ComparisonView.test.mock.tsx @@ -57,20 +57,6 @@ export const METRIC_DATA: RunMetricData = { }, ], scale: 'linear', - selectedMetrics: [ - { - group: 'training', - name: 'loss', - }, - { - group: 'validation', - name: 'accuracy', - }, - { - group: 'validation', - name: 'validation_loss', - }, - ], setScale: (): Scale => { return Scale.Linear; }, diff --git a/webui/react/src/components/ComparisonView.tsx b/webui/react/src/components/ComparisonView.tsx index 1b9dfc65f73..c5d5d018f06 100644 --- a/webui/react/src/components/ComparisonView.tsx +++ b/webui/react/src/components/ComparisonView.tsx @@ -40,6 +40,88 @@ type Props = XOR<{ experimentSelection: SelectionType }, { runSelection: Selecti const SELECTION_LIMIT = 50; +interface TabsProps { + colorMap: MapOfIdsToColors; + loadableSelectedExperiments: Loadable; + loadableSelectedRuns: Loadable; + projectId: number; +} + +const Tabs = ({ + colorMap, + loadableSelectedExperiments, + loadableSelectedRuns, + projectId, +}: TabsProps) => { + const selectedExperiments: ExperimentWithTrial[] | undefined = Loadable.getOrElse( + undefined, + loadableSelectedExperiments, + ); + + const selectedRuns: FlatRun[] | undefined = Loadable.getOrElse(undefined, loadableSelectedRuns); + + const trials = useMemo(() => { + return selectedExperiments?.flatMap((exp) => (exp.bestTrial ? [exp.bestTrial] : [])) ?? []; + }, [selectedExperiments]); + + const experiments = useMemo( + () => selectedExperiments?.map((exp) => exp.experiment) ?? [], + [selectedExperiments], + ); + + const metricData = useMetrics(selectedRuns ?? trials ?? []); + + const tabs: PivotProps['items'] = useMemo(() => { + return [ + { + children: selectedRuns ? ( + + ) : ( + + ), + key: 'metrics', + label: 'Metrics', + }, + { + children: selectedRuns ? ( + + ) : ( + + ), + key: 'hyperparameters', + label: 'Hyperparameters', + }, + { + children: selectedRuns ? ( + + ) : ( + + ), + key: 'details', + label: 'Details', + }, + ]; + }, [selectedRuns, metricData, selectedExperiments, trials, colorMap, projectId, experiments]); + + return ; +}; + const ComparisonView: React.FC = ({ children, colorMap, @@ -82,11 +164,6 @@ const ComparisonView: React.FC = ({ } }, [experimentSelection, open]); - const selectedExperiments: ExperimentWithTrial[] | undefined = Loadable.getOrElse( - undefined, - loadableSelectedExperiments, - ); - const loadableSelectedRuns = useAsync(async () => { if ( !open || @@ -112,71 +189,10 @@ const ComparisonView: React.FC = ({ } }, [open, runSelection]); - const selectedRuns: FlatRun[] | undefined = Loadable.getOrElse(undefined, loadableSelectedRuns); - const minWidths: [number, number] = useMemo(() => { return [fixedColumnsCount * MIN_COLUMN_WIDTH + scrollbarWidth, 100]; }, [fixedColumnsCount, scrollbarWidth]); - const trials = useMemo(() => { - return selectedExperiments?.flatMap((exp) => (exp.bestTrial ? [exp.bestTrial] : [])) ?? []; - }, [selectedExperiments]); - - const experiments = useMemo( - () => selectedExperiments?.map((exp) => exp.experiment) ?? [], - [selectedExperiments], - ); - - const metricData = useMetrics(selectedRuns ?? trials ?? []); - - const tabs: PivotProps['items'] = useMemo(() => { - return [ - { - children: selectedRuns ? ( - - ) : ( - - ), - key: 'metrics', - label: 'Metrics', - }, - { - children: selectedRuns ? ( - - ) : ( - - ), - key: 'hyperparameters', - label: 'Hyperparameters', - }, - { - children: selectedRuns ? ( - - ) : ( - - ), - key: 'details', - label: 'Details', - }, - ]; - }, [selectedRuns, metricData, selectedExperiments, trials, colorMap, projectId, experiments]); - const leftPane = open && !hasPinnedColumns ? ( @@ -191,7 +207,7 @@ const ComparisonView: React.FC = ({ ); } - if (selectedExperiments === undefined) { + if (loadableSelectedExperiments.isNotLoaded) { return ; } } @@ -201,7 +217,7 @@ const ComparisonView: React.FC = ({ ); } - if (selectedRuns === undefined) { + if (loadableSelectedRuns.isNotLoaded) { return ; } } @@ -210,7 +226,12 @@ const ComparisonView: React.FC = ({ {isSelectionLimitReached && ( )} - + ); }; diff --git a/webui/react/src/hooks/useMetricNames.ts b/webui/react/src/hooks/useMetricNames.ts index 2ea031898eb..fa0759c15f5 100644 --- a/webui/react/src/hooks/useMetricNames.ts +++ b/webui/react/src/hooks/useMetricNames.ts @@ -1,6 +1,6 @@ import { Loadable, Loaded, NotLoaded } from 'hew/utils/loadable'; import _ from 'lodash'; -import { useEffect, useState } from 'react'; +import { useEffect, useMemo, useRef, useState } from 'react'; import { V1ExpMetricNamesResponse } from 'services/api-ts-sdk'; import { detApi } from 'services/apiConfig'; @@ -8,31 +8,33 @@ import { readStream } from 'services/utils'; import { Metric, XAxisDomain } from 'types'; import { metricKeyToMetric, metricSorter, metricToKey } from 'utils/metric'; -import usePrevious from './usePrevious'; - -// DO NOT pass a raw object for experimentIds param -// That causes unwanted API call const useMetricNames = ( experimentIds: number[], errorHandler?: (e: unknown) => void, quickPoll?: boolean, ): Loadable => { const [metrics, setMetrics] = useState>(NotLoaded); - const previousExpIds = usePrevious(experimentIds, []); + // Do not replace with usePrevious here -- it restarts the stream erroneously; + const idsRef = useRef(experimentIds); + const curExperimentIds = useMemo(() => { + return _.isEqual(experimentIds, idsRef.current) ? idsRef.current : experimentIds; + }, [experimentIds]); useEffect(() => { - if (experimentIds.length === 0) { + const previousExpIds = idsRef.current; + if (curExperimentIds.length === 0) { setMetrics(Loaded([])); return; } - if (!_.isEqual(experimentIds, previousExpIds)) setMetrics(NotLoaded); + if (curExperimentIds !== previousExpIds) setMetrics(NotLoaded); + const canceler = new AbortController(); // We do not want to plot any x-axis metric values as y-axis data const xAxisMetrics = Object.values(XAxisDomain).map((v) => v.toLowerCase()); readStream( - detApi.StreamingInternal.expMetricNames(experimentIds, quickPoll ? 5 : undefined, { + detApi.StreamingInternal.expMetricNames(curExperimentIds, quickPoll ? 5 : undefined, { signal: canceler.signal, }), (event: V1ExpMetricNamesResponse) => { @@ -79,7 +81,10 @@ const useMetricNames = ( errorHandler, ); return () => canceler.abort(); - }, [experimentIds, previousExpIds, errorHandler, quickPoll]); + }, [curExperimentIds, errorHandler, quickPoll]); + useEffect(() => { + idsRef.current = experimentIds; + }); return metrics; }; diff --git a/webui/react/src/hooks/useMetrics.ts b/webui/react/src/hooks/useMetrics.ts index 2f19b890050..0625e5f7444 100644 --- a/webui/react/src/hooks/useMetrics.ts +++ b/webui/react/src/hooks/useMetrics.ts @@ -1,12 +1,11 @@ import { makeToast } from 'hew/Toast'; import { Loadable, Loaded, NotLoaded } from 'hew/utils/loadable'; import _ from 'lodash'; -import { useCallback, useEffect, useMemo, useState } from 'react'; +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { terminalRunStates } from 'constants/states'; import useMetricNames from 'hooks/useMetricNames'; import usePolling from 'hooks/usePolling'; -import usePrevious from 'hooks/usePrevious'; import { timeSeries } from 'services/api'; import { FlatRun, @@ -16,6 +15,7 @@ import { Scale, Serie, TrialDetails, + TrialSummary, XAxisDomain, } from 'types'; import handleError, { ErrorType } from 'utils/error'; @@ -35,7 +35,6 @@ export interface RunMetricData { scale: Scale; setScale: React.Dispatch>; metricHasData: Record; - selectedMetrics: Metric[]; } const summarizedMetricToSeries = ( @@ -44,7 +43,6 @@ const summarizedMetricToSeries = ( ): { data: Record; metricHasData: Record; - selectedMetrics: Metric[]; } => { const rawBatchValuesMap: Record = {}; const rawBatchTimesMap: Record = {}; @@ -94,9 +92,12 @@ const summarizedMetricToSeries = ( (xAxis) => (recordData?.[key]?.data?.[xAxis]?.length ?? 0) > 0, ); }); - return { data: recordData, metricHasData, selectedMetrics }; + return { data: recordData, metricHasData }; }; +// consistent reference to prevent extra calls +const EMPTY_METRICS: Metric[] = []; + export const useMetrics = (records: (TrialDetails | FlatRun | undefined)[]): RunMetricData => { const recordsAllTerminated = records?.every((record) => terminalRunStates.has(record?.state ?? RunState.Active), @@ -135,65 +136,75 @@ export const useMetrics = (records: (TrialDetails | FlatRun | undefined)[]): Run handleMetricNamesError, recordsAllNonTerminal, ); - const metricNamesLoaded = Loadable.isLoaded(loadableMetrics); const metrics = useMemo(() => { - return Loadable.getOrElse([], loadableMetrics); + return loadableMetrics + .map((m) => (m.length === 0 ? EMPTY_METRICS : m)) + .getOrElse(EMPTY_METRICS); }, [loadableMetrics]); - const [loadableData, setLoadableData] = - useState>>>(NotLoaded); - const [metricHasData, setMetricHasData] = useState>({}); const [scale, setScale] = useState(Scale.Linear); - const [selectedMetrics, setSelectedMetrics] = useState([]); - - const previousRecords = usePrevious(records, []); + const [curResponse, setCurResponse] = useState>(NotLoaded); + + // don't replace this with usePrevious -- we need the ref to prevent + // fetchRecordSummary from regenerating when the previous value is detected + const recordIdsRef = useRef([]); + const curRecordIds = useMemo(() => { + const ids = records + ?.map((t) => t?.id) + .filter((i: T): i is Exclude => i !== undefined); + return _.isEqual(ids, recordIdsRef.current) ? recordIdsRef.current : ids; + }, [records]); const fetchRecordSummary = useCallback(async () => { - // If the record ids have not changed then we do not need to - // show the loading state again. - if (!_.isEqual(_.map(previousRecords, 'id'), _.map(records, 'id'))) setLoadableData(NotLoaded); - - if (records.length === 0) { + if (loadableMetrics.isNotLoaded) return; + if (curRecordIds.length === 0) { // If there are no trials selected then // no data is available. - setMetricHasData({}); - setLoadableData(Loaded({})); + setCurResponse(Loaded([])); return; } - if (records.length > 0) { + if (curRecordIds.length > 0) { + const prevRecordIds = recordIdsRef.current; + // If the record ids have not changed then we do not need to + // show the loading state again. + if (curRecordIds !== prevRecordIds) setCurResponse(NotLoaded); try { - const metricsHaveData: Record = {}; const response = await timeSeries({ maxDatapoints: screen.width > 1600 ? 1500 : 1000, metrics, startBatches: 0, - trialIds: records?.map((t) => t?.id || 0).filter((i) => i > 0), + trialIds: curRecordIds, + }); + setCurResponse((prev) => { + const loadedData = Loaded(response); + return _.isEqual(loadedData, prev) ? prev : loadedData; }); + } catch (e) { + makeToast({ severity: 'Error', title: 'Error fetching metrics' }); + } + } + }, [loadableMetrics, metrics, curRecordIds]); + useEffect(() => { + recordIdsRef.current = curRecordIds; + }); + const requestData = useMemo(() => { + return Loadable.all([curResponse, loadableMetrics]); + }, [curResponse, loadableMetrics]); + const [metricHasData, data] = useMemo(() => { + return requestData + .map(([response, metrics]) => { + const metricsHaveData: Record = {}; const newData: Record> = {}; response.forEach((r) => { - const { - data: recordData, - metricHasData, - selectedMetrics: s, - } = summarizedMetricToSeries(r?.metrics, metrics); + const { data: recordData, metricHasData } = summarizedMetricToSeries(r?.metrics, metrics); Object.keys(metricHasData).forEach((key) => { metricsHaveData[key] ||= metricHasData[key]; }); newData[r.id] = recordData; - setSelectedMetrics((prev) => (_.isEqual(selectedMetrics, s) ? prev : s)); }); - setLoadableData((prev) => - _.isEqual(Loadable.getOrElse([], prev), newData) ? prev : Loaded(newData), - ); - // Wait until the metric names are loaded - // to determine if trials have data for any metric - if (Loadable.isLoaded(loadableMetrics)) { - setMetricHasData(metricsHaveData); - } - } catch (e) { - makeToast({ severity: 'Error', title: 'Error fetching metrics' }); - } - } - }, [loadableMetrics, metrics, selectedMetrics, records, previousRecords]); + return [metricsHaveData, newData] as const; + }) + .getOrElse([{}, {}]); + }, [requestData]); const fetchAll = useCallback(async () => { await Promise.allSettled([fetchRecordSummary()]); @@ -212,12 +223,11 @@ export const useMetrics = (records: (TrialDetails | FlatRun | undefined)[]): Run } return { - data: Loadable.getOrElse({}, loadableData), - isLoaded: metricNamesLoaded && Loadable.isLoaded(loadableData), + data, + isLoaded: requestData.isLoaded, metricHasData, metrics, scale, - selectedMetrics, setScale, }; };