diff --git a/speedtest.go b/speedtest.go index 10a1b79..cad3ef9 100644 --- a/speedtest.go +++ b/speedtest.go @@ -11,6 +11,7 @@ import ( "os" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -149,9 +150,12 @@ func main() { SourceInterface: *source, }) + blocker := sync.WaitGroup{} packetLossAnalyzerCtx, packetLossAnalyzerCancel := context.WithTimeout(context.Background(), time.Second*40) taskManager.Run("Packet Loss Analyzer", func(task *Task) { + blocker.Add(1) go func() { + defer blocker.Done() err = analyzer.RunWithContext(packetLossAnalyzerCtx, server.Host, func(packetLoss *transport.PLoss) { server.PacketLoss = *packetLoss }) @@ -211,6 +215,7 @@ func main() { time.Sleep(time.Second * 30) } packetLossAnalyzerCancel() + blocker.Wait() if !*jsonOutput { taskManager.Println(server.PacketLoss.String()) } diff --git a/speedtest/data_manager.go b/speedtest/data_manager.go index 04546ac..7ae2a48 100644 --- a/speedtest/data_manager.go +++ b/speedtest/data_manager.go @@ -88,7 +88,8 @@ type DataManager struct { rateCaptureFrequency time.Duration nThread int - running bool + running bool + runningRW sync.RWMutex download *TestDirection upload *TestDirection @@ -114,11 +115,13 @@ func (dm *DataManager) NewDataDirection(testType int) *TestDirection { } func NewDataManager() *DataManager { + r := bytes.Repeat([]byte{0xAA}, readChunkSize) // uniformly distributed sequence of bits ret := &DataManager{ nThread: runtime.NumCPU(), captureTime: time.Second * 15, rateCaptureFrequency: time.Millisecond * 50, Snapshot: &Snapshot{}, + repeatByte: &r, } ret.download = ret.NewDataDirection(typeDownload) ret.upload = ret.NewDataDirection(typeUpload) @@ -169,6 +172,14 @@ func (dm *DataManager) RegisterDownloadHandler(fn func()) *TestDirection { return dm.download } +func (td *TestDirection) GetTotalDataVolume() int64 { + return atomic.LoadInt64(&td.totalDataVolume) +} + +func (td *TestDirection) AddTotalDataVolume(delta int64) int64 { + return atomic.AddInt64(&td.totalDataVolume, delta) +} + func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerIndex int) { if len(td.fns) == 0 { panic("empty task stack") @@ -200,7 +211,9 @@ func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerInde once.Do(func() { stopCapture <- true close(stopCapture) + td.manager.runningRW.Lock() td.manager.running = false + td.manager.runningRW.Unlock() cancel() dbg.Println("FuncGroup: Stop") }) @@ -212,7 +225,10 @@ func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerInde go func() { defer wg.Done() for { - if !td.manager.running { + td.manager.runningRW.RLock() + running := td.manager.running + td.manager.runningRW.RUnlock() + if !running { return } td.fns[mainRequestHandlerIndex]() @@ -232,7 +248,10 @@ func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerInde go func() { defer wg.Done() for { - if !td.manager.running { + td.manager.runningRW.RLock() + running := td.manager.running + td.manager.runningRW.RUnlock() + if !running { return } td.fns[t]() @@ -255,14 +274,14 @@ func (td *TestDirection) rateCapture() chan bool { for { select { case <-t.C: - newTotalDataVolume := td.totalDataVolume + newTotalDataVolume := td.GetTotalDataVolume() deltaDataVolume := newTotalDataVolume - prevTotalDataVolume prevTotalDataVolume = newTotalDataVolume if deltaDataVolume != 0 { td.RateSequence = append(td.RateSequence, deltaDataVolume) } // anyway we update the measuring instrument - globalAvg := (float64(td.totalDataVolume)) / float64(time.Since(sTime).Milliseconds()) * 1000 + globalAvg := (float64(td.GetTotalDataVolume())) / float64(time.Since(sTime).Milliseconds()) * 1000 if td.welford.Update(globalAvg, float64(deltaDataVolume)) { go td.closeFunc() } @@ -290,19 +309,19 @@ func (dm *DataManager) NewChunk() Chunk { } func (dm *DataManager) AddTotalDownload(value int64) { - atomic.AddInt64(&dm.download.totalDataVolume, value) + dm.download.AddTotalDataVolume(value) } func (dm *DataManager) AddTotalUpload(value int64) { - atomic.AddInt64(&dm.upload.totalDataVolume, value) + dm.upload.AddTotalDataVolume(value) } func (dm *DataManager) GetTotalDownload() int64 { - return dm.download.totalDataVolume + return dm.download.GetTotalDataVolume() } func (dm *DataManager) GetTotalUpload() int64 { - return dm.upload.totalDataVolume + return dm.upload.GetTotalDataVolume() } func (dm *DataManager) SetRateCaptureFrequency(duration time.Duration) Manager { @@ -337,7 +356,7 @@ func (dm *DataManager) Reset() { func (dm *DataManager) GetAvgDownloadRate() float64 { unit := float64(dm.captureTime / time.Millisecond) - return float64(dm.download.totalDataVolume*8/1000) / unit + return float64(dm.download.GetTotalDataVolume()*8/1000) / unit } func (dm *DataManager) GetEWMADownloadRate() float64 { @@ -349,7 +368,7 @@ func (dm *DataManager) GetEWMADownloadRate() float64 { func (dm *DataManager) GetAvgUploadRate() float64 { unit := float64(dm.captureTime / time.Millisecond) - return float64(dm.upload.totalDataVolume*8/1000) / unit + return float64(dm.upload.GetTotalDataVolume()*8/1000) / unit } func (dm *DataManager) GetEWMAUploadRate() float64 { @@ -405,14 +424,17 @@ func (dc *DataChunk) DownloadHandler(r io.Reader) error { defer blackHolePool.Put(bufP) readSize := 0 for { - if !dc.manager.running { + dc.manager.runningRW.RLock() + running := dc.manager.running + dc.manager.runningRW.RUnlock() + if !running { return nil } readSize, dc.err = r.Read(*bufP) rs := int64(readSize) dc.remainOrDiscardSize += rs - atomic.AddInt64(&dc.manager.download.totalDataVolume, rs) + dc.manager.download.AddTotalDataVolume(rs) if dc.err != nil { if dc.err == io.EOF { return nil @@ -434,12 +456,6 @@ func (dc *DataChunk) UploadHandler(size int64) Chunk { dc.ContentLength = size dc.remainOrDiscardSize = size dc.dateType = typeUpload - - if dc.manager.repeatByte == nil { - r := bytes.Repeat([]byte{0xAA}, readChunkSize) // uniformly distributed sequence of bits - dc.manager.repeatByte = &r - } - dc.startTime = time.Now() return dc } @@ -453,7 +469,10 @@ func (dc *DataChunk) WriteTo(w io.Writer) (written int64, err error) { nw := 0 nr := readChunkSize for { - if !dc.manager.running || dc.remainOrDiscardSize <= 0 { + dc.manager.runningRW.RLock() + running := dc.manager.running + dc.manager.runningRW.RUnlock() + if !running || dc.remainOrDiscardSize <= 0 { dc.endTime = time.Now() return written, io.EOF } diff --git a/speedtest/data_manager_test.go b/speedtest/data_manager_test.go index 6f59db6..dbd34d3 100644 --- a/speedtest/data_manager_test.go +++ b/speedtest/data_manager_test.go @@ -35,7 +35,7 @@ func TestDataManager_AddTotalDownload(t *testing.T) { }() } wg.Wait() - if dmp.download.totalDataVolume != 43521000000 { + if dmp.download.GetTotalDataVolume() != 43521000000 { t.Fatal() } } diff --git a/speedtest/server.go b/speedtest/server.go index 19a3c50..7ab64bb 100644 --- a/speedtest/server.go +++ b/speedtest/server.go @@ -291,14 +291,15 @@ func (s *Speedtest) FetchServerListContext(ctx context.Context) (Servers, error) wg.Add(1) go func(gs *Server) { var latency []int64 + var errPing error if s.config.PingMode == TCP { - latency, err = gs.TCPPing(pCtx, 1, time.Millisecond, nil) + latency, errPing = gs.TCPPing(pCtx, 1, time.Millisecond, nil) } else if s.config.PingMode == ICMP { - latency, err = gs.ICMPPing(pCtx, 4*time.Second, 1, time.Millisecond, nil) + latency, errPing = gs.ICMPPing(pCtx, 4*time.Second, 1, time.Millisecond, nil) } else { - latency, err = gs.HTTPPing(pCtx, 1, time.Millisecond, nil) + latency, errPing = gs.HTTPPing(pCtx, 1, time.Millisecond, nil) } - if err != nil || len(latency) < 1 { + if errPing != nil || len(latency) < 1 { gs.Latency = PingTimeout } else { gs.Latency = time.Duration(latency[0]) * time.Nanosecond