Skip to content

Commit

Permalink
oauth2: add test coverage to exercise the transactional support in th…
Browse files Browse the repository at this point in the history
…e RefreshTokenGrantHandler's PopulateTokenEndpointResponse method.

Signed-off-by: Amir Aslaminejad <[email protected]>
  • Loading branch information
aaslamin authored and aeneasr committed Dec 23, 2018
1 parent 03f7bc8 commit b38d7c8
Showing 1 changed file with 323 additions and 0 deletions.
323 changes: 323 additions & 0 deletions handler/oauth2/flow_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@
package oauth2

import (
"context"
"fmt"
"net/url"
"testing"
"time"

"github.com/golang/mock/gomock"
"github.com/ory/fosite/internal"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -274,3 +279,321 @@ func TestRefreshFlow_PopulateTokenEndpointResponse(t *testing.T) {
})
}
}

func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
var mockTransactional *internal.MockTransactional
var mockRevocationStore *internal.MockTokenRevocationStorage
request := fosite.NewAccessRequest(&fosite.DefaultSession{})
response := fosite.NewAccessResponse()
propagatedContext := context.Background()

// some storage implementation that has support for transactions, notice the embedded type `storage.Transactional`
type transactionalStore struct {
storage.Transactional
TokenRevocationStorage
}

for _, testCase := range []struct {
description string
setup func()
expectError error
}{
{
description: "transaction should be committed successfully if no errors occur",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockTransactional.
EXPECT().
Commit(propagatedContext).
Return(nil).
Times(1)
},
},
{
description: "transaction should be rolled back if call to `GetRefreshTokenSession` results in an error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(nil, fosite.ErrNotFound).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
},
{
description: "transaction should be rolled back if call to `RevokeAccessToken` results in an error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(errors.New("Whoops, a nasty database error occurred!")).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
},
{
description: "transaction should be rolled back if call to `RevokeRefreshToken` results in an error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(errors.New("Whoops, a nasty database error occurred!")).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
},
{
description: "transaction should be rolled back if call to `CreateAccessTokenSession` results in an error",
setup: func() {
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(errors.New("Whoops, a nasty database error occurred!")).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
},
{
description: "transaction should be rolled back if call to `CreateRefreshTokenSession` results in an error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(errors.New("Whoops, a nasty database error occurred!")).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
},
{
description: "should result in a server error if transaction cannot be created",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(nil, errors.New("Could not create transaction!")).
Times(1)
},
expectError: fosite.ErrServerError,
},
{
description: "should result in a server error if transaction cannot be rolled back",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(nil, fosite.ErrNotFound).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(errors.New("Could not rollback transaction!")).
Times(1)
},
expectError: fosite.ErrServerError,
},
{
description: "should result in a server error if transaction cannot be committed",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockTransactional.
EXPECT().
Commit(propagatedContext).
Return(errors.New("Could not commit transaction!")).
Times(1)
},
expectError: fosite.ErrServerError,
},
} {
t.Run(fmt.Sprintf("scenario=%s", testCase.description), func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockTransactional = internal.NewMockTransactional(ctrl)
mockRevocationStore = internal.NewMockTokenRevocationStorage(ctrl)
testCase.setup()

handler := RefreshTokenGrantHandler{
// Notice how we are passing in a store that has support for transactions!
TokenRevocationStorage: transactionalStore{
mockTransactional,
mockRevocationStore,
},
AccessTokenStrategy: &hmacshaStrategy,
RefreshTokenStrategy: &hmacshaStrategy,
AccessTokenLifespan: time.Hour,
ScopeStrategy: fosite.HierarchicScopeStrategy,
AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy,
}

if err := handler.PopulateTokenEndpointResponse(propagatedContext, request, response); testCase.expectError != nil {
assert.EqualError(t, errors.Cause(err), testCase.expectError.Error())
}
})
}
}

0 comments on commit b38d7c8

Please sign in to comment.