diff --git a/.travis.yml b/.travis.yml index bb899b2a..107aa786 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,6 +9,7 @@ go: - "1.10.x" - "1.11.x" - "1.12.x" + - "1.13.x" - tip git: diff --git a/add.go b/add.go index 19bce1b7..e2cb9b06 100644 --- a/add.go +++ b/add.go @@ -10,10 +10,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // Attribute represents an LDAP attribute @@ -45,20 +44,26 @@ type AddRequest struct { Controls []Control } -func (a AddRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.DN, "DN")) +func (req *AddRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") - for _, attribute := range a.Attributes { + for _, attribute := range req.Attributes { attributes.AppendChild(attribute.encode()) } - request.AppendChild(attributes) - return request + pkt.AppendChild(attributes) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // Attribute adds an attribute with the given type and values -func (a *AddRequest) Attribute(attrType string, attrVals []string) { - a.Attributes = append(a.Attributes, Attribute{Type: attrType, Vals: attrVals}) +func (req *AddRequest) Attribute(attrType string, attrVals []string) { + req.Attributes = append(req.Attributes, Attribute{Type: attrType, Vals: attrVals}) } // NewAddRequest returns an AddRequest for the given DN, with no attributes @@ -72,39 +77,17 @@ func NewAddRequest(dn string, controls []Control) *AddRequest { // Add performs the given AddRequest func (l *Conn) Add(addRequest *AddRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(addRequest.encode()) - if len(addRequest.Controls) > 0 { - packet.AppendChild(encodeControls(addRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(addRequest) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationAddResponse { err := GetLDAPError(packet) if err != nil { @@ -113,7 +96,5 @@ func (l *Conn) Add(addRequest *AddRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/bind.go b/bind.go index 59c3f5ef..7b5e657a 100644 --- a/bind.go +++ b/bind.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // SimpleBindRequest represents a username/password bind operation @@ -35,13 +35,18 @@ func NewSimpleBindRequest(username string, password string, controls []Control) } } -func (bindRequest *SimpleBindRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, bindRequest.Username, "User Name")) - request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, bindRequest.Password, "Password")) +func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name")) + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password")) - return request + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // SimpleBind performs the simple bind operation defined in the given request @@ -50,41 +55,17 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client")) } - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - encodedBindRequest := simpleBindRequest.encode() - packet.AppendChild(encodedBindRequest) - if len(simpleBindRequest.Controls) > 0 { - packet.AppendChild(encodeControls(simpleBindRequest.Controls)) - } - - if l.Debug { - ber.PrintPacket(packet) - } - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(simpleBindRequest) if err != nil { return nil, err } defer l.finishMessage(msgCtx) - packetResponse, ok := <-msgCtx.responses - if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return nil, err } - if l.Debug { - if err = addLDAPDescriptions(packet); err != nil { - return nil, err - } - ber.PrintPacket(packet) - } - result := &SimpleBindResult{ Controls: make([]Control, 0), } @@ -133,3 +114,39 @@ func (l *Conn) UnauthenticatedBind(username string) error { _, err := l.SimpleBind(req) return err } + +var externalBindRequest = requestFunc(func(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name")) + + saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication") + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech")) + saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred")) + + pkt.AppendChild(saslAuth) + + envelope.AppendChild(pkt) + + return nil +}) + +// ExternalBind performs SASL/EXTERNAL authentication. +// +// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind. +// +// See https://tools.ietf.org/html/rfc4422#appendix-A +func (l *Conn) ExternalBind() error { + msgCtx, err := l.doRequest(externalBindRequest) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + packet, err := l.readPacket(msgCtx) + if err != nil { + return err + } + + return GetLDAPError(packet) +} diff --git a/client.go b/client.go index c7f41f6f..619677c7 100644 --- a/client.go +++ b/client.go @@ -8,21 +8,23 @@ import ( // Client knows how to interact with an LDAP server type Client interface { Start() - StartTLS(config *tls.Config) error + StartTLS(*tls.Config) error Close() SetTimeout(time.Duration) Bind(username, password string) error - SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) + UnauthenticatedBind(username string) error + SimpleBind(*SimpleBindRequest) (*SimpleBindResult, error) + ExternalBind() error - Add(addRequest *AddRequest) error - Del(delRequest *DelRequest) error - Modify(modifyRequest *ModifyRequest) error - ModifyDN(modifyDNRequest *ModifyDNRequest) error + Add(*AddRequest) error + Del(*DelRequest) error + Modify(*ModifyRequest) error + ModifyDN(*ModifyDNRequest) error Compare(dn, attribute, value string) (bool, error) - PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) + PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) - Search(searchRequest *SearchRequest) (*SearchResult, error) + Search(*SearchRequest) (*SearchResult, error) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) } diff --git a/compare.go b/compare.go index 5b5013cb..83694d82 100644 --- a/compare.go +++ b/compare.go @@ -20,53 +20,50 @@ package ldap import ( - "errors" "fmt" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) -// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise -// false with any error that occurs if any. -func (l *Conn) Compare(dn, attribute, value string) (bool, error) { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) +// CompareRequest represents an LDAP CompareRequest operation. +type CompareRequest struct { + DN string + Attribute string + Value string +} - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "DN")) +func (req *CompareRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) ava := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "AttributeValueAssertion") - ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "AttributeDesc")) - ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "AssertionValue")) - request.AppendChild(ava) - packet.AppendChild(request) + ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Attribute, "AttributeDesc")) + ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Value, "AssertionValue")) - l.Debug.PrintPacket(packet) + pkt.AppendChild(ava) - msgCtx, err := l.sendMessage(packet) + envelope.AppendChild(pkt) + + return nil +} + +// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise +// false with any error that occurs if any. +func (l *Conn) Compare(dn, attribute, value string) (bool, error) { + msgCtx, err := l.doRequest(&CompareRequest{ + DN: dn, + Attribute: attribute, + Value: value}) if err != nil { return false, err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return false, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return false, err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return false, err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationCompareResponse { err := GetLDAPError(packet) diff --git a/conn.go b/conn.go index c20471fc..ab9bd4f9 100644 --- a/conn.go +++ b/conn.go @@ -11,7 +11,7 @@ import ( "sync/atomic" "time" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) const ( @@ -140,7 +140,6 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { // or ldap:// specified as protocol. On success a new Conn for the connection // is returned. func DialURL(addr string) (*Conn, error) { - lurl, err := url.Parse(addr) if err != nil { return nil, NewError(ErrorNetwork, err) @@ -154,6 +153,11 @@ func DialURL(addr string) (*Conn, error) { } switch lurl.Scheme { + case "ldapi": + if lurl.Path == "" || lurl.Path == "/" { + lurl.Path = "/var/run/slapd/ldapi" + } + return Dial("unix", lurl.Path) case "ldap": if port == "" { port = DefaultLdapPort @@ -490,11 +494,13 @@ func (l *Conn) reader() { // A read error is expected here if we are closing the connection... if !l.IsClosing() { l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) - l.Debug.Printf("reader error: %s", err.Error()) + l.Debug.Printf("reader error: %s", err) } return } - addLDAPDescriptions(packet) + if err := addLDAPDescriptions(packet); err != nil { + l.Debug.Printf("descriptions error: %s", err) + } if len(packet.Children) == 0 { l.Debug.Printf("Received bad ldap packet") continue diff --git a/debug.go b/debug.go index 7279fc25..61628e3a 100644 --- a/debug.go +++ b/debug.go @@ -3,20 +3,26 @@ package ldap import ( "log" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // debugging type // - has a Printf method to write the debug output type debugging bool -// write debug output +// Enable controls debugging mode. +func (debug *debugging) Enable(b bool) { + *debug = debugging(b) +} + +// Printf writes debug output. func (debug debugging) Printf(format string, args ...interface{}) { if debug { log.Printf(format, args...) } } +// PrintPacket dumps a packet. func (debug debugging) PrintPacket(packet *ber.Packet) { if debug { ber.PrintPacket(packet) diff --git a/del.go b/del.go index 6f78beb1..0e7767b2 100644 --- a/del.go +++ b/del.go @@ -6,10 +6,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // DelRequest implements an LDAP deletion request @@ -20,15 +19,20 @@ type DelRequest struct { Controls []Control } -func (d DelRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, d.DN, "Del Request") - request.Data.Write([]byte(d.DN)) - return request +func (req *DelRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request") + pkt.Data.Write([]byte(req.DN)) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // NewDelRequest creates a delete request for the given DN and controls -func NewDelRequest(DN string, - Controls []Control) *DelRequest { +func NewDelRequest(DN string, Controls []Control) *DelRequest { return &DelRequest{ DN: DN, Controls: Controls, @@ -37,39 +41,17 @@ func NewDelRequest(DN string, // Del executes the given delete request func (l *Conn) Del(delRequest *DelRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(delRequest.encode()) - if len(delRequest.Controls) > 0 { - packet.AppendChild(encodeControls(delRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(delRequest) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationDelResponse { err := GetLDAPError(packet) if err != nil { @@ -78,7 +60,5 @@ func (l *Conn) Del(delRequest *DelRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/error.go b/error.go index 639ed824..53dedb95 100644 --- a/error.go +++ b/error.go @@ -3,7 +3,7 @@ package ldap import ( "fmt" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // LDAP Result Codes @@ -196,7 +196,9 @@ func (e *Error) Error() string { func GetLDAPError(packet *ber.Packet) error { if packet == nil { return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")} - } else if len(packet.Children) >= 2 { + } + + if len(packet.Children) >= 2 { response := packet.Children[1] if response == nil { return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..8b38566b --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module gopkg.in/ldap.v3 + +require gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..dd2f5e15 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d h1:TxyelI5cVkbREznMhfzycHdkp5cLA7DpE+GKjSslYhM= +gopkg.in/asn1-ber.v1 v1.0.0-20181015200546-f715ec2f112d/go.mod h1:cuepJuh7vyXfUyUwEgHQXw849cJrilpS5NeIjOWESAw= diff --git a/ldap.go b/ldap.go index d7666676..5b694bf3 100644 --- a/ldap.go +++ b/ldap.go @@ -1,12 +1,11 @@ package ldap import ( - "errors" "fmt" "io/ioutil" "os" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // LDAP Application Codes @@ -87,7 +86,7 @@ var BeheraPasswordPolicyErrorMap = map[int8]string{ func addLDAPDescriptions(packet *ber.Packet) (err error) { defer func() { if r := recover(); r != nil { - err = NewError(ErrorDebugging, errors.New("ldap: cannot process packet to add descriptions")) + err = NewError(ErrorDebugging, fmt.Errorf("ldap: cannot process packet to add descriptions: %s", r)) } }() packet.Description = "LDAP Response" @@ -271,6 +270,9 @@ func addRequestDescriptions(packet *ber.Packet) error { func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error { err := GetLDAPError(packet) + if err == nil { + return nil + } packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[err.(*Error).ResultCode] + ")" packet.Children[1].Children[1].Description = "Matched DN (" + err.(*Error).MatchedDN + ")" packet.Children[1].Children[2].Description = "Error Message" diff --git a/moddn.go b/moddn.go index 803279d2..889a82ac 100644 --- a/moddn.go +++ b/moddn.go @@ -11,10 +11,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // ModifyDNRequest holds the request to modify a DN @@ -46,50 +45,34 @@ func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *Modi } } -func (m ModifyDNRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN")) - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.NewRDN, "New RDN")) - request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, m.DeleteOldRDN, "Delete old RDN")) - if m.NewSuperior != "" { - request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, m.NewSuperior, "New Superior")) +func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.NewRDN, "New RDN")) + pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN")) + if req.NewSuperior != "" { + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.NewSuperior, "New Superior")) } - return request + + envelope.AppendChild(pkt) + + return nil } // ModifyDN renames the given DN and optionally move to another base (when the "newSup" argument // to NewModifyDNRequest() is not ""). func (l *Conn) ModifyDN(m *ModifyDNRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(m.encode()) - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(m) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationModifyDNResponse { err := GetLDAPError(packet) if err != nil { @@ -98,7 +81,5 @@ func (l *Conn) ModifyDN(m *ModifyDNRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/modify.go b/modify.go index d83e6221..7e09b507 100644 --- a/modify.go +++ b/modify.go @@ -26,10 +26,9 @@ package ldap import ( - "errors" "log" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // Change operation choices @@ -84,40 +83,43 @@ type ModifyRequest struct { } // Add appends the given attribute to the list of changes to be made -func (m *ModifyRequest) Add(attrType string, attrVals []string) { - m.appendChange(AddAttribute, attrType, attrVals) +func (req *ModifyRequest) Add(attrType string, attrVals []string) { + req.appendChange(AddAttribute, attrType, attrVals) } // Delete appends the given attribute to the list of changes to be made -func (m *ModifyRequest) Delete(attrType string, attrVals []string) { - m.appendChange(DeleteAttribute, attrType, attrVals) +func (req *ModifyRequest) Delete(attrType string, attrVals []string) { + req.appendChange(DeleteAttribute, attrType, attrVals) } // Replace appends the given attribute to the list of changes to be made -func (m *ModifyRequest) Replace(attrType string, attrVals []string) { - m.appendChange(ReplaceAttribute, attrType, attrVals) +func (req *ModifyRequest) Replace(attrType string, attrVals []string) { + req.appendChange(ReplaceAttribute, attrType, attrVals) } -func (m *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) { - m.Changes = append(m.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) +func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) { + req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}}) } -func (m ModifyRequest) encode() *ber.Packet { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN")) +func (req *ModifyRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN")) changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes") - for _, change := range m.Changes { + for _, change := range req.Changes { changes.AppendChild(change.encode()) } - request.AppendChild(changes) - return request + pkt.AppendChild(changes) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // NewModifyRequest creates a modify request for the given DN -func NewModifyRequest( - dn string, - controls []Control, -) *ModifyRequest { +func NewModifyRequest(dn string, controls []Control) *ModifyRequest { return &ModifyRequest{ DN: dn, Controls: controls, @@ -126,39 +128,17 @@ func NewModifyRequest( // Modify performs the ModifyRequest func (l *Conn) Modify(modifyRequest *ModifyRequest) error { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - packet.AppendChild(modifyRequest.encode()) - if len(modifyRequest.Controls) > 0 { - packet.AppendChild(encodeControls(modifyRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(modifyRequest) if err != nil { return err } defer l.finishMessage(msgCtx) - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return err - } - ber.PrintPacket(packet) - } - if packet.Children[1].Tag == ApplicationModifyResponse { err := GetLDAPError(packet) if err != nil { @@ -167,7 +147,5 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { } else { log.Printf("Unexpected Response: %d", packet.Children[1].Tag) } - - l.Debug.Printf("%d: returning", msgCtx.id) return nil } diff --git a/passwdmodify.go b/passwdmodify.go index 06bc21db..bfaceff3 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -6,10 +6,9 @@ package ldap import ( - "errors" "fmt" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) const ( @@ -36,25 +35,28 @@ type PasswordModifyResult struct { Referral string } -func (r *PasswordModifyRequest) encode() (*ber.Packet, error) { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Password Modify Extended Operation") - request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, passwordModifyOID, "Extended Request Name: Password Modify OID")) +func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Password Modify Extended Operation") + pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, passwordModifyOID, "Extended Request Name: Password Modify OID")) + extendedRequestValue := ber.Encode(ber.ClassContext, ber.TypePrimitive, 1, nil, "Extended Request Value: Password Modify Request") passwordModifyRequestValue := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Password Modify Request") - if r.UserIdentity != "" { - passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, r.UserIdentity, "User Identity")) + if req.UserIdentity != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.UserIdentity, "User Identity")) } - if r.OldPassword != "" { - passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 1, r.OldPassword, "Old Password")) + if req.OldPassword != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 1, req.OldPassword, "Old Password")) } - if r.NewPassword != "" { - passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 2, r.NewPassword, "New Password")) + if req.NewPassword != "" { + passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 2, req.NewPassword, "New Password")) } - extendedRequestValue.AppendChild(passwordModifyRequestValue) - request.AppendChild(extendedRequestValue) - return request, nil + pkt.AppendChild(extendedRequestValue) + + envelope.AppendChild(pkt) + + return nil } // NewPasswordModifyRequest creates a new PasswordModifyRequest @@ -84,46 +86,18 @@ func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPasswo // PasswordModify performs the modification request func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - - encodedPasswordModifyRequest, err := passwordModifyRequest.encode() - if err != nil { - return nil, err - } - packet.AppendChild(encodedPasswordModifyRequest) - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(passwordModifyRequest) if err != nil { return nil, err } defer l.finishMessage(msgCtx) - result := &PasswordModifyResult{} - - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + packet, err := l.readPacket(msgCtx) if err != nil { return nil, err } - if packet == nil { - return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message")) - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return nil, err - } - ber.PrintPacket(packet) - } + result := &PasswordModifyResult{} if packet.Children[1].Tag == ApplicationExtendedResponse { err := GetLDAPError(packet) diff --git a/request.go b/request.go new file mode 100644 index 00000000..814e29fe --- /dev/null +++ b/request.go @@ -0,0 +1,66 @@ +package ldap + +import ( + "errors" + + ber "gopkg.in/asn1-ber.v1" +) + +var ( + errRespChanClosed = errors.New("ldap: response channel closed") + errCouldNotRetMsg = errors.New("ldap: could not retrieve message") +) + +type request interface { + appendTo(*ber.Packet) error +} + +type requestFunc func(*ber.Packet) error + +func (f requestFunc) appendTo(p *ber.Packet) error { + return f(p) +} + +func (l *Conn) doRequest(req request) (*messageContext, error) { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + if err := req.appendTo(packet); err != nil { + return nil, err + } + + if l.Debug { + ber.PrintPacket(packet) + } + + msgCtx, err := l.sendMessage(packet) + if err != nil { + return nil, err + } + l.Debug.Printf("%d: returning", msgCtx.id) + return msgCtx, nil +} + +func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) { + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + return nil, NewError(ErrorNetwork, errRespChanClosed) + } + packet, err := packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return nil, err + } + + if packet == nil { + return nil, NewError(ErrorNetwork, errCouldNotRetMsg) + } + + if l.Debug { + if err = addLDAPDescriptions(packet); err != nil { + return nil, err + } + ber.PrintPacket(packet) + } + return packet, nil +} diff --git a/search.go b/search.go index 3aa6dac0..51eb7dc6 100644 --- a/search.go +++ b/search.go @@ -61,7 +61,7 @@ import ( "sort" "strings" - "gopkg.in/asn1-ber.v1" + ber "gopkg.in/asn1-ber.v1" ) // scope choices @@ -246,27 +246,33 @@ type SearchRequest struct { Controls []Control } -func (s *SearchRequest) encode() (*ber.Packet, error) { - request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") - request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, s.BaseDN, "Base DN")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.Scope), "Scope")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.DerefAliases), "Deref Aliases")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.SizeLimit), "Size Limit")) - request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.TimeLimit), "Time Limit")) - request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, s.TypesOnly, "Types Only")) +func (req *SearchRequest) appendTo(envelope *ber.Packet) error { + pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") + pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.Scope), "Scope")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.DerefAliases), "Deref Aliases")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.SizeLimit), "Size Limit")) + pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.TimeLimit), "Time Limit")) + pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.TypesOnly, "Types Only")) // compile and encode filter - filterPacket, err := CompileFilter(s.Filter) + filterPacket, err := CompileFilter(req.Filter) if err != nil { - return nil, err + return err } - request.AppendChild(filterPacket) + pkt.AppendChild(filterPacket) // encode attributes attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") - for _, attribute := range s.Attributes { + for _, attribute := range req.Attributes { attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) } - request.AppendChild(attributesPacket) - return request, nil + pkt.AppendChild(attributesPacket) + + envelope.AppendChild(pkt) + if len(req.Controls) > 0 { + envelope.AppendChild(encodeControls(req.Controls)) + } + + return nil } // NewSearchRequest creates a new search request @@ -366,22 +372,7 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) // Search performs the given search request func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - // encode search request - encodedSearchRequest, err := searchRequest.encode() - if err != nil { - return nil, err - } - packet.AppendChild(encodedSearchRequest) - // encode search controls - if len(searchRequest.Controls) > 0 { - packet.AppendChild(encodeControls(searchRequest.Controls)) - } - - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) + msgCtx, err := l.doRequest(searchRequest) if err != nil { return nil, err } @@ -392,26 +383,12 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { Referrals: make([]string, 0), Controls: make([]Control, 0)} - foundSearchResultDone := false - for !foundSearchResultDone { - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - return nil, NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + for { + packet, err := l.readPacket(msgCtx) if err != nil { return nil, err } - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - return nil, err - } - ber.PrintPacket(packet) - } - switch packet.Children[1].Tag { case 4: entry := new(Entry) @@ -440,11 +417,9 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { result.Controls = append(result.Controls, decodedChild) } } - foundSearchResultDone = true + return result, nil case 19: result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string)) } } - l.Debug.Printf("%d: returning", msgCtx.id) - return result, nil }