Skip to content

Commit

Permalink
Add HTTP2 INTERNAL_ERROR test
Browse files Browse the repository at this point in the history
  • Loading branch information
marselester committed Jan 25, 2024
1 parent 80f9b55 commit 55ef27a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pkg/geoipupdate/geoip_updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (c *Client) downloadEdition(

if err = w.Write(edition); err != nil {
streamErr := http2.StreamError{}
if errors.As(err, &streamErr) && streamErr.Code.String() == "INTERNAL_ERROR" {
if errors.As(err, &streamErr) && streamErr.Code == http2.ErrCodeInternal {
return err
}

Expand Down
33 changes: 29 additions & 4 deletions pkg/geoipupdate/geoip_updater_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"testing"
"time"

"github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/database"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"

"github.com/maxmind/geoipupdate/v6/pkg/geoipupdate/database"
)

// TestClientOutput makes sure that the client outputs the result of it's
Expand Down Expand Up @@ -77,6 +79,21 @@ func TestClientOutput(t *testing.T) {
t.Errorf("database %s was not updated", outputDatabases[i].EditionID)
}
}

streamErr := http2.StreamError{
Code: http2.ErrCodeInternal,
}
c.getWriter = func() (database.Writer, error) {
w := mockWriter{
WriteFunc: func(_ *database.ReadResult) error {
return streamErr
},
}

return &w, nil
}
err = c.Run(context.Background())
require.ErrorIs(t, err, streamErr)
}

type mockReader struct {
Expand All @@ -93,10 +110,18 @@ func (mr *mockReader) Read(_ context.Context, _, _ string) (*database.ReadResult
return &res, nil
}

type mockWriter struct{}
type mockWriter struct {
WriteFunc func(*database.ReadResult) error
}

func (w *mockWriter) Write(r *database.ReadResult) error {
if w.WriteFunc != nil {
return w.WriteFunc(r)
}

func (w *mockWriter) Write(_ *database.ReadResult) error { return nil }
func (w mockWriter) GetHash(_ string) (string, error) { return "", nil }
return nil
}
func (w mockWriter) GetHash(_ string) (string, error) { return "", nil }

func afterOrEqual(t1, t2 time.Time) bool {
return t1.After(t2) || t1.Equal(t2)
Expand Down

0 comments on commit 55ef27a

Please sign in to comment.