Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add search asynchronously with context #440

Merged
merged 10 commits into from
Jun 30, 2023
29 changes: 29 additions & 0 deletions examples_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -50,6 +51,34 @@ func ExampleConn_Search() {
}
}

// This example demonstrates how to search with channel
func ExampleConn_SearchWithChannel() {
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
log.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
"dc=example,dc=com", // The base dn to search
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
"(&(objectClass=organizationalPerson))", // The filter to apply
[]string{"dn", "cn"}, // A list attributes to retrieve
nil,
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ch := l.SearchWithChannel(ctx, searchRequest)
for res := range ch {
if res.Error != nil {
log.Fatalf("Error searching: %s", res.Error)
}
fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN)
}
}

// This example demonstrates how to start a TLS connection
func ExampleConn_StartTLS() {
l, err := DialURL("ldap://ldap.example.com:389")
Expand Down
62 changes: 62 additions & 0 deletions ldap_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"testing"

Expand Down Expand Up @@ -344,3 +345,64 @@ func TestEscapeDN(t *testing.T) {
})
}
}

func TestSearchWithChannel(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

srs := make([]*Entry, 0)
ctx := context.Background()
for sr := range l.SearchWithChannel(ctx, searchRequest) {
if sr.Error != nil {
t.Fatal(err)
}
srs = append(srs, sr.Entry)
}

t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}

func TestSearchWithChannelAndCancel(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

cancelNum := 10
srs := make([]*Entry, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for sr := range l.SearchWithChannel(ctx, searchRequest) {
if sr.Error != nil {
t.Fatal(err)
}
srs = append(srs, sr.Entry)
if len(srs) == cancelNum {
cancel()
}
}
if len(srs) > cancelNum+2 {
// The cancel process is asynchronous,
// so a few entries after it canceled might be received
t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2)
}
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}
128 changes: 128 additions & 0 deletions search.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"errors"
"fmt"
"reflect"
Expand Down Expand Up @@ -375,6 +376,28 @@ func (s *SearchResult) appendTo(r *SearchResult) {
r.Controls = append(r.Controls, s.Controls...)
}

// SearchSingleResult holds the server's single response to a search request
type SearchSingleResult struct {
// Entry is the returned entry
Entry *Entry
// Referral is the returned referral
Referral string
// Controls are the returned controls
Controls []Control
// Error is set when the search request was failed
Error error
}

// Print outputs a human-readable description
func (s *SearchSingleResult) Print() {
s.Entry.Print()
}

// PrettyPrint outputs a human-readable description with indenting
func (s *SearchSingleResult) PrettyPrint(indent int) {
s.Entry.PrettyPrint(indent)
}

// SearchRequest represents a search request to send to the server
type SearchRequest struct {
BaseDN string
Expand Down Expand Up @@ -559,6 +582,111 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
}
}

// SearchWithChannel performs a search request and returns all search results
// via the returned channel as soon as they are received. This means you get
// all results until an error happens (or the search successfully finished),
// e.g. for size / time limited requests all are recieved via the channel
// until the limit is reached.
func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) chan *SearchSingleResult {
ch := make(chan *SearchSingleResult)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am conflicted about the channels, already with the first pull request. I would like to leave as many options open to the developer, such as determining the channel size. A suggestion and I would ask for general feedback here:

An additional argument for the function channelSize. If channelSize is 0, a normal channel without a certain size is created like now. If the value is greater than 0, a buffered channel with the defined size is created, e.g.

var chan *SearchSingleResult
if channelSize > 0 {
    chan = make(chan *SearchSingleResult, channelSize)
} else {
    chan = make(chan *SearchSingleResult)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Also, SearchWithPaging takes pagingSize, so it seems it's no wonder API SearchWithChannel takes channelSize.

func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed a deadlock issue using a channel with zero buffer size sometimes. Asynchronous is difficult.

Copy link
Contributor Author

@t2y t2y Jun 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it occurs running with the below order.

caller

  • 1: result <- ch
  • 4: cancel() // never cancel when the search is blocked to write the next result
  • 5: result <- ch // block

search

  • 2: ch <- result
  • 3: ch <- result // block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess another deadlock issue with zero buffer size is here: https://github.com/go-ldap/ldap/actions/runs/5167123452/jobs/9307800176?pr=440

client

3 defer l.Close() //block

search

1: ch <- result
2: ch <- result // block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To unlock the mutex, defer l.finishMessage(msgCtx) must be called. Got it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not a bug on go-ldap code. So I fixed on caller's code.

go func() {
defer close(ch)
if l.IsClosing() {
return
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
// encode search request
err := searchRequest.appendTo(packet)
if err != nil {
ch <- &SearchSingleResult{Error: err}
return
}
l.Debug.PrintPacket(packet)

msgCtx, err := l.sendMessage(packet)
if err != nil {
ch <- &SearchSingleResult{Error: err}
return
}
defer l.finishMessage(msgCtx)

foundSearchSingleResultDone := false
for !foundSearchSingleResultDone {
select {
case <-ctx.Done():
l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
return
default:
l.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
ch <- &SearchSingleResult{Error: err}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These writes do not account for a closed channel, which would cause a panic. Since this goroutines does not have a recover in the deferred function, the program would crash.

If we stay with this design, I suggest to add a warning to the function's description to not close the channel outside the library.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good review! I changed the returned channel to receive-only. By doing this, the compiler rejects when the caller closes it.

$ go run main.go 
# command-line-arguments
./main.go:62:10: invalid operation: cannot close receive-only channel ch (variable of type <-chan *ldap.SearchSingleResult)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I learned recover is needed to avoid panic from https://github.com/go-ldap/ldap/actions/runs/5166940415/jobs/9307492985.

return
}
packet, err = packetResponse.ReadPacket()
l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
ch <- &SearchSingleResult{Error: err}
return
}

if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
ch <- &SearchSingleResult{Error: err}
return
}
ber.PrintPacket(packet)
}

switch packet.Children[1].Tag {
case ApplicationSearchResultEntry:
entry := new(Entry)
entry.DN = packet.Children[1].Children[0].Value.(string)
for _, child := range packet.Children[1].Children[1].Children {
cpuschma marked this conversation as resolved.
Show resolved Hide resolved
attr := new(EntryAttribute)
attr.Name = child.Children[0].Value.(string)
for _, value := range child.Children[1].Children {
attr.Values = append(attr.Values, value.Value.(string))
attr.ByteValues = append(attr.ByteValues, value.ByteValue)
}
entry.Attributes = append(entry.Attributes, attr)
}
ch <- &SearchSingleResult{Entry: entry}

case ApplicationSearchResultDone:
if err := GetLDAPError(packet); err != nil {
ch <- &SearchSingleResult{Error: err}
return
}
if len(packet.Children) == 3 {
result := &SearchSingleResult{}
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
werr := fmt.Errorf("failed to decode child control: %w", err)
ch <- &SearchSingleResult{Error: werr}
return
}
result.Controls = append(result.Controls, decodedChild)
}
ch <- result
}
foundSearchSingleResultDone = true

case ApplicationSearchResultReference:
ref := packet.Children[1].Children[0].Value.(string)
ch <- &SearchSingleResult{Referral: ref}
}
}
}
l.Debug.Printf("%d: returning", msgCtx.id)
}()
return ch
}

// unpackAttributes will extract all given LDAP attributes and it's values
// from the ber.Packet
func unpackAttributes(children []*ber.Packet) []*EntryAttribute {
Expand Down
29 changes: 29 additions & 0 deletions v3/examples_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -50,6 +51,34 @@ func ExampleConn_Search() {
}
}

// This example demonstrates how to search with channel
func ExampleConn_SearchWithChannel() {
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
log.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
"dc=example,dc=com", // The base dn to search
ScopeWholeSubtree, NeverDerefAliases, 0, 0, false,
"(&(objectClass=organizationalPerson))", // The filter to apply
[]string{"dn", "cn"}, // A list attributes to retrieve
nil,
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ch := l.SearchWithChannel(ctx, searchRequest)
for res := range ch {
if res.Error != nil {
log.Fatalf("Error searching: %s", res.Error)
}
fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN)
}
}

// This example demonstrates how to start a TLS connection
func ExampleConn_StartTLS() {
l, err := DialURL("ldap://ldap.example.com:389")
Expand Down
62 changes: 62 additions & 0 deletions v3/ldap_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"testing"

Expand Down Expand Up @@ -344,3 +345,64 @@ func TestEscapeDN(t *testing.T) {
})
}
}

func TestSearchWithChannel(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

srs := make([]*Entry, 0)
ctx := context.Background()
for sr := range l.SearchWithChannel(ctx, searchRequest) {
if sr.Error != nil {
t.Fatal(err)
}
srs = append(srs, sr.Entry)
}

t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}

func TestSearchWithChannelAndCancel(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
}
defer l.Close()

searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)

cancelNum := 10
srs := make([]*Entry, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for sr := range l.SearchWithChannel(ctx, searchRequest) {
if sr.Error != nil {
t.Fatal(err)
}
srs = append(srs, sr.Entry)
if len(srs) == cancelNum {
cancel()
}
}
if len(srs) > cancelNum+2 {
// The cancel process is asynchronous,
// so a few entries after it canceled might be received
t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2)
}
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}
Loading