Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the code #135

Merged
merged 1 commit into from
Dec 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"os"

"github.com/golang/glog"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver"
)

Expand All @@ -42,12 +41,10 @@ func main() {
os.Exit(0)
}

cloud, err := cloud.NewCloud()
drv, err := driver.NewDriver(*endpoint)
if err != nil {
glog.Fatalln(err)
}

drv := driver.NewDriver(cloud, nil, *endpoint)
if err := drv.Run(); err != nil {
glog.Fatalln(err)
}
Expand Down
21 changes: 19 additions & 2 deletions pkg/driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@ import (
"google.golang.org/grpc/status"
)

var (
// volumeCaps represents how the volume could be accessed.
// It is SINGLE_NODE_WRITER since EBS volume could only be
// attached to a single node at any given time.
volumeCaps = []csi.VolumeCapability_AccessMode{
{
Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
},
}

// controllerCaps represents the capability of controller service
controllerCaps = []csi.ControllerServiceCapability_RPC_Type{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if making these fields global is a better practice. It'd make it difficult to create a driver with custom capabilities for testing, for example.

Another option would be to create functions to set up and return the caps, and use these function in New*Driver().

If we want to keep it like this, we may want to name these vars like defaultVolume... and use these names in the comments (e.g., defaultVolumeCapabilities represents the EBS volume....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if making these fields global is a better practice. It'd make it difficult to create a driver with custom capabilities for testing, for example.

I feel we will never create a CSI EBS driver with a different capability eg, volume capability of MULTI_NODE_MULTI_WRITER since that is limited by EBS; or we will create a driver whose controller service capability is configurable, eg CREATE_DELETE_VOLUME is configurable since we always want dynamic provisioning.

I kept variable name but updated the comments of those variable though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel like moving those vars to a global scope would make the code better, even though they might not be ever changed.

I agree that initializing them in NewDriver feels messy, but I'd rather create functions that return the default values and use those functions in NewDriver (if I'm not mistaken GCP driver does something like that).

However, I don't want to block this PR from being merged, so feel free to merge it if you think it's OK as is.

csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
}
)

func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
glog.V(4).Infof("CreateVolume: called with args %+v", *req)
volName := req.GetName()
Expand Down Expand Up @@ -200,7 +217,7 @@ func (d *Driver) ControllerUnpublishVolume(ctx context.Context, req *csi.Control
func (d *Driver) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) {
glog.V(4).Infof("ControllerGetCapabilities: called with args %+v", *req)
var caps []*csi.ControllerServiceCapability
for _, cap := range d.controllerCaps {
for _, cap := range controllerCaps {
c := &csi.ControllerServiceCapability{
Type: &csi.ControllerServiceCapability_Rpc{
Rpc: &csi.ControllerServiceCapability_RPC{
Expand Down Expand Up @@ -253,7 +270,7 @@ func (d *Driver) ValidateVolumeCapabilities(ctx context.Context, req *csi.Valida

func (d *Driver) isValidVolumeCapabilities(volCaps []*csi.VolumeCapability) bool {
hasSupport := func(cap *csi.VolumeCapability) bool {
for _, c := range d.volumeCaps {
for _, c := range volumeCaps {
if c.GetMode() == cap.AccessMode.GetMode() {
return true
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func TestCreateVolume(t *testing.T) {

for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "")
awsDriver := NewFakeDriver("")

resp, err := awsDriver.CreateVolume(context.TODO(), tc.req)
if err != nil {
Expand Down Expand Up @@ -298,7 +298,7 @@ func TestDeleteVolume(t *testing.T) {

for _, tc := range testCases {
t.Logf("Test case: %s", tc.name)
awsDriver := NewDriver(cloud.NewFakeCloudProvider(), NewFakeMounter(), "")
awsDriver := NewFakeDriver("")
_, err := awsDriver.DeleteVolume(context.TODO(), tc.req)
if err != nil {
srvErr, ok := status.FromError(err)
Expand Down
31 changes: 9 additions & 22 deletions pkg/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,23 @@ type Driver struct {
srv *grpc.Server

mounter *mount.SafeFormatAndMount

volumeCaps []csi.VolumeCapability_AccessMode
controllerCaps []csi.ControllerServiceCapability_RPC_Type
nodeCaps []csi.NodeServiceCapability_RPC_Type
}

func NewDriver(cloud cloud.Cloud, mounter *mount.SafeFormatAndMount, endpoint string) *Driver {
glog.Infof("Driver: %v", driverName)
if mounter == nil {
mounter = newSafeMounter()
func NewDriver(endpoint string) (*Driver, error) {
glog.Infof("Driver: %v Version: %v", driverName, driverVersion)

cloud, err := cloud.NewCloud()
if err != nil {
return nil, err
}

m := cloud.GetMetadata()
return &Driver{
endpoint: endpoint,
nodeID: m.GetInstanceID(),
cloud: cloud,
mounter: mounter,
volumeCaps: []csi.VolumeCapability_AccessMode{
{
Mode: csi.VolumeCapability_AccessMode_SINGLE_NODE_WRITER,
},
},
controllerCaps: []csi.ControllerServiceCapability_RPC_Type{
csi.ControllerServiceCapability_RPC_CREATE_DELETE_VOLUME,
csi.ControllerServiceCapability_RPC_PUBLISH_UNPUBLISH_VOLUME,
},
nodeCaps: []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
},
}
mounter: newSafeMounter(),
}, nil
}

func (d *Driver) Run() error {
Expand Down
16 changes: 15 additions & 1 deletion pkg/driver/fakes.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ limitations under the License.

package driver

import "k8s.io/kubernetes/pkg/util/mount"
import (
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"k8s.io/kubernetes/pkg/util/mount"
)

func NewFakeMounter() *mount.SafeFormatAndMount {
return &mount.SafeFormatAndMount{
Expand All @@ -28,3 +31,14 @@ func NewFakeMounter() *mount.SafeFormatAndMount {
}

}

// NewFakeDriver creates a new mock driver used for testing
func NewFakeDriver(endpoint string) *Driver {
cloud := cloud.NewFakeCloudProvider()
return &Driver{
endpoint: endpoint,
nodeID: cloud.GetMetadata().GetInstanceID(),
cloud: cloud,
mounter: NewFakeMounter(),
}
}
9 changes: 8 additions & 1 deletion pkg/driver/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ const (
defaultFsType = "ext4"
)

var (
// nodeCaps represents the capability of node service.
nodeCaps = []csi.NodeServiceCapability_RPC_Type{
csi.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME,
}
)

func (d *Driver) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolumeRequest) (*csi.NodeStageVolumeResponse, error) {
glog.V(4).Infof("NodeStageVolume: called with args %+v", *req)
volumeID := req.GetVolumeId()
Expand Down Expand Up @@ -189,7 +196,7 @@ func (d *Driver) NodeGetVolumeStats(ctx context.Context, req *csi.NodeGetVolumeS
func (d *Driver) NodeGetCapabilities(ctx context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) {
glog.V(4).Infof("NodeGetCapabilities: called with args %+v", *req)
var caps []*csi.NodeServiceCapability
for _, cap := range d.nodeCaps {
for _, cap := range nodeCaps {
c := &csi.NodeServiceCapability{
Type: &csi.NodeServiceCapability_Rpc{
Rpc: &csi.NodeServiceCapability_RPC{
Expand Down
5 changes: 2 additions & 3 deletions tests/integration/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ func TestIntegration(t *testing.T) {
var _ = BeforeSuite(func() {
// Run CSI Driver in its own goroutine
var err error
ebs, err = cloud.NewCloud()
Expect(err).To(BeNil(), "Set up Cloud client failed with error")
drv = driver.NewDriver(ebs, nil, endpoint)
drv, err = driver.NewDriver(endpoint)
Expect(err).To(BeNil())
go func() {
err := drv.Run()
Expect(err).To(BeNil())
Expand Down
3 changes: 1 addition & 2 deletions tests/sanity/sanity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (

sanity "github.com/kubernetes-csi/csi-test/pkg/sanity"

"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud"
"github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver"
)

Expand All @@ -44,7 +43,7 @@ func TestSanity(t *testing.T) {
}

var _ = BeforeSuite(func() {
ebsDriver = driver.NewDriver(cloud.NewFakeCloudProvider(), driver.NewFakeMounter(), endpoint)
ebsDriver = driver.NewFakeDriver(endpoint)
go func() {
Expect(ebsDriver.Run()).NotTo(HaveOccurred())
}()
Expand Down