Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DXCDT-293: Access token management for client credentials #537

Merged
merged 11 commits into from
Dec 2, 2022
47 changes: 43 additions & 4 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/joeshaw/envdecode"
"golang.org/x/oauth2/clientcredentials"
)

const (
Expand Down Expand Up @@ -65,7 +66,7 @@ type Result struct {
Domain string
RefreshToken string
AccessToken string
ExpiresIn int64
ExpiresAt time.Time
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this change enabled me to reuse the Result type for client credential access tokens. The associated arithmetic gets slightly shuffled as a result.

}

type State struct {
Expand Down Expand Up @@ -170,9 +171,11 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) {
return Result{
RefreshToken: res.RefreshToken,
AccessToken: res.AccessToken,
ExpiresIn: res.ExpiresIn,
Tenant: ten,
Domain: domain,
ExpiresAt: time.Now().Add(
time.Duration(res.ExpiresIn) * time.Second,
),
Tenant: ten,
Domain: domain,
}, nil
}
}
Expand Down Expand Up @@ -249,3 +252,39 @@ func parseTenant(accessToken string) (tenant, domain string, err error) {
}
return "", "", fmt.Errorf("audience not found for %s", audiencePath)
}

// ClientCredentials encapsulates all data to facilitate access token creation with client credentials (client ID and client secret)
type ClientCredentials struct {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would've used the clientcredentials.Config type but I wanted to encapsulate the token URL formulation logic within this function.

ClientID string
ClientSecret string
Domain string
}

// GetAccessTokenFromClientCreds generates an access token from client credentials
func GetAccessTokenFromClientCreds(args ClientCredentials) (Result, error) {
u, err := url.Parse("https://" + args.Domain)
if err != nil {
return Result{}, err
}

credsConfig := &clientcredentials.Config{
ClientID: args.ClientID,
ClientSecret: args.ClientSecret,
TokenURL: u.String() + "/oauth/token",
EndpointParams: url.Values{
"client_id": {args.ClientID},
"scope": {strings.Join(RequiredScopesMin(), " ")},
"audience": {u.String() + "/api/v2/"},
},
}

resp, err := credsConfig.Token(context.Background())
if err != nil {
return Result{}, err
}

return Result{
AccessToken: resp.AccessToken,
ExpiresAt: resp.Expiry,
}, nil
}
126 changes: 67 additions & 59 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,54 @@ type cli struct {
config config
}

func (t *Tenant) authenticatedWithClientCredentials() bool {
return t.ClientID != "" && t.ClientSecret != ""
}

func (t *Tenant) authenticatedWithDeviceCodeFlow() bool {
return t.ClientID == "" && t.ClientSecret == ""
}

func (t *Tenant) hasExpiredToken() bool {
return time.Now().Add(accessTokenExpThreshold).After(t.ExpiresAt)
}

func (t *Tenant) regenerateAccessToken(ctx context.Context, c *cli) error {
if t.authenticatedWithClientCredentials() {
token, err := auth.GetAccessTokenFromClientCreds(auth.ClientCredentials{
ClientID: t.ClientID,
ClientSecret: t.ClientSecret,
Domain: t.Domain,
})
if err != nil {
return err
}

t.AccessToken = token.AccessToken
t.ExpiresAt = token.ExpiresAt
}

if t.authenticatedWithDeviceCodeFlow() {
tokenRetriever := &auth.TokenRetriever{
Authenticator: c.authenticator,
Secrets: &auth.Keyring{},
Client: http.DefaultClient,
}

tokenResponse, err := tokenRetriever.Refresh(ctx, t.Domain)
if err != nil {
return err
}

t.AccessToken = tokenResponse.AccessToken
t.ExpiresAt = time.Now().Add(
time.Duration(tokenResponse.ExpiresIn) * time.Second,
)
}

return nil
}

// isLoggedIn encodes the domain logic for determining whether or not we're
// logged in. This might check our config storage, or just in memory.
func (c *cli) isLoggedIn() bool {
Expand Down Expand Up @@ -133,28 +181,18 @@ func (c *cli) setup(ctx context.Context) error {
return err
}

var (
m *management.Management
ua = fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))
)

if t.ClientID != "" && t.ClientSecret != "" {
m, err = management.New(t.Domain,
management.WithClientCredentials(t.ClientID, t.ClientSecret),
management.WithUserAgent(ua),
)
} else {
m, err = management.New(t.Domain,
management.WithStaticToken(t.AccessToken),
management.WithUserAgent(ua),
)
}
userAgent := fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))

api, err := management.New(
t.Domain,
management.WithStaticToken(t.AccessToken),
management.WithUserAgent(userAgent),
)
if err != nil {
return err
}

c.api = auth0.NewAPI(m)
c.api = auth0.NewAPI(api)
return nil
}

Expand All @@ -168,57 +206,27 @@ func (c *cli) prepareTenant(ctx context.Context) (Tenant, error) {
return Tenant{}, err
}

if t.ClientID != "" && t.ClientSecret != "" {
if t.AccessToken == "" || (scopesChanged(t) && t.authenticatedWithDeviceCodeFlow()) {
return RunLogin(ctx, c, true)
}

if !t.hasExpiredToken() {
return t, nil
}

if t.AccessToken == "" || scopesChanged(t) {
t, err = RunLogin(ctx, c, true)
if err != nil {
return Tenant{}, err
}
} else if isExpired(t.ExpiresAt, accessTokenExpThreshold) {
// check if the stored access token is expired:
// use the refresh token to get a new access token:
tr := &auth.TokenRetriever{
Authenticator: c.authenticator,
Secrets: &auth.Keyring{},
Client: http.DefaultClient,
}
if err := t.regenerateAccessToken(ctx, c); err != nil {
// Ask and guide the user through the login process.
c.renderer.Errorf("failed to renew access token, %s", err)
return RunLogin(ctx, c, true)
}

// NOTE(cyx): this code will have to be adapted to instead
// maybe take the clientID/secret as additional params, or
// something similar.
res, err := tr.Refresh(ctx, t.Domain)
if err != nil {
// ask and guide the user through the login process:
c.renderer.Errorf("failed to renew access token, %s", err)
t, err = RunLogin(ctx, c, true)
if err != nil {
return Tenant{}, err
}
} else {
// persist the updated tenant with renewed access token
t.AccessToken = res.AccessToken
t.ExpiresAt = time.Now().Add(
time.Duration(res.ExpiresIn) * time.Second,
)

err = c.addTenant(t)
if err != nil {
return Tenant{}, err
}
}
if err := c.addTenant(t); err != nil {
return Tenant{}, fmt.Errorf("unexpected error adding tenant to config: %w", err)
}

return t, nil
}

// isExpired is true if now() + a threshold is after the given date.
func isExpired(t time.Time, threshold time.Duration) bool {
return time.Now().Add(threshold).After(t)
}

// scopesChanged compare the tenant scopes
// with the currently required scopes.
func scopesChanged(t Tenant) bool {
Expand Down
50 changes: 29 additions & 21 deletions internal/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,35 @@ import (
"github.com/auth0/auth0-cli/internal/display"
)

func TestIsExpired(t *testing.T) {
t.Run("is expired", func(t *testing.T) {
d := time.Date(2021, 01, 01, 10, 30, 30, 0, time.UTC)
if want, got := true, isExpired(d, 1*time.Minute); want != got {
t.Fatalf("wanted: %v, got %v", want, got)
}
})

t.Run("expired because of the threshold", func(t *testing.T) {
d := time.Now().Add(-2 * time.Minute)
if want, got := true, isExpired(d, 5*time.Minute); want != got {
t.Fatalf("wanted: %v, got %v", want, got)
}
})

t.Run("is not expired", func(t *testing.T) {
d := time.Now().Add(10 * time.Minute)
if want, got := false, isExpired(d, 5*time.Minute); want != got {
t.Fatalf("wanted: %v, got %v", want, got)
}
})
func TestTenant_HasExpiredToken(t *testing.T) {
var testCases = []struct {
name string
givenTime time.Time
expectedTokenToBeExpired bool
}{
{
name: "is expired",
givenTime: time.Date(2021, 01, 01, 10, 30, 30, 0, time.UTC),
expectedTokenToBeExpired: true,
},
{
name: "expired because of the threshold",
givenTime: time.Now().Add(-2 * time.Minute),
expectedTokenToBeExpired: true,
},
{
name: "is not expired",
givenTime: time.Now().Add(10 * time.Minute),
expectedTokenToBeExpired: false,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
tenant := Tenant{ExpiresAt: testCase.givenTime}
assert.Equal(t, testCase.expectedTokenToBeExpired, tenant.hasExpiredToken())
})
}
}

// TODO(cyx): think about whether we should extract this function in the
Expand Down
7 changes: 2 additions & 5 deletions internal/cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cli
import (
"context"
"fmt"
"time"

"github.com/pkg/browser"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -109,10 +108,8 @@ func RunLogin(ctx context.Context, cli *cli, expired bool) (Tenant, error) {
Name: result.Tenant,
Domain: result.Domain,
AccessToken: result.AccessToken,
ExpiresAt: time.Now().Add(
time.Duration(result.ExpiresIn) * time.Second,
),
Scopes: auth.RequiredScopes(),
ExpiresAt: result.ExpiresAt,
Scopes: auth.RequiredScopes(),
}

err = cli.addTenant(tenant)
Expand Down
14 changes: 13 additions & 1 deletion internal/cli/tenants.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/spf13/cobra"

"github.com/auth0/auth0-cli/internal/auth"
"github.com/auth0/auth0-cli/internal/prompt"
)

Expand Down Expand Up @@ -188,14 +189,25 @@ func addTenantCmd(cli *cli) *cobra.Command {
return err
}

token, err := auth.GetAccessTokenFromClientCreds(auth.ClientCredentials{
ClientID: inputs.ClientID,
ClientSecret: inputs.ClientSecret,
Domain: inputs.Domain,
})
if err != nil {
return err
}

t := Tenant{
Domain: inputs.Domain,
AccessToken: token.AccessToken,
ExpiresAt: token.ExpiresAt,
ClientID: inputs.ClientID,
ClientSecret: inputs.ClientSecret,
}

if err := cli.addTenant(t); err != nil {
return err
return fmt.Errorf("unexpected error when attempting to save tenant data: %w", err)
}

cli.renderer.Infof("Tenant added successfully: %s", t.Domain)
Expand Down