From b4f3e1110ff0ac274f0d19821b72fa2d04d1ddf4 Mon Sep 17 00:00:00 2001 From: Jarno Rajahalme Date: Thu, 25 Apr 2024 18:19:59 +0200 Subject: [PATCH] shared_client: Clean up waiters after timeouts Tell handler to delete waiters after request times out. Signed-off-by: Jarno Rajahalme --- shared_client.go | 216 ++++++++++++++++++++++++++++++++++-------- shared_client_test.go | 134 ++++++++++++++++++++++++++ 2 files changed, 308 insertions(+), 42 deletions(-) diff --git a/shared_client.go b/shared_client.go index d6fe38a5e..98d498b7f 100644 --- a/shared_client.go +++ b/shared_client.go @@ -4,6 +4,7 @@ package dns import ( + "container/heap" "context" "errors" "fmt" @@ -138,6 +139,161 @@ func (c *SharedClient) ExchangeShared(m *Msg) (r *Msg, rtt time.Duration, err er return c.ExchangeSharedContext(context.Background(), m) } +// waiter is a single waiter in the deadline heap +type waiter struct { + id uint16 + index int32 + start time.Time + deadline time.Time + ch chan sharedClientResponse +} + +// waiterQueue is a collection of waiters, implements container/heap +// so that the the waiter with the closest deadline is on the top of the heap. +// Map is used as an index on the DNS request ID for quick access. +type waiterQueue struct { + waiters []*waiter + index map[uint16]*waiter +} + +// Len from heap.Interface +func (wq waiterQueue) Len() int { return len(wq.waiters) } + +// Less from heap.Interface +func (wq waiterQueue) Less(i, j int) bool { + return wq.waiters[j].deadline.After(wq.waiters[i].deadline) +} + +// Swap from heap.Interface +func (wq waiterQueue) Swap(i, j int) { + wq.waiters[i], wq.waiters[j] = wq.waiters[j], wq.waiters[i] + wq.waiters[i].index = int32(i) + wq.waiters[j].index = int32(j) +} + +// Push from heap.Interface +func (wq *waiterQueue) Push(x any) { + wtr := x.(*waiter) + wtr.index = int32(len(wq.waiters)) + wq.waiters = append(wq.waiters, wtr) + wq.index[wtr.id] = wtr +} + +// Pop from heap.Interface +func (wq *waiterQueue) Pop() any { + old := wq.waiters + n := len(old) + wtr := old[n-1] + old[n-1] = nil // avoid memory leak + wtr.index = -1 // for safety + wq.waiters = old[:n-1] + + delete(wq.index, wtr.id) + return wtr +} + +// newWaiterQueue returns a new waiterQueue +func newWaiterQueue() *waiterQueue { + wq := waiterQueue{ + waiters: make([]*waiter, 0, 1), + index: make(map[uint16]*waiter), + } + heap.Init(&wq) + return &wq +} + +func (wq *waiterQueue) Insert(id uint16, start, deadline time.Time, ch chan sharedClientResponse) { + if wq.Exists(id) { + panic("waiterQueue: entry exists") + } + + heap.Push(wq, &waiter{id, -1, start, deadline, ch}) +} + +func (wq *waiterQueue) Exists(id uint16) bool { + _, exists := wq.index[id] + return exists +} + +func (wq *waiterQueue) FailAll(err error) { + for wq.Len() > 0 { + wtr := heap.Pop(wq).(*waiter) + wtr.ch <- sharedClientResponse{nil, 0, err} + close(wtr.ch) + } +} + +func (wq *waiterQueue) Dequeue(id uint16) *waiter { + wtr, ok := wq.index[id] + if !ok { + return nil // not found + } + if wtr.id != id { + panic(fmt.Sprintf("waiterQueue: invalid id %d != %d", wtr.id, id)) + } + if wtr.index < 0 { + panic(fmt.Sprintf("waiterQueue: invalid index: %d", wtr.index)) + } + + heap.Remove(wq, int(wtr.index)) + + return wtr +} + +// GetTimeout returns the time from now till the earliers deadline +func (wq *waiterQueue) GetTimeout() time.Duration { + // return 10 minutes if there are no waiters + if wq.Len() == 0 { + return 10 * time.Minute + } + return time.Until(wq.waiters[0].deadline) +} + +// errTimeout is an an error representing a request timeout. +// Implements net.Error +type errTimeout struct { +} + +func (e errTimeout) Timeout() bool { return true } + +// Temporary is deprecated. Return false. +func (e errTimeout) Temporary() bool { return false } + +func (e errTimeout) Error() string { + return "request timed out" +} + +var netErrorTimeout errTimeout + +// Expired sends a timeout response to all timed out waiters +func (wq *waiterQueue) Expired() { + now := time.Now() + for wq.Len() > 0 { + if wq.waiters[0].deadline.After(now) { + break + } + wtr := heap.Pop(wq).(*waiter) + wtr.ch <- sharedClientResponse{nil, 0, netErrorTimeout} + close(wtr.ch) + } +} + +// Respond passes the DNS response to the waiter +func (wq *waiterQueue) Respond(resp sharedClientResponse) { + if resp.err != nil { + // ReadMsg failed, but we cannot match it to a request, + // so complete all pending requests. + wq.FailAll(resp.err) + } else if resp.msg != nil { + wtr := wq.Dequeue(resp.msg.Id) + if wtr != nil { + resp.rtt = time.Since(wtr.start) + wtr.ch <- resp + close(wtr.ch) + } + } +} + // handler is started when the connection is dialed func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan request) { defer wg.Done() @@ -197,11 +353,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque } }() - type waiter struct { - ch chan sharedClientResponse - start time.Time - } - waitingResponses := make(map[uint16]waiter) + waiters := newWaiterQueue() defer func() { conn.Close() close(receiverTrigger) @@ -214,10 +366,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque } // Now cancel all the remaining waiters - for _, waiter := range waitingResponses { - waiter.ch <- sharedClientResponse{nil, 0, net.ErrClosed} - close(waiter.ch) - } + waiters.FailAll(net.ErrClosed) // Drain requests in case they come in while we are closing // down. This loop is done only after 'requests' channel is closed in @@ -229,8 +378,14 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque } }() + deadline := time.NewTimer(0) for { + // update timer + deadline.Reset(waiters.GetTimeout()) select { + case <-deadline.C: + waiters.Expired() + case req, ok := <-requests: if !ok { // 'requests' is closed when SharedClient is recycled, which happens @@ -243,10 +398,11 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque // Due to birthday paradox and the fact that ID is uint16 // it's likely to happen with small number (~200) of concurrent requests // which would result in goroutine leak as we would never close req.ch - if _, duplicate := waitingResponses[req.msg.Id]; duplicate { + if waiters.Exists(req.msg.Id) { // find next available ID + duplicate := true for id := req.msg.Id + 1; id != req.msg.Id; id++ { - if _, duplicate = waitingResponses[id]; !duplicate { + if duplicate = waiters.Exists(id); !duplicate { req.msg.Id = id break } @@ -264,7 +420,8 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque req.ch <- sharedClientResponse{nil, 0, err} close(req.ch) } else { - waitingResponses[req.msg.Id] = waiter{req.ch, start} + deadline := time.Now().Add(client.getTimeoutForRequest(client.readTimeout())) + waiters.Insert(req.msg.Id, start, deadline, req.ch) // Wake up the receiver that may be waiting to receive again triggerReceiver() @@ -276,22 +433,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque // nothing can be received any more return } - if resp.err != nil { - // ReadMsg failed, but we cannot match it to a request, - // so complete all pending requests. - for _, waiter := range waitingResponses { - waiter.ch <- sharedClientResponse{nil, 0, resp.err} - close(waiter.ch) - } - waitingResponses = make(map[uint16]waiter) - } else if resp.msg != nil { - if waiter, ok := waitingResponses[resp.msg.Id]; ok { - delete(waitingResponses, resp.msg.Id) - resp.rtt = time.Since(waiter.start) - waiter.ch <- resp - close(waiter.ch) - } - } + waiters.Respond(resp) } } } @@ -325,20 +467,10 @@ func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *Msg) (r *Ms return nil, 0, ctx.Err() } - // Since c.requests is unbuffered, the handler is guaranteed to eventually close 'respCh', - // unless the response never arrives. - // Quit when ctx is done so that each request times out individually in case there is no - // response at all. - readTimeout := c.getTimeoutForRequest(c.Client.readTimeout()) - ctx2, cancel2 := context.WithTimeout(ctx, readTimeout) - defer cancel2() - select { - case resp := <-respCh: - return resp.msg, resp.rtt, resp.err - case <-ctx2.Done(): - // TODO: handler now has a stale waiter that remains if the response never arrives. - return nil, 0, ctx2.Err() - } + // Handler eventually responds to every request, possibly after the request times out. + resp := <-respCh + + return resp.msg, resp.rtt, resp.err } // close closes and waits for the close to finish. diff --git a/shared_client_test.go b/shared_client_test.go index 4b3f9ed6d..963dc0ab7 100644 --- a/shared_client_test.go +++ b/shared_client_test.go @@ -6,7 +6,9 @@ package dns import ( "context" "crypto/tls" + "errors" "fmt" + "io" "net" "strconv" "strings" @@ -19,6 +21,138 @@ var ( clients = NewSharedClients() ) +func TestSharedWaiterQueue(t *testing.T) { + wq := newWaiterQueue() + if wq == nil { + t.Fatal("waiterQueue: nil") + } + + // Add first entry + start := time.Now() + deadline := start.Add(50 * time.Millisecond) + + ch1 := make(chan sharedClientResponse, 1) + wq.Insert(1, start, deadline, ch1) + if l := wq.Len(); l != 1 || len(wq.index) != 1 || len(wq.waiters) != 1 { + t.Errorf("waiterQueue: invalid length (%d != 1)", l) + } + if wq.waiters[0].start != start || wq.waiters[0].deadline != deadline { + t.Error("waiterQueue: invalid start or deadline") + } + if wq.Exists(1) != true { + t.Error("waiterQueue: Exists failed") + } + if wq.Exists(2) != false { + t.Error("waiterQueue: Exists2 failed") + } + timeout := wq.GetTimeout() + if timeout <= 0 || timeout > 50*time.Millisecond { + t.Error("waiterQueue: invalid timeout") + } + + // second entry has an earlier deadline, so it should go to front + start = start.Add(-1 * time.Second) + deadline = start.Add(10 * time.Millisecond) + + ch2 := make(chan sharedClientResponse, 1) + wq.Insert(2, start, deadline, ch2) + if l := wq.Len(); l != 2 || len(wq.index) != 2 || len(wq.waiters) != 2 { + t.Errorf("waiterQueue: invalid length (%d != 2)", l) + } + if wq.waiters[0].start != start || wq.waiters[0].deadline != deadline { + t.Errorf("waiterQueue: invalid start or deadline 2: %v", wq.waiters) + } + + select { + case resp, _ := <-ch1: + t.Errorf("waiterQueue: unexpected response before Expored 1: %v", resp) + case resp, _ := <-ch2: + t.Errorf("waiterQueue: unexpected response before Expired 2: %v", resp) + default: + } + + // Handle expired entries + wq.Expired() + + // get the expired entry response + resp, ok := <-ch2 + if !ok { + t.Error("waiterQueue: no response") + } + var neterr net.Error + if !errors.As(resp.err, &neterr) || !neterr.Timeout() { + t.Errorf("waiterQueue: error is not a timeout: %s", resp.err) + } + + // Check that ch1 did not get anything + select { + case resp, ok := <-ch1: + t.Errorf("waiterQueue: unexpected response1: %v (k: %v)", resp, ok) + default: + } + + // third entry has later deadline, so it should go to the back + start = time.Now() + deadline = start.Add(100 * time.Millisecond) + + ch3 := make(chan sharedClientResponse, 1) + wq.Insert(3, start, deadline, ch3) + if l := wq.Len(); l != 2 || len(wq.index) != 2 || len(wq.waiters) != 2 { + t.Errorf("waiterQueue: invalid length (%d != 2)", l) + } + if wq.waiters[1].start != start || wq.waiters[1].deadline != deadline { + t.Errorf("waiterQueue: invalid start or deadline 3: %v", wq.waiters) + } + + // Respond to the 3rd entry (uses wq.Dequeue()) + resp = sharedClientResponse{&Msg{MsgHdr: MsgHdr{Id: 3}}, 0, nil} + wq.Respond(resp) + if l := wq.Len(); l != 1 || len(wq.index) != 1 || len(wq.waiters) != 1 { + t.Errorf("waiterQueue: invalid length (%d != 1)", l) + } + + select { + // Check that ch1 did not get anything + case resp, ok := <-ch1: + t.Errorf("waiterQueue: unexpected response1: %v (k: %v)", resp, ok) + case resp, ok := <-ch3: + // Expecting + if !ok || resp.err != nil || resp.rtt <= 0 { + t.Errorf("waiterQueue: wrong response: %v", resp) + } + default: + t.Error("waiterQueue: should have received a response on ch3") + } + + // FailAll + wq.FailAll(io.EOF) + if l := wq.Len(); l != 0 || len(wq.index) != 0 || len(wq.waiters) != 0 { + t.Errorf("waiterQueue: invalid length (%d != 0)", l) + } + + select { + case resp, ok := <-ch1: + if !ok { + t.Errorf("waiterQueue: unexpected closure ch1: %v (k: %v)", resp, ok) + } + if resp.err != io.EOF { + t.Errorf("waiterQueue: unexpected error value on ch1: %s", resp.err) + } + default: + t.Errorf("ch1 should have received an error") + } + + select { + case resp, ok := <-ch1: + if ok { + t.Errorf("waiterQueue: unexpected response ch1: %v (k: %v)", resp, ok) + } + default: + t.Errorf("ch1 should have been closed") + } + +} + func TestSharedClientSync(t *testing.T) { HandleFunc("miek.nl.", HelloServer) defer HandleRemove("miek.nl.")