diff --git a/pkg/geoipupdate/database/http_reader.go b/pkg/geoipupdate/database/http_reader.go index e2f93428..78677016 100644 --- a/pkg/geoipupdate/database/http_reader.go +++ b/pkg/geoipupdate/database/http_reader.go @@ -14,8 +14,6 @@ import ( "strconv" "time" - "github.com/cenkalti/backoff/v4" - "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/vars" ) @@ -71,39 +69,7 @@ func NewHTTPReader( // It's the responsibility of the Writer to close the io.ReadCloser // included in the response after consumption. func (r *HTTPReader) Read(ctx context.Context, editionID, hash string) (*ReadResult, error) { - var result *ReadResult - var err error - - // RetryFor value of 0 means that no retries should be performed. - // Max zero retries has to be set to achieve that - // because the backoff never stops if MaxElapsedTime is zero. - exp := backoff.NewExponentialBackOff() - exp.MaxElapsedTime = r.retryFor - b := backoff.BackOff(exp) - if exp.MaxElapsedTime == 0 { - b = backoff.WithMaxRetries(exp, 0) - } - err = backoff.RetryNotify( - func() error { - result, err = r.get(ctx, editionID, hash) - if err == nil { - return nil - } - - var httpErr internal.HTTPError - if errors.As(err, &httpErr) && httpErr.StatusCode >= 400 && httpErr.StatusCode < 500 { - return backoff.Permanent(err) - } - - return err - }, - b, - func(err error, d time.Duration) { - if r.verbose { - log.Printf("Couldn't download %s, retrying in %v: %v", editionID, d, err) - } - }, - ) + result, err := r.get(ctx, editionID, hash) if err != nil { return nil, fmt.Errorf("getting update for %s: %w", editionID, err) } diff --git a/pkg/geoipupdate/geoip_updater.go b/pkg/geoipupdate/geoip_updater.go index 8aa9e7d1..e80f03ec 100644 --- a/pkg/geoipupdate/geoip_updater.go +++ b/pkg/geoipupdate/geoip_updater.go @@ -5,7 +5,6 @@ package geoipupdate import ( "context" "encoding/json" - "errors" "fmt" "log" "os" @@ -13,7 +12,6 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "golang.org/x/net/http2" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/database" "github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/internal" @@ -147,14 +145,16 @@ func (c *Client) downloadEdition( var edition *database.ReadResult err = backoff.RetryNotify( func() error { - edition, err = r.Read(ctx, editionID, editionHash) - if err != nil { + if edition, err = r.Read(ctx, editionID, editionHash); err != nil { + if internal.IsTemporaryError(err) { + return err + } + return backoff.Permanent(err) } if err = w.Write(edition); err != nil { - streamErr := http2.StreamError{} - if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeInternal { + if internal.IsTemporaryError(err) { return err } diff --git a/pkg/geoipupdate/internal/errors.go b/pkg/geoipupdate/internal/errors.go index 9194c85b..8d280412 100644 --- a/pkg/geoipupdate/internal/errors.go +++ b/pkg/geoipupdate/internal/errors.go @@ -2,7 +2,10 @@ package internal import ( + "errors" "fmt" + + "golang.org/x/net/http2" ) // HTTPError is an error from performing an HTTP request. @@ -14,3 +17,19 @@ type HTTPError struct { func (h HTTPError) Error() string { return fmt.Sprintf("received HTTP status code: %d: %s", h.StatusCode, h.Body) } + +// IsTemporaryError returns true if the error is temporary. +func IsTemporaryError(err error) bool { + var httpErr HTTPError + if errors.As(err, &httpErr) { + isPermanent := httpErr.StatusCode >= 400 && httpErr.StatusCode < 500 + return !isPermanent + } + + var streamErr http2.StreamError + if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeInternal { + return true + } + + return false +}