Skip to content

Commit

Permalink
feat: add initial search async function with channel #341
Browse files Browse the repository at this point in the history
  • Loading branch information
t2y committed Jun 5, 2023
1 parent 7279710 commit a9daeeb
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 0 deletions.
3 changes: 3 additions & 0 deletions v3/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"time"
)
Expand Down Expand Up @@ -32,6 +33,8 @@ type Client interface {
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

Search(*SearchRequest) (*SearchResult, error)
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
}
29 changes: 29 additions & 0 deletions v3/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,35 @@ func ExampleConn_Search() {
}
}

// This example demonstrates how to search with channel
func ExampleConn_SearchAsync() {
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
log.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
"dc=example,dc=com", // The base dn to search
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
"(&(objectClass=organizationalPerson))", // The filter to apply
[]string{"dn", "cn"}, // A list attributes to retrieve
nil,
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

r := l.SearchAsync(ctx, searchRequest, 64)
for r.Next() {
entry := r.Entry()
fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN)
}
if err := r.Err(); err != nil {
log.Fatal(err)
}
}

// This example demonstrates how to search with channel
func ExampleConn_SearchWithChannel() {
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
Expand Down
172 changes: 172 additions & 0 deletions v3/response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package ldap

import (
"context"
"errors"
"fmt"

ber "github.com/go-asn1-ber/asn1-ber"
)

// Response defines an interface to get data from an LDAP server
type Response interface {
Entry() *Entry
Referral() string
Controls() []Control
Err() error
Next() bool
}

type searchResponse struct {
conn *Conn
ch chan *SearchSingleResult

entry *Entry
referral string
controls []Control
err error
}

// Entry returns an entry from the given search request
func (r *searchResponse) Entry() *Entry {
return r.entry
}

// Referral returns a referral from the given search request
func (r *searchResponse) Referral() string {
return r.referral
}

// Controls returns controls from the given search request
func (r *searchResponse) Controls() []Control {
return r.controls
}

// Err returns an error when the given search request was failed
func (r *searchResponse) Err() error {
return r.err
}

// Next returns whether next data exist or not
func (r *searchResponse) Next() bool {
res := <-r.ch
if res == nil {
return false
}
r.err = res.Error
if r.err != nil {
return false
}
r.err = r.conn.GetLastError()
if r.err != nil {
return false
}
r.entry = res.Entry
r.referral = res.Referral
r.controls = res.Controls
return true
}

func (r *searchResponse) searchAsync(
ctx context.Context, searchRequest *SearchRequest, bufferSize int) {
if bufferSize > 0 {
r.ch = make(chan *SearchSingleResult, bufferSize)
} else {
r.ch = make(chan *SearchSingleResult)
}
go func() {
defer func() {
close(r.ch)
if err := recover(); err != nil {
r.conn.err = fmt.Errorf("ldap: recovered panic in searchAsync: %v", err)
}
}()

if r.conn.IsClosing() {
return
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
// encode search request
err := searchRequest.appendTo(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
r.conn.Debug.PrintPacket(packet)

msgCtx, err := r.conn.sendMessage(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
defer r.conn.finishMessage(msgCtx)

foundSearchSingleResultDone := false
for !foundSearchSingleResultDone {
select {
case <-ctx.Done():
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
return
default:
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
r.ch <- &SearchSingleResult{Error: err}
return
}
packet, err = packetResponse.ReadPacket()
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}

if r.conn.Debug {
if err := addLDAPDescriptions(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
ber.PrintPacket(packet)
}

switch packet.Children[1].Tag {
case ApplicationSearchResultEntry:
r.ch <- &SearchSingleResult{
Entry: &Entry{
DN: packet.Children[1].Children[0].Value.(string),
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
},
}

case ApplicationSearchResultDone:
if err := GetLDAPError(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
if len(packet.Children) == 3 {
result := &SearchSingleResult{}
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
werr := fmt.Errorf("failed to decode child control: %w", err)
r.ch <- &SearchSingleResult{Error: werr}
return
}
result.Controls = append(result.Controls, decodedChild)
}
r.ch <- result
}
foundSearchSingleResultDone = true

case ApplicationSearchResultReference:
ref := packet.Children[1].Children[0].Value.(string)
r.ch <- &SearchSingleResult{Referral: ref}
}
}
}
r.conn.Debug.Printf("%d: returning", msgCtx.id)
}()
}
11 changes: 11 additions & 0 deletions v3/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
}
}

// SearchAsync performs a search request and returns all search results asynchronously.
// This means you get all results until an error happens (or the search successfully finished),
// e.g. for size / time limited requests all are recieved until the limit is reached.
// To stop the search, call cancel function returned context.
func (l *Conn) SearchAsync(
ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response {
r := &searchResponse{conn: l}
r.searchAsync(ctx, searchRequest, bufferSize)
return r
}

// SearchWithChannel performs a search request and returns all search results
// via the returned channel as soon as they are received. This means you get
// all results until an error happens (or the search successfully finished),
Expand Down

0 comments on commit a9daeeb

Please sign in to comment.