Skip to content

Commit

Permalink
fix patchcheckpoints too
Browse files Browse the repository at this point in the history
  • Loading branch information
ashtonG committed Apr 17, 2024
1 parent e7518ca commit 4c8434f
Showing 1 changed file with 64 additions and 39 deletions.
103 changes: 64 additions & 39 deletions master/internal/api_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,55 @@ func (a *apiServer) checkpointsRBACEditCheck(
return exps, groupCUUIDsByEIDs, nil
}

func makeRegisteredCheckpointErrorMessage(
ctx context.Context,
baseMessageFormat string,
checkpointMap map[uuid.UUID]checkpoints.ModelInfo,
) (*string, error) {
curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}
var modelIDs []int
for _, v := range checkpointMap {
modelIDs = append(modelIDs, v.ID)
}
var models []struct {
ID int
}
modelQuery := internaldb.Bun().
NewSelect().
Model(&models).
Table("models").
ColumnExpr("id").
Where("id IN (?)", bun.In(modelIDs))
if modelQuery, err = modelauth.AuthZProvider.Get().
FilterReadableModelsQuery(ctx, *curUser, modelQuery); err != nil {
return nil, err
}
err = modelQuery.Scan(ctx)
if err != nil {
return nil, err
}
accessibleModels := make(map[int]bool, len(models))
for _, v := range models {
accessibleModels[v.ID] = true
}
var checkpointMsgs []string
for k, v := range checkpointMap {
if _, ok := accessibleModels[v.ID]; ok {
msg := fmt.Sprintf("%v, registered to %v (model #%d), version %d", k, v.Name, v.ID, v.Version)
checkpointMsgs = append(checkpointMsgs, msg)
} else {
msg := fmt.Sprintf("%v, registered to an unknown model", k)
checkpointMsgs = append(checkpointMsgs, msg)
}
}
retVal := fmt.Sprintf(baseMessageFormat, strings.Join(checkpointMsgs, ", "))

return &retVal, nil
}

func (a *apiServer) PatchCheckpoints(
ctx context.Context,
req *apiv1.PatchCheckpointsRequest,
Expand All @@ -222,9 +271,15 @@ func (a *apiServer) PatchCheckpoints(
return nil, err
}
if len(registeredCheckpointUUIDs) > 0 {
return nil, status.Errorf(codes.InvalidArgument,
"this subset of checkpoints provided are in the model registry and cannot be deleted: %v.",
registeredCheckpointUUIDs)
errMsg, err := makeRegisteredCheckpointErrorMessage(
ctx,
"this subset of checkpoints provided are in the model registry and cannot be patched: %v.",
registeredCheckpointUUIDs,
)
if err != nil {
return nil, err
}
return nil, status.Errorf(codes.InvalidArgument, *errMsg)
}

err = internaldb.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
Expand Down Expand Up @@ -327,45 +382,15 @@ func (a *apiServer) CheckpointsRemoveFiles(
return nil, err
}
if len(registeredCheckpointUUIDs) > 0 {
var modelIDs []int
for _, v := range registeredCheckpointUUIDs {
modelIDs = append(modelIDs, v.ID)
}
var models []struct {
ID int
}
modelQuery := internaldb.Bun().
NewSelect().
Model(&models).
Table("models").
ColumnExpr("id").
Where("id IN (?)", bun.In(modelIDs))
if modelQuery, err = modelauth.AuthZProvider.Get().
FilterReadableModelsQuery(ctx, *curUser, modelQuery); err != nil {
return nil, err
}
err = modelQuery.Scan(ctx)
errMsg, err := makeRegisteredCheckpointErrorMessage(
ctx,
"this subset of checkpoints provided are in the model registry and cannot be deleted: %v.",
registeredCheckpointUUIDs,
)
if err != nil {
return nil, err
}
accessibleModels := make(map[int]bool, len(models))
for _, v := range models {
accessibleModels[v.ID] = true
}
var checkpointMsgs []string
for k, v := range registeredCheckpointUUIDs {
if _, ok := accessibleModels[v.ID]; ok {
msg := fmt.Sprintf("%v, registered to %v (model #%d), version %d", k, v.Name, v.ID, v.Version)
checkpointMsgs = append(checkpointMsgs, msg)
} else {
msg := fmt.Sprintf("%v, registered to an unknown model", k)
checkpointMsgs = append(checkpointMsgs, msg)
}
}
checkpointList := strings.Join(checkpointMsgs, ", ")
return nil, status.Errorf(codes.InvalidArgument,
"this subset of checkpoints provided are in the model registry and cannot be deleted: %v.",
checkpointList)
return nil, status.Errorf(codes.InvalidArgument, *errMsg)
}

taskSpec := *a.m.taskSpec
Expand Down

0 comments on commit 4c8434f

Please sign in to comment.