Skip to content

Commit

Permalink
update inFlight cache to avoid race condition on volume operation
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyXiangLi authored and wongma7 committed Jul 1, 2021
1 parent 8546261 commit 8e3b578
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 68 deletions.
180 changes: 125 additions & 55 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,21 @@ func newControllerService(driverOptions *DriverOptions) controllerService {

func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
klog.V(4).Infof("CreateVolume: called with args %+v", *req)
volName := req.GetName()
if len(volName) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume name not provided")
if err := validateCreateVolumeRequest(req); err != nil {
return nil, err
}

volSizeBytes, err := getVolSizeBytes(req)
if err != nil {
return nil, err
}
volName := req.GetName()

volCaps := req.GetVolumeCapabilities()
if len(volCaps) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume capabilities not provided")
}

if !isValidVolumeCapabilities(volCaps) {
modes := util.GetAccessModes(volCaps)
stringModes := strings.Join(*modes, ", ")
errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported."
return nil, status.Error(codes.InvalidArgument, errString)
// check if a request is already in-flight
if ok := d.inFlight.Insert(volName); !ok {
msg := fmt.Sprintf("Create volume request for %s is already in progress", volName)
return nil, status.Error(codes.Aborted, msg)
}
defer d.inFlight.Delete(volName)

disk, err := d.cloud.GetDiskByName(ctx, volName, volSizeBytes)
if err != nil {
Expand Down Expand Up @@ -217,13 +211,6 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol
return newCreateVolumeResponse(disk), nil
}

// check if a request is already in-flight because the CreateVolume API is not idempotent
if ok := d.inFlight.Insert(req.String()); !ok {
msg := fmt.Sprintf("Create volume request for %s is already in progress", volName)
return nil, status.Error(codes.Aborted, msg)
}
defer d.inFlight.Delete(req.String())

// create a new volume
zone := pickAvailabilityZone(req.GetAccessibilityRequirements())
outpostArn := getOutpostArn(req.GetAccessibilityRequirements())
Expand Down Expand Up @@ -264,12 +251,40 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol
return newCreateVolumeResponse(disk), nil
}

func validateCreateVolumeRequest(req *csi.CreateVolumeRequest) error {
volName := req.GetName()
if len(volName) == 0 {
return status.Error(codes.InvalidArgument, "Volume name not provided")
}

volCaps := req.GetVolumeCapabilities()
if len(volCaps) == 0 {
return status.Error(codes.InvalidArgument, "Volume capabilities not provided")
}

if !isValidVolumeCapabilities(volCaps) {
modes := util.GetAccessModes(volCaps)
stringModes := strings.Join(*modes, ", ")
errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported."
return status.Error(codes.InvalidArgument, errString)
}
return nil
}

func (d *controllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
klog.V(4).Infof("DeleteVolume: called with args: %+v", *req)
if err := validateDeleteVolumeRequest(req); err != nil {
return nil, err
}

volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID not provided")

// check if a request is already in-flight
if ok := d.inFlight.Insert(volumeID); !ok {
msg := fmt.Sprintf(internal.VolumeOperationAlreadyExistsErrorMsg, volumeID)
return nil, status.Error(codes.Aborted, msg)
}
defer d.inFlight.Delete(volumeID)

if _, err := d.cloud.DeleteDisk(ctx, volumeID); err != nil {
if err == cloud.ErrNotFound {
Expand All @@ -282,30 +297,21 @@ func (d *controllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVol
return &csi.DeleteVolumeResponse{}, nil
}

func validateDeleteVolumeRequest(req *csi.DeleteVolumeRequest) error {
if len(req.GetVolumeId()) == 0 {
return status.Error(codes.InvalidArgument, "Volume ID not provided")
}
return nil
}

func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
klog.V(4).Infof("ControllerPublishVolume: called with args %+v", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID not provided")
if err := validateControllerPublishVolumeRequest(req); err != nil {
return nil, err
}

volumeID := req.GetVolumeId()
nodeID := req.GetNodeId()
if len(nodeID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Node ID not provided")
}

volCap := req.GetVolumeCapability()
if volCap == nil {
return nil, status.Error(codes.InvalidArgument, "Volume capability not provided")
}

caps := []*csi.VolumeCapability{volCap}
if !isValidVolumeCapabilities(caps) {
modes := util.GetAccessModes(caps)
stringModes := strings.Join(*modes, ", ")
errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported."
return nil, status.Error(codes.InvalidArgument, errString)
}

if !d.cloud.IsExistInstance(ctx, nodeID) {
return nil, status.Errorf(codes.NotFound, "Instance %q not found", nodeID)
Expand Down Expand Up @@ -333,17 +339,38 @@ func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *cs
return &csi.ControllerPublishVolumeResponse{PublishContext: pvInfo}, nil
}

func validateControllerPublishVolumeRequest(req *csi.ControllerPublishVolumeRequest) error {
if len(req.GetVolumeId()) == 0 {
return status.Error(codes.InvalidArgument, "Volume ID not provided")
}

if len(req.GetNodeId()) == 0 {
return status.Error(codes.InvalidArgument, "Node ID not provided")
}

volCap := req.GetVolumeCapability()
if volCap == nil {
return status.Error(codes.InvalidArgument, "Volume capability not provided")
}

caps := []*csi.VolumeCapability{volCap}
if !isValidVolumeCapabilities(caps) {
modes := util.GetAccessModes(caps)
stringModes := strings.Join(*modes, ", ")
errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported."
return status.Error(codes.InvalidArgument, errString)
}
return nil
}

func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
klog.V(4).Infof("ControllerUnpublishVolume: called with args %+v", *req)
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Volume ID not provided")
if err := validateControllerUnpublishVolumeRequest(req); err != nil {
return nil, err
}

volumeID := req.GetVolumeId()
nodeID := req.GetNodeId()
if len(nodeID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Node ID not provided")
}

if err := d.cloud.DetachDisk(ctx, volumeID, nodeID); err != nil {
if err == cloud.ErrNotFound {
Expand All @@ -356,6 +383,18 @@ func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req *
return &csi.ControllerUnpublishVolumeResponse{}, nil
}

func validateControllerUnpublishVolumeRequest(req *csi.ControllerUnpublishVolumeRequest) error {
if len(req.GetVolumeId()) == 0 {
return status.Error(codes.InvalidArgument, "Volume ID not provided")
}

if len(req.GetNodeId()) == 0 {
return status.Error(codes.InvalidArgument, "Node ID not provided")
}

return nil
}

func (d *controllerService) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
klog.V(4).Infof("ControllerGetCapabilities: called with args %+v", *req)
var caps []*csi.ControllerServiceCapability
Expand Down Expand Up @@ -489,15 +528,20 @@ func isValidVolumeContext(volContext map[string]string) bool {

func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) {
klog.V(4).Infof("CreateSnapshot: called with args %+v", req)
snapshotName := req.GetName()
if len(snapshotName) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot name not provided")
if err := validateCreateSnapshotRequest(req); err != nil {
return nil, err
}

snapshotName := req.GetName()
volumeID := req.GetSourceVolumeId()
if len(volumeID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot volume source ID not provided")

// check if a request is already in-flight
if ok := d.inFlight.Insert(snapshotName); !ok {
msg := fmt.Sprintf(internal.VolumeOperationAlreadyExistsErrorMsg, snapshotName)
return nil, status.Error(codes.Aborted, msg)
}
defer d.inFlight.Delete(snapshotName)

snapshot, err := d.cloud.GetSnapshotByName(ctx, snapshotName)
if err != nil && err != cloud.ErrNotFound {
klog.Errorf("Error looking for the snapshot %s: %v", snapshotName, err)
Expand Down Expand Up @@ -535,12 +579,31 @@ func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS
return newCreateSnapshotResponse(snapshot)
}

func validateCreateSnapshotRequest(req *csi.CreateSnapshotRequest) error {
if len(req.GetName()) == 0 {
return status.Error(codes.InvalidArgument, "Snapshot name not provided")
}

if len(req.GetSourceVolumeId()) == 0 {
return status.Error(codes.InvalidArgument, "Snapshot volume source ID not provided")
}
return nil
}

func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) {
klog.V(4).Infof("DeleteSnapshot: called with args %+v", req)
if err := validateDeleteSnapshotRequest(req); err != nil {
return nil, err
}

snapshotID := req.GetSnapshotId()
if len(snapshotID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot ID not provided")

// check if a request is already in-flight
if ok := d.inFlight.Insert(snapshotID); !ok {
msg := fmt.Sprintf("DeleteSnapshot for Snapshot %s is already in progress", snapshotID)
return nil, status.Error(codes.Aborted, msg)
}
defer d.inFlight.Delete(snapshotID)

if _, err := d.cloud.DeleteSnapshot(ctx, snapshotID); err != nil {
if err == cloud.ErrNotFound {
Expand All @@ -553,6 +616,13 @@ func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteS
return &csi.DeleteSnapshotResponse{}, nil
}

func validateDeleteSnapshotRequest(req *csi.DeleteSnapshotRequest) error {
if len(req.GetSnapshotId()) == 0 {
return status.Error(codes.InvalidArgument, "Snapshot ID not provided")
}
return nil
}

func (d *controllerService) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) {
klog.V(4).Infof("ListSnapshots: called with args %+v", req)
var snapshots []*cloud.Snapshot
Expand Down
Loading

0 comments on commit 8e3b578

Please sign in to comment.