Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MFA authentication APIs #331

Merged
merged 9 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions authentication/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@ package authentication
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"reflect"
"strings"
"time"

"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"

"github.com/auth0/go-auth0/authentication/oauth"
"github.com/auth0/go-auth0/internal/client"
"github.com/auth0/go-auth0/internal/idtokenvalidator"
)
Expand Down Expand Up @@ -112,6 +120,7 @@ func (u *UserInfoResponse) UnmarshalJSON(b []byte) error {
// Authentication is the auth client.
type Authentication struct {
Database *Database
MFA *MFA
OAuth *OAuth
Passwordless *Passwordless

Expand Down Expand Up @@ -172,6 +181,7 @@ func New(ctx context.Context, domain string, options ...Option) (*Authentication

a.common.authentication = a
a.Database = (*Database)(&a.common)
a.MFA = (*MFA)(&a.common)
a.OAuth = (*OAuth)(&a.common)
a.Passwordless = (*Passwordless)(&a.common)

Expand Down Expand Up @@ -214,3 +224,131 @@ func (a *Authentication) UserInfo(ctx context.Context, accessToken string, opts
err = a.Request(ctx, "GET", a.URI("userinfo"), nil, &user, opts...)
return
}

// Helper for adding values to a url.Values instance if they are not empty.
func addIfNotEmpty(key string, value string, qs url.Values) {
if value != "" {
qs.Set(key, value)
}
}

// Helper for enforcing that required values are set.
func check(errors *[]string, key string, c bool) {
if !c {
*errors = append(*errors, key)
}
}

// Helper for adding client authentication into a url.Values instance.
func (a *Authentication) addClientAuthenticationToURLValues(params oauth.ClientAuthentication, body url.Values, required bool) error {
clientID := params.ClientID
if params.ClientID == "" {
clientID = a.clientID
}
body.Set("client_id", clientID)

clientSecret := params.ClientSecret
if params.ClientSecret == "" {
clientSecret = a.clientSecret
}

switch {
case a.clientAssertionSigningKey != "" && a.clientAssertionSigningAlg != "":
clientAssertion, err := createClientAssertion(
a.clientAssertionSigningAlg,
a.clientAssertionSigningKey,
clientID,
a.url.JoinPath("/").String(),
)
if err != nil {
return err
}

body.Set("client_assertion", clientAssertion)
body.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
break
case params.ClientAssertion != "" && params.ClientAssertionType != "":
body.Set("client_assertion", params.ClientAssertion)
body.Set("client_assertion_type", params.ClientAssertionType)
Copy link
Contributor

@Widcket Widcket Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that params.ClientAssertionType is not empty? Like e.g.

case params.ClientAssertion != "" && params.ClientAssertionType != "":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense, will fix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 716bdad and also took the opportunity to extend addClientAuthenticationToClientAuthStruct to allow specifying if a client secret/client assertion is required

break
case clientSecret != "":
body.Set("client_secret", clientSecret)
break
}

if required && (body.Get("client_secret") == "" && body.Get("client_assertion") == "") {
return errors.New("client_secret or client_assertion is required but not provided")
}

return nil
}

// Helper for adding client authentication to an oauth.ClientAuthentication struct.
func (a *Authentication) addClientAuthenticationToClientAuthStruct(params *oauth.ClientAuthentication, required bool) error {
if params.ClientID == "" {
params.ClientID = a.clientID
}

if a.clientAssertionSigningKey != "" && a.clientAssertionSigningAlg != "" {
clientAssertion, err := createClientAssertion(
a.clientAssertionSigningAlg,
a.clientAssertionSigningKey,
params.ClientID,
a.url.JoinPath("/").String(),
)
if err != nil {
return err
}

params.ClientAssertion = clientAssertion
params.ClientAssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
} else if params.ClientSecret == "" && a.clientSecret != "" {
params.ClientSecret = a.clientSecret
}

if required && (params.ClientSecret == "" && params.ClientAssertion == "") {
return errors.New("client_secret or client_assertion is required but not provided")
}

return nil
}

func determineAlg(alg string) (jwa.SignatureAlgorithm, error) {
switch alg {
case "RS256":
return jwa.RS256, nil
default:
return "", fmt.Errorf("Unsupported client assertion algorithm \"%s\" provided", alg)
}
}

func createClientAssertion(clientAssertionSigningAlg, clientAssertionSigningKey, clientID, domain string) (string, error) {
alg, err := determineAlg(clientAssertionSigningAlg)
if err != nil {
return "", err
}

key, err := jwk.ParseKey([]byte(clientAssertionSigningKey), jwk.WithPEM(true))
if err != nil {
return "", err
}

token, err := jwt.NewBuilder().
IssuedAt(time.Now()).
Subject(clientID).
JwtID(uuid.New().String()).
Issuer(clientID).
Audience([]string{domain}).
Expiration(time.Now().Add(2 * time.Minute)).
Build()
if err != nil {
return "", err
}

b, err := jwt.Sign(token, jwt.WithKey(alg, key))
if err != nil {
return "", err
}

return string(b), nil
}
4 changes: 3 additions & 1 deletion authentication/authentication_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ func newError(response *http.Response) error {
// If that happens we still want to display the correct code.
if apiError.Status() == 0 {
apiError.StatusCode = response.StatusCode
apiError.Err = http.StatusText(response.StatusCode)
if apiError.Err == "" {
apiError.Err = http.StatusText(response.StatusCode)
}
}

return apiError
Expand Down
10 changes: 8 additions & 2 deletions authentication/http_recordings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"

"github.com/auth0/go-auth0/authentication/oauth"
"github.com/auth0/go-auth0/authentication/mfa"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -189,7 +189,13 @@ func redactTokens(t *testing.T, i *cassette.Interaction) {
return
}

tokenSet := &oauth.TokenSet{}
if i.Response.Code >= http.StatusBadRequest {
return
}

// We use mfa.VerifyWithRecoveryCodeResponse here as we don't want to lose the RecoveryCode
// property when anonymizing the tokenset
tokenSet := &mfa.VerifyWithRecoveryCodeResponse{}

err := json.Unmarshal([]byte(i.Response.Body), tokenSet)
require.NoError(t, err)
Expand Down
137 changes: 137 additions & 0 deletions authentication/mfa.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package authentication

import (
"context"
"fmt"
"net/url"
"strings"

"github.com/auth0/go-auth0/authentication/mfa"
"github.com/auth0/go-auth0/authentication/oauth"
)

// MFA exposes requesting an MFA challenge and verifying MFA methods.
type MFA manager

// Challenge requests a challenge for multi-factor authentication (MFA) based on the challenge types supported by the application and user.
//
// See: https://auth0.com/docs/api/authentication#challenge-request
func (m *MFA) Challenge(ctx context.Context, body mfa.ChallengeRequest, opts ...RequestOption) (c *mfa.ChallengeResponse, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "ChallengeType", body.ChallengeType != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

err = m.authentication.addClientAuthenticationToClientAuthStruct(&body.ClientAuthentication, false)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("mfa", "challenge"), body, &c, opts...)

if err != nil {
return nil, err
}

return
}

// VerifyWithOTP verifies an MFA challenge using a one-time password (OTP).
//
// See: https://auth0.com/docs/api/authentication#verify-with-one-time-password-otp-
func (m *MFA) VerifyWithOTP(ctx context.Context, body mfa.VerifyWithOTPRequest, opts ...RequestOption) (t *oauth.TokenSet, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "OTP", body.OTP != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

data := url.Values{
"mfa_token": []string{body.MFAToken},
"grant_type": []string{"http://auth0.com/oauth/grant-type/mfa-otp"},
"otp": []string{body.OTP},
}

err = m.authentication.addClientAuthenticationToURLValues(body.ClientAuthentication, data, true)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("oauth", "token"), data, &t, opts...)

return
}

// VerifyWithOOB verifies an MFA challenge using an out-of-band challenge (OOB), either push notification,
// SMS, or voice.
//
// See: https://auth0.com/docs/api/authentication#verify-with-out-of-band-oob-
func (m *MFA) VerifyWithOOB(ctx context.Context, body mfa.VerifyWithOOBRequest, opts ...RequestOption) (t *oauth.TokenSet, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "OOBCode", body.OOBCode != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

data := url.Values{
"mfa_token": []string{body.MFAToken},
"grant_type": []string{"http://auth0.com/oauth/grant-type/mfa-oob"},
"oob_code": []string{body.OOBCode},
}

if body.BindingCode != "" {
data.Set("binding_code", body.BindingCode)
}

err = m.authentication.addClientAuthenticationToURLValues(body.ClientAuthentication, data, true)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("oauth", "token"), data, &t, opts...)

return
}

// VerifyWithRecoveryCode verifies an MFA challenge using a recovery code.
//
// See: https://auth0.com/docs/api/authentication#verify-with-recovery-code
func (m *MFA) VerifyWithRecoveryCode(ctx context.Context, body mfa.VerifyWithRecoveryCodeRequest, opts ...RequestOption) (t *mfa.VerifyWithRecoveryCodeResponse, err error) {
missing := []string{}
check(&missing, "ClientID", (body.ClientID != "" || m.authentication.clientID != ""))
check(&missing, "MFAToken", body.MFAToken != "")
check(&missing, "RecoveryCode", body.RecoveryCode != "")

if len(missing) > 0 {
return nil, fmt.Errorf("Missing required fields: %s", strings.Join(missing, ", "))
}

data := url.Values{
"mfa_token": []string{body.MFAToken},
"grant_type": []string{"http://auth0.com/oauth/grant-type/mfa-recovery-code"},
"recovery_code": []string{body.RecoveryCode},
}

err = m.authentication.addClientAuthenticationToURLValues(body.ClientAuthentication, data, true)

if err != nil {
return nil, err
}

err = m.authentication.Request(ctx, "POST", m.authentication.URI("oauth", "token"), data, &t, opts...)

return
}
Loading
Loading