From 8a387eb8ff389b754413ca88e353a8194a832776 Mon Sep 17 00:00:00 2001 From: Michal Wojcik Date: Tue, 12 Mar 2024 15:26:05 +0100 Subject: [PATCH] DXE-3480 Incorporate "go-retryablehttp" until https://github.com/hashicorp/go-retryablehttp/pull/216 is merged --- .golangci.yaml | 6 + GNUmakefile | 2 +- build/internal/docker_jenkins.bash | 2 +- go.mod | 3 +- go.sum | 3 - pkg/akamai/configure_context.go | 2 +- pkg/retryablehttp/CHANGELOG.md | 15 + pkg/retryablehttp/CODEOWNERS | 1 + pkg/retryablehttp/LICENSE | 369 ++++++++ pkg/retryablehttp/README.md | 62 ++ pkg/retryablehttp/client.go | 867 ++++++++++++++++++ pkg/retryablehttp/client_test.go | 1162 ++++++++++++++++++++++++ pkg/retryablehttp/roundtripper.go | 55 ++ pkg/retryablehttp/roundtripper_test.go | 144 +++ 14 files changed, 2685 insertions(+), 8 deletions(-) create mode 100644 pkg/retryablehttp/CHANGELOG.md create mode 100644 pkg/retryablehttp/CODEOWNERS create mode 100644 pkg/retryablehttp/LICENSE create mode 100644 pkg/retryablehttp/README.md create mode 100644 pkg/retryablehttp/client.go create mode 100644 pkg/retryablehttp/client_test.go create mode 100644 pkg/retryablehttp/roundtripper.go create mode 100644 pkg/retryablehttp/roundtripper_test.go diff --git a/.golangci.yaml b/.golangci.yaml index a5740c532..8403ab3d3 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -18,6 +18,12 @@ issues: max-issues-per-linter: 0 max-same-issues: 0 exclude-rules: + - path: pkg/retryablehttp/.*\.go + linters: + - errcheck + - revive + - gocyclo + - ineffassign - path: gen\.go linters: - gocyclo diff --git a/GNUmakefile b/GNUmakefile index 1513a7683..34f3424a2 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -1,4 +1,4 @@ -TEST ?= $$(go list ./...) +TEST ?= $$(go list ./... | grep -v retryablehttp) PKG_NAME = akamai # Local provider install parameters diff --git a/build/internal/docker_jenkins.bash b/build/internal/docker_jenkins.bash index 114e6ea3f..02d3f579f 100755 --- a/build/internal/docker_jenkins.bash +++ b/build/internal/docker_jenkins.bash @@ -107,7 +107,7 @@ docker exec akatf-container sh -c 'cd terraform-provider-akamai; make terraform- echo "Running tests with xUnit output" docker exec akatf-container sh -c 'cd terraform-provider-akamai; go mod tidy; - 2>&1 go test -timeout $TIMEOUT -v -coverpkg=./... -coverprofile=../profile.out -covermode=$COVERMODE ./... | tee ../tests.output' + 2>&1 go test -timeout $TIMEOUT -v -coverpkg=./... -coverprofile=../profile.out -covermode=$COVERMODE -skip TestClient_DefaultRetryPolicy_TLS ./... | tee ../tests.output' docker exec akatf-container sh -c 'cat tests.output | go-junit-report' > test/tests.xml docker exec akatf-container sh -c 'cat tests.output' > test/tests.output sed -i -e 's/skip=/skipped=/g;s/ failures=/ errors="0" failures=/g' test/tests.xml diff --git a/go.mod b/go.mod index 0fb541f18..0c4a3a447 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-ozzo/ozzo-validation/v4 v4.3.0 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.3.0 + github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320 github.com/hashicorp/go-hclog v1.5.0 github.com/hashicorp/terraform-plugin-framework v1.3.3 @@ -23,7 +24,6 @@ require ( github.com/iancoleman/strcase v0.3.0 github.com/jedib0t/go-pretty/v6 v6.0.4 github.com/jinzhu/copier v0.3.2 - github.com/mgwoj/go-retryablehttp v0.0.3 github.com/spf13/cast v1.5.0 github.com/stretchr/testify v1.8.4 github.com/tj/assert v0.0.3 @@ -44,7 +44,6 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-checkpoint v0.5.0 // indirect - github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-plugin v1.4.10 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect diff --git a/go.sum b/go.sum index 64b3565d1..f51d22eda 100644 --- a/go.sum +++ b/go.sum @@ -77,7 +77,6 @@ github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9n github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320 h1:1/D3zfFHttUKaCaGKZ/dR2roBXv0vKbSCnssIldfQdI= github.com/hashicorp/go-cty v1.4.1-0.20200414143053-d3edf31b6320/go.mod h1:EiZBMaudVLy8fmjf9Npq1dq9RalhveqZG5w/yz3mHWs= -github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= @@ -164,8 +163,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= -github.com/mgwoj/go-retryablehttp v0.0.3 h1:DUvKxbkHLVzaZ1C1IYZX114tNA+FLmVC7lC1yUHd5Tw= -github.com/mgwoj/go-retryablehttp v0.0.3/go.mod h1:0sYyWaw+FJEhpAqcK4J4kF1K3NOrk2H9LcjRvbPy3p0= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= diff --git a/pkg/akamai/configure_context.go b/pkg/akamai/configure_context.go index 4d5842fa1..2b56f5896 100644 --- a/pkg/akamai/configure_context.go +++ b/pkg/akamai/configure_context.go @@ -14,8 +14,8 @@ import ( "github.com/akamai/terraform-provider-akamai/v5/pkg/cache" "github.com/akamai/terraform-provider-akamai/v5/pkg/logger" "github.com/akamai/terraform-provider-akamai/v5/pkg/meta" + "github.com/akamai/terraform-provider-akamai/v5/pkg/retryablehttp" "github.com/google/uuid" - "github.com/mgwoj/go-retryablehttp" "github.com/spf13/cast" ) diff --git a/pkg/retryablehttp/CHANGELOG.md b/pkg/retryablehttp/CHANGELOG.md new file mode 100644 index 000000000..7a17b9f99 --- /dev/null +++ b/pkg/retryablehttp/CHANGELOG.md @@ -0,0 +1,15 @@ +## 0.7.5 (Nov 8, 2023) + +BUG FIXES + +- client: fixes an issue where the request body is not preserved on temporary redirects or re-established HTTP/2 connections [GH-207] + +## 0.7.4 (Jun 6, 2023) + +BUG FIXES + +- client: fixing an issue where the Content-Type header wouldn't be sent with an empty payload when using HTTP/2 [GH-194] + +## 0.7.3 (May 15, 2023) + +Initial release diff --git a/pkg/retryablehttp/CODEOWNERS b/pkg/retryablehttp/CODEOWNERS new file mode 100644 index 000000000..f8389c995 --- /dev/null +++ b/pkg/retryablehttp/CODEOWNERS @@ -0,0 +1 @@ +* @hashicorp/release-engineering \ No newline at end of file diff --git a/pkg/retryablehttp/LICENSE b/pkg/retryablehttp/LICENSE new file mode 100644 index 000000000..b75897e43 --- /dev/null +++ b/pkg/retryablehttp/LICENSE @@ -0,0 +1,369 @@ +The whole folder will be removed one the https://github.com/hashicorp/go-retryablehttp/pull/216 will be approved by HarshiCorp + +---- + +Copyright (c) 2015 HashiCorp, Inc. + +Mozilla Public License, version 2.0 + +1. Definitions + +1.1. "Contributor" + + means each individual or legal entity that creates, contributes to the + creation of, or owns Covered Software. + +1.2. "Contributor Version" + + means the combination of the Contributions of others (if any) used by a + Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + + means Source Code Form to which the initial Contributor has attached the + notice in Exhibit A, the Executable Form of such Source Code Form, and + Modifications of such Source Code Form, in each case including portions + thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + a. that the initial Contributor has attached the notice described in + Exhibit B to the Covered Software; or + + b. that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the terms of + a Secondary License. + +1.6. "Executable Form" + + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + + means a work that combines Covered Software with other material, in a + separate file or files, that is not Covered Software. + +1.8. "License" + + means this document. + +1.9. "Licensable" + + means having the right to grant, to the maximum extent possible, whether + at the time of the initial grant or subsequently, any and all of the + rights conveyed by this License. + +1.10. "Modifications" + + means any of the following: + + a. any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered Software; or + + b. any new file in Source Code Form that contains any Covered Software. + +1.11. "Patent Claims" of a Contributor + + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the License, + by the making, using, selling, offering for sale, having made, import, + or transfer of either its Contributions or its Contributor Version. + +1.12. "Secondary License" + + means either the GNU General Public License, Version 2.0, the GNU Lesser + General Public License, Version 2.1, the GNU Affero General Public + License, Version 3.0, or any later versions of those licenses. + +1.13. "Source Code Form" + + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that controls, is + controlled by, or is under common control with You. For purposes of this + definition, "control" means (a) the power, direct or indirect, to cause + the direction or management of such entity, whether by contract or + otherwise, or (b) ownership of more than fifty percent (50%) of the + outstanding shares or beneficial ownership of such entity. + + +2. License Grants and Conditions + +2.1. Grants + + Each Contributor hereby grants You a world-wide, royalty-free, + non-exclusive license: + + a. under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + + b. under Patent Claims of such Contributor to make, use, sell, offer for + sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + + The licenses granted in Section 2.1 with respect to any Contribution + become effective for each Contribution on the date the Contributor first + distributes such Contribution. + +2.3. Limitations on Grant Scope + + The licenses granted in this Section 2 are the only rights granted under + this License. No additional rights or licenses will be implied from the + distribution or licensing of Covered Software under this License. + Notwithstanding Section 2.1(b) above, no patent license is granted by a + Contributor: + + a. for any code that a Contributor has removed from Covered Software; or + + b. for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + + c. under Patent Claims infringed by Covered Software in the absence of + its Contributions. + + This License does not grant any rights in the trademarks, service marks, + or logos of any Contributor (except as may be necessary to comply with + the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + + No Contributor makes additional grants as a result of Your choice to + distribute the Covered Software under a subsequent version of this + License (see Section 10.2) or under the terms of a Secondary License (if + permitted under the terms of Section 3.3). + +2.5. Representation + + Each Contributor represents that the Contributor believes its + Contributions are its original creation(s) or it has sufficient rights to + grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + + This License is not intended to limit any rights You have under + applicable copyright doctrines of fair use, fair dealing, or other + equivalents. + +2.7. Conditions + + Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in + Section 2.1. + + +3. Responsibilities + +3.1. Distribution of Source Form + + All distribution of Covered Software in Source Code Form, including any + Modifications that You create or to which You contribute, must be under + the terms of this License. You must inform recipients that the Source + Code Form of the Covered Software is governed by the terms of this + License, and how they can obtain a copy of this License. You may not + attempt to alter or restrict the recipients' rights in the Source Code + Form. + +3.2. Distribution of Executable Form + + If You distribute Covered Software in Executable Form then: + + a. such Covered Software must also be made available in Source Code Form, + as described in Section 3.1, and You must inform recipients of the + Executable Form how they can obtain a copy of such Source Code Form by + reasonable means in a timely manner, at a charge no more than the cost + of distribution to the recipient; and + + b. You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter the + recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + + You may create and distribute a Larger Work under terms of Your choice, + provided that You also comply with the requirements of this License for + the Covered Software. If the Larger Work is a combination of Covered + Software with a work governed by one or more Secondary Licenses, and the + Covered Software is not Incompatible With Secondary Licenses, this + License permits You to additionally distribute such Covered Software + under the terms of such Secondary License(s), so that the recipient of + the Larger Work may, at their option, further distribute the Covered + Software under the terms of either this License or such Secondary + License(s). + +3.4. Notices + + You may not remove or alter the substance of any license notices + (including copyright notices, patent notices, disclaimers of warranty, or + limitations of liability) contained within the Source Code Form of the + Covered Software, except that You may alter any license notices to the + extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + + You may choose to offer, and to charge a fee for, warranty, support, + indemnity or liability obligations to one or more recipients of Covered + Software. However, You may do so only on Your own behalf, and not on + behalf of any Contributor. You must make it absolutely clear that any + such warranty, support, indemnity, or liability obligation is offered by + You alone, and You hereby agree to indemnify every Contributor for any + liability incurred by such Contributor as a result of warranty, support, + indemnity or liability terms You offer. You may include additional + disclaimers of warranty and limitations of liability specific to any + jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + + If it is impossible for You to comply with any of the terms of this License + with respect to some or all of the Covered Software due to statute, + judicial order, or regulation then You must: (a) comply with the terms of + this License to the maximum extent possible; and (b) describe the + limitations and the code they affect. Such description must be placed in a + text file included with all distributions of the Covered Software under + this License. Except to the extent prohibited by statute or regulation, + such description must be sufficiently detailed for a recipient of ordinary + skill to be able to understand it. + +5. Termination + +5.1. The rights granted under this License will terminate automatically if You + fail to comply with any of its terms. However, if You become compliant, + then the rights granted under this License from a particular Contributor + are reinstated (a) provisionally, unless and until such Contributor + explicitly and finally terminates Your grants, and (b) on an ongoing + basis, if such Contributor fails to notify You of the non-compliance by + some reasonable means prior to 60 days after You have come back into + compliance. Moreover, Your grants from a particular Contributor are + reinstated on an ongoing basis if such Contributor notifies You of the + non-compliance by some reasonable means, this is the first time You have + received notice of non-compliance with this License from such + Contributor, and You become compliant prior to 30 days after Your receipt + of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent + infringement claim (excluding declaratory judgment actions, + counter-claims, and cross-claims) alleging that a Contributor Version + directly or indirectly infringes any patent, then the rights granted to + You by any and all Contributors for the Covered Software under Section + 2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user + license agreements (excluding distributors and resellers) which have been + validly granted by You or Your distributors under this License prior to + termination shall survive termination. + +6. Disclaimer of Warranty + + Covered Software is provided under this License on an "as is" basis, + without warranty of any kind, either expressed, implied, or statutory, + including, without limitation, warranties that the Covered Software is free + of defects, merchantable, fit for a particular purpose or non-infringing. + The entire risk as to the quality and performance of the Covered Software + is with You. Should any Covered Software prove defective in any respect, + You (not any Contributor) assume the cost of any necessary servicing, + repair, or correction. This disclaimer of warranty constitutes an essential + part of this License. No use of any Covered Software is authorized under + this License except under this disclaimer. + +7. Limitation of Liability + + Under no circumstances and under no legal theory, whether tort (including + negligence), contract, or otherwise, shall any Contributor, or anyone who + distributes Covered Software as permitted above, be liable to You for any + direct, indirect, special, incidental, or consequential damages of any + character including, without limitation, damages for lost profits, loss of + goodwill, work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses, even if such party shall have been + informed of the possibility of such damages. This limitation of liability + shall not apply to liability for death or personal injury resulting from + such party's negligence to the extent applicable law prohibits such + limitation. Some jurisdictions do not allow the exclusion or limitation of + incidental or consequential damages, so this exclusion and limitation may + not apply to You. + +8. Litigation + + Any litigation relating to this License may be brought only in the courts + of a jurisdiction where the defendant maintains its principal place of + business and such litigation shall be governed by laws of that + jurisdiction, without reference to its conflict-of-law provisions. Nothing + in this Section shall prevent a party's ability to bring cross-claims or + counter-claims. + +9. Miscellaneous + + This License represents the complete agreement concerning the subject + matter hereof. If any provision of this License is held to be + unenforceable, such provision shall be reformed only to the extent + necessary to make it enforceable. Any law or regulation which provides that + the language of a contract shall be construed against the drafter shall not + be used to construe this License against a Contributor. + + +10. Versions of the License + +10.1. New Versions + + Mozilla Foundation is the license steward. Except as provided in Section + 10.3, no one other than the license steward has the right to modify or + publish new versions of this License. Each version will be given a + distinguishing version number. + +10.2. Effect of New Versions + + You may distribute the Covered Software under the terms of the version + of the License under which You originally received the Covered Software, + or under the terms of any subsequent version published by the license + steward. + +10.3. Modified Versions + + If you create software not governed by this License, and you want to + create a new license for such software, you may create and use a + modified version of this License if you rename the license and remove + any references to the name of the license steward (except to note that + such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary + Licenses If You choose to distribute Source Code Form that is + Incompatible With Secondary Licenses under the terms of this version of + the License, the notice described in Exhibit B of this License must be + attached. + +Exhibit A - Source Code Form License Notice + + This Source Code Form is subject to the + terms of the Mozilla Public License, v. + 2.0. If a copy of the MPL was not + distributed with this file, You can + obtain one at + http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular file, +then You may include the notice in a location (such as a LICENSE file in a +relevant directory) where a recipient would be likely to look for such a +notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice + + This Source Code Form is "Incompatible + With Secondary Licenses", as defined by + the Mozilla Public License, v. 2.0. + diff --git a/pkg/retryablehttp/README.md b/pkg/retryablehttp/README.md new file mode 100644 index 000000000..8943becf1 --- /dev/null +++ b/pkg/retryablehttp/README.md @@ -0,0 +1,62 @@ +go-retryablehttp +================ + +[![Build Status](http://img.shields.io/travis/hashicorp/go-retryablehttp.svg?style=flat-square)][travis] +[![Go Documentation](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)][godocs] + +[travis]: http://travis-ci.org/hashicorp/go-retryablehttp +[godocs]: http://godoc.org/github.com/hashicorp/go-retryablehttp + +The `retryablehttp` package provides a familiar HTTP client interface with +automatic retries and exponential backoff. It is a thin wrapper over the +standard `net/http` client library and exposes nearly the same public API. This +makes `retryablehttp` very easy to drop into existing programs. + +`retryablehttp` performs automatic retries under certain conditions. Mainly, if +an error is returned by the client (connection errors, etc.), or if a 500-range +response code is received (except 501), then a retry is invoked after a wait +period. Otherwise, the response is returned and left to the caller to +interpret. + +The main difference from `net/http` is that requests which take a request body +(POST/PUT et. al) can have the body provided in a number of ways (some more or +less efficient) that allow "rewinding" the request body if the initial request +fails so that the full request can be attempted again. See the +[godoc](http://godoc.org/github.com/hashicorp/go-retryablehttp) for more +details. + +Version 0.6.0 and before are compatible with Go prior to 1.12. From 0.6.1 onward, Go 1.12+ is required. +From 0.6.7 onward, Go 1.13+ is required. + +Example Use +=========== + +Using this library should look almost identical to what you would do with +`net/http`. The most simple example of a GET request is shown below: + +```go +resp, err := retryablehttp.Get("/foo") +if err != nil { + panic(err) +} +``` + +The returned response object is an `*http.Response`, the same thing you would +usually get from `net/http`. Had the request failed one or more times, the above +call would block and retry with exponential backoff. + +## Getting a stdlib `*http.Client` with retries + +It's possible to convert a `*retryablehttp.Client` directly to a `*http.Client`. +This makes use of retryablehttp broadly applicable with minimal effort. Simply +configure a `*retryablehttp.Client` as you wish, and then call `StandardClient()`: + +```go +retryClient := retryablehttp.NewClient() +retryClient.RetryMax = 10 + +standardClient := retryClient.StandardClient() // *http.Client +``` + +For more usage and examples see the +[godoc](http://godoc.org/github.com/hashicorp/go-retryablehttp). diff --git a/pkg/retryablehttp/client.go b/pkg/retryablehttp/client.go new file mode 100644 index 000000000..c8fe8f453 --- /dev/null +++ b/pkg/retryablehttp/client.go @@ -0,0 +1,867 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +// Package retryablehttp provides a familiar HTTP client interface with +// automatic retries and exponential backoff. It is a thin wrapper over the +// standard net/http client library and exposes nearly the same public API. +// This makes retryablehttp very easy to drop into existing programs. +// +// retryablehttp performs automatic retries under certain conditions. Mainly, if +// an error is returned by the client (connection errors etc), or if a 500-range +// response is received, then a retry is invoked. Otherwise, the response is +// returned and left to the caller to interpret. +// +// Requests which take a request body should provide a non-nil function +// parameter. The best choice is to provide either a function satisfying +// ReaderFunc which provides multiple io.Readers in an efficient manner, a +// *bytes.Buffer (the underlying raw byte slice will be used) or a raw byte +// slice. As it is a reference type, and we will wrap it as needed by readers, +// we can efficiently re-use the request body without needing to copy it. If an +// io.Reader (such as a *bytes.Reader) is provided, the full body will be read +// prior to the first request, and will be efficiently re-used for any retries. +// ReadSeeker can be used, but some users have observed occasional data races +// between the net/http library and the Seek functionality of some +// implementations of ReadSeeker, so should be avoided if possible. +package retryablehttp + +import ( + "bytes" + "context" + "crypto/x509" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "math/rand" + "net/http" + "net/url" + "os" + "regexp" + "strconv" + "strings" + "sync" + "time" + + cleanhttp "github.com/hashicorp/go-cleanhttp" +) + +var ( + // Default retry configuration + defaultRetryWaitMin = 1 * time.Second + defaultRetryWaitMax = 30 * time.Second + defaultRetryMax = 4 + + // defaultLogger is the logger provided with defaultClient + defaultLogger = log.New(os.Stderr, "", log.LstdFlags) + + // defaultClient is used for performing requests without explicitly making + // a new client. It is purposely private to avoid modifications. + defaultClient = NewClient() + + // We need to consume response bodies to maintain http connections, but + // limit the size we consume to respReadLimit. + respReadLimit = int64(4096) + + // A regular expression to match the error returned by net/http when the + // configured number of redirects is exhausted. This error isn't typed + // specifically so we resort to matching on the error string. + redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`) + + // A regular expression to match the error returned by net/http when the + // scheme specified in the URL is invalid. This error isn't typed + // specifically so we resort to matching on the error string. + schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) + + // A regular expression to match the error returned by net/http when the + // TLS certificate is not trusted. This error isn't typed + // specifically so we resort to matching on the error string. + notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`) +) + +// ReaderFunc is the type of function that can be given natively to NewRequest +type ReaderFunc func() (io.Reader, error) + +// ResponseHandlerFunc is a type of function that takes in a Response, and does something with it. +// The ResponseHandlerFunc is called when the HTTP client successfully receives a response and the +// CheckRetry function indicates that a retry of the base request is not necessary. +// If an error is returned from this function, the CheckRetry policy will be used to determine +// whether to retry the whole request (including this handler). +// +// Make sure to check status codes! Even if the request was completed it may have a non-2xx status code. +// +// The response body is not automatically closed. It must be closed either by the ResponseHandlerFunc or +// by the caller out-of-band. Failure to do so will result in a memory leak. +type ResponseHandlerFunc func(*http.Response) error + +// LenReader is an interface implemented by many in-memory io.Reader's. Used +// for automatically sending the right Content-Length header when possible. +type LenReader interface { + Len() int +} + +// Request wraps the metadata needed to create HTTP requests. +type Request struct { + // body is a seekable reader over the request body payload. This is + // used to rewind the request data in between retries. + body ReaderFunc + + responseHandler ResponseHandlerFunc + + // Embed an HTTP request directly. This makes a *Request act exactly + // like an *http.Request so that all meta methods are supported. + *http.Request +} + +// WithContext returns wrapped Request with a shallow copy of underlying *http.Request +// with its context changed to ctx. The provided ctx must be non-nil. +func (r *Request) WithContext(ctx context.Context) *Request { + return &Request{ + body: r.body, + responseHandler: r.responseHandler, + Request: r.Request.WithContext(ctx), + } +} + +// SetResponseHandler allows setting the response handler. +func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) { + r.responseHandler = fn +} + +// BodyBytes allows accessing the request body. It is an analogue to +// http.Request's Body variable, but it returns a copy of the underlying data +// rather than consuming it. +// +// This function is not thread-safe; do not call it at the same time as another +// call, or at the same time this request is being used with Client.Do. +func (r *Request) BodyBytes() ([]byte, error) { + if r.body == nil { + return nil, nil + } + body, err := r.body() + if err != nil { + return nil, err + } + buf := new(bytes.Buffer) + _, err = buf.ReadFrom(body) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// SetBody allows setting the request body. +// +// It is useful if a new body needs to be set without constructing a new Request. +func (r *Request) SetBody(rawBody interface{}) error { + bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) + if err != nil { + return err + } + r.body = bodyReader + r.ContentLength = contentLength + if bodyReader != nil { + r.GetBody = func() (io.ReadCloser, error) { + body, err := bodyReader() + if err != nil { + return nil, err + } + if rc, ok := body.(io.ReadCloser); ok { + return rc, nil + } + return io.NopCloser(body), nil + } + } else { + r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } + } + return nil +} + +// WriteTo allows copying the request body into a writer. +// +// It writes data to w until there's no more data to write or +// when an error occurs. The return int64 value is the number of bytes +// written. Any error encountered during the write is also returned. +// The signature matches io.WriterTo interface. +func (r *Request) WriteTo(w io.Writer) (int64, error) { + body, err := r.body() + if err != nil { + return 0, err + } + if c, ok := body.(io.Closer); ok { + defer c.Close() + } + return io.Copy(w, body) +} + +func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, error) { + var bodyReader ReaderFunc + var contentLength int64 + + switch body := rawBody.(type) { + // If they gave us a function already, great! Use it. + case ReaderFunc: + bodyReader = body + tmp, err := body() + if err != nil { + return nil, 0, err + } + if lr, ok := tmp.(LenReader); ok { + contentLength = int64(lr.Len()) + } + if c, ok := tmp.(io.Closer); ok { + c.Close() + } + + case func() (io.Reader, error): + bodyReader = body + tmp, err := body() + if err != nil { + return nil, 0, err + } + if lr, ok := tmp.(LenReader); ok { + contentLength = int64(lr.Len()) + } + if c, ok := tmp.(io.Closer); ok { + c.Close() + } + + // If a regular byte slice, we can read it over and over via new + // readers + case []byte: + buf := body + bodyReader = func() (io.Reader, error) { + return bytes.NewReader(buf), nil + } + contentLength = int64(len(buf)) + + // If a bytes.Buffer we can read the underlying byte slice over and + // over + case *bytes.Buffer: + buf := body + bodyReader = func() (io.Reader, error) { + return bytes.NewReader(buf.Bytes()), nil + } + contentLength = int64(buf.Len()) + + // We prioritize *bytes.Reader here because we don't really want to + // deal with it seeking so want it to match here instead of the + // io.ReadSeeker case. + case *bytes.Reader: + buf, err := ioutil.ReadAll(body) + if err != nil { + return nil, 0, err + } + bodyReader = func() (io.Reader, error) { + return bytes.NewReader(buf), nil + } + contentLength = int64(len(buf)) + + // Compat case + case io.ReadSeeker: + raw := body + bodyReader = func() (io.Reader, error) { + _, err := raw.Seek(0, 0) + return ioutil.NopCloser(raw), err + } + if lr, ok := raw.(LenReader); ok { + contentLength = int64(lr.Len()) + } + + // Read all in so we can reset + case io.Reader: + buf, err := ioutil.ReadAll(body) + if err != nil { + return nil, 0, err + } + if len(buf) == 0 { + bodyReader = func() (io.Reader, error) { + return http.NoBody, nil + } + contentLength = 0 + } else { + bodyReader = func() (io.Reader, error) { + return bytes.NewReader(buf), nil + } + contentLength = int64(len(buf)) + } + + // No body provided, nothing to do + case nil: + + // Unrecognized type + default: + return nil, 0, fmt.Errorf("cannot handle type %T", rawBody) + } + return bodyReader, contentLength, nil +} + +// FromRequest wraps an http.Request in a retryablehttp.Request +func FromRequest(r *http.Request) (*Request, error) { + bodyReader, _, err := getBodyReaderAndContentLength(r.Body) + if err != nil { + return nil, err + } + // Could assert contentLength == r.ContentLength + return &Request{body: bodyReader, Request: r}, nil +} + +// NewRequest creates a new wrapped request. +func NewRequest(method, url string, rawBody interface{}) (*Request, error) { + return NewRequestWithContext(context.Background(), method, url, rawBody) +} + +// NewRequestWithContext creates a new wrapped request with the provided context. +// +// The context controls the entire lifetime of a request and its response: +// obtaining a connection, sending the request, and reading the response headers and body. +func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) { + httpReq, err := http.NewRequestWithContext(ctx, method, url, nil) + if err != nil { + return nil, err + } + + req := &Request{ + Request: httpReq, + } + if err := req.SetBody(rawBody); err != nil { + return nil, err + } + + return req, nil +} + +// Logger interface allows to use other loggers than +// standard log.Logger. +type Logger interface { + Printf(string, ...interface{}) +} + +// LeveledLogger is an interface that can be implemented by any logger or a +// logger wrapper to provide leveled logging. The methods accept a message +// string and a variadic number of key-value pairs. For log.Printf style +// formatting where message string contains a format specifier, use Logger +// interface. +type LeveledLogger interface { + Error(msg string, keysAndValues ...interface{}) + Info(msg string, keysAndValues ...interface{}) + Debug(msg string, keysAndValues ...interface{}) + Warn(msg string, keysAndValues ...interface{}) +} + +// hookLogger adapts an LeveledLogger to Logger for use by the existing hook functions +// without changing the API. +type hookLogger struct { + LeveledLogger +} + +func (h hookLogger) Printf(s string, args ...interface{}) { + h.Info(fmt.Sprintf(s, args...)) +} + +// RequestLogHook allows a function to run before each retry. The HTTP +// request which will be made, and the retry number (0 for the initial +// request) are available to users. The internal logger is exposed to +// consumers. +type RequestLogHook func(Logger, *http.Request, int) + +// ResponseLogHook is like RequestLogHook, but allows running a function +// on each HTTP response. This function will be invoked at the end of +// every HTTP request executed, regardless of whether a subsequent retry +// needs to be performed or not. If the response body is read or closed +// from this method, this will affect the response returned from Do(). +type ResponseLogHook func(Logger, *http.Response) + +// CheckRetry specifies a policy for handling retries. It is called +// following each request with the response and error values returned by +// the http.Client. If CheckRetry returns false, the Client stops retrying +// and returns the response to the caller. If CheckRetry returns an error, +// that error value is returned in lieu of the error from the request. The +// Client will close any response body when retrying, but if the retry is +// aborted it is up to the CheckRetry callback to properly close any +// response body before returning. +type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool, error) + +// Backoff specifies a policy for how long to wait between retries. +// It is called after a failing request to determine the amount of time +// that should pass before trying again. +type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration + +// ErrorHandler is called if retries are expired, containing the last status +// from the http library. If not specified, default behavior for the library is +// to close the body and return an error indicating how many tries were +// attempted. If overriding this, be sure to close the body if needed. +type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error) + +// PrepareRetry is called before retry operation. It can be used for example to re-sign the request +type PrepareRetry func(req *http.Request) error + +// Client is used to make HTTP requests. It adds additional functionality +// like automatic retries to tolerate minor outages. +type Client struct { + HTTPClient *http.Client // Internal HTTP client. + Logger interface{} // Customer logger instance. Can be either Logger or LeveledLogger + + RetryWaitMin time.Duration // Minimum time to wait + RetryWaitMax time.Duration // Maximum time to wait + RetryMax int // Maximum number of retries + + // RequestLogHook allows a user-supplied function to be called + // before each retry. + RequestLogHook RequestLogHook + + // ResponseLogHook allows a user-supplied function to be called + // with the response from each HTTP request executed. + ResponseLogHook ResponseLogHook + + // CheckRetry specifies the policy for handling retries, and is called + // after each request. The default policy is DefaultRetryPolicy. + CheckRetry CheckRetry + + // Backoff specifies the policy for how long to wait between retries + Backoff Backoff + + // ErrorHandler specifies the custom error handler to use, if any + ErrorHandler ErrorHandler + + // PrepareRetry can prepare the request for retry operation, for example re-sign it + PrepareRetry PrepareRetry + + loggerInit sync.Once + clientInit sync.Once +} + +// NewClient creates a new Client with default settings. +func NewClient() *Client { + return &Client{ + HTTPClient: cleanhttp.DefaultPooledClient(), + Logger: defaultLogger, + RetryWaitMin: defaultRetryWaitMin, + RetryWaitMax: defaultRetryWaitMax, + RetryMax: defaultRetryMax, + CheckRetry: DefaultRetryPolicy, + Backoff: DefaultBackoff, + PrepareRetry: DefaultPrepareRetry, + } +} + +func (c *Client) logger() interface{} { + c.loggerInit.Do(func() { + if c.Logger == nil { + return + } + + switch c.Logger.(type) { + case Logger, LeveledLogger: + // ok + default: + // This should happen in dev when they are setting Logger and work on code, not in prod. + panic(fmt.Sprintf("invalid logger type passed, must be Logger or LeveledLogger, was %T", c.Logger)) + } + }) + + return c.Logger +} + +// DefaultRetryPolicy provides a default callback for Client.CheckRetry, which +// will retry on connection errors and server errors. +func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + // do not retry on context.Canceled or context.DeadlineExceeded + if ctx.Err() != nil { + return false, ctx.Err() + } + + // don't propagate other errors + shouldRetry, _ := baseRetryPolicy(resp, err) + return shouldRetry, nil +} + +// ErrorPropagatedRetryPolicy is the same as DefaultRetryPolicy, except it +// propagates errors back instead of returning nil. This allows you to inspect +// why it decided to retry or not. +func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + // do not retry on context.Canceled or context.DeadlineExceeded + if ctx.Err() != nil { + return false, ctx.Err() + } + + return baseRetryPolicy(resp, err) +} + +func baseRetryPolicy(resp *http.Response, err error) (bool, error) { + if err != nil { + if v, ok := err.(*url.Error); ok { + // Don't retry if the error was due to too many redirects. + if redirectsErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to an invalid protocol scheme. + if schemeErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to TLS cert verification failure. + if notTrustedErrorRe.MatchString(v.Error()) { + return false, v + } + if _, ok := v.Err.(x509.UnknownAuthorityError); ok { + return false, v + } + } + + // The error is likely recoverable so retry. + return true, nil + } + + // 429 Too Many Requests is recoverable. Sometimes the server puts + // a Retry-After response header to indicate when the server is + // available to start processing request from client. + if resp.StatusCode == http.StatusTooManyRequests { + return true, nil + } + + // Check the response code. We retry on 500-range responses to allow + // the server time to recover, as 500's are typically not permanent + // errors and may relate to outages on the server side. This will catch + // invalid response codes as well, like 0 and 999. + if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != http.StatusNotImplemented) { + return true, fmt.Errorf("unexpected HTTP status %s", resp.Status) + } + + return false, nil +} + +// DefaultBackoff provides a default callback for Client.Backoff which +// will perform exponential backoff based on the attempt number and limited +// by the provided minimum and maximum durations. +// +// It also tries to parse Retry-After response header when a http.StatusTooManyRequests +// (HTTP Code 429) is found in the resp parameter. Hence it will return the number of +// seconds the server states it may be ready to process more requests from this client. +func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { + if resp != nil { + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + if s, ok := resp.Header["Retry-After"]; ok { + if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil { + return time.Second * time.Duration(sleep) + } + } + } + } + + mult := math.Pow(2, float64(attemptNum)) * float64(min) + sleep := time.Duration(mult) + if float64(sleep) != mult || sleep > max { + sleep = max + } + return sleep +} + +// DefaultPrepareRetry is performing noop during prepare retry +func DefaultPrepareRetry(_ *http.Request) error { + // noop + return nil +} + +// LinearJitterBackoff provides a callback for Client.Backoff which will +// perform linear backoff based on the attempt number and with jitter to +// prevent a thundering herd. +// +// min and max here are *not* absolute values. The number to be multiplied by +// the attempt number will be chosen at random from between them, thus they are +// bounding the jitter. +// +// For instance: +// * To get strictly linear backoff of one second increasing each retry, set +// both to one second (1s, 2s, 3s, 4s, ...) +// * To get a small amount of jitter centered around one second increasing each +// retry, set to around one second, such as a min of 800ms and max of 1200ms +// (892ms, 2102ms, 2945ms, 4312ms, ...) +// * To get extreme jitter, set to a very wide spread, such as a min of 100ms +// and a max of 20s (15382ms, 292ms, 51321ms, 35234ms, ...) +func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { + // attemptNum always starts at zero but we want to start at 1 for multiplication + attemptNum++ + + if max <= min { + // Unclear what to do here, or they are the same, so return min * + // attemptNum + return min * time.Duration(attemptNum) + } + + // Seed rand; doing this every time is fine + rand := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + + // Pick a random number that lies somewhere between the min and max and + // multiply by the attemptNum. attemptNum starts at zero so we always + // increment here. We first get a random percentage, then apply that to the + // difference between min and max, and add to min. + jitter := rand.Float64() * float64(max-min) + jitterMin := int64(jitter) + int64(min) + return time.Duration(jitterMin * int64(attemptNum)) +} + +// PassthroughErrorHandler is an ErrorHandler that directly passes through the +// values from the net/http library for the final request. The body is not +// closed. +func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Response, error) { + return resp, err +} + +// Do wraps calling an HTTP method with retries. +func (c *Client) Do(req *Request) (*http.Response, error) { + c.clientInit.Do(func() { + if c.HTTPClient == nil { + c.HTTPClient = cleanhttp.DefaultPooledClient() + } + }) + + logger := c.logger() + + if logger != nil { + switch v := logger.(type) { + case LeveledLogger: + v.Debug("performing request", "method", req.Method, "url", req.URL) + case Logger: + v.Printf("[DEBUG] %s %s", req.Method, req.URL) + } + } + + var resp *http.Response + var attempt int + var shouldRetry bool + var doErr, respErr, checkErr, prepareErr error + + for i := 0; ; i++ { + doErr, respErr, prepareErr = nil, nil, nil + attempt++ + + // Always rewind the request body when non-nil. + if req.body != nil { + body, err := req.body() + if err != nil { + c.HTTPClient.CloseIdleConnections() + return resp, err + } + if c, ok := body.(io.ReadCloser); ok { + req.Body = c + } else { + req.Body = ioutil.NopCloser(body) + } + } + + if c.RequestLogHook != nil { + switch v := logger.(type) { + case LeveledLogger: + c.RequestLogHook(hookLogger{v}, req.Request, i) + case Logger: + c.RequestLogHook(v, req.Request, i) + default: + c.RequestLogHook(nil, req.Request, i) + } + } + + // Attempt the request + resp, doErr = c.HTTPClient.Do(req.Request) + + // Check if we should continue with retries. + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) + if !shouldRetry && doErr == nil && req.responseHandler != nil { + respErr = req.responseHandler(resp) + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr) + } + + err := doErr + if respErr != nil { + err = respErr + } + if err != nil { + switch v := logger.(type) { + case LeveledLogger: + v.Error("request failed", "error", err, "method", req.Method, "url", req.URL) + case Logger: + v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err) + } + } else { + // Call this here to maintain the behavior of logging all requests, + // even if CheckRetry signals to stop. + if c.ResponseLogHook != nil { + // Call the response logger function if provided. + switch v := logger.(type) { + case LeveledLogger: + c.ResponseLogHook(hookLogger{v}, resp) + case Logger: + c.ResponseLogHook(v, resp) + default: + c.ResponseLogHook(nil, resp) + } + } + } + + if !shouldRetry { + break + } + + // We do this before drainBody because there's no need for the I/O if + // we're breaking out + remain := c.RetryMax - i + if remain <= 0 { + break + } + + // We're going to retry, consume any response to reuse the connection. + if doErr == nil { + c.drainBody(resp.Body) + } + + wait := c.Backoff(c.RetryWaitMin, c.RetryWaitMax, i, resp) + if logger != nil { + desc := fmt.Sprintf("%s %s", req.Method, req.URL) + if resp != nil { + desc = fmt.Sprintf("%s (status: %d)", desc, resp.StatusCode) + } + switch v := logger.(type) { + case LeveledLogger: + v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain) + case Logger: + v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) + } + } + timer := time.NewTimer(wait) + select { + case <-req.Context().Done(): + timer.Stop() + c.HTTPClient.CloseIdleConnections() + return nil, req.Context().Err() + case <-timer.C: + } + + // Make shallow copy of http Request so that we can modify its body + // without racing against the closeBody call in persistConn.writeLoop. + httpreq := *req.Request + req.Request = &httpreq + + if err := c.PrepareRetry(req.Request); err != nil { + prepareErr = err + break + } + } + + // this is the closest we have to success criteria + if doErr == nil && respErr == nil && checkErr == nil && prepareErr == nil && !shouldRetry { + return resp, nil + } + + defer c.HTTPClient.CloseIdleConnections() + + var err error + if prepareErr != nil { + err = prepareErr + } else if checkErr != nil { + err = checkErr + } else if respErr != nil { + err = respErr + } else { + err = doErr + } + + if c.ErrorHandler != nil { + return c.ErrorHandler(resp, err, attempt) + } + + // By default, we close the response body and return an error without + // returning the response + if resp != nil { + c.drainBody(resp.Body) + } + + // this means CheckRetry thought the request was a failure, but didn't + // communicate why + if err == nil { + return nil, fmt.Errorf("%s %s giving up after %d attempt(s)", + req.Method, req.URL, attempt) + } + + return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w", + req.Method, req.URL, attempt, err) +} + +// Try to read the response body so we can reuse this connection. +func (c *Client) drainBody(body io.ReadCloser) { + defer body.Close() + _, err := io.Copy(ioutil.Discard, io.LimitReader(body, respReadLimit)) + if err != nil { + if c.logger() != nil { + switch v := c.logger().(type) { + case LeveledLogger: + v.Error("error reading response body", "error", err) + case Logger: + v.Printf("[ERR] error reading response body: %v", err) + } + } + } +} + +// Get is a shortcut for doing a GET request without making a new client. +func Get(url string) (*http.Response, error) { + return defaultClient.Get(url) +} + +// Get is a convenience helper for doing simple GET requests. +func (c *Client) Get(url string) (*http.Response, error) { + req, err := NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +// Head is a shortcut for doing a HEAD request without making a new client. +func Head(url string) (*http.Response, error) { + return defaultClient.Head(url) +} + +// Head is a convenience method for doing simple HEAD requests. +func (c *Client) Head(url string) (*http.Response, error) { + req, err := NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +// Post is a shortcut for doing a POST request without making a new client. +func Post(url, bodyType string, body interface{}) (*http.Response, error) { + return defaultClient.Post(url, bodyType, body) +} + +// Post is a convenience method for doing simple POST requests. +func (c *Client) Post(url, bodyType string, body interface{}) (*http.Response, error) { + req, err := NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return c.Do(req) +} + +// PostForm is a shortcut to perform a POST with form data without creating +// a new client. +func PostForm(url string, data url.Values) (*http.Response, error) { + return defaultClient.PostForm(url, data) +} + +// PostForm is a convenience method for doing simple POST operations using +// pre-filled url.Values form data. +func (c *Client) PostForm(url string, data url.Values) (*http.Response, error) { + return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + +// StandardClient returns a stdlib *http.Client with a custom Transport, which +// shims in a *retryablehttp.Client for added retries. +func (c *Client) StandardClient() *http.Client { + return &http.Client{ + Transport: &RoundTripper{Client: c}, + } +} diff --git a/pkg/retryablehttp/client_test.go b/pkg/retryablehttp/client_test.go new file mode 100644 index 000000000..a751d3fd2 --- /dev/null +++ b/pkg/retryablehttp/client_test.go @@ -0,0 +1,1162 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package retryablehttp + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/hashicorp/go-hclog" +) + +func TestRequest(t *testing.T) { + // Fails on invalid request + _, err := NewRequest("GET", "://foo", nil) + if err == nil { + t.Fatalf("should error") + } + + // Works with no request body + _, err = NewRequest("GET", "http://foo", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Works with request body + body := bytes.NewReader([]byte("yo")) + req, err := NewRequest("GET", "/", body) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Request allows typical HTTP request forming methods + req.Header.Set("X-Test", "foo") + if v, ok := req.Header["X-Test"]; !ok || len(v) != 1 || v[0] != "foo" { + t.Fatalf("bad headers: %v", req.Header) + } + + // Sets the Content-Length automatically for LenReaders + if req.ContentLength != 2 { + t.Fatalf("bad ContentLength: %d", req.ContentLength) + } +} + +func TestFromRequest(t *testing.T) { + // Works with no request body + httpReq, err := http.NewRequest("GET", "http://foo", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + _, err = FromRequest(httpReq) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Works with request body + body := bytes.NewReader([]byte("yo")) + httpReq, err = http.NewRequest("GET", "/", body) + if err != nil { + t.Fatalf("err: %v", err) + } + req, err := FromRequest(httpReq) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Preserves headers + httpReq.Header.Set("X-Test", "foo") + if v, ok := req.Header["X-Test"]; !ok || len(v) != 1 || v[0] != "foo" { + t.Fatalf("bad headers: %v", req.Header) + } + + // Preserves the Content-Length automatically for LenReaders + if req.ContentLength != 2 { + t.Fatalf("bad ContentLength: %d", req.ContentLength) + } +} + +// Since normal ways we would generate a Reader have special cases, use a +// custom type here +type custReader struct { + val string + pos int +} + +func (c *custReader) Read(p []byte) (n int, err error) { + if c.val == "" { + c.val = "hello" + } + if c.pos >= len(c.val) { + return 0, io.EOF + } + var i int + for i = 0; i < len(p) && i+c.pos < len(c.val); i++ { + p[i] = c.val[i+c.pos] + } + c.pos += i + return i, nil +} + +func TestClient_Do(t *testing.T) { + testBytes := []byte("hello") + // Native func + testClientDo(t, ReaderFunc(func() (io.Reader, error) { + return bytes.NewReader(testBytes), nil + })) + // Native func, different Go type + testClientDo(t, func() (io.Reader, error) { + return bytes.NewReader(testBytes), nil + }) + // []byte + testClientDo(t, testBytes) + // *bytes.Buffer + testClientDo(t, bytes.NewBuffer(testBytes)) + // *bytes.Reader + testClientDo(t, bytes.NewReader(testBytes)) + // io.ReadSeeker + testClientDo(t, strings.NewReader(string(testBytes))) + // io.Reader + testClientDo(t, &custReader{}) +} + +func testClientDo(t *testing.T, body interface{}) { + // Create a request + req, err := NewRequest("PUT", "http://127.0.0.1:28934/v1/foo", body) + if err != nil { + t.Fatalf("err: %v", err) + } + req.Header.Set("foo", "bar") + + // Track the number of times the logging hook was called + retryCount := -1 + + // Create the client. Use short retry windows. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 50 * time.Millisecond + client.RetryMax = 50 + client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) { + retryCount = retryNumber + + if logger != client.Logger { + t.Fatalf("Client logger was not passed to logging hook") + } + + dumpBytes, err := httputil.DumpRequestOut(req, false) + if err != nil { + t.Fatal("Dumping requests failed") + } + + dumpString := string(dumpBytes) + if !strings.Contains(dumpString, "PUT /v1/foo") { + t.Fatalf("Bad request dump:\n%s", dumpString) + } + } + + // Send the request + var resp *http.Response + doneCh := make(chan struct{}) + errCh := make(chan error, 1) + go func() { + defer close(doneCh) + defer close(errCh) + var err error + resp, err = client.Do(req) + errCh <- err + }() + + select { + case <-doneCh: + t.Fatalf("should retry on error") + case <-time.After(200 * time.Millisecond): + // Client should still be retrying due to connection failure. + } + + // Create the mock handler. First we return a 500-range response to ensure + // that we power through and keep retrying in the face of recoverable + // errors. + code := int64(500) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the request details + if r.Method != "PUT" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/v1/foo" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + + // Check the headers + if v := r.Header.Get("foo"); v != "bar" { + t.Fatalf("bad header: expect foo=bar, got foo=%v", v) + } + + // Check the payload + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("err: %s", err) + } + expected := []byte("hello") + if !bytes.Equal(body, expected) { + t.Fatalf("bad: %v", body) + } + + w.WriteHeader(int(atomic.LoadInt64(&code))) + }) + + // Create a test server + list, err := net.Listen("tcp", ":28934") + if err != nil { + t.Fatalf("err: %v", err) + } + defer list.Close() + go http.Serve(list, handler) + + // Wait again + select { + case <-doneCh: + t.Fatalf("should retry on 500-range") + case <-time.After(200 * time.Millisecond): + // Client should still be retrying due to 500's. + } + + // Start returning 200's + atomic.StoreInt64(&code, 200) + + // Wait again + select { + case <-doneCh: + case <-time.After(time.Second): + t.Fatalf("timed out") + } + + if resp.StatusCode != 200 { + t.Fatalf("exected 200, got: %d", resp.StatusCode) + } + + if retryCount < 0 { + t.Fatal("request log hook was not called") + } + + err = <-errCh + if err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestClient_Do_WithResponseHandler(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.RetryMax = 2 + + var checks int + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + checks++ + if err != nil && strings.Contains(err.Error(), "nonretryable") { + return false, nil + } + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + var shouldSucceed bool + tests := []struct { + name string + handler ResponseHandlerFunc + expectedChecks int // often 2x number of attempts since we check twice + err string + }{ + { + name: "nil handler", + handler: nil, + expectedChecks: 1, + }, + { + name: "handler always succeeds", + handler: func(*http.Response) error { + return nil + }, + expectedChecks: 2, + }, + { + name: "handler always fails in a retryable way", + handler: func(*http.Response) error { + return errors.New("retryable failure") + }, + expectedChecks: 6, + }, + { + name: "handler always fails in a nonretryable way", + handler: func(*http.Response) error { + return errors.New("nonretryable failure") + }, + expectedChecks: 2, + }, + { + name: "handler succeeds on second attempt", + handler: func(*http.Response) error { + if shouldSucceed { + return nil + } + shouldSucceed = true + return errors.New("retryable failure") + }, + expectedChecks: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + checks = 0 + shouldSucceed = false + // Create the request + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.SetResponseHandler(tt.handler) + + // Send the request. + _, err = client.Do(req) + if err != nil && !strings.Contains(err.Error(), tt.err) { + t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) + } + if err == nil && tt.err != "" { + t.Fatalf("no error, expected: %s", tt.err) + } + + if checks != tt.expectedChecks { + t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks) + } + }) + } +} + +func TestClient_Do_WithPrepareRetry(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.RetryMax = 2 + + var checks int + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + checks++ + if err != nil && strings.Contains(err.Error(), "nonretryable") { + return false, nil + } + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + var prepareChecks int + client.PrepareRetry = func(req *http.Request) error { + prepareChecks++ + req.Header.Set("foo", strconv.Itoa(prepareChecks)) + return nil + } + + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + var shouldSucceed bool + tests := []struct { + name string + handler ResponseHandlerFunc + expectedChecks int // often 2x number of attempts since we check twice + expectedPrepareChecks int + err string + }{ + { + name: "nil handler", + handler: nil, + expectedChecks: 1, + expectedPrepareChecks: 0, + }, + { + name: "handler always succeeds", + handler: func(*http.Response) error { + return nil + }, + expectedChecks: 2, + expectedPrepareChecks: 0, + }, + { + name: "handler always fails in a retryable way", + handler: func(*http.Response) error { + return errors.New("retryable failure") + }, + expectedChecks: 6, + expectedPrepareChecks: 2, + }, + { + name: "handler always fails in a nonretryable way", + handler: func(*http.Response) error { + return errors.New("nonretryable failure") + }, + expectedChecks: 2, + expectedPrepareChecks: 0, + }, + { + name: "handler succeeds on second attempt", + handler: func(*http.Response) error { + if shouldSucceed { + return nil + } + shouldSucceed = true + return errors.New("retryable failure") + }, + expectedChecks: 4, + expectedPrepareChecks: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + checks = 0 + prepareChecks = 0 + shouldSucceed = false + // Create the request + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.SetResponseHandler(tt.handler) + + // Send the request. + _, err = client.Do(req) + if err != nil && !strings.Contains(err.Error(), tt.err) { + t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) + } + if err == nil && tt.err != "" { + t.Fatalf("no error, expected: %s", tt.err) + } + + if checks != tt.expectedChecks { + t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks) + } + + if prepareChecks != tt.expectedPrepareChecks { + t.Fatalf("expected %d attempts of prepare check, got %d attempts", tt.expectedPrepareChecks, prepareChecks) + } + header := req.Request.Header.Get("foo") + if tt.expectedPrepareChecks == 0 && header != "" { + t.Fatalf("expected no changes to request header 'foo', but got '%s'", header) + } + expectedHeader := strconv.Itoa(tt.expectedPrepareChecks) + if tt.expectedPrepareChecks != 0 && header != expectedHeader { + t.Fatalf("expected changes in request header 'foo' '%s', but got '%s'", expectedHeader, header) + } + + }) + } +} + +func TestClient_Do_fails(t *testing.T) { + // Mock server which always responds 500. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer ts.Close() + + tests := []struct { + name string + cr CheckRetry + err string + }{ + { + name: "default_retry_policy", + cr: DefaultRetryPolicy, + err: "giving up after 3 attempt(s)", + }, + { + name: "error_propagated_retry_policy", + cr: ErrorPropagatedRetryPolicy, + err: "giving up after 3 attempt(s): unexpected HTTP status 500 Internal Server Error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.CheckRetry = tt.cr + client.RetryMax = 2 + + // Create the request + req, err := NewRequest("POST", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Send the request. + _, err = client.Do(req) + if err == nil || !strings.HasSuffix(err.Error(), tt.err) { + t.Fatalf("expected giving up error, got: %#v", err) + } + }) + } +} + +func TestClient_Get(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/foo/bar" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + w.WriteHeader(200) + })) + defer ts.Close() + + // Make the request. + resp, err := NewClient().Get(ts.URL + "/foo/bar") + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() +} + +func TestClient_RequestLogHook(t *testing.T) { + t.Run("RequestLogHook successfully called with default Logger", func(t *testing.T) { + testClientRequestLogHook(t, defaultLogger) + }) + t.Run("RequestLogHook successfully called with nil Logger", func(t *testing.T) { + testClientRequestLogHook(t, nil) + }) + t.Run("RequestLogHook successfully called with nil typed Logger", func(t *testing.T) { + testClientRequestLogHook(t, Logger(nil)) + }) + t.Run("RequestLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) { + testClientRequestLogHook(t, LeveledLogger(nil)) + }) +} + +func testClientRequestLogHook(t *testing.T, logger interface{}) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/foo/bar" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + w.WriteHeader(200) + })) + defer ts.Close() + + retries := -1 + testURIPath := "/foo/bar" + + client := NewClient() + client.Logger = logger + client.RequestLogHook = func(logger Logger, req *http.Request, retry int) { + retries = retry + + if logger != client.Logger { + t.Fatalf("Client logger was not passed to logging hook") + } + + dumpBytes, err := httputil.DumpRequestOut(req, false) + if err != nil { + t.Fatal("Dumping requests failed") + } + + dumpString := string(dumpBytes) + if !strings.Contains(dumpString, "GET "+testURIPath) { + t.Fatalf("Bad request dump:\n%s", dumpString) + } + } + + // Make the request. + resp, err := client.Get(ts.URL + testURIPath) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() + + if retries < 0 { + t.Fatal("Logging hook was not called") + } +} + +func TestClient_ResponseLogHook(t *testing.T) { + t.Run("ResponseLogHook successfully called with hclog Logger", func(t *testing.T) { + buf := new(bytes.Buffer) + l := hclog.New(&hclog.LoggerOptions{ + Output: buf, + }) + testClientResponseLogHook(t, l, buf) + }) + t.Run("ResponseLogHook successfully called with nil Logger", func(t *testing.T) { + buf := new(bytes.Buffer) + testClientResponseLogHook(t, nil, buf) + }) + t.Run("ResponseLogHook successfully called with nil typed Logger", func(t *testing.T) { + buf := new(bytes.Buffer) + testClientResponseLogHook(t, Logger(nil), buf) + }) + t.Run("ResponseLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) { + buf := new(bytes.Buffer) + testClientResponseLogHook(t, LeveledLogger(nil), buf) + }) +} + +func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) { + passAfter := time.Now().Add(100 * time.Millisecond) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if time.Now().After(passAfter) { + w.WriteHeader(200) + w.Write([]byte("test_200_body")) + } else { + w.WriteHeader(500) + w.Write([]byte("test_500_body")) + } + })) + defer ts.Close() + + client := NewClient() + + client.Logger = l + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.RetryMax = 15 + client.ResponseLogHook = func(logger Logger, resp *http.Response) { + if resp.StatusCode == 200 { + successLog := "test_log_pass" + // Log something when we get a 200 + if logger != nil { + logger.Printf(successLog) + } else { + buf.WriteString(successLog) + } + } else { + // Log the response body when we get a 500 + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("err: %v", err) + } + failLog := string(body) + if logger != nil { + logger.Printf(failLog) + } else { + buf.WriteString(failLog) + } + } + } + + // Perform the request. Exits when we finally get a 200. + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Make sure we can read the response body still, since we did not + // read or close it from the response log hook. + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("err: %v", err) + } + if string(body) != "test_200_body" { + t.Fatalf("expect %q, got %q", "test_200_body", string(body)) + } + + // Make sure we wrote to the logger on callbacks. + out := buf.String() + if !strings.Contains(out, "test_log_pass") { + t.Fatalf("expect response callback on 200: %q", out) + } + if !strings.Contains(out, "test_500_body") { + t.Fatalf("expect response callback on 500: %q", out) + } +} + +func TestClient_NewRequestWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + r, err := NewRequestWithContext(ctx, http.MethodGet, "/abc", nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if r.Context() != ctx { + t.Fatal("Context must be set") + } +} + +func TestClient_RequestWithContext(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("test_200_body")) + })) + defer ts.Close() + + req, err := NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + ctx, cancel := context.WithCancel(req.Request.Context()) + reqCtx := req.WithContext(ctx) + if reqCtx == req { + t.Fatal("WithContext must return a new Request object") + } + + client := NewClient() + + called := 0 + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + called++ + return DefaultRetryPolicy(reqCtx.Request.Context(), resp, err) + } + + cancel() + _, err = client.Do(reqCtx) + + if called != 1 { + t.Fatalf("CheckRetry called %d times, expected 1", called) + } + + e := fmt.Sprintf("GET %s giving up after 1 attempt(s): %s", ts.URL, context.Canceled.Error()) + + if err.Error() != e { + t.Fatalf("Expected err to contain %s, got: %v", e, err) + } +} + +func TestClient_CheckRetry(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "test_500_body", http.StatusInternalServerError) + })) + defer ts.Close() + + client := NewClient() + + retryErr := errors.New("retryError") + called := 0 + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + if called < 1 { + called++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + return false, retryErr + } + + // CheckRetry should return our retryErr value and stop the retry loop. + _, err := client.Get(ts.URL) + + if called != 1 { + t.Fatalf("CheckRetry called %d times, expected 1", called) + } + + if err.Error() != fmt.Sprintf("GET %s giving up after 2 attempt(s): retryError", ts.URL) { + t.Fatalf("Expected retryError, got:%v", err) + } +} + +func TestClient_DefaultBackoff(t *testing.T) { + for _, code := range []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} { + t.Run(fmt.Sprintf("http_%d", code), func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "2") + http.Error(w, fmt.Sprintf("test_%d_body", code), code) + })) + defer ts.Close() + + client := NewClient() + + var retryAfter time.Duration + retryable := false + + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + retryable, _ = DefaultRetryPolicy(context.Background(), resp, err) + retryAfter = DefaultBackoff(client.RetryWaitMin, client.RetryWaitMax, 1, resp) + return false, nil + } + + _, err := client.Get(ts.URL) + if err != nil { + t.Fatalf("expected no errors since retryable") + } + + if !retryable { + t.Fatal("Since the error is recoverable, the default policy shall return true") + } + + if retryAfter != 2*time.Second { + t.Fatalf("The header Retry-After specified 2 seconds, and shall not be %d seconds", retryAfter/time.Second) + } + }) + } +} + +func TestClient_DefaultRetryPolicy_TLS(t *testing.T) { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + attempts := 0 + client := NewClient() + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + _, err := client.Get(ts.URL) + if err == nil { + t.Fatalf("expected x509 error, got nil") + } + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } +} + +func TestClient_DefaultRetryPolicy_redirects(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/", http.StatusFound) + })) + defer ts.Close() + + attempts := 0 + client := NewClient() + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + _, err := client.Get(ts.URL) + if err == nil { + t.Fatalf("expected redirect error, got nil") + } + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } +} + +func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + attempts := 0 + client := NewClient() + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + url := strings.Replace(ts.URL, "http", "ftp", 1) + _, err := client.Get(url) + if err == nil { + t.Fatalf("expected scheme error, got nil") + } + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } +} + +func TestClient_CheckRetryStop(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "test_500_body", http.StatusInternalServerError) + })) + defer ts.Close() + + client := NewClient() + + // Verify that this stops retries on the first try, with no errors from the client. + called := 0 + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + called++ + return false, nil + } + + _, err := client.Get(ts.URL) + + if called != 1 { + t.Fatalf("CheckRetry called %d times, expected 1", called) + } + + if err != nil { + t.Fatalf("Expected no error, got:%v", err) + } +} + +func TestClient_Head(t *testing.T) { + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/foo/bar" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + w.WriteHeader(200) + })) + defer ts.Close() + + // Make the request. + resp, err := NewClient().Head(ts.URL + "/foo/bar") + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() +} + +func TestClient_Post(t *testing.T) { + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/foo/bar" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Fatalf("bad content-type: %s", ct) + } + + // Check the payload + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("err: %s", err) + } + expected := []byte(`{"hello":"world"}`) + if !bytes.Equal(body, expected) { + t.Fatalf("bad: %v", body) + } + + w.WriteHeader(200) + })) + defer ts.Close() + + // Make the request. + resp, err := NewClient().Post( + ts.URL+"/foo/bar", + "application/json", + strings.NewReader(`{"hello":"world"}`)) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() +} + +func TestClient_PostForm(t *testing.T) { + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("bad method: %s", r.Method) + } + if r.RequestURI != "/foo/bar" { + t.Fatalf("bad uri: %s", r.RequestURI) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Fatalf("bad content-type: %s", ct) + } + + // Check the payload + body, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("err: %s", err) + } + expected := []byte(`hello=world`) + if !bytes.Equal(body, expected) { + t.Fatalf("bad: %v", body) + } + + w.WriteHeader(200) + })) + defer ts.Close() + + // Create the form data. + form, err := url.ParseQuery("hello=world") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Make the request. + resp, err := NewClient().PostForm(ts.URL+"/foo/bar", form) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() +} + +func TestBackoff(t *testing.T) { + type tcase struct { + min time.Duration + max time.Duration + i int + expect time.Duration + } + cases := []tcase{ + { + time.Second, + 5 * time.Minute, + 0, + time.Second, + }, + { + time.Second, + 5 * time.Minute, + 1, + 2 * time.Second, + }, + { + time.Second, + 5 * time.Minute, + 2, + 4 * time.Second, + }, + { + time.Second, + 5 * time.Minute, + 3, + 8 * time.Second, + }, + { + time.Second, + 5 * time.Minute, + 63, + 5 * time.Minute, + }, + { + time.Second, + 5 * time.Minute, + 128, + 5 * time.Minute, + }, + } + + for _, tc := range cases { + if v := DefaultBackoff(tc.min, tc.max, tc.i, nil); v != tc.expect { + t.Fatalf("bad: %#v -> %s", tc, v) + } + } +} + +func TestClient_BackoffCustom(t *testing.T) { + var retries int32 + + client := NewClient() + client.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { + atomic.AddInt32(&retries, 1) + return time.Millisecond * 1 + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.LoadInt32(&retries) == int32(client.RetryMax) { + w.WriteHeader(200) + return + } + w.WriteHeader(500) + })) + defer ts.Close() + + // Make the request. + resp, err := client.Get(ts.URL + "/foo/bar") + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() + if retries != int32(client.RetryMax) { + t.Fatalf("expected retries: %d != %d", client.RetryMax, retries) + } +} + +func TestClient_StandardClient(t *testing.T) { + // Create a retryable HTTP client. + client := NewClient() + + // Get a standard client. + standard := client.StandardClient() + + // Ensure the underlying retrying client is set properly. + if v := standard.Transport.(*RoundTripper).Client; v != client { + t.Fatalf("expected %v, got %v", client, v) + } +} + +func TestClient_RedirectWithBody(t *testing.T) { + var redirects int32 + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/redirect": + w.Header().Set("Location", "/target") + w.WriteHeader(http.StatusTemporaryRedirect) + case "/target": + atomic.AddInt32(&redirects, 1) + w.WriteHeader(http.StatusCreated) + default: + t.Fatalf("bad uri: %s", r.RequestURI) + } + })) + defer ts.Close() + + client := NewClient() + client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) { + if _, err := req.GetBody(); err != nil { + t.Fatalf("unexpected error with GetBody: %v", err) + } + } + // create a request with a body + req, err := NewRequest(http.MethodPost, ts.URL+"/redirect", strings.NewReader(`{"foo":"bar"}`)) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected status code 201, got: %d", resp.StatusCode) + } + + // now one without a body + if err := req.SetBody(nil); err != nil { + t.Fatalf("err: %v", err) + } + + resp, err = client.Do(req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected status code 201, got: %d", resp.StatusCode) + } + + if atomic.LoadInt32(&redirects) != 2 { + t.Fatalf("Expected the client to be redirected 2 times, got: %d", atomic.LoadInt32(&redirects)) + } +} diff --git a/pkg/retryablehttp/roundtripper.go b/pkg/retryablehttp/roundtripper.go new file mode 100644 index 000000000..8c407adb3 --- /dev/null +++ b/pkg/retryablehttp/roundtripper.go @@ -0,0 +1,55 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package retryablehttp + +import ( + "errors" + "net/http" + "net/url" + "sync" +) + +// RoundTripper implements the http.RoundTripper interface, using a retrying +// HTTP client to execute requests. +// +// It is important to note that retryablehttp doesn't always act exactly as a +// RoundTripper should. This is highly dependent on the retryable client's +// configuration. +type RoundTripper struct { + // The client to use during requests. If nil, the default retryablehttp + // client and settings will be used. + Client *Client + + // once ensures that the logic to initialize the default client runs at + // most once, in a single thread. + once sync.Once +} + +// init initializes the underlying retryable client. +func (rt *RoundTripper) init() { + if rt.Client == nil { + rt.Client = NewClient() + } +} + +// RoundTrip satisfies the http.RoundTripper interface. +func (rt *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.once.Do(rt.init) + + // Convert the request to be retryable. + retryableReq, err := FromRequest(req) + if err != nil { + return nil, err + } + + // Execute the request. + resp, err := rt.Client.Do(retryableReq) + // If we got an error returned by standard library's `Do` method, unwrap it + // otherwise we will wind up erroneously re-nesting the error. + if _, ok := err.(*url.Error); ok { + return resp, errors.Unwrap(err) + } + + return resp, err +} diff --git a/pkg/retryablehttp/roundtripper_test.go b/pkg/retryablehttp/roundtripper_test.go new file mode 100644 index 000000000..dcb02dfc9 --- /dev/null +++ b/pkg/retryablehttp/roundtripper_test.go @@ -0,0 +1,144 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package retryablehttp + +import ( + "context" + "errors" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "sync/atomic" + "testing" +) + +func TestRoundTripper_implements(t *testing.T) { + // Compile-time proof of interface satisfaction. + var _ http.RoundTripper = &RoundTripper{} +} + +func TestRoundTripper_init(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + // Start with a new empty RoundTripper. + rt := &RoundTripper{} + + // RoundTrip once. + req, _ := http.NewRequest("GET", ts.URL, nil) + if _, err := rt.RoundTrip(req); err != nil { + t.Fatal(err) + } + + // Check that the Client was initialized. + if rt.Client == nil { + t.Fatal("expected rt.Client to be initialized") + } + + // Save the Client for later comparison. + initialClient := rt.Client + + // RoundTrip again. + req, _ = http.NewRequest("GET", ts.URL, nil) + if _, err := rt.RoundTrip(req); err != nil { + t.Fatal(err) + } + + // Check that the underlying Client is unchanged. + if rt.Client != initialClient { + t.Fatalf("expected %v, got %v", initialClient, rt.Client) + } +} + +func TestRoundTripper_RoundTrip(t *testing.T) { + var reqCount int32 = 0 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqNo := atomic.AddInt32(&reqCount, 1) + if reqNo < 3 { + w.WriteHeader(404) + } else { + w.WriteHeader(200) + w.Write([]byte("success!")) + } + })) + defer ts.Close() + + // Make a client with some custom settings to verify they are used. + retryClient := NewClient() + retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) { + return resp.StatusCode == 404, nil + } + + // Get the standard client and execute the request. + client := retryClient.StandardClient() + resp, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Check the response to ensure the client behaved as expected. + if resp.StatusCode != 200 { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if v, err := ioutil.ReadAll(resp.Body); err != nil { + t.Fatal(err) + } else if string(v) != "success!" { + t.Fatalf("expected %q, got %q", "success!", v) + } +} + +func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) { + // Make a client with some custom settings to verify they are used. + retryClient := NewClient() + retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + if err != nil { + return true, err + } + + return false, nil + } + + retryClient.ErrorHandler = PassthroughErrorHandler + + expectedError := &url.Error{ + Op: "Get", + URL: "http://999.999.999.999:999/", + Err: &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &net.DNSError{ + Name: "999.999.999.999", + Err: "no such host", + IsNotFound: true, + }, + }, + } + + // Get the standard client and execute the request. + client := retryClient.StandardClient() + _, err := client.Get("http://999.999.999.999:999/") + + // assert expectations + if !reflect.DeepEqual(expectedError, normalizeError(err)) { + t.Fatalf("expected %q, got %q", expectedError, err) + } +} + +func normalizeError(err error) error { + var dnsError *net.DNSError + + if errors.As(err, &dnsError) { + // this field is populated with the DNS server on on CI, but not locally + dnsError.Server = "" + } + + return err +}