diff --git a/internal/cli/api.go b/internal/cli/api.go index afa84c3da..41a59896b 100644 --- a/internal/cli/api.go +++ b/internal/cli/api.go @@ -141,8 +141,8 @@ func apiCmdRun(cli *cli, inputs *apiCmdInputs) func(cmd *cobra.Command, args []s return err } - bearerToken := cli.config.Tenants[cli.tenant].AccessToken - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) + accessToken := getAccessToken(cli.config.Tenants[cli.tenant]) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) request.Header.Set("Content-Type", "application/json") request.Header.Set("User-Agent", fmt.Sprintf("%s/%s", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))) diff --git a/internal/cli/cli.go b/internal/cli/cli.go index cd819bc4e..087d68e45 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -162,6 +162,11 @@ func (t *Tenant) regenerateAccessToken(ctx context.Context, c *cli) error { ) } + err := keyring.StoreAccessToken(t.Domain, t.AccessToken) + if err != nil { + t.AccessToken = "" + } + return nil } @@ -208,7 +213,7 @@ func (c *cli) setup(ctx context.Context) error { api, err := management.New( t.Domain, - management.WithStaticToken(t.AccessToken), + management.WithStaticToken(getAccessToken(t)), management.WithUserAgent(userAgent), ) if err != nil { @@ -219,6 +224,15 @@ func (c *cli) setup(ctx context.Context) error { return nil } +func getAccessToken(t Tenant) string { + accessToken, err := keyring.GetAccessToken(t.Domain) + if err == nil && accessToken != "" { + return accessToken + } + + return t.AccessToken +} + // prepareTenant loads the tenant, refreshing its token if necessary. // The tenant access token needs a refresh if: // 1. The tenant scopes are different than the currently required scopes. @@ -234,7 +248,8 @@ func (c *cli) prepareTenant(ctx context.Context) (Tenant, error) { return RunLoginAsUser(ctx, c, t.additionalRequestedScopes()) } - if t.AccessToken != "" && !t.hasExpiredToken() { + accessToken := getAccessToken(t) + if accessToken != "" && !t.hasExpiredToken() { return t, nil } diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 7918bc67b..6575f0a28 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/olekukonko/tablewriter" "github.com/stretchr/testify/assert" + "github.com/zalando/go-keyring" "github.com/auth0/auth0-cli/internal/auth" "github.com/auth0/auth0-cli/internal/display" @@ -146,3 +147,30 @@ func TestTenant_AdditionalRequestedScopes(t *testing.T) { }) } } + +func TestGetAccessToken(t *testing.T) { + mockTenantDomain := "mock-tenant.com" + + t.Run("return empty string if no keyring and no access token on tenant struct", func(t *testing.T) { + assert.Equal(t, "", getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: ""})) + }) + + t.Run("returns access token on tenant struct if no keyring", func(t *testing.T) { + mockAccessToken := "this is the access token" + + assert.Equal(t, mockAccessToken, getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: mockAccessToken})) + }) + + t.Run("returns chunked access token if set on the keyring", func(t *testing.T) { + accessTokenChunks := []string{"access-token-chunk0", "access-token-chunk1"} + + keyring.MockInit() + err := keyring.Set("Auth0 CLI Access Token 0", mockTenantDomain, accessTokenChunks[0]) + assert.NoError(t, err) + err = keyring.Set("Auth0 CLI Access Token 1", mockTenantDomain, accessTokenChunks[1]) + assert.NoError(t, err) + + assert.Equal(t, strings.Join(accessTokenChunks, ""), getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: ""})) + assert.Equal(t, strings.Join(accessTokenChunks, ""), getAccessToken(Tenant{Domain: mockTenantDomain, AccessToken: "even if this is set for some reason"})) + }) +} diff --git a/internal/cli/login.go b/internal/cli/login.go index 07aeedd9d..d683bbe19 100644 --- a/internal/cli/login.go +++ b/internal/cli/login.go @@ -196,17 +196,22 @@ func RunLoginAsUser(ctx context.Context, cli *cli, additionalScopes []string) (T cli.renderer.Infof("Tenant: %s", result.Domain) cli.renderer.Newline() + tenant := Tenant{ + Name: result.Tenant, + Domain: result.Domain, + ExpiresAt: result.ExpiresAt, + Scopes: append(auth.RequiredScopes, additionalScopes...), + } + if err := keyring.StoreRefreshToken(result.Domain, result.RefreshToken); err != nil { - cli.renderer.Warnf("Could not store the refresh token to the keyring: %s", err) + cli.renderer.Warnf("Could not store the access token and the refresh token to the keyring: %s", err) cli.renderer.Warnf("Expect to login again when your access token expires.") } - tenant := Tenant{ - Name: result.Tenant, - Domain: result.Domain, - AccessToken: result.AccessToken, - ExpiresAt: result.ExpiresAt, - Scopes: append(auth.RequiredScopes, additionalScopes...), + if err := keyring.StoreAccessToken(result.Domain, result.AccessToken); err != nil { + // In case we don't have a keyring, we want the + // access token to be saved in the config file. + tenant.AccessToken = result.AccessToken } err = cli.addTenant(tenant) @@ -266,16 +271,21 @@ func RunLoginAsMachine(ctx context.Context, inputs LoginInputs, cli *cli, cmd *c "Ensure that the provided client-id, client-secret and domain are correct. \n\nerror: %w\n", err) } + t := Tenant{ + Domain: inputs.Domain, + ExpiresAt: token.ExpiresAt, + ClientID: inputs.ClientID, + } + if err = keyring.StoreClientSecret(inputs.Domain, inputs.ClientSecret); err != nil { - cli.renderer.Warnf("Could not store the client secret to the keyring: %s", err) + cli.renderer.Warnf("Could not store the client secret and the access token to the keyring: %s", err) cli.renderer.Warnf("Expect to login again when your access token expires.") } - t := Tenant{ - Domain: inputs.Domain, - AccessToken: token.AccessToken, - ExpiresAt: token.ExpiresAt, - ClientID: inputs.ClientID, + if err := keyring.StoreAccessToken(inputs.Domain, token.AccessToken); err != nil { + // In case we don't have a keyring, we want the + // access token to be saved in the config file. + t.AccessToken = token.AccessToken } if err = cli.addTenant(t); err != nil { diff --git a/internal/cli/organizations.go b/internal/cli/organizations.go index 57bb2a2b0..9473311c8 100644 --- a/internal/cli/organizations.go +++ b/internal/cli/organizations.go @@ -78,7 +78,7 @@ var ( IsRequired: true, } - // Purposefully not setting the Help value on the Flag because overridden where appropriate. + // Purposefully not setting the Help value on the Flag because overridden where appropriate. organizationNumber = Flag{ Name: "Number", LongForm: "number", diff --git a/internal/keyring/keyring.go b/internal/keyring/keyring.go index e59e80502..5905daef4 100644 --- a/internal/keyring/keyring.go +++ b/internal/keyring/keyring.go @@ -9,8 +9,14 @@ import ( ) const ( - secretRefreshToken = "Auth0 CLI Refresh Token" - secretClientSecret = "Auth0 CLI Client Secret" + secretRefreshToken = "Auth0 CLI Refresh Token" + secretClientSecret = "Auth0 CLI Client Secret" + secretAccessToken = "Auth0 CLI Access Token" + secretAccessTokenChunkSizeInBytes = 2048 + + // Access tokens have no size limit, but should be smaller than (10*2048) bytes. + // The max number of loops safeguards against infinite loops, however unlikely. + secretAccessTokenMaxChunks = 50 ) // StoreRefreshToken stores a tenant's refresh token in the system keyring. @@ -49,9 +55,64 @@ func DeleteSecretsForTenant(tenant string) error { } } + for i := 0; i < secretAccessTokenMaxChunks; i++ { + if err := keyring.Delete(fmt.Sprintf("%s %d", secretAccessToken, i), tenant); err != nil { + if !errors.Is(err, keyring.ErrNotFound) { + multiErrors = append(multiErrors, fmt.Sprintf("failed to delete access token from keyring: %s", err)) + } + } + } + if len(multiErrors) == 0 { return nil } return errors.New(strings.Join(multiErrors, ", ")) } + +func StoreAccessToken(tenant, value string) error { + chunks := chunk(value, secretAccessTokenChunkSizeInBytes) + + for i := 0; i < len(chunks); i++ { + err := keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, i), tenant, chunks[i]) + if err != nil { + return err + } + } + + return nil +} + +func GetAccessToken(tenant string) (string, error) { + var accessToken string + + for i := 0; i < secretAccessTokenMaxChunks; i++ { + a, err := keyring.Get(fmt.Sprintf("%s %d", secretAccessToken, i), tenant) + if err == keyring.ErrNotFound { + return accessToken, nil + } + if err != nil { + return "", err + } + accessToken += a + } + + return accessToken, nil +} + +func chunk(slice string, chunkSize int) []string { + var chunks []string + for i := 0; i < len(slice); i += chunkSize { + end := i + chunkSize + + // necessary check to avoid slicing beyond + // slice capacity + if end > len(slice) { + end = len(slice) + } + + chunks = append(chunks, slice[i:end]) + } + + return chunks +} diff --git a/internal/keyring/keyring_test.go b/internal/keyring/keyring_test.go index 9c8d84cfb..363b646eb 100644 --- a/internal/keyring/keyring_test.go +++ b/internal/keyring/keyring_test.go @@ -1,7 +1,10 @@ package keyring import ( + "fmt" + "math/rand" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/zalando/go-keyring" @@ -71,4 +74,41 @@ func TestSecrets(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expectedRefreshToken, actualRefreshToken) }) + + t.Run("it successfully stores an access token", func(t *testing.T) { + keyring.MockInit() + + expectedAccessToken := randomStringOfLength((2048 * 5) + 1) // Some arbitrarily long random string + err := StoreAccessToken(testTenantName, expectedAccessToken) + assert.NoError(t, err) + + actualAccessToken, err := GetAccessToken(testTenantName) + assert.NoError(t, err) + assert.Equal(t, expectedAccessToken, actualAccessToken) + }) + + t.Run("it successfully retrieves an access token split up into multiple chunks", func(t *testing.T) { + keyring.MockInit() + + err := keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, 0), testTenantName, "chunk0") + assert.NoError(t, err) + err = keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, 1), testTenantName, "chunk1") + assert.NoError(t, err) + err = keyring.Set(fmt.Sprintf("%s %d", secretAccessToken, 2), testTenantName, "chunk2") + assert.NoError(t, err) + + actualAccessToken, err := GetAccessToken(testTenantName) + assert.NoError(t, err) + assert.Equal(t, "chunk0chunk1chunk2", actualAccessToken) + }) +} + +func randomStringOfLength(length int) string { + var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) + charset := "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) }