Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Back-merging main into v1 #543

Merged
merged 1 commit into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/joeshaw/envdecode"
"golang.org/x/oauth2/clientcredentials"
)

const (
Expand Down Expand Up @@ -65,7 +66,7 @@ type Result struct {
Domain string
RefreshToken string
AccessToken string
ExpiresIn int64
ExpiresAt time.Time
}

type State struct {
Expand Down Expand Up @@ -170,9 +171,11 @@ func (a *Authenticator) Wait(ctx context.Context, state State) (Result, error) {
return Result{
RefreshToken: res.RefreshToken,
AccessToken: res.AccessToken,
ExpiresIn: res.ExpiresIn,
Tenant: ten,
Domain: domain,
ExpiresAt: time.Now().Add(
time.Duration(res.ExpiresIn) * time.Second,
),
Tenant: ten,
Domain: domain,
}, nil
}
}
Expand Down Expand Up @@ -249,3 +252,39 @@ func parseTenant(accessToken string) (tenant, domain string, err error) {
}
return "", "", fmt.Errorf("audience not found for %s", audiencePath)
}

// ClientCredentials encapsulates all data to facilitate access token creation with client credentials (client ID and client secret)
type ClientCredentials struct {
ClientID string
ClientSecret string
Domain string
}

// GetAccessTokenFromClientCreds generates an access token from client credentials
func GetAccessTokenFromClientCreds(args ClientCredentials) (Result, error) {
u, err := url.Parse("https://" + args.Domain)
if err != nil {
return Result{}, err
}

credsConfig := &clientcredentials.Config{
ClientID: args.ClientID,
ClientSecret: args.ClientSecret,
TokenURL: u.String() + "/oauth/token",
EndpointParams: url.Values{
"client_id": {args.ClientID},
"scope": {strings.Join(RequiredScopesMin(), " ")},
"audience": {u.String() + "/api/v2/"},
},
}

resp, err := credsConfig.Token(context.Background())
if err != nil {
return Result{}, err
}

return Result{
AccessToken: resp.AccessToken,
ExpiresAt: resp.Expiry,
}, nil
}
126 changes: 67 additions & 59 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,54 @@ type cli struct {
config config
}

func (t *Tenant) authenticatedWithClientCredentials() bool {
return t.ClientID != "" && t.ClientSecret != ""
}

func (t *Tenant) authenticatedWithDeviceCodeFlow() bool {
return t.ClientID == "" && t.ClientSecret == ""
}

func (t *Tenant) hasExpiredToken() bool {
return time.Now().Add(accessTokenExpThreshold).After(t.ExpiresAt)
}

func (t *Tenant) regenerateAccessToken(ctx context.Context, c *cli) error {
if t.authenticatedWithClientCredentials() {
token, err := auth.GetAccessTokenFromClientCreds(auth.ClientCredentials{
ClientID: t.ClientID,
ClientSecret: t.ClientSecret,
Domain: t.Domain,
})
if err != nil {
return err
}

t.AccessToken = token.AccessToken
t.ExpiresAt = token.ExpiresAt
}

if t.authenticatedWithDeviceCodeFlow() {
tokenRetriever := &auth.TokenRetriever{
Authenticator: c.authenticator,
Secrets: &auth.Keyring{},
Client: http.DefaultClient,
}

tokenResponse, err := tokenRetriever.Refresh(ctx, t.Domain)
if err != nil {
return err
}

t.AccessToken = tokenResponse.AccessToken
t.ExpiresAt = time.Now().Add(
time.Duration(tokenResponse.ExpiresIn) * time.Second,
)
}

return nil
}

// isLoggedIn encodes the domain logic for determining whether or not we're
// logged in. This might check our config storage, or just in memory.
func (c *cli) isLoggedIn() bool {
Expand Down Expand Up @@ -133,28 +181,18 @@ func (c *cli) setup(ctx context.Context) error {
return err
}

var (
m *management.Management
ua = fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))
)

if t.ClientID != "" && t.ClientSecret != "" {
m, err = management.New(t.Domain,
management.WithClientCredentials(t.ClientID, t.ClientSecret),
management.WithUserAgent(ua),
)
} else {
m, err = management.New(t.Domain,
management.WithStaticToken(t.AccessToken),
management.WithUserAgent(ua),
)
}
userAgent := fmt.Sprintf("%v/%v", userAgent, strings.TrimPrefix(buildinfo.Version, "v"))

api, err := management.New(
t.Domain,
management.WithStaticToken(t.AccessToken),
management.WithUserAgent(userAgent),
)
if err != nil {
return err
}

c.api = auth0.NewAPI(m)
c.api = auth0.NewAPI(api)
return nil
}

Expand All @@ -168,57 +206,27 @@ func (c *cli) prepareTenant(ctx context.Context) (Tenant, error) {
return Tenant{}, err
}

if t.ClientID != "" && t.ClientSecret != "" {
if t.AccessToken == "" || (scopesChanged(t) && t.authenticatedWithDeviceCodeFlow()) {
return RunLogin(ctx, c, true)
}

if !t.hasExpiredToken() {
return t, nil
}

if t.AccessToken == "" || scopesChanged(t) {
t, err = RunLogin(ctx, c, true)
if err != nil {
return Tenant{}, err
}
} else if isExpired(t.ExpiresAt, accessTokenExpThreshold) {
// check if the stored access token is expired:
// use the refresh token to get a new access token:
tr := &auth.TokenRetriever{
Authenticator: c.authenticator,
Secrets: &auth.Keyring{},
Client: http.DefaultClient,
}
if err := t.regenerateAccessToken(ctx, c); err != nil {
// Ask and guide the user through the login process.
c.renderer.Errorf("failed to renew access token, %s", err)
return RunLogin(ctx, c, true)
}

// NOTE(cyx): this code will have to be adapted to instead
// maybe take the clientID/secret as additional params, or
// something similar.
res, err := tr.Refresh(ctx, t.Domain)
if err != nil {
// ask and guide the user through the login process:
c.renderer.Errorf("failed to renew access token, %s", err)
t, err = RunLogin(ctx, c, true)
if err != nil {
return Tenant{}, err
}
} else {
// persist the updated tenant with renewed access token
t.AccessToken = res.AccessToken
t.ExpiresAt = time.Now().Add(
time.Duration(res.ExpiresIn) * time.Second,
)

err = c.addTenant(t)
if err != nil {
return Tenant{}, err
}
}
if err := c.addTenant(t); err != nil {
return Tenant{}, fmt.Errorf("unexpected error adding tenant to config: %w", err)
}

return t, nil
}

// isExpired is true if now() + a threshold is after the given date.
func isExpired(t time.Time, threshold time.Duration) bool {
return time.Now().Add(threshold).After(t)
}

// scopesChanged compare the tenant scopes
// with the currently required scopes.
func scopesChanged(t Tenant) bool {
Expand Down
50 changes: 29 additions & 21 deletions internal/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,35 @@ import (
"github.com/auth0/auth0-cli/internal/display"
)

func TestIsExpired(t *testing.T) {
t.Run("is expired", func(t *testing.T) {
d := time.Date(2021, 01, 01, 10, 30, 30, 0, time.UTC)
if want, got := true, isExpired(d, 1*time.Minute); want != got {
t.Fatalf("wanted: %v, got %v", want, got)
}
})

t.Run("expired because of the threshold", func(t *testing.T) {
d := time.Now().Add(-2 * time.Minute)
if want, got := true, isExpired(d, 5*time.Minute); want != got {
t.Fatalf("wanted: %v, got %v", want, got)
}
})

t.Run("is not expired", func(t *testing.T) {
d := time.Now().Add(10 * time.Minute)
if want, got := false, isExpired(d, 5*time.Minute); want != got {
t.Fatalf("wanted: %v, got %v", want, got)
}
})
func TestTenant_HasExpiredToken(t *testing.T) {
var testCases = []struct {
name string
givenTime time.Time
expectedTokenToBeExpired bool
}{
{
name: "is expired",
givenTime: time.Date(2021, 01, 01, 10, 30, 30, 0, time.UTC),
expectedTokenToBeExpired: true,
},
{
name: "expired because of the threshold",
givenTime: time.Now().Add(-2 * time.Minute),
expectedTokenToBeExpired: true,
},
{
name: "is not expired",
givenTime: time.Now().Add(10 * time.Minute),
expectedTokenToBeExpired: false,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
tenant := Tenant{ExpiresAt: testCase.givenTime}
assert.Equal(t, testCase.expectedTokenToBeExpired, tenant.hasExpiredToken())
})
}
}

// TODO(cyx): think about whether we should extract this function in the
Expand Down
7 changes: 2 additions & 5 deletions internal/cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cli
import (
"context"
"fmt"
"time"

"github.com/pkg/browser"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -109,10 +108,8 @@ func RunLogin(ctx context.Context, cli *cli, expired bool) (Tenant, error) {
Name: result.Tenant,
Domain: result.Domain,
AccessToken: result.AccessToken,
ExpiresAt: time.Now().Add(
time.Duration(result.ExpiresIn) * time.Second,
),
Scopes: auth.RequiredScopes(),
ExpiresAt: result.ExpiresAt,
Scopes: auth.RequiredScopes(),
}

err = cli.addTenant(tenant)
Expand Down
14 changes: 13 additions & 1 deletion internal/cli/tenants.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/spf13/cobra"

"github.com/auth0/auth0-cli/internal/auth"
"github.com/auth0/auth0-cli/internal/prompt"
)

Expand Down Expand Up @@ -188,14 +189,25 @@ func addTenantCmd(cli *cli) *cobra.Command {
return err
}

token, err := auth.GetAccessTokenFromClientCreds(auth.ClientCredentials{
ClientID: inputs.ClientID,
ClientSecret: inputs.ClientSecret,
Domain: inputs.Domain,
})
if err != nil {
return err
}

t := Tenant{
Domain: inputs.Domain,
AccessToken: token.AccessToken,
ExpiresAt: token.ExpiresAt,
ClientID: inputs.ClientID,
ClientSecret: inputs.ClientSecret,
}

if err := cli.addTenant(t); err != nil {
return err
return fmt.Errorf("unexpected error when attempting to save tenant data: %w", err)
}

cli.renderer.Infof("Tenant added successfully: %s", t.Domain)
Expand Down