diff --git a/cmd/csi-addons/main.go b/cmd/csi-addons/main.go index 17f04a68b..c2a3e769f 100644 --- a/cmd/csi-addons/main.go +++ b/cmd/csi-addons/main.go @@ -37,15 +37,16 @@ const ( // command contains the parsed arguments that were passed while running the // executable. type command struct { - endpoint string - stagingPath string - operation string - persistentVolume string - drivername string - secret string - cidrs string - clusterid string - legacy bool + endpoint string + stagingPath string + operation string + persistentVolume string + volumeGroupReplicationContent string + drivername string + secret string + cidrs string + clusterid string + legacy bool } // cmd is the single instance of the command struct, used inside main(). @@ -58,6 +59,7 @@ func init() { flag.StringVar(&cmd.stagingPath, "stagingpath", stagingPath, "staging path") flag.StringVar(&cmd.operation, "operation", "", "csi-addons operation") flag.StringVar(&cmd.persistentVolume, "persistentvolume", "", "name of the PersistentVolume") + flag.StringVar(&cmd.volumeGroupReplicationContent, "volumegroupreplicationcontent", "", "name of the VolumeGroupReplicationContent") flag.StringVar(&cmd.drivername, "drivername", "", "name of the CSI driver") flag.StringVar(&cmd.secret, "secret", "", "kubernetes secret in the format `namespace/name`") flag.StringVar(&cmd.cidrs, "cidrs", "", "comma separated list of cidrs") diff --git a/cmd/csi-addons/replication.go b/cmd/csi-addons/replication.go index 264fde0a0..9d8cca551 100644 --- a/cmd/csi-addons/replication.go +++ b/cmd/csi-addons/replication.go @@ -36,6 +36,7 @@ type VolumeReplicationBase struct { secretName string secretNamespace string volumeID string + groupID string } func (rep *VolumeReplicationBase) Init(c *command) error { @@ -59,21 +60,38 @@ func (rep *VolumeReplicationBase) Init(c *command) error { return errors.New("secret name is not set") } - pv, err := getKubernetesClient().CoreV1().PersistentVolumes().Get(context.Background(), c.persistentVolume, metav1.GetOptions{}) - if err != nil { - return fmt.Errorf("failed to get pv %q", c.persistentVolume) - } - - if pv.Spec.CSI == nil { - return fmt.Errorf("pv %q is not a CSI volume", c.persistentVolume) + if c.persistentVolume != "" && c.volumeGroupReplicationContent != "" { + return errors.New("only one of persistentVolume or volumeGroupReplicationContent should be set") } - if pv.Spec.CSI.VolumeHandle == "" { - return errors.New("volume ID is not set") + if c.persistentVolume != "" { + pv, err := getKubernetesClient().CoreV1().PersistentVolumes().Get(context.Background(), c.persistentVolume, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get pv %q", c.persistentVolume) + } + + if pv.Spec.CSI == nil { + return fmt.Errorf("pv %q is not a CSI volume", c.persistentVolume) + } + + if pv.Spec.CSI.VolumeHandle == "" { + return errors.New("volume ID is not set") + } + rep.volumeID = pv.Spec.CSI.VolumeHandle + return nil + } else if c.volumeGroupReplicationContent != "" { + vgrc, err := getVolumeReplicationClient().getVolumeGroupReplicationContent(context.Background(), c.volumeGroupReplicationContent, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get VolumeGroupReplicationContent %q", c.volumeGroupReplicationContent) + } + if vgrc.Spec.VolumeGroupReplicationHandle == "" { + return errors.New("volume group ID is not set") + } + rep.groupID = vgrc.Spec.VolumeGroupReplicationHandle + return nil } - rep.volumeID = pv.Spec.CSI.VolumeHandle - return nil + return errors.New("either persistentVolume or volumeGroupReplicationContent should be set") } // EnableVolumeReplication executes the EnableVolumeReplication operation. @@ -81,6 +99,30 @@ type EnableVolumeReplication struct { VolumeReplicationBase } +func (v VolumeReplicationBase) setReplicationSource(req *proto.ReplicationSource) error { + switch { + case req == nil: + return errors.New("replication source is not set") + case v.volumeID != "" && v.groupID != "": + return errors.New("only one of volumeID or groupID should be set") + case v.volumeID != "": + req.Type = &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: v.volumeID, + }, + } + return nil + case v.groupID != "": + req.Type = &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: v.groupID, + }, + } + return nil + } + return errors.New("both volumeID and groupID is not set") +} + var _ = registerOperation("EnableVolumeReplication", &EnableVolumeReplication{}) func (rep *EnableVolumeReplication) Execute() error { @@ -91,10 +133,12 @@ func (rep *EnableVolumeReplication) Execute() error { req := &proto.EnableVolumeReplicationRequest{ SecretName: rep.secretName, SecretNamespace: rep.secretNamespace, - VolumeId: rep.volumeID, } - - _, err := rs.EnableVolumeReplication(context.TODO(), req) + err := rep.setReplicationSource(req.ReplicationSource) + if err != nil { + return err + } + _, err = rs.EnableVolumeReplication(context.TODO(), req) if err != nil { return err } @@ -119,10 +163,13 @@ func (rep *DisableVolumeReplication) Execute() error { req := &proto.DisableVolumeReplicationRequest{ SecretName: rep.secretName, SecretNamespace: rep.secretNamespace, - VolumeId: rep.volumeID, + } + err := rep.setReplicationSource(req.ReplicationSource) + if err != nil { + return err } - _, err := rs.DisableVolumeReplication(context.TODO(), req) + _, err = rs.DisableVolumeReplication(context.TODO(), req) if err != nil { return err } @@ -147,10 +194,12 @@ func (rep *PromoteVolume) Execute() error { req := &proto.PromoteVolumeRequest{ SecretName: rep.secretName, SecretNamespace: rep.secretNamespace, - VolumeId: rep.volumeID, } - - _, err := rs.PromoteVolume(context.TODO(), req) + err := rep.setReplicationSource(req.ReplicationSource) + if err != nil { + return err + } + _, err = rs.PromoteVolume(context.TODO(), req) if err != nil { return err } @@ -175,10 +224,12 @@ func (rep *DemoteVolume) Execute() error { req := &proto.DemoteVolumeRequest{ SecretName: rep.secretName, SecretNamespace: rep.secretNamespace, - VolumeId: rep.volumeID, } - - _, err := rs.DemoteVolume(context.TODO(), req) + err := rep.setReplicationSource(req.ReplicationSource) + if err != nil { + return err + } + _, err = rs.DemoteVolume(context.TODO(), req) if err != nil { return err } @@ -203,10 +254,12 @@ func (rep *ResyncVolume) Execute() error { req := &proto.ResyncVolumeRequest{ SecretName: rep.secretName, SecretNamespace: rep.secretNamespace, - VolumeId: rep.volumeID, } - - _, err := rs.ResyncVolume(context.TODO(), req) + err := rep.setReplicationSource(req.ReplicationSource) + if err != nil { + return err + } + _, err = rs.ResyncVolume(context.TODO(), req) if err != nil { return err } @@ -231,7 +284,10 @@ func (rep *GetVolumeReplicationInfo) Execute() error { req := &proto.GetVolumeReplicationInfoRequest{ SecretName: rep.secretName, SecretNamespace: rep.secretNamespace, - VolumeId: rep.volumeID, + } + err := rep.setReplicationSource(req.ReplicationSource) + if err != nil { + return err } res, err := rs.GetVolumeReplicationInfo(context.TODO(), req) diff --git a/cmd/csi-addons/replicationClient.go b/cmd/csi-addons/replicationClient.go new file mode 100644 index 000000000..ec41862ad --- /dev/null +++ b/cmd/csi-addons/replicationClient.go @@ -0,0 +1,70 @@ +/* +Copyright 2024 The Ceph-CSI Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "context" + + replicationv1alpha1 "github.com/csi-addons/kubernetes-csi-addons/apis/replication.storage/v1alpha1" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/serializer" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/rest" +) + +type replicationClient struct { + restClient *rest.RESTClient +} + +func getVolumeReplicationClient() *replicationClient { + config, err := rest.InClusterConfig() + if err != nil { + panic(err.Error()) + } + scheme, err := replicationv1alpha1.SchemeBuilder.Build() + if err != nil { + panic(err.Error()) + } + + crdConfig := *config + crdConfig.ContentConfig.GroupVersion = &replicationv1alpha1.GroupVersion + crdConfig.APIPath = "/apis" + crdConfig.NegotiatedSerializer = serializer.NewCodecFactory(scheme) + crdConfig.UserAgent = rest.DefaultKubernetesUserAgent() + + restClient, err := rest.UnversionedRESTClientFor(&crdConfig) + if err != nil { + panic(err) + } + + return &replicationClient{restClient: restClient} +} + +func (r *replicationClient) getVolumeGroupReplicationContent(ctx context.Context, name string, opts metav1.GetOptions) (*replicationv1alpha1.VolumeGroupReplicationContent, error) { + result := replicationv1alpha1.VolumeGroupReplicationContent{} + err := r.restClient. + Get(). + Namespace(""). + Resource("volumegroupreplicationcontents"). + Name(name). + VersionedParams(&opts, scheme.ParameterCodec). + Do(ctx). + Into(&result) + + return &result, err +} diff --git a/controllers/replication.storage/finalizers.go b/controllers/replication.storage/finalizers.go index c1d6d233a..35d280c64 100644 --- a/controllers/replication.storage/finalizers.go +++ b/controllers/replication.storage/finalizers.go @@ -31,6 +31,7 @@ import ( const ( volumeReplicationFinalizer = "replication.storage.openshift.io" pvcReplicationFinalizer = "replication.storage.openshift.io/pvc-protection" + vgrReplicationFinalizer = "replication.storage.openshift.io/vgr-protection" ) // addFinalizerToVR adds the VR finalizer on the VolumeReplication instance. @@ -94,3 +95,34 @@ func (r *VolumeReplicationReconciler) removeFinalizerFromPVC(logger logr.Logger, return nil } + +// addFinalizerToVGR adds the VR finalizer on the VolumeGroupReplication. +func (r *VolumeReplicationReconciler) addFinalizerToVGR(logger logr.Logger, vgr *replicationv1alpha1.VolumeGroupReplication) error { + if !slices.Contains(vgr.ObjectMeta.Finalizers, vgrReplicationFinalizer) { + logger.Info("adding finalizer to VolumeGroupReplication object", "Finalizer", vgrReplicationFinalizer) + vgr.ObjectMeta.Finalizers = append(vgr.ObjectMeta.Finalizers, vgrReplicationFinalizer) + if err := r.Client.Update(context.TODO(), vgr); err != nil { + return fmt.Errorf("failed to add finalizer (%s) to VolumeGroupReplication resource"+ + " (%s/%s) %w", + vgrReplicationFinalizer, vgr.Namespace, vgr.Name, err) + } + } + + return nil +} + +// removeFinalizerFromVGR removes the VR finalizer on VolumeGroupReplication. +func (r *VolumeReplicationReconciler) removeFinalizerFromVGR(logger logr.Logger, vgr *replicationv1alpha1.VolumeGroupReplication, +) error { + if slices.Contains(vgr.ObjectMeta.Finalizers, vgrReplicationFinalizer) { + logger.Info("removing finalizer from VolumeGroupReplication object", "Finalizer", vgrReplicationFinalizer) + vgr.ObjectMeta.Finalizers = util.RemoveFromSlice(vgr.ObjectMeta.Finalizers, vgrReplicationFinalizer) + if err := r.Client.Update(context.TODO(), vgr); err != nil { + return fmt.Errorf("failed to remove finalizer (%s) from VolumeGroupReplication resource"+ + " (%s/%s), %w", + vgrReplicationFinalizer, vgr.Namespace, vgr.Name, err) + } + } + + return nil +} diff --git a/controllers/replication.storage/replication/replication.go b/controllers/replication.storage/replication/replication.go index 772257e2b..090f9bde0 100644 --- a/controllers/replication.storage/replication/replication.go +++ b/controllers/replication.storage/replication/replication.go @@ -17,6 +17,8 @@ limitations under the License. package replication import ( + "errors" + "github.com/csi-addons/kubernetes-csi-addons/internal/client" "google.golang.org/grpc/codes" @@ -38,6 +40,7 @@ type Response struct { // CommonRequestParameters holds the common parameters across replication operations. type CommonRequestParameters struct { VolumeID string + GroupID string ReplicationID string Parameters map[string]string SecretName string @@ -45,9 +48,25 @@ type CommonRequestParameters struct { Replication client.VolumeReplication } +func (r *Replication) getID() (string, error) { + switch { + case r.Params.VolumeID != "" && r.Params.GroupID != "": + return "", errors.New("VolumeID and GroupID cannot be provided together") + case r.Params.VolumeID != "": + return r.Params.VolumeID, nil + case r.Params.GroupID != "": + return r.Params.GroupID, nil + } + return "", errors.New("VolumeID or GroupID must be provided") +} + func (r *Replication) Enable() *Response { + id, err := r.getID() + if err != nil { + return &Response{Error: err} + } resp, err := r.Params.Replication.EnableVolumeReplication( - r.Params.VolumeID, + id, r.Params.ReplicationID, r.Params.SecretName, r.Params.SecretNamespace, @@ -58,8 +77,12 @@ func (r *Replication) Enable() *Response { } func (r *Replication) Disable() *Response { + id, err := r.getID() + if err != nil { + return &Response{Error: err} + } resp, err := r.Params.Replication.DisableVolumeReplication( - r.Params.VolumeID, + id, r.Params.ReplicationID, r.Params.SecretName, r.Params.SecretNamespace, @@ -70,8 +93,12 @@ func (r *Replication) Disable() *Response { } func (r *Replication) Promote() *Response { + id, err := r.getID() + if err != nil { + return &Response{Error: err} + } resp, err := r.Params.Replication.PromoteVolume( - r.Params.VolumeID, + id, r.Params.ReplicationID, r.Force, r.Params.SecretName, @@ -83,8 +110,12 @@ func (r *Replication) Promote() *Response { } func (r *Replication) Demote() *Response { + id, err := r.getID() + if err != nil { + return &Response{Error: err} + } resp, err := r.Params.Replication.DemoteVolume( - r.Params.VolumeID, + id, r.Params.ReplicationID, r.Params.SecretName, r.Params.SecretNamespace, @@ -95,8 +126,12 @@ func (r *Replication) Demote() *Response { } func (r *Replication) Resync() *Response { + id, err := r.getID() + if err != nil { + return &Response{Error: err} + } resp, err := r.Params.Replication.ResyncVolume( - r.Params.VolumeID, + id, r.Params.ReplicationID, r.Force, r.Params.SecretName, @@ -108,8 +143,12 @@ func (r *Replication) Resync() *Response { } func (r *Replication) GetInfo() *Response { + id, err := r.getID() + if err != nil { + return &Response{Error: err} + } resp, err := r.Params.Replication.GetVolumeReplicationInfo( - r.Params.VolumeID, + id, r.Params.ReplicationID, r.Params.SecretName, r.Params.SecretNamespace, diff --git a/controllers/replication.storage/volumereplication_controller.go b/controllers/replication.storage/volumereplication_controller.go index e61925e52..75e2ace6e 100644 --- a/controllers/replication.storage/volumereplication_controller.go +++ b/controllers/replication.storage/volumereplication_controller.go @@ -18,6 +18,7 @@ package controllers import ( "context" + stderrors "errors" "fmt" "slices" "time" @@ -49,10 +50,11 @@ import ( ) const ( - pvcDataSource = "PersistentVolumeClaim" - volumeReplicationClass = "VolumeReplicationClass" - volumeReplication = "VolumeReplication" - defaultScheduleTime = time.Hour + pvcDataSource = "PersistentVolumeClaim" + volumeGroupReplicationDataSource = "VolumeGroupReplication" + volumeReplicationClass = "VolumeReplicationClass" + volumeReplication = "VolumeReplication" + defaultScheduleTime = time.Hour ) var ( @@ -72,10 +74,13 @@ type VolumeReplicationReconciler struct { Replication grpcClient.VolumeReplication } -// +kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplications,verbs=get;list;watch;update -// +kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplications/status,verbs=update -// +kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplications/finalizers,verbs=update -// +kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplicationclasses,verbs=get;list;watch +//+kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplications,verbs=get;list;watch;update +//+kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplications/status,verbs=update +//+kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplications/finalizers,verbs=update +//+kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumereplicationclasses,verbs=get;list;watch +//+kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumegroupreplications,verbs=get;list;watch +//+kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumegroupreplications/finalizers,verbs=update +//+kubebuilder:rbac:groups=replication.storage.openshift.io,resources=volumegroupreplicationcontents,verbs=get;list;watch //+kubebuilder:rbac:groups=core,resources=persistentvolumeclaims/finalizers,verbs=update //+kubebuilder:rbac:groups=core,resources=persistentvolumeclaims,verbs=get;list;watch @@ -134,10 +139,16 @@ func (r *VolumeReplicationReconciler) Reconcile(ctx context.Context, req ctrl.Re secretNamespace := vrcObj.Spec.Parameters[prefixedReplicationSecretNamespaceKey] var ( + // var for pvc replication volumeHandle string pvc *corev1.PersistentVolumeClaim pv *corev1.PersistentVolume pvErr error + // var for volume group replication + groupHandle string + vgrc *replicationv1alpha1.VolumeGroupReplicationContent + vgr *replicationv1alpha1.VolumeGroupReplication + vgrErr error ) replicationHandle := instance.Spec.ReplicationHandle @@ -158,6 +169,20 @@ func (r *VolumeReplicationReconciler) Reconcile(ctx context.Context, req ctrl.Re } volumeHandle = pv.Spec.CSI.VolumeHandle + logger.Info("volume handle", "VolumeHandleName", volumeHandle) + case volumeGroupReplicationDataSource: + vgr, vgrc, vgrErr = r.getVolumeGroupReplicationDataSource(logger, nameSpacedName) + if vgrErr != nil { + logger.Error(vgrErr, "failed to get VolumeGroupReplication", "VGRName", instance.Spec.DataSource.Name) + setFailureCondition(instance) + uErr := r.updateReplicationStatus(instance, logger, getCurrentReplicationState(instance), vgrErr.Error()) + if uErr != nil { + logger.Error(uErr, "failed to update volumeReplication status", "VRName", instance.Name) + } + return ctrl.Result{}, vgrErr + } + groupHandle = vgrc.Spec.VolumeGroupReplicationHandle + logger.Info("Group handle", "GroupHandleName", groupHandle) default: err = fmt.Errorf("unsupported datasource kind") logger.Error(err, "given kind not supported", "Kind", instance.Spec.DataSource.Kind) @@ -170,12 +195,11 @@ func (r *VolumeReplicationReconciler) Reconcile(ctx context.Context, req ctrl.Re return ctrl.Result{}, nil } - logger.Info("volume handle", "VolumeHandleName", volumeHandle) if replicationHandle != "" { logger.Info("Replication handle", "ReplicationHandleName", replicationHandle) } - replicationClient, err := r.getReplicationClient(ctx, vrcObj.Spec.Provisioner) + replicationClient, err := r.getReplicationClient(ctx, vrcObj.Spec.Provisioner, instance.Spec.DataSource.Kind) if err != nil { logger.Error(err, "Failed to get ReplicationClient") @@ -187,6 +211,7 @@ func (r *VolumeReplicationReconciler) Reconcile(ctx context.Context, req ctrl.Re instance: instance, commonRequestParameters: replication.CommonRequestParameters{ VolumeID: volumeHandle, + GroupID: groupHandle, ReplicationID: replicationHandle, Parameters: parameters, SecretName: secretName, @@ -203,17 +228,32 @@ func (r *VolumeReplicationReconciler) Reconcile(ctx context.Context, req ctrl.Re return reconcile.Result{}, err } + switch instance.Spec.DataSource.Kind { + case pvcDataSource: + err = r.annotatePVCWithOwner(ctx, logger, req.Name, pvc) + if err != nil { + logger.Error(err, "Failed to annotate PVC owner") + return ctrl.Result{}, err + } - err = r.annotatePVCWithOwner(ctx, logger, req.Name, pvc) - if err != nil { - logger.Error(err, "Failed to annotate PVC owner") - return ctrl.Result{}, err - } + if err = r.addFinalizerToPVC(logger, pvc); err != nil { + logger.Error(err, "Failed to add PersistentVolumeClaim finalizer") - if err = r.addFinalizerToPVC(logger, pvc); err != nil { - logger.Error(err, "Failed to add PersistentVolumeClaim finalizer") + return reconcile.Result{}, err + } + case volumeGroupReplicationDataSource: + err = r.annotateVolumeGroupReplicationWithOwner(ctx, logger, req.Name, vgr) + if err != nil { + logger.Error(err, "Failed to annotate VolumeGroupReplication owner") - return reconcile.Result{}, err + return ctrl.Result{}, err + } + + if err = r.addFinalizerToVGR(logger, vgr); err != nil { + logger.Error(err, "Failed to add VolumeGroupReplication finalizer") + + return reconcile.Result{}, err + } } } else { if slices.Contains(instance.GetFinalizers(), volumeReplicationFinalizer) { @@ -223,19 +263,32 @@ func (r *VolumeReplicationReconciler) Reconcile(ctx context.Context, req ctrl.Re return ctrl.Result{}, err } + switch instance.Spec.DataSource.Kind { + case pvcDataSource: + if err = r.removeOwnerFromPVCAnnotation(ctx, logger, pvc); err != nil { + logger.Error(err, "Failed to remove VolumeReplication annotation from PersistentVolumeClaim") - if err = r.removeOwnerFromPVCAnnotation(ctx, logger, pvc); err != nil { - logger.Error(err, "Failed to remove VolumeReplication annotation from PersistentVolumeClaim") + return reconcile.Result{}, err + } - return reconcile.Result{}, err - } + if err = r.removeFinalizerFromPVC(logger, pvc); err != nil { + logger.Error(err, "Failed to remove PersistentVolumeClaim finalizer") - if err = r.removeFinalizerFromPVC(logger, pvc); err != nil { - logger.Error(err, "Failed to remove PersistentVolumeClaim finalizer") + return reconcile.Result{}, err + } + case volumeGroupReplicationDataSource: + if err = r.removeOwnerFromVGRAnnotation(ctx, logger, vgr); err != nil { + logger.Error(err, "Failed to remove VolumeReplication annotation from VolumeGroupReplication") - return reconcile.Result{}, err - } + return reconcile.Result{}, err + } + if err = r.removeFinalizerFromVGR(logger, vgr); err != nil { + logger.Error(err, "Failed to remove VolumeGroupReplication finalizer") + + return reconcile.Result{}, err + } + } // once all finalizers have been removed, the object will be // deleted if err = r.removeFinalizerFromVR(logger, instance); err != nil { @@ -454,7 +507,7 @@ func getInfoReconcileInterval(parameters map[string]string, logger logr.Logger) return scheduleTime / 2 } -func (r *VolumeReplicationReconciler) getReplicationClient(ctx context.Context, driverName string) (grpcClient.VolumeReplication, error) { +func (r *VolumeReplicationReconciler) getReplicationClient(ctx context.Context, driverName, dataSource string) (grpcClient.VolumeReplication, error) { conn, err := r.Connpool.GetLeaderByDriver(ctx, r.Client, driverName) if err != nil { return nil, fmt.Errorf("no leader for the ControllerService of driver %q", driverName) @@ -468,7 +521,11 @@ func (r *VolumeReplicationReconciler) getReplicationClient(ctx context.Context, // validate of VOLUME_REPLICATION capability is enabled by the storage driver. if cap.GetVolumeReplication().GetType() == identity.Capability_VolumeReplication_VOLUME_REPLICATION { - return grpcClient.NewReplicationClient(conn.Client, r.Timeout), nil + if dataSource == pvcDataSource { + return grpcClient.NewVolumeReplicationClient(conn.Client, r.Timeout), nil + } else if dataSource == volumeGroupReplicationDataSource { + return grpcClient.NewVolumeGroupReplicationClient(conn.Client, r.Timeout), nil + } } } @@ -758,3 +815,83 @@ func getCurrentTime() *metav1.Time { return &metav1NowTime } + +// annotateVolumeGroupReplicationWithOwner will add the VolumeReplication details to the VGR annotations. +func (r *VolumeReplicationReconciler) annotateVolumeGroupReplicationWithOwner(ctx context.Context, logger logr.Logger, reqOwnerName string, vgr *replicationv1alpha1.VolumeGroupReplication) error { + if vgr.ObjectMeta.Annotations == nil { + vgr.ObjectMeta.Annotations = map[string]string{} + } + + currentOwnerName := vgr.ObjectMeta.Annotations[replicationv1alpha1.VolumeReplicationNameAnnotation] + if currentOwnerName == "" { + logger.Info("setting owner on VGR annotation", "Name", vgr.Name, "owner", reqOwnerName) + vgr.ObjectMeta.Annotations[replicationv1alpha1.VolumeReplicationNameAnnotation] = reqOwnerName + err := r.Update(ctx, vgr) + if err != nil { + logger.Error(err, "Failed to update VGR annotation", "Name", vgr.Name) + + return fmt.Errorf("failed to update VGR %q annotation for VolumeReplication: %w", + vgr.Name, err) + } + + return nil + } + + if currentOwnerName != reqOwnerName { + logger.Info("cannot change the owner of vgr", + "VGR name", vgr.Name, + "current owner", currentOwnerName, + "requested owner", reqOwnerName) + + return fmt.Errorf("VGR %q not owned by VolumeReplication %q", + vgr.Name, reqOwnerName) + } + + return nil +} + +func (r *VolumeReplicationReconciler) getVolumeGroupReplicationDataSource(logger logr.Logger, req types.NamespacedName) (*replicationv1alpha1.VolumeGroupReplication, *replicationv1alpha1.VolumeGroupReplicationContent, error) { + volumeGroupReplication := &replicationv1alpha1.VolumeGroupReplication{} + err := r.Client.Get(context.TODO(), req, volumeGroupReplication) + if err != nil { + if errors.IsNotFound(err) { + logger.Error(err, "VolumeGroupReplication not found", "VolumeGroupReplication Name", req.Name) + } + + return nil, nil, err + } + vgrcName := volumeGroupReplication.Spec.VolumeGroupReplicationContentName + if vgrcName == "" { + logger.Error(err, "VolumeGroupReplicationContentName is empty", "VolumeGroupReplication Name", req.Name) + + return nil, nil, stderrors.New("VolumeGroupReplicationContentName is empty") + } + + vgrcReq := types.NamespacedName{Name: vgrcName} + volumeGroupReplicationContent := &replicationv1alpha1.VolumeGroupReplicationContent{} + err = r.Client.Get(context.TODO(), vgrcReq, volumeGroupReplicationContent) + if err != nil { + if errors.IsNotFound(err) { + logger.Error(err, "VolumeGroupReplicationContent not found", "VolumeGroupReplicationContent Name", vgrcName) + } + + return nil, nil, err + } + + return volumeGroupReplication, volumeGroupReplicationContent, nil +} + +// removeOwnerFromVGRAnnotation removes the VolumeReplication owner from the VGR annotations. +func (r *VolumeReplicationReconciler) removeOwnerFromVGRAnnotation(ctx context.Context, logger logr.Logger, vgr *replicationv1alpha1.VolumeGroupReplication) error { + if _, ok := vgr.ObjectMeta.Annotations[replicationv1alpha1.VolumeReplicationNameAnnotation]; ok { + logger.Info("removing owner annotation from VolumeGroupReplication object", "Annotation", replicationv1alpha1.VolumeReplicationNameAnnotation) + delete(vgr.ObjectMeta.Annotations, replicationv1alpha1.VolumeReplicationNameAnnotation) + if err := r.Client.Update(ctx, vgr); err != nil { + return fmt.Errorf("failed to remove annotation %q from VolumeGroupReplication "+ + "%q %w", + replicationv1alpha1.VolumeReplicationNameAnnotation, vgr.Name, err) + } + } + + return nil +} diff --git a/internal/client/fake/replication-client.go b/internal/client/fake/replication.go similarity index 100% rename from internal/client/fake/replication-client.go rename to internal/client/fake/replication.go diff --git a/internal/client/group-replication.go b/internal/client/group-replication.go new file mode 100644 index 000000000..0f3b16b0f --- /dev/null +++ b/internal/client/group-replication.go @@ -0,0 +1,183 @@ +/* +Copyright 2024 The Kubernetes-CSI-Addons Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "time" + + "github.com/csi-addons/kubernetes-csi-addons/internal/proto" + + "google.golang.org/grpc" +) + +type volumeGroupReplicationClient struct { + client proto.ReplicationClient + timeout time.Duration +} + +var _ VolumeReplication = &volumeGroupReplicationClient{} + +// NewVolumeGroupReplicationClient returns VolumeReplication interface which has the RPC +// calls for volume group replication. +func NewVolumeGroupReplicationClient(cc *grpc.ClientConn, timeout time.Duration) VolumeReplication { + return &volumeGroupReplicationClient{client: proto.NewReplicationClient(cc), timeout: timeout} +} + +// EnableVolumeReplication RPC call to enable the volume group replication. +func (rc *volumeGroupReplicationClient) EnableVolumeReplication(groupID, replicationID string, + secretName, secretNamespace string, parameters map[string]string) (*proto.EnableVolumeReplicationResponse, error) { + req := &proto.EnableVolumeReplicationRequest{ + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: groupID, + }, + }, + }, + ReplicationId: replicationID, + Parameters: parameters, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + + createCtx, cancel := context.WithTimeout(context.Background(), rc.timeout) + defer cancel() + resp, err := rc.client.EnableVolumeReplication(createCtx, req) + + return resp, err +} + +// DisableVolumeReplication RPC call to disable the volume group replication. +func (rc *volumeGroupReplicationClient) DisableVolumeReplication(groupID, replicationID string, + secretName, secretNamespace string, parameters map[string]string) (*proto.DisableVolumeReplicationResponse, error) { + req := &proto.DisableVolumeReplicationRequest{ + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: groupID, + }, + }, + }, + ReplicationId: replicationID, + Parameters: parameters, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + + createCtx, cancel := context.WithTimeout(context.Background(), rc.timeout) + defer cancel() + resp, err := rc.client.DisableVolumeReplication(createCtx, req) + + return resp, err +} + +// PromoteVolume RPC call to promote the volume group. +func (rc *volumeGroupReplicationClient) PromoteVolume(groupID, replicationID string, + force bool, secretName, secretNamespace string, parameters map[string]string) (*proto.PromoteVolumeResponse, error) { + req := &proto.PromoteVolumeRequest{ + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: groupID, + }, + }, + }, + ReplicationId: replicationID, + Force: force, + Parameters: parameters, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + + createCtx, cancel := context.WithTimeout(context.Background(), rc.timeout) + defer cancel() + resp, err := rc.client.PromoteVolume(createCtx, req) + + return resp, err +} + +// DemoteVolume RPC call to demote the volume group. +func (rc *volumeGroupReplicationClient) DemoteVolume(groupID, replicationID string, + secretName, secretNamespace string, parameters map[string]string) (*proto.DemoteVolumeResponse, error) { + req := &proto.DemoteVolumeRequest{ + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: groupID, + }, + }, + }, + ReplicationId: replicationID, + Parameters: parameters, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + createCtx, cancel := context.WithTimeout(context.Background(), rc.timeout) + defer cancel() + resp, err := rc.client.DemoteVolume(createCtx, req) + + return resp, err +} + +// ResyncVolume RPC call to resync the volume group. +func (rc *volumeGroupReplicationClient) ResyncVolume(groupID, replicationID string, force bool, + secretName, secretNamespace string, parameters map[string]string) (*proto.ResyncVolumeResponse, error) { + req := &proto.ResyncVolumeRequest{ + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: groupID, + }, + }, + }, + ReplicationId: replicationID, + Parameters: parameters, + Force: force, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + + createCtx, cancel := context.WithTimeout(context.Background(), rc.timeout) + defer cancel() + resp, err := rc.client.ResyncVolume(createCtx, req) + + return resp, err +} + +// GetVolumeReplicationInfo RPC call to get volume group replication info. +func (rc *volumeGroupReplicationClient) GetVolumeReplicationInfo(groupID, replicationID, + secretName, secretNamespace string) (*proto.GetVolumeReplicationInfoResponse, error) { + req := &proto.GetVolumeReplicationInfoRequest{ + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: groupID, + }, + }, + }, + ReplicationId: replicationID, + SecretName: secretName, + SecretNamespace: secretNamespace, + } + + createCtx, cancel := context.WithTimeout(context.Background(), rc.timeout) + defer cancel() + resp, err := rc.client.GetVolumeReplicationInfo(createCtx, req) + + return resp, err +} diff --git a/internal/client/replication.go b/internal/client/replication.go new file mode 100644 index 000000000..5031dff74 --- /dev/null +++ b/internal/client/replication.go @@ -0,0 +1,40 @@ +/* +Copyright 2022 The Kubernetes-CSI-Addons Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "github.com/csi-addons/kubernetes-csi-addons/internal/proto" +) + +// VolumeReplication holds the methods required for replication. +type VolumeReplication interface { + // EnableVolumeReplication RPC call to enable the volume replication. + EnableVolumeReplication(id, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto.EnableVolumeReplicationResponse, error) + // DisableVolumeReplication RPC call to disable the volume replication. + DisableVolumeReplication(id, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto.DisableVolumeReplicationResponse, error) + // PromoteVolume RPC call to promote the volume. + PromoteVolume(id, replicationID string, force bool, secretName, secretNamespace string, parameters map[string]string) (*proto. + PromoteVolumeResponse, error) + // DemoteVolume RPC call to demote the volume. + DemoteVolume(id, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto. + DemoteVolumeResponse, error) + // ResyncVolume RPC call to resync the volume. + ResyncVolume(id, replicationID string, force bool, secretName, secretNamespace string, parameters map[string]string) (*proto. + ResyncVolumeResponse, error) + // GetVolumeReplicationInfo RPC call to get volume replication info. + GetVolumeReplicationInfo(id, replicationID, secretName, secretNamespace string) (*proto.GetVolumeReplicationInfoResponse, error) +} diff --git a/internal/client/replication-client.go b/internal/client/volume-replication.go similarity index 63% rename from internal/client/replication-client.go rename to internal/client/volume-replication.go index 98de25914..57f21401f 100644 --- a/internal/client/replication-client.go +++ b/internal/client/volume-replication.go @@ -25,41 +25,30 @@ import ( "google.golang.org/grpc" ) -type replicationClient struct { +type volumeReplicationClient struct { client proto.ReplicationClient timeout time.Duration } -// VolumeReplication holds the methods required for volume replication. -type VolumeReplication interface { - // EnableVolumeReplication RPC call to enable the volume replication. - EnableVolumeReplication(volumeID, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto.EnableVolumeReplicationResponse, error) - // DisableVolumeReplication RPC call to disable the volume replication. - DisableVolumeReplication(volumeID, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto.DisableVolumeReplicationResponse, error) - // PromoteVolume RPC call to promote the volume. - PromoteVolume(volumeID, replicationID string, force bool, secretName, secretNamespace string, parameters map[string]string) (*proto. - PromoteVolumeResponse, error) - // DemoteVolume RPC call to demote the volume. - DemoteVolume(volumeID, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto. - DemoteVolumeResponse, error) - // ResyncVolume RPC call to resync the volume. - ResyncVolume(volumeID, replicationID string, force bool, secretName, secretNamespace string, parameters map[string]string) (*proto. - ResyncVolumeResponse, error) - // GetVolumeReplicationInfo RPC call to get volume replication info. - GetVolumeReplicationInfo(volumeID, replicationID, secretName, secretNamespace string) (*proto.GetVolumeReplicationInfoResponse, error) -} +var _ VolumeReplication = &volumeReplicationClient{} // NewReplicationClient returns VolumeReplication interface which has the RPC // calls for replication. -func NewReplicationClient(cc *grpc.ClientConn, timeout time.Duration) VolumeReplication { - return &replicationClient{client: proto.NewReplicationClient(cc), timeout: timeout} +func NewVolumeReplicationClient(cc *grpc.ClientConn, timeout time.Duration) VolumeReplication { + return &volumeReplicationClient{client: proto.NewReplicationClient(cc), timeout: timeout} } // EnableVolumeReplication RPC call to enable the volume replication. -func (rc *replicationClient) EnableVolumeReplication(volumeID, replicationID string, +func (rc *volumeReplicationClient) EnableVolumeReplication(volumeID, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto.EnableVolumeReplicationResponse, error) { req := &proto.EnableVolumeReplicationRequest{ - VolumeId: volumeID, + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: volumeID, + }, + }, + }, ReplicationId: replicationID, Parameters: parameters, SecretName: secretName, @@ -74,10 +63,16 @@ func (rc *replicationClient) EnableVolumeReplication(volumeID, replicationID str } // DisableVolumeReplication RPC call to disable the volume replication. -func (rc *replicationClient) DisableVolumeReplication(volumeID, replicationID string, +func (rc *volumeReplicationClient) DisableVolumeReplication(volumeID, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto.DisableVolumeReplicationResponse, error) { req := &proto.DisableVolumeReplicationRequest{ - VolumeId: volumeID, + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: volumeID, + }, + }, + }, ReplicationId: replicationID, Parameters: parameters, SecretName: secretName, @@ -92,10 +87,16 @@ func (rc *replicationClient) DisableVolumeReplication(volumeID, replicationID st } // PromoteVolume RPC call to promote the volume. -func (rc *replicationClient) PromoteVolume(volumeID, replicationID string, +func (rc *volumeReplicationClient) PromoteVolume(volumeID, replicationID string, force bool, secretName, secretNamespace string, parameters map[string]string) (*proto.PromoteVolumeResponse, error) { req := &proto.PromoteVolumeRequest{ - VolumeId: volumeID, + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: volumeID, + }, + }, + }, ReplicationId: replicationID, Force: force, Parameters: parameters, @@ -111,10 +112,16 @@ func (rc *replicationClient) PromoteVolume(volumeID, replicationID string, } // DemoteVolume RPC call to demote the volume. -func (rc *replicationClient) DemoteVolume(volumeID, replicationID string, +func (rc *volumeReplicationClient) DemoteVolume(volumeID, replicationID string, secretName, secretNamespace string, parameters map[string]string) (*proto.DemoteVolumeResponse, error) { req := &proto.DemoteVolumeRequest{ - VolumeId: volumeID, + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: volumeID, + }, + }, + }, ReplicationId: replicationID, Parameters: parameters, SecretName: secretName, @@ -128,10 +135,16 @@ func (rc *replicationClient) DemoteVolume(volumeID, replicationID string, } // ResyncVolume RPC call to resync the volume. -func (rc *replicationClient) ResyncVolume(volumeID, replicationID string, force bool, +func (rc *volumeReplicationClient) ResyncVolume(volumeID, replicationID string, force bool, secretName, secretNamespace string, parameters map[string]string) (*proto.ResyncVolumeResponse, error) { req := &proto.ResyncVolumeRequest{ - VolumeId: volumeID, + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: volumeID, + }, + }, + }, ReplicationId: replicationID, Parameters: parameters, Force: force, @@ -147,10 +160,16 @@ func (rc *replicationClient) ResyncVolume(volumeID, replicationID string, force } // GetVolumeReplicationInfo RPC call to get volume replication info. -func (rc *replicationClient) GetVolumeReplicationInfo(volumeID, replicationID, +func (rc *volumeReplicationClient) GetVolumeReplicationInfo(volumeID, replicationID, secretName, secretNamespace string) (*proto.GetVolumeReplicationInfoResponse, error) { req := &proto.GetVolumeReplicationInfoRequest{ - VolumeId: volumeID, + ReplicationSource: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: volumeID, + }, + }, + }, ReplicationId: replicationID, SecretName: secretName, SecretNamespace: secretNamespace, diff --git a/internal/client/replication-client_test.go b/internal/client/volume-replication_test.go similarity index 100% rename from internal/client/replication-client_test.go rename to internal/client/volume-replication_test.go diff --git a/internal/sidecar/service/volumereplication.go b/internal/sidecar/service/volumereplication.go index 4dc25ead6..bcab6656a 100644 --- a/internal/sidecar/service/volumereplication.go +++ b/internal/sidecar/service/volumereplication.go @@ -17,6 +17,7 @@ package service import ( "context" + "errors" kube "github.com/csi-addons/kubernetes-csi-addons/internal/kubernetes" "github.com/csi-addons/kubernetes-csi-addons/internal/proto" @@ -68,17 +69,12 @@ func (rs *ReplicationServer) EnableVolumeReplication( Parameters: req.GetParameters(), Secrets: data, } - if req.VolumeId != "" { - // setting repReq.VolumeId for backward compatibility for volume replication of a given volume - repReq.VolumeId = req.GetVolumeId() // nolint:staticcheck - repReq.ReplicationSource = &csiReplication.ReplicationSource{ - Type: &csiReplication.ReplicationSource_Volume{ - Volume: &csiReplication.ReplicationSource_VolumeSource{ - VolumeId: req.GetVolumeId(), - }, - }, - } + err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource()) + if err != nil { + klog.Errorf("Failed to set replication source: %v", err) + return nil, status.Error(codes.Internal, err.Error()) } + _, err = rs.controllerClient.EnableVolumeReplication(ctx, repReq) if err != nil { @@ -106,16 +102,10 @@ func (rs *ReplicationServer) DisableVolumeReplication( Parameters: req.GetParameters(), Secrets: data, } - if req.GetVolumeId() != "" { - // setting repReq.VolumeId for backward compatibility for volume replication of a given volume - repReq.VolumeId = req.GetVolumeId() // nolint:staticcheck - repReq.ReplicationSource = &csiReplication.ReplicationSource{ - Type: &csiReplication.ReplicationSource_Volume{ - Volume: &csiReplication.ReplicationSource_VolumeSource{ - VolumeId: req.GetVolumeId(), - }, - }, - } + err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource()) + if err != nil { + klog.Errorf("Failed to set replication source: %v", err) + return nil, status.Error(codes.Internal, err.Error()) } _, err = rs.controllerClient.DisableVolumeReplication(ctx, repReq) @@ -145,16 +135,10 @@ func (rs *ReplicationServer) PromoteVolume( Force: req.GetForce(), Secrets: data, } - if req.GetVolumeId() != "" { - // setting repReq.VolumeId for backward compatibility for volume replication of a given volume - repReq.VolumeId = req.GetVolumeId() // nolint:staticcheck - repReq.ReplicationSource = &csiReplication.ReplicationSource{ - Type: &csiReplication.ReplicationSource_Volume{ - Volume: &csiReplication.ReplicationSource_VolumeSource{ - VolumeId: req.GetVolumeId(), - }, - }, - } + err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource()) + if err != nil { + klog.Errorf("Failed to set replication source: %v", err) + return nil, status.Error(codes.Internal, err.Error()) } _, err = rs.controllerClient.PromoteVolume(ctx, repReq) @@ -184,17 +168,12 @@ func (rs *ReplicationServer) DemoteVolume( Force: req.GetForce(), Secrets: data, } - if req.GetVolumeId() != "" { - // setting repReq.VolumeId for backward compatibility for volume replication of a given volume - repReq.VolumeId = req.GetVolumeId() // nolint:staticcheck - repReq.ReplicationSource = &csiReplication.ReplicationSource{ - Type: &csiReplication.ReplicationSource_Volume{ - Volume: &csiReplication.ReplicationSource_VolumeSource{ - VolumeId: req.GetVolumeId(), - }, - }, - } + err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource()) + if err != nil { + klog.Errorf("Failed to set replication source: %v", err) + return nil, status.Error(codes.Internal, err.Error()) } + _, err = rs.controllerClient.DemoteVolume(ctx, repReq) if err != nil { klog.Errorf("Failed to demote volume: %v", err) @@ -222,16 +201,10 @@ func (rs *ReplicationServer) ResyncVolume( Force: req.GetForce(), Secrets: data, } - if req.GetVolumeId() != "" { - // setting repReq.VolumeId for backward compatibility for volume replication of a given volume - repReq.VolumeId = req.GetVolumeId() // nolint:staticcheck - repReq.ReplicationSource = &csiReplication.ReplicationSource{ - Type: &csiReplication.ReplicationSource_Volume{ - Volume: &csiReplication.ReplicationSource_VolumeSource{ - VolumeId: req.GetVolumeId(), - }, - }, - } + err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource()) + if err != nil { + klog.Errorf("Failed to set replication source: %v", err) + return nil, status.Error(codes.Internal, err.Error()) } resp, err := rs.controllerClient.ResyncVolume(ctx, repReq) @@ -261,16 +234,10 @@ func (rs *ReplicationServer) GetVolumeReplicationInfo( ReplicationId: req.GetReplicationId(), Secrets: data, } - if req.GetVolumeId() != "" { - // setting repReq.VolumeId for backward compatibility for volume replication of a given volume - repReq.VolumeId = req.GetVolumeId() // nolint:staticcheck - repReq.ReplicationSource = &csiReplication.ReplicationSource{ - Type: &csiReplication.ReplicationSource_Volume{ - Volume: &csiReplication.ReplicationSource_VolumeSource{ - VolumeId: req.GetVolumeId(), - }, - }, - } + err = setReplicationSource(repReq.ReplicationSource, req.GetReplicationSource()) + if err != nil { + klog.Errorf("Failed to set replication source: %v", err) + return nil, status.Error(codes.Internal, err.Error()) } resp, err := rs.controllerClient.GetVolumeReplicationInfo(ctx, repReq) @@ -290,3 +257,28 @@ func (rs *ReplicationServer) GetVolumeReplicationInfo( LastSyncBytes: resp.GetLastSyncBytes(), }, nil } + +// setReplicationSource sets the replication source for the given ReplicationSource. +func setReplicationSource(src *csiReplication.ReplicationSource, req *proto.ReplicationSource) error { + if src == nil { + src = &csiReplication.ReplicationSource{} + } + + switch { + case req == nil: + return errors.New("replication source is required") + case req.GetVolume() == nil && req.GetVolumeGroup() == nil: + return errors.New("either volume or volume group is required") + case req.GetVolume() != nil: + src.Type = &csiReplication.ReplicationSource_Volume{Volume: &csiReplication.ReplicationSource_VolumeSource{ + VolumeId: req.GetVolume().GetVolumeId(), + }} + return nil + case req.GetVolumeGroup() != nil: + src.Type = &csiReplication.ReplicationSource_Volumegroup{Volumegroup: &csiReplication.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: req.GetVolumeGroup().GetVolumeGroupId(), + }} + return nil + } + return errors.New("either volume or volume group is required") +} diff --git a/internal/sidecar/service/volumereplication_test.go b/internal/sidecar/service/volumereplication_test.go new file mode 100644 index 000000000..44ae40656 --- /dev/null +++ b/internal/sidecar/service/volumereplication_test.go @@ -0,0 +1,99 @@ +/* +Copyright 2024 The Kubernetes-CSI-Addons Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package service + +import ( + "testing" + + "github.com/csi-addons/kubernetes-csi-addons/internal/proto" + csiReplication "github.com/csi-addons/spec/lib/go/replication" +) + +func Test_setReplicationSource(t *testing.T) { + type args struct { + src *csiReplication.ReplicationSource + req *proto.ReplicationSource + } + volID := "volumeID" + groupID := "groupID" + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "set replication source when request is nil", + args: args{ + src: &csiReplication.ReplicationSource{}, + req: nil, + }, + wantErr: true, + }, + { + name: "set replication source when request is not set", + args: args{ + src: &csiReplication.ReplicationSource{}, + req: &proto.ReplicationSource{}, + }, + wantErr: true, + }, + { + name: "set replication source when volume is set", + args: args{ + src: &csiReplication.ReplicationSource{}, + req: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_Volume{ + Volume: &proto.ReplicationSource_VolumeSource{ + VolumeId: volID, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "set replication source when volume group is set", + args: args{ + src: &csiReplication.ReplicationSource{}, + req: &proto.ReplicationSource{ + Type: &proto.ReplicationSource_VolumeGroup{ + VolumeGroup: &proto.ReplicationSource_VolumeGroupSource{ + VolumeGroupId: groupID, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := setReplicationSource(tt.args.src, tt.args.req); (err != nil) != tt.wantErr { + t.Errorf("setReplicationSource() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.args.req.GetVolume() != nil { + if tt.args.req.GetVolume().GetVolumeId() != volID { + t.Errorf("setReplicationSource() got = %v volumeID, expected = %v volumeID", tt.args.req.GetVolume().GetVolumeId(), volID) + } + } + if tt.args.req.GetVolumeGroup() != nil { + if tt.args.req.GetVolumeGroup().GetVolumeGroupId() != groupID { + t.Errorf("setReplicationSource() got = %v groupID, expected = %v volumeID", tt.args.req.GetVolumeGroup().GetVolumeGroupId(), groupID) + } + } + }) + } +}