diff --git a/speedtest/request.go b/speedtest/request.go index 77ec4cc..25de202 100644 --- a/speedtest/request.go +++ b/speedtest/request.go @@ -271,7 +271,7 @@ func (s *Server) PingTestContext(ctx context.Context) error { return err } - resp, err := http.DefaultClient.Do(req) + resp, err := s.doer.Do(req) if err != nil { return err } diff --git a/speedtest/server.go b/speedtest/server.go index bdfa730..55ae2f6 100644 --- a/speedtest/server.go +++ b/speedtest/server.go @@ -30,6 +30,8 @@ type Server struct { Latency time.Duration `json:"latency"` DLSpeed float64 `json:"dl_speed"` ULSpeed float64 `json:"ul_speed"` + + doer *http.Client } // ServerList list of Server @@ -61,18 +63,23 @@ func (b ByDistance) Less(i, j int) bool { } // FetchServerList retrieves a list of available servers -func FetchServerList(user *User) (ServerList, error) { +func (client *Speedtest) FetchServerList(user *User) (ServerList, error) { return FetchServerListContext(context.Background(), user) } +// FetchServerList retrieves a list of available servers +func FetchServerList(user *User) (ServerList, error) { + return defaultClient.FetchServerList(user) +} + // FetchServerListContext retrieves a list of available servers, observing the given context. -func FetchServerListContext(ctx context.Context, user *User) (ServerList, error) { +func (client *Speedtest) FetchServerListContext(ctx context.Context, user *User) (ServerList, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, speedTestServersUrl, nil) if err != nil { return ServerList{}, err } - resp, err := http.DefaultClient.Do(req) + resp, err := client.doer.Do(req) if err != nil { return ServerList{}, err } @@ -85,7 +92,7 @@ func FetchServerListContext(ctx context.Context, user *User) (ServerList, error) return ServerList{}, err } - resp, err = http.DefaultClient.Do(req) + resp, err = client.doer.Do(req) if err != nil { return ServerList{}, err } @@ -101,6 +108,11 @@ func FetchServerListContext(ctx context.Context, user *User) (ServerList, error) return list, err } + // set doer of server + for _, s := range list.Servers { + s.doer = client.doer + } + // Calculate distance for i := range list.Servers { server := list.Servers[i] @@ -121,6 +133,11 @@ func FetchServerListContext(ctx context.Context, user *User) (ServerList, error) return list, nil } +// FetchServerListContext retrieves a list of available servers, observing the given context. +func FetchServerListContext(ctx context.Context, user *User) (ServerList, error) { + return defaultClient.FetchServerListContext(ctx, user) +} + func distance(lat1 float64, lon1 float64, lat2 float64, lon2 float64) float64 { radius := 6378.137 diff --git a/speedtest/server_test.go b/speedtest/server_test.go index 6537682..bd0af95 100644 --- a/speedtest/server_test.go +++ b/speedtest/server_test.go @@ -9,7 +9,10 @@ func TestFetchServerList(t *testing.T) { Lon: "138.44", Isp: "Hello", } - serverList, err := FetchServerList(&user) + + client := New() + + serverList, err := client.FetchServerList(&user) if err != nil { t.Errorf(err.Error()) } diff --git a/speedtest/speedtest.go b/speedtest/speedtest.go new file mode 100644 index 0000000..fe923ba --- /dev/null +++ b/speedtest/speedtest.go @@ -0,0 +1,33 @@ +package speedtest + +import "net/http" + +// Speedtest is a speedtest client. +type Speedtest struct { + doer *http.Client +} + +// Option is a function that can be passed to New to modify the Client. +type Option func(*Speedtest) + +// WithDoer sets the http.Client used to make requests. +func WithDoer(doer *http.Client) Option { + return func(s *Speedtest) { + s.doer = doer + } +} + +// New creates a new speedtest client. +func New(opts ...Option) *Speedtest { + s := &Speedtest{ + doer: http.DefaultClient, + } + + for _, opt := range opts { + opt(s) + } + + return s +} + +var defaultClient = New() diff --git a/speedtest/speedtest_test.go b/speedtest/speedtest_test.go new file mode 100644 index 0000000..6d502d3 --- /dev/null +++ b/speedtest/speedtest_test.go @@ -0,0 +1,26 @@ +package speedtest + +import ( + "net/http" + "testing" +) + +func TestNew(t *testing.T) { + t.Run("DefaultDoer", func(t *testing.T) { + c := New() + + if c.doer == nil { + t.Error("doer is nil by") + } + }) + + t.Run("CustomDoer", func(t *testing.T) { + doer := &http.Client{} + + c := New(WithDoer(doer)) + if c.doer != doer { + t.Error("doer is not the same") + } + }) + +} diff --git a/speedtest/user.go b/speedtest/user.go index fe75501..4b4c6db 100644 --- a/speedtest/user.go +++ b/speedtest/user.go @@ -24,18 +24,23 @@ type Users struct { } // FetchUserInfo returns information about caller determined by speedtest.net -func FetchUserInfo() (*User, error) { +func (client *Speedtest) FetchUserInfo() (*User, error) { return FetchUserInfoContext(context.Background()) } +// FetchUserInfo returns information about caller determined by speedtest.net +func FetchUserInfo() (*User, error) { + return defaultClient.FetchUserInfo() +} + // FetchUserInfoContext returns information about caller determined by speedtest.net, observing the given context. -func FetchUserInfoContext(ctx context.Context) (*User, error) { +func (client *Speedtest) FetchUserInfoContext(ctx context.Context) (*User, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, speedTestConfigUrl, nil) if err != nil { return nil, err } - resp, err := http.DefaultClient.Do(req) + resp, err := client.doer.Do(req) if err != nil { return nil, err } @@ -57,6 +62,11 @@ func FetchUserInfoContext(ctx context.Context) (*User, error) { return &users.Users[0], nil } +// FetchUserInfoContext returns information about caller determined by speedtest.net, observing the given context. +func FetchUserInfoContext(ctx context.Context) (*User, error) { + return defaultClient.FetchUserInfoContext(ctx) +} + // String representation of User func (u *User) String() string { return fmt.Sprintf("%s, (%s) [%s, %s]", u.IP, u.Isp, u.Lat, u.Lon) diff --git a/speedtest/user_test.go b/speedtest/user_test.go index 8dc8829..be534ad 100644 --- a/speedtest/user_test.go +++ b/speedtest/user_test.go @@ -7,7 +7,9 @@ import ( ) func TestFetchUserInfo(t *testing.T) { - user, err := FetchUserInfo() + client := New() + + user, err := client.FetchUserInfo() if err != nil { t.Errorf(err.Error()) }