Skip to content

Commit

Permalink
Rework auth errors handling during reconnect
Browse files Browse the repository at this point in the history
Move the detection/handling down to sendConnect when we know we
got an -ERR from the server. With the use of a boolean we know
if we should abort the reconnect attempts in the doReconnect loop.

Signed-off-by: Ivan Kozlovic <[email protected]>
  • Loading branch information
kozlovic committed Aug 7, 2019
1 parent 8f27558 commit 0b78893
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 35 deletions.
86 changes: 51 additions & 35 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ type Conn struct {
ps *parseState
ptmr *time.Timer
pout int
ar bool // abort reconnect

// New style response handler
respSub string // The wildcard subject
Expand Down Expand Up @@ -1676,6 +1677,17 @@ func (nc *Conn) sendConnect() error {
if strings.HasPrefix(proto, _ERR_OP_) {
// Remove -ERR, trim spaces and quotes, and convert to lower case.
proto = normalizeErr(proto)

// Check if this is an auth error
if authErr := checkAuthError(strings.ToLower(proto)); authErr != nil {
// This will schedule an async error if we are in reconnect,
// and keep track of the auth error for the current server.
// If we have got the same error twice, this sets nc.ar to true to
// indicate that the reconnect should be aborted (will be checked
// in doReconnect()).
nc.processAuthError(authErr)
}

return errors.New("nats: " + proto)
}

Expand Down Expand Up @@ -1849,19 +1861,11 @@ func (nc *Conn) doReconnect(err error) {

// Process connect logic
if nc.err = nc.processConnectInit(); nc.err != nil {
// If we have a lastErr recorded for this server
// do the normal processing here. We might get closed.
if nc.current.lastErr != nil {
// Remove possible "nats: " prefix
errStr := strings.TrimPrefix(nc.err.Error(), "nats: ")
nc.mu.Unlock()
nc.processErr(errStr)
nc.mu.Lock()
if nc.isClosed() {
break
}
// Check if we should abort reconnect. If so, break out
// of the loop and connection will be closed.
if nc.ar {
break
}

nc.status = RECONNECTING
// Reset the buffered writer to the pending buffer
// (was set to a buffered writer on nc.conn in createConn)
Expand Down Expand Up @@ -2265,35 +2269,24 @@ func (nc *Conn) processPermissionsViolation(err string) {
nc.mu.Unlock()
}

// processAuthorizationViolation is called when the server signals a user
// authorization violation.
func (nc *Conn) processAuthorizationViolation(err string) {
nc.processAuthError(ErrAuthorization)
}

// processAuthenticationExpired is called when the server signals a user
// authorization has expired.
func (nc *Conn) processAuthenticationExpired(err string) {
nc.processAuthError(ErrAuthExpired)
}

// processAuthError generally processing for auth errors. We want to do retries
// unless we get the same error again. This allows us for instance to swap credentials
// and have the app reconnect, but if nothing is changing we should bail.
func (nc *Conn) processAuthError(err error) {
nc.mu.Lock()
// This function will return true if the connection should be closed, false otherwise.
// Connection lock is held on entry
func (nc *Conn) processAuthError(err error) bool {
nc.err = err
if nc.Opts.AsyncErrorCB != nil {
if !nc.initc && nc.Opts.AsyncErrorCB != nil {
nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, err) })
}
// We should give up if we tried twice on this server and got the
// same error.
if nc.current.lastErr == err {
defer nc.Close()
nc.ar = true
} else {
nc.current.lastErr = err
}
nc.mu.Unlock()
return nc.ar
}

// flusher is a separate Go routine that will process flush requests for the write
Expand Down Expand Up @@ -2463,6 +2456,18 @@ func (nc *Conn) LastError() error {
return err
}

// Check if the given error string is an auth error, and if so returns
// the corresponding ErrXXX error, nil otherwise
func checkAuthError(e string) error {
if strings.HasPrefix(e, AUTHORIZATION_ERR) {
return ErrAuthorization
}
if strings.HasPrefix(e, AUTHENTICATION_EXPIRED_ERR) {
return ErrAuthExpired
}
return nil
}

// processErr processes any error messages from the server and
// sets the connection's lastError.
func (nc *Conn) processErr(ie string) {
Expand All @@ -2471,19 +2476,24 @@ func (nc *Conn) processErr(ie string) {
// convert to lower case.
e := strings.ToLower(ne)

close := false

// FIXME(dlc) - process Slow Consumer signals special.
if e == STALE_CONNECTION {
nc.processOpErr(ErrStaleConnection)
} else if strings.HasPrefix(e, PERMISSIONS_ERR) {
nc.processPermissionsViolation(ne)
} else if strings.HasPrefix(e, AUTHORIZATION_ERR) {
nc.processAuthorizationViolation(ne)
} else if strings.HasPrefix(e, AUTHENTICATION_EXPIRED_ERR) {
nc.processAuthenticationExpired(ne)
} else if authErr := checkAuthError(e); authErr != nil {
nc.mu.Lock()
close = nc.processAuthError(authErr)
nc.mu.Unlock()
} else {
close = true
nc.mu.Lock()
nc.err = errors.New("nats: " + ne)
nc.mu.Unlock()
}
if close {
nc.Close()
}
}
Expand Down Expand Up @@ -3619,8 +3629,14 @@ func (nc *Conn) close(status Status, doCBs bool, err error) {
nc.stopPingTimer()
nc.ptmr = nil

// Go ahead and make sure we have flushed the outbound
if nc.conn != nil {
// Need to close and set tcp conn to nil if reconnect loop has stopped,
// otherwise we would incorrectly invoke Disconnect handler (if set)
// down below.
if nc.ar && nc.conn != nil {
nc.conn.Close()
nc.conn = nil
} else if nc.conn != nil {
// Go ahead and make sure we have flushed the outbound
nc.bw.Flush()
defer nc.conn.Close()
}
Expand Down
61 changes: 61 additions & 0 deletions nats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2023,3 +2023,64 @@ func BenchmarkNextMsgNoTimeout(b *testing.B) {
}
}
}

func TestAuthErrorOnReconnect(t *testing.T) {
// This is a bit of an artificial test, but it is to demonstrate
// that if the client is disconnected from a server (not due to an auth error),
// it will still correctly stop the reconnection logic if it gets twice an
// auth error from the same server.

o1 := natsserver.DefaultTestOptions
o1.Port = -1
s1 := RunServerWithOptions(&o1)
defer s1.Shutdown()

o2 := natsserver.DefaultTestOptions
o2.Port = -1
o2.Username = "ivan"
o2.Password = "pwd"
s2 := RunServerWithOptions(&o2)
defer s2.Shutdown()

dch := make(chan bool)
cch := make(chan bool)

urls := fmt.Sprintf("nats://%s:%d, nats://%s:%d", o1.Host, o1.Port, o2.Host, o2.Port)
nc, err := Connect(urls,
ReconnectWait(25*time.Millisecond),
MaxReconnects(-1),
DontRandomize(),
DisconnectErrHandler(func(_ *Conn, e error) {
dch <- true
}),
ClosedHandler(func(_ *Conn) {
cch <- true
}))
if err != nil {
t.Fatalf("Expected to connect, got err: %v\n", err)
}
defer nc.Close()

s1.Shutdown()

// wait for disconnect
if e := WaitTime(dch, 5*time.Second); e != nil {
t.Fatal("Did not receive a disconnect callback message")
}

// Wait for ClosedCB
if e := WaitTime(cch, 5*time.Second); e != nil {
reconnects := nc.Stats().Reconnects
t.Fatalf("Did not receive a closed callback message, #reconnects: %v", reconnects)
}

// We should have stopped after 2 reconnects.
if reconnects := nc.Stats().Reconnects; reconnects != 2 {
t.Fatalf("Expected 2 reconnects, got %v", reconnects)
}

// Expect connection to be closed...
if !nc.IsClosed() {
t.Fatalf("Wrong status: %d\n", nc.Status())
}
}

0 comments on commit 0b78893

Please sign in to comment.