Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shared_client: Bump request id #12

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go: [ 1.17.x, 1.18.x ]
go: [ 1.18.x, 1.19.x ]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated nit: Given the sheer amount of highly concurrent code in shared client, we should also run with go test -race in this workflow

steps:

- name: Set up Go
Expand Down
243 changes: 203 additions & 40 deletions shared_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package dns

import (
"container/heap"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -138,6 +139,161 @@ func (c *SharedClient) ExchangeShared(m *Msg) (r *Msg, rtt time.Duration, err er
return c.ExchangeSharedContext(context.Background(), m)
}

// waiter is a single waiter in the deadline heap
type waiter struct {
id uint16
index int32
start time.Time
deadline time.Time
ch chan sharedClientResponse
}

// waiterQueue is a collection of waiters, implements container/heap
// so that the the waiter with the closest deadline is on the top of the heap.
// Map is used as an index on the DNS request ID for quick access.
type waiterQueue struct {
waiters []*waiter
index map[uint16]*waiter
}

// Len from heap.Interface
func (wq waiterQueue) Len() int { return len(wq.waiters) }

// Less from heap.Interface
func (wq waiterQueue) Less(i, j int) bool {
return wq.waiters[j].deadline.After(wq.waiters[i].deadline)
}

// Swap from heap.Interface
func (wq waiterQueue) Swap(i, j int) {
wq.waiters[i], wq.waiters[j] = wq.waiters[j], wq.waiters[i]
wq.waiters[i].index = int32(i)
wq.waiters[j].index = int32(j)
}

// Push from heap.Interface
func (wq *waiterQueue) Push(x any) {
wtr := x.(*waiter)
wtr.index = int32(len(wq.waiters))
wq.waiters = append(wq.waiters, wtr)
wq.index[wtr.id] = wtr
}

// Pop from heap.Interface
func (wq *waiterQueue) Pop() any {
old := wq.waiters
n := len(old)
wtr := old[n-1]
old[n-1] = nil // avoid memory leak
wtr.index = -1 // for safety
wq.waiters = old[:n-1]

delete(wq.index, wtr.id)
return wtr
}

// newWaiterQueue returns a new waiterQueue
func newWaiterQueue() *waiterQueue {
wq := waiterQueue{
waiters: make([]*waiter, 0, 1),
index: make(map[uint16]*waiter),
}
heap.Init(&wq)
return &wq
}

func (wq *waiterQueue) Insert(id uint16, start, deadline time.Time, ch chan sharedClientResponse) {
if wq.Exists(id) {
panic("waiterQueue: entry exists")
}

heap.Push(wq, &waiter{id, -1, start, deadline, ch})
}

func (wq *waiterQueue) Exists(id uint16) bool {
_, exists := wq.index[id]
return exists
}

func (wq *waiterQueue) FailAll(err error) {
for wq.Len() > 0 {
wtr := heap.Pop(wq).(*waiter)
wtr.ch <- sharedClientResponse{nil, 0, err}
close(wtr.ch)
}
}

func (wq *waiterQueue) Dequeue(id uint16) *waiter {
wtr, ok := wq.index[id]
if !ok {
return nil // not found
}
if wtr.id != id {
panic(fmt.Sprintf("waiterQueue: invalid id %d != %d", wtr.id, id))
}
if wtr.index < 0 {
panic(fmt.Sprintf("waiterQueue: invalid index: %d", wtr.index))
}

heap.Remove(wq, int(wtr.index))

return wtr
}

// GetTimeout returns the time from now till the earliers deadline
func (wq *waiterQueue) GetTimeout() time.Duration {
// return 10 minutes if there are no waiters
if wq.Len() == 0 {
return 10 * time.Minute
}
return time.Until(wq.waiters[0].deadline)
}

// errTimeout is an an error representing a request timeout.
// Implements net.Error
type errTimeout struct {
}

func (e errTimeout) Timeout() bool { return true }

// Temporary is deprecated. Return false.
func (e errTimeout) Temporary() bool { return false }

func (e errTimeout) Error() string {
return "request timed out"
}

var netErrorTimeout errTimeout

// Expired sends a timeout response to all timed out waiters
func (wq *waiterQueue) Expired() {
now := time.Now()
for wq.Len() > 0 {
if wq.waiters[0].deadline.After(now) {
break
}
wtr := heap.Pop(wq).(*waiter)
wtr.ch <- sharedClientResponse{nil, 0, netErrorTimeout}
close(wtr.ch)
}
}

// Respond passes the DNS response to the waiter
func (wq *waiterQueue) Respond(resp sharedClientResponse) {
if resp.err != nil {
// ReadMsg failed, but we cannot match it to a request,
// so complete all pending requests.
wq.FailAll(resp.err)
} else if resp.msg != nil {
wtr := wq.Dequeue(resp.msg.Id)
if wtr != nil {
resp.rtt = time.Since(wtr.start)
wtr.ch <- resp
close(wtr.ch)
}
}
}

// handler is started when the connection is dialed
func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan request) {
defer wg.Done()
Expand Down Expand Up @@ -197,53 +353,76 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
}
}()

type waiter struct {
ch chan sharedClientResponse
start time.Time
}
waitingResponses := make(map[uint16]waiter)
waiters := newWaiterQueue()
defer func() {
conn.Close()
close(receiverTrigger)

// Drain responses send by receive loop to allow it to exit.
// Drain responses sent by receive loop to allow it to exit.
// It may be repeatedly reading after an i/o timeout, for example.
// This range will be done only after the receive loop has returned
// and closed 'responses' in its defer function.
for range responses {
}

for _, waiter := range waitingResponses {
waiter.ch <- sharedClientResponse{nil, 0, net.ErrClosed}
close(waiter.ch)
// Now cancel all the remaining waiters
waiters.FailAll(net.ErrClosed)

// Drain requests in case they come in while we are closing
// down. This loop is done only after 'requests' channel is closed in
// SharedClient.close() and it is not possible for new requests or timeouts
// to be sent on those closed channels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this comment. Sending on closed channel panics - surely that's not how we prevent the senders from sending requests? I assume the real check from preventing senders is that conn is nil, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SharedClient.close() (note: not a close(<channel>)) is called when the shared client can no longer be used for new requests. So there is nothing sent to the closed channel. The point of the comment is that the range loop on the channel completes only after the channel is closed (by the side sending to the channel), so we are guaranteed to send replies to all requests received on this channel.

for req := range requests {
req.ch <- sharedClientResponse{nil, 0, net.ErrClosed}
close(req.ch)
}
}()

deadline := time.NewTimer(0)
for {
// update timer
deadline.Reset(waiters.GetTimeout())
Copy link
Member

@gandro gandro Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reset is not safe to use when the timer is not drained https://pkg.go.dev/time#Timer.Reset

How do we know that the timer is drained here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noting this, will have to work around this...

select {
case <-deadline.C:
waiters.Expired()

case req, ok := <-requests:
if !ok {
// 'requests' is closed when SharedClient is recycled, which happens
// responeses (or errors) have been received and there are no more
// requests to be sent.
return
}
start := time.Now()

// Check if we already have a request with the same id
// Due to birthday paradox and the fact that ID is uint16
// it's likely to happen with small number (~200) of concurrent requests
// which would result in goroutine leak as we would never close req.ch
if _, ok := waitingResponses[req.msg.Id]; ok {
req.ch <- sharedClientResponse{nil, 0, fmt.Errorf("duplicate request id %d", req.msg.Id)}
close(req.ch)
continue
if waiters.Exists(req.msg.Id) {
duplicate := true
for n := 0; n < 5; n++ {
// Try a new ID
id := Id()
if duplicate = waiters.Exists(id); !duplicate {
req.msg.Id = id
break
}
}
if duplicate {
req.ch <- sharedClientResponse{nil, 0, fmt.Errorf("duplicate request id %d", req.msg.Id)}
close(req.ch)
continue
}
}

start := time.Now()
err := client.SendContext(req.ctx, req.msg, conn, start)
if err != nil {
req.ch <- sharedClientResponse{nil, 0, err}
close(req.ch)
} else {
waitingResponses[req.msg.Id] = waiter{req.ch, start}
deadline := time.Now().Add(client.getTimeoutForRequest(client.readTimeout()))
waiters.Insert(req.msg.Id, start, deadline, req.ch)

// Wake up the receiver that may be waiting to receive again
triggerReceiver()
Expand All @@ -255,22 +434,7 @@ func handler(wg *sync.WaitGroup, client *Client, conn *Conn, requests chan reque
// nothing can be received any more
return
}
if resp.err != nil {
// ReadMsg failed, but we cannot match it to a request,
// so complete all pending requests.
for _, waiter := range waitingResponses {
waiter.ch <- sharedClientResponse{nil, 0, resp.err}
close(waiter.ch)
}
waitingResponses = make(map[uint16]waiter)
} else if resp.msg != nil {
if waiter, ok := waitingResponses[resp.msg.Id]; ok {
delete(waitingResponses, resp.msg.Id)
resp.rtt = time.Since(waiter.start)
waiter.ch <- resp
close(waiter.ch)
}
}
waiters.Respond(resp)
}
}
}
Expand All @@ -294,21 +458,20 @@ func (c *SharedClient) ExchangeSharedContext(ctx context.Context, m *Msg) (r *Ms
timeout := c.getTimeoutForRequest(c.Client.writeTimeout())
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
respCh := make(chan sharedClientResponse)
// Response channel is buffered with capacity of one. This quarantees that the handler can always send
// one response, even if we time out below and never actually receive the response.
respCh := make(chan sharedClientResponse, 1)
select {
case c.requests <- request{ctx: ctx, msg: m, ch: respCh}:
case <-ctx.Done():
// request was not sent, no cleanup to do
return nil, 0, ctx.Err()
}

// Since c.requests is unbuffered, the handler is guaranteed to eventually close 'respCh'
select {
case resp := <-respCh:
return resp.msg, resp.rtt, resp.err
// This is just fail-safe mechanism in case there is another similar issue
case <-time.After(time.Minute):
return nil, 0, fmt.Errorf("timeout waiting for response")
}
// Handler eventually responds to every request, possibly after the request times out.
resp := <-respCh

return resp.msg, resp.rtt, resp.err
}

// close closes and waits for the close to finish.
Expand Down
Loading
Loading