Skip to content

Commit

Permalink
Merge pull request #84 from rtrox/rtrox/header-fix
Browse files Browse the repository at this point in the history
set custom user-agent on all requests
  • Loading branch information
showwin authored Nov 28, 2022
2 parents 6baf9c5 + 5c0ffd2 commit 6b57ad2
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 8 deletions.
2 changes: 1 addition & 1 deletion speedtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type fullOutput struct {
type outputTime time.Time

func main() {
kingpin.Version("1.2.1")
kingpin.Version(speedtest.Version())
kingpin.Parse()

user, err := speedtest.FetchUserInfo()
Expand Down
4 changes: 0 additions & 4 deletions speedtest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ func (client *Speedtest) FetchServerListContext(ctx context.Context, user *User)
return Servers{}, err
}

req.Header.Set("User-Agent", "Go-http-client/1.1") // Request could be rejected if not initialized

resp, err := client.doer.Do(req)
if err != nil {
return Servers{}, err
Expand All @@ -105,8 +103,6 @@ func (client *Speedtest) FetchServerListContext(ctx context.Context, user *User)
return Servers{}, err
}

req.Header.Set("User-Agent", "Go-http-client/1.1") // Request could be rejected if not initialized

resp, err = client.doer.Do(req)
if err != nil {
return Servers{}, err
Expand Down
41 changes: 40 additions & 1 deletion speedtest/speedtest.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
package speedtest

import "net/http"
import (
"fmt"
"net/http"
)

var (
version = "1.2.1"
defaultUserAgent = fmt.Sprintf("showwin/speedtest-go %s", version)
)

// Speedtest is a speedtest client.
type Speedtest struct {
doer *http.Client
}

type userAgentTransport struct {
T http.RoundTripper
UserAgent string
}

func newUserAgentTransport(T http.RoundTripper, UserAgent string) *userAgentTransport {
if T == nil {
T = http.DefaultTransport
}
return &userAgentTransport{T, UserAgent}
}

func (uat *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("User-Agent", uat.UserAgent)
return uat.T.RoundTrip(req)
}

// Option is a function that can be passed to New to modify the Client.
type Option func(*Speedtest)

Expand All @@ -17,11 +42,21 @@ func WithDoer(doer *http.Client) Option {
}
}

// WithUserAgent adds the passed "User-Agent" header to all requests.
// To use with a custom Doer, "WithDoer" must be passed before WithUserAgent:
// `New(WithDoer(myDoer), WithUserAgent(myUserAgent))`
func WithUserAgent(UserAgent string) Option {
return func(s *Speedtest) {
s.doer.Transport = newUserAgentTransport(s.doer.Transport, UserAgent)
}
}

// New creates a new speedtest client.
func New(opts ...Option) *Speedtest {
s := &Speedtest{
doer: http.DefaultClient,
}
WithUserAgent(defaultUserAgent)(s)

for _, opt := range opts {
opt(s)
Expand All @@ -30,4 +65,8 @@ func New(opts ...Option) *Speedtest {
return s
}

func Version() string {
return version
}

var defaultClient = New()
37 changes: 37 additions & 0 deletions speedtest/speedtest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package speedtest

import (
"net/http"
"net/http/httptest"
"testing"
)

Expand All @@ -22,5 +23,41 @@ func TestNew(t *testing.T) {
t.Error("doer is not the same")
}
})
}

func TestUserAgent(t *testing.T) {
testServer := func(expectedUserAgent string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.UserAgent() == "" {
t.Error("Did not receive User-Agent header")
} else if r.UserAgent() != expectedUserAgent {
t.Errorf("Incorrect User-Agent header: %s, expected: %s", r.UserAgent(), expectedUserAgent)
}
}))
}

t.Run("DefaultUserAgent", func(t *testing.T) {
c := New()
s := testServer(defaultUserAgent)
c.doer.Get(s.URL)
})

t.Run("CustomUserAgent", func(t *testing.T) {
testAgent := "asdf1234"
s := testServer(testAgent)
c := New(WithUserAgent(testAgent))
c.doer.Get(s.URL)
})

// Test that With
t.Run("CustomUserAgentAndDoer", func(t *testing.T) {
testAgent := "asdf2345"
doer := &http.Client{}
s := testServer(testAgent)
c := New(WithDoer(doer), WithUserAgent(testAgent))
if c.doer != doer {
t.Error("doer is not the same")
}
c.doer.Get(s.URL)
})
}
2 changes: 0 additions & 2 deletions speedtest/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ func (client *Speedtest) FetchUserInfoContext(ctx context.Context) (*User, error
return nil, err
}

req.Header.Set("User-Agent", "Go-http-client/1.1") // Request could be rejected if not initialized

resp, err := client.doer.Do(req)
if err != nil {
return nil, err
Expand Down

0 comments on commit 6b57ad2

Please sign in to comment.