From 655a9ab96d51cc787b898d999fe115b0d0f88b4a Mon Sep 17 00:00:00 2001 From: Ashton G Date: Wed, 17 Apr 2024 15:38:07 -0400 Subject: [PATCH] fix: report errors from deletecheckpoints endpoint + improve feedback (#9178) --- master/internal/api_checkpoint.go | 71 +++++++++++++++++-- .../checkpoints/postgres_checkpoints.go | 27 +++++-- .../postgres_checkpoints_intg_test.go | 10 ++- .../react/src/components/CheckpointModal.tsx | 12 ++-- .../ExperimentCheckpoints.tsx | 17 +++-- webui/react/src/services/api.ts | 6 ++ webui/react/src/services/apiConfig.ts | 12 +++- 7 files changed, 127 insertions(+), 28 deletions(-) diff --git a/master/internal/api_checkpoint.go b/master/internal/api_checkpoint.go index 4c78b4f42f8c..f96b231c7793 100644 --- a/master/internal/api_checkpoint.go +++ b/master/internal/api_checkpoint.go @@ -198,6 +198,55 @@ func (a *apiServer) checkpointsRBACEditCheck( return exps, groupCUUIDsByEIDs, nil } +func makeRegisteredCheckpointErrorMessage( + ctx context.Context, + baseMessageFormat string, + checkpointMap map[uuid.UUID]checkpoints.ModelInfo, +) (*string, error) { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + var modelIDs []int + for _, v := range checkpointMap { + modelIDs = append(modelIDs, v.ID) + } + var models []struct { + ID int + } + modelQuery := internaldb.Bun(). + NewSelect(). + Model(&models). + Table("models"). + ColumnExpr("id"). + Where("id IN (?)", bun.In(modelIDs)) + if modelQuery, err = modelauth.AuthZProvider.Get(). + FilterReadableModelsQuery(ctx, *curUser, modelQuery); err != nil { + return nil, err + } + err = modelQuery.Scan(ctx) + if err != nil { + return nil, err + } + accessibleModels := make(map[int]bool, len(models)) + for _, v := range models { + accessibleModels[v.ID] = true + } + var checkpointMsgs []string + for k, v := range checkpointMap { + if _, ok := accessibleModels[v.ID]; ok { + msg := fmt.Sprintf("%v, registered to %v (model #%d), version %d", k, v.Name, v.ID, v.Version) + checkpointMsgs = append(checkpointMsgs, msg) + } else { + msg := fmt.Sprintf("%v, registered to an unknown model", k) + checkpointMsgs = append(checkpointMsgs, msg) + } + } + retVal := fmt.Sprintf(baseMessageFormat, strings.Join(checkpointMsgs, ", ")) + + return &retVal, nil +} + func (a *apiServer) PatchCheckpoints( ctx context.Context, req *apiv1.PatchCheckpointsRequest, @@ -222,9 +271,15 @@ func (a *apiServer) PatchCheckpoints( return nil, err } if len(registeredCheckpointUUIDs) > 0 { - return nil, status.Errorf(codes.InvalidArgument, - "this subset of checkpoints provided are in the model registry and cannot be deleted: %v.", - registeredCheckpointUUIDs) + errMsg, err := makeRegisteredCheckpointErrorMessage( + ctx, + "this subset of checkpoints provided are in the model registry and cannot be patched: %v.", + registeredCheckpointUUIDs, + ) + if err != nil { + return nil, err + } + return nil, status.Errorf(codes.InvalidArgument, *errMsg) } err = internaldb.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { @@ -327,9 +382,15 @@ func (a *apiServer) CheckpointsRemoveFiles( return nil, err } if len(registeredCheckpointUUIDs) > 0 { - return nil, status.Errorf(codes.InvalidArgument, + errMsg, err := makeRegisteredCheckpointErrorMessage( + ctx, "this subset of checkpoints provided are in the model registry and cannot be deleted: %v.", - registeredCheckpointUUIDs) + registeredCheckpointUUIDs, + ) + if err != nil { + return nil, err + } + return nil, status.Errorf(codes.InvalidArgument, *errMsg) } taskSpec := *a.m.taskSpec diff --git a/master/internal/checkpoints/postgres_checkpoints.go b/master/internal/checkpoints/postgres_checkpoints.go index 197f2796f8ec..261b8abeef60 100644 --- a/master/internal/checkpoints/postgres_checkpoints.go +++ b/master/internal/checkpoints/postgres_checkpoints.go @@ -54,25 +54,40 @@ func GetModelIDsAssociatedWithCheckpoint(ctx context.Context, ckptUUID uuid.UUID return modelIDs, nil } +// ModelInfo is a struct containing info used for locating models. +type ModelInfo struct { + ID int + Version int + Name string +} + // GetRegisteredCheckpoints gets the checkpoints in // the model registrys from the list of checkpoints provided. -func GetRegisteredCheckpoints(ctx context.Context, checkpoints []uuid.UUID) (map[uuid.UUID]bool, error) { +func GetRegisteredCheckpoints(ctx context.Context, checkpoints []uuid.UUID) (map[uuid.UUID]ModelInfo, error) { var checkpointIDRows []struct { - ID uuid.UUID + ID uuid.UUID + ModelID int + ModelVersion int + ModelName string } if err := db.Bun().NewRaw(` - SELECT DISTINCT(mv.checkpoint_uuid) as ID FROM model_versions AS mv - WHERE mv.checkpoint_uuid IN (SELECT UNNEST(?::uuid[]));`, + SELECT DISTINCT(mv.checkpoint_uuid) as ID, mv.model_id as model_id, mv.version as model_version, m.name as model_name + FROM model_versions AS mv LEFT JOIN models as m on mv.model_id=m.id WHERE mv.checkpoint_uuid + IN (SELECT UNNEST(?::uuid[]));`, pgdialect.Array(checkpoints)).Scan(ctx, &checkpointIDRows); err != nil { return nil, fmt.Errorf( "filtering checkpoint uuids by those registered in the model registry: %w", err) } - checkpointIDs := make(map[uuid.UUID]bool, len(checkpointIDRows)) + checkpointIDs := make(map[uuid.UUID]ModelInfo, len(checkpointIDRows)) for _, cRow := range checkpointIDRows { - checkpointIDs[cRow.ID] = true + checkpointIDs[cRow.ID] = ModelInfo{ + ID: cRow.ModelID, + Version: cRow.ModelVersion, + Name: cRow.ModelName, + } } return checkpointIDs, nil diff --git a/master/internal/checkpoints/postgres_checkpoints_intg_test.go b/master/internal/checkpoints/postgres_checkpoints_intg_test.go index 309bf4b342ec..8d4636eff6c2 100644 --- a/master/internal/checkpoints/postgres_checkpoints_intg_test.go +++ b/master/internal/checkpoints/postgres_checkpoints_intg_test.go @@ -213,7 +213,7 @@ func TestGetRegisteredCheckpoints(t *testing.T) { Name: "checkpoint 1", Comment: "empty", } - _, err = db.InsertModelVersion(ctx, pmdl.Id, retCkpt1.Uuid, mv1.Name, mv1.Comment, + mv1mdl, err := db.InsertModelVersion(ctx, pmdl.Id, retCkpt1.Uuid, mv1.Name, mv1.Comment, emptyMetadata, strings.Join(mv1.Labels, ","), mv1.Notes, user.ID, ) require.NoError(t, err) @@ -233,8 +233,12 @@ func TestGetRegisteredCheckpoints(t *testing.T) { require.NoError(t, err) checkpoints := []uuid.UUID{checkpoint1.UUID, checkpoint3.UUID} - expectedRegisteredCheckpoints := make(map[uuid.UUID]bool) - expectedRegisteredCheckpoints[checkpoint1.UUID] = true + expectedRegisteredCheckpoints := make(map[uuid.UUID]ModelInfo) + expectedRegisteredCheckpoints[checkpoint1.UUID] = ModelInfo{ + ID: int(pmdl.Id), + Version: int(mv1mdl.Version), + Name: pmdl.Name, + } dCheckpointsInRegistry, err := GetRegisteredCheckpoints(ctx, checkpoints) require.NoError(t, err) require.Equal(t, expectedRegisteredCheckpoints, dCheckpointsInRegistry) diff --git a/webui/react/src/components/CheckpointModal.tsx b/webui/react/src/components/CheckpointModal.tsx index 50029f717245..adc45c958927 100644 --- a/webui/react/src/components/CheckpointModal.tsx +++ b/webui/react/src/components/CheckpointModal.tsx @@ -6,8 +6,7 @@ import useConfirm from 'hew/useConfirm'; import React, { useCallback, useMemo } from 'react'; import { paths } from 'routes/utils'; -import { detApi } from 'services/apiConfig'; -import { readStream } from 'services/utils'; +import { deleteCheckpoints } from 'services/api'; import { CheckpointState, CheckpointStorageType, @@ -86,12 +85,17 @@ const CheckpointModalComponent: React.FC = ({ const handleDelete = useCallback(async () => { if (!checkpoint?.uuid) return; - await readStream(detApi.Checkpoint.deleteCheckpoints({ checkpointUuids: [checkpoint.uuid] })); + try { + await deleteCheckpoints({ checkpointUuids: [checkpoint.uuid] }); + } catch (e) { + // modal error handling overwrites error message + handleError(e); + } }, [checkpoint]); const onClickDelete = useCallback(() => { const content = `Are you sure you want to request checkpoint deletion for batch -${checkpoint?.totalBatches}. This action may complete or fail without further notification.`; +${checkpoint?.totalBatches}? This action may complete or fail without further notification.`; confirm({ content, diff --git a/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx b/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx index 519251dcb538..edc87ba6a7bb 100644 --- a/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx +++ b/webui/react/src/pages/ExperimentDetails/ExperimentCheckpoints.tsx @@ -24,11 +24,9 @@ import { useCheckpointFlow } from 'hooks/useCheckpointFlow'; import { useFetchModels } from 'hooks/useFetchModels'; import usePolling from 'hooks/usePolling'; import { useSettings } from 'hooks/useSettings'; -import { getExperimentCheckpoints } from 'services/api'; +import { deleteCheckpoints, getExperimentCheckpoints } from 'services/api'; import { Checkpointv1SortBy, Checkpointv1State } from 'services/api-ts-sdk'; -import { detApi } from 'services/apiConfig'; import { encodeCheckpointState } from 'services/decoder'; -import { readStream } from 'services/utils'; import { checkpointAction, CheckpointAction, @@ -128,12 +126,13 @@ const ExperimentCheckpoints: React.FC = ({ experiment, pageRef }: Props) [registerModal], ); - const handleDelete = useCallback((checkpoints: string[]) => { - readStream( - detApi.Checkpoint.deleteCheckpoints({ - checkpointUuids: checkpoints, - }), - ); + const handleDelete = useCallback(async (checkpointUuids: string[]) => { + try { + await deleteCheckpoints({ checkpointUuids }); + } catch (e) { + // confirm modal overwrites error message + handleError(e); + } }, []); const handleDeleteCheckpoint = useCallback( diff --git a/webui/react/src/services/api.ts b/webui/react/src/services/api.ts index 90738a6a27b2..5ceb41216b0d 100644 --- a/webui/react/src/services/api.ts +++ b/webui/react/src/services/api.ts @@ -845,6 +845,12 @@ export const launchTensorBoard = generateDetApi< Type.CommandResponse >(Config.launchTensorBoard); +export const deleteCheckpoints = generateDetApi< + Api.V1DeleteCheckpointsRequest, + Api.V1DeleteCheckpointsResponse, + Api.V1DeleteCheckpointsResponse +>(Config.deleteCheckpoints); + export const openOrCreateTensorBoard = async ( params: Service.LaunchTensorBoardParams, ): Promise => { diff --git a/webui/react/src/services/apiConfig.ts b/webui/react/src/services/apiConfig.ts index 708cf2e7db2d..c356b27773d1 100644 --- a/webui/react/src/services/apiConfig.ts +++ b/webui/react/src/services/apiConfig.ts @@ -28,7 +28,7 @@ const generateApiConfig = (apiConfig?: Api.ConfigurationParameters) => { const config = updatedApiConfigParams(apiConfig); return { Auth: new Api.AuthenticationApi(config), - Checkpoint: Api.CheckpointsApiFetchParamCreator(config), + Checkpoint: new Api.CheckpointsApi(config), Cluster: new Api.ClusterApi(config), Commands: new Api.CommandsApi(config), Experiments: new Api.ExperimentsApi(config), @@ -1952,3 +1952,13 @@ export const updateJobQueue: DetApi< postProcess: identity, request: (params: Api.V1UpdateJobQueueRequest) => detApi.Internal.updateJobQueue(params), }; + +export const deleteCheckpoints: DetApi< + Api.V1DeleteCheckpointsRequest, + Api.V1DeleteCheckpointsResponse, + Api.V1DeleteCheckpointsResponse +> = { + name: 'deleteCheckpoints', + postProcess: identity, + request: (params, options) => detApi.Checkpoint.deleteCheckpoints(params, options), +};