diff --git a/chrysom/acquire.go b/chrysom/acquire.go new file mode 100644 index 0000000..325a761 --- /dev/null +++ b/chrysom/acquire.go @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 + +package chrysom + +import ( + "net/http" +) + +// Acquirer adds an authorization header and value to a given http request. +type Acquirer interface { + AddAuth(*http.Request) error +} diff --git a/chrysom/acquire_test.go b/chrysom/acquire_test.go new file mode 100644 index 0000000..a08346b --- /dev/null +++ b/chrysom/acquire_test.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: 2024 Comcast Cable Communications Management, LLC +// SPDX-License-Identifier: Apache-2.0 +package chrysom + +import ( + "net/http" + + "github.com/stretchr/testify/mock" +) + +type MockAquirer struct { + mock.Mock +} + +func (m *MockAquirer) AddAuth(req *http.Request) error { + args := m.Called(req) + + return args.Error(0) +} diff --git a/chrysom/basicClient.go b/chrysom/basicClient.go index d4339d2..3b3e848 100644 --- a/chrysom/basicClient.go +++ b/chrysom/basicClient.go @@ -13,7 +13,6 @@ import ( "net/http" "github.com/xmidt-org/ancla/model" - "github.com/xmidt-org/bascule/acquire" "github.com/xmidt-org/sallust" "go.uber.org/zap" ) @@ -25,6 +24,7 @@ var ( ErrItemDataEmpty = errors.New("data field in item is required") ErrUndefinedIntervalTicker = errors.New("interval ticker is nil. Can't listen for updates") ErrAuthAcquirerFailure = errors.New("failed acquiring auth token") + ErrAuthAcquirerNil = errors.New("auth aquirer is nil") ErrBadRequest = errors.New("argus rejected the request as invalid") ) @@ -58,23 +58,17 @@ type BasicClientConfig struct { // Auth provides the mechanism to add auth headers to outgoing requests. // (Optional) If not provided, no auth headers are added. - Auth Auth + Auth Acquirer } // BasicClient is the client used to make requests to Argus. type BasicClient struct { client *http.Client - auth acquire.Acquirer + auth Acquirer storeBaseURL string bucket string } -// Auth contains authorization data for requests to Argus. -type Auth struct { - JWT acquire.RemoteBearerTokenAcquirerOptions - Basic string -} - type response struct { Body []byte ArgusErrorHeader string @@ -99,13 +93,9 @@ func NewBasicClient(config BasicClientConfig) (*BasicClient, error) { return nil, err } - tokenAcquirer, err := buildTokenAcquirer(config.Auth) - if err != nil { - return nil, err - } clientStore := &BasicClient{ client: config.HTTPClient, - auth: tokenAcquirer, + auth: config.Auth, bucket: config.Bucket, storeBaseURL: config.Address + storeAPIPath, } @@ -210,7 +200,10 @@ func (c *BasicClient) sendRequest(ctx context.Context, owner, method, url string if err != nil { return response{}, fmt.Errorf(errWrappedFmt, errNewRequestFailure, err.Error()) } - err = acquire.AddAuth(r, c.auth) + if c.auth == nil { + return response{}, ErrAuthAcquirerNil + } + err = c.auth.AddAuth(r) if err != nil { return response{}, fmt.Errorf(errWrappedFmt, ErrAuthAcquirerFailure, err.Error()) } @@ -234,10 +227,6 @@ func (c *BasicClient) sendRequest(ctx context.Context, owner, method, url string return sqResp, nil } -func isEmpty(options acquire.RemoteBearerTokenAcquirerOptions) bool { - return len(options.AuthURL) < 1 || options.Buffer == 0 || options.Timeout == 0 -} - // translateNonSuccessStatusCode returns as specific error // for known Argus status codes. func translateNonSuccessStatusCode(code int) error { @@ -251,15 +240,6 @@ func translateNonSuccessStatusCode(code int) error { } } -func buildTokenAcquirer(auth Auth) (acquire.Acquirer, error) { - if !isEmpty(auth.JWT) { - return acquire.NewRemoteBearerTokenAcquirer(auth.JWT) - } else if len(auth.Basic) > 0 { - return acquire.NewFixedAuthAcquirer(auth.Basic) - } - return &acquire.DefaultAcquirer{}, nil -} - func validateBasicConfig(config *BasicClientConfig) error { if config.Address == "" { return ErrAddressEmpty diff --git a/chrysom/basicClient_test.go b/chrysom/basicClient_test.go index 738657e..8fe1a7d 100644 --- a/chrysom/basicClient_test.go +++ b/chrysom/basicClient_test.go @@ -17,6 +17,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/xmidt-org/ancla/model" ) @@ -24,8 +25,9 @@ import ( const failingURL = "nowhere://" var ( - _ Pusher = &BasicClient{} - _ Reader = &BasicClient{} + _ Pusher = &BasicClient{} + _ Reader = &BasicClient{} + errFails = errors.New("fails") ) func TestValidateBasicConfig(t *testing.T) { @@ -45,7 +47,6 @@ func TestValidateBasicConfig(t *testing.T) { allDefinedCaseConfig := &BasicClientConfig{ HTTPClient: myAmazingClient, Address: "http://legit-argus-hostname.io", - Auth: Auth{}, Bucket: "amazing-bucket", } @@ -102,10 +103,10 @@ func TestSendRequest(t *testing.T) { Method string URL string Body []byte - AcquirerFails bool ClientDoFails bool ExpectedResponse response ExpectedErr error + MockError error } tcs := []testCase{ @@ -114,19 +115,21 @@ func TestSendRequest(t *testing.T) { Method: "what method?", URL: "http://argus-hostname.io", ExpectedErr: errNewRequestFailure, + MockError: nil, }, { - Description: "Auth acquirer fails", - Method: http.MethodGet, - URL: "http://argus-hostname.io", - AcquirerFails: true, - ExpectedErr: ErrAuthAcquirerFailure, + Description: "Auth acquirer fails", + Method: http.MethodGet, + URL: "http://argus-hostname.io", + MockError: errFails, + ExpectedErr: ErrAuthAcquirerFailure, }, { Description: "Client Do fails", Method: http.MethodPut, ClientDoFails: true, ExpectedErr: errDoRequestFailure, + MockError: nil, }, { Description: "Happy path", @@ -138,6 +141,7 @@ func TestSendRequest(t *testing.T) { Code: http.StatusOK, Body: []byte("testing"), }, + MockError: nil, }, } for _, tc := range tcs { @@ -145,6 +149,7 @@ func TestSendRequest(t *testing.T) { assert := assert.New(t) require := require.New(t) + acquirer := new(MockAquirer) echoHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { assert.Equal(tc.Owner, r.Header.Get(ItemOwnerHeaderKey)) rw.WriteHeader(http.StatusOK) @@ -162,9 +167,8 @@ func TestSendRequest(t *testing.T) { Bucket: "bucket-name", }) - if tc.AcquirerFails { - client.auth = acquirerFunc(failAcquirer) - } + acquirer.On("AddAuth", mock.Anything).Return(tc.MockError) + client.auth = acquirer var URL = server.URL if tc.ClientDoFails { @@ -186,63 +190,70 @@ func TestSendRequest(t *testing.T) { func TestGetItems(t *testing.T) { type testCase struct { - Description string - ResponsePayload []byte - ResponseCode int - ShouldMakeRequestFail bool - ShouldDoRequestFail bool - ExpectedErr error - ExpectedOutput Items + Description string + ResponsePayload []byte + ResponseCode int + ShouldDoRequestFail bool + ExpectedErr error + ExpectedOutput Items + MockError error } tcs := []testCase{ { - Description: "Make request fails", - ShouldMakeRequestFail: true, - ExpectedErr: ErrAuthAcquirerFailure, + Description: "Make request fails", + ExpectedErr: ErrAuthAcquirerFailure, + MockError: errFails, }, { Description: "Do request fails", ShouldDoRequestFail: true, ExpectedErr: errDoRequestFailure, + MockError: nil, }, { Description: "Unauthorized", ResponseCode: http.StatusForbidden, ExpectedErr: ErrFailedAuthentication, + MockError: nil, }, { Description: "Bad request", ResponseCode: http.StatusBadRequest, ExpectedErr: ErrBadRequest, + MockError: nil, }, { Description: "Other non-success", ResponseCode: http.StatusInternalServerError, ExpectedErr: errNonSuccessResponse, + MockError: nil, }, { Description: "Payload unmarshal error", ResponseCode: http.StatusOK, ResponsePayload: []byte("[{}"), ExpectedErr: errJSONUnmarshal, + MockError: nil, }, { Description: "Happy path", ResponseCode: http.StatusOK, ResponsePayload: getItemsValidPayload(), ExpectedOutput: getItemsHappyOutput(), + MockError: nil, }, } for _, tc := range tcs { t.Run(tc.Description, func(t *testing.T) { var ( - assert = assert.New(t) - require = require.New(t) - bucket = "bucket-name" - owner = "owner-name" + assert = assert.New(t) + require = require.New(t) + bucket = "bucket-name" + owner = "owner-name" + acquirer = new(MockAquirer) ) server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -262,9 +273,8 @@ func TestGetItems(t *testing.T) { require.Nil(err) - if tc.ShouldMakeRequestFail { - client.auth = acquirerFunc(failAcquirer) - } + acquirer.On("AddAuth", mock.Anything).Return(tc.MockError) + client.auth = acquirer if tc.ShouldDoRequestFail { client.storeBaseURL = failingURL @@ -282,16 +292,16 @@ func TestGetItems(t *testing.T) { func TestPushItem(t *testing.T) { type testCase struct { - Description string - Item model.Item - Owner string - ResponseCode int - ShouldEraseBucket bool - ShouldRespNonSuccess bool - ShouldMakeRequestFail bool - ShouldDoRequestFail bool - ExpectedErr error - ExpectedOutput PushResult + Description string + Item model.Item + Owner string + ResponseCode int + ShouldEraseBucket bool + ShouldRespNonSuccess bool + ShouldDoRequestFail bool + ExpectedErr error + ExpectedOutput PushResult + MockError error } validItem := model.Item{ @@ -308,53 +318,61 @@ func TestPushItem(t *testing.T) { Description: "Item ID Missing", Item: model.Item{Data: map[string]interface{}{}}, ExpectedErr: ErrItemIDEmpty, + MockError: nil, }, { Description: "Item Data missing", Item: model.Item{ID: validItem.ID}, ExpectedErr: ErrItemDataEmpty, + MockError: nil, }, { - Description: "Make request fails", - Item: validItem, - ShouldMakeRequestFail: true, - ExpectedErr: ErrAuthAcquirerFailure, + Description: "Make request fails", + Item: validItem, + ExpectedErr: ErrAuthAcquirerFailure, + MockError: errFails, }, { Description: "Do request fails", Item: validItem, ShouldDoRequestFail: true, ExpectedErr: errDoRequestFailure, + MockError: nil, }, { Description: "Unauthorized", Item: validItem, ResponseCode: http.StatusForbidden, ExpectedErr: ErrFailedAuthentication, + MockError: nil, }, { Description: "Bad request", Item: validItem, ResponseCode: http.StatusBadRequest, ExpectedErr: ErrBadRequest, + MockError: nil, }, { Description: "Other non-success", Item: validItem, ResponseCode: http.StatusInternalServerError, ExpectedErr: errNonSuccessResponse, + MockError: nil, }, { Description: "Create success", Item: validItem, ResponseCode: http.StatusCreated, ExpectedOutput: CreatedPushResult, + MockError: nil, }, { Description: "Update success", Item: validItem, ResponseCode: http.StatusOK, ExpectedOutput: UpdatedPushResult, + MockError: nil, }, { Description: "Update success with owner", @@ -362,16 +380,18 @@ func TestPushItem(t *testing.T) { ResponseCode: http.StatusOK, Owner: "owner-name", ExpectedOutput: UpdatedPushResult, + MockError: nil, }, } for _, tc := range tcs { t.Run(tc.Description, func(t *testing.T) { var ( - assert = assert.New(t) - require = require.New(t) - bucket = "bucket-name" - id = "252f10c83610ebca1a059c0bae8255eba2f95be4d1d7bcfa89d7248a82d9f111" + assert = assert.New(t) + require = require.New(t) + bucket = "bucket-name" + id = "252f10c83610ebca1a059c0bae8255eba2f95be4d1d7bcfa89d7248a82d9f111" + acquirer = new(MockAquirer) ) server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -395,9 +415,8 @@ func TestPushItem(t *testing.T) { Bucket: bucket, }) - if tc.ShouldMakeRequestFail { - client.auth = acquirerFunc(failAcquirer) - } + acquirer.On("AddAuth", mock.Anything).Return(tc.MockError) + client.auth = acquirer if tc.ShouldDoRequestFail { client.storeBaseURL = failingURL @@ -421,54 +440,60 @@ func TestPushItem(t *testing.T) { func TestRemoveItem(t *testing.T) { type testCase struct { - Description string - ResponsePayload []byte - ResponseCode int - Owner string - ShouldRespNonSuccess bool - ShouldMakeRequestFail bool - ShouldDoRequestFail bool - ExpectedErr error - ExpectedOutput model.Item + Description string + ResponsePayload []byte + ResponseCode int + Owner string + ShouldRespNonSuccess bool + ShouldDoRequestFail bool + ExpectedErr error + ExpectedOutput model.Item + MockError error } tcs := []testCase{ { - Description: "Make request fails", - ShouldMakeRequestFail: true, - ExpectedErr: ErrAuthAcquirerFailure, + Description: "Make request fails", + ExpectedErr: ErrAuthAcquirerFailure, + MockError: errFails, }, { Description: "Do request fails", ShouldDoRequestFail: true, ExpectedErr: errDoRequestFailure, + MockError: nil, }, { Description: "Unauthorized", ResponseCode: http.StatusForbidden, ExpectedErr: ErrFailedAuthentication, + MockError: nil, }, { Description: "Bad request", ResponseCode: http.StatusBadRequest, ExpectedErr: ErrBadRequest, + MockError: nil, }, { Description: "Other non-success", ResponseCode: http.StatusInternalServerError, ExpectedErr: errNonSuccessResponse, + MockError: nil, }, { Description: "Unmarshal failure", ResponseCode: http.StatusOK, ResponsePayload: []byte("{{}"), ExpectedErr: errJSONUnmarshal, + MockError: nil, }, { Description: "Succcess", ResponseCode: http.StatusOK, ResponsePayload: getRemoveItemValidPayload(), ExpectedOutput: getRemoveItemHappyOutput(), + MockError: nil, }, } @@ -479,7 +504,8 @@ func TestRemoveItem(t *testing.T) { require = require.New(t) bucket = "bucket-name" // nolint:gosec - id = "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7" + id = "7e8c5f378b4addbaebc70897c4478cca06009e3e360208ebd073dbee4b3774e7" + acquirer = new(MockAquirer) ) server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { assert.Equal(fmt.Sprintf("%s/%s/%s", storeAPIPath, bucket, id), r.URL.Path) @@ -494,9 +520,8 @@ func TestRemoveItem(t *testing.T) { Bucket: bucket, }) - if tc.ShouldMakeRequestFail { - client.auth = acquirerFunc(failAcquirer) - } + acquirer.On("AddAuth", mock.Anything).Return(tc.MockError) + client.auth = acquirer if tc.ShouldDoRequestFail { client.storeBaseURL = failingURL @@ -548,16 +573,6 @@ func TestTranslateStatusCode(t *testing.T) { } } -func failAcquirer() (string, error) { - return "", errors.New("always fail") -} - -type acquirerFunc func() (string, error) - -func (a acquirerFunc) Acquire() (string, error) { - return a() -} - func getRemoveItemValidPayload() []byte { return []byte(` { diff --git a/go.mod b/go.mod index 2ed92be..86d0c30 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,11 @@ toolchain go1.22.9 require ( github.com/aws/aws-sdk-go v1.54.19 github.com/go-kit/kit v0.13.0 - github.com/golang-jwt/jwt v3.2.2+incompatible github.com/lestrrat-go/jwx/v2 v2.1.2 github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_model v0.6.1 github.com/spf13/cast v1.6.0 github.com/stretchr/testify v1.9.0 - github.com/xmidt-org/bascule v0.11.6 github.com/xmidt-org/httpaux v0.4.0 github.com/xmidt-org/sallust v0.2.2 github.com/xmidt-org/touchstone v0.1.5 diff --git a/go.sum b/go.sum index fa68183..0fbe9c1 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,6 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -65,8 +63,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/xmidt-org/bascule v0.11.6 h1:i46FAI97XPMt3OKraiNyKa+mt36AhLO8iuInAypXKNM= -github.com/xmidt-org/bascule v0.11.6/go.mod h1:BXb5PEm/tjqdiEGsd+phm+fItMJx+Huv6LTCEU/zTzg= github.com/xmidt-org/httpaux v0.4.0 h1:cAL/MzIBpSsv4xZZeq/Eu1J5M3vfNe49xr41mP3COKU= github.com/xmidt-org/httpaux v0.4.0/go.mod h1:UypqZwuZV1nn8D6+K1JDb+im9IZrLNg/2oO/Bgiybxc= github.com/xmidt-org/sallust v0.2.2 h1:MrINLEr7cMj6ENx/O76fvpfd5LNGYnk7OipZAGXPYA0= diff --git a/jwtAcquireParser.go b/jwtAcquireParser.go deleted file mode 100644 index a27ed32..0000000 --- a/jwtAcquireParser.go +++ /dev/null @@ -1,76 +0,0 @@ -// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package ancla - -import ( - "errors" - "time" - - "github.com/golang-jwt/jwt" - "github.com/spf13/cast" - "github.com/xmidt-org/bascule/acquire" -) - -type jwtAcquireParserType string - -const ( - simpleType jwtAcquireParserType = "simple" - rawType jwtAcquireParserType = "raw" -) - -var ( - errMissingExpClaim = errors.New("missing exp claim in jwt") - errUnexpectedCasting = errors.New("unexpected casting error") -) - -type jwtAcquireParser struct { - token acquire.TokenParser - expiration acquire.ParseExpiration -} - -func rawTokenParser(data []byte) (string, error) { - return string(data), nil -} - -func rawTokenExpirationParser(data []byte) (time.Time, error) { - p := jwt.Parser{SkipClaimsValidation: true} - token, _, err := p.ParseUnverified(string(data), jwt.MapClaims{}) - if err != nil { - return time.Time{}, err - } - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - return time.Time{}, errUnexpectedCasting - } - expVal, ok := claims["exp"] - if !ok { - return time.Time{}, errMissingExpClaim - } - - exp, err := cast.ToInt64E(expVal) - if err != nil { - return time.Time{}, err - } - return time.Unix(exp, 0), nil -} - -func newJWTAcquireParser(pType jwtAcquireParserType) (jwtAcquireParser, error) { - if pType == "" { - pType = simpleType - } - if pType != simpleType && pType != rawType { - return jwtAcquireParser{}, errors.New("only 'simple' or 'raw' are supported as jwt acquire parser types") - } - // nil defaults are fine (bascule/acquire will use the simple - // default parsers internally). - var ( - tokenParser acquire.TokenParser - expirationParser acquire.ParseExpiration - ) - if pType == rawType { - tokenParser = rawTokenParser - expirationParser = rawTokenExpirationParser - } - return jwtAcquireParser{expiration: expirationParser, token: tokenParser}, nil -} diff --git a/jwtAcquireParser_test.go b/jwtAcquireParser_test.go deleted file mode 100644 index fed2a25..0000000 --- a/jwtAcquireParser_test.go +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-FileCopyrightText: 2022 Comcast Cable Communications Management, LLC -// SPDX-License-Identifier: Apache-2.0 - -package ancla - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestNewJWTAcquireParser(t *testing.T) { - tcs := []struct { - Description string - ParserType jwtAcquireParserType - ShouldFail bool - }{ - { - Description: "Default", - }, - { - Description: "Invalid type", - ParserType: "advanced", - ShouldFail: true, - }, - { - Description: "Simple", - ParserType: simpleType, - }, - { - Description: "Raw", - ParserType: rawType, - }, - } - - for _, tc := range tcs { - t.Run(tc.Description, func(t *testing.T) { - assert := assert.New(t) - p, err := newJWTAcquireParser(tc.ParserType) - if tc.ShouldFail { - assert.NotNil(err) - assert.Nil(p.expiration) - assert.Nil(p.token) - } else { - assert.Nil(err) - if tc.ParserType == rawType { - assert.NotNil(p.expiration) - assert.NotNil(p.token) - } - } - }) - } -} - -func TestRawTokenParser(t *testing.T) { - assert := assert.New(t) - payload := []byte("eyJhbGciOiJSUzI1NiIsImtpZCI6ImRldmVsb3BtZW50IiwidHlwIjoiSldUIn0.eyJhbGxvd2VkUmVzb3VyY2VzIjp7ImFsbG93ZWRQYXJ0bmVycyI6WyJjb21jYXN0Il19LCJhdWQiOiJYTWlEVCIsImNhcGFiaWxpdGllcyI6WyJ4MTppc3N1ZXI6dGVzdDouKjphbGwiLCJ4MTppc3N1ZXI6dWk6YWxsIl0sImV4cCI6MTYyMjE1Nzk4MSwiaWF0IjoxNjIyMDcxNTgxLCJpc3MiOiJkZXZlbG9wbWVudCIsImp0aSI6ImN4ZmkybTZDWnJjaFNoZ1Nzdi1EM3ciLCJuYmYiOjE2MjIwNzE1NjYsInBhcnRuZXItaWQiOiJjb21jYXN0Iiwic3ViIjoiY2xpZW50LXN1cHBsaWVkIiwidHJ1c3QiOjEwMDB9.7QzRWJgxGs1cEZunMOewYCnEDiq2CTDh5R5F47PYhkMVb2KxSf06PRRGN-rQSWPhhBbev1fGgu63mr3yp_VDmdVvHR2oYiKyxP2skJTSzfQmiRyLMYY5LcLn3BObyQxU8EnLhnqGIjpORW0L5Dd4QsaZmXRnkC73yGnJx4XCx0I") - token, err := rawTokenParser(payload) - assert.Equal(string(payload), token) - assert.Nil(err) -} - -func TestRawExpirationParser(t *testing.T) { - tcs := []struct { - Description string - Payload []byte - ShouldFail bool - ExpectedTime time.Time - }{ - { - Description: "Not a JWT", - Payload: []byte("xyz==abcNotAJWT"), - ShouldFail: true, - }, - { - Description: "A jwt", - Payload: []byte("eyJhbGciOiJSUzI1NiIsImtpZCI6ImRldmVsb3BtZW50IiwidHlwIjoiSldUIn0.eyJhbGxvd2VkUmVzb3VyY2VzIjp7ImFsbG93ZWRQYXJ0bmVycyI6WyJjb21jYXN0Il19LCJhdWQiOiJYTWlEVCIsImNhcGFiaWxpdGllcyI6WyJ4MTppc3N1ZXI6dGVzdDouKjphbGwiLCJ4MTppc3N1ZXI6dWk6YWxsIl0sImV4cCI6MTYyMjE1Nzk4MSwiaWF0IjoxNjIyMDcxNTgxLCJpc3MiOiJkZXZlbG9wbWVudCIsImp0aSI6ImN4ZmkybTZDWnJjaFNoZ1Nzdi1EM3ciLCJuYmYiOjE2MjIwNzE1NjYsInBhcnRuZXItaWQiOiJjb21jYXN0Iiwic3ViIjoiY2xpZW50LXN1cHBsaWVkIiwidHJ1c3QiOjEwMDB9.7QzRWJgxGs1cEZunMOewYCnEDiq2CTDh5R5F47PYhkMVb2KxSf06PRRGN-rQSWPhhBbev1fGgu63mr3yp_VDmdVvHR2oYiKyxP2skJTSzfQmiRyLMYY5LcLn3BObyQxU8EnLhnqGIjpORW0L5Dd4QsaZmXRnkC73yGnJx4XCx0I"), - ExpectedTime: time.Unix(1622157981, 0), - }, - } - - for _, tc := range tcs { - assert := assert.New(t) - exp, err := rawTokenExpirationParser(tc.Payload) - if tc.ShouldFail { - assert.NotNil(err) - assert.Empty(exp) - } else { - assert.Nil(err) - assert.Equal(tc.ExpectedTime, exp) - } - } -} diff --git a/service.go b/service.go index b0edad3..3f6cab1 100644 --- a/service.go +++ b/service.go @@ -42,13 +42,6 @@ type Service interface { type Config struct { BasicClientConfig chrysom.BasicClientConfig - // JWTParserType establishes which parser type will be used by the JWT token - // acquirer used by Argus. Options include 'simple' and 'raw'. - // Simple: parser assumes token payloads have the following structure: https://github.com/xmidt-org/bascule/blob/c011b128d6b95fa8358228535c63d1945347adaa/acquire/bearer.go#L77 - // Raw: parser assumes all of the token payload == JWT token - // (Optional). Defaults to 'simple' - JWTParserType jwtAcquireParserType - // DisablePartnerIDs, if true, will allow webhooks to register without // checking the validity of the partnerIDs in the request DisablePartnerIDs bool @@ -69,7 +62,6 @@ type ClientService struct { // NewService builds the Argus client service from the given configuration. func NewService(cfg Config) (*ClientService, error) { - prepArgusBasicClientConfig(&cfg) basic, err := chrysom.NewBasicClient(cfg.BasicClientConfig) if err != nil { return nil, fmt.Errorf("failed to create chrysom basic client: %v", err) @@ -133,16 +125,6 @@ func (s *ClientService) GetAll(ctx context.Context) ([]Register, error) { return iws, nil } -func prepArgusBasicClientConfig(cfg *Config) error { - p, err := newJWTAcquireParser(cfg.JWTParserType) - if err != nil { - return err - } - cfg.BasicClientConfig.Auth.JWT.GetToken = p.token - cfg.BasicClientConfig.Auth.JWT.GetExpiration = p.expiration - return nil -} - func prepArgusListenerConfig(cfg *chrysom.ListenerConfig, metrics chrysom.Measures, watches ...Watch) { watches = append(watches, webhookListSizeWatch(metrics.WebhookListSizeGauge)) cfg.Listener = chrysom.ListenerFunc(func(ctx context.Context, items chrysom.Items) { @@ -160,8 +142,10 @@ func prepArgusListenerConfig(cfg *chrysom.ListenerConfig, metrics chrysom.Measur type ServiceIn struct { fx.In + Config Config Client *http.Client + Auth chrysom.Acquirer } func ProvideService() fx.Option {