diff --git a/src/NATS.Client/Connection.cs b/src/NATS.Client/Connection.cs index 8d531ffa0..70884bf3c 100644 --- a/src/NATS.Client/Connection.cs +++ b/src/NATS.Client/Connection.cs @@ -1177,7 +1177,7 @@ internal bool connect(Srv s, out Exception exToThrow) { exToThrow = null; - NATSConnectionException natsAuthEx = null; + String lastAuthExMessage = null; for(var i = 0; i < 6; i++) //Precaution to not end up in server returning ExTypeA, ExTypeB, ExTypeA etc. { @@ -1196,16 +1196,18 @@ internal bool connect(Srv s, out Exception exToThrow) } catch (NATSConnectionException ex) { - if (!ex.IsAuthenticationOrAuthorizationError()) + string message = ex.Message.ToLower(); + if (!NATSException.IsAuthenticationOrAuthorizationError(message, true)) { throw; } ScheduleErrorEvent(s, ex); - if (natsAuthEx == null || !natsAuthEx.Message.Equals(ex.Message, StringComparison.OrdinalIgnoreCase)) + // avoiding double the same + if (lastAuthExMessage == null || !lastAuthExMessage.Equals(message)) { - natsAuthEx = ex; + lastAuthExMessage = message; continue; } @@ -2448,9 +2450,6 @@ public Exception LastError // sets the connection's lastError. internal void processErr(MemoryStream errorStream) { - bool invokeDelegates = false; - Exception ex = null; - string s = getNormalizedError(errorStream); if (IC.STALE_CONNECTION.Equals(s)) @@ -2466,7 +2465,9 @@ internal void processErr(MemoryStream errorStream) } else { - ex = new NATSException("Error from processErr(): " + s); + NATSException ex = new NATSException("Error from processErr(): " + s); + bool invokeDelegates = false; + lock (mu) { lastEx = ex; @@ -2478,6 +2479,11 @@ internal void processErr(MemoryStream errorStream) } close(ConnState.CLOSED, invokeDelegates, ex); + + if (NATSException.IsAuthenticationOrAuthorizationError(s)) + { + processReconnect(); + } } } diff --git a/src/NATS.Client/Exceptions.cs b/src/NATS.Client/Exceptions.cs index b7f22f0a3..4077c44fc 100644 --- a/src/NATS.Client/Exceptions.cs +++ b/src/NATS.Client/Exceptions.cs @@ -24,6 +24,14 @@ public class NATSException : Exception public NATSException() : base() { } public NATSException(string err) : base (err) {} public NATSException(string err, Exception innerEx) : base(err, innerEx) { } + + public static bool IsAuthenticationOrAuthorizationError(string message, bool alreadyLowered = false) + { + string lowerMessage = alreadyLowered ? message : message.ToLower(); + return lowerMessage.Contains("user authentication") + || lowerMessage.Contains("authorization violation") + || lowerMessage.Contains("authentication expired"); + } } /// @@ -33,14 +41,6 @@ public class NATSConnectionException : NATSException { public NATSConnectionException(string err) : base(err) { } public NATSConnectionException(string err, Exception innerEx) : base(err, innerEx) { } - - public bool IsAuthenticationOrAuthorizationError() - { - string lowerMessage = Message.ToLower(); - return lowerMessage.Contains("user authentication") - || lowerMessage.Contains("authorization violation") - || lowerMessage.Contains("authentication expired"); - } } /// diff --git a/src/Tests/IntegrationTests/TestAuthorization.cs b/src/Tests/IntegrationTests/TestAuthorization.cs index 5856c7211..069691fa7 100644 --- a/src/Tests/IntegrationTests/TestAuthorization.cs +++ b/src/Tests/IntegrationTests/TestAuthorization.cs @@ -18,7 +18,9 @@ using System.Threading; using NATS.Client; using NATS.Client.Internals; +using UnitTests; using Xunit; +using Xunit.Abstractions; namespace IntegrationTests { @@ -27,7 +29,15 @@ namespace IntegrationTests /// public class TestAuthorization : TestSuite { - public TestAuthorization(AuthorizationSuiteContext context) : base(context) {} + private readonly ITestOutputHelper output; + + public TestAuthorization(ITestOutputHelper output, AuthorizationSuiteContext context) : base(context) + { + this.output = output; + Console.SetOut(new TestBase.ConsoleWriter(output)); + } + + // public TestAuthorization(AuthorizationSuiteContext context) : base(context) {} int hitDisconnect; @@ -261,23 +271,34 @@ public void TestRealUserAuthenticationExpired() string credsFile = Path.GetTempFileName(); File.WriteAllText(credsFile, cred); - CountdownEvent userAuthenticationExpired = new CountdownEvent(1); + CountdownEvent userAuthenticationExpiredCde = new CountdownEvent(1); + CountdownEvent reconnectCde = new CountdownEvent(1); using (NATSServer.CreateWithConfig(Context.Server3.Port, "operatorJnatsTest.conf")) { var opts = Context.GetTestOptionsWithDefaultTimeout(Context.Server3.Port); opts.SetUserCredentials(credsFile); + opts.MaxReconnect = 1; opts.DisconnectedEventHandler += (sender, e) => { if (e.Error.ToString().Contains("user authentication expired")) { - userAuthenticationExpired.Signal(); + userAuthenticationExpiredCde.Signal(); + } + }; + opts.ReconnectedEventHandler += (sender, e) => + { + if (userAuthenticationExpiredCde.IsSet) + { + reconnectCde.Signal(); } }; - IConnection c = Context.ConnectionFactory.CreateConnection(opts); - userAuthenticationExpired.Wait(wait); - Assert.True(userAuthenticationExpired.IsSet); + IConnection c = Context.ConnectionFactory.CreateConnection(opts, true); + userAuthenticationExpiredCde.Wait(wait); + Assert.True(userAuthenticationExpiredCde.IsSet); + reconnectCde.Wait(wait); + Assert.True(reconnectCde.IsSet); } }