Skip to content

Commit

Permalink
update oauth2 method, adding id to client assertion
Browse files Browse the repository at this point in the history
Signed-off-by: Tien Nguyen <[email protected]>
  • Loading branch information
duytiennguyen-okta committed Dec 5, 2024
1 parent 964f066 commit 3a618dd
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 244 deletions.
86 changes: 0 additions & 86 deletions .generator/templates/cache_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package okta

import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"

"github.com/jarcoal/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Create_Cache_Key(t *testing.T) {
Expand Down Expand Up @@ -74,86 +71,3 @@ func Test_Cache_Cleared_Successful(t *testing.T) {
found = myCache.Has(cacheKey)
assert.False(t, found, "cache was not cleared")
}

func TestOAuthTokensAlwaysCached(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
WithCache(false)
cfg, err := NewConfiguration(
WithCache(false),
WithOrgUrl("https://testing.oktapreview.com"),
WithAuthorizationMode("PrivateKey"),
WithClientId("abc"),
WithPrivateKey(`
-----BEGIN RSA PRIVATE KEY-----
MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu
KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm
o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k
TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7
9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy
v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs
/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00
-----END RSA PRIVATE KEY-----
`),
WithScopes(([]string{"okta.users.read"})),
)
require.NoError(t, err, "Creating a new config should not error")

client := NewAPIClient(cfg)

accessToken := RequestAccessToken{
TokenType: "Bearer",
ExpiresIn: 3600,
AccessToken: "xyz",
Scope: "okta.users.read",
}
httpmockTokenURLRegex := `=~^https://testing\.oktapreview\.com/oauth2/v1/token\?client_assertion=.*\z`
jsonResp, err := httpmock.NewJsonResponder(200, accessToken)
require.NoError(t, err)
httpmock.RegisterResponder("POST", httpmockTokenURLRegex, jsonResp)

adminConsole := Application{}
adminConsole.SetId("abc123")
adminConsole.SetStatus("ACTIVE")
adminConsole.SetLabel("Okta Admin Console")
apps1 := []*Application{
&adminConsole,
}
jsonResp, err = httpmock.NewJsonResponder(200, apps1)
require.NoError(t, err)
httpmockAdminConsoleRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Admin\+Console.*\z`
httpmock.RegisterResponder("GET", httpmockAdminConsoleRegex, jsonResp)

dashboard := Application{}
adminConsole.SetId("def456")
adminConsole.SetStatus("ACTIVE")
adminConsole.SetLabel("Okta Dashboard")
apps2 := []*Application{
&dashboard,
}
jsonResp, err = httpmock.NewJsonResponder(200, apps2)
require.NoError(t, err)
httpmockDashboardRegex := `=~^https://testing\.oktapreview\.com/api/v1/apps?.*q\=Okta\+Dashboard.*\z`
httpmock.RegisterResponder("GET", httpmockDashboardRegex, jsonResp)

_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute()
require.NoError(t, err)
_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Admin Console").Execute()
require.NoError(t, err)

_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute()
require.NoError(t, err)
_, _, err = client.ApplicationAPI.ListApplications(cfg.Context).Limit(1).Filter("status eq ACTIVE").Q("Okta Dashboard").Execute()
require.NoError(t, err)

info := httpmock.GetCallCountInfo()
totalCalls := httpmock.GetTotalCallCount()

assert.Equal(t, 5, totalCalls, fmt.Sprintf("there should only be 5 API calls in this test but there were %d calls", totalCalls))
// Tokens from requests should be cached.
require.True(t, info[fmt.Sprintf("POST %s", httpmockTokenURLRegex)] == 1, "tokens endpoint should only be called once")

// But all other requests should not be cached.
require.True(t, info[fmt.Sprintf("GET %s", httpmockAdminConsoleRegex)] == 2)
require.True(t, info[fmt.Sprintf("GET %s", httpmockDashboardRegex)] == 2)
}
77 changes: 47 additions & 30 deletions .generator/templates/client.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ var (
)

const (
VERSION = "{{{packageVersion}}}"
VERSION = "{{{packageVersion}}}"
AccessTokenCacheKey = "OKTA_ACCESS_TOKEN"
DpopAccessTokenNonce = "DPOP_OKTA_ACCESS_TOKEN_NONCE"
DpopAccessTokenPrivateKey = "DPOP_OKTA_ACCESS_TOKEN_PRIVATE_KEY"
Expand All @@ -59,9 +59,9 @@ const (
// APIClient manages communication with the {{appName}} API v{{version}}
// In most cases there should be only one, shared, APIClient.
type APIClient struct {
cfg *Configuration
common service // Reuse a single struct instead of allocating one for each service on the heap.
cache Cache
cfg *Configuration
common service // Reuse a single struct instead of allocating one for each service on the heap.
cache Cache
tokenCache *goCache.Cache
freshcache bool
Expand Down Expand Up @@ -196,7 +196,7 @@ func (a *PrivateKeyAuth) Authorize(method, URL string) error {
return err
}

accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff)
accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, a.clientId, a.privateKeySigner)
if err != nil {
return err
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func (a *JWTAuth) Authorize(method, URL string) error {
}
}
} else {
accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff)
accessToken, nonce, privateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, a.clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -408,7 +408,7 @@ func (a *JWKAuth) Authorize(method, URL string) error {
return err
}

accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff)
accessToken, nonce, dpopPrivateKey, err := getAccessTokenForPrivateKey(a.httpClient, a.orgURL, clientAssertion, a.userAgent, a.scopes, a.maxRetries, a.maxBackoff, "", nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -446,16 +446,16 @@ func convertJWKToPrivateKey(jwks, encryptionType string) (string, error) {
pair := it.Pair()
key := pair.Value.(jwk.Key)
var rawkey interface{} // This is the raw key, like *rsa.PrivateKey or *ecdsa.PrivateKey
err := key.Raw(&rawkey);
err := key.Raw(&rawkey)
if err != nil {
return "",err
return "", err
}

switch encryptionType {
case "RSA":
rsaPrivateKey, ok := rawkey.(*rsa.PrivateKey)
if !ok {
return "",fmt.Errorf("expected rsa key, got %T", rawkey)
return "", fmt.Errorf("expected rsa key, got %T", rawkey)
}
return string(privateKeyToBytes(rsaPrivateKey)), nil
default:
Expand Down Expand Up @@ -514,26 +514,25 @@ func createClientAssertion(orgURL, clientID string, privateKeySinger jose.Signer
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(1))),
Issuer: clientID,
Audience: orgURL + "/oauth2/v1/token",
ID: uuid.New().String(),
}
jwtBuilder := jwt.Signed(privateKeySinger).Claims(claims)
return jwtBuilder.CompactSerialize()
}

func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
var tokenRequestBuff io.ReadWriter
func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertion, userAgent string, scopes []string, maxRetries int32, maxBackoff int64, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
query := url.Values{}
tokenRequestURL := orgURL + "/oauth2/v1/token"

query.Add("grant_type", "client_credentials")
query.Add("scope", strings.Join(scopes, " "))
query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
query.Add("client_assertion", clientAssertion)
tokenRequestURL += "?" + query.Encode()
tokenRequest, err := http.NewRequest("POST", tokenRequestURL, tokenRequestBuff)

tokenRequest, err := http.NewRequest("POST", tokenRequestURL, strings.NewReader(query.Encode()))
if err != nil {
return nil, "", nil, err
}

tokenRequest.Header.Add("Accept", "application/json")
tokenRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
tokenRequest.Header.Add("User-Agent", userAgent)
Expand All @@ -552,14 +551,20 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio
if err != nil {
return nil, "", nil, err
}

respBody, err := io.ReadAll(tokenResponse.Body)
origResp := io.NopCloser(bytes.NewBuffer(respBody))
tokenResponse.Body = origResp
var accessToken *RequestAccessToken

newClientAssertion, err := createClientAssertion(orgURL, clientID, signer)
if err != nil {
return nil, "", nil, err
}

if tokenResponse.StatusCode >= 300 {
if strings.Contains(string(respBody), "invalid_dpop_proof") {
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff)
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, "", maxRetries, maxBackoff, newClientAssertion, strings.Join(scopes, " "), clientID, signer)
} else {
return nil, "", nil, err
}
Expand All @@ -572,7 +577,7 @@ func getAccessTokenForPrivateKey(httpClient *http.Client, orgURL, clientAssertio
return accessToken, "", nil, nil
}

func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *http.Client, orgURL, nonce string, maxRetries int32, maxBackoff int64, clientAssertion string, scopes string, clientID string, signer jose.Signer) (*RequestAccessToken, string, *rsa.PrivateKey, error) {
privateKey, err := generatePrivateKey(2048)
if err != nil {
return nil, "", nil, err
Expand All @@ -581,7 +586,19 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt
if err != nil {
return nil, "", nil, err
}
newClientAssertion, err := createClientAssertion(orgURL, clientID, signer)
if err != nil {
return nil, "", nil, err
}

query := url.Values{}
query.Add("grant_type", "client_credentials")
query.Add("scope", scopes)
query.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
query.Add("client_assertion", newClientAssertion)
tokenRequest.Body = io.NopCloser(strings.NewReader(query.Encode()))
tokenRequest.Header.Set("DPoP", dpopJWT)

bOff := &oktaBackoff{
ctx: context.TODO(),
maxRetries: maxRetries,
Expand All @@ -603,9 +620,9 @@ func getAccessTokenForDpopPrivateKey(tokenRequest *http.Request, httpClient *htt
}

if tokenResponse.StatusCode >= 300 {
if strings.Contains(string(respBody), "use_dpop_nonce") {
if strings.Contains(string(respBody), "use_dpop_nonce") {
newNonce := tokenResponse.Header.Get("Dpop-Nonce")
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff)
return getAccessTokenForDpopPrivateKey(tokenRequest, httpClient, orgURL, newNonce, maxRetries, maxBackoff, clientAssertion, scopes, clientID, signer)
} else {
return nil, "", nil, err
}
Expand Down Expand Up @@ -780,9 +797,9 @@ func (c *APIClient) GetConfig() *Configuration {
}

type formFile struct {
fileBytes []byte
fileName string
formFileName string
fileBytes []byte
fileName string
formFileName string
}

// prepareRequest build the request
Expand Down Expand Up @@ -836,11 +853,11 @@ func (c *APIClient) prepareRequest(
w.Boundary()
part, err := w.CreateFormFile(formFile.formFileName, filepath.Base(formFile.fileName))
if err != nil {
return nil, err
return nil, err
}
_, err = part.Write(formFile.fileBytes)
if err != nil {
return nil, err
return nil, err
}
}
}
Expand Down Expand Up @@ -879,7 +896,7 @@ func (c *APIClient) prepareRequest(
URL.Scheme = c.cfg.Scheme
}

var urlWithoutQuery = *URL
urlWithoutQuery := *URL

// Adding Query Param
query := URL.Query()
Expand Down Expand Up @@ -1103,7 +1120,7 @@ func (c *APIClient) RefreshNext() *APIClient {
return c
}

func (c *APIClient) do(ctx context.Context, req *http.Request)(*http.Response, error){
func (c *APIClient) do(ctx context.Context, req *http.Request) (*http.Response, error) {
cacheKey := CreateCacheKey(req)
if req.Method != http.MethodGet {
c.cache.Delete(cacheKey)
Expand Down Expand Up @@ -1343,9 +1360,9 @@ func (e GenericOpenAPIError) Model() interface{} {

// Okta Backoff
type oktaBackoff struct {
retryCount, maxRetries int32
backoffDuration time.Duration
ctx context.Context
retryCount, maxRetries int32
backoffDuration time.Duration
ctx context.Context
}

// NextBackOff returns the duration to wait before retrying the operation,
Expand Down Expand Up @@ -1456,7 +1473,7 @@ func generateDpopJWT(privateKey *rsa.PrivateKey, httpMethod, URL, nonce, accessT
return "", err
}
key := jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}
var signerOpts = jose.SignerOptions{}
signerOpts := jose.SignerOptions{}
signerOpts.WithType("dpop+jwt")
signerOpts.WithHeader("jwk", set)
rsaSigner, err := jose.NewSigner(key, &signerOpts)
Expand Down
Loading

0 comments on commit 3a618dd

Please sign in to comment.