Skip to content

Commit

Permalink
Implement MFA authentication APIs (#331)
Browse files Browse the repository at this point in the history
Co-authored-by: Rita Zerrizuela <[email protected]>
  • Loading branch information
ewanharris and Widcket authored Dec 13, 2023
1 parent 8c3c25e commit fc6fe15
Show file tree
Hide file tree
Showing 15 changed files with 837 additions and 140 deletions.
138 changes: 138 additions & 0 deletions authentication/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@ package authentication
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"reflect"
"strings"
"time"

"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"

"github.com/auth0/go-auth0/authentication/oauth"
"github.com/auth0/go-auth0/internal/client"
"github.com/auth0/go-auth0/internal/idtokenvalidator"
)
Expand Down Expand Up @@ -112,6 +120,7 @@ func (u *UserInfoResponse) UnmarshalJSON(b []byte) error {
// Authentication is the auth client.
type Authentication struct {
Database *Database
MFA *MFA
OAuth *OAuth
Passwordless *Passwordless

Expand Down Expand Up @@ -172,6 +181,7 @@ func New(ctx context.Context, domain string, options ...Option) (*Authentication

a.common.authentication = a
a.Database = (*Database)(&a.common)
a.MFA = (*MFA)(&a.common)
a.OAuth = (*OAuth)(&a.common)
a.Passwordless = (*Passwordless)(&a.common)

Expand Down Expand Up @@ -214,3 +224,131 @@ func (a *Authentication) UserInfo(ctx context.Context, accessToken string, opts
err = a.Request(ctx, "GET", a.URI("userinfo"), nil, &user, opts...)
return
}

// Helper for adding values to a url.Values instance if they are not empty.
func addIfNotEmpty(key string, value string, qs url.Values) {
if value != "" {
qs.Set(key, value)
}
}

// Helper for enforcing that required values are set.
func check(errors *[]string, key string, c bool) {
if !c {
*errors = append(*errors, key)
}
}

// Helper for adding client authentication into a url.Values instance.
func (a *Authentication) addClientAuthenticationToURLValues(params oauth.ClientAuthentication, body url.Values, required bool) error {
clientID := params.ClientID
if params.ClientID == "" {
clientID = a.clientID
}
body.Set("client_id", clientID)

clientSecret := params.ClientSecret
if params.ClientSecret == "" {
clientSecret = a.clientSecret
}

switch {
case a.clientAssertionSigningKey != "" && a.clientAssertionSigningAlg != "":
clientAssertion, err := createClientAssertion(
a.clientAssertionSigningAlg,
a.clientAssertionSigningKey,
clientID,
a.url.JoinPath("/").String(),
)
if err != nil {
return err
}

body.Set("client_assertion", clientAssertion)
body.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
break
case params.ClientAssertion != "" && params.ClientAssertionType != "":
body.Set("client_assertion", params.ClientAssertion)
body.Set("client_assertion_type", params.ClientAssertionType)
break
case clientSecret != "":
body.Set("client_secret", clientSecret)
break
}

if required && (body.Get("client_secret") == "" && body.Get("client_assertion") == "") {
return errors.New("client_secret or client_assertion is required but not provided")
}

return nil
}

// Helper for adding client authentication to an oauth.ClientAuthentication struct.
func (a *Authentication) addClientAuthenticationToClientAuthStruct(params *oauth.ClientAuthentication, required bool) error {
if params.ClientID == "" {
params.ClientID = a.clientID
}

if a.clientAssertionSigningKey != "" && a.clientAssertionSigningAlg != "" {
clientAssertion, err := createClientAssertion(
a.clientAssertionSigningAlg,
a.clientAssertionSigningKey,
params.ClientID,
a.url.JoinPath("/").String(),
)
if err != nil {
return err
}

params.ClientAssertion = clientAssertion
params.ClientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
} else if params.ClientSecret == "" && a.clientSecret != "" {
params.ClientSecret = a.clientSecret
}

if required && (params.ClientSecret == "" && params.ClientAssertion == "") {
return errors.New("client_secret or client_assertion is required but not provided")
}

return nil
}

func determineAlg(alg string) (jwa.SignatureAlgorithm, error) {
switch alg {
case "RS256":
return jwa.RS256, nil
default:
return "", fmt.Errorf("Unsupported client assertion algorithm \"%s\" provided", alg)
}
}

func createClientAssertion(clientAssertionSigningAlg, clientAssertionSigningKey, clientID, domain string) (string, error) {
alg, err := determineAlg(clientAssertionSigningAlg)
if err != nil {
return "", err
}

key, err := jwk.ParseKey([]byte(clientAssertionSigningKey), jwk.WithPEM(true))
if err != nil {
return "", err
}

token, err := jwt.NewBuilder().
IssuedAt(time.Now()).
Subject(clientID).
JwtID(uuid.New().String()).
Issuer(clientID).
Audience([]string{domain}).
Expiration(time.Now().Add(2 * time.Minute)).
Build()
if err != nil {
return "", err
}

b, err := jwt.Sign(token, jwt.WithKey(alg, key))
if err != nil {
return "", err
}

return string(b), nil
}
4 changes: 3 additions & 1 deletion authentication/authentication_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ func newError(response *http.Response) error {
// If that happens we still want to display the correct code.
if apiError.Status() == 0 {
apiError.StatusCode = response.StatusCode
apiError.Err = http.StatusText(response.StatusCode)
if apiError.Err == "" {
apiError.Err = http.StatusText(response.StatusCode)
}
}

return apiError
Expand Down
10 changes: 8 additions & 2 deletions authentication/http_recordings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"

"github.com/auth0/go-auth0/authentication/oauth"
"github.com/auth0/go-auth0/authentication/mfa"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -189,7 +189,13 @@ func redactTokens(t *testing.T, i *cassette.Interaction) {
return
}

tokenSet := &oauth.TokenSet{}
if i.Response.Code >= http.StatusBadRequest {
return
}

// We use mfa.VerifyWithRecoveryCodeResponse here as we don't want to lose the RecoveryCode
// property when anonymizing the tokenset
tokenSet := &mfa.VerifyWithRecoveryCodeResponse{}

err := json.Unmarshal([]byte(i.Response.Body), tokenSet)
require.NoError(t, err)
Expand Down
137 changes: 137 additions & 0 deletions authentication/mfa.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package authentication

import (
"context"
"fmt"
"net/url"
"strings"

"github.com/auth0/go-auth0/authentication/mfa"
"github.com/auth0/go-auth0/authentication/oauth"
)

// MFA exposes requesting an MFA challenge and verifying MFA methods.
type MFA manager

// Challenge requests a challenge for multi-factor authentication (MFA) based on the challenge types supported by the application and user.
//
// See: https://auth0.com/docs/api/authentication#challenge-request
func (m *MFA) Challenge(ctx context.Context, body mfa.ChallengeRequest, opts ...RequestOption) (c *mfa.ChallengeResponse, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "ChallengeType", body.ChallengeType != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

err = m.authentication.addClientAuthenticationToClientAuthStruct(&body.ClientAuthentication, false)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("mfa", "challenge"), body, &c, opts...)

if err != nil {
return nil, err
}

return
}

// VerifyWithOTP verifies an MFA challenge using a one-time password (OTP).
//
// See: https://auth0.com/docs/api/authentication#verify-with-one-time-password-otp-
func (m *MFA) VerifyWithOTP(ctx context.Context, body mfa.VerifyWithOTPRequest, opts ...RequestOption) (t *oauth.TokenSet, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "OTP", body.OTP != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

data := url.Values{
"mfa_token": []string{body.MFAToken},
"grant_type": []string{"http://auth0.com/oauth/grant-type/mfa-otp"},
"otp": []string{body.OTP},
}

err = m.authentication.addClientAuthenticationToURLValues(body.ClientAuthentication, data, true)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("oauth", "token"), data, &t, opts...)

return
}

// VerifyWithOOB verifies an MFA challenge using an out-of-band challenge (OOB), either push notification,
// SMS, or voice.
//
// See: https://auth0.com/docs/api/authentication#verify-with-out-of-band-oob-
func (m *MFA) VerifyWithOOB(ctx context.Context, body mfa.VerifyWithOOBRequest, opts ...RequestOption) (t *oauth.TokenSet, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "OOBCode", body.OOBCode != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

data := url.Values{
"mfa_token": []string{body.MFAToken},
"grant_type": []string{"http://auth0.com/oauth/grant-type/mfa-oob"},
"oob_code": []string{body.OOBCode},
}

if body.BindingCode != "" {
data.Set("binding_code", body.BindingCode)
}

err = m.authentication.addClientAuthenticationToURLValues(body.ClientAuthentication, data, true)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("oauth", "token"), data, &t, opts...)

return
}

// VerifyWithRecoveryCode verifies an MFA challenge using a recovery code.
//
// See: https://auth0.com/docs/api/authentication#verify-with-recovery-code
func (m *MFA) VerifyWithRecoveryCode(ctx context.Context, body mfa.VerifyWithRecoveryCodeRequest, opts ...RequestOption) (t *mfa.VerifyWithRecoveryCodeResponse, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "RecoveryCode", body.RecoveryCode != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

data := url.Values{
"mfa_token": []string{body.MFAToken},
"grant_type": []string{"http://auth0.com/oauth/grant-type/mfa-recovery-code"},
"recovery_code": []string{body.RecoveryCode},
}

err = m.authentication.addClientAuthenticationToURLValues(body.ClientAuthentication, data, true)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("oauth", "token"), data, &t, opts...)

return
}
Loading

0 comments on commit fc6fe15

Please sign in to comment.