Skip to content

Commit

Permalink
Add search asynchronously with context (#440)
Browse files Browse the repository at this point in the history
* feat: add search with channels inspired by #319

* refactor: fix to check proper test results #319

* refactor: fix to use unpackAttributes() for Attributes #319

* refactor: returns receive-only channel to prevent closing it from the caller #319

* refactor: pass channelSize to be able to controll buffered channel by the caller #319

* fix: recover an asynchronouse closing timing issue #319

* fix: consume all entries from the channel to prevent blocking by the connection #319

* feat: add initial search async function with channel #341

* feat: provide search async function and drop search with channels #319 #341

* refactor: lock when to call GetLastError since it might be in communication
  • Loading branch information
t2y authored Jun 30, 2023
1 parent cdb0754 commit 7778a1c
Show file tree
Hide file tree
Showing 12 changed files with 632 additions and 0 deletions.
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"time"
)
Expand Down Expand Up @@ -32,6 +33,7 @@ type Client interface {
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

Search(*SearchRequest) (*SearchResult, error)
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
}
2 changes: 2 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ func (l *Conn) nextMessageID() int64 {
// GetLastError returns the last recorded error from goroutines like processMessages and reader.
// Only the last recorded error will be returned.
func (l *Conn) GetLastError() error {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
return l.err
}

Expand Down
30 changes: 30 additions & 0 deletions examples_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -50,6 +51,35 @@ func ExampleConn_Search() {
}
}

// This example demonstrates how to search asynchronously
func ExampleConn_SearchAsync() {
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
log.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
"dc=example,dc=com", // The base dn to search
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
"(&(objectClass=organizationalPerson))", // The filter to apply
[]string{"dn", "cn"}, // A list attributes to retrieve
nil,
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

r := l.SearchAsync(ctx, searchRequest, 64)
for r.Next() {
entry := r.Entry()
fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN)
}
if err := r.Err(); err != nil {
log.Fatal(err)
}
}

// This example demonstrates how to start a TLS connection
func ExampleConn_StartTLS() {
l, err := DialURL("ldap://ldap.example.com:389")
Expand Down
66 changes: 66 additions & 0 deletions ldap_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package ldap

import (
"context"
"crypto/tls"
"log"
"testing"

ber "github.com/go-asn1-ber/asn1-ber"
Expand Down Expand Up @@ -344,3 +346,67 @@ func TestEscapeDN(t *testing.T) {
})
}
}

func TestSearchAsync(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

srs := make([]*Entry, 0)
ctx := context.Background()
r := l.SearchAsync(ctx, searchRequest, 64)
for r.Next() {
srs = append(srs, r.Entry())
}
if err := r.Err(); err != nil {
log.Fatal(err)
}

t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}

func TestSearchAsyncAndCancel(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

cancelNum := 10
srs := make([]*Entry, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
r := l.SearchAsync(ctx, searchRequest, 0)
for r.Next() {
srs = append(srs, r.Entry())
if len(srs) == cancelNum {
cancel()
}
}
if err := r.Err(); err != nil {
log.Fatal(err)
}

if len(srs) > cancelNum+3 {
// the cancellation process is asynchronous,
// so it might get some entries after calling cancel()
t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3)
}
t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}
182 changes: 182 additions & 0 deletions response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package ldap

import (
"context"
"errors"
"fmt"

ber "github.com/go-asn1-ber/asn1-ber"
)

// Response defines an interface to get data from an LDAP server
type Response interface {
Entry() *Entry
Referral() string
Controls() []Control
Err() error
Next() bool
}

type searchResponse struct {
conn *Conn
ch chan *SearchSingleResult

entry *Entry
referral string
controls []Control
err error
}

// Entry returns an entry from the given search request
func (r *searchResponse) Entry() *Entry {
return r.entry
}

// Referral returns a referral from the given search request
func (r *searchResponse) Referral() string {
return r.referral
}

// Controls returns controls from the given search request
func (r *searchResponse) Controls() []Control {
return r.controls
}

// Err returns an error when the given search request was failed
func (r *searchResponse) Err() error {
return r.err
}

// Next returns whether next data exist or not
func (r *searchResponse) Next() bool {
res, ok := <-r.ch
if !ok {
return false
}
if res == nil {
return false
}
r.err = res.Error
if r.err != nil {
return false
}
r.err = r.conn.GetLastError()
if r.err != nil {
return false
}
r.entry = res.Entry
r.referral = res.Referral
r.controls = res.Controls
return true
}

func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) {
go func() {
defer func() {
close(r.ch)
if err := recover(); err != nil {
r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err)
}
}()

if r.conn.IsClosing() {
return
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
// encode search request
err := searchRequest.appendTo(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
r.conn.Debug.PrintPacket(packet)

msgCtx, err := r.conn.sendMessage(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
defer r.conn.finishMessage(msgCtx)

foundSearchSingleResultDone := false
for !foundSearchSingleResultDone {
select {
case <-ctx.Done():
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
return
default:
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
r.ch <- &SearchSingleResult{Error: err}
return
}
packet, err = packetResponse.ReadPacket()
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}

if r.conn.Debug {
if err := addLDAPDescriptions(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
ber.PrintPacket(packet)
}

switch packet.Children[1].Tag {
case ApplicationSearchResultEntry:
r.ch <- &SearchSingleResult{
Entry: &Entry{
DN: packet.Children[1].Children[0].Value.(string),
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
},
}

case ApplicationSearchResultDone:
if err := GetLDAPError(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
if len(packet.Children) == 3 {
result := &SearchSingleResult{}
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
werr := fmt.Errorf("failed to decode child control: %w", err)
r.ch <- &SearchSingleResult{Error: werr}
return
}
result.Controls = append(result.Controls, decodedChild)
}
r.ch <- result
}
foundSearchSingleResultDone = true

case ApplicationSearchResultReference:
ref := packet.Children[1].Children[0].Value.(string)
r.ch <- &SearchSingleResult{Referral: ref}
}
}
}
r.conn.Debug.Printf("%d: returning", msgCtx.id)
}()
}

func newSearchResponse(conn *Conn, bufferSize int) *searchResponse {
var ch chan *SearchSingleResult
if bufferSize > 0 {
ch = make(chan *SearchSingleResult, bufferSize)
} else {
ch = make(chan *SearchSingleResult)
}
return &searchResponse{
conn: conn,
ch: ch,
}
}
34 changes: 34 additions & 0 deletions search.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"errors"
"fmt"
"reflect"
Expand Down Expand Up @@ -375,6 +376,28 @@ func (s *SearchResult) appendTo(r *SearchResult) {
r.Controls = append(r.Controls, s.Controls...)
}

// SearchSingleResult holds the server's single response to a search request
type SearchSingleResult struct {
// Entry is the returned entry
Entry *Entry
// Referral is the returned referral
Referral string
// Controls are the returned controls
Controls []Control
// Error is set when the search request was failed
Error error
}

// Print outputs a human-readable description
func (s *SearchSingleResult) Print() {
s.Entry.Print()
}

// PrettyPrint outputs a human-readable description with indenting
func (s *SearchSingleResult) PrettyPrint(indent int) {
s.Entry.PrettyPrint(indent)
}

// SearchRequest represents a search request to send to the server
type SearchRequest struct {
BaseDN string
Expand Down Expand Up @@ -559,6 +582,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
}
}

// SearchAsync performs a search request and returns all search results asynchronously.
// This means you get all results until an error happens (or the search successfully finished),
// e.g. for size / time limited requests all are recieved until the limit is reached.
// To stop the search, call cancel function returned context.
func (l *Conn) SearchAsync(
ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response {
r := newSearchResponse(l, bufferSize)
r.start(ctx, searchRequest)
return r
}

// unpackAttributes will extract all given LDAP attributes and it's values
// from the ber.Packet
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
Expand Down
Loading

0 comments on commit 7778a1c

Please sign in to comment.