diff --git a/api/server/sdk/volume_migrate.go b/api/server/sdk/volume_migrate.go index c68e523bb..d46edc8b8 100644 --- a/api/server/sdk/volume_migrate.go +++ b/api/server/sdk/volume_migrate.go @@ -162,6 +162,28 @@ func (s *VolumeServer) volumeMigrate( }, nil } +func (s *VolumeServer) checkMigrationPermissions(ctx context.Context, taskId string) error { + // Inspect migration to get VolumeIds + resp, err := s.driver(ctx).CloudMigrateStatus(&api.CloudMigrateStatusRequest{ + TaskId: taskId, + }) + if err != nil { + return status.Errorf(codes.Internal, "Failed to get migration information : %v", err) + } + + // Check that a user has access to all volumes being migrated + for _, cluster := range resp.Info { + for _, migrateInfo := range cluster.List { + if err := checkAccessFromDriverForVolumeId(ctx, s.driver(ctx), + migrateInfo.GetLocalVolumeId(), api.Ownership_Read); err != nil { + return err + } + } + } + + return nil +} + // Cancel or stop a ongoing migration func (s *VolumeServer) Cancel( ctx context.Context, @@ -170,12 +192,16 @@ func (s *VolumeServer) Cancel( if s.cluster() == nil || s.driver(ctx) == nil { return nil, status.Error(codes.Unavailable, "Resource has not been initialized") } - if req.GetRequest() == nil { return nil, status.Errorf(codes.InvalidArgument, "Must supply valid request") } else if len(req.GetRequest().GetTaskId()) == 0 { return nil, status.Errorf(codes.InvalidArgument, "Must supply valid Task ID") } + + // Check if the user has access to all volumes associated with the TaskID + if err := s.checkMigrationPermissions(ctx, req.GetRequest().GetTaskId()); err != nil { + return nil, err + } err := s.driver(ctx).CloudMigrateCancel(req.GetRequest()) if err != nil { return nil, status.Errorf(codes.Internal, "Cannot stop migration for %s : %v", @@ -184,6 +210,62 @@ func (s *VolumeServer) Cancel( return &api.SdkCloudMigrateCancelResponse{}, nil } +// filterStatusResponseForPermissions alters the response object to only return objects +// that we have access to. While it seems too complicated, it minimizes the number of driver calls. +func (s *VolumeServer) filterStatusResponseForPermissions( + ctx context.Context, + resp *api.CloudMigrateStatusResponse) (*api.CloudMigrateStatusResponse, error) { + allVolIds := make([]string, 0) + + // get all volume ids to inspect + for _, cluster := range resp.Info { + for _, migrateInfo := range cluster.List { + allVolIds = append(allVolIds, migrateInfo.GetLocalVolumeId()) + } + } + + // When no vol ids are found, exit quickly + if len(allVolIds) == 0 { + return resp, nil + } + + // get all volumes from single inspect + allVols, err := s.driver(ctx).Inspect(allVolIds) + if err != nil { + return nil, err + } + + // check which volumes we have access to + volAccessPermitted := make(map[string]bool) + for _, vol := range allVols { + if !vol.IsPermitted(ctx, api.Ownership_Read) { + volAccessPermitted[vol.Id] = true + } + } + + // Generate new response with permitted migrate info based + // on which volume ids we have access to + var filteredResp api.CloudMigrateStatusResponse + filteredResp.Info = make(map[string]*api.CloudMigrateInfoList) + for clusterId, cluster := range resp.Info { + filteredCluster := api.CloudMigrateInfoList{} + filteredCluster.List = make([]*api.CloudMigrateInfo, 0) + + for _, migrateInfo := range cluster.List { + if found := volAccessPermitted[migrateInfo.GetLocalVolumeId()]; found { + filteredCluster.List = append(filteredCluster.List, migrateInfo) + } + } + + // Do not return empty clusters we don't have access to. + if len(cluster.List) > 0 { + filteredResp.Info[clusterId] = &filteredCluster + } + } + + return &filteredResp, nil +} + // Status of ongoing migration func (s *VolumeServer) Status( ctx context.Context, @@ -197,6 +279,13 @@ func (s *VolumeServer) Status( if err != nil { return nil, status.Errorf(codes.Internal, "Cannot get status of migration : %v", err) } + + // Filter out volumes we don't have access to + resp, err = s.filterStatusResponseForPermissions(ctx, resp) + if err != nil { + return nil, err + } + return &api.SdkCloudMigrateStatusResponse{ Result: resp, }, nil diff --git a/api/server/sdk/volume_migrate_test.go b/api/server/sdk/volume_migrate_test.go index 503da45dc..2b9ff65dc 100644 --- a/api/server/sdk/volume_migrate_test.go +++ b/api/server/sdk/volume_migrate_test.go @@ -301,14 +301,23 @@ func TestVolumeMigrate_CancelSuccess(t *testing.T) { s := newTestServer(t) defer s.Stop() + taskId := "1" req := &api.SdkCloudMigrateCancelRequest{ Request: &api.CloudMigrateCancelRequest{ - TaskId: "1"}, + TaskId: taskId, + }, } + resp := &api.CloudMigrateStatusResponse{} + s.MockDriver().EXPECT(). + CloudMigrateStatus(&api.CloudMigrateStatusRequest{ + TaskId: taskId, + }). + Return(resp, nil) + s.MockDriver().EXPECT(). CloudMigrateCancel(&api.CloudMigrateCancelRequest{ - TaskId: "1", + TaskId: taskId, }). Return(nil) // Setup client @@ -361,9 +370,10 @@ func TestVolumeMigrate_StatusSucess(t *testing.T) { req := &api.SdkCloudMigrateStatusRequest{ Request: &api.CloudMigrateStatusRequest{}, } + vId := "VID" info := &api.CloudMigrateInfo{ ClusterId: "Source", - LocalVolumeId: "VID", + LocalVolumeId: vId, LocalVolumeName: "VNAME", RemoteVolumeId: "RID", CloudbackupId: "CBKUPID", @@ -383,6 +393,14 @@ func TestVolumeMigrate_StatusSucess(t *testing.T) { s.MockDriver().EXPECT(). CloudMigrateStatus(&api.CloudMigrateStatusRequest{}). Return(resp, nil) + + inspectResp := &api.Volume{ + Id: vId, + } + s.MockDriver().EXPECT(). + Inspect([]string{vId}). + Return([]*api.Volume{inspectResp}, nil) + // Setup client c := api.NewOpenStorageMigrateClient(s.Conn()) r, err := c.Status(context.Background(), req)