Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add search asynchronously with context #440

Merged
merged 10 commits into from
Jun 30, 2023
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"time"
)
Expand Down Expand Up @@ -32,6 +33,7 @@ type Client interface {
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

Search(*SearchRequest) (*SearchResult, error)
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
}
2 changes: 2 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ func (l *Conn) nextMessageID() int64 {
// GetLastError returns the last recorded error from goroutines like processMessages and reader.
// Only the last recorded error will be returned.
func (l *Conn) GetLastError() error {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
return l.err
}

Expand Down
30 changes: 30 additions & 0 deletions examples_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -50,6 +51,35 @@ func ExampleConn_Search() {
}
}

// This example demonstrates how to search asynchronously
func ExampleConn_SearchAsync() {
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
log.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
"dc=example,dc=com", // The base dn to search
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
"(&(objectClass=organizationalPerson))", // The filter to apply
[]string{"dn", "cn"}, // A list attributes to retrieve
nil,
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

r := l.SearchAsync(ctx, searchRequest, 64)
for r.Next() {
entry := r.Entry()
fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN)
}
if err := r.Err(); err != nil {
log.Fatal(err)
}
}

// This example demonstrates how to start a TLS connection
func ExampleConn_StartTLS() {
l, err := DialURL("ldap://ldap.example.com:389")
Expand Down
66 changes: 66 additions & 0 deletions ldap_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package ldap

import (
"context"
"crypto/tls"
"log"
"testing"

ber "github.com/go-asn1-ber/asn1-ber"
Expand Down Expand Up @@ -344,3 +346,67 @@ func TestEscapeDN(t *testing.T) {
})
}
}

func TestSearchAsync(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

srs := make([]*Entry, 0)
ctx := context.Background()
r := l.SearchAsync(ctx, searchRequest, 64)
for r.Next() {
srs = append(srs, r.Entry())
}
if err := r.Err(); err != nil {
log.Fatal(err)
}

t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}

func TestSearchAsyncAndCancel(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

cancelNum := 10
srs := make([]*Entry, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r := l.SearchAsync(ctx, searchRequest, 0)
for r.Next() {
srs = append(srs, r.Entry())
if len(srs) == cancelNum {
cancel()
}
}
if err := r.Err(); err != nil {
log.Fatal(err)
}

if len(srs) > cancelNum+3 {
// the cancellation process is asynchronous,
// so it might get some entries after calling cancel()
t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3)
}
t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}
182 changes: 182 additions & 0 deletions response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package ldap

import (
"context"
"errors"
"fmt"

ber "github.com/go-asn1-ber/asn1-ber"
)

// Response defines an interface to get data from an LDAP server
type Response interface {
Entry() *Entry
Referral() string
Controls() []Control
Err() error
Next() bool
}

type searchResponse struct {
conn *Conn
ch chan *SearchSingleResult

entry *Entry
referral string
controls []Control
err error
}

// Entry returns an entry from the given search request
func (r *searchResponse) Entry() *Entry {
return r.entry
}

// Referral returns a referral from the given search request
func (r *searchResponse) Referral() string {
return r.referral
}

// Controls returns controls from the given search request
func (r *searchResponse) Controls() []Control {
return r.controls
}

// Err returns an error when the given search request was failed
func (r *searchResponse) Err() error {
return r.err
}

// Next returns whether next data exist or not
func (r *searchResponse) Next() bool {
res, ok := <-r.ch
if !ok {
return false
}
if res == nil {
return false
}
r.err = res.Error
if r.err != nil {
return false
}
r.err = r.conn.GetLastError()
if r.err != nil {
return false
}
r.entry = res.Entry
r.referral = res.Referral
r.controls = res.Controls
return true
}

func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) {
go func() {
defer func() {
close(r.ch)
if err := recover(); err != nil {
r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err)
}
}()

if r.conn.IsClosing() {
return
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
// encode search request
err := searchRequest.appendTo(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
r.conn.Debug.PrintPacket(packet)

msgCtx, err := r.conn.sendMessage(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
defer r.conn.finishMessage(msgCtx)

foundSearchSingleResultDone := false
for !foundSearchSingleResultDone {
select {
case <-ctx.Done():
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
return
default:
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
r.ch <- &SearchSingleResult{Error: err}
return
}
packet, err = packetResponse.ReadPacket()
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}

if r.conn.Debug {
if err := addLDAPDescriptions(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
ber.PrintPacket(packet)
}

switch packet.Children[1].Tag {
case ApplicationSearchResultEntry:
r.ch <- &SearchSingleResult{
Entry: &Entry{
DN: packet.Children[1].Children[0].Value.(string),
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
},
}

case ApplicationSearchResultDone:
if err := GetLDAPError(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
if len(packet.Children) == 3 {
result := &SearchSingleResult{}
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
werr := fmt.Errorf("failed to decode child control: %w", err)
r.ch <- &SearchSingleResult{Error: werr}
return
}
result.Controls = append(result.Controls, decodedChild)
}
r.ch <- result
}
foundSearchSingleResultDone = true

case ApplicationSearchResultReference:
ref := packet.Children[1].Children[0].Value.(string)
r.ch <- &SearchSingleResult{Referral: ref}
}
}
}
r.conn.Debug.Printf("%d: returning", msgCtx.id)
}()
}

func newSearchResponse(conn *Conn, bufferSize int) *searchResponse {
var ch chan *SearchSingleResult
if bufferSize > 0 {
ch = make(chan *SearchSingleResult, bufferSize)
} else {
ch = make(chan *SearchSingleResult)
}
return &searchResponse{
conn: conn,
ch: ch,
}
}
34 changes: 34 additions & 0 deletions search.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"errors"
"fmt"
"reflect"
Expand Down Expand Up @@ -375,6 +376,28 @@ func (s *SearchResult) appendTo(r *SearchResult) {
r.Controls = append(r.Controls, s.Controls...)
}

// SearchSingleResult holds the server's single response to a search request
type SearchSingleResult struct {
// Entry is the returned entry
Entry *Entry
// Referral is the returned referral
Referral string
// Controls are the returned controls
Controls []Control
// Error is set when the search request was failed
Error error
}

// Print outputs a human-readable description
func (s *SearchSingleResult) Print() {
s.Entry.Print()
}

// PrettyPrint outputs a human-readable description with indenting
func (s *SearchSingleResult) PrettyPrint(indent int) {
s.Entry.PrettyPrint(indent)
}

// SearchRequest represents a search request to send to the server
type SearchRequest struct {
BaseDN string
Expand Down Expand Up @@ -559,6 +582,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
}
}

// SearchAsync performs a search request and returns all search results asynchronously.
// This means you get all results until an error happens (or the search successfully finished),
// e.g. for size / time limited requests all are recieved until the limit is reached.
// To stop the search, call cancel function returned context.
func (l *Conn) SearchAsync(
ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response {
r := newSearchResponse(l, bufferSize)
r.start(ctx, searchRequest)
return r
}

// unpackAttributes will extract all given LDAP attributes and it's values
// from the ber.Packet
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
Expand Down
Loading