Skip to content

Commit

Permalink
Refactor the code
Browse files Browse the repository at this point in the history
1. Split `NewDriver` into two methods `NewDriver` and `NewMockDriver`.
So that one only takes in endpoint and serves as production uses and
the other create driver with mock dependencies
2. Move controller capability, node capability and volume capability out
of driver struct since they are constant that should never change for a
specific version of driver
  • Loading branch information
Cheng Pan committed Dec 3, 2018
1 parent 9a17b67 commit 714a580
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 32 deletions.
5 changes: 1 addition & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,17 @@ import (
"flag"

"github.com/golang/glog"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver"
)

func main() {
var endpoint = flag.String("endpoint", "unix://tmp/csi.sock", "CSI Endpoint")
flag.Parse()

cloud, err := cloud.NewCloud()
drv, err := driver.NewDriver(*endpoint)
if err != nil {
glog.Fatalln(err)
}

drv := driver.NewDriver(cloud, nil, *endpoint)
if err := drv.Run(); err != nil {
glog.Fatalln(err)
}
Expand Down
19 changes: 17 additions & 2 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ import (
"google.golang.org/grpc/status"
)

var (
// EBS volume capability. It could only be read/write on a single node at any given time.
volumeCaps = []csi.VolumeCapability_AccessMode{
{
Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
},
}

// controller capability
controllerCaps = []csi.ControllerServiceCapability_RPC_Type{
csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
}
)

func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
glog.V(4).Infof("CreateVolume: called with args %+v", *req)
volName := req.GetName()
Expand Down Expand Up @@ -200,7 +215,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control
func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
glog.V(4).Infof("ControllerGetCapabilities: called with args %+v", *req)
var caps []*csi.ControllerServiceCapability
for _, cap := range d.controllerCaps {
for _, cap := range controllerCaps {
c := &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Expand Down Expand Up @@ -253,7 +268,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida

func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool {
hasSupport := func(cap *csi.VolumeCapability) bool {
for _, c := range d.volumeCaps {
for _, c := range volumeCaps {
if c.GetMode() == cap.AccessMode.GetMode() {
return true
}
Expand Down
14 changes: 12 additions & 2 deletions pkg/driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,12 @@ func TestCreateVolume(t *testing.T) {

for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "")
awsDriver := &Driver{
endpoint: "",
nodeID: "",
cloud: cloud.NewFakeCloudProvider(),
mounter: NewFakeMounter(),
}

resp, err := awsDriver.CreateVolume(context.TODO(), tc.req)
if err != nil {
Expand Down Expand Up @@ -298,7 +303,12 @@ func TestDeleteVolume(t *testing.T) {

for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "")
awsDriver := &Driver{
endpoint: "",
nodeID: "",
cloud: cloud.NewFakeCloudProvider(),
mounter: NewFakeMounter(),
}
_, err := awsDriver.DeleteVolume(context.TODO(), tc.req)
if err != nil {
srvErr, ok := status.FromError(err)
Expand Down
40 changes: 19 additions & 21 deletions pkg/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,35 +46,33 @@ type Driver struct {
srv *grpc.Server

mounter *mount.SafeFormatAndMount

volumeCaps []csi.VolumeCapability_AccessMode
controllerCaps []csi.ControllerServiceCapability_RPC_Type
nodeCaps []csi.NodeServiceCapability_RPC_Type
}

func NewDriver(cloud cloud.Cloud, mounter *mount.SafeFormatAndMount, endpoint string) *Driver {
glog.Infof("Driver: %v", driverName)
if mounter == nil {
mounter = newSafeMounter()
func NewDriver(endpoint string) (*Driver, error) {
glog.Infof("Driver: %v Version: %v", driverName, vendorVersion)

cloud, err := cloud.NewCloud()
if err != nil {
return nil, err
}

m := cloud.GetMetadata()
return &Driver{
endpoint: endpoint,
nodeID: m.GetInstanceID(),
cloud: cloud,
mounter: mounter,
volumeCaps: []csi.VolumeCapability_AccessMode{
{
Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
},
},
controllerCaps: []csi.ControllerServiceCapability_RPC_Type{
csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
},
nodeCaps: []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
},
mounter: newSafeMounter(),
}, nil
}

// NewMockDriver creates a new mock driver used for testing
func NewMockDriver(endpoint string) *Driver {
cloud := cloud.NewFakeCloudProvider()
return &Driver{
endpoint: endpoint,
nodeID: cloud.GetMetadata().GetInstanceID(),
cloud: cloud,
mounter: NewFakeMounter(),
}
}

Expand Down
9 changes: 8 additions & 1 deletion pkg/driver/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ const (
defaultFsType = "ext4"
)

var (
// node capability
nodeCaps = []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
}
)

func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) {
glog.V(4).Infof("NodeStageVolume: called with args %+v", *req)
volumeID := req.GetVolumeId()
Expand Down Expand Up @@ -189,7 +196,7 @@ func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeS
func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) {
glog.V(4).Infof("NodeGetCapabilities: called with args %+v", *req)
var caps []*csi.NodeServiceCapability
for _, cap := range d.nodeCaps {
for _, cap := range nodeCaps {
c := &csi.NodeServiceCapability{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Expand Down
3 changes: 1 addition & 2 deletions tests/sanity/sanity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (

sanity "github.com/kubernetes-csi/csi-test/pkg/sanity"

"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver"
)

Expand All @@ -44,7 +43,7 @@ func TestSanity(t *testing.T) {
}

var _ = BeforeSuite(func() {
ebsDriver = driver.NewDriver(cloud.NewFakeCloudProvider(), driver.NewFakeMounter(), endpoint)
ebsDriver = driver.NewMockDriver(endpoint)
go func() {
Expect(ebsDriver.Run()).NotTo(HaveOccurred())
}()
Expand Down

0 comments on commit 714a580

Please sign in to comment.