diff --git a/speedtest.go b/speedtest.go index 8d7b9fe..9d717f4 100644 --- a/speedtest.go +++ b/speedtest.go @@ -28,7 +28,7 @@ type fullOutput struct { type outputTime time.Time func main() { - kingpin.Version("1.2.1") + kingpin.Version(speedtest.Version()) kingpin.Parse() user, err := speedtest.FetchUserInfo() diff --git a/speedtest/server.go b/speedtest/server.go index 2687966..5701788 100644 --- a/speedtest/server.go +++ b/speedtest/server.go @@ -88,8 +88,6 @@ func (client *Speedtest) FetchServerListContext(ctx context.Context, user *User) return Servers{}, err } - req.Header.Set("User-Agent", "Go-http-client/1.1") // Request could be rejected if not initialized - resp, err := client.doer.Do(req) if err != nil { return Servers{}, err @@ -105,8 +103,6 @@ func (client *Speedtest) FetchServerListContext(ctx context.Context, user *User) return Servers{}, err } - req.Header.Set("User-Agent", "Go-http-client/1.1") // Request could be rejected if not initialized - resp, err = client.doer.Do(req) if err != nil { return Servers{}, err diff --git a/speedtest/speedtest.go b/speedtest/speedtest.go index fe923ba..d4b2d74 100644 --- a/speedtest/speedtest.go +++ b/speedtest/speedtest.go @@ -1,12 +1,37 @@ package speedtest -import "net/http" +import ( + "fmt" + "net/http" +) + +var ( + version = "1.2.1" + defaultUserAgent = fmt.Sprintf("showwin/speedtest-go %s", version) +) // Speedtest is a speedtest client. type Speedtest struct { doer *http.Client } +type userAgentTransport struct { + T http.RoundTripper + UserAgent string +} + +func newUserAgentTransport(T http.RoundTripper, UserAgent string) *userAgentTransport { + if T == nil { + T = http.DefaultTransport + } + return &userAgentTransport{T, UserAgent} +} + +func (uat *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Add("User-Agent", uat.UserAgent) + return uat.T.RoundTrip(req) +} + // Option is a function that can be passed to New to modify the Client. type Option func(*Speedtest) @@ -17,11 +42,21 @@ func WithDoer(doer *http.Client) Option { } } +// WithUserAgent adds the passed "User-Agent" header to all requests. +// To use with a custom Doer, "WithDoer" must be passed before WithUserAgent: +// `New(WithDoer(myDoer), WithUserAgent(myUserAgent))` +func WithUserAgent(UserAgent string) Option { + return func(s *Speedtest) { + s.doer.Transport = newUserAgentTransport(s.doer.Transport, UserAgent) + } +} + // New creates a new speedtest client. func New(opts ...Option) *Speedtest { s := &Speedtest{ doer: http.DefaultClient, } + WithUserAgent(defaultUserAgent)(s) for _, opt := range opts { opt(s) @@ -30,4 +65,8 @@ func New(opts ...Option) *Speedtest { return s } +func Version() string { + return version +} + var defaultClient = New() diff --git a/speedtest/speedtest_test.go b/speedtest/speedtest_test.go index 6d502d3..1e1b151 100644 --- a/speedtest/speedtest_test.go +++ b/speedtest/speedtest_test.go @@ -2,6 +2,7 @@ package speedtest import ( "net/http" + "net/http/httptest" "testing" ) @@ -22,5 +23,41 @@ func TestNew(t *testing.T) { t.Error("doer is not the same") } }) +} + +func TestUserAgent(t *testing.T) { + testServer := func(expectedUserAgent string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.UserAgent() == "" { + t.Error("Did not receive User-Agent header") + } else if r.UserAgent() != expectedUserAgent { + t.Errorf("Incorrect User-Agent header: %s, expected: %s", r.UserAgent(), expectedUserAgent) + } + })) + } + + t.Run("DefaultUserAgent", func(t *testing.T) { + c := New() + s := testServer(defaultUserAgent) + c.doer.Get(s.URL) + }) + + t.Run("CustomUserAgent", func(t *testing.T) { + testAgent := "asdf1234" + s := testServer(testAgent) + c := New(WithUserAgent(testAgent)) + c.doer.Get(s.URL) + }) + // Test that With + t.Run("CustomUserAgentAndDoer", func(t *testing.T) { + testAgent := "asdf2345" + doer := &http.Client{} + s := testServer(testAgent) + c := New(WithDoer(doer), WithUserAgent(testAgent)) + if c.doer != doer { + t.Error("doer is not the same") + } + c.doer.Get(s.URL) + }) } diff --git a/speedtest/user.go b/speedtest/user.go index 46d8919..2bcc79e 100644 --- a/speedtest/user.go +++ b/speedtest/user.go @@ -45,8 +45,6 @@ func (client *Speedtest) FetchUserInfoContext(ctx context.Context) (*User, error return nil, err } - req.Header.Set("User-Agent", "Go-http-client/1.1") // Request could be rejected if not initialized - resp, err := client.doer.Do(req) if err != nil { return nil, err