Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix request timeout race condition #465

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ func (l *Conn) Close() (err error) {
l.chanMessage <- &messagePacket{Op: MessageQuit}

timeoutCtx := context.Background()
if l.requestTimeout > 0 {
requestTimeout := l.getTimeout()
if requestTimeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(requestTimeout))
defer cancelFunc()
}
select {
Expand All @@ -316,6 +317,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}

func (l *Conn) getTimeout() int64 {
return atomic.LoadInt64(&l.requestTimeout)
}

// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
if messageID, ok := <-l.chanMessageID; ok {
Expand Down Expand Up @@ -486,7 +491,7 @@ func (l *Conn) processMessages() {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
Expand Down Expand Up @@ -514,7 +519,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
close(message.Context.responses)
break
}
Expand All @@ -524,9 +529,9 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context

// Add timeout if defined
if l.requestTimeout > 0 {
if l.getTimeout() > 0 {
go func() {
timer := time.NewTimer(time.Duration(l.requestTimeout))
timer := time.NewTimer(time.Duration(l.getTimeout()))
defer func() {
if err := recover(); err != nil {
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
Expand All @@ -549,7 +554,7 @@ func (l *Conn) processMessages() {
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
} else {
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
Expand All @@ -559,7 +564,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
Expand Down
26 changes: 26 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,32 @@ func TestInvalidStateCloseDeadlock(t *testing.T) {
conn.Close()
}

func TestRequestTimeoutDeadlock(t *testing.T) {
// The do-nothing server that accepts requests and does nothing
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error connecting to localhost tcp: %v", err)
}

// Create an Ldap connection
conn := NewConn(c, false)
conn.Start()
// trigger a race condition on accessing request timeout
n := 3
for i := 0; i < n; i++ {
go func() {
conn.SetTimeout(time.Millisecond)
}()
}

// Attempt to close the connection when the message handler is
// blocked or inactive
conn.Close()
}

// TestInvalidStateSendResponseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateSendResponseDeadlock(t *testing.T) {
Expand Down
21 changes: 13 additions & 8 deletions v3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ func (l *Conn) Close() (err error) {
l.chanMessage <- &messagePacket{Op: MessageQuit}

timeoutCtx := context.Background()
if l.requestTimeout > 0 {
if l.getTimeout() > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
defer cancelFunc()
}
select {
Expand All @@ -316,6 +316,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}

func (l *Conn) getTimeout() int64 {
return atomic.LoadInt64(&l.requestTimeout)
}

// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
if messageID, ok := <-l.chanMessageID; ok {
Expand Down Expand Up @@ -486,7 +490,7 @@ func (l *Conn) processMessages() {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
Expand Down Expand Up @@ -514,7 +518,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
close(message.Context.responses)
break
}
Expand All @@ -524,9 +528,10 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context

// Add timeout if defined
if l.requestTimeout > 0 {
requestTimeout := l.getTimeout()
if requestTimeout > 0 {
go func() {
timer := time.NewTimer(time.Duration(l.requestTimeout))
timer := time.NewTimer(time.Duration(requestTimeout))
defer func() {
if err := recover(); err != nil {
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
Expand All @@ -549,7 +554,7 @@ func (l *Conn) processMessages() {
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
} else {
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
Expand All @@ -559,7 +564,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
Expand Down
26 changes: 26 additions & 0 deletions v3/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,32 @@ func TestUnresponsiveConnection(t *testing.T) {
}
}

func TestRequestTimeoutDeadlock(t *testing.T) {
// The do-nothing server that accepts requests and does nothing
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error connecting to localhost tcp: %v", err)
}

// Create an Ldap connection
conn := NewConn(c, false)
conn.Start()
// trigger a race condition on accessing request timeout
n := 3
for i := 0; i < n; i++ {
go func() {
conn.SetTimeout(time.Millisecond)
}()
}

// Attempt to close the connection when the message handler is
// blocked or inactive
conn.Close()
}

// TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateCloseDeadlock(t *testing.T) {
Expand Down