Skip to content

Commit

Permalink
feat: provide search async function and drop search with channels go-…
Browse files Browse the repository at this point in the history
  • Loading branch information
t2y committed Jun 22, 2023
1 parent a9daeeb commit 2f623f0
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 308 deletions.
2 changes: 2 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ldap

import (
"context"
"crypto/tls"
"time"
)
Expand Down Expand Up @@ -32,6 +33,7 @@ type Client interface {
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)

Search(*SearchRequest) (*SearchResult, error)
SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error)
}
17 changes: 9 additions & 8 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func ExampleConn_Search() {
}
}

// This example demonstrates how to search with channel
func ExampleConn_SearchWithChannel() {
// This example demonstrates how to search asynchronously
func ExampleConn_SearchAsync() {
l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389))
if err != nil {
log.Fatal(err)
Expand All @@ -70,12 +70,13 @@ func ExampleConn_SearchWithChannel() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ch := l.SearchWithChannel(ctx, searchRequest, 64)
for res := range ch {
if res.Error != nil {
log.Fatalf("Error searching: %s", res.Error)
}
fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN)
r := l.SearchAsync(ctx, searchRequest, 64)
for r.Next() {
entry := r.Entry()
fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN)
}
if err := r.Err(); err != nil {
log.Fatal(err)
}
}

Expand Down
41 changes: 21 additions & 20 deletions ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ldap
import (
"context"
"crypto/tls"
"log"
"testing"

ber "github.com/go-asn1-ber/asn1-ber"
Expand Down Expand Up @@ -346,7 +347,7 @@ func TestEscapeDN(t *testing.T) {
}
}

func TestSearchWithChannel(t *testing.T) {
func TestSearchAsync(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
Expand All @@ -362,17 +363,18 @@ func TestSearchWithChannel(t *testing.T) {

srs := make([]*Entry, 0)
ctx := context.Background()
for sr := range l.SearchWithChannel(ctx, searchRequest, 64) {
if sr.Error != nil {
t.Fatal(err)
}
srs = append(srs, sr.Entry)
r := l.SearchAsync(ctx, searchRequest, 64)
for r.Next() {
srs = append(srs, r.Entry())
}
if err := r.Err(); err != nil {
log.Fatal(err)
}

t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}

func TestSearchWithChannelAndCancel(t *testing.T) {
func TestSearchAsyncAndCancel(t *testing.T) {
l, err := DialURL(ldapServer)
if err != nil {
t.Fatal(err)
Expand All @@ -390,22 +392,21 @@ func TestSearchWithChannelAndCancel(t *testing.T) {
srs := make([]*Entry, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := l.SearchWithChannel(ctx, searchRequest, 0)
for i := 0; i < 10; i++ {
sr := <-ch
if sr.Error != nil {
t.Fatal(err)
}
srs = append(srs, sr.Entry)
r := l.SearchAsync(ctx, searchRequest, 0)
for r.Next() {
srs = append(srs, r.Entry())
if len(srs) == cancelNum {
cancel()
}
}
for range ch {
t.Log("Consume all entries from the channel to prevent blocking by the connection")
if err := r.Err(); err != nil {
log.Fatal(err)
}
if len(srs) != cancelNum {
t.Errorf("Got entries %d, expected %d", len(srs), cancelNum)

if len(srs) > cancelNum+3 {
// the cancellation process is asynchronous,
// so it might get some entries after calling cancel()
t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3)
}
t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs))
}
182 changes: 182 additions & 0 deletions response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package ldap

import (
"context"
"errors"
"fmt"

ber "github.com/go-asn1-ber/asn1-ber"
)

// Response defines an interface to get data from an LDAP server
type Response interface {
Entry() *Entry
Referral() string
Controls() []Control
Err() error
Next() bool
}

type searchResponse struct {
conn *Conn
ch chan *SearchSingleResult

entry *Entry
referral string
controls []Control
err error
}

// Entry returns an entry from the given search request
func (r *searchResponse) Entry() *Entry {
return r.entry
}

// Referral returns a referral from the given search request
func (r *searchResponse) Referral() string {
return r.referral
}

// Controls returns controls from the given search request
func (r *searchResponse) Controls() []Control {
return r.controls
}

// Err returns an error when the given search request was failed
func (r *searchResponse) Err() error {
return r.err
}

// Next returns whether next data exist or not
func (r *searchResponse) Next() bool {
res, ok := <-r.ch
if !ok {
return false
}
if res == nil {
return false
}
r.err = res.Error
if r.err != nil {
return false
}
r.err = r.conn.GetLastError()
if r.err != nil {
return false
}
r.entry = res.Entry
r.referral = res.Referral
r.controls = res.Controls
return true
}

func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) {
go func() {
defer func() {
close(r.ch)
if err := recover(); err != nil {
r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err)
}
}()

if r.conn.IsClosing() {
return
}

packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID"))
// encode search request
err := searchRequest.appendTo(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
r.conn.Debug.PrintPacket(packet)

msgCtx, err := r.conn.sendMessage(packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
defer r.conn.finishMessage(msgCtx)

foundSearchSingleResultDone := false
for !foundSearchSingleResultDone {
select {
case <-ctx.Done():
r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error())
return
default:
r.conn.Debug.Printf("%d: waiting for response", msgCtx.id)
packetResponse, ok := <-msgCtx.responses
if !ok {
err := NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
r.ch <- &SearchSingleResult{Error: err}
return
}
packet, err = packetResponse.ReadPacket()
r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet)
if err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}

if r.conn.Debug {
if err := addLDAPDescriptions(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
ber.PrintPacket(packet)
}

switch packet.Children[1].Tag {
case ApplicationSearchResultEntry:
r.ch <- &SearchSingleResult{
Entry: &Entry{
DN: packet.Children[1].Children[0].Value.(string),
Attributes: unpackAttributes(packet.Children[1].Children[1].Children),
},
}

case ApplicationSearchResultDone:
if err := GetLDAPError(packet); err != nil {
r.ch <- &SearchSingleResult{Error: err}
return
}
if len(packet.Children) == 3 {
result := &SearchSingleResult{}
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
werr := fmt.Errorf("failed to decode child control: %w", err)
r.ch <- &SearchSingleResult{Error: werr}
return
}
result.Controls = append(result.Controls, decodedChild)
}
r.ch <- result
}
foundSearchSingleResultDone = true

case ApplicationSearchResultReference:
ref := packet.Children[1].Children[0].Value.(string)
r.ch <- &SearchSingleResult{Referral: ref}
}
}
}
r.conn.Debug.Printf("%d: returning", msgCtx.id)
}()
}

func newSearchResponse(conn *Conn, bufferSize int) *searchResponse {
var ch chan *SearchSingleResult
if bufferSize > 0 {
ch = make(chan *SearchSingleResult, bufferSize)
} else {
ch = make(chan *SearchSingleResult)
}
return &searchResponse{
conn: conn,
ch: ch,
}
}
Loading

0 comments on commit 2f623f0

Please sign in to comment.