Skip to content

Commit

Permalink
DXCDT-420: Add unit tests to cover internal auth functionality (#707)
Browse files Browse the repository at this point in the history
Add unit tests to cover internal auth functionality
  • Loading branch information
ewanharris authored Apr 3, 2023
1 parent 927259a commit 7dce9c9
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 6 deletions.
8 changes: 4 additions & 4 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var credentials = &Credentials{
}

// WaitUntilUserLogsIn waits until the user is logged in on the browser.
func WaitUntilUserLogsIn(ctx context.Context, state State) (Result, error) {
func WaitUntilUserLogsIn(ctx context.Context, httpClient *http.Client, state State) (Result, error) {
t := time.NewTicker(state.IntervalDuration())
for {
select {
Expand All @@ -69,7 +69,7 @@ func WaitUntilUserLogsIn(ctx context.Context, state State) (Result, error) {
"grant_type": []string{"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": []string{state.DeviceCode},
}
r, err := http.PostForm(credentials.OauthTokenEndpoint, data)
r, err := httpClient.PostForm(credentials.OauthTokenEndpoint, data)
if err != nil {
return Result{}, fmt.Errorf("cannot get device code: %w", err)
}
Expand Down Expand Up @@ -119,7 +119,7 @@ func WaitUntilUserLogsIn(ctx context.Context, state State) (Result, error) {
// GetDeviceCode kicks-off the device authentication flow by requesting
// a device code from Auth0. The returned state contains the
// URI for the next step of the flow.
func GetDeviceCode(ctx context.Context, additionalScopes []string) (State, error) {
func GetDeviceCode(ctx context.Context, httpClient *http.Client, additionalScopes []string) (State, error) {
a := credentials

data := url.Values{
Expand All @@ -140,7 +140,7 @@ func GetDeviceCode(ctx context.Context, additionalScopes []string) (State, error

request.Header.Set("Content-Type", "application/x-www-form-urlencoded")

response, err := http.DefaultClient.Do(request)
response, err := httpClient.Do(request)
if err != nil {
return State{}, fmt.Errorf("failed to send the request: %w", err)
}
Expand Down
226 changes: 226 additions & 0 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package auth

import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestWaitUntilUserLogsIn(t *testing.T) {
state := State{
"1234",
"12345",
"https://example.com/12345",
1000,
1,
}

t.Run("successfully waits and handles response", func(t *testing.T) {
counter := 0
tokenResponse := `{
"access_token": "Zm9v.eyJhdWQiOiBbImh0dHBzOi8vYXV0aDAtY2xpLXRlc3QudXMuYXV0aDAuY29tL2FwaS92Mi8iXX0",
"id_token": "id-token-here",
"refresh_token": "refresh-token-here",
"scope": "scope-here",
"token_type": "token-type-here",
"expires_in": 1000
}`
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if counter < 1 {
io.WriteString(w, `{
"error": "authorization_pending",
"error_description": "still pending auth"
}`)
} else {
io.WriteString(w, tokenResponse)
}
counter++
}))

defer ts.Close()

parsedURL, err := url.Parse(ts.URL)
assert.NoError(t, err)
u := url.URL{Scheme: "https", Host: parsedURL.Host, Path: "/oauth/token"}
credentials.OauthTokenEndpoint = u.String()

result, err := WaitUntilUserLogsIn(context.Background(), ts.Client(), state)

assert.NoError(t, err)
assert.Equal(t, "auth0-cli-test", result.Tenant)
assert.Equal(t, "auth0-cli-test.us.auth0.com", result.Domain)
})

testCases := []struct {
name string
httpStatus int
response string
expect string
}{
{
name: "handle malformed JSON",
httpStatus: http.StatusOK,
response: "foo",
expect: "cannot decode response: invalid character 'o' in literal false (expecting 'a')",
},
{
name: "should pass through authorization server errors",
httpStatus: http.StatusOK,
response: "{\"error\": \"slow_down\", \"error_description\": \"slow down!\"}",
expect: "slow down!",
},
{
name: "should error if can't parse tenant info",
httpStatus: http.StatusOK,
response: "{\"access_token\": \"bad.token\"}",
expect: "cannot parse tenant from the given access token: illegal base64 data at input byte 4",
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(testCase.httpStatus)
if testCase.response != "" {
io.WriteString(w, testCase.response)
}
}))

defer ts.Close()

parsedURL, err := url.Parse(ts.URL)
assert.NoError(t, err)
u := url.URL{Scheme: "https", Host: parsedURL.Host, Path: "/oauth/token"}
credentials.OauthTokenEndpoint = u.String()

_, err = WaitUntilUserLogsIn(context.Background(), ts.Client(), state)

assert.EqualError(t, err, testCase.expect)
})
}
}

func TestGetDeviceCode(t *testing.T) {
t.Run("successfully retrieve state from response", func(t *testing.T) {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
io.WriteString(w, `{
"device_code": "device-code-here",
"user_code": "user-code-here",
"verification_uri_complete": "verification-uri-here",
"expires_in": 1000,
"interval": 1
}`)
}))

defer ts.Close()

parsedURL, err := url.Parse(ts.URL)
assert.NoError(t, err)
u := url.URL{Scheme: "https", Host: parsedURL.Host, Path: "/oauth/device/code"}
credentials.DeviceCodeEndpoint = u.String()

state, err := GetDeviceCode(context.Background(), ts.Client(), []string{})

assert.NoError(t, err)
assert.Equal(t, "device-code-here", state.DeviceCode)
assert.Equal(t, "user-code-here", state.UserCode)
assert.Equal(t, "verification-uri-here", state.VerificationURI)
assert.Equal(t, 1000, state.ExpiresIn)
assert.Equal(t, 1, state.Interval)
assert.Equal(t, time.Duration(4000000000), state.IntervalDuration())
})

testCases := []struct {
name string
httpStatus int
response string
expect string
}{
{
name: "handle HTTP status errors",
httpStatus: http.StatusNotFound,
response: "Test response return",
expect: "received a 404 response: Test response return",
},
{
name: "handle bad JSON response",
httpStatus: http.StatusOK,
response: "foo",
expect: "failed to decode the response: invalid character 'o' in literal false (expecting 'a')",
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(testCase.httpStatus)
if testCase.response != "" {
io.WriteString(w, testCase.response)
}
}))

defer ts.Close()

parsedURL, err := url.Parse(ts.URL)
assert.NoError(t, err)
u := url.URL{Scheme: "https", Host: parsedURL.Host, Path: "/oauth/device/code"}
credentials.DeviceCodeEndpoint = u.String()

_, err = GetDeviceCode(context.Background(), ts.Client(), []string{})

assert.EqualError(t, err, testCase.expect)
})
}
}

func TestParseTenant(t *testing.T) {
t.Run("Successfully parse tenant and domain", func(t *testing.T) {
tenant, domain, err := parseTenant("Zm9v.eyJhdWQiOiBbImh0dHBzOi8vYXV0aDAtY2xpLXRlc3QudXMuYXV0aDAuY29tL2FwaS92Mi8iXX0")
assert.NoError(t, err)
assert.Equal(t, "auth0-cli-test", tenant)
assert.Equal(t, "auth0-cli-test.us.auth0.com", domain)
})

testCases := []struct {
name string
accessToken string
err string
}{
{
name: "bad base64 encoding",
accessToken: "bad.token.foo",
err: "illegal base64 data at input byte 4",
},
{
name: "bad json encoding",
accessToken: "Zm9v.Zm9v", // foo encoded in base64
err: "invalid character 'o' in literal false (expecting 'a')",
},
{
name: "invalid URL in aud array",
accessToken: "Zm9v.eyJhdWQiOiBbIjpleGFtcGxlLmNvbSJdfQ",
err: "parse \":example.com\": missing protocol scheme",
},
{
name: "no matching URL aud array",
accessToken: "Zm9v.eyJhdWQiOiBbImh0dHBzOi8vZXhhbXBsZXMuY29tIl19",
err: "audience not found for /api/v2/",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
tenant, domain, err := parseTenant(testCase.accessToken)
assert.EqualError(t, err, testCase.err)
assert.Equal(t, "", tenant)
assert.Equal(t, "", domain)
})
}
}
5 changes: 3 additions & 2 deletions internal/cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
"context"
"fmt"
"net/http"
"strings"

"github.com/pkg/browser"
Expand Down Expand Up @@ -155,7 +156,7 @@ func loginCmd(cli *cli) *cobra.Command {
// RunLoginAsUser runs the login flow guiding the user through the process
// by showing the login instructions, opening the browser.
func RunLoginAsUser(ctx context.Context, cli *cli, additionalScopes []string) (Tenant, error) {
state, err := auth.GetDeviceCode(ctx, additionalScopes)
state, err := auth.GetDeviceCode(ctx, http.DefaultClient, additionalScopes)
if err != nil {
return Tenant{}, fmt.Errorf("failed to get the device code: %w", err)
}
Expand Down Expand Up @@ -184,7 +185,7 @@ func RunLoginAsUser(ctx context.Context, cli *cli, additionalScopes []string) (T

var result auth.Result
err = ansi.Spinner("Waiting for the login to complete in the browser", func() error {
result, err = auth.WaitUntilUserLogsIn(ctx, state)
result, err = auth.WaitUntilUserLogsIn(ctx, http.DefaultClient, state)
return err
})
if err != nil {
Expand Down

0 comments on commit 7dce9c9

Please sign in to comment.