diff --git a/conn.go b/conn.go index 6ed7b5e6..25a691cb 100644 --- a/conn.go +++ b/conn.go @@ -87,6 +87,7 @@ const ( type Conn struct { // requestTimeout is loaded atomically // so we need to ensure 64-bit alignment on 32-bit platforms. + // https://github.com/go-ldap/ldap/pull/199 requestTimeout int64 conn net.Conn isTLS bool @@ -281,9 +282,7 @@ func (l *Conn) Close() { // SetTimeout sets the time after a request is sent that a MessageTimeout triggers func (l *Conn) SetTimeout(timeout time.Duration) { - if timeout > 0 { - atomic.StoreInt64(&l.requestTimeout, int64(timeout)) - } + atomic.StoreInt64(&l.requestTimeout, int64(timeout)) } // Returns the next available messageID @@ -486,20 +485,26 @@ func (l *Conn) processMessages() { l.messageContexts[message.MessageID] = message.Context // Add timeout if defined - requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) - if requestTimeout > 0 { + if l.requestTimeout > 0 { go func() { + timer := time.NewTimer(time.Duration(l.requestTimeout)) defer func() { if err := recover(); err != nil { logger.Printf("ldap: recovered panic in RequestTimeout: %v", err) } + + timer.Stop() }() - time.Sleep(requestTimeout) - timeoutMessage := &messagePacket{ - Op: MessageTimeout, - MessageID: message.MessageID, + + select { + case <-timer.C: + timeoutMessage := &messagePacket{ + Op: MessageTimeout, + MessageID: message.MessageID, + } + l.sendProcessMessage(timeoutMessage) + case <-message.Context.done: } - l.sendProcessMessage(timeoutMessage) }() } case MessageResponse: diff --git a/v3/conn.go b/v3/conn.go index 6ed7b5e6..25a691cb 100644 --- a/v3/conn.go +++ b/v3/conn.go @@ -87,6 +87,7 @@ const ( type Conn struct { // requestTimeout is loaded atomically // so we need to ensure 64-bit alignment on 32-bit platforms. + // https://github.com/go-ldap/ldap/pull/199 requestTimeout int64 conn net.Conn isTLS bool @@ -281,9 +282,7 @@ func (l *Conn) Close() { // SetTimeout sets the time after a request is sent that a MessageTimeout triggers func (l *Conn) SetTimeout(timeout time.Duration) { - if timeout > 0 { - atomic.StoreInt64(&l.requestTimeout, int64(timeout)) - } + atomic.StoreInt64(&l.requestTimeout, int64(timeout)) } // Returns the next available messageID @@ -486,20 +485,26 @@ func (l *Conn) processMessages() { l.messageContexts[message.MessageID] = message.Context // Add timeout if defined - requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) - if requestTimeout > 0 { + if l.requestTimeout > 0 { go func() { + timer := time.NewTimer(time.Duration(l.requestTimeout)) defer func() { if err := recover(); err != nil { logger.Printf("ldap: recovered panic in RequestTimeout: %v", err) } + + timer.Stop() }() - time.Sleep(requestTimeout) - timeoutMessage := &messagePacket{ - Op: MessageTimeout, - MessageID: message.MessageID, + + select { + case <-timer.C: + timeoutMessage := &messagePacket{ + Op: MessageTimeout, + MessageID: message.MessageID, + } + l.sendProcessMessage(timeoutMessage) + case <-message.Context.done: } - l.sendProcessMessage(timeoutMessage) }() } case MessageResponse: