diff --git a/speedtest.go b/speedtest.go index 2cff6f1..ea55cbd 100644 --- a/speedtest.go +++ b/speedtest.go @@ -36,14 +36,14 @@ func main() { showUser(user) } - serverList, err := speedtest.FetchServerList(user) + servers, err := speedtest.FetchServers(user) checkError(err) if *showList { - showServerList(serverList) + showServerList(servers) return } - targets, err := serverList.FindServer(*serverIds) + targets, err := servers.FindServer(*serverIds) checkError(err) startTest(targets, *savingMode, *jsonOutput) @@ -140,8 +140,8 @@ func showUser(user *speedtest.User) { } } -func showServerList(serverList speedtest.ServerList) { - for _, s := range serverList.Servers { +func showServerList(servers speedtest.Servers) { + for _, s := range servers { fmt.Printf("[%4s] %8.2fkm ", s.ID, s.Distance) fmt.Printf(s.Name + " (" + s.Country + ") by " + s.Sponsor + "\n") } diff --git a/speedtest/server.go b/speedtest/server.go index bdfa730..895fe27 100644 --- a/speedtest/server.go +++ b/speedtest/server.go @@ -2,6 +2,7 @@ package speedtest import ( "context" + "encoding/json" "encoding/xml" "errors" "fmt" @@ -12,9 +13,17 @@ import ( "time" ) -const speedTestServersUrl = "https://www.speedtest.net/speedtest-servers-static.php" +const speedTestServersUrl = "https://www.speedtest.net/api/js/servers?engine=js&limit=10" const speedTestServersAlternativeUrl = "https://www.speedtest.net/speedtest-servers-static.php" +type PayloadType int + +const ( + JSONPayload PayloadType = iota + XMLPayload +) + + // Server information type Server struct { URL string `xml:"url,attr" json:"url"` @@ -60,50 +69,68 @@ func (b ByDistance) Less(i, j int) bool { return b.Servers[i].Distance < b.Servers[j].Distance } -// FetchServerList retrieves a list of available servers -func FetchServerList(user *User) (ServerList, error) { +// FetchServers retrieves a list of available servers +func FetchServers(user *User) (Servers, error) { return FetchServerListContext(context.Background(), user) } // FetchServerListContext retrieves a list of available servers, observing the given context. -func FetchServerListContext(ctx context.Context, user *User) (ServerList, error) { +func FetchServerListContext(ctx context.Context, user *User) (Servers, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, speedTestServersUrl, nil) if err != nil { - return ServerList{}, err + return Servers{}, err } resp, err := http.DefaultClient.Do(req) if err != nil { - return ServerList{}, err + return Servers{}, err } + payloadType := JSONPayload + if resp.ContentLength == 0 { resp.Body.Close() req, err = http.NewRequestWithContext(ctx, http.MethodGet, speedTestServersAlternativeUrl, nil) if err != nil { - return ServerList{}, err + return Servers{}, err } resp, err = http.DefaultClient.Do(req) if err != nil { - return ServerList{}, err + return Servers{}, err } + + payloadType = XMLPayload } defer resp.Body.Close() - // Decode xml - decoder := xml.NewDecoder(resp.Body) + var servers Servers - var list ServerList - if err := decoder.Decode(&list); err != nil { - return list, err + switch payloadType { + case JSONPayload: + // Decode xml + decoder := json.NewDecoder(resp.Body) + + if err := decoder.Decode(&servers); err != nil { + return servers, err + } + case XMLPayload: + var list ServerList + // Decode xml + decoder := xml.NewDecoder(resp.Body) + + if err := decoder.Decode(&list); err != nil { + return servers, err + } + servers = list.Servers + default: + return servers, fmt.Errorf("response payload decoding not implemented") } // Calculate distance - for i := range list.Servers { - server := list.Servers[i] + for _, server := range servers { sLat, _ := strconv.ParseFloat(server.Lat, 64) sLon, _ := strconv.ParseFloat(server.Lon, 64) uLat, _ := strconv.ParseFloat(user.Lat, 64) @@ -112,13 +139,13 @@ func FetchServerListContext(ctx context.Context, user *User) (ServerList, error) } // Sort by distance - sort.Sort(ByDistance{list.Servers}) + sort.Sort(ByDistance{servers}) - if len(list.Servers) <= 0 { - return list, errors.New("unable to retrieve server list") + if len(servers) <= 0 { + return servers, errors.New("unable to retrieve server list") } - return list, nil + return servers, nil } func distance(lat1 float64, lon1 float64, lat2 float64, lon2 float64) float64 { @@ -134,15 +161,15 @@ func distance(lat1 float64, lon1 float64, lat2 float64, lon2 float64) float64 { } // FindServer finds server by serverID -func (l *ServerList) FindServer(serverID []int) (Servers, error) { +func (l Servers) FindServer(serverID []int) (Servers, error) { servers := Servers{} - if len(l.Servers) <= 0 { + if len(l) <= 0 { return servers, errors.New("no servers available") } for _, sid := range serverID { - for _, s := range l.Servers { + for _, s := range l { id, _ := strconv.Atoi(s.ID) if sid == id { servers = append(servers, s) @@ -151,7 +178,7 @@ func (l *ServerList) FindServer(serverID []int) (Servers, error) { } if len(servers) == 0 { - servers = append(servers, l.Servers[0]) + servers = append(servers, l[0]) } return servers, nil diff --git a/speedtest/server_test.go b/speedtest/server_test.go index 6537682..adc3e01 100644 --- a/speedtest/server_test.go +++ b/speedtest/server_test.go @@ -9,7 +9,7 @@ func TestFetchServerList(t *testing.T) { Lon: "138.44", Isp: "Hello", } - serverList, err := FetchServerList(&user) + serverList, err := FetchServers(&user) if err != nil { t.Errorf(err.Error()) }