-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
shared_client: Bump request id #12
base: master
Are you sure you want to change the base?
Changes from all commits
fe97f44
022c597
f027334
75c2156
67f6b75
5724e90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
package dns | ||
|
||
import ( | ||
"container/heap" | ||
"context" | ||
"errors" | ||
"fmt" | ||
|
@@ -138,6 +139,161 @@ func (c *SharedClient) ExchangeShared(m *Msg) (r *Msg, rtt time.Duration, err er | |
return c.ExchangeSharedContext(context.Background(), m) | ||
} | ||
|
||
// waiter is a single waiter in the deadline heap | ||
type waiter struct { | ||
id uint16 | ||
index int32 | ||
start time.Time | ||
deadline time.Time | ||
ch chan sharedClientResponse | ||
} | ||
|
||
// waiterQueue is a collection of waiters, implements container/heap | ||
// so that the the waiter with the closest deadline is on the top of the heap. | ||
// Map is used as an index on the DNS request ID for quick access. | ||
type waiterQueue struct { | ||
waiters []*waiter | ||
index map[uint16]*waiter | ||
} | ||
|
||
// Len from heap.Interface | ||
func (wq waiterQueue) Len() int { return len(wq.waiters) } | ||
|
||
// Less from heap.Interface | ||
func (wq waiterQueue) Less(i, j int) bool { | ||
return wq.waiters[j].deadline.After(wq.waiters[i].deadline) | ||
} | ||
|
||
// Swap from heap.Interface | ||
func (wq waiterQueue) Swap(i, j int) { | ||
wq.waiters[i], wq.waiters[j] = wq.waiters[j], wq.waiters[i] | ||
wq.waiters[i].index = int32(i) | ||
wq.waiters[j].index = int32(j) | ||
} | ||
|
||
// Push from heap.Interface | ||
func (wq *waiterQueue) Push(x any) { | ||
wtr := x.(*waiter) | ||
wtr.index = int32(len(wq.waiters)) | ||
wq.waiters = append(wq.waiters, wtr) | ||
wq.index[wtr.id] = wtr | ||
} | ||
|
||
// Pop from heap.Interface | ||
func (wq *waiterQueue) Pop() any { | ||
old := wq.waiters | ||
n := len(old) | ||
wtr := old[n-1] | ||
old[n-1] = nil // avoid memory leak | ||
wtr.index = -1 // for safety | ||
wq.waiters = old[:n-1] | ||
|
||
delete(wq.index, wtr.id) | ||
return wtr | ||
} | ||
|
||
// newWaiterQueue returns a new waiterQueue | ||
func newWaiterQueue() *waiterQueue { | ||
wq := waiterQueue{ | ||
waiters: make([]*waiter, 0, 1), | ||
index: make(map[uint16]*waiter), | ||
} | ||
heap.Init(&wq) | ||
return &wq | ||
} | ||
|
||
func (wq *waiterQueue) Insert(id uint16, start, deadline time.Time, ch chan sharedClientResponse) { | ||
if wq.Exists(id) { | ||
panic("waiterQueue: entry exists") | ||
} | ||
|
||
heap.Push(wq, &waiter{id, -1, start, deadline, ch}) | ||
} | ||
|
||
func (wq *waiterQueue) Exists(id uint16) bool { | ||
_, exists := wq.index[id] | ||
return exists | ||
} | ||
|
||
func (wq *waiterQueue) FailAll(err error) { | ||
for wq.Len() > 0 { | ||
wtr := heap.Pop(wq).(*waiter) | ||
wtr.ch <- sharedClientResponse{nil, 0, err} | ||
close(wtr.ch) | ||
} | ||
} | ||
|
||
func (wq *waiterQueue) Dequeue(id uint16) *waiter { | ||
wtr, ok := wq.index[id] | ||
if !ok { | ||
return nil // not found | ||
} | ||
if wtr.id != id { | ||
panic(fmt.Sprintf("waiterQueue: invalid id %d != %d", wtr.id, id)) | ||
} | ||
if wtr.index < 0 { | ||
panic(fmt.Sprintf("waiterQueue: invalid index: %d", wtr.index)) | ||
} | ||
|
||
heap.Remove(wq, int(wtr.index)) | ||
|
||
return wtr | ||
} | ||
|
||
// GetTimeout returns the time from now till the earliers deadline | ||
func (wq *waiterQueue) GetTimeout() time.Duration { | ||
// return 10 minutes if there are no waiters | ||
if wq.Len() == 0 { | ||
return 10 * time.Minute | ||
} | ||
return time.Until(wq.waiters[0].deadline) | ||
} | ||
|
||
// errTimeout is an an error representing a request timeout. | ||
// Implements net.Error | ||
type errTimeout struct { | ||
} | ||
|
||
func (e errTimeout) Timeout() bool { return true } | ||
|
||
// Temporary is deprecated. Return false. | ||
func (e errTimeout) Temporary() bool { return false } | ||
|
||
func (e errTimeout) Error() string { | ||
return "request timed out" | ||
} | ||
|
||
var netErrorTimeout errTimeout | ||
|
||
// Expired sends a timeout response to all timed out waiters | ||
func (wq *waiterQueue) Expired() { | ||
now := time.Now() | ||
for wq.Len() > 0 { | ||
if wq.waiters[0].deadline.After(now) { | ||
break | ||
} | ||
wtr := heap.Pop(wq).(*waiter) | ||
wtr.ch <- sharedClientResponse{nil, 0, netErrorTimeout} | ||
close(wtr.ch) | ||
} | ||
} | ||
|
||
// Respond passes the DNS response to the waiter | ||
func (wq *waiterQueue) Respond(resp sharedClientResponse) { | ||
if resp.err != nil { | ||
// ReadMsg failed, but we cannot match it to a request, | ||
// so complete all pending requests. | ||
wq.FailAll(resp.err) | ||
} else if resp.msg != nil { | ||
wtr := wq.Dequeue(resp.msg.Id) | ||
if wtr != nil { | ||
resp.rtt = time.Since(wtr.start) | ||
wtr.ch <- resp | ||
close(wtr.ch) | ||
} | ||
} | ||
} | ||
|
||
// handler is started when the connection is dialed | ||
func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan request) { | ||
defer wg.Done() | ||
|
@@ -197,53 +353,76 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque | |
} | ||
}() | ||
|
||
type waiter struct { | ||
ch chan sharedClientResponse | ||
start time.Time | ||
} | ||
waitingResponses := make(map[uint16]waiter) | ||
waiters := newWaiterQueue() | ||
defer func() { | ||
conn.Close() | ||
close(receiverTrigger) | ||
|
||
// Drain responses send by receive loop to allow it to exit. | ||
// Drain responses sent by receive loop to allow it to exit. | ||
// It may be repeatedly reading after an i/o timeout, for example. | ||
// This range will be done only after the receive loop has returned | ||
// and closed 'responses' in its defer function. | ||
for range responses { | ||
} | ||
|
||
for _, waiter := range waitingResponses { | ||
waiter.ch <- sharedClientResponse{nil, 0, net.ErrClosed} | ||
close(waiter.ch) | ||
// Now cancel all the remaining waiters | ||
waiters.FailAll(net.ErrClosed) | ||
|
||
// Drain requests in case they come in while we are closing | ||
// down. This loop is done only after 'requests' channel is closed in | ||
// SharedClient.close() and it is not possible for new requests or timeouts | ||
// to be sent on those closed channels. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow this comment. Sending on closed channel panics - surely that's not how we prevent the senders from sending requests? I assume the real check from preventing senders is that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
for req := range requests { | ||
req.ch <- sharedClientResponse{nil, 0, net.ErrClosed} | ||
close(req.ch) | ||
} | ||
}() | ||
|
||
deadline := time.NewTimer(0) | ||
for { | ||
// update timer | ||
deadline.Reset(waiters.GetTimeout()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How do we know that the timer is drained here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for noting this, will have to work around this... |
||
select { | ||
case <-deadline.C: | ||
waiters.Expired() | ||
|
||
case req, ok := <-requests: | ||
if !ok { | ||
// 'requests' is closed when SharedClient is recycled, which happens | ||
// responeses (or errors) have been received and there are no more | ||
// requests to be sent. | ||
return | ||
} | ||
start := time.Now() | ||
|
||
// Check if we already have a request with the same id | ||
// Due to birthday paradox and the fact that ID is uint16 | ||
// it's likely to happen with small number (~200) of concurrent requests | ||
// which would result in goroutine leak as we would never close req.ch | ||
if _, ok := waitingResponses[req.msg.Id]; ok { | ||
req.ch <- sharedClientResponse{nil, 0, fmt.Errorf("duplicate request id %d", req.msg.Id)} | ||
close(req.ch) | ||
continue | ||
if waiters.Exists(req.msg.Id) { | ||
duplicate := true | ||
for n := 0; n < 5; n++ { | ||
// Try a new ID | ||
id := Id() | ||
if duplicate = waiters.Exists(id); !duplicate { | ||
req.msg.Id = id | ||
break | ||
} | ||
} | ||
if duplicate { | ||
req.ch <- sharedClientResponse{nil, 0, fmt.Errorf("duplicate request id %d", req.msg.Id)} | ||
close(req.ch) | ||
continue | ||
} | ||
} | ||
|
||
start := time.Now() | ||
err := client.SendContext(req.ctx, req.msg, conn, start) | ||
if err != nil { | ||
req.ch <- sharedClientResponse{nil, 0, err} | ||
close(req.ch) | ||
} else { | ||
waitingResponses[req.msg.Id] = waiter{req.ch, start} | ||
deadline := time.Now().Add(client.getTimeoutForRequest(client.readTimeout())) | ||
waiters.Insert(req.msg.Id, start, deadline, req.ch) | ||
|
||
// Wake up the receiver that may be waiting to receive again | ||
triggerReceiver() | ||
|
@@ -255,22 +434,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque | |
// nothing can be received any more | ||
return | ||
} | ||
if resp.err != nil { | ||
// ReadMsg failed, but we cannot match it to a request, | ||
// so complete all pending requests. | ||
for _, waiter := range waitingResponses { | ||
waiter.ch <- sharedClientResponse{nil, 0, resp.err} | ||
close(waiter.ch) | ||
} | ||
waitingResponses = make(map[uint16]waiter) | ||
} else if resp.msg != nil { | ||
if waiter, ok := waitingResponses[resp.msg.Id]; ok { | ||
delete(waitingResponses, resp.msg.Id) | ||
resp.rtt = time.Since(waiter.start) | ||
waiter.ch <- resp | ||
close(waiter.ch) | ||
} | ||
} | ||
waiters.Respond(resp) | ||
} | ||
} | ||
} | ||
|
@@ -294,21 +458,20 @@ func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *Msg) (r *Ms | |
timeout := c.getTimeoutForRequest(c.Client.writeTimeout()) | ||
ctx, cancel := context.WithTimeout(ctx, timeout) | ||
defer cancel() | ||
respCh := make(chan sharedClientResponse) | ||
// Response channel is buffered with capacity of one. This quarantees that the handler can always send | ||
// one response, even if we time out below and never actually receive the response. | ||
respCh := make(chan sharedClientResponse, 1) | ||
select { | ||
case c.requests <- request{ctx: ctx, msg: m, ch: respCh}: | ||
case <-ctx.Done(): | ||
// request was not sent, no cleanup to do | ||
return nil, 0, ctx.Err() | ||
} | ||
|
||
// Since c.requests is unbuffered, the handler is guaranteed to eventually close 'respCh' | ||
select { | ||
case resp := <-respCh: | ||
return resp.msg, resp.rtt, resp.err | ||
// This is just fail-safe mechanism in case there is another similar issue | ||
case <-time.After(time.Minute): | ||
return nil, 0, fmt.Errorf("timeout waiting for response") | ||
} | ||
// Handler eventually responds to every request, possibly after the request times out. | ||
resp := <-respCh | ||
|
||
return resp.msg, resp.rtt, resp.err | ||
} | ||
|
||
// close closes and waits for the close to finish. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated nit: Given the sheer amount of highly concurrent code in shared client, we should also run with
go test -race
in this workflow