Skip to content

Commit

Permalink
Store access token in keyring
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiught committed Feb 13, 2023
1 parent ee3048f commit ac25ef7
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 20 deletions.
4 changes: 2 additions & 2 deletions internal/cli/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ func apiCmdRun(cli *cli, inputs *apiCmdInputs) func(cmd *cobra.Command, args []s
return err
}

bearerToken := cli.config.Tenants[cli.tenant].AccessToken
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken))
accessToken := getAccessToken(cli.config.Tenants[cli.tenant])
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("%s/%s", userAgent, strings.TrimPrefix(buildinfo.Version, "v")))

Expand Down
19 changes: 17 additions & 2 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ func (t *Tenant) regenerateAccessToken(ctx context.Context, c *cli) error {
)
}

err := keyring.StoreAccessToken(t.Domain, t.AccessToken)
if err != nil {
t.AccessToken = ""
}

return nil
}

Expand Down Expand Up @@ -208,7 +213,7 @@ func (c *cli) setup(ctx context.Context) error {

api, err := management.New(
t.Domain,
management.WithStaticToken(t.AccessToken),
management.WithStaticToken(getAccessToken(t)),
management.WithUserAgent(userAgent),
)
if err != nil {
Expand All @@ -219,6 +224,15 @@ func (c *cli) setup(ctx context.Context) error {
return nil
}

func getAccessToken(t Tenant) string {
accessToken, err := keyring.GetAccessToken(t.Domain)
if err == nil && accessToken != "" {
return accessToken
}

return t.AccessToken
}

// prepareTenant loads the tenant, refreshing its token if necessary.
// The tenant access token needs a refresh if:
// 1. The tenant scopes are different than the currently required scopes.
Expand All @@ -234,7 +248,8 @@ func (c *cli) prepareTenant(ctx context.Context) (Tenant, error) {
return RunLoginAsUser(ctx, c, t.additionalRequestedScopes())
}

if t.AccessToken != "" && !t.hasExpiredToken() {
accessToken := getAccessToken(t)
if accessToken != "" && !t.hasExpiredToken() {
return t, nil
}

Expand Down
28 changes: 28 additions & 0 deletions internal/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/olekukonko/tablewriter"
"github.com/stretchr/testify/assert"
"github.com/zalando/go-keyring"

"github.com/auth0/auth0-cli/internal/auth"
"github.com/auth0/auth0-cli/internal/display"
Expand Down Expand Up @@ -146,3 +147,30 @@ func TestTenant_AdditionalRequestedScopes(t *testing.T) {
})
}
}

func TestGetAccessToken(t *testing.T) {
mockTenantDomain := "mock-tenant.com"

t.Run("return empty string if no keyring and no access token on tenant struct", func(t *testing.T) {
assert.Equal(t, "", getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: ""}))
})

t.Run("returns access token on tenant struct if no keyring", func(t *testing.T) {
mockAccessToken := "this is the access token"

assert.Equal(t, mockAccessToken, getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: mockAccessToken}))
})

t.Run("returns chunked access token if set on the keyring", func(t *testing.T) {
accessTokenChunks := []string{"access-token-chunk0", "access-token-chunk1"}

keyring.MockInit()
err := keyring.Set("Auth0 CLI Access Token 0", mockTenantDomain, accessTokenChunks[0])
assert.NoError(t, err)
err = keyring.Set("Auth0 CLI Access Token 1", mockTenantDomain, accessTokenChunks[1])
assert.NoError(t, err)

assert.Equal(t, strings.Join(accessTokenChunks, ""), getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: ""}))
assert.Equal(t, strings.Join(accessTokenChunks, ""), getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: "even if this is set for some reason"}))
})
}
36 changes: 23 additions & 13 deletions internal/cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,22 @@ func RunLoginAsUser(ctx context.Context, cli *cli, additionalScopes []string) (T
cli.renderer.Infof("Tenant: %s", result.Domain)
cli.renderer.Newline()

tenant := Tenant{
Name: result.Tenant,
Domain: result.Domain,
ExpiresAt: result.ExpiresAt,
Scopes: append(auth.RequiredScopes, additionalScopes...),
}

if err := keyring.StoreRefreshToken(result.Domain, result.RefreshToken); err != nil {
cli.renderer.Warnf("Could not store the refresh token to the keyring: %s", err)
cli.renderer.Warnf("Could not store the access token and the refresh token to the keyring: %s", err)
cli.renderer.Warnf("Expect to login again when your access token expires.")
}

tenant := Tenant{
Name: result.Tenant,
Domain: result.Domain,
AccessToken: result.AccessToken,
ExpiresAt: result.ExpiresAt,
Scopes: append(auth.RequiredScopes, additionalScopes...),
if err := keyring.StoreAccessToken(result.Domain, result.AccessToken); err != nil {
// In case we don't have a keyring, we want the
// access token to be saved in the config file.
tenant.AccessToken = result.AccessToken
}

err = cli.addTenant(tenant)
Expand Down Expand Up @@ -266,16 +271,21 @@ func RunLoginAsMachine(ctx context.Context, inputs LoginInputs, cli *cli, cmd *c
"Ensure that the provided client-id, client-secret and domain are correct. \n\nerror: %w\n", err)
}

t := Tenant{
Domain: inputs.Domain,
ExpiresAt: token.ExpiresAt,
ClientID: inputs.ClientID,
}

if err = keyring.StoreClientSecret(inputs.Domain, inputs.ClientSecret); err != nil {
cli.renderer.Warnf("Could not store the client secret to the keyring: %s", err)
cli.renderer.Warnf("Could not store the client secret and the access token to the keyring: %s", err)
cli.renderer.Warnf("Expect to login again when your access token expires.")
}

t := Tenant{
Domain: inputs.Domain,
AccessToken: token.AccessToken,
ExpiresAt: token.ExpiresAt,
ClientID: inputs.ClientID,
if err := keyring.StoreAccessToken(inputs.Domain, token.AccessToken); err != nil {
// In case we don't have a keyring, we want the
// access token to be saved in the config file.
t.AccessToken = token.AccessToken
}

if err = cli.addTenant(t); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/cli/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ var (
IsRequired: true,
}

// Purposefully not setting the Help value on the Flag because overridden where appropriate.
// Purposefully not setting the Help value on the Flag because overridden where appropriate.
organizationNumber = Flag{
Name: "Number",
LongForm: "number",
Expand Down
65 changes: 63 additions & 2 deletions internal/keyring/keyring.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ import (
)

const (
secretRefreshToken = "Auth0 CLI Refresh Token"
secretClientSecret = "Auth0 CLI Client Secret"
secretRefreshToken = "Auth0 CLI Refresh Token"
secretClientSecret = "Auth0 CLI Client Secret"
secretAccessToken = "Auth0 CLI Access Token"
secretAccessTokenChunkSizeInBytes = 2048

// Access tokens have no size limit, but should be smaller than (10*2048) bytes.
// The max number of loops safeguards against infinite loops, however unlikely.
secretAccessTokenMaxChunks = 50
)

// StoreRefreshToken stores a tenant's refresh token in the system keyring.
Expand Down Expand Up @@ -49,9 +55,64 @@ func DeleteSecretsForTenant(tenant string) error {
}
}

for i := 0; i < secretAccessTokenMaxChunks; i++ {
if err := keyring.Delete(fmt.Sprintf("%s %d", secretAccessToken, i), tenant); err != nil {
if !errors.Is(err, keyring.ErrNotFound) {
multiErrors = append(multiErrors, fmt.Sprintf("failed to delete access token from keyring: %s", err))
}
}
}

if len(multiErrors) == 0 {
return nil
}

return errors.New(strings.Join(multiErrors, ", "))
}

func StoreAccessToken(tenant, value string) error {
chunks := chunk(value, secretAccessTokenChunkSizeInBytes)

for i := 0; i < len(chunks); i++ {
err := keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, i), tenant, chunks[i])
if err != nil {
return err
}
}

return nil
}

func GetAccessToken(tenant string) (string, error) {
var accessToken string

for i := 0; i < secretAccessTokenMaxChunks; i++ {
a, err := keyring.Get(fmt.Sprintf("%s %d", secretAccessToken, i), tenant)
if err == keyring.ErrNotFound {
return accessToken, nil
}
if err != nil {
return "", err
}
accessToken += a
}

return accessToken, nil
}

func chunk(slice string, chunkSize int) []string {
var chunks []string
for i := 0; i < len(slice); i += chunkSize {
end := i + chunkSize

// necessary check to avoid slicing beyond
// slice capacity
if end > len(slice) {
end = len(slice)
}

chunks = append(chunks, slice[i:end])
}

return chunks
}
40 changes: 40 additions & 0 deletions internal/keyring/keyring_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package keyring

import (
"fmt"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/zalando/go-keyring"
Expand Down Expand Up @@ -71,4 +74,41 @@ func TestSecrets(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, expectedRefreshToken, actualRefreshToken)
})

t.Run("it successfully stores an access token", func(t *testing.T) {
keyring.MockInit()

expectedAccessToken := randomStringOfLength((2048 * 5) + 1) // Some arbitrarily long random string
err := StoreAccessToken(testTenantName, expectedAccessToken)
assert.NoError(t, err)

actualAccessToken, err := GetAccessToken(testTenantName)
assert.NoError(t, err)
assert.Equal(t, expectedAccessToken, actualAccessToken)
})

t.Run("it successfully retrieves an access token split up into multiple chunks", func(t *testing.T) {
keyring.MockInit()

err := keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, 0), testTenantName, "chunk0")
assert.NoError(t, err)
err = keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, 1), testTenantName, "chunk1")
assert.NoError(t, err)
err = keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, 2), testTenantName, "chunk2")
assert.NoError(t, err)

actualAccessToken, err := GetAccessToken(testTenantName)
assert.NoError(t, err)
assert.Equal(t, "chunk0chunk1chunk2", actualAccessToken)
})
}

func randomStringOfLength(length int) string {
var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
charset := "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[seededRand.Intn(len(charset))]
}
return string(b)
}

0 comments on commit ac25ef7

Please sign in to comment.