diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index bd899caffd..c6ea7afd62 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -19,6 +19,8 @@ package cloud import ( "errors" "fmt" + "math/rand" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" @@ -78,16 +80,26 @@ var ( ErrAlreadyExists = errors.New("Resource already exists") ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + +// Disk represents a EBS volume type Disk struct { - VolumeID string - CapacityGiB int64 + VolumeID string + CapacityGiB int64 + AvailabilityZone string } +// DiskOptions represents parameters to create an EBS volume type DiskOptions struct { CapacityBytes int64 Tags map[string]string VolumeType string IOPSPerGB int64 + // the availability zone to create volume in + // if nil a random zone will be used + AvailabilityZone *string } // EC2 abstracts aws.EC2 to facilitate its mocking. @@ -98,6 +110,9 @@ type EC2 interface { DetachVolume(input *ec2.DetachVolumeInput) (*ec2.VolumeAttachment, error) AttachVolume(input *ec2.AttachVolumeInput) (*ec2.VolumeAttachment, error) DescribeInstances(input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) + + // Get all the zones for current region + DescribeAvailabilityZones(input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) } type Cloud interface { @@ -156,8 +171,10 @@ func (c *cloud) GetMetadata() MetadataService { } func (c *cloud) CreateDisk(volumeName string, diskOptions *DiskOptions) (*Disk, error) { - var createType string - var iops int64 + var ( + createType string + iops int64 + ) capacityGiB := util.BytesToGiB(diskOptions.CapacityBytes) switch diskOptions.VolumeType { @@ -187,9 +204,21 @@ func (c *cloud) CreateDisk(volumeName string, diskOptions *DiskOptions) (*Disk, Tags: tags, } - m := c.GetMetadata() + var ( + zone string + err error + ) + if diskOptions.AvailabilityZone == nil { + zone, err = c.pickRandomAvailabilityZone() + if err != nil { + return nil, err + } + } else { + zone = *diskOptions.AvailabilityZone + } + request := &ec2.CreateVolumeInput{ - AvailabilityZone: aws.String(m.GetAvailabilityZone()), + AvailabilityZone: aws.String(zone), Size: aws.Int64(capacityGiB), VolumeType: aws.String(createType), TagSpecifications: []*ec2.TagSpecification{&tagSpec}, @@ -213,7 +242,7 @@ func (c *cloud) CreateDisk(volumeName string, diskOptions *DiskOptions) (*Disk, return nil, fmt.Errorf("disk size was not returned by CreateVolume") } - return &Disk{CapacityGiB: size, VolumeID: volumeID}, nil + return &Disk{CapacityGiB: size, VolumeID: volumeID, AvailabilityZone: zone}, nil } func (c *cloud) DeleteDisk(volumeID string) (bool, error) { @@ -433,3 +462,17 @@ func (c *cloud) getInstance(nodeID string) (*ec2.Instance, error) { return instances[0], nil } + +func (c *cloud) pickRandomAvailabilityZone() (string, error) { + output, err := c.ec2.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{}) + if err != nil { + return "", err + } + + var zones []string + for _, zone := range output.AvailabilityZones { + zones = append(zones, *zone.ZoneName) + } + + return zones[rand.Int()%len(zones)], nil +} diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index fd7e53b71f..ca34de162e 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -42,8 +42,23 @@ func TestCreateDisk(t *testing.T) { name: "success: normal", volumeName: "vol-test-name", diskOptions: &DiskOptions{ - CapacityBytes: util.GiBToBytes(1), - Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + CapacityBytes: util.GiBToBytes(1), + Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + AvailabilityZone: nil, + }, + expDisk: &Disk{ + VolumeID: "vol-test", + CapacityGiB: 1, + }, + expErr: nil, + }, + { + name: "success: normal with provided zone", + volumeName: "vol-test-name", + diskOptions: &DiskOptions{ + CapacityBytes: util.GiBToBytes(1), + Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + AvailabilityZone: stringPtr("us-west-2"), }, expDisk: &Disk{ VolumeID: "vol-test", @@ -55,8 +70,9 @@ func TestCreateDisk(t *testing.T) { name: "fail: CreateVolume returned an error", volumeName: "vol-test-name-error", diskOptions: &DiskOptions{ - CapacityBytes: util.GiBToBytes(1), - Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + CapacityBytes: util.GiBToBytes(1), + Tags: map[string]string{VolumeNameTagKey: "vol-test"}, + AvailabilityZone: nil, }, expErr: fmt.Errorf("CreateVolume generic error"), }, @@ -78,6 +94,24 @@ func TestCreateDisk(t *testing.T) { mockEC2.EXPECT().CreateVolume(gomock.Any()).Return(vol, tc.expErr) + if tc.diskOptions.AvailabilityZone == nil { + describeAvailabilityZonesResp := &ec2.DescribeAvailabilityZonesOutput{ + AvailabilityZones: []*ec2.AvailabilityZone{ + &ec2.AvailabilityZone{ + ZoneName: aws.String("us-west-2a"), + }, + &ec2.AvailabilityZone{ + ZoneName: aws.String("us-west-2b"), + }, + &ec2.AvailabilityZone{ + ZoneName: aws.String("us-west-2c"), + }, + }, + } + + mockEC2.EXPECT().DescribeAvailabilityZones(gomock.Any()).Return(describeAvailabilityZonesResp, nil) + } + disk, err := c.CreateDisk(tc.volumeName, tc.diskOptions) if err != nil { if tc.expErr == nil { @@ -369,3 +403,7 @@ func newDescribeInstancesOutput(nodeID string) *ec2.DescribeInstancesOutput { }}, } } + +func stringPtr(str string) *string { + return &str +} diff --git a/pkg/cloud/mocks/mock_ec2.go b/pkg/cloud/mocks/mock_ec2.go index 4559ba77cd..4ba61f5a1f 100644 --- a/pkg/cloud/mocks/mock_ec2.go +++ b/pkg/cloud/mocks/mock_ec2.go @@ -5,10 +5,9 @@ package mocks import ( - reflect "reflect" - ec2 "github.com/aws/aws-sdk-go/service/ec2" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) // MockEC2 is a mock of EC2 interface @@ -73,6 +72,19 @@ func (mr *MockEC2MockRecorder) DeleteVolume(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteVolume", reflect.TypeOf((*MockEC2)(nil).DeleteVolume), arg0) } +// DescribeAvailabilityZones mocks base method +func (m *MockEC2) DescribeAvailabilityZones(arg0 *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) { + ret := m.ctrl.Call(m, "DescribeAvailabilityZones", arg0) + ret0, _ := ret[0].(*ec2.DescribeAvailabilityZonesOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeAvailabilityZones indicates an expected call of DescribeAvailabilityZones +func (mr *MockEC2MockRecorder) DescribeAvailabilityZones(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeAvailabilityZones", reflect.TypeOf((*MockEC2)(nil).DescribeAvailabilityZones), arg0) +} + // DescribeInstances mocks base method func (m *MockEC2) DescribeInstances(arg0 *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { ret := m.ctrl.Call(m, "DescribeInstances", arg0) diff --git a/pkg/cloud/mocks/mock_ec2metadata.go b/pkg/cloud/mocks/mock_ec2metadata.go index 4e2e57353e..f54a7098cb 100644 --- a/pkg/cloud/mocks/mock_ec2metadata.go +++ b/pkg/cloud/mocks/mock_ec2metadata.go @@ -5,10 +5,9 @@ package mocks import ( - reflect "reflect" - ec2metadata "github.com/aws/aws-sdk-go/aws/ec2metadata" gomock "github.com/golang/mock/gomock" + reflect "reflect" ) // MockEC2Metadata is a mock of EC2Metadata interface diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 4fdcfb4bbd..9920c23f54 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -68,24 +68,23 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) } } - if disk == nil { - opts := &cloud.DiskOptions{ - CapacityBytes: volSizeBytes, - Tags: map[string]string{cloud.VolumeNameTagKey: volName}, - } - newDisk, err := d.cloud.CreateDisk(volName, opts) - if err != nil { - return nil, status.Errorf(codes.Internal, "Could not create volume %q: %v", volName, err) - } - disk = newDisk + // volume exists already + if disk != nil { + return newCreateVolumeResponse(disk), nil } - return &csi.CreateVolumeResponse{ - Volume: &csi.Volume{ - Id: disk.VolumeID, - CapacityBytes: util.GiBToBytes(disk.CapacityGiB), - }, - }, nil + // create a new volume + zone := pickAvailabilityZone(req.GetAccessibilityRequirements()) + opts := &cloud.DiskOptions{ + CapacityBytes: volSizeBytes, + AvailabilityZone: zone, + Tags: map[string]string{cloud.VolumeNameTagKey: volName}, + } + disk, err = d.cloud.CreateDisk(volName, opts) + if err != nil { + return nil, status.Errorf(codes.Internal, "Could not create volume %q: %v", volName, err) + } + return newCreateVolumeResponse(disk), nil } func (d *Driver) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { @@ -253,3 +252,37 @@ func (d *Driver) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { return nil, status.Error(codes.Unimplemented, "") } + +// pickAvailabilityZone selects 1 zone given topology requirement. +func pickAvailabilityZone(requirement *csi.TopologyRequirement) *string { + if requirement == nil { + return nil + } + for _, topology := range requirement.GetPreferred() { + zone, exists := topology.GetSegments()[topologyKey] + if exists { + return &zone + } + } + for _, topology := range requirement.GetRequisite() { + zone, exists := topology.GetSegments()[topologyKey] + if exists { + return &zone + } + } + return nil +} + +func newCreateVolumeResponse(disk *cloud.Disk) *csi.CreateVolumeResponse { + return &csi.CreateVolumeResponse{ + Volume: &csi.Volume{ + Id: disk.VolumeID, + CapacityBytes: util.GiBToBytes(disk.CapacityGiB), + AccessibleTopology: []*csi.Topology{ + &csi.Topology{ + Segments: map[string]string{topologyKey: disk.AvailabilityZone}, + }, + }, + }, + } +} diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 0fd887a486..deab7a805e 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -148,7 +148,7 @@ func TestCreateVolume(t *testing.T) { t.Fatalf("Could not get error status code from error: %v", srvErr) } if srvErr.Code() != tc.expErrCode { - t.Fatalf("Expected error code %d, got %d", tc.expErrCode, srvErr.Code()) + t.Fatalf("Expected error code %d, got %d message %s", tc.expErrCode, srvErr.Code(), srvErr.Message()) } continue } @@ -235,3 +235,75 @@ func TestDeleteVolume(t *testing.T) { } } } + +func TestPickAvailabilityZone(t *testing.T) { + expZone := "us-west-2b" + testCases := []struct { + name string + requirement *csi.TopologyRequirement + expZone *string + }{ + { + name: "Pick from preferred", + requirement: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + &csi.Topology{ + Segments: map[string]string{topologyKey: expZone}, + }, + }, + Preferred: []*csi.Topology{ + &csi.Topology{ + Segments: map[string]string{topologyKey: expZone}, + }, + }, + }, + expZone: stringPtr(expZone), + }, + { + name: "Pick from requisite", + requirement: &csi.TopologyRequirement{ + Requisite: []*csi.Topology{ + &csi.Topology{ + Segments: map[string]string{topologyKey: expZone}, + }, + }, + }, + expZone: stringPtr(expZone), + }, + { + name: "Pick from empty topology", + requirement: &csi.TopologyRequirement{ + Preferred: []*csi.Topology{&csi.Topology{}}, + Requisite: []*csi.Topology{&csi.Topology{}}, + }, + expZone: nil, + }, + + { + name: "Topology Requirement is nil", + requirement: nil, + expZone: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := pickAvailabilityZone(tc.requirement) + if tc.expZone == nil { + if actual != nil { + t.Fatalf("Expected zone to be nil, got %v", actual) + } + } else { + if *actual != *tc.expZone { + t.Fatalf("Expected zone %v, got zone: %v", tc.expZone, actual) + + } + } + }) + } + +} + +func stringPtr(str string) *string { + return &str +} diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 4098283b83..2c99b6ae56 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -31,6 +31,7 @@ import ( const ( driverName = "com.amazon.aws.csi.ebs" vendorVersion = "0.0.1" // FIXME + topologyKey = driverName + "/zone" ) type Driver struct { diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 0ef23b8961..c02eef3920 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -190,8 +190,14 @@ func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabi func (d *Driver) NodeGetInfo(ctx context.Context, req *csi.NodeGetInfoRequest) (*csi.NodeGetInfoResponse, error) { glog.V(4).Infof("NodeGetInfo: called with args %#v", req) m := d.cloud.GetMetadata() + + topology := &csi.Topology{ + Segments: map[string]string{topologyKey: m.GetAvailabilityZone()}, + } + return &csi.NodeGetInfoResponse{ - NodeId: m.GetInstanceID(), + NodeId: m.GetInstanceID(), + AccessibleTopology: topology, }, nil }