diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 426e448ef..1168cdb0c 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -14,6 +14,7 @@ import ( "time" "github.com/joeshaw/envdecode" + "golang.org/x/oauth2/clientcredentials" ) const ( @@ -65,7 +66,7 @@ type Result struct { Domain string RefreshToken string AccessToken string - ExpiresIn int64 + ExpiresAt time.Time } type State struct { @@ -170,9 +171,11 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) { return Result{ RefreshToken: res.RefreshToken, AccessToken: res.AccessToken, - ExpiresIn: res.ExpiresIn, - Tenant: ten, - Domain: domain, + ExpiresAt: time.Now().Add( + time.Duration(res.ExpiresIn) * time.Second, + ), + Tenant: ten, + Domain: domain, }, nil } } @@ -249,3 +252,39 @@ func parseTenant(accessToken string) (tenant, domain string, err error) { } return "", "", fmt.Errorf("audience not found for %s", audiencePath) } + +// ClientCredentials encapsulates all data to facilitate access token creation with client credentials (client ID and client secret) +type ClientCredentials struct { + ClientID string + ClientSecret string + Domain string +} + +// GetAccessTokenFromClientCreds generates an access token from client credentials +func GetAccessTokenFromClientCreds(args ClientCredentials) (Result, error) { + u, err := url.Parse("https://" + args.Domain) + if err != nil { + return Result{}, err + } + + credsConfig := &clientcredentials.Config{ + ClientID: args.ClientID, + ClientSecret: args.ClientSecret, + TokenURL: u.String() + "/oauth/token", + EndpointParams: url.Values{ + "client_id": {args.ClientID}, + "scope": {strings.Join(RequiredScopesMin(), " ")}, + "audience": {u.String() + "/api/v2/"}, + }, + } + + resp, err := credsConfig.Token(context.Background()) + if err != nil { + return Result{}, err + } + + return Result{ + AccessToken: resp.AccessToken, + ExpiresAt: resp.Expiry, + }, nil +} diff --git a/internal/cli/cli.go b/internal/cli/cli.go index e082f80ad..e30837d9d 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -94,6 +94,54 @@ type cli struct { config config } +func (t *Tenant) authenticatedWithClientCredentials() bool { + return t.ClientID != "" && t.ClientSecret != "" +} + +func (t *Tenant) authenticatedWithDeviceCodeFlow() bool { + return t.ClientID == "" && t.ClientSecret == "" +} + +func (t *Tenant) hasExpiredToken() bool { + return time.Now().Add(accessTokenExpThreshold).After(t.ExpiresAt) +} + +func (t *Tenant) regenerateAccessToken(ctx context.Context, c *cli) error { + if t.authenticatedWithClientCredentials() { + token, err := auth.GetAccessTokenFromClientCreds(auth.ClientCredentials{ + ClientID: t.ClientID, + ClientSecret: t.ClientSecret, + Domain: t.Domain, + }) + if err != nil { + return err + } + + t.AccessToken = token.AccessToken + t.ExpiresAt = token.ExpiresAt + } + + if t.authenticatedWithDeviceCodeFlow() { + tokenRetriever := &auth.TokenRetriever{ + Authenticator: c.authenticator, + Secrets: &auth.Keyring{}, + Client: http.DefaultClient, + } + + tokenResponse, err := tokenRetriever.Refresh(ctx, t.Domain) + if err != nil { + return err + } + + t.AccessToken = tokenResponse.AccessToken + t.ExpiresAt = time.Now().Add( + time.Duration(tokenResponse.ExpiresIn) * time.Second, + ) + } + + return nil +} + // isLoggedIn encodes the domain logic for determining whether or not we're // logged in. This might check our config storage, or just in memory. func (c *cli) isLoggedIn() bool { @@ -133,28 +181,18 @@ func (c *cli) setup(ctx context.Context) error { return err } - var ( - m *management.Management - ua = fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v")) - ) - - if t.ClientID != "" && t.ClientSecret != "" { - m, err = management.New(t.Domain, - management.WithClientCredentials(t.ClientID, t.ClientSecret), - management.WithUserAgent(ua), - ) - } else { - m, err = management.New(t.Domain, - management.WithStaticToken(t.AccessToken), - management.WithUserAgent(ua), - ) - } + userAgent := fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v")) + api, err := management.New( + t.Domain, + management.WithStaticToken(t.AccessToken), + management.WithUserAgent(userAgent), + ) if err != nil { return err } - c.api = auth0.NewAPI(m) + c.api = auth0.NewAPI(api) return nil } @@ -168,57 +206,27 @@ func (c *cli) prepareTenant(ctx context.Context) (Tenant, error) { return Tenant{}, err } - if t.ClientID != "" && t.ClientSecret != "" { + if t.AccessToken == "" || (scopesChanged(t) && t.authenticatedWithDeviceCodeFlow()) { + return RunLogin(ctx, c, true) + } + + if !t.hasExpiredToken() { return t, nil } - if t.AccessToken == "" || scopesChanged(t) { - t, err = RunLogin(ctx, c, true) - if err != nil { - return Tenant{}, err - } - } else if isExpired(t.ExpiresAt, accessTokenExpThreshold) { - // check if the stored access token is expired: - // use the refresh token to get a new access token: - tr := &auth.TokenRetriever{ - Authenticator: c.authenticator, - Secrets: &auth.Keyring{}, - Client: http.DefaultClient, - } + if err := t.regenerateAccessToken(ctx, c); err != nil { + // Ask and guide the user through the login process. + c.renderer.Errorf("failed to renew access token, %s", err) + return RunLogin(ctx, c, true) + } - // NOTE(cyx): this code will have to be adapted to instead - // maybe take the clientID/secret as additional params, or - // something similar. - res, err := tr.Refresh(ctx, t.Domain) - if err != nil { - // ask and guide the user through the login process: - c.renderer.Errorf("failed to renew access token, %s", err) - t, err = RunLogin(ctx, c, true) - if err != nil { - return Tenant{}, err - } - } else { - // persist the updated tenant with renewed access token - t.AccessToken = res.AccessToken - t.ExpiresAt = time.Now().Add( - time.Duration(res.ExpiresIn) * time.Second, - ) - - err = c.addTenant(t) - if err != nil { - return Tenant{}, err - } - } + if err := c.addTenant(t); err != nil { + return Tenant{}, fmt.Errorf("unexpected error adding tenant to config: %w", err) } return t, nil } -// isExpired is true if now() + a threshold is after the given date. -func isExpired(t time.Time, threshold time.Duration) bool { - return time.Now().Add(threshold).After(t) -} - // scopesChanged compare the tenant scopes // with the currently required scopes. func scopesChanged(t Tenant) bool { diff --git a/internal/cli/cli_test.go b/internal/cli/cli_test.go index 290217f0b..e8bbae59c 100644 --- a/internal/cli/cli_test.go +++ b/internal/cli/cli_test.go @@ -16,27 +16,35 @@ import ( "github.com/auth0/auth0-cli/internal/display" ) -func TestIsExpired(t *testing.T) { - t.Run("is expired", func(t *testing.T) { - d := time.Date(2021, 01, 01, 10, 30, 30, 0, time.UTC) - if want, got := true, isExpired(d, 1*time.Minute); want != got { - t.Fatalf("wanted: %v, got %v", want, got) - } - }) - - t.Run("expired because of the threshold", func(t *testing.T) { - d := time.Now().Add(-2 * time.Minute) - if want, got := true, isExpired(d, 5*time.Minute); want != got { - t.Fatalf("wanted: %v, got %v", want, got) - } - }) - - t.Run("is not expired", func(t *testing.T) { - d := time.Now().Add(10 * time.Minute) - if want, got := false, isExpired(d, 5*time.Minute); want != got { - t.Fatalf("wanted: %v, got %v", want, got) - } - }) +func TestTenant_HasExpiredToken(t *testing.T) { + var testCases = []struct { + name string + givenTime time.Time + expectedTokenToBeExpired bool + }{ + { + name: "is expired", + givenTime: time.Date(2021, 01, 01, 10, 30, 30, 0, time.UTC), + expectedTokenToBeExpired: true, + }, + { + name: "expired because of the threshold", + givenTime: time.Now().Add(-2 * time.Minute), + expectedTokenToBeExpired: true, + }, + { + name: "is not expired", + givenTime: time.Now().Add(10 * time.Minute), + expectedTokenToBeExpired: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + tenant := Tenant{ExpiresAt: testCase.givenTime} + assert.Equal(t, testCase.expectedTokenToBeExpired, tenant.hasExpiredToken()) + }) + } } // TODO(cyx): think about whether we should extract this function in the diff --git a/internal/cli/login.go b/internal/cli/login.go index c5f7d598a..2bdac5460 100644 --- a/internal/cli/login.go +++ b/internal/cli/login.go @@ -3,7 +3,6 @@ package cli import ( "context" "fmt" - "time" "github.com/pkg/browser" "github.com/spf13/cobra" @@ -109,10 +108,8 @@ func RunLogin(ctx context.Context, cli *cli, expired bool) (Tenant, error) { Name: result.Tenant, Domain: result.Domain, AccessToken: result.AccessToken, - ExpiresAt: time.Now().Add( - time.Duration(result.ExpiresIn) * time.Second, - ), - Scopes: auth.RequiredScopes(), + ExpiresAt: result.ExpiresAt, + Scopes: auth.RequiredScopes(), } err = cli.addTenant(tenant) diff --git a/internal/cli/tenants.go b/internal/cli/tenants.go index f199d62c9..b028da9b0 100644 --- a/internal/cli/tenants.go +++ b/internal/cli/tenants.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/cobra" + "github.com/auth0/auth0-cli/internal/auth" "github.com/auth0/auth0-cli/internal/prompt" ) @@ -188,14 +189,25 @@ func addTenantCmd(cli *cli) *cobra.Command { return err } + token, err := auth.GetAccessTokenFromClientCreds(auth.ClientCredentials{ + ClientID: inputs.ClientID, + ClientSecret: inputs.ClientSecret, + Domain: inputs.Domain, + }) + if err != nil { + return err + } + t := Tenant{ Domain: inputs.Domain, + AccessToken: token.AccessToken, + ExpiresAt: token.ExpiresAt, ClientID: inputs.ClientID, ClientSecret: inputs.ClientSecret, } if err := cli.addTenant(t); err != nil { - return err + return fmt.Errorf("unexpected error when attempting to save tenant data: %w", err) } cli.renderer.Infof("Tenant added successfully: %s", t.Domain)