Skip to content
This repository has been archived by the owner on Jan 11, 2023. It is now read-only.

Commit

Permalink
separate authArgs validation
Browse files Browse the repository at this point in the history
  • Loading branch information
CecileRobertMichon committed May 7, 2018
1 parent 54c8145 commit e23310d
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 151 deletions.
4 changes: 4 additions & 0 deletions cmd/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ func (dc *deployCmd) validate(cmd *cobra.Command, args []string) error {
return fmt.Errorf(fmt.Sprintf("--location does not match api model location"))
}

if err = dc.authArgs.validateAuthArgs(); err != nil {
return fmt.Errorf("%s", err)
}

dc.client, err = dc.authArgs.getClient()
if err != nil {
return fmt.Errorf("failed to get client: %s", err.Error())
Expand Down
21 changes: 14 additions & 7 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,38 @@ func addAuthFlags(authArgs *authArgs, f *flag.FlagSet) {
f.StringVar(&authArgs.language, "language", "en-us", "language to return error messages in")
}

func (authArgs *authArgs) getClient() (*armhelpers.AzureClient, error) {
func (authArgs *authArgs) validateAuthArgs() error {
authArgs.ClientID, _ = uuid.FromString(authArgs.rawClientID)
authArgs.SubscriptionID, _ = uuid.FromString(authArgs.rawSubscriptionID)

if authArgs.AuthMethod == "client_secret" {
if authArgs.ClientID.String() == "00000000-0000-0000-0000-000000000000" || authArgs.ClientSecret == "" {
return nil, fmt.Errorf(`--client-id and --client-secret must be specified when --auth-method="client_secret"`)
return fmt.Errorf(`--client-id and --client-secret must be specified when --auth-method="client_secret"`)
}
// try parse the UUID
} else if authArgs.AuthMethod == "client_certificate" {
if authArgs.ClientID.String() == "00000000-0000-0000-0000-000000000000" || authArgs.CertificatePath == "" || authArgs.PrivateKeyPath == "" {
return nil, fmt.Errorf(`--client-id and --certificate-path, and --private-key-path must be specified when --auth-method="client_certificate"`)
return fmt.Errorf(`--client-id and --certificate-path, and --private-key-path must be specified when --auth-method="client_certificate"`)
}
}

if authArgs.SubscriptionID.String() == "00000000-0000-0000-0000-000000000000" {
return nil, fmt.Errorf("--subscription-id is required (and must be a valid UUID)")
return fmt.Errorf("--subscription-id is required (and must be a valid UUID)")
}

env, err := azure.EnvironmentFromName(authArgs.RawAzureEnvironment)
_, err := azure.EnvironmentFromName(authArgs.RawAzureEnvironment)
if err != nil {
return nil, fmt.Errorf("failed to parse --azure-env as a valid target Azure cloud environment")
return fmt.Errorf("failed to parse --azure-env as a valid target Azure cloud environment")
}
return nil
}

func (authArgs *authArgs) getClient() (*armhelpers.AzureClient, error) {
var client *armhelpers.AzureClient
env, err := azure.EnvironmentFromName(authArgs.RawAzureEnvironment)
if err != nil {
return nil, err
}
switch authArgs.AuthMethod {
case "device":
client, err = armhelpers.NewAzureClientWithDeviceAuth(env, authArgs.SubscriptionID.String())
Expand All @@ -105,7 +112,7 @@ func (authArgs *authArgs) getClient() (*armhelpers.AzureClient, error) {
case "client_certificate":
client, err = armhelpers.NewAzureClientWithClientCertificateFile(env, authArgs.SubscriptionID.String(), authArgs.ClientID.String(), authArgs.CertificatePath, authArgs.PrivateKeyPath)
default:
return nil, fmt.Errorf("--auth-method: ERROR: method unsupported. method=%q.", authArgs.AuthMethod)
return nil, fmt.Errorf("--auth-method: ERROR: method unsupported. method=%q", authArgs.AuthMethod)
}
if err != nil {
return nil, err
Expand Down
4 changes: 4 additions & 0 deletions cmd/scale.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ func (sc *scaleCmd) validate(cmd *cobra.Command, args []string) {
log.Fatal("--new-node-count must be specified")
}

if err = sc.authArgs.validateAuthArgs(); err != nil {
log.Fatal("%s", err)
}

if sc.client, err = sc.authArgs.getClient(); err != nil {
log.Error("Failed to get client:", err)
}
Expand Down
16 changes: 7 additions & 9 deletions cmd/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ func (uc *upgradeCmd) validate(cmd *cobra.Command) error {
if uc.location == "" {
cmd.Usage()
return fmt.Errorf("--location must be specified")
} else {
uc.location = helpers.NormalizeAzureRegion(uc.location)
}
uc.location = helpers.NormalizeAzureRegion(uc.location)

if uc.timeoutInMinutes != -1 {
timeout := time.Duration(uc.timeoutInMinutes) * time.Minute
Expand All @@ -108,13 +107,13 @@ func (uc *upgradeCmd) validate(cmd *cobra.Command) error {
return fmt.Errorf("--deployment-dir must be specified")
}

if uc.client, err = uc.authArgs.getClient(); err != nil {
return fmt.Errorf("Failed to get client: %s", err)
if err = uc.authArgs.validateAuthArgs(); err != nil {
return fmt.Errorf("%s", err)
}
return nil
}

func (uc *upgradeCmd) load(cmd *cobra.Command) error {
func (uc *upgradeCmd) loadCluster(cmd *cobra.Command) error {
var err error
_, err = uc.client.EnsureResourceGroup(uc.resourceGroupName, uc.location, nil)
if err != nil {
Expand Down Expand Up @@ -167,9 +166,8 @@ func (uc *upgradeCmd) load(cmd *cobra.Command) error {
return fmt.Errorf("version %s is not supported", uc.upgradeVersion)
}

uc.client, err = uc.authArgs.getClient()
if err != nil {
return fmt.Errorf("failed to get client") // TODO: cleanup
if uc.client, err = uc.authArgs.getClient(); err != nil {
return fmt.Errorf("Failed to get client: %s", err)
}

// Read name suffix to identify nodes in the resource group that belong
Expand Down Expand Up @@ -204,7 +202,7 @@ func (uc *upgradeCmd) run(cmd *cobra.Command, args []string) error {
log.Fatalf("error validating upgrade command: %v", err)
}

err = uc.load(cmd)
err = uc.loadCluster(cmd)
if err != nil {
log.Fatalf("error loading existing cluster: %v", err)
}
Expand Down
270 changes: 135 additions & 135 deletions cmd/upgrade_test.go
Original file line number Diff line number Diff line change
@@ -1,135 +1,135 @@
package cmd

import (
"fmt"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/spf13/cobra"
)

var _ = Describe("the upgrade command", func() {

It("should create an upgrade command", func() {
output := newUpgradeCmd()

Expect(output.Use).Should(Equal(upgradeName))
Expect(output.Short).Should(Equal(upgradeShortDescription))
Expect(output.Long).Should(Equal(upgradeLongDescription))
Expect(output.Flags().Lookup("location")).NotTo(BeNil())
Expect(output.Flags().Lookup("resource-group")).NotTo(BeNil())
Expect(output.Flags().Lookup("deployment-dir")).NotTo(BeNil())
Expect(output.Flags().Lookup("upgrade-version")).NotTo(BeNil())
})

It("should validate an upgrade command", func() {
r := &cobra.Command{}

cases := []struct {
uc *upgradeCmd
expectedErr error
}{
{
uc: &upgradeCmd{
resourceGroupName: "",
deploymentDirectory: "_output/test",
upgradeVersion: "1.8.9",
location: "centralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--resource-group must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/test",
upgradeVersion: "1.8.9",
location: "",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--location must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/test",
upgradeVersion: "",
location: "southcentralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--upgrade-version must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "",
upgradeVersion: "1.9.0",
location: "southcentralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--deployment-dir must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "",
upgradeVersion: "1.9.0",
location: "southcentralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--deployment-dir must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/mydir",
upgradeVersion: "1.9.0",
location: "southcentralus",
authArgs: authArgs{},
},
expectedErr: fmt.Errorf("Failed to get client: --subscription-id is required (and must be a valid UUID)"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/mydir",
upgradeVersion: "1.9.0",
location: "southcentralus",
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
RawAzureEnvironment: "AzurePublicCloud",
AuthMethod: "device",
},
},
expectedErr: nil,
},
}

for _, c := range cases {
err := c.uc.validate(r)
if c.expectedErr != nil && err != nil {
Expect(err.Error()).To(Equal(c.expectedErr.Error()))
} else {
Expect(err).To(BeNil())
Expect(c.expectedErr).To(BeNil())
}
}

})

})
package cmd

import (
"fmt"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/spf13/cobra"
)

var _ = Describe("the upgrade command", func() {

It("should create an upgrade command", func() {
output := newUpgradeCmd()

Expect(output.Use).Should(Equal(upgradeName))
Expect(output.Short).Should(Equal(upgradeShortDescription))
Expect(output.Long).Should(Equal(upgradeLongDescription))
Expect(output.Flags().Lookup("location")).NotTo(BeNil())
Expect(output.Flags().Lookup("resource-group")).NotTo(BeNil())
Expect(output.Flags().Lookup("deployment-dir")).NotTo(BeNil())
Expect(output.Flags().Lookup("upgrade-version")).NotTo(BeNil())
})

It("should validate an upgrade command", func() {
r := &cobra.Command{}

cases := []struct {
uc *upgradeCmd
expectedErr error
}{
{
uc: &upgradeCmd{
resourceGroupName: "",
deploymentDirectory: "_output/test",
upgradeVersion: "1.8.9",
location: "centralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--resource-group must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/test",
upgradeVersion: "1.8.9",
location: "",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--location must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/test",
upgradeVersion: "",
location: "southcentralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--upgrade-version must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "",
upgradeVersion: "1.9.0",
location: "southcentralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--deployment-dir must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "",
upgradeVersion: "1.9.0",
location: "southcentralus",
timeoutInMinutes: 60,
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
},
},
expectedErr: fmt.Errorf("--deployment-dir must be specified"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/mydir",
upgradeVersion: "1.9.0",
location: "southcentralus",
authArgs: authArgs{},
},
expectedErr: fmt.Errorf("--subscription-id is required (and must be a valid UUID)"),
},
{
uc: &upgradeCmd{
resourceGroupName: "test",
deploymentDirectory: "_output/mydir",
upgradeVersion: "1.9.0",
location: "southcentralus",
authArgs: authArgs{
rawSubscriptionID: "99999999-0000-0000-0000-000000000000",
RawAzureEnvironment: "AzurePublicCloud",
AuthMethod: "device",
},
},
expectedErr: nil,
},
}

for _, c := range cases {
err := c.uc.validate(r)
if c.expectedErr != nil && err != nil {
Expect(err.Error()).To(Equal(c.expectedErr.Error()))
} else {
Expect(err).To(BeNil())
Expect(c.expectedErr).To(BeNil())
}
}

})

})

0 comments on commit e23310d

Please sign in to comment.