diff --git a/client.go b/client.go index b438d254..5799f39b 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "time" ) @@ -32,6 +33,7 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/examples_test.go b/examples_test.go index 1de6c9be..61f16197 100644 --- a/examples_test.go +++ b/examples_test.go @@ -51,8 +51,8 @@ func ExampleConn_Search() { } } -// This example demonstrates how to search with channel -func ExampleConn_SearchWithChannel() { +// This example demonstrates how to search asynchronously +func ExampleConn_SearchAsync() { l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) if err != nil { log.Fatal(err) @@ -70,12 +70,13 @@ func ExampleConn_SearchWithChannel() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest, 64) - for res := range ch { - if res.Error != nil { - log.Fatalf("Error searching: %s", res.Error) - } - fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN) + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + entry := r.Entry() + fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN) + } + if err := r.Err(); err != nil { + log.Fatal(err) } } diff --git a/ldap_test.go b/ldap_test.go index bbeecc9d..5b96e039 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -3,6 +3,7 @@ package ldap import ( "context" "crypto/tls" + "log" "testing" ber "github.com/go-asn1-ber/asn1-ber" @@ -346,7 +347,7 @@ func TestEscapeDN(t *testing.T) { } } -func TestSearchWithChannel(t *testing.T) { +func TestSearchAsync(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -362,17 +363,18 @@ func TestSearchWithChannel(t *testing.T) { srs := make([]*Entry, 0) ctx := context.Background() - for sr := range l.SearchWithChannel(ctx, searchRequest, 64) { - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + srs = append(srs, r.Entry()) + } + if err := r.Err(); err != nil { + log.Fatal(err) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } -func TestSearchWithChannelAndCancel(t *testing.T) { +func TestSearchAsyncAndCancel(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -390,22 +392,21 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest, 0) - for i := 0; i < 10; i++ { - sr := <-ch - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 0) + for r.Next() { + srs = append(srs, r.Entry()) if len(srs) == cancelNum { cancel() } } - for range ch { - t.Log("Consume all entries from the channel to prevent blocking by the connection") + if err := r.Err(); err != nil { + log.Fatal(err) } - if len(srs) != cancelNum { - t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) + + if len(srs) > cancelNum+3 { + // the cancellation process is asynchronous, + // so it might get some entries after calling cancel() + t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } diff --git a/response.go b/response.go new file mode 100644 index 00000000..81d97d9b --- /dev/null +++ b/response.go @@ -0,0 +1,182 @@ +package ldap + +import ( + "context" + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Response defines an interface to get data from an LDAP server +type Response interface { + Entry() *Entry + Referral() string + Controls() []Control + Err() error + Next() bool +} + +type searchResponse struct { + conn *Conn + ch chan *SearchSingleResult + + entry *Entry + referral string + controls []Control + err error +} + +// Entry returns an entry from the given search request +func (r *searchResponse) Entry() *Entry { + return r.entry +} + +// Referral returns a referral from the given search request +func (r *searchResponse) Referral() string { + return r.referral +} + +// Controls returns controls from the given search request +func (r *searchResponse) Controls() []Control { + return r.controls +} + +// Err returns an error when the given search request was failed +func (r *searchResponse) Err() error { + return r.err +} + +// Next returns whether next data exist or not +func (r *searchResponse) Next() bool { + res, ok := <-r.ch + if !ok { + return false + } + if res == nil { + return false + } + r.err = res.Error + if r.err != nil { + return false + } + r.err = r.conn.GetLastError() + if r.err != nil { + return false + } + r.entry = res.Entry + r.referral = res.Referral + r.controls = res.Controls + return true +} + +func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) { + go func() { + defer func() { + close(r.ch) + if err := recover(); err != nil { + r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err) + } + }() + + if r.conn.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + r.conn.Debug.PrintPacket(packet) + + msgCtx, err := r.conn.sendMessage(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + defer r.conn.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + r.conn.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + r.ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + + if r.conn.Debug { + if err := addLDAPDescriptions(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + r.ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, + } + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + r.ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + r.ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + r.ch <- &SearchSingleResult{Referral: ref} + } + } + } + r.conn.Debug.Printf("%d: returning", msgCtx.id) + }() +} + +func newSearchResponse(conn *Conn, bufferSize int) *searchResponse { + var ch chan *SearchSingleResult + if bufferSize > 0 { + ch = make(chan *SearchSingleResult, bufferSize) + } else { + ch = make(chan *SearchSingleResult) + } + return &searchResponse{ + conn: conn, + ch: ch, + } +} diff --git a/search.go b/search.go index d2961947..3d8d9e70 100644 --- a/search.go +++ b/search.go @@ -582,114 +582,15 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } -// SearchWithChannel performs a search request and returns all search results -// via the returned channel as soon as they are received. This means you get -// all results until an error happens (or the search successfully finished), -// e.g. for size / time limited requests all are recieved via the channel -// until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult { - var ch chan *SearchSingleResult - if channelSize > 0 { - ch = make(chan *SearchSingleResult, channelSize) - } else { - ch = make(chan *SearchSingleResult) - } - go func() { - defer func() { - close(ch) - if err := recover(); err != nil { - l.err = fmt.Errorf("ldap: recovered panic in SearchWithChannel: %v", err) - } - }() - - if l.IsClosing() { - return - } - - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - // encode search request - err := searchRequest.appendTo(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - defer l.finishMessage(msgCtx) - - foundSearchSingleResultDone := false - for !foundSearchSingleResultDone { - select { - case <-ctx.Done(): - l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) - return - default: - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - ch <- &SearchSingleResult{Error: err} - return - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - ber.PrintPacket(packet) - } - - switch packet.Children[1].Tag { - case ApplicationSearchResultEntry: - ch <- &SearchSingleResult{ - Entry: &Entry{ - DN: packet.Children[1].Children[0].Value.(string), - Attributes: unpackAttributes(packet.Children[1].Children[1].Children), - }, - } - - case ApplicationSearchResultDone: - if err := GetLDAPError(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - if len(packet.Children) == 3 { - result := &SearchSingleResult{} - for _, child := range packet.Children[2].Children { - decodedChild, err := DecodeControl(child) - if err != nil { - werr := fmt.Errorf("failed to decode child control: %w", err) - ch <- &SearchSingleResult{Error: werr} - return - } - result.Controls = append(result.Controls, decodedChild) - } - ch <- result - } - foundSearchSingleResultDone = true - - case ApplicationSearchResultReference: - ref := packet.Children[1].Children[0].Value.(string) - ch <- &SearchSingleResult{Referral: ref} - } - } - } - l.Debug.Printf("%d: returning", msgCtx.id) - }() - return ch +// SearchAsync performs a search request and returns all search results asynchronously. +// This means you get all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved until the limit is reached. +// To stop the search, call cancel function returned context. +func (l *Conn) SearchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { + r := newSearchResponse(l, bufferSize) + r.start(ctx, searchRequest) + return r } // unpackAttributes will extract all given LDAP attributes and it's values diff --git a/v3/client.go b/v3/client.go index cef2d91b..5799f39b 100644 --- a/v3/client.go +++ b/v3/client.go @@ -34,7 +34,6 @@ type Client interface { Search(*SearchRequest) (*SearchResult, error) SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response - SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/v3/examples_test.go b/v3/examples_test.go index 46898abb..61f16197 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -51,7 +51,7 @@ func ExampleConn_Search() { } } -// This example demonstrates how to search with channel +// This example demonstrates how to search asynchronously func ExampleConn_SearchAsync() { l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) if err != nil { @@ -80,34 +80,6 @@ func ExampleConn_SearchAsync() { } } -// This example demonstrates how to search with channel -func ExampleConn_SearchWithChannel() { - l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) - if err != nil { - log.Fatal(err) - } - defer l.Close() - - searchRequest := NewSearchRequest( - "dc=example,dc=com", // The base dn to search - ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, - "(&(objectClass=organizationalPerson))", // The filter to apply - []string{"dn", "cn"}, // A list attributes to retrieve - nil, - ) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ch := l.SearchWithChannel(ctx, searchRequest, 64) - for res := range ch { - if res.Error != nil { - log.Fatalf("Error searching: %s", res.Error) - } - fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN) - } -} - // This example demonstrates how to start a TLS connection func ExampleConn_StartTLS() { l, err := DialURL("ldap://ldap.example.com:389") diff --git a/v3/ldap_test.go b/v3/ldap_test.go index bbeecc9d..5b96e039 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -3,6 +3,7 @@ package ldap import ( "context" "crypto/tls" + "log" "testing" ber "github.com/go-asn1-ber/asn1-ber" @@ -346,7 +347,7 @@ func TestEscapeDN(t *testing.T) { } } -func TestSearchWithChannel(t *testing.T) { +func TestSearchAsync(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -362,17 +363,18 @@ func TestSearchWithChannel(t *testing.T) { srs := make([]*Entry, 0) ctx := context.Background() - for sr := range l.SearchWithChannel(ctx, searchRequest, 64) { - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + srs = append(srs, r.Entry()) + } + if err := r.Err(); err != nil { + log.Fatal(err) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } -func TestSearchWithChannelAndCancel(t *testing.T) { +func TestSearchAsyncAndCancel(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -390,22 +392,21 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest, 0) - for i := 0; i < 10; i++ { - sr := <-ch - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 0) + for r.Next() { + srs = append(srs, r.Entry()) if len(srs) == cancelNum { cancel() } } - for range ch { - t.Log("Consume all entries from the channel to prevent blocking by the connection") + if err := r.Err(); err != nil { + log.Fatal(err) } - if len(srs) != cancelNum { - t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) + + if len(srs) > cancelNum+3 { + // the cancellation process is asynchronous, + // so it might get some entries after calling cancel() + t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } diff --git a/v3/response.go b/v3/response.go index 3bfef84e..81d97d9b 100644 --- a/v3/response.go +++ b/v3/response.go @@ -49,7 +49,10 @@ func (r *searchResponse) Err() error { // Next returns whether next data exist or not func (r *searchResponse) Next() bool { - res := <-r.ch + res, ok := <-r.ch + if !ok { + return false + } if res == nil { return false } @@ -67,18 +70,12 @@ func (r *searchResponse) Next() bool { return true } -func (r *searchResponse) searchAsync( - ctx context.Context, searchRequest *SearchRequest, bufferSize int) { - if bufferSize > 0 { - r.ch = make(chan *SearchSingleResult, bufferSize) - } else { - r.ch = make(chan *SearchSingleResult) - } +func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) { go func() { defer func() { close(r.ch) if err := recover(); err != nil { - r.conn.err = fmt.Errorf("ldap: recovered panic in searchAsync: %v", err) + r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err) } }() @@ -170,3 +167,16 @@ func (r *searchResponse) searchAsync( r.conn.Debug.Printf("%d: returning", msgCtx.id) }() } + +func newSearchResponse(conn *Conn, bufferSize int) *searchResponse { + var ch chan *SearchSingleResult + if bufferSize > 0 { + ch = make(chan *SearchSingleResult, bufferSize) + } else { + ch = make(chan *SearchSingleResult) + } + return &searchResponse{ + conn: conn, + ch: ch, + } +} diff --git a/v3/search.go b/v3/search.go index 2d8e13ad..afac768c 100644 --- a/v3/search.go +++ b/v3/search.go @@ -378,7 +378,7 @@ func (s *SearchResult) appendTo(r *SearchResult) { r.Controls = append(r.Controls, s.Controls...) } -// SearchSingleResult holds the server's single response to a search request +// SearchSingleResult holds the server's single entry response to a search request type SearchSingleResult struct { // Entry is the returned entry Entry *Entry @@ -590,121 +590,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { // To stop the search, call cancel function returned context. func (l *Conn) SearchAsync( ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { - r := &searchResponse{conn: l} - r.searchAsync(ctx, searchRequest, bufferSize) + r := newSearchResponse(l, bufferSize) + r.start(ctx, searchRequest) return r } -// SearchWithChannel performs a search request and returns all search results -// via the returned channel as soon as they are received. This means you get -// all results until an error happens (or the search successfully finished), -// e.g. for size / time limited requests all are recieved via the channel -// until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult { - var ch chan *SearchSingleResult - if channelSize > 0 { - ch = make(chan *SearchSingleResult, channelSize) - } else { - ch = make(chan *SearchSingleResult) - } - go func() { - defer func() { - close(ch) - if err := recover(); err != nil { - l.err = fmt.Errorf("ldap: recovered panic in SearchWithChannel: %v", err) - } - }() - - if l.IsClosing() { - return - } - - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - // encode search request - err := searchRequest.appendTo(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - defer l.finishMessage(msgCtx) - - foundSearchSingleResultDone := false - for !foundSearchSingleResultDone { - select { - case <-ctx.Done(): - l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) - return - default: - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - ch <- &SearchSingleResult{Error: err} - return - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - ber.PrintPacket(packet) - } - - switch packet.Children[1].Tag { - case ApplicationSearchResultEntry: - ch <- &SearchSingleResult{ - Entry: &Entry{ - DN: packet.Children[1].Children[0].Value.(string), - Attributes: unpackAttributes(packet.Children[1].Children[1].Children), - }, - } - - case ApplicationSearchResultDone: - if err := GetLDAPError(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - if len(packet.Children) == 3 { - result := &SearchSingleResult{} - for _, child := range packet.Children[2].Children { - decodedChild, err := DecodeControl(child) - if err != nil { - werr := fmt.Errorf("failed to decode child control: %w", err) - ch <- &SearchSingleResult{Error: werr} - return - } - result.Controls = append(result.Controls, decodedChild) - } - ch <- result - } - foundSearchSingleResultDone = true - - case ApplicationSearchResultReference: - ref := packet.Children[1].Children[0].Value.(string) - ch <- &SearchSingleResult{Referral: ref} - } - } - } - l.Debug.Printf("%d: returning", msgCtx.id) - }() - return ch -} - // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute {