diff --git a/cmd/main.go b/cmd/main.go index f46ee60b0f..3fbf3c84ae 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -20,7 +20,6 @@ import ( "flag" "github.com/golang/glog" - "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" ) @@ -28,12 +27,10 @@ func main() { var endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI Endpoint") flag.Parse() - cloud, err := cloud.NewCloud() + drv, err := driver.NewDriver(*endpoint) if err != nil { glog.Fatalln(err) } - - drv := driver.NewDriver(cloud, nil, *endpoint) if err := drv.Run(); err != nil { glog.Fatalln(err) } diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 05dc451ff5..f697842371 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -28,6 +28,21 @@ import ( "google.golang.org/grpc/status" ) +var ( + // EBS volume capability. It could only be read/write on a single node at any given time. + volumeCaps = []csi.VolumeCapability_AccessMode{ + { + Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, + }, + } + + // controller capability + controllerCaps = []csi.ControllerServiceCapability_RPC_Type{ + csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, + csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, + } +) + func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { glog.V(4).Infof("CreateVolume: called with args %+v", *req) volName := req.GetName() @@ -200,7 +215,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { glog.V(4).Infof("ControllerGetCapabilities: called with args %+v", *req) var caps []*csi.ControllerServiceCapability - for _, cap := range d.controllerCaps { + for _, cap := range controllerCaps { c := &csi.ControllerServiceCapability{ Type: &csi.ControllerServiceCapability_Rpc{ Rpc: &csi.ControllerServiceCapability_RPC{ @@ -253,7 +268,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool { hasSupport := func(cap *csi.VolumeCapability) bool { - for _, c := range d.volumeCaps { + for _, c := range volumeCaps { if c.GetMode() == cap.AccessMode.GetMode() { return true } diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 52e94e3390..82fd03b6f1 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -219,7 +219,12 @@ func TestCreateVolume(t *testing.T) { for _, tc := range testCases { t.Logf("Test case: %s", tc.name) - awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "") + awsDriver := &Driver{ + endpoint: "", + nodeID: "", + cloud: cloud.NewFakeCloudProvider(), + mounter: NewFakeMounter(), + } resp, err := awsDriver.CreateVolume(context.TODO(), tc.req) if err != nil { @@ -298,7 +303,12 @@ func TestDeleteVolume(t *testing.T) { for _, tc := range testCases { t.Logf("Test case: %s", tc.name) - awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "") + awsDriver := &Driver{ + endpoint: "", + nodeID: "", + cloud: cloud.NewFakeCloudProvider(), + mounter: NewFakeMounter(), + } _, err := awsDriver.DeleteVolume(context.TODO(), tc.req) if err != nil { srvErr, ok := status.FromError(err) diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 8eff745fa6..a14c7f4f7c 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -46,35 +46,33 @@ type Driver struct { srv *grpc.Server mounter *mount.SafeFormatAndMount - - volumeCaps []csi.VolumeCapability_AccessMode - controllerCaps []csi.ControllerServiceCapability_RPC_Type - nodeCaps []csi.NodeServiceCapability_RPC_Type } -func NewDriver(cloud cloud.Cloud, mounter *mount.SafeFormatAndMount, endpoint string) *Driver { - glog.Infof("Driver: %v", driverName) - if mounter == nil { - mounter = newSafeMounter() +func NewDriver(endpoint string) (*Driver, error) { + glog.Infof("Driver: %v Version: %v", driverName, vendorVersion) + + cloud, err := cloud.NewCloud() + if err != nil { + return nil, err } + m := cloud.GetMetadata() return &Driver{ endpoint: endpoint, nodeID: m.GetInstanceID(), cloud: cloud, - mounter: mounter, - volumeCaps: []csi.VolumeCapability_AccessMode{ - { - Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER, - }, - }, - controllerCaps: []csi.ControllerServiceCapability_RPC_Type{ - csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME, - csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME, - }, - nodeCaps: []csi.NodeServiceCapability_RPC_Type{ - csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME, - }, + mounter: newSafeMounter(), + }, nil +} + +// NewMockDriver creates a new mock driver used for testing +func NewMockDriver(endpoint string) *Driver { + cloud := cloud.NewFakeCloudProvider() + return &Driver{ + endpoint: endpoint, + nodeID: cloud.GetMetadata().GetInstanceID(), + cloud: cloud, + mounter: NewFakeMounter(), } } diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 7bec2c145b..952c0c1dcc 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -32,6 +32,13 @@ const ( defaultFsType = "ext4" ) +var ( + // node capability + nodeCaps = []csi.NodeServiceCapability_RPC_Type{ + csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME, + } +) + func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) { glog.V(4).Infof("NodeStageVolume: called with args %+v", *req) volumeID := req.GetVolumeId() @@ -189,7 +196,7 @@ func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeS func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { glog.V(4).Infof("NodeGetCapabilities: called with args %+v", *req) var caps []*csi.NodeServiceCapability - for _, cap := range d.nodeCaps { + for _, cap := range nodeCaps { c := &csi.NodeServiceCapability{ Type: &csi.NodeServiceCapability_Rpc{ Rpc: &csi.NodeServiceCapability_RPC{ diff --git a/tests/sanity/sanity_test.go b/tests/sanity/sanity_test.go index a7f3fc1ea2..4cccd5ac9d 100644 --- a/tests/sanity/sanity_test.go +++ b/tests/sanity/sanity_test.go @@ -25,7 +25,6 @@ import ( sanity "github.com/kubernetes-csi/csi-test/pkg/sanity" - "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" ) @@ -44,7 +43,7 @@ func TestSanity(t *testing.T) { } var _ = BeforeSuite(func() { - ebsDriver = driver.NewDriver(cloud.NewFakeCloudProvider(), driver.NewFakeMounter(), endpoint) + ebsDriver = driver.NewMockDriver(endpoint) go func() { Expect(ebsDriver.Run()).NotTo(HaveOccurred()) }()