From 3a618ddc42183b28dc9c7a3b2ebaad3266e40ab1 Mon Sep 17 00:00:00 2001 From: Tien Nguyen Date: Thu, 5 Dec 2024 14:12:54 -0500 Subject: [PATCH] update oauth2 method, adding id to client assertion Signed-off-by: Tien Nguyen --- .generator/templates/cache_test.go | 86 ----------------------- .generator/templates/client.mustache | 77 ++++++++++++-------- okta/cache_test.go | 86 ----------------------- okta/client.go | 101 ++++++++++++++++----------- okta/main_test.go | 2 +- 5 files changed, 108 insertions(+), 244 deletions(-) diff --git a/.generator/templates/cache_test.go b/.generator/templates/cache_test.go index 56659c3d..7d57c93c 100644 --- a/.generator/templates/cache_test.go +++ b/.generator/templates/cache_test.go @@ -1,16 +1,13 @@ package okta import ( - "fmt" "io" "io/ioutil" "net/http" "net/http/httptest" "testing" - "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func Test_Create_Cache_Key(t *testing.T) { @@ -74,86 +71,3 @@ func Test_Cache_Cleared_Successful(t *testing.T) { found = myCache.Has(cacheKey) assert.False(t, found, "cache was not cleared") } - -func TestOAuthTokensAlwaysCached(t *testing.T) { - httpmock.Activate() - defer httpmock.DeactivateAndReset() - WithCache(false) - cfg, err := NewConfiguration( - WithCache(false), - WithOrgUrl("https://testing.oktapreview.com"), - WithAuthorizationMode("PrivateKey"), - WithClientId("abc"), - WithPrivateKey(` ------BEGIN RSA PRIVATE KEY----- -MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu -KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm -o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k -TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7 -9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy -v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs -/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00 ------END RSA PRIVATE KEY----- - `), - WithScopes(([]string{"okta.users.read"})), - ) - require.NoError(t, err, "Creating a new config should not error") - - client := NewAPIClient(cfg) - - accessToken := RequestAccessToken{ - TokenType: "Bearer", - ExpiresIn: 3600, - AccessToken: "xyz", - Scope: "okta.users.read", - } - httpmockTokenURLRegex := `=~^https://testing\.oktapreview\.com/oauth2/v1/token\?client_assertion=.*\z` - jsonResp, err := httpmock.NewJsonResponder(200, accessToken) - require.NoError(t, err) - httpmock.RegisterResponder("POST", httpmockTokenURLRegex, jsonResp) - - adminConsole := Application{} - adminConsole.SetId("abc123") - adminConsole.SetStatus("ACTIVE") - adminConsole.SetLabel("Okta Admin Console") - apps1 := []*Application{ - &adminConsole, - } - jsonResp, err = httpmock.NewJsonResponder(200, apps1) - require.NoError(t, err) - httpmockAdminConsoleRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Admin\+Console.*\z` - httpmock.RegisterResponder("GET", httpmockAdminConsoleRegex, jsonResp) - - dashboard := Application{} - adminConsole.SetId("def456") - adminConsole.SetStatus("ACTIVE") - adminConsole.SetLabel("Okta Dashboard") - apps2 := []*Application{ - &dashboard, - } - jsonResp, err = httpmock.NewJsonResponder(200, apps2) - require.NoError(t, err) - httpmockDashboardRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Dashboard.*\z` - httpmock.RegisterResponder("GET", httpmockDashboardRegex, jsonResp) - - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute() - require.NoError(t, err) - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute() - require.NoError(t, err) - - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute() - require.NoError(t, err) - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute() - require.NoError(t, err) - - info := httpmock.GetCallCountInfo() - totalCalls := httpmock.GetTotalCallCount() - - assert.Equal(t, 5, totalCalls, fmt.Sprintf("there should only be 5 API calls in this test but there were %d calls", totalCalls)) - // Tokens from requests should be cached. - require.True(t, info[fmt.Sprintf("POST %s", httpmockTokenURLRegex)] == 1, "tokens endpoint should only be called once") - - // But all other requests should not be cached. - require.True(t, info[fmt.Sprintf("GET %s", httpmockAdminConsoleRegex)] == 2) - require.True(t, info[fmt.Sprintf("GET %s", httpmockDashboardRegex)] == 2) -} diff --git a/.generator/templates/client.mustache b/.generator/templates/client.mustache index ba69f615..0231bb61 100644 --- a/.generator/templates/client.mustache +++ b/.generator/templates/client.mustache @@ -50,7 +50,7 @@ var ( ) const ( - VERSION = "{{{packageVersion}}}" + VERSION = "{{{packageVersion}}}" AccessTokenCacheKey = "OKTA_ACCESS_TOKEN" DpopAccessTokenNonce = "DPOP_OKTA_ACCESS_TOKEN_NONCE" DpopAccessTokenPrivateKey = "DPOP_OKTA_ACCESS_TOKEN_PRIVATE_KEY" @@ -59,9 +59,9 @@ const ( // APIClient manages communication with the {{appName}} API v{{version}} // In most cases there should be only one, shared, APIClient. type APIClient struct { - cfg *Configuration - common service // Reuse a single struct instead of allocating one for each service on the heap. - cache Cache + cfg *Configuration + common service // Reuse a single struct instead of allocating one for each service on the heap. + cache Cache tokenCache *goCache.Cache freshcache bool @@ -196,7 +196,7 @@ func (a *PrivateKeyAuth) Authorize(method, URL string) error { return err } - accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff) + accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, a.clientId, a.privateKeySigner) if err != nil { return err } @@ -287,7 +287,7 @@ func (a *JWTAuth) Authorize(method, URL string) error { } } } else { - accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff) + accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil) if err != nil { return err } @@ -408,7 +408,7 @@ func (a *JWKAuth) Authorize(method, URL string) error { return err } - accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff) + accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil) if err != nil { return err } @@ -446,16 +446,16 @@ func convertJWKToPrivateKey(jwks, encryptionType string) (string, error) { pair := it.Pair() key := pair.Value.(jwk.Key) var rawkey interface{} // This is the raw key, like *rsa.PrivateKey or *ecdsa.PrivateKey - err := key.Raw(&rawkey); + err := key.Raw(&rawkey) if err != nil { - return "",err + return "", err } switch encryptionType { case "RSA": rsaPrivateKey, ok := rawkey.(*rsa.PrivateKey) if !ok { - return "",fmt.Errorf("expected rsa key, got %T", rawkey) + return "", fmt.Errorf("expected rsa key, got %T", rawkey) } return string(privateKeyToBytes(rsaPrivateKey)), nil default: @@ -514,13 +514,13 @@ func createClientAssertion(orgURL, clientID string, privateKeySinger jose.Signer Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))), Issuer: clientID, Audience: orgURL + "/oauth2/v1/token", + ID: uuid.New().String(), } jwtBuilder := jwt.Signed(privateKeySinger).Claims(claims) return jwtBuilder.CompactSerialize() } -func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) { - var tokenRequestBuff io.ReadWriter +func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) { query := url.Values{} tokenRequestURL := orgURL + "/oauth2/v1/token" @@ -528,12 +528,11 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio query.Add("scope", strings.Join(scopes, " ")) query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") query.Add("client_assertion", clientAssertion) - tokenRequestURL += "?" + query.Encode() - tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff) + + tokenRequest, err := http.NewRequest("POST", tokenRequestURL, strings.NewReader(query.Encode())) if err != nil { return nil, "", nil, err } - tokenRequest.Header.Add("Accept", "application/json") tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") tokenRequest.Header.Add("User-Agent", userAgent) @@ -552,14 +551,20 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio if err != nil { return nil, "", nil, err } + respBody, err := io.ReadAll(tokenResponse.Body) origResp := io.NopCloser(bytes.NewBuffer(respBody)) tokenResponse.Body = origResp var accessToken *RequestAccessToken + newClientAssertion, err := createClientAssertion(orgURL, clientID, signer) + if err != nil { + return nil, "", nil, err + } + if tokenResponse.StatusCode >= 300 { if strings.Contains(string(respBody), "invalid_dpop_proof") { - return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff) + return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff, newClientAssertion, strings.Join(scopes, " "), clientID, signer) } else { return nil, "", nil, err } @@ -572,7 +577,7 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio return accessToken, "", nil, nil } -func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) { +func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64, clientAssertion string, scopes string, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) { privateKey, err := generatePrivateKey(2048) if err != nil { return nil, "", nil, err @@ -581,7 +586,19 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt if err != nil { return nil, "", nil, err } + newClientAssertion, err := createClientAssertion(orgURL, clientID, signer) + if err != nil { + return nil, "", nil, err + } + + query := url.Values{} + query.Add("grant_type", "client_credentials") + query.Add("scope", scopes) + query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + query.Add("client_assertion", newClientAssertion) + tokenRequest.Body = io.NopCloser(strings.NewReader(query.Encode())) tokenRequest.Header.Set("DPoP", dpopJWT) + bOff := &oktaBackoff{ ctx: context.TODO(), maxRetries: maxRetries, @@ -603,9 +620,9 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt } if tokenResponse.StatusCode >= 300 { - if strings.Contains(string(respBody), "use_dpop_nonce") { + if strings.Contains(string(respBody), "use_dpop_nonce") { newNonce := tokenResponse.Header.Get("Dpop-Nonce") - return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff) + return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff, clientAssertion, scopes, clientID, signer) } else { return nil, "", nil, err } @@ -780,9 +797,9 @@ func (c *APIClient) GetConfig() *Configuration { } type formFile struct { - fileBytes []byte - fileName string - formFileName string + fileBytes []byte + fileName string + formFileName string } // prepareRequest build the request @@ -836,11 +853,11 @@ func (c *APIClient) prepareRequest( w.Boundary() part, err := w.CreateFormFile(formFile.formFileName, filepath.Base(formFile.fileName)) if err != nil { - return nil, err + return nil, err } _, err = part.Write(formFile.fileBytes) if err != nil { - return nil, err + return nil, err } } } @@ -879,7 +896,7 @@ func (c *APIClient) prepareRequest( URL.Scheme = c.cfg.Scheme } - var urlWithoutQuery = *URL + urlWithoutQuery := *URL // Adding Query Param query := URL.Query() @@ -1103,7 +1120,7 @@ func (c *APIClient) RefreshNext() *APIClient { return c } -func (c *APIClient) do(ctx context.Context, req *http.Request)(*http.Response, error){ +func (c *APIClient) do(ctx context.Context, req *http.Request) (*http.Response, error) { cacheKey := CreateCacheKey(req) if req.Method != http.MethodGet { c.cache.Delete(cacheKey) @@ -1343,9 +1360,9 @@ func (e GenericOpenAPIError) Model() interface{} { // Okta Backoff type oktaBackoff struct { - retryCount, maxRetries int32 - backoffDuration time.Duration - ctx context.Context + retryCount, maxRetries int32 + backoffDuration time.Duration + ctx context.Context } // NextBackOff returns the duration to wait before retrying the operation, @@ -1456,7 +1473,7 @@ func generateDpopJWT(privateKey *rsa.PrivateKey, httpMethod, URL, nonce, accessT return "", err } key := jose.SigningKey{Algorithm: jose.RS256, Key: privateKey} - var signerOpts = jose.SignerOptions{} + signerOpts := jose.SignerOptions{} signerOpts.WithType("dpop+jwt") signerOpts.WithHeader("jwk", set) rsaSigner, err := jose.NewSigner(key, &signerOpts) diff --git a/okta/cache_test.go b/okta/cache_test.go index 56659c3d..7d57c93c 100644 --- a/okta/cache_test.go +++ b/okta/cache_test.go @@ -1,16 +1,13 @@ package okta import ( - "fmt" "io" "io/ioutil" "net/http" "net/http/httptest" "testing" - "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func Test_Create_Cache_Key(t *testing.T) { @@ -74,86 +71,3 @@ func Test_Cache_Cleared_Successful(t *testing.T) { found = myCache.Has(cacheKey) assert.False(t, found, "cache was not cleared") } - -func TestOAuthTokensAlwaysCached(t *testing.T) { - httpmock.Activate() - defer httpmock.DeactivateAndReset() - WithCache(false) - cfg, err := NewConfiguration( - WithCache(false), - WithOrgUrl("https://testing.oktapreview.com"), - WithAuthorizationMode("PrivateKey"), - WithClientId("abc"), - WithPrivateKey(` ------BEGIN RSA PRIVATE KEY----- -MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu -KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm -o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k -TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7 -9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy -v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs -/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00 ------END RSA PRIVATE KEY----- - `), - WithScopes(([]string{"okta.users.read"})), - ) - require.NoError(t, err, "Creating a new config should not error") - - client := NewAPIClient(cfg) - - accessToken := RequestAccessToken{ - TokenType: "Bearer", - ExpiresIn: 3600, - AccessToken: "xyz", - Scope: "okta.users.read", - } - httpmockTokenURLRegex := `=~^https://testing\.oktapreview\.com/oauth2/v1/token\?client_assertion=.*\z` - jsonResp, err := httpmock.NewJsonResponder(200, accessToken) - require.NoError(t, err) - httpmock.RegisterResponder("POST", httpmockTokenURLRegex, jsonResp) - - adminConsole := Application{} - adminConsole.SetId("abc123") - adminConsole.SetStatus("ACTIVE") - adminConsole.SetLabel("Okta Admin Console") - apps1 := []*Application{ - &adminConsole, - } - jsonResp, err = httpmock.NewJsonResponder(200, apps1) - require.NoError(t, err) - httpmockAdminConsoleRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Admin\+Console.*\z` - httpmock.RegisterResponder("GET", httpmockAdminConsoleRegex, jsonResp) - - dashboard := Application{} - adminConsole.SetId("def456") - adminConsole.SetStatus("ACTIVE") - adminConsole.SetLabel("Okta Dashboard") - apps2 := []*Application{ - &dashboard, - } - jsonResp, err = httpmock.NewJsonResponder(200, apps2) - require.NoError(t, err) - httpmockDashboardRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Dashboard.*\z` - httpmock.RegisterResponder("GET", httpmockDashboardRegex, jsonResp) - - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute() - require.NoError(t, err) - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute() - require.NoError(t, err) - - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute() - require.NoError(t, err) - _, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute() - require.NoError(t, err) - - info := httpmock.GetCallCountInfo() - totalCalls := httpmock.GetTotalCallCount() - - assert.Equal(t, 5, totalCalls, fmt.Sprintf("there should only be 5 API calls in this test but there were %d calls", totalCalls)) - // Tokens from requests should be cached. - require.True(t, info[fmt.Sprintf("POST %s", httpmockTokenURLRegex)] == 1, "tokens endpoint should only be called once") - - // But all other requests should not be cached. - require.True(t, info[fmt.Sprintf("GET %s", httpmockAdminConsoleRegex)] == 2) - require.True(t, info[fmt.Sprintf("GET %s", httpmockDashboardRegex)] == 2) -} diff --git a/okta/client.go b/okta/client.go index 59053696..8a6f3800 100644 --- a/okta/client.go +++ b/okta/client.go @@ -1,7 +1,7 @@ /* Okta Admin Management -Allows customers to easily access the Okta Management APIs +# Allows customers to easily access the Okta Management APIs Copyright 2018 - Present Okta, Inc. @@ -9,7 +9,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -57,10 +57,10 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" - "golang.org/x/oauth2" "github.com/google/uuid" "github.com/lestrrat-go/jwx/jwk" goCache "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" ) var ( @@ -69,7 +69,7 @@ var ( ) const ( - VERSION = "5.0.0" + VERSION = "5.0.0" AccessTokenCacheKey = "OKTA_ACCESS_TOKEN" DpopAccessTokenNonce = "DPOP_OKTA_ACCESS_TOKEN_NONCE" DpopAccessTokenPrivateKey = "DPOP_OKTA_ACCESS_TOKEN_PRIVATE_KEY" @@ -78,9 +78,9 @@ const ( // APIClient manages communication with the Okta Admin Management API v2024.06.1 // In most cases there should be only one, shared, APIClient. type APIClient struct { - cfg *Configuration - common service // Reuse a single struct instead of allocating one for each service on the heap. - cache Cache + cfg *Configuration + common service // Reuse a single struct instead of allocating one for each service on the heap. + cache Cache tokenCache *goCache.Cache freshcache bool @@ -283,7 +283,7 @@ type PrivateKeyAuth struct { privateKeyId string clientId string orgURL string - userAgent string + userAgent string scopes []string maxRetries int32 maxBackoff int64 @@ -298,7 +298,7 @@ type PrivateKeyAuthConfig struct { PrivateKeyId string ClientId string OrgURL string - UserAgent string + UserAgent string Scopes []string MaxRetries int32 MaxBackoff int64 @@ -359,7 +359,7 @@ func (a *PrivateKeyAuth) Authorize(method, URL string) error { return err } - accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff) + accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, a.clientId, a.privateKeySigner) if err != nil { return err } @@ -392,7 +392,7 @@ type JWTAuth struct { tokenCache *goCache.Cache httpClient *http.Client orgURL string - userAgent string + userAgent string scopes []string clientAssertion string maxRetries int32 @@ -404,7 +404,7 @@ type JWTAuthConfig struct { TokenCache *goCache.Cache HttpClient *http.Client OrgURL string - UserAgent string + UserAgent string Scopes []string ClientAssertion string MaxRetries int32 @@ -417,7 +417,7 @@ func NewJWTAuth(config JWTAuthConfig) *JWTAuth { tokenCache: config.TokenCache, httpClient: config.HttpClient, orgURL: config.OrgURL, - userAgent: config.UserAgent, + userAgent: config.UserAgent, scopes: config.Scopes, clientAssertion: config.ClientAssertion, maxRetries: config.MaxRetries, @@ -450,7 +450,7 @@ func (a *JWTAuth) Authorize(method, URL string) error { } } } else { - accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff) + accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil) if err != nil { return err } @@ -499,7 +499,7 @@ type JWKAuth struct { type JWKAuthConfig struct { TokenCache *goCache.Cache HttpClient *http.Client - JWK string + JWK string EncryptionType string PrivateKeySigner jose.Signer PrivateKeyId string @@ -516,8 +516,8 @@ func NewJWKAuth(config JWKAuthConfig) *JWKAuth { return &JWKAuth{ tokenCache: config.TokenCache, httpClient: config.HttpClient, - jwk: config.JWK, - encryptionType: config.EncryptionType, + jwk: config.JWK, + encryptionType: config.EncryptionType, privateKeySigner: config.PrivateKeySigner, privateKeyId: config.PrivateKeyId, clientId: config.ClientId, @@ -571,7 +571,7 @@ func (a *JWKAuth) Authorize(method, URL string) error { return err } - accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff) + accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil) if err != nil { return err } @@ -609,16 +609,16 @@ func convertJWKToPrivateKey(jwks, encryptionType string) (string, error) { pair := it.Pair() key := pair.Value.(jwk.Key) var rawkey interface{} // This is the raw key, like *rsa.PrivateKey or *ecdsa.PrivateKey - err := key.Raw(&rawkey); + err := key.Raw(&rawkey) if err != nil { - return "",err + return "", err } switch encryptionType { case "RSA": rsaPrivateKey, ok := rawkey.(*rsa.PrivateKey) if !ok { - return "",fmt.Errorf("expected rsa key, got %T", rawkey) + return "", fmt.Errorf("expected rsa key, got %T", rawkey) } return string(privateKeyToBytes(rsaPrivateKey)), nil default: @@ -677,13 +677,13 @@ func createClientAssertion(orgURL, clientID string, privateKeySinger jose.Signer Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))), Issuer: clientID, Audience: orgURL + "/oauth2/v1/token", + ID: uuid.New().String(), } jwtBuilder := jwt.Signed(privateKeySinger).Claims(claims) return jwtBuilder.CompactSerialize() } -func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) { - var tokenRequestBuff io.ReadWriter +func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) { query := url.Values{} tokenRequestURL := orgURL + "/oauth2/v1/token" @@ -691,12 +691,12 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio query.Add("scope", strings.Join(scopes, " ")) query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") query.Add("client_assertion", clientAssertion) - tokenRequestURL += "?" + query.Encode() - tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff) + + tokenRequest, err := http.NewRequest("POST", tokenRequestURL, strings.NewReader(query.Encode())) if err != nil { return nil, "", nil, err } - + tokenRequest.Header.Add("Accept", "application/json") tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") tokenRequest.Header.Add("User-Agent", userAgent) @@ -715,14 +715,20 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio if err != nil { return nil, "", nil, err } + respBody, err := io.ReadAll(tokenResponse.Body) origResp := io.NopCloser(bytes.NewBuffer(respBody)) tokenResponse.Body = origResp var accessToken *RequestAccessToken + newClientAssertion, err := createClientAssertion(orgURL, clientID, signer) + if err != nil { + return nil, "", nil, err + } + if tokenResponse.StatusCode >= 300 { if strings.Contains(string(respBody), "invalid_dpop_proof") { - return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff) + return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff, newClientAssertion, strings.Join(scopes, " "), clientID, signer) } else { return nil, "", nil, err } @@ -735,7 +741,7 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio return accessToken, "", nil, nil } -func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) { +func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64, clientAssertion string, scopes string, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) { privateKey, err := generatePrivateKey(2048) if err != nil { return nil, "", nil, err @@ -744,7 +750,19 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt if err != nil { return nil, "", nil, err } + newClientAssertion, err := createClientAssertion(orgURL, clientID, signer) + if err != nil { + return nil, "", nil, err + } + + query := url.Values{} + query.Add("grant_type", "client_credentials") + query.Add("scope", scopes) + query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + query.Add("client_assertion", newClientAssertion) + tokenRequest.Body = io.NopCloser(strings.NewReader(query.Encode())) tokenRequest.Header.Set("DPoP", dpopJWT) + bOff := &oktaBackoff{ ctx: context.TODO(), maxRetries: maxRetries, @@ -760,15 +778,16 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt if err != nil { return nil, "", nil, err } + respBody, err := io.ReadAll(tokenResponse.Body) if err != nil { return nil, "", nil, err } if tokenResponse.StatusCode >= 300 { - if strings.Contains(string(respBody), "use_dpop_nonce") { + if strings.Contains(string(respBody), "use_dpop_nonce") { newNonce := tokenResponse.Header.Get("Dpop-Nonce") - return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff) + return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff, clientAssertion, scopes, clientID, signer) } else { return nil, "", nil, err } @@ -1012,9 +1031,9 @@ func (c *APIClient) GetConfig() *Configuration { } type formFile struct { - fileBytes []byte - fileName string - formFileName string + fileBytes []byte + fileName string + formFileName string } // prepareRequest build the request @@ -1068,11 +1087,11 @@ func (c *APIClient) prepareRequest( w.Boundary() part, err := w.CreateFormFile(formFile.formFileName, filepath.Base(formFile.fileName)) if err != nil { - return nil, err + return nil, err } _, err = part.Write(formFile.fileBytes) if err != nil { - return nil, err + return nil, err } } } @@ -1111,7 +1130,7 @@ func (c *APIClient) prepareRequest( URL.Scheme = c.cfg.Scheme } - var urlWithoutQuery = *URL + urlWithoutQuery := *URL // Adding Query Param query := URL.Query() @@ -1287,7 +1306,7 @@ func (c *APIClient) RefreshNext() *APIClient { return c } -func (c *APIClient) do(ctx context.Context, req *http.Request)(*http.Response, error){ +func (c *APIClient) do(ctx context.Context, req *http.Request) (*http.Response, error) { cacheKey := CreateCacheKey(req) if req.Method != http.MethodGet { c.cache.Delete(cacheKey) @@ -1527,9 +1546,9 @@ func (e GenericOpenAPIError) Model() interface{} { // Okta Backoff type oktaBackoff struct { - retryCount, maxRetries int32 - backoffDuration time.Duration - ctx context.Context + retryCount, maxRetries int32 + backoffDuration time.Duration + ctx context.Context } // NextBackOff returns the duration to wait before retrying the operation, @@ -1640,7 +1659,7 @@ func generateDpopJWT(privateKey *rsa.PrivateKey, httpMethod, URL, nonce, accessT return "", err } key := jose.SigningKey{Algorithm: jose.RS256, Key: privateKey} - var signerOpts = jose.SignerOptions{} + signerOpts := jose.SignerOptions{} signerOpts.WithType("dpop+jwt") signerOpts.WithHeader("jwk", set) rsaSigner, err := jose.NewSigner(key, &signerOpts) diff --git a/okta/main_test.go b/okta/main_test.go index 83260b7f..03fdcd6c 100644 --- a/okta/main_test.go +++ b/okta/main_test.go @@ -15,7 +15,7 @@ func init() { if err != nil { fmt.Printf("Create new config should not be error %v", err) } - configuration.Debug = true + configuration.Debug = false apiClient = NewAPIClient(configuration) }