Skip to content

Commit

Permalink
Refactor the code
Browse files Browse the repository at this point in the history
1. Split `NewDriver` into two methods `NewDriver` and `NewFakeDriver`.
So that one only takes in endpoint and serves as production uses and
the other create driver with mock dependencies
2. Move controller capability, node capability and volume capability out
of driver struct since they are constant that should never change for a
specific version of driver
  • Loading branch information
Cheng Pan committed Dec 6, 2018
1 parent 9a17b67 commit 1a65009
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 37 deletions.
5 changes: 1 addition & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,17 @@ import (
"flag"

"github.com/golang/glog"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver"
)

func main() {
var endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI Endpoint")
flag.Parse()

cloud, err := cloud.NewCloud()
drv, err := driver.NewDriver(*endpoint)
if err != nil {
glog.Fatalln(err)
}

drv := driver.NewDriver(cloud, nil, *endpoint)
if err := drv.Run(); err != nil {
glog.Fatalln(err)
}
Expand Down
21 changes: 19 additions & 2 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@ import (
"google.golang.org/grpc/status"
)

var (
// volumeCaps represents how the volume could be accessed.
// It is SINGLE_NODE_WRITER since EBS volume could only be
// attached to a single node at any given time.
volumeCaps = []csi.VolumeCapability_AccessMode{
{
Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
},
}

// controllerCaps represents the capability of controller service
controllerCaps = []csi.ControllerServiceCapability_RPC_Type{
csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
}
)

func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
glog.V(4).Infof("CreateVolume: called with args %+v", *req)
volName := req.GetName()
Expand Down Expand Up @@ -200,7 +217,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control
func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
glog.V(4).Infof("ControllerGetCapabilities: called with args %+v", *req)
var caps []*csi.ControllerServiceCapability
for _, cap := range d.controllerCaps {
for _, cap := range controllerCaps {
c := &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Expand Down Expand Up @@ -253,7 +270,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida

func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool {
hasSupport := func(cap *csi.VolumeCapability) bool {
for _, c := range d.volumeCaps {
for _, c := range volumeCaps {
if c.GetMode() == cap.AccessMode.GetMode() {
return true
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func TestCreateVolume(t *testing.T) {

for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "")
awsDriver := NewFakeDriver("")

resp, err := awsDriver.CreateVolume(context.TODO(), tc.req)
if err != nil {
Expand Down Expand Up @@ -298,7 +298,7 @@ func TestDeleteVolume(t *testing.T) {

for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "")
awsDriver := NewFakeDriver("")
_, err := awsDriver.DeleteVolume(context.TODO(), tc.req)
if err != nil {
srvErr, ok := status.FromError(err)
Expand Down
31 changes: 9 additions & 22 deletions pkg/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,23 @@ type Driver struct {
srv *grpc.Server

mounter *mount.SafeFormatAndMount

volumeCaps []csi.VolumeCapability_AccessMode
controllerCaps []csi.ControllerServiceCapability_RPC_Type
nodeCaps []csi.NodeServiceCapability_RPC_Type
}

func NewDriver(cloud cloud.Cloud, mounter *mount.SafeFormatAndMount, endpoint string) *Driver {
glog.Infof("Driver: %v", driverName)
if mounter == nil {
mounter = newSafeMounter()
func NewDriver(endpoint string) (*Driver, error) {
glog.Infof("Driver: %v Version: %v", driverName, vendorVersion)

cloud, err := cloud.NewCloud()
if err != nil {
return nil, err
}

m := cloud.GetMetadata()
return &Driver{
endpoint: endpoint,
nodeID: m.GetInstanceID(),
cloud: cloud,
mounter: mounter,
volumeCaps: []csi.VolumeCapability_AccessMode{
{
Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
},
},
controllerCaps: []csi.ControllerServiceCapability_RPC_Type{
csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
},
nodeCaps: []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
},
}
mounter: newSafeMounter(),
}, nil
}

func (d *Driver) Run() error {
Expand Down
16 changes: 15 additions & 1 deletion pkg/driver/fakes.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.

package driver

import "k8s.io/kubernetes/pkg/util/mount"
import (
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"k8s.io/kubernetes/pkg/util/mount"
)

func NewFakeMounter() *mount.SafeFormatAndMount {
return &mount.SafeFormatAndMount{
Expand All @@ -28,3 +31,14 @@ func NewFakeMounter() *mount.SafeFormatAndMount {
}

}

// NewFakeDriver creates a new mock driver used for testing
func NewFakeDriver(endpoint string) *Driver {
cloud := cloud.NewFakeCloudProvider()
return &Driver{
endpoint: endpoint,
nodeID: cloud.GetMetadata().GetInstanceID(),
cloud: cloud,
mounter: NewFakeMounter(),
}
}
9 changes: 8 additions & 1 deletion pkg/driver/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ const (
defaultFsType = "ext4"
)

var (
// nodeCaps represents the capability of node service.
nodeCaps = []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
}
)

func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) {
glog.V(4).Infof("NodeStageVolume: called with args %+v", *req)
volumeID := req.GetVolumeId()
Expand Down Expand Up @@ -189,7 +196,7 @@ func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeS
func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) {
glog.V(4).Infof("NodeGetCapabilities: called with args %+v", *req)
var caps []*csi.NodeServiceCapability
for _, cap := range d.nodeCaps {
for _, cap := range nodeCaps {
c := &csi.NodeServiceCapability{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Expand Down
5 changes: 2 additions & 3 deletions tests/integration/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ func TestIntegration(t *testing.T) {
var _ = BeforeSuite(func() {
// Run CSI Driver in its own goroutine
var err error
ebs, err = cloud.NewCloud()
Expect(err).To(BeNil(), "Set up Cloud client failed with error")
drv = driver.NewDriver(ebs, nil, endpoint)
drv, err = driver.NewDriver(endpoint)
Expect(err).To(BeNil())
go func() {
err := drv.Run()
Expect(err).To(BeNil())
Expand Down
3 changes: 1 addition & 2 deletions tests/sanity/sanity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (

sanity "github.com/kubernetes-csi/csi-test/pkg/sanity"

"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver"
)

Expand All @@ -44,7 +43,7 @@ func TestSanity(t *testing.T) {
}

var _ = BeforeSuite(func() {
ebsDriver = driver.NewDriver(cloud.NewFakeCloudProvider(), driver.NewFakeMounter(), endpoint)
ebsDriver = driver.NewFakeDriver(endpoint)
go func() {
Expect(ebsDriver.Run()).NotTo(HaveOccurred())
}()
Expand Down

0 comments on commit 1a65009

Please sign in to comment.