Skip to content

Commit

Permalink
Improve handling of connection read errors (#66)
Browse files Browse the repository at this point in the history
If the reader() goroutine encounters an unexpected error when reading a packet
a series of unwinding takes place:

- The reader() goroutine shuts down. A deferred function runs which closes the
  connection.
- Close() sends a MessageQuit message to the processMessages() goroutine.
- The processMessages() goroutine shuts down. A deferred function runs which
  closes the results channels for any pending requests.
- These pending request handlers receive a nil *PacketResponse because their
  response channel has been closed. They then return a not-very-helpful error
  string: "ldap: channel closed".

This patch updates the reader() goroutine to set a closeErr value on the conn
when it encounters an unexpected error reading a packet from the server. The
processMessages() deferred function checks for this closeErr when it is
shutting down due to the connection closing and sends this error in the
*PacketResponse values to the pending request handlers *before* closing those
results channels. This allows for the error which caused the shutdown to be
bubbled up to all pending request calls.

Docker-DCO-1.1-Signed-off-by: Josh Hawn <[email protected]> (github: jlhawn)
  • Loading branch information
jlhawn authored and liggitt committed Jun 14, 2016
1 parent 1dc79ce commit 8a8cb05
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 3 deletions.
14 changes: 13 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type Conn struct {
conn net.Conn
isTLS bool
isClosing bool
closeErr error
isStartingTLS bool
Debug debugging
chanConfirm chan bool
Expand Down Expand Up @@ -298,6 +299,11 @@ func (l *Conn) processMessages() {
log.Printf("ldap: recovered panic in processMessages: %v", err)
}
for messageID, channel := range l.chanResults {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.isClosing && l.closeErr != nil {
channel <- &PacketResponse{Error: l.closeErr}
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(channel)
delete(l.chanResults, messageID)
Expand All @@ -324,15 +330,20 @@ func (l *Conn) processMessages() {
case MessageRequest:
// Add to message list and write to network
l.Debug.Printf("Sending message %d", message.MessageID)
l.chanResults[message.MessageID] = message.Channel

buf := message.Packet.Bytes()
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Channel <- &PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}
close(message.Channel)
break
}

// Only add to chanResults if we were able to
// successfully write the message.
l.chanResults[message.MessageID] = message.Channel

// Add timeout if defined
if l.requestTimeout > 0 {
go func() {
Expand Down Expand Up @@ -397,6 +408,7 @@ func (l *Conn) reader() {
if err != nil {
// A read error is expected here if we are closing the connection...
if !l.isClosing {
l.closeErr = fmt.Errorf("unable to read LDAP response packet: %s", err)
l.Debug.Printf("reader error: %s", err.Error())
}
return
Expand Down
77 changes: 75 additions & 2 deletions error_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package ldap

import (
"errors"
"net"
"strings"
"testing"
"time"

"gopkg.in/asn1-ber.v1"
)
Expand All @@ -16,14 +20,83 @@ func TestNilPacket(t *testing.T) {

// Test for nil result
kids := []*ber.Packet{
&ber.Packet{}, // Unused
nil, // Can't be nil
{}, // Unused
nil, // Can't be nil
}
pack := &ber.Packet{Children: kids}
code, _ = getLDAPResultCode(pack)

if code != ErrorUnexpectedResponse {
t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", code)
}
}

// TestConnReadErr tests that an unexpected error reading from underlying
// connection bubbles up to the goroutine which makes a request.
func TestConnReadErr(t *testing.T) {
conn := &signalErrConn{
signals: make(chan error),
}

ldapConn := NewConn(conn, false)
ldapConn.Start()

// Make a dummy search request.
searchReq := NewSearchRequest("dc=example,dc=com", ScopeWholeSubtree, DerefAlways, 0, 0, false, "(objectClass=*)", nil, nil)

expectedError := errors.New("this is the error you are looking for")

// Send the signal after a short amount of time.
time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError })

// This should block until the underlyiny conn gets the error signal
// which should bubble up through the reader() goroutine, close the
// connection, and
_, err := ldapConn.Search(searchReq)
if err == nil || !strings.Contains(err.Error(), expectedError.Error()) {
t.Errorf("not the expected error: %s", err)
}
}

// signalErrConn is a helful type used with TestConnReadErr. It implements the
// net.Conn interface to be used as a connection for the test. Most methods are
// no-ops but the Read() method blocks until it receives a signal which it
// returns as an error.
type signalErrConn struct {
signals chan error
}

// Read blocks until an error is sent on the internal signals channel. That
// error is returned.
func (c *signalErrConn) Read(b []byte) (n int, err error) {
return 0, <-c.signals
}

func (c *signalErrConn) Write(b []byte) (n int, err error) {
return len(b), nil
}

func (c *signalErrConn) Close() error {
close(c.signals)
return nil
}

func (c *signalErrConn) LocalAddr() net.Addr {
return (*net.TCPAddr)(nil)
}

func (c *signalErrConn) RemoteAddr() net.Addr {
return (*net.TCPAddr)(nil)
}

func (c *signalErrConn) SetDeadline(t time.Time) error {
return nil
}

func (c *signalErrConn) SetReadDeadline(t time.Time) error {
return nil
}

func (c *signalErrConn) SetWriteDeadline(t time.Time) error {
return nil
}

0 comments on commit 8a8cb05

Please sign in to comment.