From a9daeebe787c8b355e1522e764fc436e78cc98ab Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Mon, 5 Jun 2023 11:48:15 +0900 Subject: [PATCH] feat: add initial search async function with channel #341 --- v3/client.go | 3 + v3/examples_test.go | 29 ++++++++ v3/response.go | 172 ++++++++++++++++++++++++++++++++++++++++++++ v3/search.go | 11 +++ 4 files changed, 215 insertions(+) create mode 100644 v3/response.go diff --git a/v3/client.go b/v3/client.go index b438d254..cef2d91b 100644 --- a/v3/client.go +++ b/v3/client.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "time" ) @@ -32,6 +33,8 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response + SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/v3/examples_test.go b/v3/examples_test.go index 1de6c9be..46898abb 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -51,6 +51,35 @@ func ExampleConn_Search() { } } +// This example demonstrates how to search with channel +func ExampleConn_SearchAsync() { + l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + entry := r.Entry() + fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN) + } + if err := r.Err(); err != nil { + log.Fatal(err) + } +} + // This example demonstrates how to search with channel func ExampleConn_SearchWithChannel() { l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) diff --git a/v3/response.go b/v3/response.go new file mode 100644 index 00000000..3bfef84e --- /dev/null +++ b/v3/response.go @@ -0,0 +1,172 @@ +package ldap + +import ( + "context" + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Response defines an interface to get data from an LDAP server +type Response interface { + Entry() *Entry + Referral() string + Controls() []Control + Err() error + Next() bool +} + +type searchResponse struct { + conn *Conn + ch chan *SearchSingleResult + + entry *Entry + referral string + controls []Control + err error +} + +// Entry returns an entry from the given search request +func (r *searchResponse) Entry() *Entry { + return r.entry +} + +// Referral returns a referral from the given search request +func (r *searchResponse) Referral() string { + return r.referral +} + +// Controls returns controls from the given search request +func (r *searchResponse) Controls() []Control { + return r.controls +} + +// Err returns an error when the given search request was failed +func (r *searchResponse) Err() error { + return r.err +} + +// Next returns whether next data exist or not +func (r *searchResponse) Next() bool { + res := <-r.ch + if res == nil { + return false + } + r.err = res.Error + if r.err != nil { + return false + } + r.err = r.conn.GetLastError() + if r.err != nil { + return false + } + r.entry = res.Entry + r.referral = res.Referral + r.controls = res.Controls + return true +} + +func (r *searchResponse) searchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) { + if bufferSize > 0 { + r.ch = make(chan *SearchSingleResult, bufferSize) + } else { + r.ch = make(chan *SearchSingleResult) + } + go func() { + defer func() { + close(r.ch) + if err := recover(); err != nil { + r.conn.err = fmt.Errorf("ldap: recovered panic in searchAsync: %v", err) + } + }() + + if r.conn.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + r.conn.Debug.PrintPacket(packet) + + msgCtx, err := r.conn.sendMessage(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + defer r.conn.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + r.conn.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + r.ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + + if r.conn.Debug { + if err := addLDAPDescriptions(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + r.ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, + } + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + r.ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + r.ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + r.ch <- &SearchSingleResult{Referral: ref} + } + } + } + r.conn.Debug.Printf("%d: returning", msgCtx.id) + }() +} diff --git a/v3/search.go b/v3/search.go index f3edcd70..2d8e13ad 100644 --- a/v3/search.go +++ b/v3/search.go @@ -584,6 +584,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchAsync performs a search request and returns all search results asynchronously. +// This means you get all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved until the limit is reached. +// To stop the search, call cancel function returned context. +func (l *Conn) SearchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { + r := &searchResponse{conn: l} + r.searchAsync(ctx, searchRequest, bufferSize) + return r +} + // SearchWithChannel performs a search request and returns all search results // via the returned channel as soon as they are received. This means you get // all results until an error happens (or the search successfully finished),