Skip to content

Commit

Permalink
Merge pull request #2258 from cristianoveiga/ocm-9613-allow-billing-a…
Browse files Browse the repository at this point in the history
…ccount-update

OCM-9613 | feat: Allow billing account update via the cluster edit command
  • Loading branch information
openshift-merge-bot[bot] authored Oct 1, 2024
2 parents 3300a71 + 8c796d2 commit cd5b932
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 33 deletions.
23 changes: 8 additions & 15 deletions cmd/create/cluster/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ func initFlags(cmd *cobra.Command) {
&args.billingAccount,
"billing-account",
"",
"Account ID used for billing subscriptions purchased via the AWS marketplace",
"Account ID used for billing subscriptions purchased through the AWS console for ROSA",
)

flags.BoolVar(
Expand Down Expand Up @@ -1115,8 +1115,9 @@ func run(cmd *cobra.Command, _ []string) {
if !isHcpBillingTechPreview {

if billingAccount != "" && !ocm.IsValidAWSAccount(billingAccount) {
r.Reporter.Errorf("Billing account is invalid. Run the command again with a valid billing account. %s",
listBillingAccountMessage)
r.Reporter.Errorf("Provided billing account number %s is not valid. "+
"Rerun the command with a valid billing account number. %s",
billingAccount, listBillingAccountMessage)
os.Exit(1)
}

Expand Down Expand Up @@ -1155,20 +1156,20 @@ func run(cmd *cobra.Command, _ []string) {
billingAccount = aws.ParseOption(billingAccount)
}

err := validateBillingAccount(billingAccount)
err := ocm.ValidateBillingAccount(billingAccount)
if err != nil {
r.Reporter.Errorf("%v", err)
os.Exit(1)
}

// Get contract info
contracts, isContractEnabled := GetBillingAccountContracts(cloudAccounts, billingAccount)
contracts, isContractEnabled := ocm.GetBillingAccountContracts(cloudAccounts, billingAccount)

if billingAccount != awsCreator.AccountID {
r.Reporter.Infof(
"The selected AWS billing account is a different account than your AWS infrastructure account." +
"The AWS billing account you selected is different from your AWS infrastructure account. " +
"The AWS billing account will be charged for subscription usage. " +
"The AWS infrastructure account will be used for managing the cluster.",
"The AWS infrastructure account contains the ROSA infrastructure.",
)
} else {
r.Reporter.Infof("Using '%s' as billing account.",
Expand Down Expand Up @@ -3387,14 +3388,6 @@ func clusterConfigFor(
return clusterConfig, nil
}

func validateBillingAccount(billingAccount string) error {
if billingAccount == "" || !ocm.IsValidAWSAccount(billingAccount) {
return fmt.Errorf("billing account is invalid. Run the command again with a valid billing account. %s",
listBillingAccountMessage)
}
return nil
}

func provideBillingAccount(billingAccounts []string, accountID string, r *rosa.Runtime) (string, error) {
if !helper.ContainsPrefix(billingAccounts, accountID) {
return "", fmt.Errorf("A billing account is required for Hosted Control Plane clusters. %s",
Expand Down
17 changes: 8 additions & 9 deletions cmd/create/cluster/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ var _ = Describe("Validate cloud accounts", func() {
Dimensions(v1.NewContractDimension().Name("control_plane").Value("4")))
cloudAccount, err := mockCloudAccount.Build()
Expect(err).NotTo(HaveOccurred())
_, isContractEnabled := GetBillingAccountContracts([]*v1.CloudAccount{cloudAccount}, "1234567")
_, isContractEnabled := ocm.GetBillingAccountContracts([]*v1.CloudAccount{cloudAccount}, "1234567")
Expect(isContractEnabled).To(Equal(true))
})

Expand All @@ -287,7 +287,7 @@ var _ = Describe("Validate cloud accounts", func() {
" | Number of clusters: |'4' | \n" +
" +---------------------+----------------+ \n"

contractDisplay := GenerateContractDisplay(mockContract)
contractDisplay := ocm.GenerateContractDisplay(mockContract)

Expect(contractDisplay).To(Equal(expected))
})
Expand Down Expand Up @@ -435,24 +435,23 @@ var _ = Describe("validateBillingAccount()", func() {

It("OK: valid billing account", func() {
validBillingAccount := "123456789012"
err := validateBillingAccount(validBillingAccount)
err := ocm.ValidateBillingAccount(validBillingAccount)
Expect(err).NotTo(HaveOccurred())
})

It("KO: fails to validate a wrong billing account", func() {
wrongBillingAccount := "123"
err := validateBillingAccount(wrongBillingAccount)
err := ocm.ValidateBillingAccount(wrongBillingAccount)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("billing account is invalid. Run the command again with a valid billing account." +
" To see the list of billing account options, you can use interactive mode by passing '-i'."))
Expect(err.Error()).To(Equal("Provided billing account number 123 is not valid. " +
"Rerun the command with a valid billing account number"))
})

It("KO: fails to validate an empty billing account", func() {
wrongBillingAccount := ""
err := validateBillingAccount(wrongBillingAccount)
err := ocm.ValidateBillingAccount(wrongBillingAccount)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("billing account is invalid. Run the command again with a valid billing account." +
" To see the list of billing account options, you can use interactive mode by passing '-i'."))
Expect(err.Error()).To(Equal("A billing account number is required"))
})

})
Expand Down
84 changes: 83 additions & 1 deletion cmd/edit/cluster/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ var args struct {
// Audit log forwarding
auditLogRoleARN string

// HCP options:
billingAccount string

// Other options
additionalAllowedPrincipals []string

Expand Down Expand Up @@ -176,6 +179,13 @@ func init() {
)

clusterRegistryConfigArgs = clusterregistryconfig.AddClusterRegistryConfigFlags(Cmd)

flags.StringVar(
&args.billingAccount,
"billing-account",
"",
"Account ID used for billing subscriptions purchased through the AWS console for ROSA",
)
}

func run(cmd *cobra.Command, _ []string) {
Expand All @@ -192,7 +202,7 @@ func run(cmd *cobra.Command, _ []string) {
"additional-trust-bundle-file", "additional-allowed-principals", "audit-log-arn",
"registry-config-allowed-registries", "registry-config-blocked-registries",
"registry-config-insecure-registries", "allowed-registries-for-import",
"registry-config-platform-allowlist", "registry-config-additional-trusted-ca"} {
"registry-config-platform-allowlist", "registry-config-additional-trusted-ca", "billing-account"} {
if cmd.Flags().Changed(flag) {
changedFlags = true
break
Expand Down Expand Up @@ -709,6 +719,78 @@ func run(cmd *cobra.Command, _ []string) {
}
}

var billingAccount string
if cmd.Flags().Changed("billing-account") {
billingAccount = args.billingAccount

if billingAccount != "" && !aws.IsHostedCP(cluster) {
r.Reporter.Errorf("Billing accounts are only supported for Hosted Control Plane clusters")
os.Exit(1)
}
if billingAccount != "" && !ocm.IsValidAWSAccount(billingAccount) {
r.Reporter.Errorf("Provided billing account number %s is not valid. "+
"Rerun the command with a valid billing account number", billingAccount)
os.Exit(1)
}
} else {
billingAccount = cluster.AWS().BillingAccountID()
}

if interactive.Enabled() && aws.IsHostedCP(cluster) {
cloudAccounts, err := r.OCMClient.GetBillingAccounts()
if err != nil {
r.Reporter.Errorf("%s", err)
os.Exit(1)
}

billingAccounts := ocm.GenerateBillingAccountsList(cloudAccounts)
if len(billingAccounts) > 0 {
billingAccount, err = interactive.GetOption(interactive.Input{
Question: "Update billing account",
Help: cmd.Flags().Lookup("billing-account").Usage,
Default: billingAccount,
DefaultMessage: fmt.Sprintf("current = '%s'", cluster.AWS().BillingAccountID()),
Required: true,
Options: billingAccounts,
})

if err != nil {
r.Reporter.Errorf("Expected a valid billing account: '%s'", err)
os.Exit(1)
}

billingAccount = aws.ParseOption(billingAccount)
}

err = ocm.ValidateBillingAccount(billingAccount)
if err != nil {
r.Reporter.Errorf("%v", err)
os.Exit(1)
}

// Get contract info
contracts, isContractEnabled := ocm.GetBillingAccountContracts(cloudAccounts, billingAccount)

if billingAccount != r.Creator.AccountID {
r.Reporter.Infof(
"The AWS billing account you selected is different from your AWS infrastructure account. " +
"The AWS billing account will be charged for subscription usage. " +
"The AWS infrastructure account contains the ROSA infrastructure.",
)
}

if isContractEnabled && len(contracts) > 0 {
//currently, an AWS account will have only one ROSA HCP active contract at a time
contractDisplay := ocm.GenerateContractDisplay(contracts[0])
r.Reporter.Infof(contractDisplay)
}
}

// sets the billing account only if it has changed
if billingAccount != "" && billingAccount != cluster.AWS().BillingAccountID() {
clusterConfig.BillingAccount = billingAccount
}

r.Reporter.Debugf("Updating cluster '%s'", clusterKey)
err = r.OCMClient.UpdateCluster(cluster.ID(), r.Creator, clusterConfig)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
- name: registry-config-allowed-registries-for-import
- name: registry-config-platform-allowlist
- name: registry-config-additional-trusted-ca
- name: billing-account
19 changes: 12 additions & 7 deletions pkg/interactive/interactive.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ import (
)

type Input struct {
Question string
Help string
Options []string
Default interface{}
Required bool
Validators []Validator
Question string
Help string
Options []string
Default interface{}
DefaultMessage string
Required bool
Validators []Validator
}

// Gets string input from the command line
Expand Down Expand Up @@ -179,7 +180,11 @@ func GetOption(input Input) (a string, err error) {
}
defaultMessage := ""
if dflt != "" {
defaultMessage = fmt.Sprintf("default = '%s'", dflt)
if input.DefaultMessage != "" {
defaultMessage = input.DefaultMessage
} else {
defaultMessage = fmt.Sprintf("default = '%s'", dflt)
}
}
question := input.Question
optionalMessage := ""
Expand Down
48 changes: 48 additions & 0 deletions pkg/ocm/billing.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,51 @@ func HasValidContracts(cloudAccount *v1.CloudAccount) bool {
func IsValidAWSAccount(account string) bool {
return awsAccountRegexp.MatchString(account)
}

func ValidateBillingAccount(billingAccount string) error {
if billingAccount == "" {
return fmt.Errorf("A billing account number is required")
}
if !IsValidAWSAccount(billingAccount) {
return fmt.Errorf("Provided billing account number %s is not valid. "+
"Rerun the command with a valid billing account number", billingAccount)
}
return nil
}

func GenerateContractDisplay(contract *v1.Contract) string {
format := "Jan 02, 2006"
dimensions := contract.Dimensions()

numberOfVCPUs, numberOfClusters := GetNumsOfVCPUsAndClusters(dimensions)

contractDisplay := fmt.Sprintf(`
+---------------------+----------------+
| Start Date |%s |
| End Date |%s |
| Number of vCPUs: |'%s' |
| Number of clusters: |'%s' |
+---------------------+----------------+
`,
contract.StartDate().Format(format),
contract.EndDate().Format(format),
strconv.Itoa(numberOfVCPUs),
strconv.Itoa(numberOfClusters),
)

return contractDisplay
}

func GetBillingAccountContracts(cloudAccounts []*v1.CloudAccount,
billingAccount string) ([]*v1.Contract, bool) {
var contracts []*v1.Contract
for _, account := range cloudAccounts {
if account.CloudAccountID() == billingAccount {
contracts = account.Contracts()
if HasValidContracts(account) {
return contracts, true
}
}
}
return contracts, false
}
5 changes: 4 additions & 1 deletion pkg/ocm/clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ func (c *Client) UpdateCluster(clusterKey string, creator *aws.Creator, config S
clusterBuilder.RegistryConfig(registryConfigBuilder)
}

if config.AuditLogRoleARN != nil || config.AdditionalAllowedPrincipals != nil {
if config.AuditLogRoleARN != nil || config.AdditionalAllowedPrincipals != nil || config.BillingAccount != "" {
awsBuilder := cmv1.NewAWS()
if config.AdditionalAllowedPrincipals != nil {
awsBuilder = awsBuilder.AdditionalAllowedPrincipals(config.AdditionalAllowedPrincipals...)
Expand All @@ -656,6 +656,9 @@ func (c *Client) UpdateCluster(clusterKey string, creator *aws.Creator, config S
auditLogBuiler := cmv1.NewAuditLog().RoleArn(*config.AuditLogRoleARN)
awsBuilder = awsBuilder.AuditLog(auditLogBuiler)
}
if config.BillingAccount != "" {
awsBuilder.BillingAccountID(config.BillingAccount)
}
clusterBuilder.AWS(awsBuilder)
}

Expand Down

0 comments on commit cd5b932

Please sign in to comment.