From 9467ca8ac7b7b7785c96f049a422ed1d16e639b4 Mon Sep 17 00:00:00 2001 From: Aeneas Rekkas Date: Sat, 2 Jan 2016 14:09:29 +0100 Subject: [PATCH] all: refactoring, renaming, more tests --- authorize.go | 153 +++++++++++++++------------- authorize_test.go | 220 ++++++++++++++++++++++++++++++++++++---- config.go | 22 ++-- errors.go | 85 +++++++++------- errors_test.go | 9 +- helper.go | 2 +- internal/generator.go | 51 ++++++++++ rand/bytes.go | 2 +- rand/bytes_test.go | 4 +- session/session.go | 40 +++++--- session/session_test.go | 29 ++++++ storage.go | 7 -- 12 files changed, 456 insertions(+), 168 deletions(-) create mode 100644 internal/generator.go create mode 100644 session/session_test.go diff --git a/authorize.go b/authorize.go index abae42452..04ca37010 100644 --- a/authorize.go +++ b/authorize.go @@ -12,60 +12,18 @@ import ( // Authorize request information type AuthorizeRequest struct { - Types []string - Client Client - Scopes []string - RedirectURI string - State string - Expiration int32 - Code *generator.Token + ResponseTypes []string + Client Client + Scopes []string + RedirectURI string + State string + ExpiresIn int32 + Code *generator.Token } -type ScopeStrategy interface { -} - -// redirectFromValues extracts the redirect_uri from values. -// * rfc6749 3.1. Authorization Endpoint -// * rfc6749 3.1.2. Redirection Endpoint -func (c *Config) redirectFromValues(values url.Values) (string, error) { - // rfc6749 3.1. Authorization Endpoint - // The endpoint URI MAY include an "application/x-www-form-urlencoded" formatted (per Appendix B) query component - redirectURI, err := url.QueryUnescape(values.Get("redirect_uri")) - if err != nil { - return "", errors.Wrap(ErrInvalidRequest, 0) - } +const minStateLength = 8 - // rfc6749 3.1.2. Redirection Endpoint - // "The redirection endpoint URI MUST be an absolute URI as defined by [RFC3986] Section 4.3" - if !(isValidURL(redirectURI)) { - return "", errors.Wrap(ErrInvalidRequest, 0) - } - - return redirectURI, nil -} - -// redirectFromClient looks up if redirect and client are matching. -// * rfc6749 10.6. Authorization Code Redirection URI Manipulation -// * rfc6819 4.4.1.7. Threat: Authorization "code" Leakage through Counterfeit Client -func (c *Config) redirectFromClient(redirectURI string, client Client) (string, error) { - // rfc6749 10.6. Authorization Code Redirection URI Manipulation - // The authorization server MUST require public clients and SHOULD require confidential clients - // to register their redirection URIs. If a redirection URI is provided - // in the request, the authorization server MUST validate it against the - // registered value. - // - // rfc6819 4.4.1.7. Threat: Authorization "code" Leakage through Counterfeit Client - // The authorization server may also enforce the usage and validation - // of pre-registered redirect URIs (see Section 5.2.3.5). - if redirectURI == "" && len(client.GetRedirectURIs()) == 1 { - if isValidURL(client.GetRedirectURIs()[0]) { - return client.GetRedirectURIs()[0], nil - } - } else if stringInSlice(redirectURI, client.GetRedirectURIs()) { - return redirectURI, nil - } - - return "", errors.New(ErrInvalidRequest) +type ScopeStrategy interface { } // NewAuthorizeRequest returns an AuthorizeRequest. This method makes rfc6749 compliant @@ -78,25 +36,25 @@ func (c *Config) redirectFromClient(redirectURI string, client Client) (string, // It also introduces countermeasures described in rfc6819: // * rfc6819 4.4.1.7. Threat: Authorization "code" Leakage through Counterfeit Client // * rfc6819 4.4.1.8. Threat: CSRF Attack against redirect-uri -func (c *Config) NewAuthorizeRequest(_ context.Context, r *http.Request, store Storage) (*AuthorizeRequest, error) { +func (c *Config) NewAuthorizeRequest(_ context.Context, r *http.Request) (*AuthorizeRequest, error) { if err := r.ParseForm(); err != nil { - return nil, errors.Wrap(ErrInvalidRequest, 0) + return nil, errors.New(ErrInvalidRequest) } - redirectURI, err := c.redirectFromValues(r.Form) + redirectURI, err := redirectFromValues(r.Form) if err != nil { - return nil, errors.Wrap(ErrInvalidRequest, 0) + return nil, errors.New(ErrInvalidRequest) } client, err := c.Store.GetClient(r.Form.Get("client_id")) if err != nil { - return nil, errors.Wrap(ErrInvalidClient, 0) + return nil, errors.New(ErrInvalidClient) } // * rfc6749 10.6. Authorization Code Redirection URI Manipulation // * rfc6819 4.4.1.7. Threat: Authorization "code" Leakage through Counterfeit Client - if redirectURI, err = c.redirectFromClient(redirectURI, client); err != nil { - return nil, errors.Wrap(ErrInvalidRequest, 0) + if redirectURI, err = redirectFromClient(redirectURI, client); err != nil { + return nil, errors.New(ErrInvalidRequest) } // rfc6749 3.1.1. Response Type @@ -111,36 +69,45 @@ func (c *Config) NewAuthorizeRequest(_ context.Context, r *http.Request, store S // response-char = "_" / DIGIT / ALPHA responseTypes := removeEmpty(strings.Split(r.Form.Get("response_type"), " ")) if !areResponseTypesValid(c, responseTypes) { - return nil, errors.Wrap(ErrUnsupportedGrantType, 0) + return nil, errors.New(ErrUnsupportedResponseType) } // rfc6819 4.4.1.8. Threat: CSRF Attack against redirect-uri // The "state" parameter should be used to link the authorization // request with the redirect URI used to deliver the access token (Section 5.3.5). + // + // https://tools.ietf.org/html/rfc6819#section-4.4.1.8 + // The "state" parameter should not be guessable state := r.Form.Get("state") if state == "" { - return nil, errors.Wrap(ErrInvalidRequest, 0) + return nil, errors.New(ErrInvalidState) + } else if len(state) < minStateLength { + // We're assuming that using less then 6 characters for the state can not be considered "unguessable" + return nil, errors.New(ErrInvalidState) } + // Generate the auth token code, err := c.AuthorizeCodeGenerator.Generate() - if state == "" { - return nil, errors.Wrap(ErrServerError, 0) + if err != nil { + return nil, errors.New(ErrServerError) } + // Remove empty items from arrays scopes := removeEmpty(strings.Split(r.Form.Get("scope"), " ")) + return &AuthorizeRequest{ - Types: responseTypes, - Client: client, - Scopes: scopes, - State: state, - Expiration: c.Lifetime, - RedirectURI: redirectURI, - Code: code, + ResponseTypes: responseTypes, + Client: client, + Scopes: scopes, + State: state, + ExpiresIn: c.Lifetime, + RedirectURI: redirectURI, + Code: code, }, nil } func (c *Config) WriteAuthError(rw http.ResponseWriter, req *http.Request, err error) { - redirectURI, err := c.redirectFromValues(req.Form) + redirectURI, err := redirectFromValues(req.Form) if err != nil { http.Error(rw, errInvalidRequestName, http.StatusBadRequest) return @@ -154,7 +121,7 @@ func (c *Config) WriteAuthError(rw http.ResponseWriter, req *http.Request, err e // * rfc6749 10.6. Authorization Code Redirection URI Manipulation // * rfc6819 4.4.1.7. Threat: Authorization "code" Leakage through Counterfeit Client - if redirectURI, err = c.redirectFromClient(redirectURI, client); err != nil { + if redirectURI, err = redirectFromClient(redirectURI, client); err != nil { http.Error(rw, errInvalidRequestName, http.StatusBadRequest) return } @@ -172,3 +139,47 @@ func (c *Config) WriteAuthError(rw http.ResponseWriter, req *http.Request, err e rw.Header().Add("Location", redir.String()) rw.WriteHeader(http.StatusFound) } + +// redirectFromValues extracts the redirect_uri from values. +// * rfc6749 3.1. Authorization Endpoint +// * rfc6749 3.1.2. Redirection Endpoint +func redirectFromValues(values url.Values) (string, error) { + // rfc6749 3.1. Authorization Endpoint + // The endpoint URI MAY include an "application/x-www-form-urlencoded" formatted (per Appendix B) query component + redirectURI, err := url.QueryUnescape(values.Get("redirect_uri")) + if err != nil { + return "", errors.Wrap(ErrInvalidRequest, 0) + } + + // rfc6749 3.1.2. Redirection Endpoint + // "The redirection endpoint URI MUST be an absolute URI as defined by [RFC3986] Section 4.3" + if !(isValidURL(redirectURI)) { + return "", errors.Wrap(ErrInvalidRequest, 0) + } + + return redirectURI, nil +} + +// redirectFromClient looks up if redirect and client are matching. +// * rfc6749 10.6. Authorization Code Redirection URI Manipulation +// * rfc6819 4.4.1.7. Threat: Authorization "code" Leakage through Counterfeit Client +func redirectFromClient(redirectURI string, client Client) (string, error) { + // rfc6749 10.6. Authorization Code Redirection URI Manipulation + // The authorization server MUST require public clients and SHOULD require confidential clients + // to register their redirection URIs. If a redirection URI is provided + // in the request, the authorization server MUST validate it against the + // registered value. + // + // rfc6819 4.4.1.7. Threat: Authorization "code" Leakage through Counterfeit Client + // The authorization server may also enforce the usage and validation + // of pre-registered redirect URIs (see Section 5.2.3.5). + if redirectURI == "" && len(client.GetRedirectURIs()) == 1 { + if isValidURL(client.GetRedirectURIs()[0]) { + return client.GetRedirectURIs()[0], nil + } + } else if stringInSlice(redirectURI, client.GetRedirectURIs()) { + return redirectURI, nil + } + + return "", errors.New(ErrInvalidRequest) +} diff --git a/authorize_test.go b/authorize_test.go index adf266fc9..e16b7402a 100644 --- a/authorize_test.go +++ b/authorize_test.go @@ -3,8 +3,10 @@ package fosite import ( "github.com/golang/mock/gomock" . "github.com/ory-am/fosite/client" + "github.com/ory-am/fosite/generator" . "github.com/ory-am/fosite/internal" "github.com/stretchr/testify/assert" + "github.com/vektra/errors" "golang.org/x/net/context" "net/http" "net/url" @@ -17,7 +19,6 @@ import ( // rfc6749 3.1.2. Redirection Endpoint // "The redirection endpoint URI MUST be an absolute URI as defined by [RFC3986] Section 4.3" func TestGetRedirectURI(t *testing.T) { - cf := &Config{} for _, c := range []struct { in string isError bool @@ -27,7 +28,7 @@ func TestGetRedirectURI(t *testing.T) { } { values := url.Values{} values.Set("redirect_uri", c.in) - res, err := cf.redirectFromValues(values) + res, err := redirectFromValues(values) assert.Equal(t, c.isError, err != nil, "%s", err) if err == nil { assert.Equal(t, c.expected, res) @@ -50,7 +51,6 @@ func TestDoesClientWhiteListRedirect(t *testing.T) { var err error var redir string - cf := &Config{} for k, c := range []struct { client Client url string @@ -90,7 +90,7 @@ func TestDoesClientWhiteListRedirect(t *testing.T) { isError: true, }, } { - redir, err = cf.redirectFromClient(c.url, c.client) + redir, err = redirectFromClient(c.url, c.client) assert.Equal(t, c.isError, err != nil, "%d: %s", k, err) assert.Equal(t, c.expected, redir) } @@ -99,29 +99,215 @@ func TestDoesClientWhiteListRedirect(t *testing.T) { func TestNewAuthorizeRequest(t *testing.T) { ctrl := gomock.NewController(t) store := NewMockStorage(ctrl) + gen := NewMockGenerator(ctrl) defer ctrl.Finish() for k, c := range []struct { - conf *Config - r *http.Request - isError bool - mock func() + desc string + conf *Config + r *http.Request + query url.Values + expectedError error + mock func() + expect *AuthorizeRequest }{ + /* empty request */ { - conf: &Config{}, - r: &http.Request{ - Header: http.Header{"": []string{""}}, - Form: url.Values{}, - PostForm: url.Values{}, + desc: "empty request fails", + conf: &Config{Store: store}, + r: &http.Request{}, + expectedError: ErrInvalidRequest, + mock: func() {}, + }, + /* invalid redirect uri */ + { + desc: "invalid redirect uri fails", + conf: &Config{Store: store}, + query: url.Values{"redirect_uri": []string{"invalid"}}, + expectedError: ErrInvalidRequest, + mock: func() {}, + }, + /* invalid client */ + { + desc: "invalid client uri fails", + conf: &Config{Store: store}, + query: url.Values{"redirect_uri": []string{"http://foo.bar/cb"}}, + expectedError: ErrInvalidClient, + mock: func() { + store.EXPECT().GetClient(gomock.Any()).Return(nil, errors.New("foo")) + }, + }, + /* redirect client mismatch */ + { + desc: "client and request redirects mismatch", + conf: &Config{Store: store}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, }, + expectedError: ErrInvalidRequest, mock: func() { - //store.EXPECT().GetClient() + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"invalid"}}, nil) + }, + }, + /* no response type */ + { + desc: "no response type", + conf: &Config{Store: store}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + }, + expectedError: ErrUnsupportedResponseType, + mock: func() { + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* invalid response type */ + { + desc: "invalid response type", + conf: &Config{Store: store}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"foo"}, + }, + expectedError: ErrUnsupportedResponseType, + mock: func() { + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* invalid response type */ + { + desc: "invalid response type", + conf: &Config{Store: store}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"foo"}, + }, + expectedError: ErrUnsupportedResponseType, + mock: func() { + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* unsupported response type */ + { + desc: "unspported response type", + conf: &Config{Store: store, AllowedResponseTypes: []string{"code"}}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"code token"}, + }, + expectedError: ErrUnsupportedResponseType, + mock: func() { + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* unsupported response type */ + { + desc: "unspported response type", + conf: &Config{Store: store, AllowedResponseTypes: []string{"code"}}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"foo"}, + }, + expectedError: ErrUnsupportedResponseType, + mock: func() { + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* no state */ + { + desc: "no state", + conf: &Config{Store: store, AllowedResponseTypes: []string{"code"}}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"code"}, + }, + expectedError: ErrInvalidState, + mock: func() { + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* short state */ + { + desc: "short state", + conf: &Config{Store: store, AllowedResponseTypes: []string{"code"}}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"code"}, + "state": []string{"short"}, + }, + expectedError: ErrInvalidState, + mock: func() { + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* code gen fails */ + { + desc: "code gen fails", + conf: &Config{Store: store, AuthorizeCodeGenerator: gen, AllowedResponseTypes: []string{"code"}}, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"code"}, + "state": []string{"strong-state"}, + }, + expectedError: ErrServerError, + mock: func() { + gen.EXPECT().Generate().Return(nil, errors.New("")) + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + }, + /* success case */ + { + desc: "should pass", + conf: &Config{ + Store: store, + AuthorizeCodeGenerator: gen, + AllowedResponseTypes: []string{"code", "token"}, + Lifetime: 3600, + }, + query: url.Values{ + "redirect_uri": []string{"http://foo.bar/cb"}, + "client_id": []string{"1234"}, + "response_type": []string{"code token"}, + "state": []string{"strong-state"}, + "scope": []string{"foo bar"}, + }, + mock: func() { + gen.EXPECT().Generate().Return(&generator.Token{Key: "foo", Signature: "bar"}, nil) + store.EXPECT().GetClient("1234").Return(&SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: "http://foo.bar/cb", + Client: &SecureClient{RedirectURIs: []string{"http://foo.bar/cb"}}, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Scopes: []string{"foo", "bar"}, + ExpiresIn: 3600, + Code: &generator.Token{Key: "foo", Signature: "bar"}, }, - isError: true, }, } { c.mock() - _, err := c.conf.NewAuthorizeRequest(context.Background(), c.r, store) - assert.Equal(t, c.isError, err != nil, "%d: %s", k, err) + if c.r == nil { + c.r = &http.Request{Header: http.Header{}} + if c.query != nil { + c.r.URL = &url.URL{RawQuery: c.query.Encode()} + } + } + + ar, err := c.conf.NewAuthorizeRequest(context.Background(), c.r) + assert.Equal(t, c.expectedError == nil, err == nil, "%d: %s\n%s", k, c.desc, err) + if c.expectedError != nil { + assert.Equal(t, err.Error(), c.expectedError.Error(), "%d: %s\n%s", k, c.desc, err) + } + assert.Equal(t, c.expect, ar, "%d: %s\n", k, c.desc) } } diff --git a/config.go b/config.go index 47d4f20c4..8365f7ddd 100644 --- a/config.go +++ b/config.go @@ -3,20 +3,20 @@ package fosite import "github.com/ory-am/fosite/generator" type Config struct { - AllowedAuthorizeResponseTypes []string - AllowedTokenResponseTypes []string - Lifetime int32 - Store Storage - Entropy int32 - AuthorizeCodeGenerator generator.Generator + AllowedResponseTypes []string + AllowedTokenResponseTypes []string + Lifetime int32 + Store Storage + Entropy int32 + AuthorizeCodeGenerator generator.Generator } func NewDefaultConfig() *Config { return &Config{ - AllowedAuthorizeResponseTypes: []string{"code", "token", "id_token"}, - AllowedTokenResponseTypes: []string{}, - Lifetime: 3600, - Entropy: 128, - AuthorizeCodeGenerator: &generator.CryptoGenerator{}, + AllowedResponseTypes: []string{"code", "token", "id_token"}, + AllowedTokenResponseTypes: []string{}, + Lifetime: 3600, + Entropy: 128, + AuthorizeCodeGenerator: &generator.CryptoGenerator{}, } } diff --git a/errors.go b/errors.go index c82bcb7e8..ad2936e85 100644 --- a/errors.go +++ b/errors.go @@ -3,30 +3,32 @@ package fosite import "github.com/go-errors/errors" var ( - ErrInvalidRequest = errors.New("The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed.") - ErrUnauthorizedClient = errors.New("The client is not authorized to request a token using this method.") - ErrAccessDenied = errors.New("The resource owner or authorization server denied the request.") - ErrUnsupportedResponseType = errors.New("The authorization server does not support obtaining a token using this method.") - ErrInvalidScope = errors.New("The requested scope is invalid, unknown, or malformed.") - ErrServerError = errors.New("The authorization server encountered an unexpected condition that prevented it from fulfilling the request.") - ErrTemporarilyUnavailable = errors.New("The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server.") - ErrUnsupportedGrantType = errors.New("The authorization grant type is not supported by the authorization server.") - ErrInvalidGrant = errors.New("The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client.") - ErrInvalidClient = errors.New("Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method).") + ErrInvalidRequest = errors.New("The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed") + ErrUnauthorizedClient = errors.New("The client is not authorized to request a token using this method") + ErrAccessDenied = errors.New("The resource owner or authorization server denied the request") + ErrUnsupportedResponseType = errors.New("The authorization server does not support obtaining a token using this method") + ErrInvalidScope = errors.New("The requested scope is invalid, unknown, or malformed") + ErrServerError = errors.New("The authorization server encountered an unexpected condition that prevented it from fulfilling the request") + ErrTemporarilyUnavailable = errors.New("The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server") + ErrUnsupportedGrantType = errors.New("The authorization grant type is not supported by the authorization server") + ErrInvalidGrant = errors.New("The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client") + ErrInvalidClient = errors.New("Client authentication failed (e.g., unknown client, no client authentication included, or unsupported authentication method)") + ErrInvalidState = errors.Errorf("The state is missing or has less than %d characters and is therefore considered too weak", minStateLength) ) const ( - errInvalidRequestName = "invalid_request" - errUnauthorizedClientName = "unauthorized_client" - errAccessDeniedName = "acccess_denied" + errInvalidRequestName = "invalid_request" + errUnauthorizedClientName = "unauthorized_client" + errAccessDeniedName = "acccess_denied" errUnsupportedResponseTypeName = "unsupported_response_type" - errInvalidScopeName = "invalid_scope" - errServerErrorName = "server_error" - errTemporarilyUnavailableName = "errTemporarilyUnavailableName" - errUnsupportedGrantTypeName = "unsupported_grant_type" - errInvalidGrantName = "invalid_grant" - errInvalidClientName = "invalid_client" - errInvalidError = "invalid_error" + errInvalidScopeName = "invalid_scope" + errServerErrorName = "server_error" + errTemporarilyUnavailableName = "errTemporarilyUnavailableName" + errUnsupportedGrantTypeName = "unsupported_grant_type" + errInvalidGrantName = "invalid_grant" + errInvalidClientName = "invalid_client" + errInvalidError = "invalid_error" + errInvalidState = "invalid_state" ) type RFC6749Error struct { @@ -35,72 +37,77 @@ type RFC6749Error struct { Hint string `json:"-"` } -func ErrorToRFC6749(err error) (*RFC6749Error) { +func ErrorToRFC6749(err error) *RFC6749Error { ge, ok := err.(*errors.Error) if !ok { return &RFC6749Error{ - Name: errInvalidError, + Name: errInvalidError, Description: "The error is unrecognizable.", - Hint: err.Error(), + Hint: err.Error(), } } if errors.Is(ge, ErrInvalidRequest) { return &RFC6749Error{ - Name: errInvalidRequestName, + Name: errInvalidRequestName, Description: ge.Error(), - Hint: "Make sure that the various parameters are correct, be aware of case sensitivity and trim your parameters. Make sure that the client you are using has exactly whitelisted the redirect_uri you specified.", + Hint: "Make sure that the various parameters are correct, be aware of case sensitivity and trim your parameters. Make sure that the client you are using has exactly whitelisted the redirect_uri you specified.", } } else if errors.Is(ge, ErrUnauthorizedClient) { return &RFC6749Error{ - Name: errUnauthorizedClientName, + Name: errUnauthorizedClientName, Description: ge.Error(), - Hint: "Make sure that client id and secret are correctly specified and that the client exists.", + Hint: "Make sure that client id and secret are correctly specified and that the client exists.", } } else if errors.Is(ge, ErrAccessDenied) { return &RFC6749Error{ - Name: errAccessDeniedName, + Name: errAccessDeniedName, Description: ge.Error(), - Hint: "Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", + Hint: "Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", } } else if errors.Is(ge, ErrUnsupportedResponseType) { return &RFC6749Error{ - Name: errUnsupportedResponseTypeName, + Name: errUnsupportedResponseTypeName, Description: ge.Error(), } } else if errors.Is(ge, ErrInvalidScope) { return &RFC6749Error{ - Name: errInvalidScopeName, + Name: errInvalidScopeName, Description: ge.Error(), } } else if errors.Is(ge, ErrServerError) { return &RFC6749Error{ - Name: errServerErrorName, + Name: errServerErrorName, Description: ge.Error(), } } else if errors.Is(ge, ErrTemporarilyUnavailable) { return &RFC6749Error{ - Name: errTemporarilyUnavailableName, + Name: errTemporarilyUnavailableName, Description: ge.Error(), } } else if errors.Is(ge, ErrUnsupportedGrantType) { return &RFC6749Error{ - Name: errUnsupportedGrantTypeName, + Name: errUnsupportedGrantTypeName, Description: ge.Error(), } } else if errors.Is(ge, ErrInvalidGrant) { return &RFC6749Error{ - Name: errInvalidGrantName, + Name: errInvalidGrantName, Description: ge.Error(), } } else if errors.Is(ge, ErrInvalidClient) { return &RFC6749Error{ - Name: errInvalidClientName, + Name: errInvalidClientName, + Description: ge.Error(), + } + } else if errors.Is(ge, ErrInvalidState) { + return &RFC6749Error{ + Name: errInvalidState, Description: ge.Error(), } } return &RFC6749Error{ - Name: errInvalidError, + Name: errInvalidError, Description: "The error is unrecognizable.", - Hint: ge.Error(), + Hint: ge.Error(), } -} \ No newline at end of file +} diff --git a/errors_test.go b/errors_test.go index a64881f9c..15b386399 100644 --- a/errors_test.go +++ b/errors_test.go @@ -1,10 +1,10 @@ package fosite import ( - "testing" - "github.com/stretchr/testify/assert" - "github.com/go-errors/errors" native "errors" + "github.com/go-errors/errors" + "github.com/stretchr/testify/assert" + "testing" ) func TestErrorToRFC6749(t *testing.T) { @@ -21,4 +21,5 @@ func TestErrorToRFC6749(t *testing.T) { assert.Equal(t, errUnsupportedGrantTypeName, ErrorToRFC6749(errors.New(ErrUnsupportedGrantType)).Name) assert.Equal(t, errInvalidGrantName, ErrorToRFC6749(errors.New(ErrInvalidGrant)).Name) assert.Equal(t, errInvalidClientName, ErrorToRFC6749(errors.New(ErrInvalidClient)).Name) -} \ No newline at end of file + assert.Equal(t, errInvalidState, ErrorToRFC6749(errors.New(ErrInvalidState)).Name) +} diff --git a/helper.go b/helper.go index d8f6d2de4..c0c8a33ee 100644 --- a/helper.go +++ b/helper.go @@ -10,7 +10,7 @@ func areResponseTypesValid(c *Config, responseTypes []string) bool { return false } for _, responseType := range responseTypes { - if !stringInSlice(responseType, c.AllowedAuthorizeResponseTypes) { + if !stringInSlice(responseType, c.AllowedResponseTypes) { return false } } diff --git a/internal/generator.go b/internal/generator.go new file mode 100644 index 000000000..1e948ec4a --- /dev/null +++ b/internal/generator.go @@ -0,0 +1,51 @@ +// Automatically generated by MockGen. DO NOT EDIT! +// Source: generator/generator.go + +package internal + +import ( + gomock "github.com/golang/mock/gomock" + . "github.com/ory-am/fosite/generator" +) + +// Mock of Generator interface +type MockGenerator struct { + ctrl *gomock.Controller + recorder *_MockGeneratorRecorder +} + +// Recorder for MockGenerator (not exported) +type _MockGeneratorRecorder struct { + mock *MockGenerator +} + +func NewMockGenerator(ctrl *gomock.Controller) *MockGenerator { + mock := &MockGenerator{ctrl: ctrl} + mock.recorder = &_MockGeneratorRecorder{mock} + return mock +} + +func (_m *MockGenerator) EXPECT() *_MockGeneratorRecorder { + return _m.recorder +} + +func (_m *MockGenerator) Generate() (*Token, error) { + ret := _m.ctrl.Call(_m, "Generate") + ret0, _ := ret[0].(*Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (_mr *_MockGeneratorRecorder) Generate() *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "Generate") +} + +func (_m *MockGenerator) ValidateSignature(token *Token) error { + ret := _m.ctrl.Call(_m, "ValidateSignature", token) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockGeneratorRecorder) ValidateSignature(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "ValidateSignature", arg0) +} diff --git a/rand/bytes.go b/rand/bytes.go index 1a5c692a4..e3f33a4de 100644 --- a/rand/bytes.go +++ b/rand/bytes.go @@ -22,4 +22,4 @@ func RandomBytes(entropy, tries int) ([]byte, error) { } return bytes, nil -} \ No newline at end of file +} diff --git a/rand/bytes_test.go b/rand/bytes_test.go index 727fa5033..baedfd9ae 100644 --- a/rand/bytes_test.go +++ b/rand/bytes_test.go @@ -1,8 +1,8 @@ package rand import ( - "testing" "github.com/stretchr/testify/assert" + "testing" ) func TestRandomBytes(t *testing.T) { @@ -22,4 +22,4 @@ func TestPseudoRandomness(t *testing.T) { assert.False(t, ok) results[string(bytes)] = true } -} \ No newline at end of file +} diff --git a/session/session.go b/session/session.go index de627e52e..1ab26f45e 100644 --- a/session/session.go +++ b/session/session.go @@ -7,7 +7,7 @@ import ( ) // Session defines a authorize flow session which will be persisted and passed to the token endpoint (Authorize Code Flow). -type Session interface { +type AuthorizeSession interface { // SetExtra sets extra information that you want to be persisted. Ignore this if you have // your own session management or do not need additional persistent states. SetExtra(extra interface{}) error @@ -24,6 +24,9 @@ type Session interface { // GetResponseTypes returns the scope for this authorize session. GetScopes() []string + // GetUser returns the user for this authorize session. + GetUserID() string + // GetResponseTypes returns the redirect_uri for this authorize session. GetRedirectURI() string @@ -37,9 +40,9 @@ type Session interface { GetCodeSignature() string } -// JSONSession uses json.Marshal and json.Unmarshall to store extra information. It is recommended to use this +// defaultSession uses json.Marshal and json.Unmarshall to store extra information. It is recommended to use this // implementation. -type JSONSession struct { +type defaultSession struct { extra []byte responseTypes []string clientID string @@ -47,22 +50,25 @@ type JSONSession struct { redirectURI string state string signature string + userID string ar *fosite.AuthorizeRequest } -func NewJSONSession(ar *fosite.AuthorizeRequest) *JSONSession { - return &JSONSession{ +func NewAuthorizeSession(ar *fosite.AuthorizeRequest, userID string) AuthorizeSession { + return &defaultSession{ ar: ar, signature: ar.Code.Signature, extra: []byte{}, - responseTypes: ar.Types, + responseTypes: ar.ResponseTypes, clientID: ar.Client.GetID(), state: ar.State, + scopes: ar.Scopes, redirectURI: ar.RedirectURI, + userID: userID, } } -func (s *JSONSession) SetExtra(extra interface{}) error { +func (s *defaultSession) SetExtra(extra interface{}) error { result, err := json.Marshal(extra) if err != nil { return errors.New(err) @@ -71,33 +77,37 @@ func (s *JSONSession) SetExtra(extra interface{}) error { return nil } -func (s *JSONSession) WriteExtra(to interface{}) error { +func (s *defaultSession) WriteExtra(to interface{}) error { if err := json.Unmarshal(s.extra, to); err != nil { return errors.New(err) } return nil } -func (s *JSONSession) GetResponseTypes() []string { +func (s *defaultSession) GetResponseTypes() []string { return s.responseTypes } -func (s *JSONSession) GetClientID() string { +func (s *defaultSession) GetClientID() string { return s.clientID } -func (s *JSONSession) GetScopes() []string { +func (s *defaultSession) GetScopes() []string { return s.scopes } -func (s *JSONSession) GetRedirectURI() string { +func (s *defaultSession) GetRedirectURI() string { return s.redirectURI } -func (s *JSONSession) GetState() string { +func (s *defaultSession) GetState() string { return s.state } -func (s *JSONSession) GetCodeSignature() string { - return s.GetCodeSignature() +func (s *defaultSession) GetCodeSignature() string { + return s.signature +} + +func (s *defaultSession) GetUserID() string { + return s.userID } diff --git a/session/session_test.go b/session/session_test.go new file mode 100644 index 000000000..0b40a3698 --- /dev/null +++ b/session/session_test.go @@ -0,0 +1,29 @@ +package session + +import ( + "github.com/ory-am/fosite" + "github.com/ory-am/fosite/client" + "github.com/ory-am/fosite/generator" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestNewAuthorizeSession(t *testing.T) { + ar := &fosite.AuthorizeRequest{ + ResponseTypes: []string{"code token"}, + Client: &client.SecureClient{ID: "client"}, + Scopes: []string{"email id_token"}, + RedirectURI: "https://foo.bar/cb", + State: "randomState", + ExpiresIn: 30, + Code: &generator.Token{Key: "key", Signature: "sig"}, + } + as := NewAuthorizeSession(ar, "1234") + + assert.Equal(t, ar.Client.GetID(), as.GetClientID()) + assert.Equal(t, ar.Code.Signature, as.GetCodeSignature()) + assert.Equal(t, ar.RedirectURI, as.GetRedirectURI()) + assert.Equal(t, ar.ResponseTypes, as.GetResponseTypes()) + assert.Equal(t, ar.Scopes, as.GetScopes()) + assert.Equal(t, "1234", as.GetUserID()) +} diff --git a/storage.go b/storage.go index d6c08d309..9df4121d3 100644 --- a/storage.go +++ b/storage.go @@ -39,10 +39,3 @@ type Storage interface { // RemoveRefresh revokes or deletes refresh AccessData. // RemoveRefresh(token string) error } - -// Manager defines an optional but recommended API for your fosite storage implementation. This API is not -// consumed by fosite itself. You don not need to implement this library, it is merely a good practice guide. -type Manager interface { - // StoreClient stores a client or returns an error. - StoreClient(Client) error -}