diff --git a/client.go b/client.go index cab851245..26c3dd2a9 100644 --- a/client.go +++ b/client.go @@ -194,72 +194,131 @@ type message struct { err error } -type receiver struct { - wg *sync.WaitGroup - messages chan *message - err error +// callMap provides access to a map of active calls, guarded by a mutex. +type callMap struct { + m sync.Mutex + activeCalls map[uint32]*callRequest + closeErr error } -func (r *receiver) run(ctx context.Context, c *channel) { - defer r.wg.Done() +// newCallMap returns a new callMap with an empty set of active calls. +func newCallMap() *callMap { + return &callMap{ + activeCalls: make(map[uint32]*callRequest), + } +} - for { - select { - case <-ctx.Done(): - r.err = ctx.Err() - return - default: - mh, p, err := c.recv() - if err != nil { - _, ok := status.FromError(err) - if !ok { - // treat all errors that are not an rpc status as terminal. - // all others poison the connection. - r.err = filterCloseErr(err) - return - } - } - select { - case r.messages <- &message{ - messageHeader: mh, - p: p[:mh.Length], - err: err, - }: - case <-ctx.Done(): - r.err = ctx.Err() - return - } - } +// set adds a call entry to the map with the given streamID key. +func (cm *callMap) set(streamID uint32, cr *callRequest) error { + cm.m.Lock() + defer cm.m.Unlock() + if cm.closeErr != nil { + return cm.closeErr } + cm.activeCalls[streamID] = cr + return nil +} + +// get looks up the call entry for the given streamID key, then removes it +// from the map and returns it. +func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) { + cm.m.Lock() + defer cm.m.Unlock() + if cm.closeErr != nil { + return nil, false, cm.closeErr + } + cr, ok = cm.activeCalls[streamID] + if ok { + delete(cm.activeCalls, streamID) + } + return +} + +// abort sends the given error to each active call, and clears the map. +// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort. +func (cm *callMap) abort(err error) error { + cm.m.Lock() + defer cm.m.Unlock() + if cm.closeErr != nil { + return cm.closeErr + } + for streamID, call := range cm.activeCalls { + call.errs <- err + delete(cm.activeCalls, streamID) + } + cm.closeErr = err + return nil } func (c *Client) run() { var ( - streamID uint32 = 1 - waiters = make(map[uint32]*callRequest) - calls = c.calls - incoming = make(chan *message) - receiversDone = make(chan struct{}) - wg sync.WaitGroup + waiters = newCallMap() + receiverDone = make(chan struct{}) ) - // broadcast the shutdown error to the remaining waiters. - abortWaiters := func(wErr error) { - for _, waiter := range waiters { - waiter.errs <- wErr + // Sender goroutine + // Receives calls from dispatch, adds them to the set of active calls, and sends them + // to the server. + go func() { + var streamID uint32 = 1 + for { + select { + case <-c.ctx.Done(): + return + case call := <-c.calls: + id := streamID + streamID += 2 // enforce odd client initiated request ids + if err := waiters.set(id, call); err != nil { + call.errs <- err // errs is buffered so should not block. + continue + } + if err := c.send(id, messageTypeRequest, call.req); err != nil { + call.errs <- err // errs is buffered so should not block. + waiters.get(id) // remove from waiters set + } + } } - } - recv := &receiver{ - wg: &wg, - messages: incoming, - } - wg.Add(1) + }() + // Receiver goroutine + // Receives responses from the server, looks up the call info in the set of active calls, + // and notifies the caller of the response. go func() { - wg.Wait() - close(receiversDone) + defer close(receiverDone) + for { + select { + case <-c.ctx.Done(): + c.setError(c.ctx.Err()) + return + default: + mh, p, err := c.channel.recv() + if err != nil { + _, ok := status.FromError(err) + if !ok { + // treat all errors that are not an rpc status as terminal. + // all others poison the connection. + c.setError(filterCloseErr(err)) + return + } + } + msg := &message{ + messageHeader: mh, + p: p[:mh.Length], + err: err, + } + call, ok, err := waiters.get(mh.StreamID) + if err != nil { + logrus.Errorf("ttrpc: failed to look up active call: %s", err) + continue + } + if !ok { + logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID) + continue + } + call.errs <- c.recv(call.resp, msg) + } + } }() - go recv.run(c.ctx, c.channel) defer func() { c.conn.Close() @@ -269,32 +328,14 @@ func (c *Client) run() { for { select { - case call := <-calls: - if err := c.send(streamID, messageTypeRequest, call.req); err != nil { - call.errs <- err - continue - } - - waiters[streamID] = call - streamID += 2 // enforce odd client initiated request ids - case msg := <-incoming: - call, ok := waiters[msg.StreamID] - if !ok { - logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID) - continue - } - - call.errs <- c.recv(call.resp, msg) - delete(waiters, msg.StreamID) - case <-receiversDone: - // all the receivers have exited - if recv.err != nil { - c.setError(recv.err) - } + case <-receiverDone: + // The receiver has exited. // don't return out, let the close of the context trigger the abort of waiters c.Close() case <-c.ctx.Done(): - abortWaiters(c.error()) + // Abort all active calls. This will also prevent any new calls from being added + // to waiters. + waiters.abort(c.error()) return } }