Skip to content

Commit

Permalink
OCM-9704 | feat: persist refreshed tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
tylercreller committed Aug 9, 2024
1 parent ac85896 commit 518383e
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 10 deletions.
7 changes: 7 additions & 0 deletions cmd/logs/install/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ func run(cmd *cobra.Command, argv []string) {
os.Exit(0)
}
printLog(logResponse.Body(), spin)

err = r.OCMClient.KeepTokensAlive()
if err != nil {
r.Reporter.Errorf(fmt.Sprintf("Failed to keep tokens alive for polling: %v", err))
os.Exit(1)
}

return false
})
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions cmd/logs/uninstall/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ func run(cmd *cobra.Command, argv []string) {
os.Exit(0)
}
printLog(logResponse.Body(), spin)

err = r.OCMClient.KeepTokensAlive()
if err != nil {
r.Reporter.Errorf(fmt.Sprintf("Failed to keep tokens alive for polling: %v", err))
os.Exit(1)
}

return false
})
if err != nil {
Expand Down
16 changes: 16 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,22 @@ func (c *Config) Connection() (connection *sdk.Connection, err error) {
return
}

func PersistTokens(cfg *Config, accessToken string, refreshToken string) error {
var err error
activeCfg := cfg

if activeCfg == nil {
// Load the configuration if none is provided
activeCfg, err = Load()
if err != nil {
return err
}
}
activeCfg.AccessToken = accessToken
activeCfg.RefreshToken = refreshToken
return Save(activeCfg)
}

// IsKeyringManaged returns the keyring name and a boolean indicating if the config is managed by the keyring.
func IsKeyringManaged() (keyring string, ok bool) {
keyring = os.Getenv(properties.KeyringEnvKey)
Expand Down
55 changes: 46 additions & 9 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ var _ = Describe("Config", Ordered, func() {

BeforeAll(func() {
tmpdir, err = os.MkdirTemp("/tmp", ".ocm-config-*")
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
os.Setenv("OCM_CONFIG", tmpdir+"/ocm_config.json")
})

Expand All @@ -101,7 +101,7 @@ var _ = Describe("Config", Ordered, func() {
Save(cfg)

myconf, err := Load()
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(myconf.URL).To(Equal(url))
})
})
Expand All @@ -117,10 +117,47 @@ var _ = Describe("Config", Ordered, func() {

It("Saves and loads config", func() {
myconf, err := Load()
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(myconf).To(BeNil())
})
})

When("Persisting tokens", Ordered, func() {
var tmpdir string
var err error

BeforeAll(func() {
tmpdir, err = os.MkdirTemp("/tmp", ".ocm-config-*")
Expect(err).NotTo(HaveOccurred())
os.Setenv("OCM_CONFIG", tmpdir+"/ocm_config.json")
})

AfterAll(func() {
os.Setenv("OCM_CONFIG", "")
})

It("Uses existing config and saves", func() {
cfg := &Config{}
err := PersistTokens(cfg, "foo", "bar")
Expect(err).NotTo(HaveOccurred())

myconf, err := Load()
Expect(err).NotTo(HaveOccurred())
Expect(myconf.AccessToken).To(Equal("foo"))
Expect(myconf.RefreshToken).To(Equal("bar"))
})

It("Loads config and saves", func() {
err := PersistTokens(nil, "foo", "bar")
Expect(err).NotTo(HaveOccurred())

myconf, err := Load()
Expect(err).NotTo(HaveOccurred())
Expect(myconf.AccessToken).To(Equal("foo"))
Expect(myconf.RefreshToken).To(Equal("bar"))
})
})

})
var _ = Describe("Config Keyring", func() {
When("Load()", func() {
Expand All @@ -141,7 +178,7 @@ var _ = Describe("Config Keyring", func() {
GetConfigFromKeyring = mockSpy.MockGetConfigFromKeyring

cfg, err := Load()
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(cfg).ToNot(BeNil())
Expect(cfg.AccessToken).To(Equal("access_token"))
Expect(mockSpy.calledGet).To(BeTrue())
Expand All @@ -152,7 +189,7 @@ var _ = Describe("Config Keyring", func() {
GetConfigFromKeyring = mockSpy.MockGetConfigFromKeyring

cfg, err := Load()
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(cfg).To(BeNil())
Expect(mockSpy.calledGet).To(BeTrue())
})
Expand All @@ -163,7 +200,7 @@ var _ = Describe("Config Keyring", func() {
GetConfigFromKeyring = mockSpy.MockGetConfigFromKeyring

cfg, err := Load()
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(cfg).To(BeNil())
Expect(mockSpy.calledGet).To(BeTrue())
})
Expand Down Expand Up @@ -200,7 +237,7 @@ var _ = Describe("Config Keyring", func() {
UpsertConfigToKeyring = mockSpy.MockUpsertConfigToKeyring

err := Save(data)
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(mockSpy.calledUpsert).To(BeTrue())
})

Expand Down Expand Up @@ -234,7 +271,7 @@ var _ = Describe("Config Keyring", func() {
RemoveConfigFromKeyring = mockSpy.MockRemoveConfigFromKeyring

err := Remove()
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(mockSpy.calledRemove).To(BeTrue())
})

Expand All @@ -258,7 +295,7 @@ func generateInvalidConfigBytes() []byte {
func generateConfigBytes(config Config) []byte {
data := &config
jsonData, err := json.Marshal(data)
Expect(err).To(BeNil())
Expect(err).NotTo(HaveOccurred())

return jsonData
}
27 changes: 26 additions & 1 deletion pkg/ocm/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,21 @@ func (b *ClientBuilder) Build() (result *Client, err error) {
if err != nil {
return
}
_, _, err = conn.Tokens(10 * time.Minute)
accessToken, refreshToken, err := conn.Tokens(10 * time.Minute)
if err != nil {
if strings.Contains(err.Error(), "invalid_grant") {
return nil, fmt.Errorf("your authorization token needs to be updated. " +
"Please login again using rosa login")
}
return nil, fmt.Errorf("error creating connection. Not able to get authentication token: %s", err)
}

// Persist tokens in the configuration file, the SDK may have refreshed them
err = config.PersistTokens(b.cfg, accessToken, refreshToken)
if err != nil {
return nil, fmt.Errorf("error creating connection. Can't persist tokens to config: %s", err)
}

return &Client{
ocm: conn,
}, nil
Expand All @@ -172,3 +179,21 @@ func (c *Client) GetConnectionURL() string {
func (c *Client) GetConnectionTokens(expiresIn ...time.Duration) (string, string, error) {
return c.ocm.Tokens(expiresIn...)
}

func (c *Client) KeepTokensAlive() error {
if c.ocm == nil {
return fmt.Errorf("Connection is nil")
}

accessToken, refreshToken, err := c.GetConnectionTokens(10 * time.Minute)
if err != nil {
return fmt.Errorf("Can't get new tokens: %v", err)
}

err = config.PersistTokens(nil, accessToken, refreshToken)
if err != nil {
return fmt.Errorf("Can't persist tokens to config: %v", err)
}

return nil
}
115 changes: 115 additions & 0 deletions pkg/ocm/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package ocm

import (
"net/http"
"os"
"time"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/onsi/gomega/ghttp"
sdk "github.com/openshift-online/ocm-sdk-go"
"github.com/openshift-online/ocm-sdk-go/logging"
. "github.com/openshift-online/ocm-sdk-go/testing"

"github.com/openshift/rosa/pkg/config"
)

var _ = Describe("OCM Client", Ordered, func() {
When("Keeping tokens alive", Ordered, func() {
var ssoServer, apiServer *ghttp.Server
var ocmClient *Client
var tmpdir string
var err error
accessToken := MakeTokenString("Bearer", 15*time.Minute)
refreshToken := MakeTokenString("Refresh", 15*time.Minute)
newAccessToken := MakeTokenString("Bearer", 15*time.Minute)
newRefreshToken := MakeTokenString("Refresh", 15*time.Minute)

BeforeAll(func() {
tmpdir, err = os.MkdirTemp("/tmp", ".ocm-config-*")
Expect(err).To(BeNil())
os.Setenv("OCM_CONFIG", tmpdir+"/ocm_config.json")
})

AfterAll(func() {
os.Setenv("OCM_CONFIG", "")
})

BeforeEach(func() {
// Create the servers:
ssoServer = MakeTCPServer()
apiServer = MakeTCPServer()
apiServer.SetAllowUnhandledRequests(true)
apiServer.SetUnhandledRequestStatusCode(http.StatusInternalServerError)

// Prepare the server:
ssoServer.AppendHandlers(
RespondWithAccessAndRefreshTokens(newAccessToken, newRefreshToken),
)
// Prepare the logger:
logger, err := logging.NewGoLoggerBuilder().
Debug(false).
Build()
Expect(err).NotTo(HaveOccurred())
// Set up the connection with the fake config
connection, err := sdk.NewConnectionBuilder().
Logger(logger).
Tokens(accessToken, refreshToken).
URL(apiServer.URL()).
TokenURL(ssoServer.URL()).
Build()
Expect(err).NotTo(HaveOccurred())

ocmClient = NewClientWithConnection(connection)
config.Save(&config.Config{
AccessToken: accessToken,
RefreshToken: refreshToken,
})
})

AfterEach(func() {
ssoServer.Close()
apiServer.Close()
})

It("Fails with inability to get tokens", func() {
config.Save(&config.Config{})
connection, _ := sdk.NewConnectionBuilder().
Tokens(refreshToken).
URL(apiServer.URL()).
TokenURL(ssoServer.URL()).
Build()
ssoServer.Reset()
ssoServer.AllowUnhandledRequests = true
ssoServer.UnhandledRequestStatusCode = http.StatusInternalServerError
ocmClient = NewClientWithConnection(connection)
err = ocmClient.KeepTokensAlive()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("Can't get new tokens"))
})

It("Fails without a valid connection", func() {
ocmClient = NewClientWithConnection(nil)
err = ocmClient.KeepTokensAlive()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("Connection is nil"))
})

It("Persists updated tokens", func() {
myconf, err := config.Load()
Expect(err).NotTo(HaveOccurred())
Expect(myconf).NotTo(BeNil())
Expect(myconf.AccessToken).To(Equal(accessToken))
Expect(myconf.RefreshToken).To(Equal(refreshToken))

err = ocmClient.KeepTokensAlive()
Expect(err).NotTo(HaveOccurred())

myconf, err = config.Load()
Expect(err).NotTo(HaveOccurred())
Expect(myconf.AccessToken).To(Equal(newAccessToken))
Expect(myconf.RefreshToken).To(Equal(newRefreshToken))
})
})
})

0 comments on commit 518383e

Please sign in to comment.