Skip to content

Commit

Permalink
client: Handle sending/receiving in separate goroutines
Browse files Browse the repository at this point in the history
Changes the TTRPC client logic so that sending and receiving with the
server are in completely independent goroutines, with shared state
guarded by a mutex. Previously, sending/receiving were tied together by
reliance on a coordinator goroutine. This led to issues where if the
server was not reading from the connection, the client could get stuck
sending a request, causing the client to not read responses from the
server. See [1] for more details.

The new design sets up separate sending/receiving goroutines. These
share state in the form of the set of active calls that have been made
to the server. This state is encapsulated in the callMap type and access
is guarded by a mutex.

The main event loop in `run` previously handled a lot of state
management for the client. Now that most state is tracked by the
callMap, it mostly exists to notice when the client is closed and take
appropriate action to clean up.

Also did some minor code cleanup. For instance, the code was previously
written to support multiple receiver goroutines, though this was not
actually used. I've removed this for now, since the code is simpler this
way, and it's easy to add back if we actually need it in the future.

[1] containerd#72

Signed-off-by: Kevin Parsons <[email protected]>
  • Loading branch information
kevpar committed Oct 14, 2021
1 parent 77fc8a4 commit 4f0aeb5
Showing 1 changed file with 116 additions and 75 deletions.
191 changes: 116 additions & 75 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,72 +194,131 @@ type message struct {
err error
}

type receiver struct {
wg *sync.WaitGroup
messages chan *message
err error
// callMap provides access to a map of active calls, guarded by a mutex.
type callMap struct {
m sync.Mutex
activeCalls map[uint32]*callRequest
closeErr error
}

func (r *receiver) run(ctx context.Context, c *channel) {
defer r.wg.Done()
// newCallMap returns a new callMap with an empty set of active calls.
func newCallMap() *callMap {
return &callMap{
activeCalls: make(map[uint32]*callRequest),
}
}

for {
select {
case <-ctx.Done():
r.err = ctx.Err()
return
default:
mh, p, err := c.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
r.err = filterCloseErr(err)
return
}
}
select {
case r.messages <- &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}:
case <-ctx.Done():
r.err = ctx.Err()
return
}
}
// set adds a call entry to the map with the given streamID key.
func (cm *callMap) set(streamID uint32, cr *callRequest) error {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return cm.closeErr
}
cm.activeCalls[streamID] = cr
return nil
}

// get looks up the call entry for the given streamID key, then removes it
// from the map and returns it.
func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return nil, false, cm.closeErr
}
cr, ok = cm.activeCalls[streamID]
if ok {
delete(cm.activeCalls, streamID)
}
return
}

// abort sends the given error to each active call, and clears the map.
// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort.
func (cm *callMap) abort(err error) error {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return cm.closeErr
}
for streamID, call := range cm.activeCalls {
call.errs <- err
delete(cm.activeCalls, streamID)
}
cm.closeErr = err
return nil
}

func (c *Client) run() {
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
waiters = newCallMap()
receiverDone = make(chan struct{})
)

// broadcast the shutdown error to the remaining waiters.
abortWaiters := func(wErr error) {
for _, waiter := range waiters {
waiter.errs <- wErr
// Sender goroutine
// Receives calls from dispatch, adds them to the set of active calls, and sends them
// to the server.
go func() {
var streamID uint32 = 1
for {
select {
case <-c.ctx.Done():
return
case call := <-c.calls:
id := streamID
streamID += 2 // enforce odd client initiated request ids
if err := waiters.set(id, call); err != nil {
call.errs <- err // errs is buffered so should not block.
continue
}
if err := c.send(id, messageTypeRequest, call.req); err != nil {
call.errs <- err // errs is buffered so should not block.
waiters.get(id) // remove from waiters set
}
}
}
}
recv := &receiver{
wg: &wg,
messages: incoming,
}
wg.Add(1)
}()

// Receiver goroutine
// Receives responses from the server, looks up the call info in the set of active calls,
// and notifies the caller of the response.
go func() {
wg.Wait()
close(receiversDone)
defer close(receiverDone)
for {
select {
case <-c.ctx.Done():
c.setError(c.ctx.Err())
return
default:
mh, p, err := c.channel.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
c.setError(filterCloseErr(err))
return
}
}
msg := &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}
call, ok, err := waiters.get(mh.StreamID)
if err != nil {
logrus.Errorf("ttrpc: failed to look up active call: %s", err)
continue
}
if !ok {
logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID)
continue
}
call.errs <- c.recv(call.resp, msg)
}
}
}()
go recv.run(c.ctx, c.channel)

defer func() {
c.conn.Close()
Expand All @@ -269,32 +328,14 @@ func (c *Client) run() {

for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}

waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids
case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
continue
}

call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID)
case <-receiversDone:
// all the receivers have exited
if recv.err != nil {
c.setError(recv.err)
}
case <-receiverDone:
// The receiver has exited.
// don't return out, let the close of the context trigger the abort of waiters
c.Close()
case <-c.ctx.Done():
abortWaiters(c.error())
// Abort all active calls. This will also prevent any new calls from being added
// to waiters.
waiters.abort(c.error())
return
}
}
Expand Down

0 comments on commit 4f0aeb5

Please sign in to comment.