Skip to content

Commit

Permalink
Fix request context cancel race when reading response (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
eamonnotoole authored Nov 13, 2024
1 parent 51d5120 commit 62299cb
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 45 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/trivy.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 Hewlett Packard Enterprise Development LP
# Copyright 2022-2024 Hewlett Packard Enterprise Development LP
name: Trivy
on:
pull_request:
Expand All @@ -11,7 +11,7 @@ jobs:
uses: actions/checkout@v2

- name: Run Trivy vulnerability scanner (go.mod)
uses: aquasecurity/trivy-action@master
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: 'fs'
hide-progress: false
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module github.com/hewlettpackard/hpegl-provider-lib

go 1.21
go 1.22.1

toolchain go1.22.5

require (
Expand Down
2 changes: 1 addition & 1 deletion pkg/token/httpclient/httpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Client struct {

// New creates a new identity Client object
func New(identityServiceURL string, vendedServiceClient bool, passedInToken string) *Client {
client := &http.Client{Timeout: 10 * time.Second}
client := &http.Client{Timeout: 120 * time.Second}
identityServiceURL = strings.TrimRight(identityServiceURL, "/")

return &Client{
Expand Down
38 changes: 28 additions & 10 deletions pkg/token/identitytoken/identitytoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,27 @@ func GenerateToken(
return "", err
}

resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) {
req, errReq := http.NewRequestWithContext(reqCtx, http.MethodPost, url, strings.NewReader(string(b)))
if errReq != nil {
return nil, nil, errReq
}
req.Header.Set("Content-Type", "application/json")
resp, errResp := httpClient.Do(req)

return req, resp, errResp
}, retryLimit)
// Create a slice of cancel functions to be returned by the retries
cancelFuncs := make([]context.CancelFunc, 0)

resp, err := tokenutil.DoRetries(
ctx,
&cancelFuncs,
func(reqCtx context.Context) (*http.Request, *http.Response, error) {
req, errReq := http.NewRequestWithContext(reqCtx, http.MethodPost, url, strings.NewReader(string(b)))
if errReq != nil {
return nil, nil, errReq
}
req.Header.Set("Content-Type", "application/json")
respFromDo, errResp := httpClient.Do(req)

return req, respFromDo, errResp
},
retryLimit,
)
// Defer execution of cancelFuncs
defer executeCancelFuncs(&cancelFuncs)

if err != nil {
return "", err
}
Expand All @@ -91,3 +102,10 @@ func GenerateToken(

return token.AccessToken, nil
}

// executeCancelFuncs executes all cancel functions in the slice
func executeCancelFuncs(cancelFuncs *[]context.CancelFunc) {
for _, cancel := range *cancelFuncs {
cancel()
}
}
46 changes: 32 additions & 14 deletions pkg/token/issuertoken/issuertoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,32 @@ func GenerateToken(
return "", err
}

// Create a slice of cancel functions to be returned by the retries
cancelFuncs := make([]context.CancelFunc, 0)

// Execute the request, with retries
resp, err := tokenutil.DoRetries(ctx, func(reqCtx context.Context) (*http.Request, *http.Response, error) {
// Create the request
req, errReq := createRequest(reqCtx, params, clientURL)
if errReq != nil {
return nil, nil, errReq
}
// Close the request after use, i.e. don't reuse the TCP connection
req.Close = true

// Execute the request
resp, errResp := httpClient.Do(req)

return req, resp, errResp
}, retryLimit)
resp, err := tokenutil.DoRetries(
ctx,
&cancelFuncs,
func(reqCtx context.Context) (*http.Request, *http.Response, error) {
// Create the request
req, errReq := createRequest(reqCtx, params, clientURL)
if errReq != nil {
return nil, nil, errReq
}
// Close the request after use, i.e. don't reuse the TCP connection
req.Close = true

// Execute the request
respFromDo, errResp := httpClient.Do(req)

return req, respFromDo, errResp
},
retryLimit,
)
// Defer execution of cancel functions
defer executeCancelFuncs(&cancelFuncs)

if err != nil {
return "", err
}
Expand All @@ -80,6 +91,13 @@ func GenerateToken(
return token.AccessToken, nil
}

// executeCancelFuncs executes all cancel functions in the slice
func executeCancelFuncs(cancelFuncs *[]context.CancelFunc) {
for _, cancel := range *cancelFuncs {
cancel()
}
}

// createRequest creates a new http request
func createRequest(ctx context.Context, params url.Values, clientURL string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, clientURL, strings.NewReader(params.Encode()))
Expand Down
11 changes: 9 additions & 2 deletions pkg/token/token-util/token-util.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ func DecodeAccessToken(rawToken string) (Token, error) {
return token, nil
}

func DoRetries(ctx context.Context, call func(ctx context.Context) (*http.Request, *http.Response, error), retries int) (*http.Response, error) {
func DoRetries(
ctx context.Context,
cancelFuncs *[]context.CancelFunc,
call func(ctx context.Context) (*http.Request, *http.Response, error),
retries int,
) (*http.Response, error) {
var req *http.Request
var resp *http.Response
var err error
Expand All @@ -91,7 +96,9 @@ func DoRetries(ctx context.Context, call func(ctx context.Context) (*http.Reques

// Create a new context with a timeout
ctxWithTimeout, cancel := createContextWithTimeout(ctx)
defer cancel()

// Add the cancel function to the list of cancel functions
*cancelFuncs = append(*cancelFuncs, cancel)

// Execute the request
req, resp, err = call(ctxWithTimeout)
Expand Down
24 changes: 9 additions & 15 deletions pkg/token/token-util/token-util_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// (C) Copyright 2021 Hewlett Packard Enterprise Development LP
// (C) Copyright 2021-2024 Hewlett Packard Enterprise Development LP

package tokenutil

Expand Down Expand Up @@ -202,17 +202,6 @@ func TestDoRetries(t *testing.T) {
},
err: errLimitExceeded,
},
{
name: "Context cancelled",
ctx: context.Background(),
call: func(ctx context.Context) (*http.Request, *http.Response, error) {
req := &http.Request{}
req = req.WithContext(ctx)

return req, nil, context.Canceled
},
err: context.Canceled,
},
{
name: "no url",
ctx: context.Background(),
Expand All @@ -226,13 +215,16 @@ func TestDoRetries(t *testing.T) {
for _, testcase := range testcases {
tc := testcase
t.Run(tc.name, func(t *testing.T) {
resp, err := DoRetries(tc.ctx, tc.call, 1) // nolint: bodyclose
cancelFuncs := make([]context.CancelFunc, 0)
resp, err := DoRetries(tc.ctx, &cancelFuncs, tc.call, 2) // nolint: bodyclose
if tc.err != nil {
assert.EqualError(t, err, tc.err.Error())
if tc.err == errLimitExceeded {
assert.Equal(t, 1, totalRetries)
assert.Equal(t, 2, totalRetries)
assert.Equal(t, 2, len(cancelFuncs))
} else {
assert.Equal(t, 0, totalRetries)
assert.Equal(t, 1, len(cancelFuncs))
}

totalRetries = 0
Expand All @@ -243,8 +235,10 @@ func TestDoRetries(t *testing.T) {
// only 429, 500 and 502 status codes should retry
if tc.responseStatus == http.StatusForbidden {
assert.Equal(t, 0, totalRetries)
assert.Equal(t, 1, len(cancelFuncs))
} else {
assert.Equal(t, 1, totalRetries)
assert.Equal(t, 2, totalRetries)
assert.Equal(t, 2, len(cancelFuncs))
}

totalRetries = 0
Expand Down

0 comments on commit 62299cb

Please sign in to comment.