From ef4da407f50b69bc8ce71063cf5a2f41be84651b Mon Sep 17 00:00:00 2001 From: Hossein Ahmadian-Yazdi Date: Mon, 25 Nov 2019 15:11:22 -0500 Subject: [PATCH] Match our Fork with go-ldap/ldap (#4) * unified request flow && external binding to LDAP (#232) * unified request flow && external binding to LDAP * fix debug mode * go.mod was added * Prevent negative waitgroup panic by `Add`ing first. (#237) * Use github for module name. (#239) * Rename asn1 ber dependency. (#243) * Rename asn1.ber dependency. * go mod tidy * Update travic CI to new asn1 ber * Update travis go_import_path * Update README (#245) * Update go.mod (#241) I believe that this (plus a new tag) is what is necessary for native go modules support -- right now go modules complains that the version tag is >= v2 but the module doesn't claim to be a version >= v2. * Versioned v3 according to Go wiki, maintaining backward compatibility (#247) * Moved v3 to subfolder to allow for versioning >2 with go modules. Reverted top level go.mod to fix backward compatibility * Updated readme to include directions on Go Modules, including the rationale --- .travis.yml | 7 +- README.md | 37 ++-- add.go | 53 ++--- bind.go | 83 ++++--- client.go | 18 +- compare.go | 55 +++-- conn.go | 16 +- conn_test.go | 2 +- control.go | 2 +- control_test.go | 2 +- debug.go | 10 +- del.go | 48 ++--- dn.go | 2 +- error.go | 6 +- error_test.go | 2 +- filter.go | 2 +- filter_test.go | 2 +- go.mod | 5 + go.sum | 2 + ldap.go | 5 +- moddn.go | 47 ++-- modify.go | 70 +++--- passwdmodify.go | 64 ++---- request.go | 66 ++++++ search.go | 75 +++---- v3/add.go | 100 +++++++++ v3/bind.go | 152 +++++++++++++ v3/client.go | 30 +++ v3/compare.go | 80 +++++++ v3/conn.go | 522 +++++++++++++++++++++++++++++++++++++++++++++ v3/conn_test.go | 336 +++++++++++++++++++++++++++++ v3/control.go | 499 +++++++++++++++++++++++++++++++++++++++++++ v3/control_test.go | 123 +++++++++++ v3/debug.go | 30 +++ v3/del.go | 64 ++++++ v3/dn.go | 247 +++++++++++++++++++++ v3/dn_test.go | 209 ++++++++++++++++++ v3/doc.go | 4 + v3/error.go | 236 ++++++++++++++++++++ v3/error_test.go | 141 ++++++++++++ v3/example_test.go | 342 +++++++++++++++++++++++++++++ v3/filter.go | 465 ++++++++++++++++++++++++++++++++++++++++ v3/filter_test.go | 254 ++++++++++++++++++++++ v3/go.mod | 5 + v3/ldap.go | 340 +++++++++++++++++++++++++++++ v3/ldap_test.go | 326 ++++++++++++++++++++++++++++ v3/moddn.go | 85 ++++++++ v3/moddn_test.go | 73 +++++++ v3/modify.go | 151 +++++++++++++ v3/passwdmodify.go | 131 ++++++++++++ v3/request.go | 66 ++++++ v3/search.go | 425 ++++++++++++++++++++++++++++++++++++ v3/search_test.go | 31 +++ 53 files changed, 5797 insertions(+), 351 deletions(-) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 request.go create mode 100644 v3/add.go create mode 100644 v3/bind.go create mode 100644 v3/client.go create mode 100644 v3/compare.go create mode 100644 v3/conn.go create mode 100644 v3/conn_test.go create mode 100644 v3/control.go create mode 100644 v3/control_test.go create mode 100644 v3/debug.go create mode 100644 v3/del.go create mode 100644 v3/dn.go create mode 100644 v3/dn_test.go create mode 100644 v3/doc.go create mode 100644 v3/error.go create mode 100644 v3/error_test.go create mode 100644 v3/example_test.go create mode 100644 v3/filter.go create mode 100644 v3/filter_test.go create mode 100644 v3/go.mod create mode 100644 v3/ldap.go create mode 100644 v3/ldap_test.go create mode 100644 v3/moddn.go create mode 100644 v3/moddn_test.go create mode 100644 v3/modify.go create mode 100644 v3/passwdmodify.go create mode 100644 v3/request.go create mode 100644 v3/search.go create mode 100644 v3/search_test.go diff --git a/.travis.yml b/.travis.yml index bb899b2a..f2538fd5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,18 +9,19 @@ go: - "1.10.x" - "1.11.x" - "1.12.x" + - "1.13.x" - tip git: - depth: 1 + depth: 1 matrix: fast_finish: true allow_failures: - go: tip -go_import_path: gopkg.in/ldap.v3 +go_import_path: github.com/go-ldap/ldap install: - - go get gopkg.in/asn1-ber.v1 + - go get github.com/go-asn1-ber/asn1-ber - go get code.google.com/p/go.tools/cmd/cover || go get golang.org/x/tools/cmd/cover - go get github.com/golang/lint/golint || go get golang.org/x/lint/golint || true - go build -v ./... diff --git a/README.md b/README.md index 25cf730b..c61e4883 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,8 @@ -[![GoDoc](https://godoc.org/gopkg.in/ldap.v3?status.svg)](https://godoc.org/gopkg.in/ldap.v3) +[![GoDoc](https://godoc.org/github.com/go-ldap/ldap?status.svg)](https://godoc.org/github.com/go-ldap/ldap) [![Build Status](https://travis-ci.org/go-ldap/ldap.svg)](https://travis-ci.org/go-ldap/ldap) # Basic LDAP v3 functionality for the GO programming language. -## Install - -For the latest version use: - - go get gopkg.in/ldap.v3 - -Import the latest version with: - - import "gopkg.in/ldap.v3" - -## Required Libraries: - - - gopkg.in/asn1-ber.v1 - ## Features: - Connecting to LDAP server (non-TLS, TLS, STARTTLS) @@ -34,6 +20,27 @@ Import the latest version with: - search - modify +## Go Modules: + +`go get github.com/go-ldap/ldap/v3` + +As go-ldap was v2+ when Go Modules came out, updating to Go Modules would be considered a breaking change. + +To maintain backwards compatability, we ultimately decided to use subfolders (as v3 was already a branch). +Whilst this duplicates the code, we can move toward implementing a backwards-compatible versioning system that allows for code reuse. +The alternative would be to increment the version number, however we believe that this would confuse users as v3 is in line with LDAPv3 (RFC-4511) +https://tools.ietf.org/html/rfc4511 + + +For more info, please visit the pull request that updated to modules. +https://github.com/go-ldap/ldap/pull/247 + +To install with `GOMODULE111=off`, use `go get github.com/go-ldap/ldap` +https://golang.org/cmd/go/#hdr-Legacy_GOPATH_go_get + +As always, we are looking for contributors with great ideas on how to best move forward. + + ## Contributing: Bug reports and pull requests are welcome! diff --git a/add.go b/add.go index 19bce1b7..a3e6b881 100644 --- a/add.go +++ b/add.go @@ -10,10 +10,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // Attribute represents an LDAP attribute @@ -45,20 +44,26 @@ type AddRequest struct { Controls []Control } -func (a AddRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.DN, "DN")) +func (req *AddRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") - for _, attribute := range a.Attributes { + for _, attribute := range req.Attributes { attributes.AppendChild(attribute.encode()) } - request.AppendChild(attributes) - return request + pkt.AppendChild(attributes) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // Attribute adds an attribute with the given type and values -func (a *AddRequest) Attribute(attrType string, attrVals []string) { - a.Attributes = append(a.Attributes, Attribute{Type: attrType, Vals: attrVals}) +func (req *AddRequest) Attribute(attrType string, attrVals []string) { + req.Attributes = append(req.Attributes, Attribute{Type: attrType, Vals: attrVals}) } // NewAddRequest returns an AddRequest for the given DN, with no attributes @@ -72,39 +77,17 @@ func NewAddRequest(dn string, controls []Control) *AddRequest { // Add performs the given AddRequest func (l *Conn) Add(addRequest *AddRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(addRequest.encode()) - if len(addRequest.Controls) > 0 { - packet.AppendChild(encodeControls(addRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(addRequest) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationAddResponse { err := GetLDAPError(packet) if err != nil { @@ -113,7 +96,5 @@ func (l *Conn) Add(addRequest *AddRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/bind.go b/bind.go index 59c3f5ef..15307f50 100644 --- a/bind.go +++ b/bind.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // SimpleBindRequest represents a username/password bind operation @@ -35,13 +35,18 @@ func NewSimpleBindRequest(username string, password string, controls []Control) } } -func (bindRequest *SimpleBindRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, bindRequest.Username, "User Name")) - request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, bindRequest.Password, "Password")) +func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name")) + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password")) - return request + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // SimpleBind performs the simple bind operation defined in the given request @@ -50,41 +55,17 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) } - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - encodedBindRequest := simpleBindRequest.encode() - packet.AppendChild(encodedBindRequest) - if len(simpleBindRequest.Controls) > 0 { - packet.AppendChild(encodeControls(simpleBindRequest.Controls)) - } - - if l.Debug { - ber.PrintPacket(packet) - } - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(simpleBindRequest) if err != nil { return nil, err } defer l.finishMessage(msgCtx) - packetResponse, ok := <-msgCtx.responses - if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return nil, err } - if l.Debug { - if err = addLDAPDescriptions(packet); err != nil { - return nil, err - } - ber.PrintPacket(packet) - } - result := &SimpleBindResult{ Controls: make([]Control, 0), } @@ -133,3 +114,39 @@ func (l *Conn) UnauthenticatedBind(username string) error { _, err := l.SimpleBind(req) return err } + +var externalBindRequest = requestFunc(func(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech")) + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred")) + + pkt.AppendChild(saslAuth) + + envelope.AppendChild(pkt) + + return nil +}) + +// ExternalBind performs SASL/EXTERNAL authentication. +// +// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind. +// +// See https://tools.ietf.org/html/rfc4422#appendix-A +func (l *Conn) ExternalBind() error { + msgCtx, err := l.doRequest(externalBindRequest) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + return GetLDAPError(packet) +} diff --git a/client.go b/client.go index c7f41f6f..619677c7 100644 --- a/client.go +++ b/client.go @@ -8,21 +8,23 @@ import ( // Client knows how to interact with an LDAP server type Client interface { Start() - StartTLS(config *tls.Config) error + StartTLS(*tls.Config) error Close() SetTimeout(time.Duration) Bind(username, password string) error - SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) + UnauthenticatedBind(username string) error + SimpleBind(*SimpleBindRequest) (*SimpleBindResult, error) + ExternalBind() error - Add(addRequest *AddRequest) error - Del(delRequest *DelRequest) error - Modify(modifyRequest *ModifyRequest) error - ModifyDN(modifyDNRequest *ModifyDNRequest) error + Add(*AddRequest) error + Del(*DelRequest) error + Modify(*ModifyRequest) error + ModifyDN(*ModifyDNRequest) error Compare(dn, attribute, value string) (bool, error) - PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) + PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) - Search(searchRequest *SearchRequest) (*SearchResult, error) + Search(*SearchRequest) (*SearchResult, error) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) } diff --git a/compare.go b/compare.go index 5b5013cb..04a2e17c 100644 --- a/compare.go +++ b/compare.go @@ -20,53 +20,50 @@ package ldap import ( - "errors" "fmt" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) -// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise -// false with any error that occurs if any. -func (l *Conn) Compare(dn, attribute, value string) (bool, error) { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) +// CompareRequest represents an LDAP CompareRequest operation. +type CompareRequest struct { + DN string + Attribute string + Value string +} - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "DN")) +func (req *CompareRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) ava := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "AttributeValueAssertion") - ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "AttributeDesc")) - ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "AssertionValue")) - request.AppendChild(ava) - packet.AppendChild(request) + ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Attribute, "AttributeDesc")) + ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Value, "AssertionValue")) - l.Debug.PrintPacket(packet) + pkt.AppendChild(ava) - msgCtx, err := l.sendMessage(packet) + envelope.AppendChild(pkt) + + return nil +} + +// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise +// false with any error that occurs if any. +func (l *Conn) Compare(dn, attribute, value string) (bool, error) { + msgCtx, err := l.doRequest(&CompareRequest{ + DN: dn, + Attribute: attribute, + Value: value}) if err != nil { return false, err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return false, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return false, err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return false, err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationCompareResponse { err := GetLDAPError(packet) diff --git a/conn.go b/conn.go index c20471fc..e6051170 100644 --- a/conn.go +++ b/conn.go @@ -11,7 +11,7 @@ import ( "sync/atomic" "time" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -140,7 +140,6 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { // or ldap:// specified as protocol. On success a new Conn for the connection // is returned. func DialURL(addr string) (*Conn, error) { - lurl, err := url.Parse(addr) if err != nil { return nil, NewError(ErrorNetwork, err) @@ -154,6 +153,11 @@ func DialURL(addr string) (*Conn, error) { } switch lurl.Scheme { + case "ldapi": + if lurl.Path == "" || lurl.Path == "/" { + lurl.Path = "/var/run/slapd/ldapi" + } + return Dial("unix", lurl.Path) case "ldap": if port == "" { port = DefaultLdapPort @@ -187,9 +191,9 @@ func NewConn(conn net.Conn, isTLS bool) *Conn { // Start initializes goroutines to read responses and process messages func (l *Conn) Start() { + l.wgClose.Add(1) go l.reader() go l.processMessages() - l.wgClose.Add(1) } // IsClosing returns whether or not we're currently closing. @@ -490,11 +494,13 @@ func (l *Conn) reader() { // A read error is expected here if we are closing the connection... if !l.IsClosing() { l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) - l.Debug.Printf("reader error: %s", err.Error()) + l.Debug.Printf("reader error: %s", err) } return } - addLDAPDescriptions(packet) + if err := addLDAPDescriptions(packet); err != nil { + l.Debug.Printf("descriptions error: %s", err) + } if len(packet.Children) == 0 { l.Debug.Printf("Received bad ldap packet") continue diff --git a/conn_test.go b/conn_test.go index 488754d1..f815a967 100644 --- a/conn_test.go +++ b/conn_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "gopkg.in/asn1-ber.v1" + "github.com/go-asn1-ber/asn1-ber" ) func TestUnresponsiveConnection(t *testing.T) { diff --git a/control.go b/control.go index 3f181912..463fe3a3 100644 --- a/control.go +++ b/control.go @@ -4,7 +4,7 @@ import ( "fmt" "strconv" - "gopkg.in/asn1-ber.v1" + "github.com/go-asn1-ber/asn1-ber" ) const ( diff --git a/control_test.go b/control_test.go index 0d59d10c..309d7f27 100644 --- a/control_test.go +++ b/control_test.go @@ -7,7 +7,7 @@ import ( "runtime" "testing" - ber "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) func TestControlPaging(t *testing.T) { diff --git a/debug.go b/debug.go index 7279fc25..2c0b30c8 100644 --- a/debug.go +++ b/debug.go @@ -3,20 +3,26 @@ package ldap import ( "log" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // debugging type // - has a Printf method to write the debug output type debugging bool -// write debug output +// Enable controls debugging mode. +func (debug *debugging) Enable(b bool) { + *debug = debugging(b) +} + +// Printf writes debug output. func (debug debugging) Printf(format string, args ...interface{}) { if debug { log.Printf(format, args...) } } +// PrintPacket dumps a packet. func (debug debugging) PrintPacket(packet *ber.Packet) { if debug { ber.PrintPacket(packet) diff --git a/del.go b/del.go index 6f78beb1..49811d34 100644 --- a/del.go +++ b/del.go @@ -6,10 +6,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // DelRequest implements an LDAP deletion request @@ -20,15 +19,20 @@ type DelRequest struct { Controls []Control } -func (d DelRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, d.DN, "Del Request") - request.Data.Write([]byte(d.DN)) - return request +func (req *DelRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request") + pkt.Data.Write([]byte(req.DN)) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // NewDelRequest creates a delete request for the given DN and controls -func NewDelRequest(DN string, - Controls []Control) *DelRequest { +func NewDelRequest(DN string, Controls []Control) *DelRequest { return &DelRequest{ DN: DN, Controls: Controls, @@ -37,39 +41,17 @@ func NewDelRequest(DN string, // Del executes the given delete request func (l *Conn) Del(delRequest *DelRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(delRequest.encode()) - if len(delRequest.Controls) > 0 { - packet.AppendChild(encodeControls(delRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(delRequest) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationDelResponse { err := GetLDAPError(packet) if err != nil { @@ -78,7 +60,5 @@ func (l *Conn) Del(delRequest *DelRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/dn.go b/dn.go index f89e73a9..9d32e7fa 100644 --- a/dn.go +++ b/dn.go @@ -48,7 +48,7 @@ import ( "fmt" "strings" - "gopkg.in/asn1-ber.v1" + "github.com/go-asn1-ber/asn1-ber" ) // AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 diff --git a/error.go b/error.go index 639ed824..b1fda2d8 100644 --- a/error.go +++ b/error.go @@ -3,7 +3,7 @@ package ldap import ( "fmt" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // LDAP Result Codes @@ -196,7 +196,9 @@ func (e *Error) Error() string { func GetLDAPError(packet *ber.Packet) error { if packet == nil { return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")} - } else if len(packet.Children) >= 2 { + } + + if len(packet.Children) >= 2 { response := packet.Children[1] if response == nil { return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")} diff --git a/error_test.go b/error_test.go index 816bf8a1..c547e454 100644 --- a/error_test.go +++ b/error_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "gopkg.in/asn1-ber.v1" + "github.com/go-asn1-ber/asn1-ber" ) // TestNilPacket tests that nil packets don't cause a panic. diff --git a/filter.go b/filter.go index 4cc4207b..a3875506 100644 --- a/filter.go +++ b/filter.go @@ -8,7 +8,7 @@ import ( "strings" "unicode/utf8" - "gopkg.in/asn1-ber.v1" + "github.com/go-asn1-ber/asn1-ber" ) // Filter choices diff --git a/filter_test.go b/filter_test.go index 7a9194f7..45796325 100644 --- a/filter_test.go +++ b/filter_test.go @@ -4,7 +4,7 @@ import ( "strings" "testing" - "gopkg.in/asn1-ber.v1" + "github.com/go-asn1-ber/asn1-ber" ) type compileTest struct { diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..9b350ff9 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/go-ldap/ldap + +require github.com/go-asn1-ber/asn1-ber v1.3.1 + +go 1.13 diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..c8b9085b --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/go-asn1-ber/asn1-ber v1.3.1 h1:gvPdv/Hr++TRFCl0UbPFHC54P9N9jgsRPnmnr419Uck= +github.com/go-asn1-ber/asn1-ber v1.3.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= diff --git a/ldap.go b/ldap.go index e6155d50..7dbc951a 100644 --- a/ldap.go +++ b/ldap.go @@ -1,12 +1,11 @@ package ldap import ( - "errors" "fmt" "io/ioutil" "os" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // LDAP Application Codes @@ -87,7 +86,7 @@ var BeheraPasswordPolicyErrorMap = map[int8]string{ func addLDAPDescriptions(packet *ber.Packet) (err error) { defer func() { if r := recover(); r != nil { - err = NewError(ErrorDebugging, errors.New("ldap: cannot process packet to add descriptions")) + err = NewError(ErrorDebugging, fmt.Errorf("ldap: cannot process packet to add descriptions: %s", r)) } }() packet.Description = "LDAP Response" diff --git a/moddn.go b/moddn.go index 803279d2..ca48b2b4 100644 --- a/moddn.go +++ b/moddn.go @@ -11,10 +11,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // ModifyDNRequest holds the request to modify a DN @@ -46,50 +45,34 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi } } -func (m ModifyDNRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN")) - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.NewRDN, "New RDN")) - request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, m.DeleteOldRDN, "Delete old RDN")) - if m.NewSuperior != "" { - request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, m.NewSuperior, "New Superior")) +func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.NewRDN, "New RDN")) + pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN")) + if req.NewSuperior != "" { + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.NewSuperior, "New Superior")) } - return request + + envelope.AppendChild(pkt) + + return nil } // ModifyDN renames the given DN and optionally move to another base (when the "newSup" argument // to NewModifyDNRequest() is not ""). func (l *Conn) ModifyDN(m *ModifyDNRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(m.encode()) - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(m) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationModifyDNResponse { err := GetLDAPError(packet) if err != nil { @@ -98,7 +81,5 @@ func (l *Conn) ModifyDN(m *ModifyDNRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/modify.go b/modify.go index d83e6221..88f65c1b 100644 --- a/modify.go +++ b/modify.go @@ -26,10 +26,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // Change operation choices @@ -84,40 +83,43 @@ type ModifyRequest struct { } // Add appends the given attribute to the list of changes to be made -func (m *ModifyRequest) Add(attrType string, attrVals []string) { - m.appendChange(AddAttribute, attrType, attrVals) +func (req *ModifyRequest) Add(attrType string, attrVals []string) { + req.appendChange(AddAttribute, attrType, attrVals) } // Delete appends the given attribute to the list of changes to be made -func (m *ModifyRequest) Delete(attrType string, attrVals []string) { - m.appendChange(DeleteAttribute, attrType, attrVals) +func (req *ModifyRequest) Delete(attrType string, attrVals []string) { + req.appendChange(DeleteAttribute, attrType, attrVals) } // Replace appends the given attribute to the list of changes to be made -func (m *ModifyRequest) Replace(attrType string, attrVals []string) { - m.appendChange(ReplaceAttribute, attrType, attrVals) +func (req *ModifyRequest) Replace(attrType string, attrVals []string) { + req.appendChange(ReplaceAttribute, attrType, attrVals) } -func (m *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) { - m.Changes = append(m.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) +func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) { + req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) } -func (m ModifyRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN")) +func (req *ModifyRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes") - for _, change := range m.Changes { + for _, change := range req.Changes { changes.AppendChild(change.encode()) } - request.AppendChild(changes) - return request + pkt.AppendChild(changes) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // NewModifyRequest creates a modify request for the given DN -func NewModifyRequest( - dn string, - controls []Control, -) *ModifyRequest { +func NewModifyRequest(dn string, controls []Control) *ModifyRequest { return &ModifyRequest{ DN: dn, Controls: controls, @@ -126,39 +128,17 @@ func NewModifyRequest( // Modify performs the ModifyRequest func (l *Conn) Modify(modifyRequest *ModifyRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(modifyRequest.encode()) - if len(modifyRequest.Controls) > 0 { - packet.AppendChild(encodeControls(modifyRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(modifyRequest) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationModifyResponse { err := GetLDAPError(packet) if err != nil { @@ -167,7 +147,5 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/passwdmodify.go b/passwdmodify.go index 06bc21db..135554d9 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -6,10 +6,9 @@ package ldap import ( - "errors" "fmt" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) const ( @@ -36,25 +35,28 @@ type PasswordModifyResult struct { Referral string } -func (r *PasswordModifyRequest) encode() (*ber.Packet, error) { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Password Modify Extended Operation") - request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, passwordModifyOID, "Extended Request Name: Password Modify OID")) +func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Password Modify Extended Operation") + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, passwordModifyOID, "Extended Request Name: Password Modify OID")) + extendedRequestValue := ber.Encode(ber.ClassContext, ber.TypePrimitive, 1, nil, "Extended Request Value: Password Modify Request") passwordModifyRequestValue := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Password Modify Request") - if r.UserIdentity != "" { - passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, r.UserIdentity, "User Identity")) + if req.UserIdentity != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.UserIdentity, "User Identity")) } - if r.OldPassword != "" { - passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 1, r.OldPassword, "Old Password")) + if req.OldPassword != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 1, req.OldPassword, "Old Password")) } - if r.NewPassword != "" { - passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 2, r.NewPassword, "New Password")) + if req.NewPassword != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 2, req.NewPassword, "New Password")) } - extendedRequestValue.AppendChild(passwordModifyRequestValue) - request.AppendChild(extendedRequestValue) - return request, nil + pkt.AppendChild(extendedRequestValue) + + envelope.AppendChild(pkt) + + return nil } // NewPasswordModifyRequest creates a new PasswordModifyRequest @@ -84,46 +86,18 @@ func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPasswo // PasswordModify performs the modification request func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - - encodedPasswordModifyRequest, err := passwordModifyRequest.encode() - if err != nil { - return nil, err - } - packet.AppendChild(encodedPasswordModifyRequest) - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(passwordModifyRequest) if err != nil { return nil, err } defer l.finishMessage(msgCtx) - result := &PasswordModifyResult{} - - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return nil, err } - if packet == nil { - return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message")) - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return nil, err - } - ber.PrintPacket(packet) - } + result := &PasswordModifyResult{} if packet.Children[1].Tag == ApplicationExtendedResponse { err := GetLDAPError(packet) diff --git a/request.go b/request.go new file mode 100644 index 00000000..4790bbfe --- /dev/null +++ b/request.go @@ -0,0 +1,66 @@ +package ldap + +import ( + "errors" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +var ( + errRespChanClosed = errors.New("ldap: response channel closed") + errCouldNotRetMsg = errors.New("ldap: could not retrieve message") +) + +type request interface { + appendTo(*ber.Packet) error +} + +type requestFunc func(*ber.Packet) error + +func (f requestFunc) appendTo(p *ber.Packet) error { + return f(p) +} + +func (l *Conn) doRequest(req request) (*messageContext, error) { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + if err := req.appendTo(packet); err != nil { + return nil, err + } + + if l.Debug { + ber.PrintPacket(packet) + } + + msgCtx, err := l.sendMessage(packet) + if err != nil { + return nil, err + } + l.Debug.Printf("%d: returning", msgCtx.id) + return msgCtx, nil +} + +func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) { + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + return nil, NewError(ErrorNetwork, errRespChanClosed) + } + packet, err := packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return nil, err + } + + if packet == nil { + return nil, NewError(ErrorNetwork, errCouldNotRetMsg) + } + + if l.Debug { + if err = addLDAPDescriptions(packet); err != nil { + return nil, err + } + ber.PrintPacket(packet) + } + return packet, nil +} diff --git a/search.go b/search.go index 3aa6dac0..186246c0 100644 --- a/search.go +++ b/search.go @@ -61,7 +61,7 @@ import ( "sort" "strings" - "gopkg.in/asn1-ber.v1" + ber "github.com/go-asn1-ber/asn1-ber" ) // scope choices @@ -246,27 +246,33 @@ type SearchRequest struct { Controls []Control } -func (s *SearchRequest) encode() (*ber.Packet, error) { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, s.BaseDN, "Base DN")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.Scope), "Scope")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.DerefAliases), "Deref Aliases")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.SizeLimit), "Size Limit")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.TimeLimit), "Time Limit")) - request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, s.TypesOnly, "Types Only")) +func (req *SearchRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.Scope), "Scope")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.DerefAliases), "Deref Aliases")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.SizeLimit), "Size Limit")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.TimeLimit), "Time Limit")) + pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.TypesOnly, "Types Only")) // compile and encode filter - filterPacket, err := CompileFilter(s.Filter) + filterPacket, err := CompileFilter(req.Filter) if err != nil { - return nil, err + return err } - request.AppendChild(filterPacket) + pkt.AppendChild(filterPacket) // encode attributes attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") - for _, attribute := range s.Attributes { + for _, attribute := range req.Attributes { attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) } - request.AppendChild(attributesPacket) - return request, nil + pkt.AppendChild(attributesPacket) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // NewSearchRequest creates a new search request @@ -366,22 +372,7 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) // Search performs the given search request func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - // encode search request - encodedSearchRequest, err := searchRequest.encode() - if err != nil { - return nil, err - } - packet.AppendChild(encodedSearchRequest) - // encode search controls - if len(searchRequest.Controls) > 0 { - packet.AppendChild(encodeControls(searchRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(searchRequest) if err != nil { return nil, err } @@ -392,26 +383,12 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { Referrals: make([]string, 0), Controls: make([]Control, 0)} - foundSearchResultDone := false - for !foundSearchResultDone { - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + for { + packet, err := l.readPacket(msgCtx) if err != nil { return nil, err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return nil, err - } - ber.PrintPacket(packet) - } - switch packet.Children[1].Tag { case 4: entry := new(Entry) @@ -440,11 +417,9 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { result.Controls = append(result.Controls, decodedChild) } } - foundSearchResultDone = true + return result, nil case 19: result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string)) } } - l.Debug.Printf("%d: returning", msgCtx.id) - return result, nil } diff --git a/v3/add.go b/v3/add.go new file mode 100644 index 00000000..a3e6b881 --- /dev/null +++ b/v3/add.go @@ -0,0 +1,100 @@ +// +// https://tools.ietf.org/html/rfc4511 +// +// AddRequest ::= [APPLICATION 8] SEQUENCE { +// entry LDAPDN, +// attributes AttributeList } +// +// AttributeList ::= SEQUENCE OF attribute Attribute + +package ldap + +import ( + "log" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Attribute represents an LDAP attribute +type Attribute struct { + // Type is the name of the LDAP attribute + Type string + // Vals are the LDAP attribute values + Vals []string +} + +func (a *Attribute) encode() *ber.Packet { + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") + seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.Type, "Type")) + set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") + for _, value := range a.Vals { + set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals")) + } + seq.AppendChild(set) + return seq +} + +// AddRequest represents an LDAP AddRequest operation +type AddRequest struct { + // DN identifies the entry being added + DN string + // Attributes list the attributes of the new entry + Attributes []Attribute + // Controls hold optional controls to send with the request + Controls []Control +} + +func (req *AddRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) + attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") + for _, attribute := range req.Attributes { + attributes.AppendChild(attribute.encode()) + } + pkt.AppendChild(attributes) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil +} + +// Attribute adds an attribute with the given type and values +func (req *AddRequest) Attribute(attrType string, attrVals []string) { + req.Attributes = append(req.Attributes, Attribute{Type: attrType, Vals: attrVals}) +} + +// NewAddRequest returns an AddRequest for the given DN, with no attributes +func NewAddRequest(dn string, controls []Control) *AddRequest { + return &AddRequest{ + DN: dn, + Controls: controls, + } + +} + +// Add performs the given AddRequest +func (l *Conn) Add(addRequest *AddRequest) error { + msgCtx, err := l.doRequest(addRequest) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + if packet.Children[1].Tag == ApplicationAddResponse { + err := GetLDAPError(packet) + if err != nil { + return err + } + } else { + log.Printf("Unexpected Response: %d", packet.Children[1].Tag) + } + return nil +} diff --git a/v3/bind.go b/v3/bind.go new file mode 100644 index 00000000..15307f50 --- /dev/null +++ b/v3/bind.go @@ -0,0 +1,152 @@ +package ldap + +import ( + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// SimpleBindRequest represents a username/password bind operation +type SimpleBindRequest struct { + // Username is the name of the Directory object that the client wishes to bind as + Username string + // Password is the credentials to bind with + Password string + // Controls are optional controls to send with the bind request + Controls []Control + // AllowEmptyPassword sets whether the client allows binding with an empty password + // (normally used for unauthenticated bind). + AllowEmptyPassword bool +} + +// SimpleBindResult contains the response from the server +type SimpleBindResult struct { + Controls []Control +} + +// NewSimpleBindRequest returns a bind request +func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest { + return &SimpleBindRequest{ + Username: username, + Password: password, + Controls: controls, + AllowEmptyPassword: false, + } +} + +func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name")) + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password")) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil +} + +// SimpleBind performs the simple bind operation defined in the given request +func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) { + if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword { + return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) + } + + msgCtx, err := l.doRequest(simpleBindRequest) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return nil, err + } + + result := &SimpleBindResult{ + Controls: make([]Control, 0), + } + + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + decodedChild, decodeErr := DecodeControl(child) + if decodeErr != nil { + return nil, fmt.Errorf("failed to decode child control: %s", decodeErr) + } + result.Controls = append(result.Controls, decodedChild) + } + } + + err = GetLDAPError(packet) + return result, err +} + +// Bind performs a bind with the given username and password. +// +// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method +// for that. +func (l *Conn) Bind(username, password string) error { + req := &SimpleBindRequest{ + Username: username, + Password: password, + AllowEmptyPassword: false, + } + _, err := l.SimpleBind(req) + return err +} + +// UnauthenticatedBind performs an unauthenticated bind. +// +// A username may be provided for trace (e.g. logging) purpose only, but it is normally not +// authenticated or otherwise validated by the LDAP server. +// +// See https://tools.ietf.org/html/rfc4513#section-5.1.2 . +// See https://tools.ietf.org/html/rfc4513#section-6.3.1 . +func (l *Conn) UnauthenticatedBind(username string) error { + req := &SimpleBindRequest{ + Username: username, + Password: "", + AllowEmptyPassword: true, + } + _, err := l.SimpleBind(req) + return err +} + +var externalBindRequest = requestFunc(func(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech")) + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred")) + + pkt.AppendChild(saslAuth) + + envelope.AppendChild(pkt) + + return nil +}) + +// ExternalBind performs SASL/EXTERNAL authentication. +// +// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind. +// +// See https://tools.ietf.org/html/rfc4422#appendix-A +func (l *Conn) ExternalBind() error { + msgCtx, err := l.doRequest(externalBindRequest) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + return GetLDAPError(packet) +} diff --git a/v3/client.go b/v3/client.go new file mode 100644 index 00000000..619677c7 --- /dev/null +++ b/v3/client.go @@ -0,0 +1,30 @@ +package ldap + +import ( + "crypto/tls" + "time" +) + +// Client knows how to interact with an LDAP server +type Client interface { + Start() + StartTLS(*tls.Config) error + Close() + SetTimeout(time.Duration) + + Bind(username, password string) error + UnauthenticatedBind(username string) error + SimpleBind(*SimpleBindRequest) (*SimpleBindResult, error) + ExternalBind() error + + Add(*AddRequest) error + Del(*DelRequest) error + Modify(*ModifyRequest) error + ModifyDN(*ModifyDNRequest) error + + Compare(dn, attribute, value string) (bool, error) + PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) + + Search(*SearchRequest) (*SearchResult, error) + SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) +} diff --git a/v3/compare.go b/v3/compare.go new file mode 100644 index 00000000..04a2e17c --- /dev/null +++ b/v3/compare.go @@ -0,0 +1,80 @@ +// File contains Compare functionality +// +// https://tools.ietf.org/html/rfc4511 +// +// CompareRequest ::= [APPLICATION 14] SEQUENCE { +// entry LDAPDN, +// ava AttributeValueAssertion } +// +// AttributeValueAssertion ::= SEQUENCE { +// attributeDesc AttributeDescription, +// assertionValue AssertionValue } +// +// AttributeDescription ::= LDAPString +// -- Constrained to +// -- [RFC4512] +// +// AttributeValue ::= OCTET STRING +// + +package ldap + +import ( + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// CompareRequest represents an LDAP CompareRequest operation. +type CompareRequest struct { + DN string + Attribute string + Value string +} + +func (req *CompareRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) + + ava := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "AttributeValueAssertion") + ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Attribute, "AttributeDesc")) + ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Value, "AssertionValue")) + + pkt.AppendChild(ava) + + envelope.AppendChild(pkt) + + return nil +} + +// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise +// false with any error that occurs if any. +func (l *Conn) Compare(dn, attribute, value string) (bool, error) { + msgCtx, err := l.doRequest(&CompareRequest{ + DN: dn, + Attribute: attribute, + Value: value}) + if err != nil { + return false, err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return false, err + } + + if packet.Children[1].Tag == ApplicationCompareResponse { + err := GetLDAPError(packet) + + switch { + case IsErrorWithCode(err, LDAPResultCompareTrue): + return true, nil + case IsErrorWithCode(err, LDAPResultCompareFalse): + return false, nil + default: + return false, err + } + } + return false, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag) +} diff --git a/v3/conn.go b/v3/conn.go new file mode 100644 index 00000000..e6051170 --- /dev/null +++ b/v3/conn.go @@ -0,0 +1,522 @@ +package ldap + +import ( + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "net/url" + "sync" + "sync/atomic" + "time" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +const ( + // MessageQuit causes the processMessages loop to exit + MessageQuit = 0 + // MessageRequest sends a request to the server + MessageRequest = 1 + // MessageResponse receives a response from the server + MessageResponse = 2 + // MessageFinish indicates the client considers a particular message ID to be finished + MessageFinish = 3 + // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached + MessageTimeout = 4 +) + +const ( + // DefaultLdapPort default ldap port for pure TCP connection + DefaultLdapPort = "389" + // DefaultLdapsPort default ldap port for SSL connection + DefaultLdapsPort = "636" +) + +// PacketResponse contains the packet or error encountered reading a response +type PacketResponse struct { + // Packet is the packet read from the server + Packet *ber.Packet + // Error is an error encountered while reading + Error error +} + +// ReadPacket returns the packet or an error +func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { + if (pr == nil) || (pr.Packet == nil && pr.Error == nil) { + return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) + } + return pr.Packet, pr.Error +} + +type messageContext struct { + id int64 + // close(done) should only be called from finishMessage() + done chan struct{} + // close(responses) should only be called from processMessages(), and only sent to from sendResponse() + responses chan *PacketResponse +} + +// sendResponse should only be called within the processMessages() loop which +// is also responsible for closing the responses channel. +func (msgCtx *messageContext) sendResponse(packet *PacketResponse) { + select { + case msgCtx.responses <- packet: + // Successfully sent packet to message handler. + case <-msgCtx.done: + // The request handler is done and will not receive more + // packets. + } +} + +type messagePacket struct { + Op int + MessageID int64 + Packet *ber.Packet + Context *messageContext +} + +type sendMessageFlags uint + +const ( + startTLS sendMessageFlags = 1 << iota +) + +// Conn represents an LDAP Connection +type Conn struct { + // requestTimeout is loaded atomically + // so we need to ensure 64-bit alignment on 32-bit platforms. + requestTimeout int64 + conn net.Conn + isTLS bool + closing uint32 + closeErr atomic.Value + isStartingTLS bool + Debug debugging + chanConfirm chan struct{} + messageContexts map[int64]*messageContext + chanMessage chan *messagePacket + chanMessageID chan int64 + wgClose sync.WaitGroup + outstandingRequests uint + messageMutex sync.Mutex +} + +var _ Client = &Conn{} + +// DefaultTimeout is a package-level variable that sets the timeout value +// used for the Dial and DialTLS methods. +// +// WARNING: since this is a package-level variable, setting this value from +// multiple places will probably result in undesired behaviour. +var DefaultTimeout = 60 * time.Second + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, error) { + c, err := net.DialTimeout(network, addr, DefaultTimeout) + if err != nil { + return nil, NewError(ErrorNetwork, err) + } + conn := NewConn(c, false) + conn.Start() + return conn, nil +} + +// DialTLS connects to the given address on the given network using tls.Dial +// and then returns a new Conn for the connection. +func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { + c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config) + if err != nil { + return nil, NewError(ErrorNetwork, err) + } + conn := NewConn(c, true) + conn.Start() + return conn, nil +} + +// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps:// +// or ldap:// specified as protocol. On success a new Conn for the connection +// is returned. +func DialURL(addr string) (*Conn, error) { + lurl, err := url.Parse(addr) + if err != nil { + return nil, NewError(ErrorNetwork, err) + } + + host, port, err := net.SplitHostPort(lurl.Host) + if err != nil { + // we asume that error is due to missing port + host = lurl.Host + port = "" + } + + switch lurl.Scheme { + case "ldapi": + if lurl.Path == "" || lurl.Path == "/" { + lurl.Path = "/var/run/slapd/ldapi" + } + return Dial("unix", lurl.Path) + case "ldap": + if port == "" { + port = DefaultLdapPort + } + return Dial("tcp", net.JoinHostPort(host, port)) + case "ldaps": + if port == "" { + port = DefaultLdapsPort + } + tlsConf := &tls.Config{ + ServerName: host, + } + return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf) + } + + return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme)) +} + +// NewConn returns a new Conn using conn for network I/O. +func NewConn(conn net.Conn, isTLS bool) *Conn { + return &Conn{ + conn: conn, + chanConfirm: make(chan struct{}), + chanMessageID: make(chan int64), + chanMessage: make(chan *messagePacket, 10), + messageContexts: map[int64]*messageContext{}, + requestTimeout: 0, + isTLS: isTLS, + } +} + +// Start initializes goroutines to read responses and process messages +func (l *Conn) Start() { + l.wgClose.Add(1) + go l.reader() + go l.processMessages() +} + +// IsClosing returns whether or not we're currently closing. +func (l *Conn) IsClosing() bool { + return atomic.LoadUint32(&l.closing) == 1 +} + +// setClosing sets the closing value to true +func (l *Conn) setClosing() bool { + return atomic.CompareAndSwapUint32(&l.closing, 0, 1) +} + +// Close closes the connection. +func (l *Conn) Close() { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() + + if l.setClosing() { + l.Debug.Printf("Sending quit message and waiting for confirmation") + l.chanMessage <- &messagePacket{Op: MessageQuit} + <-l.chanConfirm + close(l.chanMessage) + + l.Debug.Printf("Closing network connection") + if err := l.conn.Close(); err != nil { + log.Println(err) + } + + l.wgClose.Done() + } + l.wgClose.Wait() +} + +// SetTimeout sets the time after a request is sent that a MessageTimeout triggers +func (l *Conn) SetTimeout(timeout time.Duration) { + if timeout > 0 { + atomic.StoreInt64(&l.requestTimeout, int64(timeout)) + } +} + +// Returns the next available messageID +func (l *Conn) nextMessageID() int64 { + if messageID, ok := <-l.chanMessageID; ok { + return messageID + } + return 0 +} + +// StartTLS sends the command to start a TLS session and then creates a new TLS Client +func (l *Conn) StartTLS(config *tls.Config) error { + if l.isTLS { + return NewError(ErrorNetwork, errors.New("ldap: already encrypted")) + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") + request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) + packet.AppendChild(request) + l.Debug.PrintPacket(packet) + + msgCtx, err := l.sendMessageWithFlags(packet, startTLS) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + l.Debug.Printf("%d: waiting for response", msgCtx.id) + + packetResponse, ok := <-msgCtx.responses + if !ok { + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return err + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + l.Close() + return err + } + ber.PrintPacket(packet) + } + + if err := GetLDAPError(packet); err == nil { + conn := tls.Client(l.conn, config) + + if connErr := conn.Handshake(); connErr != nil { + l.Close() + return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr)) + } + + l.isTLS = true + l.conn = conn + } else { + return err + } + go l.reader() + + return nil +} + +// TLSConnectionState returns the client's TLS connection state. +// The return values are their zero values if StartTLS did +// not succeed. +func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) { + tc, ok := l.conn.(*tls.Conn) + if !ok { + return + } + return tc.ConnectionState(), true +} + +func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { + return l.sendMessageWithFlags(packet, 0) +} + +func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { + if l.IsClosing() { + return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) + } + l.messageMutex.Lock() + l.Debug.Printf("flags&startTLS = %d", flags&startTLS) + if l.isStartingTLS { + l.messageMutex.Unlock() + return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase")) + } + if flags&startTLS != 0 { + if l.outstandingRequests != 0 { + l.messageMutex.Unlock() + return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests")) + } + l.isStartingTLS = true + } + l.outstandingRequests++ + + l.messageMutex.Unlock() + + responses := make(chan *PacketResponse) + messageID := packet.Children[0].Value.(int64) + message := &messagePacket{ + Op: MessageRequest, + MessageID: messageID, + Packet: packet, + Context: &messageContext{ + id: messageID, + done: make(chan struct{}), + responses: responses, + }, + } + l.sendProcessMessage(message) + return message.Context, nil +} + +func (l *Conn) finishMessage(msgCtx *messageContext) { + close(msgCtx.done) + + if l.IsClosing() { + return + } + + l.messageMutex.Lock() + l.outstandingRequests-- + if l.isStartingTLS { + l.isStartingTLS = false + } + l.messageMutex.Unlock() + + message := &messagePacket{ + Op: MessageFinish, + MessageID: msgCtx.id, + } + l.sendProcessMessage(message) +} + +func (l *Conn) sendProcessMessage(message *messagePacket) bool { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() + if l.IsClosing() { + return false + } + l.chanMessage <- message + return true +} + +func (l *Conn) processMessages() { + defer func() { + if err := recover(); err != nil { + log.Printf("ldap: recovered panic in processMessages: %v", err) + } + for messageID, msgCtx := range l.messageContexts { + // If we are closing due to an error, inform anyone who + // is waiting about the error. + if l.IsClosing() && l.closeErr.Load() != nil { + msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}) + } + l.Debug.Printf("Closing channel for MessageID %d", messageID) + close(msgCtx.responses) + delete(l.messageContexts, messageID) + } + close(l.chanMessageID) + close(l.chanConfirm) + }() + + var messageID int64 = 1 + for { + select { + case l.chanMessageID <- messageID: + messageID++ + case message := <-l.chanMessage: + switch message.Op { + case MessageQuit: + l.Debug.Printf("Shutting down - quit message received") + return + case MessageRequest: + // Add to message list and write to network + l.Debug.Printf("Sending message %d", message.MessageID) + + buf := message.Packet.Bytes() + _, err := l.conn.Write(buf) + if err != nil { + l.Debug.Printf("Error Sending Message: %s", err.Error()) + message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}) + close(message.Context.responses) + break + } + + // Only add to messageContexts if we were able to + // successfully write the message. + l.messageContexts[message.MessageID] = message.Context + + // Add timeout if defined + requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) + if requestTimeout > 0 { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("ldap: recovered panic in RequestTimeout: %v", err) + } + }() + time.Sleep(requestTimeout) + timeoutMessage := &messagePacket{ + Op: MessageTimeout, + MessageID: message.MessageID, + } + l.sendProcessMessage(timeoutMessage) + }() + } + case MessageResponse: + l.Debug.Printf("Receiving message %d", message.MessageID) + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) + } else { + log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing()) + ber.PrintPacket(message.Packet) + } + case MessageTimeout: + // Handle the timeout by closing the channel + // All reads will return immediately + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + l.Debug.Printf("Receiving message timeout for %d", message.MessageID) + msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) + } + case MessageFinish: + l.Debug.Printf("Finished message %d", message.MessageID) + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) + } + } + } + } +} + +func (l *Conn) reader() { + cleanstop := false + defer func() { + if err := recover(); err != nil { + log.Printf("ldap: recovered panic in reader: %v", err) + } + if !cleanstop { + l.Close() + } + }() + + for { + if cleanstop { + l.Debug.Printf("reader clean stopping (without closing the connection)") + return + } + packet, err := ber.ReadPacket(l.conn) + if err != nil { + // A read error is expected here if we are closing the connection... + if !l.IsClosing() { + l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) + l.Debug.Printf("reader error: %s", err) + } + return + } + if err := addLDAPDescriptions(packet); err != nil { + l.Debug.Printf("descriptions error: %s", err) + } + if len(packet.Children) == 0 { + l.Debug.Printf("Received bad ldap packet") + continue + } + l.messageMutex.Lock() + if l.isStartingTLS { + cleanstop = true + } + l.messageMutex.Unlock() + message := &messagePacket{ + Op: MessageResponse, + MessageID: packet.Children[0].Value.(int64), + Packet: packet, + } + if !l.sendProcessMessage(message) { + return + } + } +} diff --git a/v3/conn_test.go b/v3/conn_test.go new file mode 100644 index 00000000..f815a967 --- /dev/null +++ b/v3/conn_test.go @@ -0,0 +1,336 @@ +package ldap + +import ( + "bytes" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "runtime" + "sync" + "testing" + "time" + + "github.com/go-asn1-ber/asn1-ber" +) + +func TestUnresponsiveConnection(t *testing.T) { + // The do-nothing server that accepts requests and does nothing + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error connecting to localhost tcp: %v", err) + } + + // Create an Ldap connection + conn := NewConn(c, false) + conn.SetTimeout(time.Millisecond) + conn.Start() + defer conn.Close() + + // Mock a packet + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, conn.nextMessageID(), "MessageID")) + bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + packet.AppendChild(bindRequest) + + // Send packet and test response + msgCtx, err := conn.sendMessage(packet) + if err != nil { + t.Fatalf("error sending message: %v", err) + } + defer conn.finishMessage(msgCtx) + + packetResponse, ok := <-msgCtx.responses + if !ok { + t.Fatalf("no PacketResponse in response channel") + } + packet, err = packetResponse.ReadPacket() + if err == nil { + t.Fatalf("expected timeout error") + } + if err.Error() != "ldap: connection timed out" { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestFinishMessage tests that we do not enter deadlock when a goroutine makes +// a request but does not handle all responses from the server. +func TestFinishMessage(t *testing.T) { + ptc := newPacketTranslatorConn() + defer ptc.Close() + + conn := NewConn(ptc, false) + conn.Start() + + // Test sending 5 different requests in series. Ensure that we can + // get a response packet from the underlying connection and also + // ensure that we can gracefully ignore unhandled responses. + for i := 0; i < 5; i++ { + t.Logf("serial request %d", i) + // Create a message and make sure we can receive responses. + msgCtx := testSendRequest(t, ptc, conn) + testReceiveResponse(t, ptc, msgCtx) + + // Send a few unhandled responses and finish the message. + testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) + t.Logf("serial request %d done", i) + } + + // Test sending 5 different requests in parallel. + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + t.Logf("parallel request %d", i) + // Create a message and make sure we can receive responses. + msgCtx := testSendRequest(t, ptc, conn) + testReceiveResponse(t, ptc, msgCtx) + + // Send a few unhandled responses and finish the message. + testSendUnhandledResponsesAndFinish(t, ptc, conn, msgCtx, 5) + t.Logf("parallel request %d done", i) + }(i) + } + wg.Wait() + + // We cannot run Close() in a defer because t.FailNow() will run it and + // it will block if the processMessage Loop is in a deadlock. + conn.Close() +} + +func testSendRequest(t *testing.T, ptc *packetTranslatorConn, conn *Conn) (msgCtx *messageContext) { + var msgID int64 + runWithTimeout(t, time.Second, func() { + msgID = conn.nextMessageID() + }) + + requestPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + requestPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgID, "MessageID")) + + var err error + + runWithTimeout(t, time.Second, func() { + msgCtx, err = conn.sendMessage(requestPacket) + if err != nil { + t.Fatalf("unable to send request message: %s", err) + } + }) + + // We should now be able to get this request packet out from the other + // side. + runWithTimeout(t, time.Second, func() { + if _, err = ptc.ReceiveRequest(); err != nil { + t.Fatalf("unable to receive request packet: %s", err) + } + }) + + return msgCtx +} + +func testReceiveResponse(t *testing.T, ptc *packetTranslatorConn, msgCtx *messageContext) { + // Send a mock response packet. + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) + + runWithTimeout(t, time.Second, func() { + if err := ptc.SendResponse(responsePacket); err != nil { + t.Fatalf("unable to send response packet: %s", err) + } + }) + + // We should be able to receive the packet from the connection. + runWithTimeout(t, time.Second, func() { + if _, ok := <-msgCtx.responses; !ok { + t.Fatal("response channel closed") + } + }) +} + +func testSendUnhandledResponsesAndFinish(t *testing.T, ptc *packetTranslatorConn, conn *Conn, msgCtx *messageContext, numResponses int) { + // Send a mock response packet. + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, msgCtx.id, "MessageID")) + + // Send extra responses but do not attempt to receive them on the + // client side. + for i := 0; i < numResponses; i++ { + runWithTimeout(t, time.Second, func() { + if err := ptc.SendResponse(responsePacket); err != nil { + t.Fatalf("unable to send response packet: %s", err) + } + }) + } + + // Finally, attempt to finish this message. + runWithTimeout(t, time.Second, func() { + conn.finishMessage(msgCtx) + }) +} + +func runWithTimeout(t *testing.T, timeout time.Duration, f func()) { + done := make(chan struct{}) + go func() { + f() + close(done) + }() + + select { + case <-done: // Success! + case <-time.After(timeout): + _, file, line, _ := runtime.Caller(1) + t.Fatalf("%s:%d timed out", file, line) + } +} + +// packetTranslatorConn is a helpful type which can be used with various tests +// in this package. It implements the net.Conn interface to be used as an +// underlying connection for a *ldap.Conn. Most methods are no-ops but the +// Read() and Write() methods are able to translate ber-encoded packets for +// testing LDAP requests and responses. +// +// Test cases can simulate an LDAP server sending a response by calling the +// SendResponse() method with a ber-encoded LDAP response packet. Test cases +// can simulate an LDAP server receiving a request from a client by calling the +// ReceiveRequest() method which returns a ber-encoded LDAP request packet. +type packetTranslatorConn struct { + lock sync.Mutex + isClosed bool + + responseCond sync.Cond + requestCond sync.Cond + + responseBuf bytes.Buffer + requestBuf bytes.Buffer +} + +var errPacketTranslatorConnClosed = errors.New("connection closed") + +func newPacketTranslatorConn() *packetTranslatorConn { + conn := &packetTranslatorConn{} + conn.responseCond = sync.Cond{L: &conn.lock} + conn.requestCond = sync.Cond{L: &conn.lock} + + return conn +} + +// Read is called by the reader() loop to receive response packets. It will +// block until there are more packet bytes available or this connection is +// closed. +func (c *packetTranslatorConn) Read(b []byte) (n int, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + for !c.isClosed { + // Attempt to read data from the response buffer. If it fails + // with an EOF, wait and try again. + n, err = c.responseBuf.Read(b) + if err != io.EOF { + return n, err + } + + c.responseCond.Wait() + } + + return 0, errPacketTranslatorConnClosed +} + +// SendResponse writes the given response packet to the response buffer for +// this connection, signalling any goroutine waiting to read a response. +func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error { + c.lock.Lock() + defer c.lock.Unlock() + + if c.isClosed { + return errPacketTranslatorConnClosed + } + + // Signal any goroutine waiting to read a response. + defer c.responseCond.Broadcast() + + // Writes to the buffer should always succeed. + c.responseBuf.Write(packet.Bytes()) + + return nil +} + +// Write is called by the processMessages() loop to send request packets. +func (c *packetTranslatorConn) Write(b []byte) (n int, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.isClosed { + return 0, errPacketTranslatorConnClosed + } + + // Signal any goroutine waiting to read a request. + defer c.requestCond.Broadcast() + + // Writes to the buffer should always succeed. + return c.requestBuf.Write(b) +} + +// ReceiveRequest attempts to read a request packet from this connection. It +// will block until it is able to read a full request packet or until this +// connection is closed. +func (c *packetTranslatorConn) ReceiveRequest() (*ber.Packet, error) { + c.lock.Lock() + defer c.lock.Unlock() + + for !c.isClosed { + // Attempt to parse a request packet from the request buffer. + // If it fails with an unexpected EOF, wait and try again. + requestReader := bytes.NewReader(c.requestBuf.Bytes()) + packet, err := ber.ReadPacket(requestReader) + switch err { + case io.EOF, io.ErrUnexpectedEOF: + c.requestCond.Wait() + case nil: + // Advance the request buffer by the number of bytes + // read to decode the request packet. + c.requestBuf.Next(c.requestBuf.Len() - requestReader.Len()) + return packet, nil + default: + return nil, err + } + } + + return nil, errPacketTranslatorConnClosed +} + +// Close closes this connection causing Read() and Write() calls to fail. +func (c *packetTranslatorConn) Close() error { + c.lock.Lock() + defer c.lock.Unlock() + + c.isClosed = true + c.responseCond.Broadcast() + c.requestCond.Broadcast() + + return nil +} + +func (c *packetTranslatorConn) LocalAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *packetTranslatorConn) RemoteAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *packetTranslatorConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *packetTranslatorConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *packetTranslatorConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/v3/control.go b/v3/control.go new file mode 100644 index 00000000..463fe3a3 --- /dev/null +++ b/v3/control.go @@ -0,0 +1,499 @@ +package ldap + +import ( + "fmt" + "strconv" + + "github.com/go-asn1-ber/asn1-ber" +) + +const ( + // ControlTypePaging - https://www.ietf.org/rfc/rfc2696.txt + ControlTypePaging = "1.2.840.113556.1.4.319" + // ControlTypeBeheraPasswordPolicy - https://tools.ietf.org/html/draft-behera-ldap-password-policy-10 + ControlTypeBeheraPasswordPolicy = "1.3.6.1.4.1.42.2.27.8.5.1" + // ControlTypeVChuPasswordMustChange - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 + ControlTypeVChuPasswordMustChange = "2.16.840.1.113730.3.4.4" + // ControlTypeVChuPasswordWarning - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 + ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5" + // ControlTypeManageDsaIT - https://tools.ietf.org/html/rfc3296 + ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2" + + // ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx + ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528" + // ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx + ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417" +) + +// ControlTypeMap maps controls to text descriptions +var ControlTypeMap = map[string]string{ + ControlTypePaging: "Paging", + ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft", + ControlTypeManageDsaIT: "Manage DSA IT", + ControlTypeMicrosoftNotification: "Change Notification - Microsoft", + ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft", +} + +// Control defines an interface controls provide to encode and describe themselves +type Control interface { + // GetControlType returns the OID + GetControlType() string + // Encode returns the ber packet representation + Encode() *ber.Packet + // String returns a human-readable description + String() string +} + +// ControlString implements the Control interface for simple controls +type ControlString struct { + ControlType string + Criticality bool + ControlValue string +} + +// GetControlType returns the OID +func (c *ControlString) GetControlType() string { + return c.ControlType +} + +// Encode returns the ber packet representation +func (c *ControlString) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")")) + if c.Criticality { + packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality")) + } + if c.ControlValue != "" { + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(c.ControlValue), "Control Value")) + } + return packet +} + +// String returns a human-readable description +func (c *ControlString) String() string { + return fmt.Sprintf("Control Type: %s (%q) Criticality: %t Control Value: %s", ControlTypeMap[c.ControlType], c.ControlType, c.Criticality, c.ControlValue) +} + +// ControlPaging implements the paging control described in https://www.ietf.org/rfc/rfc2696.txt +type ControlPaging struct { + // PagingSize indicates the page size + PagingSize uint32 + // Cookie is an opaque value returned by the server to track a paging cursor + Cookie []byte +} + +// GetControlType returns the OID +func (c *ControlPaging) GetControlType() string { + return ControlTypePaging +} + +// Encode returns the ber packet representation +func (c *ControlPaging) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")")) + + p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Paging)") + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value") + seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(c.PagingSize), "Paging Size")) + cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie") + cookie.Value = c.Cookie + cookie.Data.Write(c.Cookie) + seq.AppendChild(cookie) + p2.AppendChild(seq) + + packet.AppendChild(p2) + return packet +} + +// String returns a human-readable description +func (c *ControlPaging) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality: %t PagingSize: %d Cookie: %q", + ControlTypeMap[ControlTypePaging], + ControlTypePaging, + false, + c.PagingSize, + c.Cookie) +} + +// SetCookie stores the given cookie in the paging control +func (c *ControlPaging) SetCookie(cookie []byte) { + c.Cookie = cookie +} + +// ControlBeheraPasswordPolicy implements the control described in https://tools.ietf.org/html/draft-behera-ldap-password-policy-10 +type ControlBeheraPasswordPolicy struct { + // Expire contains the number of seconds before a password will expire + Expire int64 + // Grace indicates the remaining number of times a user will be allowed to authenticate with an expired password + Grace int64 + // Error indicates the error code + Error int8 + // ErrorString is a human readable error + ErrorString string +} + +// GetControlType returns the OID +func (c *ControlBeheraPasswordPolicy) GetControlType() string { + return ControlTypeBeheraPasswordPolicy +} + +// Encode returns the ber packet representation +func (c *ControlBeheraPasswordPolicy) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeBeheraPasswordPolicy, "Control Type ("+ControlTypeMap[ControlTypeBeheraPasswordPolicy]+")")) + + return packet +} + +// String returns a human-readable description +func (c *ControlBeheraPasswordPolicy) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality: %t Expire: %d Grace: %d Error: %d, ErrorString: %s", + ControlTypeMap[ControlTypeBeheraPasswordPolicy], + ControlTypeBeheraPasswordPolicy, + false, + c.Expire, + c.Grace, + c.Error, + c.ErrorString) +} + +// ControlVChuPasswordMustChange implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 +type ControlVChuPasswordMustChange struct { + // MustChange indicates if the password is required to be changed + MustChange bool +} + +// GetControlType returns the OID +func (c *ControlVChuPasswordMustChange) GetControlType() string { + return ControlTypeVChuPasswordMustChange +} + +// Encode returns the ber packet representation +func (c *ControlVChuPasswordMustChange) Encode() *ber.Packet { + return nil +} + +// String returns a human-readable description +func (c *ControlVChuPasswordMustChange) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality: %t MustChange: %v", + ControlTypeMap[ControlTypeVChuPasswordMustChange], + ControlTypeVChuPasswordMustChange, + false, + c.MustChange) +} + +// ControlVChuPasswordWarning implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 +type ControlVChuPasswordWarning struct { + // Expire indicates the time in seconds until the password expires + Expire int64 +} + +// GetControlType returns the OID +func (c *ControlVChuPasswordWarning) GetControlType() string { + return ControlTypeVChuPasswordWarning +} + +// Encode returns the ber packet representation +func (c *ControlVChuPasswordWarning) Encode() *ber.Packet { + return nil +} + +// String returns a human-readable description +func (c *ControlVChuPasswordWarning) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality: %t Expire: %b", + ControlTypeMap[ControlTypeVChuPasswordWarning], + ControlTypeVChuPasswordWarning, + false, + c.Expire) +} + +// ControlManageDsaIT implements the control described in https://tools.ietf.org/html/rfc3296 +type ControlManageDsaIT struct { + // Criticality indicates if this control is required + Criticality bool +} + +// GetControlType returns the OID +func (c *ControlManageDsaIT) GetControlType() string { + return ControlTypeManageDsaIT +} + +// Encode returns the ber packet representation +func (c *ControlManageDsaIT) Encode() *ber.Packet { + //FIXME + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeManageDsaIT, "Control Type ("+ControlTypeMap[ControlTypeManageDsaIT]+")")) + if c.Criticality { + packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality")) + } + return packet +} + +// String returns a human-readable description +func (c *ControlManageDsaIT) String() string { + return fmt.Sprintf( + "Control Type: %s (%q) Criticality: %t", + ControlTypeMap[ControlTypeManageDsaIT], + ControlTypeManageDsaIT, + c.Criticality) +} + +// NewControlManageDsaIT returns a ControlManageDsaIT control +func NewControlManageDsaIT(Criticality bool) *ControlManageDsaIT { + return &ControlManageDsaIT{Criticality: Criticality} +} + +// ControlMicrosoftNotification implements the control described in https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx +type ControlMicrosoftNotification struct{} + +// GetControlType returns the OID +func (c *ControlMicrosoftNotification) GetControlType() string { + return ControlTypeMicrosoftNotification +} + +// Encode returns the ber packet representation +func (c *ControlMicrosoftNotification) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftNotification, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftNotification]+")")) + + return packet +} + +// String returns a human-readable description +func (c *ControlMicrosoftNotification) String() string { + return fmt.Sprintf( + "Control Type: %s (%q)", + ControlTypeMap[ControlTypeMicrosoftNotification], + ControlTypeMicrosoftNotification) +} + +// NewControlMicrosoftNotification returns a ControlMicrosoftNotification control +func NewControlMicrosoftNotification() *ControlMicrosoftNotification { + return &ControlMicrosoftNotification{} +} + +// ControlMicrosoftShowDeleted implements the control described in https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx +type ControlMicrosoftShowDeleted struct{} + +// GetControlType returns the OID +func (c *ControlMicrosoftShowDeleted) GetControlType() string { + return ControlTypeMicrosoftShowDeleted +} + +// Encode returns the ber packet representation +func (c *ControlMicrosoftShowDeleted) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftShowDeleted, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftShowDeleted]+")")) + + return packet +} + +// String returns a human-readable description +func (c *ControlMicrosoftShowDeleted) String() string { + return fmt.Sprintf( + "Control Type: %s (%q)", + ControlTypeMap[ControlTypeMicrosoftShowDeleted], + ControlTypeMicrosoftShowDeleted) +} + +// NewControlMicrosoftShowDeleted returns a ControlMicrosoftShowDeleted control +func NewControlMicrosoftShowDeleted() *ControlMicrosoftShowDeleted { + return &ControlMicrosoftShowDeleted{} +} + +// FindControl returns the first control of the given type in the list, or nil +func FindControl(controls []Control, controlType string) Control { + for _, c := range controls { + if c.GetControlType() == controlType { + return c + } + } + return nil +} + +// DecodeControl returns a control read from the given packet, or nil if no recognized control can be made +func DecodeControl(packet *ber.Packet) (Control, error) { + var ( + ControlType = "" + Criticality = false + value *ber.Packet + ) + + switch len(packet.Children) { + case 0: + // at least one child is required for control type + return nil, fmt.Errorf("at least one child is required for control type") + + case 1: + // just type, no criticality or value + packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" + ControlType = packet.Children[0].Value.(string) + + case 2: + packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" + ControlType = packet.Children[0].Value.(string) + + // Children[1] could be criticality or value (both are optional) + // duck-type on whether this is a boolean + if _, ok := packet.Children[1].Value.(bool); ok { + packet.Children[1].Description = "Criticality" + Criticality = packet.Children[1].Value.(bool) + } else { + packet.Children[1].Description = "Control Value" + value = packet.Children[1] + } + + case 3: + packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" + ControlType = packet.Children[0].Value.(string) + + packet.Children[1].Description = "Criticality" + Criticality = packet.Children[1].Value.(bool) + + packet.Children[2].Description = "Control Value" + value = packet.Children[2] + + default: + // more than 3 children is invalid + return nil, fmt.Errorf("more than 3 children is invalid for controls") + } + + switch ControlType { + case ControlTypeManageDsaIT: + return NewControlManageDsaIT(Criticality), nil + case ControlTypePaging: + value.Description += " (Paging)" + c := new(ControlPaging) + if value.Value != nil { + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } + value.Data.Truncate(0) + value.Value = nil + value.AppendChild(valueChildren) + } + value = value.Children[0] + value.Description = "Search Control Value" + value.Children[0].Description = "Paging Size" + value.Children[1].Description = "Cookie" + c.PagingSize = uint32(value.Children[0].Value.(int64)) + c.Cookie = value.Children[1].Data.Bytes() + value.Children[1].Value = c.Cookie + return c, nil + case ControlTypeBeheraPasswordPolicy: + value.Description += " (Password Policy - Behera)" + c := NewControlBeheraPasswordPolicy() + if value.Value != nil { + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } + value.Data.Truncate(0) + value.Value = nil + value.AppendChild(valueChildren) + } + + sequence := value.Children[0] + + for _, child := range sequence.Children { + if child.Tag == 0 { + //Warning + warningPacket := child.Children[0] + packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } + val, ok := packet.Value.(int64) + if ok { + if warningPacket.Tag == 0 { + //timeBeforeExpiration + c.Expire = val + warningPacket.Value = c.Expire + } else if warningPacket.Tag == 1 { + //graceAuthNsRemaining + c.Grace = val + warningPacket.Value = c.Grace + } + } + } else if child.Tag == 1 { + // Error + packet, err := ber.DecodePacketErr(child.Data.Bytes()) + if err != nil { + return nil, fmt.Errorf("failed to decode data bytes: %s", err) + } + val, ok := packet.Value.(int8) + if !ok { + // what to do? + val = -1 + } + c.Error = val + child.Value = c.Error + c.ErrorString = BeheraPasswordPolicyErrorMap[c.Error] + } + } + return c, nil + case ControlTypeVChuPasswordMustChange: + c := &ControlVChuPasswordMustChange{MustChange: true} + return c, nil + case ControlTypeVChuPasswordWarning: + c := &ControlVChuPasswordWarning{Expire: -1} + expireStr := ber.DecodeString(value.Data.Bytes()) + + expire, err := strconv.ParseInt(expireStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse value as int: %s", err) + } + c.Expire = expire + value.Value = c.Expire + + return c, nil + case ControlTypeMicrosoftNotification: + return NewControlMicrosoftNotification(), nil + case ControlTypeMicrosoftShowDeleted: + return NewControlMicrosoftShowDeleted(), nil + default: + c := new(ControlString) + c.ControlType = ControlType + c.Criticality = Criticality + if value != nil { + c.ControlValue = value.Value.(string) + } + return c, nil + } +} + +// NewControlString returns a generic control +func NewControlString(controlType string, criticality bool, controlValue string) *ControlString { + return &ControlString{ + ControlType: controlType, + Criticality: criticality, + ControlValue: controlValue, + } +} + +// NewControlPaging returns a paging control +func NewControlPaging(pagingSize uint32) *ControlPaging { + return &ControlPaging{PagingSize: pagingSize} +} + +// NewControlBeheraPasswordPolicy returns a ControlBeheraPasswordPolicy +func NewControlBeheraPasswordPolicy() *ControlBeheraPasswordPolicy { + return &ControlBeheraPasswordPolicy{ + Expire: -1, + Grace: -1, + Error: -1, + } +} + +func encodeControls(controls []Control) *ber.Packet { + packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls") + for _, control := range controls { + packet.AppendChild(control.Encode()) + } + return packet +} diff --git a/v3/control_test.go b/v3/control_test.go new file mode 100644 index 00000000..309d7f27 --- /dev/null +++ b/v3/control_test.go @@ -0,0 +1,123 @@ +package ldap + +import ( + "bytes" + "fmt" + "reflect" + "runtime" + "testing" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +func TestControlPaging(t *testing.T) { + runControlTest(t, NewControlPaging(0)) + runControlTest(t, NewControlPaging(100)) +} + +func TestControlManageDsaIT(t *testing.T) { + runControlTest(t, NewControlManageDsaIT(true)) + runControlTest(t, NewControlManageDsaIT(false)) +} + +func TestControlMicrosoftNotification(t *testing.T) { + runControlTest(t, NewControlMicrosoftNotification()) +} + +func TestControlMicrosoftShowDeleted(t *testing.T) { + runControlTest(t, NewControlMicrosoftShowDeleted()) +} + +func TestControlString(t *testing.T) { + runControlTest(t, NewControlString("x", true, "y")) + runControlTest(t, NewControlString("x", true, "")) + runControlTest(t, NewControlString("x", false, "y")) + runControlTest(t, NewControlString("x", false, "")) +} + +func runControlTest(t *testing.T, originalControl Control) { + header := "" + if callerpc, _, line, ok := runtime.Caller(1); ok { + if caller := runtime.FuncForPC(callerpc); caller != nil { + header = fmt.Sprintf("%s:%d: ", caller.Name(), line) + } + } + + encodedPacket := originalControl.Encode() + encodedBytes := encodedPacket.Bytes() + + // Decode directly from the encoded packet (ensures Value is correct) + fromPacket, err := DecodeControl(encodedPacket) + if err != nil { + t.Errorf("%sdecoding encoded bytes control failed: %s", header, err) + } + if !bytes.Equal(encodedBytes, fromPacket.Encode().Bytes()) { + t.Errorf("%sround-trip from encoded packet failed", header) + } + if reflect.TypeOf(originalControl) != reflect.TypeOf(fromPacket) { + t.Errorf("%sgot different type decoding from encoded packet: %T vs %T", header, fromPacket, originalControl) + } + + // Decode from the wire bytes (ensures ber-encoding is correct) + pkt, err := ber.DecodePacketErr(encodedBytes) + if err != nil { + t.Errorf("%sdecoding encoded bytes failed: %s", header, err) + } + fromBytes, err := DecodeControl(pkt) + if err != nil { + t.Errorf("%sdecoding control failed: %s", header, err) + } + if !bytes.Equal(encodedBytes, fromBytes.Encode().Bytes()) { + t.Errorf("%sround-trip from encoded bytes failed", header) + } + if reflect.TypeOf(originalControl) != reflect.TypeOf(fromPacket) { + t.Errorf("%sgot different type decoding from encoded bytes: %T vs %T", header, fromBytes, originalControl) + } +} + +func TestDescribeControlManageDsaIT(t *testing.T) { + runAddControlDescriptions(t, NewControlManageDsaIT(false), "Control Type (Manage DSA IT)") + runAddControlDescriptions(t, NewControlManageDsaIT(true), "Control Type (Manage DSA IT)", "Criticality") +} + +func TestDescribeControlPaging(t *testing.T) { + runAddControlDescriptions(t, NewControlPaging(100), "Control Type (Paging)", "Control Value (Paging)") + runAddControlDescriptions(t, NewControlPaging(0), "Control Type (Paging)", "Control Value (Paging)") +} + +func TestDescribeControlMicrosoftNotification(t *testing.T) { + runAddControlDescriptions(t, NewControlMicrosoftNotification(), "Control Type (Change Notification - Microsoft)") +} + +func TestDescribeControlMicrosoftShowDeleted(t *testing.T) { + runAddControlDescriptions(t, NewControlMicrosoftShowDeleted(), "Control Type (Show Deleted Objects - Microsoft)") +} + +func TestDescribeControlString(t *testing.T) { + runAddControlDescriptions(t, NewControlString("x", true, "y"), "Control Type ()", "Criticality", "Control Value") + runAddControlDescriptions(t, NewControlString("x", true, ""), "Control Type ()", "Criticality") + runAddControlDescriptions(t, NewControlString("x", false, "y"), "Control Type ()", "Control Value") + runAddControlDescriptions(t, NewControlString("x", false, ""), "Control Type ()") +} + +func runAddControlDescriptions(t *testing.T, originalControl Control, childDescriptions ...string) { + header := "" + if callerpc, _, line, ok := runtime.Caller(1); ok { + if caller := runtime.FuncForPC(callerpc); caller != nil { + header = fmt.Sprintf("%s:%d: ", caller.Name(), line) + } + } + + encodedControls := encodeControls([]Control{originalControl}) + addControlDescriptions(encodedControls) + encodedPacket := encodedControls.Children[0] + if len(encodedPacket.Children) != len(childDescriptions) { + t.Errorf("%sinvalid number of children: %d != %d", header, len(encodedPacket.Children), len(childDescriptions)) + } + for i, desc := range childDescriptions { + if encodedPacket.Children[i].Description != desc { + t.Errorf("%sdescription not as expected: %s != %s", header, encodedPacket.Children[i].Description, desc) + } + } + +} diff --git a/v3/debug.go b/v3/debug.go new file mode 100644 index 00000000..2c0b30c8 --- /dev/null +++ b/v3/debug.go @@ -0,0 +1,30 @@ +package ldap + +import ( + "log" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// debugging type +// - has a Printf method to write the debug output +type debugging bool + +// Enable controls debugging mode. +func (debug *debugging) Enable(b bool) { + *debug = debugging(b) +} + +// Printf writes debug output. +func (debug debugging) Printf(format string, args ...interface{}) { + if debug { + log.Printf(format, args...) + } +} + +// PrintPacket dumps a packet. +func (debug debugging) PrintPacket(packet *ber.Packet) { + if debug { + ber.PrintPacket(packet) + } +} diff --git a/v3/del.go b/v3/del.go new file mode 100644 index 00000000..49811d34 --- /dev/null +++ b/v3/del.go @@ -0,0 +1,64 @@ +// +// https://tools.ietf.org/html/rfc4511 +// +// DelRequest ::= [APPLICATION 10] LDAPDN + +package ldap + +import ( + "log" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// DelRequest implements an LDAP deletion request +type DelRequest struct { + // DN is the name of the directory entry to delete + DN string + // Controls hold optional controls to send with the request + Controls []Control +} + +func (req *DelRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request") + pkt.Data.Write([]byte(req.DN)) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil +} + +// NewDelRequest creates a delete request for the given DN and controls +func NewDelRequest(DN string, Controls []Control) *DelRequest { + return &DelRequest{ + DN: DN, + Controls: Controls, + } +} + +// Del executes the given delete request +func (l *Conn) Del(delRequest *DelRequest) error { + msgCtx, err := l.doRequest(delRequest) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + if packet.Children[1].Tag == ApplicationDelResponse { + err := GetLDAPError(packet) + if err != nil { + return err + } + } else { + log.Printf("Unexpected Response: %d", packet.Children[1].Tag) + } + return nil +} diff --git a/v3/dn.go b/v3/dn.go new file mode 100644 index 00000000..9d32e7fa --- /dev/null +++ b/v3/dn.go @@ -0,0 +1,247 @@ +// File contains DN parsing functionality +// +// https://tools.ietf.org/html/rfc4514 +// +// distinguishedName = [ relativeDistinguishedName +// *( COMMA relativeDistinguishedName ) ] +// relativeDistinguishedName = attributeTypeAndValue +// *( PLUS attributeTypeAndValue ) +// attributeTypeAndValue = attributeType EQUALS attributeValue +// attributeType = descr / numericoid +// attributeValue = string / hexstring +// +// ; The following characters are to be escaped when they appear +// ; in the value to be encoded: ESC, one of , leading +// ; SHARP or SPACE, trailing SPACE, and NULL. +// string = [ ( leadchar / pair ) [ *( stringchar / pair ) +// ( trailchar / pair ) ] ] +// +// leadchar = LUTF1 / UTFMB +// LUTF1 = %x01-1F / %x21 / %x24-2A / %x2D-3A / +// %x3D / %x3F-5B / %x5D-7F +// +// trailchar = TUTF1 / UTFMB +// TUTF1 = %x01-1F / %x21 / %x23-2A / %x2D-3A / +// %x3D / %x3F-5B / %x5D-7F +// +// stringchar = SUTF1 / UTFMB +// SUTF1 = %x01-21 / %x23-2A / %x2D-3A / +// %x3D / %x3F-5B / %x5D-7F +// +// pair = ESC ( ESC / special / hexpair ) +// special = escaped / SPACE / SHARP / EQUALS +// escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE +// hexstring = SHARP 1*hexpair +// hexpair = HEX HEX +// +// where the productions , , , , +// , , , , , , , , +// , , and are defined in [RFC4512]. +// + +package ldap + +import ( + "bytes" + enchex "encoding/hex" + "errors" + "fmt" + "strings" + + "github.com/go-asn1-ber/asn1-ber" +) + +// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514 +type AttributeTypeAndValue struct { + // Type is the attribute type + Type string + // Value is the attribute value + Value string +} + +// RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514 +type RelativeDN struct { + Attributes []*AttributeTypeAndValue +} + +// DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514 +type DN struct { + RDNs []*RelativeDN +} + +// ParseDN returns a distinguishedName or an error +func ParseDN(str string) (*DN, error) { + dn := new(DN) + dn.RDNs = make([]*RelativeDN, 0) + rdn := new(RelativeDN) + rdn.Attributes = make([]*AttributeTypeAndValue, 0) + buffer := bytes.Buffer{} + attribute := new(AttributeTypeAndValue) + escaping := false + + unescapedTrailingSpaces := 0 + stringFromBuffer := func() string { + s := buffer.String() + s = s[0 : len(s)-unescapedTrailingSpaces] + buffer.Reset() + unescapedTrailingSpaces = 0 + return s + } + + for i := 0; i < len(str); i++ { + char := str[i] + switch { + case escaping: + unescapedTrailingSpaces = 0 + escaping = false + switch char { + case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\': + buffer.WriteByte(char) + continue + } + // Not a special character, assume hex encoded octet + if len(str) == i+1 { + return nil, errors.New("got corrupted escaped character") + } + + dst := []byte{0} + n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2])) + if err != nil { + return nil, fmt.Errorf("failed to decode escaped character: %s", err) + } else if n != 1 { + return nil, fmt.Errorf("expected 1 byte when un-escaping, got %d", n) + } + buffer.WriteByte(dst[0]) + i++ + case char == '\\': + unescapedTrailingSpaces = 0 + escaping = true + case char == '=': + attribute.Type = stringFromBuffer() + // Special case: If the first character in the value is # the + // following data is BER encoded so we can just fast forward + // and decode. + if len(str) > i+1 && str[i+1] == '#' { + i += 2 + index := strings.IndexAny(str[i:], ",+") + data := str + if index > 0 { + data = str[i : i+index] + } else { + data = str[i:] + } + rawBER, err := enchex.DecodeString(data) + if err != nil { + return nil, fmt.Errorf("failed to decode BER encoding: %s", err) + } + packet, err := ber.DecodePacketErr(rawBER) + if err != nil { + return nil, fmt.Errorf("failed to decode BER packet: %s", err) + } + buffer.WriteString(packet.Data.String()) + i += len(data) - 1 + } + case char == ',' || char == '+': + // We're done with this RDN or value, push it + if len(attribute.Type) == 0 { + return nil, errors.New("incomplete type, value pair") + } + attribute.Value = stringFromBuffer() + rdn.Attributes = append(rdn.Attributes, attribute) + attribute = new(AttributeTypeAndValue) + if char == ',' { + dn.RDNs = append(dn.RDNs, rdn) + rdn = new(RelativeDN) + rdn.Attributes = make([]*AttributeTypeAndValue, 0) + } + case char == ' ' && buffer.Len() == 0: + // ignore unescaped leading spaces + continue + default: + if char == ' ' { + // Track unescaped spaces in case they are trailing and we need to remove them + unescapedTrailingSpaces++ + } else { + // Reset if we see a non-space char + unescapedTrailingSpaces = 0 + } + buffer.WriteByte(char) + } + } + if buffer.Len() > 0 { + if len(attribute.Type) == 0 { + return nil, errors.New("DN ended with incomplete type, value pair") + } + attribute.Value = stringFromBuffer() + rdn.Attributes = append(rdn.Attributes, attribute) + dn.RDNs = append(dn.RDNs, rdn) + } + return dn, nil +} + +// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch). +// Returns true if they have the same number of relative distinguished names +// and corresponding relative distinguished names (by position) are the same. +func (d *DN) Equal(other *DN) bool { + if len(d.RDNs) != len(other.RDNs) { + return false + } + for i := range d.RDNs { + if !d.RDNs[i].Equal(other.RDNs[i]) { + return false + } + } + return true +} + +// AncestorOf returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN. +// "ou=widgets,o=acme.com" is an ancestor of "ou=sprockets,ou=widgets,o=acme.com" +// "ou=widgets,o=acme.com" is not an ancestor of "ou=sprockets,ou=widgets,o=foo.com" +// "ou=widgets,o=acme.com" is not an ancestor of "ou=widgets,o=acme.com" +func (d *DN) AncestorOf(other *DN) bool { + if len(d.RDNs) >= len(other.RDNs) { + return false + } + // Take the last `len(d.RDNs)` RDNs from the other DN to compare against + otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):] + for i := range d.RDNs { + if !d.RDNs[i].Equal(otherRDNs[i]) { + return false + } + } + return true +} + +// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch). +// Relative distinguished names are the same if and only if they have the same number of AttributeTypeAndValues +// and each attribute of the first RDN is the same as the attribute of the second RDN with the same attribute type. +// The order of attributes is not significant. +// Case of attribute types is not significant. +func (r *RelativeDN) Equal(other *RelativeDN) bool { + if len(r.Attributes) != len(other.Attributes) { + return false + } + return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes) +} + +func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool { + for _, attr := range attrs { + found := false + for _, myattr := range r.Attributes { + if myattr.Equal(attr) { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +// Equal returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue +// Case of the attribute type is not significant +func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool { + return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value +} diff --git a/v3/dn_test.go b/v3/dn_test.go new file mode 100644 index 00000000..511c8343 --- /dev/null +++ b/v3/dn_test.go @@ -0,0 +1,209 @@ +package ldap + +import ( + "reflect" + "testing" +) + +func TestSuccessfulDNParsing(t *testing.T) { + testcases := map[string]DN{ + "": DN{[]*RelativeDN{}}, + "cn=Jim\\2C \\22Hasse Hö\\22 Hansson!,dc=dummy,dc=com": DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"cn", "Jim, \"Hasse Hö\" Hansson!"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"dc", "dummy"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"dc", "com"}}}}}, + "UID=jsmith,DC=example,DC=net": DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"UID", "jsmith"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "example"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "net"}}}}}, + "OU=Sales+CN=J. Smith,DC=example,DC=net": DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{ + &AttributeTypeAndValue{"OU", "Sales"}, + &AttributeTypeAndValue{"CN", "J. Smith"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "example"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "net"}}}}}, + "1.3.6.1.4.1.1466.0=#04024869": DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"1.3.6.1.4.1.1466.0", "Hi"}}}}}, + "1.3.6.1.4.1.1466.0=#04024869,DC=net": DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"1.3.6.1.4.1.1466.0", "Hi"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "net"}}}}}, + "CN=Lu\\C4\\8Di\\C4\\87": DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"CN", "Lučić"}}}}}, + " CN = Lu\\C4\\8Di\\C4\\87 ": DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"CN", "Lučić"}}}}}, + ` A = 1 , B = 2 `: DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"A", "1"}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"B", "2"}}}}}, + ` A = 1 + B = 2 `: DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{ + &AttributeTypeAndValue{"A", "1"}, + &AttributeTypeAndValue{"B", "2"}}}}}, + ` \ \ A\ \ = \ \ 1\ \ , \ \ B\ \ = \ \ 2\ \ `: DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{" A ", " 1 "}}}, + &RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{" B ", " 2 "}}}}}, + ` \ \ A\ \ = \ \ 1\ \ + \ \ B\ \ = \ \ 2\ \ `: DN{[]*RelativeDN{ + &RelativeDN{[]*AttributeTypeAndValue{ + &AttributeTypeAndValue{" A ", " 1 "}, + &AttributeTypeAndValue{" B ", " 2 "}}}}}, + } + + for test, answer := range testcases { + dn, err := ParseDN(test) + if err != nil { + t.Errorf(err.Error()) + continue + } + if !reflect.DeepEqual(dn, &answer) { + t.Errorf("Parsed DN %s is not equal to the expected structure", test) + t.Logf("Expected:") + for _, rdn := range answer.RDNs { + for _, attribs := range rdn.Attributes { + t.Logf("#%v\n", attribs) + } + } + t.Logf("Actual:") + for _, rdn := range dn.RDNs { + for _, attribs := range rdn.Attributes { + t.Logf("#%v\n", attribs) + } + } + } + } +} + +func TestErrorDNParsing(t *testing.T) { + testcases := map[string]string{ + "*": "DN ended with incomplete type, value pair", + "cn=Jim\\0Test": "failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'", + "cn=Jim\\0": "got corrupted escaped character", + "DC=example,=net": "DN ended with incomplete type, value pair", + "1=#0402486": "failed to decode BER encoding: encoding/hex: odd length hex string", + "test,DC=example,DC=com": "incomplete type, value pair", + "=test,DC=example,DC=com": "incomplete type, value pair", + } + + for test, answer := range testcases { + _, err := ParseDN(test) + if err == nil { + t.Errorf("Expected %s to fail parsing but succeeded\n", test) + } else if err.Error() != answer { + t.Errorf("Unexpected error on %s:\n%s\nvs.\n%s\n", test, answer, err.Error()) + } + } +} + +func TestDNEqual(t *testing.T) { + testcases := []struct { + A string + B string + Equal bool + }{ + // Exact match + {"", "", true}, + {"o=A", "o=A", true}, + {"o=A", "o=B", false}, + + {"o=A,o=B", "o=A,o=B", true}, + {"o=A,o=B", "o=A,o=C", false}, + + {"o=A+o=B", "o=A+o=B", true}, + {"o=A+o=B", "o=A+o=C", false}, + + // Case mismatch in type is ignored + {"o=A", "O=A", true}, + {"o=A,o=B", "o=A,O=B", true}, + {"o=A+o=B", "o=A+O=B", true}, + + // Case mismatch in value is significant + {"o=a", "O=A", false}, + {"o=a,o=B", "o=A,O=B", false}, + {"o=a+o=B", "o=A+O=B", false}, + + // Multi-valued RDN order mismatch is ignored + {"o=A+o=B", "O=B+o=A", true}, + // Number of RDN attributes is significant + {"o=A+o=B", "O=B+o=A+O=B", false}, + + // Missing values are significant + {"o=A+o=B", "O=B+o=A+O=C", false}, // missing values matter + {"o=A+o=B+o=C", "O=B+o=A", false}, // missing values matter + + // Whitespace tests + // Matching + { + "cn=John Doe, ou=People, dc=sun.com", + "cn=John Doe, ou=People, dc=sun.com", + true, + }, + // Difference in leading/trailing chars is ignored + { + "cn=John Doe, ou=People, dc=sun.com", + "cn=John Doe,ou=People,dc=sun.com", + true, + }, + // Difference in values is significant + { + "cn=John Doe, ou=People, dc=sun.com", + "cn=John Doe, ou=People, dc=sun.com", + false, + }, + } + + for i, tc := range testcases { + a, err := ParseDN(tc.A) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + b, err := ParseDN(tc.B) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + if expected, actual := tc.Equal, a.Equal(b); expected != actual { + t.Errorf("%d: when comparing '%s' and '%s' expected %v, got %v", i, tc.A, tc.B, expected, actual) + continue + } + if expected, actual := tc.Equal, b.Equal(a); expected != actual { + t.Errorf("%d: when comparing '%s' and '%s' expected %v, got %v", i, tc.A, tc.B, expected, actual) + continue + } + } +} + +func TestDNAncestor(t *testing.T) { + testcases := []struct { + A string + B string + Ancestor bool + }{ + // Exact match returns false + {"", "", false}, + {"o=A", "o=A", false}, + {"o=A,o=B", "o=A,o=B", false}, + {"o=A+o=B", "o=A+o=B", false}, + + // Mismatch + {"ou=C,ou=B,o=A", "ou=E,ou=D,ou=B,o=A", false}, + + // Descendant + {"ou=C,ou=B,o=A", "ou=E,ou=C,ou=B,o=A", true}, + } + + for i, tc := range testcases { + a, err := ParseDN(tc.A) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + b, err := ParseDN(tc.B) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + if expected, actual := tc.Ancestor, a.AncestorOf(b); expected != actual { + t.Errorf("%d: when comparing '%s' and '%s' expected %v, got %v", i, tc.A, tc.B, expected, actual) + continue + } + } +} diff --git a/v3/doc.go b/v3/doc.go new file mode 100644 index 00000000..f20d39bc --- /dev/null +++ b/v3/doc.go @@ -0,0 +1,4 @@ +/* +Package ldap provides basic LDAP v3 functionality. +*/ +package ldap diff --git a/v3/error.go b/v3/error.go new file mode 100644 index 00000000..b1fda2d8 --- /dev/null +++ b/v3/error.go @@ -0,0 +1,236 @@ +package ldap + +import ( + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// LDAP Result Codes +const ( + LDAPResultSuccess = 0 + LDAPResultOperationsError = 1 + LDAPResultProtocolError = 2 + LDAPResultTimeLimitExceeded = 3 + LDAPResultSizeLimitExceeded = 4 + LDAPResultCompareFalse = 5 + LDAPResultCompareTrue = 6 + LDAPResultAuthMethodNotSupported = 7 + LDAPResultStrongAuthRequired = 8 + LDAPResultReferral = 10 + LDAPResultAdminLimitExceeded = 11 + LDAPResultUnavailableCriticalExtension = 12 + LDAPResultConfidentialityRequired = 13 + LDAPResultSaslBindInProgress = 14 + LDAPResultNoSuchAttribute = 16 + LDAPResultUndefinedAttributeType = 17 + LDAPResultInappropriateMatching = 18 + LDAPResultConstraintViolation = 19 + LDAPResultAttributeOrValueExists = 20 + LDAPResultInvalidAttributeSyntax = 21 + LDAPResultNoSuchObject = 32 + LDAPResultAliasProblem = 33 + LDAPResultInvalidDNSyntax = 34 + LDAPResultIsLeaf = 35 + LDAPResultAliasDereferencingProblem = 36 + LDAPResultInappropriateAuthentication = 48 + LDAPResultInvalidCredentials = 49 + LDAPResultInsufficientAccessRights = 50 + LDAPResultBusy = 51 + LDAPResultUnavailable = 52 + LDAPResultUnwillingToPerform = 53 + LDAPResultLoopDetect = 54 + LDAPResultSortControlMissing = 60 + LDAPResultOffsetRangeError = 61 + LDAPResultNamingViolation = 64 + LDAPResultObjectClassViolation = 65 + LDAPResultNotAllowedOnNonLeaf = 66 + LDAPResultNotAllowedOnRDN = 67 + LDAPResultEntryAlreadyExists = 68 + LDAPResultObjectClassModsProhibited = 69 + LDAPResultResultsTooLarge = 70 + LDAPResultAffectsMultipleDSAs = 71 + LDAPResultVirtualListViewErrorOrControlError = 76 + LDAPResultOther = 80 + LDAPResultServerDown = 81 + LDAPResultLocalError = 82 + LDAPResultEncodingError = 83 + LDAPResultDecodingError = 84 + LDAPResultTimeout = 85 + LDAPResultAuthUnknown = 86 + LDAPResultFilterError = 87 + LDAPResultUserCanceled = 88 + LDAPResultParamError = 89 + LDAPResultNoMemory = 90 + LDAPResultConnectError = 91 + LDAPResultNotSupported = 92 + LDAPResultControlNotFound = 93 + LDAPResultNoResultsReturned = 94 + LDAPResultMoreResultsToReturn = 95 + LDAPResultClientLoop = 96 + LDAPResultReferralLimitExceeded = 97 + LDAPResultInvalidResponse = 100 + LDAPResultAmbiguousResponse = 101 + LDAPResultTLSNotSupported = 112 + LDAPResultIntermediateResponse = 113 + LDAPResultUnknownType = 114 + LDAPResultCanceled = 118 + LDAPResultNoSuchOperation = 119 + LDAPResultTooLate = 120 + LDAPResultCannotCancel = 121 + LDAPResultAssertionFailed = 122 + LDAPResultAuthorizationDenied = 123 + LDAPResultSyncRefreshRequired = 4096 + + ErrorNetwork = 200 + ErrorFilterCompile = 201 + ErrorFilterDecompile = 202 + ErrorDebugging = 203 + ErrorUnexpectedMessage = 204 + ErrorUnexpectedResponse = 205 + ErrorEmptyPassword = 206 +) + +// LDAPResultCodeMap contains string descriptions for LDAP error codes +var LDAPResultCodeMap = map[uint16]string{ + LDAPResultSuccess: "Success", + LDAPResultOperationsError: "Operations Error", + LDAPResultProtocolError: "Protocol Error", + LDAPResultTimeLimitExceeded: "Time Limit Exceeded", + LDAPResultSizeLimitExceeded: "Size Limit Exceeded", + LDAPResultCompareFalse: "Compare False", + LDAPResultCompareTrue: "Compare True", + LDAPResultAuthMethodNotSupported: "Auth Method Not Supported", + LDAPResultStrongAuthRequired: "Strong Auth Required", + LDAPResultReferral: "Referral", + LDAPResultAdminLimitExceeded: "Admin Limit Exceeded", + LDAPResultUnavailableCriticalExtension: "Unavailable Critical Extension", + LDAPResultConfidentialityRequired: "Confidentiality Required", + LDAPResultSaslBindInProgress: "Sasl Bind In Progress", + LDAPResultNoSuchAttribute: "No Such Attribute", + LDAPResultUndefinedAttributeType: "Undefined Attribute Type", + LDAPResultInappropriateMatching: "Inappropriate Matching", + LDAPResultConstraintViolation: "Constraint Violation", + LDAPResultAttributeOrValueExists: "Attribute Or Value Exists", + LDAPResultInvalidAttributeSyntax: "Invalid Attribute Syntax", + LDAPResultNoSuchObject: "No Such Object", + LDAPResultAliasProblem: "Alias Problem", + LDAPResultInvalidDNSyntax: "Invalid DN Syntax", + LDAPResultIsLeaf: "Is Leaf", + LDAPResultAliasDereferencingProblem: "Alias Dereferencing Problem", + LDAPResultInappropriateAuthentication: "Inappropriate Authentication", + LDAPResultInvalidCredentials: "Invalid Credentials", + LDAPResultInsufficientAccessRights: "Insufficient Access Rights", + LDAPResultBusy: "Busy", + LDAPResultUnavailable: "Unavailable", + LDAPResultUnwillingToPerform: "Unwilling To Perform", + LDAPResultLoopDetect: "Loop Detect", + LDAPResultSortControlMissing: "Sort Control Missing", + LDAPResultOffsetRangeError: "Result Offset Range Error", + LDAPResultNamingViolation: "Naming Violation", + LDAPResultObjectClassViolation: "Object Class Violation", + LDAPResultResultsTooLarge: "Results Too Large", + LDAPResultNotAllowedOnNonLeaf: "Not Allowed On Non Leaf", + LDAPResultNotAllowedOnRDN: "Not Allowed On RDN", + LDAPResultEntryAlreadyExists: "Entry Already Exists", + LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited", + LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs", + LDAPResultVirtualListViewErrorOrControlError: "Failed because of a problem related to the virtual list view", + LDAPResultOther: "Other", + LDAPResultServerDown: "Cannot establish a connection", + LDAPResultLocalError: "An error occurred", + LDAPResultEncodingError: "LDAP encountered an error while encoding", + LDAPResultDecodingError: "LDAP encountered an error while decoding", + LDAPResultTimeout: "LDAP timeout while waiting for a response from the server", + LDAPResultAuthUnknown: "The auth method requested in a bind request is unknown", + LDAPResultFilterError: "An error occurred while encoding the given search filter", + LDAPResultUserCanceled: "The user canceled the operation", + LDAPResultParamError: "An invalid parameter was specified", + LDAPResultNoMemory: "Out of memory error", + LDAPResultConnectError: "A connection to the server could not be established", + LDAPResultNotSupported: "An attempt has been made to use a feature not supported LDAP", + LDAPResultControlNotFound: "The controls required to perform the requested operation were not found", + LDAPResultNoResultsReturned: "No results were returned from the server", + LDAPResultMoreResultsToReturn: "There are more results in the chain of results", + LDAPResultClientLoop: "A loop has been detected. For example when following referrals", + LDAPResultReferralLimitExceeded: "The referral hop limit has been exceeded", + LDAPResultCanceled: "Operation was canceled", + LDAPResultNoSuchOperation: "Server has no knowledge of the operation requested for cancellation", + LDAPResultTooLate: "Too late to cancel the outstanding operation", + LDAPResultCannotCancel: "The identified operation does not support cancellation or the cancel operation cannot be performed", + LDAPResultAssertionFailed: "An assertion control given in the LDAP operation evaluated to false causing the operation to not be performed", + LDAPResultSyncRefreshRequired: "Refresh Required", + LDAPResultInvalidResponse: "Invalid Response", + LDAPResultAmbiguousResponse: "Ambiguous Response", + LDAPResultTLSNotSupported: "Tls Not Supported", + LDAPResultIntermediateResponse: "Intermediate Response", + LDAPResultUnknownType: "Unknown Type", + LDAPResultAuthorizationDenied: "Authorization Denied", + + ErrorNetwork: "Network Error", + ErrorFilterCompile: "Filter Compile Error", + ErrorFilterDecompile: "Filter Decompile Error", + ErrorDebugging: "Debugging Error", + ErrorUnexpectedMessage: "Unexpected Message", + ErrorUnexpectedResponse: "Unexpected Response", + ErrorEmptyPassword: "Empty password not allowed by the client", +} + +// Error holds LDAP error information +type Error struct { + // Err is the underlying error + Err error + // ResultCode is the LDAP error code + ResultCode uint16 + // MatchedDN is the matchedDN returned if any + MatchedDN string +} + +func (e *Error) Error() string { + return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error()) +} + +// GetLDAPError creates an Error out of a BER packet representing a LDAPResult +// The return is an error object. It can be casted to a Error structure. +// This function returns nil if resultCode in the LDAPResult sequence is success(0). +func GetLDAPError(packet *ber.Packet) error { + if packet == nil { + return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")} + } + + if len(packet.Children) >= 2 { + response := packet.Children[1] + if response == nil { + return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")} + } + if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 { + resultCode := uint16(response.Children[0].Value.(int64)) + if resultCode == 0 { // No error + return nil + } + return &Error{ResultCode: resultCode, MatchedDN: response.Children[1].Value.(string), + Err: fmt.Errorf("%s", response.Children[2].Value.(string))} + } + } + + return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format")} +} + +// NewError creates an LDAP error with the given code and underlying error +func NewError(resultCode uint16, err error) error { + return &Error{ResultCode: resultCode, Err: err} +} + +// IsErrorWithCode returns true if the given error is an LDAP error with the given result code +func IsErrorWithCode(err error, desiredResultCode uint16) bool { + if err == nil { + return false + } + + serverError, ok := err.(*Error) + if !ok { + return false + } + + return serverError.ResultCode == desiredResultCode +} diff --git a/v3/error_test.go b/v3/error_test.go new file mode 100644 index 00000000..c547e454 --- /dev/null +++ b/v3/error_test.go @@ -0,0 +1,141 @@ +package ldap + +import ( + "errors" + "net" + "strings" + "testing" + "time" + + "github.com/go-asn1-ber/asn1-ber" +) + +// TestNilPacket tests that nil packets don't cause a panic. +func TestNilPacket(t *testing.T) { + // Test for nil packet + err := GetLDAPError(nil) + if !IsErrorWithCode(err, ErrorUnexpectedResponse) { + t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", err) + } + + // Test for nil result + kids := []*ber.Packet{ + {}, // Unused + nil, // Can't be nil + } + pack := &ber.Packet{Children: kids} + err = GetLDAPError(pack) + + if !IsErrorWithCode(err, ErrorUnexpectedResponse) { + t.Errorf("Should have an 'ErrorUnexpectedResponse' error in nil packets, got: %v", err) + } +} + +// TestConnReadErr tests that an unexpected error reading from underlying +// connection bubbles up to the goroutine which makes a request. +func TestConnReadErr(t *testing.T) { + conn := &signalErrConn{ + signals: make(chan error), + } + + ldapConn := NewConn(conn, false) + ldapConn.Start() + + // Make a dummy search request. + searchReq := NewSearchRequest("dc=example,dc=com", ScopeWholeSubtree, DerefAlways, 0, 0, false, "(objectClass=*)", nil, nil) + + expectedError := errors.New("this is the error you are looking for") + + // Send the signal after a short amount of time. + time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError }) + + // This should block until the underlying conn gets the error signal + // which should bubble up through the reader() goroutine, close the + // connection, and + _, err := ldapConn.Search(searchReq) + if err == nil || !strings.Contains(err.Error(), expectedError.Error()) { + t.Errorf("not the expected error: %s", err) + } +} + +// TestGetLDAPError tests parsing of result with a error response. +func TestGetLDAPError(t *testing.T) { + diagnosticMessage := "Detailed error message" + bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") + bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(LDAPResultInvalidCredentials), "resultCode")) + bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "dc=example,dc=org", "matchedDN")) + bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, diagnosticMessage, "diagnosticMessage")) + packet := ber.NewSequence("LDAPMessage") + packet.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "messageID")) + packet.AppendChild(bindResponse) + err := GetLDAPError(packet) + if err == nil { + t.Errorf("Did not get error response") + } + + ldapError := err.(*Error) + if ldapError.ResultCode != LDAPResultInvalidCredentials { + t.Errorf("Got incorrect error code in LDAP error; got %v, expected %v", ldapError.ResultCode, LDAPResultInvalidCredentials) + } + if ldapError.Err.Error() != diagnosticMessage { + t.Errorf("Got incorrect error message in LDAP error; got %v, expected %v", ldapError.Err.Error(), diagnosticMessage) + } +} + +// TestGetLDAPErrorSuccess tests parsing of a result with no error (resultCode == 0). +func TestGetLDAPErrorSuccess(t *testing.T) { + bindResponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") + bindResponse.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "resultCode")) + bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN")) + bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "diagnosticMessage")) + packet := ber.NewSequence("LDAPMessage") + packet.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(0), "messageID")) + packet.AppendChild(bindResponse) + err := GetLDAPError(packet) + if err != nil { + t.Errorf("Successful responses should not produce an error, but got: %v", err) + } +} + +// signalErrConn is a helpful type used with TestConnReadErr. It implements the +// net.Conn interface to be used as a connection for the test. Most methods are +// no-ops but the Read() method blocks until it receives a signal which it +// returns as an error. +type signalErrConn struct { + signals chan error +} + +// Read blocks until an error is sent on the internal signals channel. That +// error is returned. +func (c *signalErrConn) Read(b []byte) (n int, err error) { + return 0, <-c.signals +} + +func (c *signalErrConn) Write(b []byte) (n int, err error) { + return len(b), nil +} + +func (c *signalErrConn) Close() error { + close(c.signals) + return nil +} + +func (c *signalErrConn) LocalAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *signalErrConn) RemoteAddr() net.Addr { + return (*net.TCPAddr)(nil) +} + +func (c *signalErrConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *signalErrConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *signalErrConn) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/v3/example_test.go b/v3/example_test.go new file mode 100644 index 00000000..4f87705b --- /dev/null +++ b/v3/example_test.go @@ -0,0 +1,342 @@ +package ldap + +import ( + "crypto/tls" + "fmt" + "log" +) + +// ExampleConn_Bind demonstrates how to bind a connection to an ldap user +// allowing access to restricted attributes that user has access to +func ExampleConn_Bind() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + err = l.Bind("cn=read-only-admin,dc=example,dc=com", "password") + if err != nil { + log.Fatal(err) + } +} + +// ExampleConn_Search demonstrates how to use the search interface +func ExampleConn_Search() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + sr, err := l.Search(searchRequest) + if err != nil { + log.Fatal(err) + } + + for _, entry := range sr.Entries { + fmt.Printf("%s: %v\n", entry.DN, entry.GetAttributeValue("cn")) + } +} + +// ExampleStartTLS demonstrates how to start a TLS connection +func ExampleConn_StartTLS() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + // Reconnect with TLS + err = l.StartTLS(&tls.Config{InsecureSkipVerify: true}) + if err != nil { + log.Fatal(err) + } + + // Operations via l are now encrypted +} + +// ExampleConn_Compare demonstrates how to compare an attribute with a value +func ExampleConn_Compare() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + matched, err := l.Compare("cn=user,dc=example,dc=com", "uid", "someuserid") + if err != nil { + log.Fatal(err) + } + + fmt.Println(matched) +} + +func ExampleConn_PasswordModify_admin() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + err = l.Bind("cn=admin,dc=example,dc=com", "password") + if err != nil { + log.Fatal(err) + } + + passwordModifyRequest := NewPasswordModifyRequest("cn=user,dc=example,dc=com", "", "NewPassword") + _, err = l.PasswordModify(passwordModifyRequest) + + if err != nil { + log.Fatalf("Password could not be changed: %s", err.Error()) + } +} + +func ExampleConn_PasswordModify_generatedPassword() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + err = l.Bind("cn=user,dc=example,dc=com", "password") + if err != nil { + log.Fatal(err) + } + + passwordModifyRequest := NewPasswordModifyRequest("", "OldPassword", "") + passwordModifyResponse, err := l.PasswordModify(passwordModifyRequest) + if err != nil { + log.Fatalf("Password could not be changed: %s", err.Error()) + } + + generatedPassword := passwordModifyResponse.GeneratedPassword + log.Printf("Generated password: %s\n", generatedPassword) +} + +func ExampleConn_PasswordModify_setNewPassword() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + err = l.Bind("cn=user,dc=example,dc=com", "password") + if err != nil { + log.Fatal(err) + } + + passwordModifyRequest := NewPasswordModifyRequest("", "OldPassword", "NewPassword") + _, err = l.PasswordModify(passwordModifyRequest) + + if err != nil { + log.Fatalf("Password could not be changed: %s", err.Error()) + } +} + +func ExampleConn_Modify() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + // Add a description, and replace the mail attributes + modify := NewModifyRequest("cn=user,dc=example,dc=com", nil) + modify.Add("description", []string{"An example user"}) + modify.Replace("mail", []string{"user@example.org"}) + + err = l.Modify(modify) + if err != nil { + log.Fatal(err) + } +} + +// Example User Authentication shows how a typical application can verify a login attempt +func Example_userAuthentication() { + // The username and password we want to check + username := "someuser" + password := "userpassword" + + bindusername := "readonly" + bindpassword := "password" + + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + // Reconnect with TLS + err = l.StartTLS(&tls.Config{InsecureSkipVerify: true}) + if err != nil { + log.Fatal(err) + } + + // First bind with a read only user + err = l.Bind(bindusername, bindpassword) + if err != nil { + log.Fatal(err) + } + + // Search for the given username + searchRequest := NewSearchRequest( + "dc=example,dc=com", + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + fmt.Sprintf("(&(objectClass=organizationalPerson)(uid=%s))", username), + []string{"dn"}, + nil, + ) + + sr, err := l.Search(searchRequest) + if err != nil { + log.Fatal(err) + } + + if len(sr.Entries) != 1 { + log.Fatal("User does not exist or too many entries returned") + } + + userdn := sr.Entries[0].DN + + // Bind as the user to verify their password + err = l.Bind(userdn, password) + if err != nil { + log.Fatal(err) + } + + // Rebind as the read only user for any further queries + err = l.Bind(bindusername, bindpassword) + if err != nil { + log.Fatal(err) + } +} + +func Example_beherappolicy() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + controls := []Control{} + controls = append(controls, NewControlBeheraPasswordPolicy()) + bindRequest := NewSimpleBindRequest("cn=admin,dc=example,dc=com", "password", controls) + + r, err := l.SimpleBind(bindRequest) + ppolicyControl := FindControl(r.Controls, ControlTypeBeheraPasswordPolicy) + + var ppolicy *ControlBeheraPasswordPolicy + if ppolicyControl != nil { + ppolicy = ppolicyControl.(*ControlBeheraPasswordPolicy) + } else { + log.Printf("ppolicyControl response not available.\n") + } + if err != nil { + errStr := "ERROR: Cannot bind: " + err.Error() + if ppolicy != nil && ppolicy.Error >= 0 { + errStr += ":" + ppolicy.ErrorString + } + log.Print(errStr) + } else { + logStr := "Login Ok" + if ppolicy != nil { + if ppolicy.Expire >= 0 { + logStr += fmt.Sprintf(". Password expires in %d seconds\n", ppolicy.Expire) + } else if ppolicy.Grace >= 0 { + logStr += fmt.Sprintf(". Password expired, %d grace logins remain\n", ppolicy.Grace) + } + } + log.Print(logStr) + } +} + +func Example_vchuppolicy() { + l, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + l.Debug = true + + bindRequest := NewSimpleBindRequest("cn=admin,dc=example,dc=com", "password", nil) + + r, err := l.SimpleBind(bindRequest) + + passwordMustChangeControl := FindControl(r.Controls, ControlTypeVChuPasswordMustChange) + var passwordMustChange *ControlVChuPasswordMustChange + if passwordMustChangeControl != nil { + passwordMustChange = passwordMustChangeControl.(*ControlVChuPasswordMustChange) + } + + if passwordMustChange != nil && passwordMustChange.MustChange { + log.Printf("Password Must be changed.\n") + } + + passwordWarningControl := FindControl(r.Controls, ControlTypeVChuPasswordWarning) + + var passwordWarning *ControlVChuPasswordWarning + if passwordWarningControl != nil { + passwordWarning = passwordWarningControl.(*ControlVChuPasswordWarning) + } else { + log.Printf("ppolicyControl response not available.\n") + } + if err != nil { + log.Print("ERROR: Cannot bind: " + err.Error()) + } else { + logStr := "Login Ok" + if passwordWarning != nil { + if passwordWarning.Expire >= 0 { + logStr += fmt.Sprintf(". Password expires in %d seconds\n", passwordWarning.Expire) + } + } + log.Print(logStr) + } +} + +// This example demonstrates how to use ControlPaging to manually execute a +// paginated search request instead of using SearchWithPaging. +func ExampleControlPaging_manualPaging() { + conn, err := Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer conn.Close() + + var pageSize uint32 = 32 + searchBase := "dc=example,dc=com" + filter := "(objectClass=group)" + pagingControl := NewControlPaging(pageSize) + attributes := []string{} + controls := []Control{pagingControl} + + for { + request := NewSearchRequest(searchBase, ScopeWholeSubtree, DerefAlways, 0, 0, false, filter, attributes, controls) + response, err := conn.Search(request) + if err != nil { + log.Fatalf("Failed to execute search request: %s", err.Error()) + } + + // [do something with the response entries] + + // In order to prepare the next request, we check if the response + // contains another ControlPaging object and a not-empty cookie and + // copy that cookie into our pagingControl object: + updatedControl := FindControl(response.Controls, ControlTypePaging) + if ctrl, ok := updatedControl.(*ControlPaging); ctrl != nil && ok && len(ctrl.Cookie) != 0 { + pagingControl.SetCookie(ctrl.Cookie) + continue + } + // If no new paging information is available or the cookie is empty, we + // are done with the pagination. + break + } +} diff --git a/v3/filter.go b/v3/filter.go new file mode 100644 index 00000000..a3875506 --- /dev/null +++ b/v3/filter.go @@ -0,0 +1,465 @@ +package ldap + +import ( + "bytes" + hexpac "encoding/hex" + "errors" + "fmt" + "strings" + "unicode/utf8" + + "github.com/go-asn1-ber/asn1-ber" +) + +// Filter choices +const ( + FilterAnd = 0 + FilterOr = 1 + FilterNot = 2 + FilterEqualityMatch = 3 + FilterSubstrings = 4 + FilterGreaterOrEqual = 5 + FilterLessOrEqual = 6 + FilterPresent = 7 + FilterApproxMatch = 8 + FilterExtensibleMatch = 9 +) + +// FilterMap contains human readable descriptions of Filter choices +var FilterMap = map[uint64]string{ + FilterAnd: "And", + FilterOr: "Or", + FilterNot: "Not", + FilterEqualityMatch: "Equality Match", + FilterSubstrings: "Substrings", + FilterGreaterOrEqual: "Greater Or Equal", + FilterLessOrEqual: "Less Or Equal", + FilterPresent: "Present", + FilterApproxMatch: "Approx Match", + FilterExtensibleMatch: "Extensible Match", +} + +// SubstringFilter options +const ( + FilterSubstringsInitial = 0 + FilterSubstringsAny = 1 + FilterSubstringsFinal = 2 +) + +// FilterSubstringsMap contains human readable descriptions of SubstringFilter choices +var FilterSubstringsMap = map[uint64]string{ + FilterSubstringsInitial: "Substrings Initial", + FilterSubstringsAny: "Substrings Any", + FilterSubstringsFinal: "Substrings Final", +} + +// MatchingRuleAssertion choices +const ( + MatchingRuleAssertionMatchingRule = 1 + MatchingRuleAssertionType = 2 + MatchingRuleAssertionMatchValue = 3 + MatchingRuleAssertionDNAttributes = 4 +) + +// MatchingRuleAssertionMap contains human readable descriptions of MatchingRuleAssertion choices +var MatchingRuleAssertionMap = map[uint64]string{ + MatchingRuleAssertionMatchingRule: "Matching Rule Assertion Matching Rule", + MatchingRuleAssertionType: "Matching Rule Assertion Type", + MatchingRuleAssertionMatchValue: "Matching Rule Assertion Match Value", + MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes", +} + +// CompileFilter converts a string representation of a filter into a BER-encoded packet +func CompileFilter(filter string) (*ber.Packet, error) { + if len(filter) == 0 || filter[0] != '(' { + return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('")) + } + packet, pos, err := compileFilter(filter, 1) + if err != nil { + return nil, err + } + switch { + case pos > len(filter): + return nil, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter")) + case pos < len(filter): + return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:]))) + } + return packet, nil +} + +// DecompileFilter converts a packet representation of a filter into a string representation +func DecompileFilter(packet *ber.Packet) (ret string, err error) { + defer func() { + if r := recover(); r != nil { + err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter")) + } + }() + ret = "(" + err = nil + childStr := "" + + switch packet.Tag { + case FilterAnd: + ret += "&" + for _, child := range packet.Children { + childStr, err = DecompileFilter(child) + if err != nil { + return + } + ret += childStr + } + case FilterOr: + ret += "|" + for _, child := range packet.Children { + childStr, err = DecompileFilter(child) + if err != nil { + return + } + ret += childStr + } + case FilterNot: + ret += "!" + childStr, err = DecompileFilter(packet.Children[0]) + if err != nil { + return + } + ret += childStr + + case FilterSubstrings: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "=" + for i, child := range packet.Children[1].Children { + if i == 0 && child.Tag != FilterSubstringsInitial { + ret += "*" + } + ret += EscapeFilter(ber.DecodeString(child.Data.Bytes())) + if child.Tag != FilterSubstringsFinal { + ret += "*" + } + } + case FilterEqualityMatch: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "=" + ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + case FilterGreaterOrEqual: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += ">=" + ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + case FilterLessOrEqual: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "<=" + ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + case FilterPresent: + ret += ber.DecodeString(packet.Data.Bytes()) + ret += "=*" + case FilterApproxMatch: + ret += ber.DecodeString(packet.Children[0].Data.Bytes()) + ret += "~=" + ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())) + case FilterExtensibleMatch: + attr := "" + dnAttributes := false + matchingRule := "" + value := "" + + for _, child := range packet.Children { + switch child.Tag { + case MatchingRuleAssertionMatchingRule: + matchingRule = ber.DecodeString(child.Data.Bytes()) + case MatchingRuleAssertionType: + attr = ber.DecodeString(child.Data.Bytes()) + case MatchingRuleAssertionMatchValue: + value = ber.DecodeString(child.Data.Bytes()) + case MatchingRuleAssertionDNAttributes: + dnAttributes = child.Value.(bool) + } + } + + if len(attr) > 0 { + ret += attr + } + if dnAttributes { + ret += ":dn" + } + if len(matchingRule) > 0 { + ret += ":" + ret += matchingRule + } + ret += ":=" + ret += EscapeFilter(value) + } + + ret += ")" + return +} + +func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) { + for pos < len(filter) && filter[pos] == '(' { + child, newPos, err := compileFilter(filter, pos+1) + if err != nil { + return pos, err + } + pos = newPos + parent.AppendChild(child) + } + if pos == len(filter) { + return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter")) + } + + return pos + 1, nil +} + +func compileFilter(filter string, pos int) (*ber.Packet, int, error) { + var ( + packet *ber.Packet + err error + ) + + defer func() { + if r := recover(); r != nil { + err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter")) + } + }() + newPos := pos + + currentRune, currentWidth := utf8.DecodeRuneInString(filter[newPos:]) + + switch currentRune { + case utf8.RuneError: + return nil, 0, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos)) + case '(': + packet, newPos, err = compileFilter(filter, pos+currentWidth) + newPos++ + return packet, newPos, err + case '&': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd]) + newPos, err = compileFilterSet(filter, pos+currentWidth, packet) + return packet, newPos, err + case '|': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr]) + newPos, err = compileFilterSet(filter, pos+currentWidth, packet) + return packet, newPos, err + case '!': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot]) + var child *ber.Packet + child, newPos, err = compileFilter(filter, pos+currentWidth) + packet.AppendChild(child) + return packet, newPos, err + default: + const ( + stateReadingAttr = 0 + stateReadingExtensibleMatchingRule = 1 + stateReadingCondition = 2 + ) + + state := stateReadingAttr + + attribute := "" + extensibleDNAttributes := false + extensibleMatchingRule := "" + condition := "" + + for newPos < len(filter) { + remainingFilter := filter[newPos:] + currentRune, currentWidth = utf8.DecodeRuneInString(remainingFilter) + if currentRune == ')' { + break + } + if currentRune == utf8.RuneError { + return packet, newPos, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos)) + } + + switch state { + case stateReadingAttr: + switch { + // Extensible rule, with only DN-matching + case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:="): + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) + extensibleDNAttributes = true + state = stateReadingCondition + newPos += 5 + + // Extensible rule, with DN-matching and a matching OID + case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:"): + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) + extensibleDNAttributes = true + state = stateReadingExtensibleMatchingRule + newPos += 4 + + // Extensible rule, with attr only + case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="): + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) + state = stateReadingCondition + newPos += 2 + + // Extensible rule, with no DN attribute matching + case currentRune == ':': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch]) + state = stateReadingExtensibleMatchingRule + newPos++ + + // Equality condition + case currentRune == '=': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch]) + state = stateReadingCondition + newPos++ + + // Greater-than or equal + case currentRune == '>' && strings.HasPrefix(remainingFilter, ">="): + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual]) + state = stateReadingCondition + newPos += 2 + + // Less-than or equal + case currentRune == '<' && strings.HasPrefix(remainingFilter, "<="): + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual]) + state = stateReadingCondition + newPos += 2 + + // Approx + case currentRune == '~' && strings.HasPrefix(remainingFilter, "~="): + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterApproxMatch]) + state = stateReadingCondition + newPos += 2 + + // Still reading the attribute name + default: + attribute += fmt.Sprintf("%c", currentRune) + newPos += currentWidth + } + + case stateReadingExtensibleMatchingRule: + switch { + + // Matching rule OID is done + case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="): + state = stateReadingCondition + newPos += 2 + + // Still reading the matching rule oid + default: + extensibleMatchingRule += fmt.Sprintf("%c", currentRune) + newPos += currentWidth + } + + case stateReadingCondition: + // append to the condition + condition += fmt.Sprintf("%c", currentRune) + newPos += currentWidth + } + } + + if newPos == len(filter) { + err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter")) + return packet, newPos, err + } + if packet == nil { + err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter")) + return packet, newPos, err + } + + switch { + case packet.Tag == FilterExtensibleMatch: + // MatchingRuleAssertion ::= SEQUENCE { + // matchingRule [1] MatchingRuleID OPTIONAL, + // type [2] AttributeDescription OPTIONAL, + // matchValue [3] AssertionValue, + // dnAttributes [4] BOOLEAN DEFAULT FALSE + // } + + // Include the matching rule oid, if specified + if len(extensibleMatchingRule) > 0 { + packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule])) + } + + // Include the attribute, if specified + if len(attribute) > 0 { + packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType])) + } + + // Add the value (only required child) + encodedString, encodeErr := escapedStringToEncodedBytes(condition) + if encodeErr != nil { + return packet, newPos, encodeErr + } + packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchValue, encodedString, MatchingRuleAssertionMap[MatchingRuleAssertionMatchValue])) + + // Defaults to false, so only include in the sequence if true + if extensibleDNAttributes { + packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes])) + } + + case packet.Tag == FilterEqualityMatch && condition == "*": + packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent]) + case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"): + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) + packet.Tag = FilterSubstrings + packet.Description = FilterMap[uint64(packet.Tag)] + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") + parts := strings.Split(condition, "*") + for i, part := range parts { + if part == "" { + continue + } + var tag ber.Tag + switch i { + case 0: + tag = FilterSubstringsInitial + case len(parts) - 1: + tag = FilterSubstringsFinal + default: + tag = FilterSubstringsAny + } + encodedString, encodeErr := escapedStringToEncodedBytes(part) + if encodeErr != nil { + return packet, newPos, encodeErr + } + seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)])) + } + packet.AppendChild(seq) + default: + encodedString, encodeErr := escapedStringToEncodedBytes(condition) + if encodeErr != nil { + return packet, newPos, encodeErr + } + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition")) + } + + newPos += currentWidth + return packet, newPos, err + } +} + +// Convert from "ABC\xx\xx\xx" form to literal bytes for transport +func escapedStringToEncodedBytes(escapedString string) (string, error) { + var buffer bytes.Buffer + i := 0 + for i < len(escapedString) { + currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:]) + if currentRune == utf8.RuneError { + return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i)) + } + + // Check for escaped hex characters and convert them to their literal value for transport. + if currentRune == '\\' { + // http://tools.ietf.org/search/rfc4515 + // \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not + // being a member of UTF1SUBSET. + if i+2 > len(escapedString) { + return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter")) + } + escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3]) + if decodeErr != nil { + return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter")) + } + buffer.WriteByte(escByte[0]) + i += 2 // +1 from end of loop, so 3 total for \xx. + } else { + buffer.WriteRune(currentRune) + } + + i += currentWidth + } + return buffer.String(), nil +} diff --git a/v3/filter_test.go b/v3/filter_test.go new file mode 100644 index 00000000..45796325 --- /dev/null +++ b/v3/filter_test.go @@ -0,0 +1,254 @@ +package ldap + +import ( + "strings" + "testing" + + "github.com/go-asn1-ber/asn1-ber" +) + +type compileTest struct { + filterStr string + + expectedFilter string + expectedType int + expectedErr string +} + +var testFilters = []compileTest{ + compileTest{ + filterStr: "(&(sn=Miller)(givenName=Bob))", + expectedFilter: "(&(sn=Miller)(givenName=Bob))", + expectedType: FilterAnd, + }, + compileTest{ + filterStr: "(|(sn=Miller)(givenName=Bob))", + expectedFilter: "(|(sn=Miller)(givenName=Bob))", + expectedType: FilterOr, + }, + compileTest{ + filterStr: "(!(sn=Miller))", + expectedFilter: "(!(sn=Miller))", + expectedType: FilterNot, + }, + compileTest{ + filterStr: "(sn=Miller)", + expectedFilter: "(sn=Miller)", + expectedType: FilterEqualityMatch, + }, + compileTest{ + filterStr: "(sn=Mill*)", + expectedFilter: "(sn=Mill*)", + expectedType: FilterSubstrings, + }, + compileTest{ + filterStr: "(sn=*Mill)", + expectedFilter: "(sn=*Mill)", + expectedType: FilterSubstrings, + }, + compileTest{ + filterStr: "(sn=*Mill*)", + expectedFilter: "(sn=*Mill*)", + expectedType: FilterSubstrings, + }, + compileTest{ + filterStr: "(sn=*i*le*)", + expectedFilter: "(sn=*i*le*)", + expectedType: FilterSubstrings, + }, + compileTest{ + filterStr: "(sn=Mi*l*r)", + expectedFilter: "(sn=Mi*l*r)", + expectedType: FilterSubstrings, + }, + // substring filters escape properly + compileTest{ + filterStr: `(sn=Mi*함*r)`, + expectedFilter: `(sn=Mi*\ed\95\a8*r)`, + expectedType: FilterSubstrings, + }, + // already escaped substring filters don't get double-escaped + compileTest{ + filterStr: `(sn=Mi*\ed\95\a8*r)`, + expectedFilter: `(sn=Mi*\ed\95\a8*r)`, + expectedType: FilterSubstrings, + }, + compileTest{ + filterStr: "(sn=Mi*le*)", + expectedFilter: "(sn=Mi*le*)", + expectedType: FilterSubstrings, + }, + compileTest{ + filterStr: "(sn=*i*ler)", + expectedFilter: "(sn=*i*ler)", + expectedType: FilterSubstrings, + }, + compileTest{ + filterStr: "(sn>=Miller)", + expectedFilter: "(sn>=Miller)", + expectedType: FilterGreaterOrEqual, + }, + compileTest{ + filterStr: "(sn<=Miller)", + expectedFilter: "(sn<=Miller)", + expectedType: FilterLessOrEqual, + }, + compileTest{ + filterStr: "(sn=*)", + expectedFilter: "(sn=*)", + expectedType: FilterPresent, + }, + compileTest{ + filterStr: "(sn~=Miller)", + expectedFilter: "(sn~=Miller)", + expectedType: FilterApproxMatch, + }, + compileTest{ + filterStr: `(objectGUID='\fc\fe\a3\ab\f9\90N\aaGm\d5I~\d12)`, + expectedFilter: `(objectGUID='\fc\fe\a3\ab\f9\90N\aaGm\d5I~\d12)`, + expectedType: FilterEqualityMatch, + }, + compileTest{ + filterStr: `(objectGUID=абвгдеёжзийклмнопрстуфхцчшщъыьэюя)`, + expectedFilter: `(objectGUID=\d0\b0\d0\b1\d0\b2\d0\b3\d0\b4\d0\b5\d1\91\d0\b6\d0\b7\d0\b8\d0\b9\d0\ba\d0\bb\d0\bc\d0\bd\d0\be\d0\bf\d1\80\d1\81\d1\82\d1\83\d1\84\d1\85\d1\86\d1\87\d1\88\d1\89\d1\8a\d1\8b\d1\8c\d1\8d\d1\8e\d1\8f)`, + expectedType: FilterEqualityMatch, + }, + compileTest{ + filterStr: `(objectGUID=함수목록)`, + expectedFilter: `(objectGUID=\ed\95\a8\ec\88\98\eb\aa\a9\eb\a1\9d)`, + expectedType: FilterEqualityMatch, + }, + compileTest{ + filterStr: `(objectGUID=`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + compileTest{ + filterStr: `(objectGUID=함수목록`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + compileTest{ + filterStr: `((cn=)`, + expectedFilter: ``, + expectedType: 0, + expectedErr: "unexpected end of filter", + }, + compileTest{ + filterStr: `(&(objectclass=inetorgperson)(cn=中文))`, + expectedFilter: `(&(objectclass=inetorgperson)(cn=\e4\b8\ad\e6\96\87))`, + expectedType: 0, + }, + // attr extension + compileTest{ + filterStr: `(memberOf:=foo)`, + expectedFilter: `(memberOf:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+named matching rule extension + compileTest{ + filterStr: `(memberOf:test:=foo)`, + expectedFilter: `(memberOf:test:=foo)`, + expectedType: FilterExtensibleMatch, + }, + // attr+oid matching rule extension + compileTest{ + filterStr: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedFilter: `(cn:1.2.3.4.5:=Fred Flintstone)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn+oid matching rule extension + compileTest{ + filterStr: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedFilter: `(sn:dn:2.4.6.8.10:=Barney Rubble)`, + expectedType: FilterExtensibleMatch, + }, + // attr+dn extension + compileTest{ + filterStr: `(o:dn:=Ace Industry)`, + expectedFilter: `(o:dn:=Ace Industry)`, + expectedType: FilterExtensibleMatch, + }, + // dn extension + compileTest{ + filterStr: `(:dn:2.4.6.8.10:=Dino)`, + expectedFilter: `(:dn:2.4.6.8.10:=Dino)`, + expectedType: FilterExtensibleMatch, + }, + compileTest{ + filterStr: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedFilter: `(memberOf:1.2.840.113556.1.4.1941:=CN=User1,OU=blah,DC=mydomain,DC=net)`, + expectedType: FilterExtensibleMatch, + }, + + // compileTest{ filterStr: "()", filterType: FilterExtensibleMatch }, +} + +var testInvalidFilters = []string{ + `(objectGUID=\zz)`, + `(objectGUID=\a)`, +} + +func TestFilter(t *testing.T) { + // Test Compiler and Decompiler + for _, i := range testFilters { + filter, err := CompileFilter(i.filterStr) + switch { + case err != nil: + if i.expectedErr == "" || !strings.Contains(err.Error(), i.expectedErr) { + t.Errorf("Problem compiling '%s' - '%v' (expected error to contain '%v')", i.filterStr, err, i.expectedErr) + } + case filter.Tag != ber.Tag(i.expectedType): + t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[uint64(i.expectedType)], FilterMap[uint64(filter.Tag)]) + default: + o, err := DecompileFilter(filter) + if err != nil { + t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error()) + } else if i.expectedFilter != o { + t.Errorf("%q expected, got %q", i.expectedFilter, o) + } + } + } +} + +func TestInvalidFilter(t *testing.T) { + for _, filterStr := range testInvalidFilters { + if _, err := CompileFilter(filterStr); err == nil { + t.Errorf("Problem compiling %s - expected err", filterStr) + } + } +} + +func BenchmarkFilterCompile(b *testing.B) { + b.StopTimer() + filters := make([]string, len(testFilters)) + + // Test Compiler and Decompiler + for idx, i := range testFilters { + filters[idx] = i.filterStr + } + + maxIdx := len(filters) + b.StartTimer() + for i := 0; i < b.N; i++ { + CompileFilter(filters[i%maxIdx]) + } +} + +func BenchmarkFilterDecompile(b *testing.B) { + b.StopTimer() + filters := make([]*ber.Packet, len(testFilters)) + + // Test Compiler and Decompiler + for idx, i := range testFilters { + filters[idx], _ = CompileFilter(i.filterStr) + } + + maxIdx := len(filters) + b.StartTimer() + for i := 0; i < b.N; i++ { + DecompileFilter(filters[i%maxIdx]) + } +} diff --git a/v3/go.mod b/v3/go.mod new file mode 100644 index 00000000..78ef9791 --- /dev/null +++ b/v3/go.mod @@ -0,0 +1,5 @@ +module github.com/go-ldap/ldap/v3 + +require github.com/go-asn1-ber/asn1-ber v1.3.1 + +go 1.13 diff --git a/v3/ldap.go b/v3/ldap.go new file mode 100644 index 00000000..b29c018a --- /dev/null +++ b/v3/ldap.go @@ -0,0 +1,340 @@ +package ldap + +import ( + "fmt" + "io/ioutil" + "os" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// LDAP Application Codes +const ( + ApplicationBindRequest = 0 + ApplicationBindResponse = 1 + ApplicationUnbindRequest = 2 + ApplicationSearchRequest = 3 + ApplicationSearchResultEntry = 4 + ApplicationSearchResultDone = 5 + ApplicationModifyRequest = 6 + ApplicationModifyResponse = 7 + ApplicationAddRequest = 8 + ApplicationAddResponse = 9 + ApplicationDelRequest = 10 + ApplicationDelResponse = 11 + ApplicationModifyDNRequest = 12 + ApplicationModifyDNResponse = 13 + ApplicationCompareRequest = 14 + ApplicationCompareResponse = 15 + ApplicationAbandonRequest = 16 + ApplicationSearchResultReference = 19 + ApplicationExtendedRequest = 23 + ApplicationExtendedResponse = 24 +) + +// ApplicationMap contains human readable descriptions of LDAP Application Codes +var ApplicationMap = map[uint8]string{ + ApplicationBindRequest: "Bind Request", + ApplicationBindResponse: "Bind Response", + ApplicationUnbindRequest: "Unbind Request", + ApplicationSearchRequest: "Search Request", + ApplicationSearchResultEntry: "Search Result Entry", + ApplicationSearchResultDone: "Search Result Done", + ApplicationModifyRequest: "Modify Request", + ApplicationModifyResponse: "Modify Response", + ApplicationAddRequest: "Add Request", + ApplicationAddResponse: "Add Response", + ApplicationDelRequest: "Del Request", + ApplicationDelResponse: "Del Response", + ApplicationModifyDNRequest: "Modify DN Request", + ApplicationModifyDNResponse: "Modify DN Response", + ApplicationCompareRequest: "Compare Request", + ApplicationCompareResponse: "Compare Response", + ApplicationAbandonRequest: "Abandon Request", + ApplicationSearchResultReference: "Search Result Reference", + ApplicationExtendedRequest: "Extended Request", + ApplicationExtendedResponse: "Extended Response", +} + +// Ldap Behera Password Policy Draft 10 (https://tools.ietf.org/html/draft-behera-ldap-password-policy-10) +const ( + BeheraPasswordExpired = 0 + BeheraAccountLocked = 1 + BeheraChangeAfterReset = 2 + BeheraPasswordModNotAllowed = 3 + BeheraMustSupplyOldPassword = 4 + BeheraInsufficientPasswordQuality = 5 + BeheraPasswordTooShort = 6 + BeheraPasswordTooYoung = 7 + BeheraPasswordInHistory = 8 +) + +// BeheraPasswordPolicyErrorMap contains human readable descriptions of Behera Password Policy error codes +var BeheraPasswordPolicyErrorMap = map[int8]string{ + BeheraPasswordExpired: "Password expired", + BeheraAccountLocked: "Account locked", + BeheraChangeAfterReset: "Password must be changed", + BeheraPasswordModNotAllowed: "Policy prevents password modification", + BeheraMustSupplyOldPassword: "Policy requires old password in order to change password", + BeheraInsufficientPasswordQuality: "Password fails quality checks", + BeheraPasswordTooShort: "Password is too short for policy", + BeheraPasswordTooYoung: "Password has been changed too recently", + BeheraPasswordInHistory: "New password is in list of old passwords", +} + +// Adds descriptions to an LDAP Response packet for debugging +func addLDAPDescriptions(packet *ber.Packet) (err error) { + defer func() { + if r := recover(); r != nil { + err = NewError(ErrorDebugging, fmt.Errorf("ldap: cannot process packet to add descriptions: %s", r)) + } + }() + packet.Description = "LDAP Response" + packet.Children[0].Description = "Message ID" + + application := uint8(packet.Children[1].Tag) + packet.Children[1].Description = ApplicationMap[application] + + switch application { + case ApplicationBindRequest: + err = addRequestDescriptions(packet) + case ApplicationBindResponse: + err = addDefaultLDAPResponseDescriptions(packet) + case ApplicationUnbindRequest: + err = addRequestDescriptions(packet) + case ApplicationSearchRequest: + err = addRequestDescriptions(packet) + case ApplicationSearchResultEntry: + packet.Children[1].Children[0].Description = "Object Name" + packet.Children[1].Children[1].Description = "Attributes" + for _, child := range packet.Children[1].Children[1].Children { + child.Description = "Attribute" + child.Children[0].Description = "Attribute Name" + child.Children[1].Description = "Attribute Values" + for _, grandchild := range child.Children[1].Children { + grandchild.Description = "Attribute Value" + } + } + if len(packet.Children) == 3 { + err = addControlDescriptions(packet.Children[2]) + } + case ApplicationSearchResultDone: + err = addDefaultLDAPResponseDescriptions(packet) + case ApplicationModifyRequest: + err = addRequestDescriptions(packet) + case ApplicationModifyResponse: + case ApplicationAddRequest: + err = addRequestDescriptions(packet) + case ApplicationAddResponse: + case ApplicationDelRequest: + err = addRequestDescriptions(packet) + case ApplicationDelResponse: + case ApplicationModifyDNRequest: + err = addRequestDescriptions(packet) + case ApplicationModifyDNResponse: + case ApplicationCompareRequest: + err = addRequestDescriptions(packet) + case ApplicationCompareResponse: + case ApplicationAbandonRequest: + err = addRequestDescriptions(packet) + case ApplicationSearchResultReference: + case ApplicationExtendedRequest: + err = addRequestDescriptions(packet) + case ApplicationExtendedResponse: + } + + return err +} + +func addControlDescriptions(packet *ber.Packet) error { + packet.Description = "Controls" + for _, child := range packet.Children { + var value *ber.Packet + controlType := "" + child.Description = "Control" + switch len(child.Children) { + case 0: + // at least one child is required for control type + return fmt.Errorf("at least one child is required for control type") + + case 1: + // just type, no criticality or value + controlType = child.Children[0].Value.(string) + child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" + + case 2: + controlType = child.Children[0].Value.(string) + child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" + // Children[1] could be criticality or value (both are optional) + // duck-type on whether this is a boolean + if _, ok := child.Children[1].Value.(bool); ok { + child.Children[1].Description = "Criticality" + } else { + child.Children[1].Description = "Control Value" + value = child.Children[1] + } + + case 3: + // criticality and value present + controlType = child.Children[0].Value.(string) + child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" + child.Children[1].Description = "Criticality" + child.Children[2].Description = "Control Value" + value = child.Children[2] + + default: + // more than 3 children is invalid + return fmt.Errorf("more than 3 children for control packet found") + } + + if value == nil { + continue + } + switch controlType { + case ControlTypePaging: + value.Description += " (Paging)" + if value.Value != nil { + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } + value.Data.Truncate(0) + value.Value = nil + valueChildren.Children[1].Value = valueChildren.Children[1].Data.Bytes() + value.AppendChild(valueChildren) + } + value.Children[0].Description = "Real Search Control Value" + value.Children[0].Children[0].Description = "Paging Size" + value.Children[0].Children[1].Description = "Cookie" + + case ControlTypeBeheraPasswordPolicy: + value.Description += " (Password Policy - Behera Draft)" + if value.Value != nil { + valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } + value.Data.Truncate(0) + value.Value = nil + value.AppendChild(valueChildren) + } + sequence := value.Children[0] + for _, child := range sequence.Children { + if child.Tag == 0 { + //Warning + warningPacket := child.Children[0] + packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } + val, ok := packet.Value.(int64) + if ok { + if warningPacket.Tag == 0 { + //timeBeforeExpiration + value.Description += " (TimeBeforeExpiration)" + warningPacket.Value = val + } else if warningPacket.Tag == 1 { + //graceAuthNsRemaining + value.Description += " (GraceAuthNsRemaining)" + warningPacket.Value = val + } + } + } else if child.Tag == 1 { + // Error + packet, err := ber.DecodePacketErr(child.Data.Bytes()) + if err != nil { + return fmt.Errorf("failed to decode data bytes: %s", err) + } + val, ok := packet.Value.(int8) + if !ok { + val = -1 + } + child.Description = "Error" + child.Value = val + } + } + } + } + return nil +} + +func addRequestDescriptions(packet *ber.Packet) error { + packet.Description = "LDAP Request" + packet.Children[0].Description = "Message ID" + packet.Children[1].Description = ApplicationMap[uint8(packet.Children[1].Tag)] + if len(packet.Children) == 3 { + return addControlDescriptions(packet.Children[2]) + } + return nil +} + +func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error { + err := GetLDAPError(packet) + if err == nil { + return nil + } + packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[err.(*Error).ResultCode] + ")" + packet.Children[1].Children[1].Description = "Matched DN (" + err.(*Error).MatchedDN + ")" + packet.Children[1].Children[2].Description = "Error Message" + if len(packet.Children[1].Children) > 3 { + packet.Children[1].Children[3].Description = "Referral" + } + if len(packet.Children) == 3 { + return addControlDescriptions(packet.Children[2]) + } + return nil +} + +// DebugBinaryFile reads and prints packets from the given filename +func DebugBinaryFile(fileName string) error { + file, err := ioutil.ReadFile(fileName) + if err != nil { + return NewError(ErrorDebugging, err) + } + ber.PrintBytes(os.Stdout, file, "") + packet, err := ber.DecodePacketErr(file) + if err != nil { + return fmt.Errorf("failed to decode packet: %s", err) + } + if err := addLDAPDescriptions(packet); err != nil { + return err + } + ber.PrintPacket(packet) + + return nil +} + +var hex = "0123456789abcdef" + +func mustEscape(c byte) bool { + return c > 0x7f || c == '(' || c == ')' || c == '\\' || c == '*' || c == 0 +} + +// EscapeFilter escapes from the provided LDAP filter string the special +// characters in the set `()*\` and those out of the range 0 < c < 0x80, +// as defined in RFC4515. +func EscapeFilter(filter string) string { + escape := 0 + for i := 0; i < len(filter); i++ { + if mustEscape(filter[i]) { + escape++ + } + } + if escape == 0 { + return filter + } + buf := make([]byte, len(filter)+escape*2) + for i, j := 0, 0; i < len(filter); i++ { + c := filter[i] + if mustEscape(c) { + buf[j+0] = '\\' + buf[j+1] = hex[c>>4] + buf[j+2] = hex[c&0xf] + j += 3 + } else { + buf[j] = c + j++ + } + } + return string(buf) +} diff --git a/v3/ldap_test.go b/v3/ldap_test.go new file mode 100644 index 00000000..487124b4 --- /dev/null +++ b/v3/ldap_test.go @@ -0,0 +1,326 @@ +package ldap + +import ( + "crypto/tls" + "fmt" + "testing" +) + +var ldapServer = "ldap.itd.umich.edu" +var ldapPort = uint16(389) +var ldapTLSPort = uint16(636) +var baseDN = "dc=umich,dc=edu" +var filter = []string{ + "(cn=cis-fac)", + "(&(owner=*)(cn=cis-fac))", + "(&(objectclass=rfc822mailgroup)(cn=*Computer*))", + "(&(objectclass=rfc822mailgroup)(cn=*Mathematics*))"} +var attributes = []string{ + "cn", + "description"} + +func TestDial(t *testing.T) { + fmt.Printf("TestDial: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Errorf(err.Error()) + return + } + defer l.Close() + fmt.Printf("TestDial: finished...\n") +} + +func TestDialTLS(t *testing.T) { + fmt.Printf("TestDialTLS: starting...\n") + l, err := DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Errorf(err.Error()) + return + } + defer l.Close() + fmt.Printf("TestDialTLS: finished...\n") +} + +func TestStartTLS(t *testing.T) { + fmt.Printf("TestStartTLS: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Errorf(err.Error()) + return + } + err = l.StartTLS(&tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Errorf(err.Error()) + return + } + fmt.Printf("TestStartTLS: finished...\n") +} + +func TestTLSConnectionState(t *testing.T) { + fmt.Printf("TestTLSConnectionState: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Errorf(err.Error()) + return + } + err = l.StartTLS(&tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Errorf(err.Error()) + return + } + + cs, ok := l.TLSConnectionState() + if !ok { + t.Errorf("TLSConnectionState returned ok == false; want true") + } + if cs.Version == 0 || !cs.HandshakeComplete { + t.Errorf("ConnectionState = %#v; expected Version != 0 and HandshakeComplete = true", cs) + } + + fmt.Printf("TestTLSConnectionState: finished...\n") +} + +func TestSearch(t *testing.T) { + fmt.Printf("TestSearch: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Errorf(err.Error()) + return + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + sr, err := l.Search(searchRequest) + if err != nil { + t.Errorf(err.Error()) + return + } + + fmt.Printf("TestSearch: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries)) +} + +func TestSearchStartTLS(t *testing.T) { + fmt.Printf("TestSearchStartTLS: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Errorf(err.Error()) + return + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + sr, err := l.Search(searchRequest) + if err != nil { + t.Errorf(err.Error()) + return + } + + fmt.Printf("TestSearchStartTLS: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries)) + + fmt.Printf("TestSearchStartTLS: upgrading with startTLS\n") + err = l.StartTLS(&tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Errorf(err.Error()) + return + } + + sr, err = l.Search(searchRequest) + if err != nil { + t.Errorf(err.Error()) + return + } + + fmt.Printf("TestSearchStartTLS: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries)) +} + +func TestSearchWithPaging(t *testing.T) { + fmt.Printf("TestSearchWithPaging: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Errorf(err.Error()) + return + } + defer l.Close() + + err = l.UnauthenticatedBind("") + if err != nil { + t.Errorf(err.Error()) + return + } + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + sr, err := l.SearchWithPaging(searchRequest, 5) + if err != nil { + t.Errorf(err.Error()) + return + } + + fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries)) + + searchRequest = NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + []Control{NewControlPaging(5)}) + sr, err = l.SearchWithPaging(searchRequest, 5) + if err != nil { + t.Errorf(err.Error()) + return + } + + fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries)) + + searchRequest = NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + []Control{NewControlPaging(500)}) + sr, err = l.SearchWithPaging(searchRequest, 5) + if err == nil { + t.Errorf("expected an error when paging size in control in search request doesn't match size given in call, got none") + return + } +} + +func searchGoroutine(t *testing.T, l *Conn, results chan *SearchResult, i int) { + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[i], + attributes, + nil) + sr, err := l.Search(searchRequest) + if err != nil { + t.Errorf(err.Error()) + results <- nil + return + } + results <- sr +} + +func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) { + fmt.Printf("TestMultiGoroutineSearch: starting...\n") + var l *Conn + var err error + if TLS { + l, err = DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Errorf(err.Error()) + return + } + defer l.Close() + } else { + l, err = Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Errorf(err.Error()) + return + } + if startTLS { + fmt.Printf("TestMultiGoroutineSearch: using StartTLS...\n") + err := l.StartTLS(&tls.Config{InsecureSkipVerify: true}) + if err != nil { + t.Errorf(err.Error()) + return + } + + } + } + + results := make([]chan *SearchResult, len(filter)) + for i := range filter { + results[i] = make(chan *SearchResult) + go searchGoroutine(t, l, results[i], i) + } + for i := range filter { + sr := <-results[i] + if sr == nil { + t.Errorf("Did not receive results from goroutine for %q", filter[i]) + } else { + fmt.Printf("TestMultiGoroutineSearch(%d): %s -> num of entries = %d\n", i, filter[i], len(sr.Entries)) + } + } +} + +func TestMultiGoroutineSearch(t *testing.T) { + testMultiGoroutineSearch(t, false, false) + testMultiGoroutineSearch(t, true, true) + testMultiGoroutineSearch(t, false, true) +} + +func TestEscapeFilter(t *testing.T) { + if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want { + t.Errorf("Got %s, expected %s", want, got) + } + if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want { + t.Errorf("Got %s, expected %s", want, got) + } +} + +func TestCompare(t *testing.T) { + fmt.Printf("TestCompare: starting...\n") + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Fatal(err.Error()) + } + defer l.Close() + + dn := "cn=math mich,ou=User Groups,ou=Groups,dc=umich,dc=edu" + attribute := "cn" + value := "math mich" + + sr, err := l.Compare(dn, attribute, value) + if err != nil { + t.Errorf(err.Error()) + return + } + + fmt.Printf("TestCompare: -> %v\n", sr) +} + +func TestMatchDNError(t *testing.T) { + fmt.Printf("TestMatchDNError: starting..\n") + + l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort)) + if err != nil { + t.Fatal(err.Error()) + } + defer l.Close() + + wrongBase := "ou=roups,dc=umich,dc=edu" + + searchRequest := NewSearchRequest( + wrongBase, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[0], + attributes, + nil) + + _, err = l.Search(searchRequest) + + if err == nil { + t.Errorf("Expected Error, got nil") + return + } + + fmt.Printf("TestMatchDNError: err: %s\n", err.Error()) + +} diff --git a/v3/moddn.go b/v3/moddn.go new file mode 100644 index 00000000..ca48b2b4 --- /dev/null +++ b/v3/moddn.go @@ -0,0 +1,85 @@ +// Package ldap - moddn.go contains ModifyDN functionality +// +// https://tools.ietf.org/html/rfc4511 +// ModifyDNRequest ::= [APPLICATION 12] SEQUENCE { +// entry LDAPDN, +// newrdn RelativeLDAPDN, +// deleteoldrdn BOOLEAN, +// newSuperior [0] LDAPDN OPTIONAL } +// +// +package ldap + +import ( + "log" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// ModifyDNRequest holds the request to modify a DN +type ModifyDNRequest struct { + DN string + NewRDN string + DeleteOldRDN bool + NewSuperior string +} + +// NewModifyDNRequest creates a new request which can be passed to ModifyDN(). +// +// To move an object in the tree, set the "newSup" to the new parent entry DN. Use an +// empty string for just changing the object's RDN. +// +// For moving the object without renaming, the "rdn" must be the first +// RDN of the given DN. +// +// A call like +// mdnReq := NewModifyDNRequest("uid=someone,dc=example,dc=org", "uid=newname", true, "") +// will setup the request to just rename uid=someone,dc=example,dc=org to +// uid=newname,dc=example,dc=org. +func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *ModifyDNRequest { + return &ModifyDNRequest{ + DN: dn, + NewRDN: rdn, + DeleteOldRDN: delOld, + NewSuperior: newSup, + } +} + +func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.NewRDN, "New RDN")) + pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN")) + if req.NewSuperior != "" { + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.NewSuperior, "New Superior")) + } + + envelope.AppendChild(pkt) + + return nil +} + +// ModifyDN renames the given DN and optionally move to another base (when the "newSup" argument +// to NewModifyDNRequest() is not ""). +func (l *Conn) ModifyDN(m *ModifyDNRequest) error { + msgCtx, err := l.doRequest(m) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + if packet.Children[1].Tag == ApplicationModifyDNResponse { + err := GetLDAPError(packet) + if err != nil { + return err + } + } else { + log.Printf("Unexpected Response: %d", packet.Children[1].Tag) + } + return nil +} diff --git a/v3/moddn_test.go b/v3/moddn_test.go new file mode 100644 index 00000000..832a616c --- /dev/null +++ b/v3/moddn_test.go @@ -0,0 +1,73 @@ +package ldap + +import ( + "log" +) + +// ExampleConn_ModifyDN_renameNoMove shows how to rename an entry without moving it +func ExampleConn_ModifyDN_renameNoMove() { + conn, err := Dial("tcp", "ldap.example.org:389") + if err != nil { + log.Fatalf("Failed to connect: %s\n", err) + } + defer conn.Close() + + _, err = conn.SimpleBind(&SimpleBindRequest{ + Username: "uid=someone,ou=people,dc=example,dc=org", + Password: "MySecretPass", + }) + if err != nil { + log.Fatalf("Failed to bind: %s\n", err) + } + // just rename to uid=new,ou=people,dc=example,dc=org: + req := NewModifyDNRequest("uid=user,ou=people,dc=example,dc=org", "uid=new", true, "") + if err = conn.ModifyDN(req); err != nil { + log.Fatalf("Failed to call ModifyDN(): %s\n", err) + } +} + +// ExampleConn_ModifyDN_renameAndMove shows how to rename an entry and moving it to a new base +func ExampleConn_ModifyDN_renameAndMove() { + conn, err := Dial("tcp", "ldap.example.org:389") + if err != nil { + log.Fatalf("Failed to connect: %s\n", err) + } + defer conn.Close() + + _, err = conn.SimpleBind(&SimpleBindRequest{ + Username: "uid=someone,ou=people,dc=example,dc=org", + Password: "MySecretPass", + }) + if err != nil { + log.Fatalf("Failed to bind: %s\n", err) + } + // rename to uid=new,ou=people,dc=example,dc=org and move to ou=users,dc=example,dc=org -> + // uid=new,ou=users,dc=example,dc=org + req := NewModifyDNRequest("uid=user,ou=people,dc=example,dc=org", "uid=new", true, "ou=users,dc=example,dc=org") + + if err = conn.ModifyDN(req); err != nil { + log.Fatalf("Failed to call ModifyDN(): %s\n", err) + } +} + +// ExampleConn_ModifyDN_moveOnly shows how to move an entry to a new base without renaming the RDN +func ExampleConn_ModifyDN_moveOnly() { + conn, err := Dial("tcp", "ldap.example.org:389") + if err != nil { + log.Fatalf("Failed to connect: %s\n", err) + } + defer conn.Close() + + _, err = conn.SimpleBind(&SimpleBindRequest{ + Username: "uid=someone,ou=people,dc=example,dc=org", + Password: "MySecretPass", + }) + if err != nil { + log.Fatalf("Failed to bind: %s\n", err) + } + // move to ou=users,dc=example,dc=org -> uid=user,ou=users,dc=example,dc=org + req := NewModifyDNRequest("uid=user,ou=people,dc=example,dc=org", "uid=user", true, "ou=users,dc=example,dc=org") + if err = conn.ModifyDN(req); err != nil { + log.Fatalf("Failed to call ModifyDN(): %s\n", err) + } +} diff --git a/v3/modify.go b/v3/modify.go new file mode 100644 index 00000000..88f65c1b --- /dev/null +++ b/v3/modify.go @@ -0,0 +1,151 @@ +// File contains Modify functionality +// +// https://tools.ietf.org/html/rfc4511 +// +// ModifyRequest ::= [APPLICATION 6] SEQUENCE { +// object LDAPDN, +// changes SEQUENCE OF change SEQUENCE { +// operation ENUMERATED { +// add (0), +// delete (1), +// replace (2), +// ... }, +// modification PartialAttribute } } +// +// PartialAttribute ::= SEQUENCE { +// type AttributeDescription, +// vals SET OF value AttributeValue } +// +// AttributeDescription ::= LDAPString +// -- Constrained to +// -- [RFC4512] +// +// AttributeValue ::= OCTET STRING +// + +package ldap + +import ( + "log" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Change operation choices +const ( + AddAttribute = 0 + DeleteAttribute = 1 + ReplaceAttribute = 2 +) + +// PartialAttribute for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511 +type PartialAttribute struct { + // Type is the type of the partial attribute + Type string + // Vals are the values of the partial attribute + Vals []string +} + +func (p *PartialAttribute) encode() *ber.Packet { + seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute") + seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.Type, "Type")) + set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") + for _, value := range p.Vals { + set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals")) + } + seq.AppendChild(set) + return seq +} + +// Change for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511 +type Change struct { + // Operation is the type of change to be made + Operation uint + // Modification is the attribute to be modified + Modification PartialAttribute +} + +func (c *Change) encode() *ber.Packet { + change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") + change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(c.Operation), "Operation")) + change.AppendChild(c.Modification.encode()) + return change +} + +// ModifyRequest as defined in https://tools.ietf.org/html/rfc4511 +type ModifyRequest struct { + // DN is the distinguishedName of the directory entry to modify + DN string + // Changes contain the attributes to modify + Changes []Change + // Controls hold optional controls to send with the request + Controls []Control +} + +// Add appends the given attribute to the list of changes to be made +func (req *ModifyRequest) Add(attrType string, attrVals []string) { + req.appendChange(AddAttribute, attrType, attrVals) +} + +// Delete appends the given attribute to the list of changes to be made +func (req *ModifyRequest) Delete(attrType string, attrVals []string) { + req.appendChange(DeleteAttribute, attrType, attrVals) +} + +// Replace appends the given attribute to the list of changes to be made +func (req *ModifyRequest) Replace(attrType string, attrVals []string) { + req.appendChange(ReplaceAttribute, attrType, attrVals) +} + +func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) { + req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) +} + +func (req *ModifyRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) + changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes") + for _, change := range req.Changes { + changes.AppendChild(change.encode()) + } + pkt.AppendChild(changes) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil +} + +// NewModifyRequest creates a modify request for the given DN +func NewModifyRequest(dn string, controls []Control) *ModifyRequest { + return &ModifyRequest{ + DN: dn, + Controls: controls, + } +} + +// Modify performs the ModifyRequest +func (l *Conn) Modify(modifyRequest *ModifyRequest) error { + msgCtx, err := l.doRequest(modifyRequest) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + if packet.Children[1].Tag == ApplicationModifyResponse { + err := GetLDAPError(packet) + if err != nil { + return err + } + } else { + log.Printf("Unexpected Response: %d", packet.Children[1].Tag) + } + return nil +} diff --git a/v3/passwdmodify.go b/v3/passwdmodify.go new file mode 100644 index 00000000..135554d9 --- /dev/null +++ b/v3/passwdmodify.go @@ -0,0 +1,131 @@ +// This file contains the password modify extended operation as specified in rfc 3062 +// +// https://tools.ietf.org/html/rfc3062 +// + +package ldap + +import ( + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +const ( + passwordModifyOID = "1.3.6.1.4.1.4203.1.11.1" +) + +// PasswordModifyRequest implements the Password Modify Extended Operation as defined in https://www.ietf.org/rfc/rfc3062.txt +type PasswordModifyRequest struct { + // UserIdentity is an optional string representation of the user associated with the request. + // This string may or may not be an LDAPDN [RFC2253]. + // If no UserIdentity field is present, the request acts up upon the password of the user currently associated with the LDAP session + UserIdentity string + // OldPassword, if present, contains the user's current password + OldPassword string + // NewPassword, if present, contains the desired password for this user + NewPassword string +} + +// PasswordModifyResult holds the server response to a PasswordModifyRequest +type PasswordModifyResult struct { + // GeneratedPassword holds a password generated by the server, if present + GeneratedPassword string + // Referral are the returned referral + Referral string +} + +func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Password Modify Extended Operation") + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, passwordModifyOID, "Extended Request Name: Password Modify OID")) + + extendedRequestValue := ber.Encode(ber.ClassContext, ber.TypePrimitive, 1, nil, "Extended Request Value: Password Modify Request") + passwordModifyRequestValue := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Password Modify Request") + if req.UserIdentity != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.UserIdentity, "User Identity")) + } + if req.OldPassword != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 1, req.OldPassword, "Old Password")) + } + if req.NewPassword != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 2, req.NewPassword, "New Password")) + } + extendedRequestValue.AppendChild(passwordModifyRequestValue) + + pkt.AppendChild(extendedRequestValue) + + envelope.AppendChild(pkt) + + return nil +} + +// NewPasswordModifyRequest creates a new PasswordModifyRequest +// +// According to the RFC 3602: +// userIdentity is a string representing the user associated with the request. +// This string may or may not be an LDAPDN (RFC 2253). +// If userIdentity is empty then the operation will act on the user associated +// with the session. +// +// oldPassword is the current user's password, it can be empty or it can be +// needed depending on the session user access rights (usually an administrator +// can change a user's password without knowing the current one) and the +// password policy (see pwdSafeModify password policy's attribute) +// +// newPassword is the desired user's password. If empty the server can return +// an error or generate a new password that will be available in the +// PasswordModifyResult.GeneratedPassword +// +func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPassword string) *PasswordModifyRequest { + return &PasswordModifyRequest{ + UserIdentity: userIdentity, + OldPassword: oldPassword, + NewPassword: newPassword, + } +} + +// PasswordModify performs the modification request +func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { + msgCtx, err := l.doRequest(passwordModifyRequest) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return nil, err + } + + result := &PasswordModifyResult{} + + if packet.Children[1].Tag == ApplicationExtendedResponse { + err := GetLDAPError(packet) + if err != nil { + if IsErrorWithCode(err, LDAPResultReferral) { + for _, child := range packet.Children[1].Children { + if child.Tag == 3 { + result.Referral = child.Children[0].Value.(string) + } + } + } + return result, err + } + } else { + return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag)) + } + + extendedResponse := packet.Children[1] + for _, child := range extendedResponse.Children { + if child.Tag == 11 { + passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes()) + if len(passwordModifyResponseValue.Children) == 1 { + if passwordModifyResponseValue.Children[0].Tag == 0 { + result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes()) + } + } + } + } + + return result, nil +} diff --git a/v3/request.go b/v3/request.go new file mode 100644 index 00000000..4790bbfe --- /dev/null +++ b/v3/request.go @@ -0,0 +1,66 @@ +package ldap + +import ( + "errors" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +var ( + errRespChanClosed = errors.New("ldap: response channel closed") + errCouldNotRetMsg = errors.New("ldap: could not retrieve message") +) + +type request interface { + appendTo(*ber.Packet) error +} + +type requestFunc func(*ber.Packet) error + +func (f requestFunc) appendTo(p *ber.Packet) error { + return f(p) +} + +func (l *Conn) doRequest(req request) (*messageContext, error) { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + if err := req.appendTo(packet); err != nil { + return nil, err + } + + if l.Debug { + ber.PrintPacket(packet) + } + + msgCtx, err := l.sendMessage(packet) + if err != nil { + return nil, err + } + l.Debug.Printf("%d: returning", msgCtx.id) + return msgCtx, nil +} + +func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) { + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + return nil, NewError(ErrorNetwork, errRespChanClosed) + } + packet, err := packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return nil, err + } + + if packet == nil { + return nil, NewError(ErrorNetwork, errCouldNotRetMsg) + } + + if l.Debug { + if err = addLDAPDescriptions(packet); err != nil { + return nil, err + } + ber.PrintPacket(packet) + } + return packet, nil +} diff --git a/v3/search.go b/v3/search.go new file mode 100644 index 00000000..186246c0 --- /dev/null +++ b/v3/search.go @@ -0,0 +1,425 @@ +// File contains Search functionality +// +// https://tools.ietf.org/html/rfc4511 +// +// SearchRequest ::= [APPLICATION 3] SEQUENCE { +// baseObject LDAPDN, +// scope ENUMERATED { +// baseObject (0), +// singleLevel (1), +// wholeSubtree (2), +// ... }, +// derefAliases ENUMERATED { +// neverDerefAliases (0), +// derefInSearching (1), +// derefFindingBaseObj (2), +// derefAlways (3) }, +// sizeLimit INTEGER (0 .. maxInt), +// timeLimit INTEGER (0 .. maxInt), +// typesOnly BOOLEAN, +// filter Filter, +// attributes AttributeSelection } +// +// AttributeSelection ::= SEQUENCE OF selector LDAPString +// -- The LDAPString is constrained to +// -- in Section 4.5.1.8 +// +// Filter ::= CHOICE { +// and [0] SET SIZE (1..MAX) OF filter Filter, +// or [1] SET SIZE (1..MAX) OF filter Filter, +// not [2] Filter, +// equalityMatch [3] AttributeValueAssertion, +// substrings [4] SubstringFilter, +// greaterOrEqual [5] AttributeValueAssertion, +// lessOrEqual [6] AttributeValueAssertion, +// present [7] AttributeDescription, +// approxMatch [8] AttributeValueAssertion, +// extensibleMatch [9] MatchingRuleAssertion, +// ... } +// +// SubstringFilter ::= SEQUENCE { +// type AttributeDescription, +// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE { +// initial [0] AssertionValue, -- can occur at most once +// any [1] AssertionValue, +// final [2] AssertionValue } -- can occur at most once +// } +// +// MatchingRuleAssertion ::= SEQUENCE { +// matchingRule [1] MatchingRuleId OPTIONAL, +// type [2] AttributeDescription OPTIONAL, +// matchValue [3] AssertionValue, +// dnAttributes [4] BOOLEAN DEFAULT FALSE } +// +// + +package ldap + +import ( + "errors" + "fmt" + "sort" + "strings" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// scope choices +const ( + ScopeBaseObject = 0 + ScopeSingleLevel = 1 + ScopeWholeSubtree = 2 +) + +// ScopeMap contains human readable descriptions of scope choices +var ScopeMap = map[int]string{ + ScopeBaseObject: "Base Object", + ScopeSingleLevel: "Single Level", + ScopeWholeSubtree: "Whole Subtree", +} + +// derefAliases +const ( + NeverDerefAliases = 0 + DerefInSearching = 1 + DerefFindingBaseObj = 2 + DerefAlways = 3 +) + +// DerefMap contains human readable descriptions of derefAliases choices +var DerefMap = map[int]string{ + NeverDerefAliases: "NeverDerefAliases", + DerefInSearching: "DerefInSearching", + DerefFindingBaseObj: "DerefFindingBaseObj", + DerefAlways: "DerefAlways", +} + +// NewEntry returns an Entry object with the specified distinguished name and attribute key-value pairs. +// The map of attributes is accessed in alphabetical order of the keys in order to ensure that, for the +// same input map of attributes, the output entry will contain the same order of attributes +func NewEntry(dn string, attributes map[string][]string) *Entry { + var attributeNames []string + for attributeName := range attributes { + attributeNames = append(attributeNames, attributeName) + } + sort.Strings(attributeNames) + + var encodedAttributes []*EntryAttribute + for _, attributeName := range attributeNames { + encodedAttributes = append(encodedAttributes, NewEntryAttribute(attributeName, attributes[attributeName])) + } + return &Entry{ + DN: dn, + Attributes: encodedAttributes, + } +} + +// Entry represents a single search result entry +type Entry struct { + // DN is the distinguished name of the entry + DN string + // Attributes are the returned attributes for the entry + Attributes []*EntryAttribute +} + +// GetAttributeValues returns the values for the named attribute, or an empty list +func (e *Entry) GetAttributeValues(attribute string) []string { + for _, attr := range e.Attributes { + if attr.Name == attribute { + return attr.Values + } + } + return []string{} +} + +// GetRawAttributeValues returns the byte values for the named attribute, or an empty list +func (e *Entry) GetRawAttributeValues(attribute string) [][]byte { + for _, attr := range e.Attributes { + if attr.Name == attribute { + return attr.ByteValues + } + } + return [][]byte{} +} + +// GetAttributeValue returns the first value for the named attribute, or "" +func (e *Entry) GetAttributeValue(attribute string) string { + values := e.GetAttributeValues(attribute) + if len(values) == 0 { + return "" + } + return values[0] +} + +// GetRawAttributeValue returns the first value for the named attribute, or an empty slice +func (e *Entry) GetRawAttributeValue(attribute string) []byte { + values := e.GetRawAttributeValues(attribute) + if len(values) == 0 { + return []byte{} + } + return values[0] +} + +// Print outputs a human-readable description +func (e *Entry) Print() { + fmt.Printf("DN: %s\n", e.DN) + for _, attr := range e.Attributes { + attr.Print() + } +} + +// PrettyPrint outputs a human-readable description indenting +func (e *Entry) PrettyPrint(indent int) { + fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN) + for _, attr := range e.Attributes { + attr.PrettyPrint(indent + 2) + } +} + +// NewEntryAttribute returns a new EntryAttribute with the desired key-value pair +func NewEntryAttribute(name string, values []string) *EntryAttribute { + var bytes [][]byte + for _, value := range values { + bytes = append(bytes, []byte(value)) + } + return &EntryAttribute{ + Name: name, + Values: values, + ByteValues: bytes, + } +} + +// EntryAttribute holds a single attribute +type EntryAttribute struct { + // Name is the name of the attribute + Name string + // Values contain the string values of the attribute + Values []string + // ByteValues contain the raw values of the attribute + ByteValues [][]byte +} + +// Print outputs a human-readable description +func (e *EntryAttribute) Print() { + fmt.Printf("%s: %s\n", e.Name, e.Values) +} + +// PrettyPrint outputs a human-readable description with indenting +func (e *EntryAttribute) PrettyPrint(indent int) { + fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values) +} + +// SearchResult holds the server's response to a search request +type SearchResult struct { + // Entries are the returned entries + Entries []*Entry + // Referrals are the returned referrals + Referrals []string + // Controls are the returned controls + Controls []Control +} + +// Print outputs a human-readable description +func (s *SearchResult) Print() { + for _, entry := range s.Entries { + entry.Print() + } +} + +// PrettyPrint outputs a human-readable description with indenting +func (s *SearchResult) PrettyPrint(indent int) { + for _, entry := range s.Entries { + entry.PrettyPrint(indent) + } +} + +// SearchRequest represents a search request to send to the server +type SearchRequest struct { + BaseDN string + Scope int + DerefAliases int + SizeLimit int + TimeLimit int + TypesOnly bool + Filter string + Attributes []string + Controls []Control +} + +func (req *SearchRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.Scope), "Scope")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.DerefAliases), "Deref Aliases")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.SizeLimit), "Size Limit")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.TimeLimit), "Time Limit")) + pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.TypesOnly, "Types Only")) + // compile and encode filter + filterPacket, err := CompileFilter(req.Filter) + if err != nil { + return err + } + pkt.AppendChild(filterPacket) + // encode attributes + attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") + for _, attribute := range req.Attributes { + attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) + } + pkt.AppendChild(attributesPacket) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil +} + +// NewSearchRequest creates a new search request +func NewSearchRequest( + BaseDN string, + Scope, DerefAliases, SizeLimit, TimeLimit int, + TypesOnly bool, + Filter string, + Attributes []string, + Controls []Control, +) *SearchRequest { + return &SearchRequest{ + BaseDN: BaseDN, + Scope: Scope, + DerefAliases: DerefAliases, + SizeLimit: SizeLimit, + TimeLimit: TimeLimit, + TypesOnly: TypesOnly, + Filter: Filter, + Attributes: Attributes, + Controls: Controls, + } +} + +// SearchWithPaging accepts a search request and desired page size in order to execute LDAP queries to fulfill the +// search request. All paged LDAP query responses will be buffered and the final result will be returned atomically. +// The following four cases are possible given the arguments: +// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size +// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request +// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries +// A requested pagingSize of 0 is interpreted as no limit by LDAP servers. +func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) { + var pagingControl *ControlPaging + + control := FindControl(searchRequest.Controls, ControlTypePaging) + if control == nil { + pagingControl = NewControlPaging(pagingSize) + searchRequest.Controls = append(searchRequest.Controls, pagingControl) + } else { + castControl, ok := control.(*ControlPaging) + if !ok { + return nil, fmt.Errorf("expected paging control to be of type *ControlPaging, got %v", control) + } + if castControl.PagingSize != pagingSize { + return nil, fmt.Errorf("paging size given in search request (%d) conflicts with size given in search call (%d)", castControl.PagingSize, pagingSize) + } + pagingControl = castControl + } + + searchResult := new(SearchResult) + for { + result, err := l.Search(searchRequest) + l.Debug.Printf("Looking for Paging Control...") + if err != nil { + return searchResult, err + } + if result == nil { + return searchResult, NewError(ErrorNetwork, errors.New("ldap: packet not received")) + } + + for _, entry := range result.Entries { + searchResult.Entries = append(searchResult.Entries, entry) + } + for _, referral := range result.Referrals { + searchResult.Referrals = append(searchResult.Referrals, referral) + } + for _, control := range result.Controls { + searchResult.Controls = append(searchResult.Controls, control) + } + + l.Debug.Printf("Looking for Paging Control...") + pagingResult := FindControl(result.Controls, ControlTypePaging) + if pagingResult == nil { + pagingControl = nil + l.Debug.Printf("Could not find paging control. Breaking...") + break + } + + cookie := pagingResult.(*ControlPaging).Cookie + if len(cookie) == 0 { + pagingControl = nil + l.Debug.Printf("Could not find cookie. Breaking...") + break + } + pagingControl.SetCookie(cookie) + } + + if pagingControl != nil { + l.Debug.Printf("Abandoning Paging...") + pagingControl.PagingSize = 0 + l.Search(searchRequest) + } + + return searchResult, nil +} + +// Search performs the given search request +func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { + msgCtx, err := l.doRequest(searchRequest) + if err != nil { + return nil, err + } + defer l.finishMessage(msgCtx) + + result := &SearchResult{ + Entries: make([]*Entry, 0), + Referrals: make([]string, 0), + Controls: make([]Control, 0)} + + for { + packet, err := l.readPacket(msgCtx) + if err != nil { + return nil, err + } + + switch packet.Children[1].Tag { + case 4: + entry := new(Entry) + entry.DN = packet.Children[1].Children[0].Value.(string) + for _, child := range packet.Children[1].Children[1].Children { + attr := new(EntryAttribute) + attr.Name = child.Children[0].Value.(string) + for _, value := range child.Children[1].Children { + attr.Values = append(attr.Values, value.Value.(string)) + attr.ByteValues = append(attr.ByteValues, value.ByteValue) + } + entry.Attributes = append(entry.Attributes, attr) + } + result.Entries = append(result.Entries, entry) + case 5: + err := GetLDAPError(packet) + if err != nil { + return nil, err + } + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + return nil, fmt.Errorf("failed to decode child control: %s", err) + } + result.Controls = append(result.Controls, decodedChild) + } + } + return result, nil + case 19: + result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string)) + } + } +} diff --git a/v3/search_test.go b/v3/search_test.go new file mode 100644 index 00000000..4a4b7dab --- /dev/null +++ b/v3/search_test.go @@ -0,0 +1,31 @@ +package ldap + +import ( + "reflect" + "testing" +) + +// TestNewEntry tests that repeated calls to NewEntry return the same value with the same input +func TestNewEntry(t *testing.T) { + dn := "testDN" + attributes := map[string][]string{ + "alpha": {"value"}, + "beta": {"value"}, + "gamma": {"value"}, + "delta": {"value"}, + "epsilon": {"value"}, + } + executedEntry := NewEntry(dn, attributes) + + iteration := 0 + for { + if iteration == 100 { + break + } + testEntry := NewEntry(dn, attributes) + if !reflect.DeepEqual(executedEntry, testEntry) { + t.Fatalf("subsequent calls to NewEntry did not yield the same result:\n\texpected:\n\t%v\n\tgot:\n\t%v\n", executedEntry, testEntry) + } + iteration = iteration + 1 + } +}