From 068b9595a696a4589f8aed9120c0f0af41d2c980 Mon Sep 17 00:00:00 2001 From: jesse-amano-hpe Date: Wed, 31 Jul 2024 14:05:41 -0700 Subject: [PATCH] feat: checkpoint view for flat runs [ET-658] (#9769) Co-authored-by: Ashton Galloway --- .../ExperimentCheckpoints.settings.ts | 101 +++-- .../ExperimentCheckpoints.tsx | 20 +- .../ExperimentDetails/TrialCheckpoints.tsx | 410 ++++++++++++++++++ webui/react/src/pages/TrialDetails.tsx | 12 + webui/react/src/services/api.ts | 6 + webui/react/src/services/apiConfig.ts | 20 + 6 files changed, 524 insertions(+), 45 deletions(-) create mode 100644 webui/react/src/pages/ExperimentDetails/TrialCheckpoints.tsx diff --git a/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.settings.ts b/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.settings.ts index 4cc6ca0d462..c8235baccd7 100644 --- a/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.settings.ts +++ b/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.settings.ts @@ -1,10 +1,11 @@ -import { array, boolean, literal, number, string, undefined as undefinedType, union } from 'io-ts'; +import { array, boolean, keyof, number, string, undefined as undefinedType, union } from 'io-ts'; import { InteractiveTableSettings } from 'components/Table/InteractiveTable'; import { MINIMUM_PAGE_SIZE } from 'components/Table/Table'; import { SettingsConfig } from 'hooks/useSettings'; import { Checkpointv1SortBy } from 'services/api-ts-sdk'; import { CheckpointState } from 'types'; +import valueof from 'utils/valueof'; export type CheckpointColumnName = | 'action' @@ -48,14 +49,14 @@ export const configForExperiment = (id: number): SettingsConfig => ({ skipUrlEncoding: true, storageKey: 'columns', type: array( - union([ - literal('action'), - literal('uuid'), - literal('state'), - literal('searcherMetric'), - literal('totalBatches'), - literal('checkpoint'), - ]), + keyof({ + action: null, + checkpoint: null, + searcherMetric: null, + state: null, + totalBatches: null, + uuid: null, + }), ), }, columnWidths: { @@ -78,32 +79,12 @@ export const configForExperiment = (id: number): SettingsConfig => ({ sortKey: { defaultValue: Checkpointv1SortBy.UUID, storageKey: 'sortKey', - type: union([ - literal(Checkpointv1SortBy.BATCHNUMBER), - literal(Checkpointv1SortBy.ENDTIME), - literal(Checkpointv1SortBy.SEARCHERMETRIC), - literal(Checkpointv1SortBy.STATE), - literal(Checkpointv1SortBy.TRIALID), - literal(Checkpointv1SortBy.UNSPECIFIED), - literal(Checkpointv1SortBy.UUID), - ]), + type: valueof(Checkpointv1SortBy), }, state: { defaultValue: undefined, storageKey: 'state', - type: union([ - undefinedType, - array( - union([ - literal(CheckpointState.Active), - literal(CheckpointState.Completed), - literal(CheckpointState.Deleted), - literal(CheckpointState.PartiallyDeleted), - literal(CheckpointState.Error), - literal(CheckpointState.Unspecified), - ]), - ), - ]), + type: union([undefinedType, array(valueof(CheckpointState))]), }, tableLimit: { defaultValue: MINIMUM_PAGE_SIZE, @@ -118,3 +99,61 @@ export const configForExperiment = (id: number): SettingsConfig => ({ }, storagePath: `${id}-checkpoints`, }); + +export const configForTrial = (id: number): SettingsConfig => ({ + settings: { + columns: { + defaultValue: DEFAULT_COLUMNS, + skipUrlEncoding: true, + storageKey: 'columns', + type: array( + keyof({ + action: null, + checkpoint: null, + searcherMetric: null, + state: null, + totalBatches: null, + uuid: null, + }), + ), + }, + columnWidths: { + defaultValue: DEFAULT_COLUMNS.map((col: CheckpointColumnName) => DEFAULT_COLUMN_WIDTHS[col]), + skipUrlEncoding: true, + storageKey: 'columnWidths', + type: array(number), + }, + row: { + defaultValue: undefined, + skipUrlEncoding: true, + storageKey: 'row', + type: union([undefinedType, array(string)]), + }, + sortDesc: { + defaultValue: true, + storageKey: 'sortDesc', + type: boolean, + }, + sortKey: { + defaultValue: Checkpointv1SortBy.UUID, + storageKey: 'sortKey', + type: valueof(Checkpointv1SortBy), + }, + state: { + defaultValue: undefined, + storageKey: 'state', + type: union([undefinedType, array(valueof(CheckpointState))]), + }, + tableLimit: { + defaultValue: MINIMUM_PAGE_SIZE, + storageKey: 'tableLimit', + type: number, + }, + tableOffset: { + defaultValue: 0, + storageKey: 'tableOffset', + type: number, + }, + }, + storagePath: `trial-${id}-checkpoints`, +}); diff --git a/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx b/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx index 3b1e47e417f..23edfa766eb 100644 --- a/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx +++ b/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx @@ -283,9 +283,9 @@ const ExperimentCheckpoints: React.FC = ({ experiment, pageRef }: Props) { signal: canceler.signal }, ); setTotal(response.pagination.total ?? 0); - if (!isEqual(response.checkpoints, checkpoints)) { - setCheckpoints(response.checkpoints); - } + setCheckpoints((cps) => { + return isEqual(response.checkpoints, cps) ? cps : response.checkpoints; + }); } catch (e) { handleError(e, { publicSubject: `Unable to fetch ${f_flat_runs ? 'search' : 'experiment'} ${experiment.id} checkpoints.`, @@ -295,7 +295,7 @@ const ExperimentCheckpoints: React.FC = ({ experiment, pageRef }: Props) } finally { setIsLoading(false); } - }, [f_flat_runs, settings, experiment.id, canceler.signal, checkpoints]); + }, [f_flat_runs, settings, experiment.id, canceler.signal]); const submitBatchAction = useCallback( async (action: CheckpointAction) => { @@ -319,21 +319,13 @@ const ExperimentCheckpoints: React.FC = ({ experiment, pageRef }: Props) [dropDownOnTrigger, fetchExperimentCheckpoints, settings.row], ); - const { stopPolling } = usePolling(fetchExperimentCheckpoints, { rerunOnNewFn: true }); - - // Get new trials based on changes to the pagination, sorter and filters. - useEffect(() => { - setIsLoading(true); - fetchExperimentCheckpoints(); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); + usePolling(fetchExperimentCheckpoints, { rerunOnNewFn: true }); useEffect(() => { return () => { canceler.abort(); - stopPolling(); }; - }, [canceler, stopPolling]); + }, [canceler]); const handleTableRowSelect = useCallback( (rowKeys?: Key[]) => { diff --git a/webui/react/src/pages/ExperimentDetails/TrialCheckpoints.tsx b/webui/react/src/pages/ExperimentDetails/TrialCheckpoints.tsx new file mode 100644 index 00000000000..80d3b6a75fe --- /dev/null +++ b/webui/react/src/pages/ExperimentDetails/TrialCheckpoints.tsx @@ -0,0 +1,410 @@ +import { FilterDropdownProps } from 'antd/es/table/interface'; +import Button from 'hew/Button'; +import Icon from 'hew/Icon'; +import { useModal } from 'hew/Modal'; +import useConfirm from 'hew/useConfirm'; +import { isEqual } from 'lodash'; +import React, { Key, useCallback, useEffect, useMemo, useState } from 'react'; + +import ActionDropdown from 'components/ActionDropdown'; +import Badge, { BadgeType } from 'components/Badge'; +import ModelCreateModal from 'components/ModelCreateModal'; +import RegisterCheckpointModal from 'components/RegisterCheckpointModal'; +import Section from 'components/Section'; +import InteractiveTable, { ContextMenuProps } from 'components/Table/InteractiveTable'; +import SkeletonTable from 'components/Table/SkeletonTable'; +import { + defaultRowClassName, + getFullPaginationConfig, + HumanReadableNumberRenderer, +} from 'components/Table/Table'; +import TableBatch from 'components/Table/TableBatch'; +import TableFilterDropdown from 'components/Table/TableFilterDropdown'; +import { useCheckpointFlow } from 'hooks/useCheckpointFlow'; +import useFeature from 'hooks/useFeature'; +import { useFetchModels } from 'hooks/useFetchModels'; +import usePolling from 'hooks/usePolling'; +import { useSettings } from 'hooks/useSettings'; +import { deleteCheckpoints, getTrialCheckpoints } from 'services/api'; +import { Checkpointv1SortBy, Checkpointv1State } from 'services/api-ts-sdk'; +import { encodeCheckpointState } from 'services/decoder'; +import { + CheckpointAction, + checkpointAction, + CheckpointState, + CoreApiGenericCheckpoint, + ExperimentBase, + RecordKey, + TrialDetails, +} from 'types'; +import { canActionCheckpoint, getActionsForCheckpointsUnion } from 'utils/checkpoint'; +import { ensureArray } from 'utils/data'; +import handleError, { DetError, ErrorLevel, ErrorType } from 'utils/error'; +import { validateDetApiEnum, validateDetApiEnumList } from 'utils/service'; +import { pluralizer } from 'utils/string'; + +import { configForTrial, Settings } from './ExperimentCheckpoints.settings'; +import { columns as defaultColumns } from './ExperimentCheckpoints.table'; + +interface Props { + experiment: ExperimentBase; + trial: TrialDetails; + pageRef: React.RefObject; +} + +const batchActions = [checkpointAction.Register, checkpointAction.Delete]; + +const TrialCheckpoints: React.FC = ({ experiment, trial, pageRef }: Props) => { + const confirm = useConfirm(); + const [total, setTotal] = useState(0); + const [isLoading, setIsLoading] = useState(true); + const [checkpoints, setCheckpoints] = useState(); + const [selectedCheckpoints, setSelectedCheckpoints] = useState(); + const [selectedModelName, setSelectedModelName] = useState(); + const [canceler] = useState(new AbortController()); + const models = useFetchModels(); + const f_flat_runs = useFeature().isOn('flat_runs'); + + const config = useMemo(() => configForTrial(trial.id), [trial.id]); + const { settings, updateSettings } = useSettings(config); + + const [checkpoint, setCheckpoint] = useState(); + const { checkpointModalComponents, openCheckpoint } = useCheckpointFlow({ + checkpoint: checkpoint, + config: experiment.config, + models, + title: `Checkpoint ${checkpoint?.uuid}`, + }); + + const modelCreateModal = useModal(ModelCreateModal); + const registerModal = useModal(RegisterCheckpointModal); + + const handleOnCloseCreateModel = useCallback( + (modelName?: string) => { + if (modelName) { + setSelectedModelName(modelName); + registerModal.open(); + } + }, + [setSelectedModelName, registerModal], + ); + + const clearSelected = useCallback(() => { + updateSettings({ row: undefined }); + }, [updateSettings]); + + const handleStateFilterApply = useCallback( + (states: string[]) => { + updateSettings({ + row: undefined, + state: states.length !== 0 ? (states as CheckpointState[]) : undefined, + tableOffset: 0, + }); + }, + [updateSettings], + ); + + const handleStateFilterReset = useCallback(() => { + updateSettings({ row: undefined, state: undefined, tableOffset: 0 }); + }, [updateSettings]); + + const stateFilterDropdown = useCallback( + (filterProps: FilterDropdownProps) => { + return ( + + ); + }, + [handleStateFilterApply, handleStateFilterReset, settings.state], + ); + + const handleRegisterCheckpoint = useCallback( + (checkpoints: string[]) => { + setSelectedCheckpoints(checkpoints); + registerModal.open(); + }, + [registerModal], + ); + + const handleDelete = useCallback(async (checkpointUuids: string[]) => { + try { + await deleteCheckpoints({ checkpointUuids }); + } catch (e) { + if (e instanceof DetError && e.type === ErrorType.Server) { + e.silent = false; + } + // confirm modal overwrites error message + handleError(e); + } + }, []); + + const handleDeleteCheckpoint = useCallback( + (checkpoints: string[]) => { + const content = `Are you sure you want to request checkpoint deletion for ${ + checkpoints.length + } + ${pluralizer( + checkpoints.length, + 'checkpoint', + )}? This action may complete or fail without further notification.`; + + confirm({ + content, + danger: true, + okText: 'Request Delete', + onConfirm: () => handleDelete(checkpoints), + onError: handleError, + title: 'Confirm Checkpoint Deletion', + }); + }, + [confirm, handleDelete], + ); + + const dropDownOnTrigger = useCallback( + (checkpoints: string | string[]) => { + const checkpointsArr = ensureArray(checkpoints); + return { + [checkpointAction.Register]: () => handleRegisterCheckpoint(checkpointsArr), + [checkpointAction.Delete]: () => handleDeleteCheckpoint(checkpointsArr), + }; + }, + [handleDeleteCheckpoint, handleRegisterCheckpoint], + ); + + const CheckpointActionDropdown: React.FC> = + useCallback( + ({ record, children }) => { + return ( + + actionOrder={batchActions} + danger={{ [checkpointAction.Delete]: true }} + disabled={{ + [checkpointAction.Register]: !canActionCheckpoint(checkpointAction.Register, record), + [checkpointAction.Delete]: !canActionCheckpoint(checkpointAction.Delete, record), + }} + id={record.uuid} + isContextMenu + kind="checkpoint" + onError={handleError} + onTrigger={dropDownOnTrigger(record.uuid)}> + {children} + + ); + }, + [dropDownOnTrigger], + ); + + const handleOpenCheckpoint = useCallback( + (checkpoint: CoreApiGenericCheckpoint) => { + setCheckpoint(checkpoint); + openCheckpoint(); + }, + [openCheckpoint], + ); + + const columns = useMemo(() => { + const actionRenderer = (_: string, record: CoreApiGenericCheckpoint): React.ReactNode => ( + + actionOrder={batchActions} + danger={{ [checkpointAction.Delete]: true }} + disabled={{ + [checkpointAction.Register]: !canActionCheckpoint(checkpointAction.Register, record), + [checkpointAction.Delete]: !canActionCheckpoint(checkpointAction.Delete, record), + }} + id={record.uuid} + kind="checkpoint" + onError={handleError} + onTrigger={dropDownOnTrigger(record.uuid)} + /> + ); + + const checkpointRenderer = (_: string, record: CoreApiGenericCheckpoint): React.ReactNode => { + return ( +