diff --git a/modify.go b/modify.go index f6fe3859..8b379558 100644 --- a/modify.go +++ b/modify.go @@ -135,6 +135,8 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { type ModifyResult struct { // Controls are the returned controls Controls []Control + // Referral is the returned referral + Referral string } // ModifyWithResult performs the ModifyRequest and returns the result @@ -157,9 +159,14 @@ func (l *Conn) ModifyWithResult(modifyRequest *ModifyRequest) (*ModifyResult, er switch packet.Children[1].Tag { case ApplicationModifyResponse: - err := GetLDAPError(packet) - if err != nil { - return nil, err + if err = GetLDAPError(packet); err != nil { + if referral, referralErr := getReferral(err, packet); referralErr != nil { + return result, referralErr + } else { + result.Referral = referral + } + + return result, err } if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { diff --git a/passwdmodify.go b/passwdmodify.go index 62a11084..e776e3b3 100644 --- a/passwdmodify.go +++ b/passwdmodify.go @@ -95,15 +95,13 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa result := &PasswordModifyResult{} if packet.Children[1].Tag == ApplicationExtendedResponse { - err := GetLDAPError(packet) - if err != nil { - if IsErrorWithCode(err, LDAPResultReferral) { - for _, child := range packet.Children[1].Children { - if child.Tag == 3 { - result.Referral = child.Children[0].Value.(string) - } - } + if err = GetLDAPError(packet); err != nil { + if referral, referralErr := getReferral(err, packet); referralErr != nil { + return result, referralErr + } else { + result.Referral = referral } + return result, err } } else { @@ -112,10 +110,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa extendedResponse := packet.Children[1] for _, child := range extendedResponse.Children { - if child.Tag == 11 { + if child.Tag == ber.TagEmbeddedPDV { passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes()) if len(passwordModifyResponseValue.Children) == 1 { - if passwordModifyResponseValue.Children[0].Tag == 0 { + if passwordModifyResponseValue.Children[0].Tag == ber.TagEOC { result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes()) } } diff --git a/request.go b/request.go index 4ea31e90..adc3b1c2 100644 --- a/request.go +++ b/request.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" ber "github.com/go-asn1-ber/asn1-ber" ) @@ -69,3 +70,29 @@ func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) { } return packet, nil } + +func getReferral(err error, packet *ber.Packet) (referral string, e error) { + if !IsErrorWithCode(err, LDAPResultReferral) { + return "", nil + } + + if len(packet.Children) < 2 { + return "", fmt.Errorf("ldap: returned error indicates the packet contains a referral but it doesn't have sufficient child nodes: %w", err) + } + + if packet.Children[1].Tag != ber.TagObjectDescriptor { + return "", fmt.Errorf("ldap: returned error indicates the packet contains a referral but the relevant child node isn't an object descriptor: %w", err) + } + + var ok bool + + for _, child := range packet.Children[1].Children { + if child.Tag == ber.TagBitString && len(child.Children) >= 1 { + if referral, ok = child.Children[0].Value.(string); ok { + return referral, nil + } + } + } + + return "", fmt.Errorf("ldap: returned error indicates the packet contains a referral but the referral couldn't be decoded: %w", err) +} diff --git a/v3/modify.go b/v3/modify.go index f6fe3859..8b379558 100644 --- a/v3/modify.go +++ b/v3/modify.go @@ -135,6 +135,8 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error { type ModifyResult struct { // Controls are the returned controls Controls []Control + // Referral is the returned referral + Referral string } // ModifyWithResult performs the ModifyRequest and returns the result @@ -157,9 +159,14 @@ func (l *Conn) ModifyWithResult(modifyRequest *ModifyRequest) (*ModifyResult, er switch packet.Children[1].Tag { case ApplicationModifyResponse: - err := GetLDAPError(packet) - if err != nil { - return nil, err + if err = GetLDAPError(packet); err != nil { + if referral, referralErr := getReferral(err, packet); referralErr != nil { + return result, referralErr + } else { + result.Referral = referral + } + + return result, err } if len(packet.Children) == 3 { for _, child := range packet.Children[2].Children { diff --git a/v3/passwdmodify.go b/v3/passwdmodify.go index 62a11084..e776e3b3 100644 --- a/v3/passwdmodify.go +++ b/v3/passwdmodify.go @@ -95,15 +95,13 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa result := &PasswordModifyResult{} if packet.Children[1].Tag == ApplicationExtendedResponse { - err := GetLDAPError(packet) - if err != nil { - if IsErrorWithCode(err, LDAPResultReferral) { - for _, child := range packet.Children[1].Children { - if child.Tag == 3 { - result.Referral = child.Children[0].Value.(string) - } - } + if err = GetLDAPError(packet); err != nil { + if referral, referralErr := getReferral(err, packet); referralErr != nil { + return result, referralErr + } else { + result.Referral = referral } + return result, err } } else { @@ -112,10 +110,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa extendedResponse := packet.Children[1] for _, child := range extendedResponse.Children { - if child.Tag == 11 { + if child.Tag == ber.TagEmbeddedPDV { passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes()) if len(passwordModifyResponseValue.Children) == 1 { - if passwordModifyResponseValue.Children[0].Tag == 0 { + if passwordModifyResponseValue.Children[0].Tag == ber.TagEOC { result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes()) } } diff --git a/v3/request.go b/v3/request.go index 4ea31e90..adc3b1c2 100644 --- a/v3/request.go +++ b/v3/request.go @@ -2,6 +2,7 @@ package ldap import ( "errors" + "fmt" ber "github.com/go-asn1-ber/asn1-ber" ) @@ -69,3 +70,29 @@ func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) { } return packet, nil } + +func getReferral(err error, packet *ber.Packet) (referral string, e error) { + if !IsErrorWithCode(err, LDAPResultReferral) { + return "", nil + } + + if len(packet.Children) < 2 { + return "", fmt.Errorf("ldap: returned error indicates the packet contains a referral but it doesn't have sufficient child nodes: %w", err) + } + + if packet.Children[1].Tag != ber.TagObjectDescriptor { + return "", fmt.Errorf("ldap: returned error indicates the packet contains a referral but the relevant child node isn't an object descriptor: %w", err) + } + + var ok bool + + for _, child := range packet.Children[1].Children { + if child.Tag == ber.TagBitString && len(child.Children) >= 1 { + if referral, ok = child.Children[0].Value.(string); ok { + return referral, nil + } + } + } + + return "", fmt.Errorf("ldap: returned error indicates the packet contains a referral but the referral couldn't be decoded: %w", err) +}