Skip to content

Commit

Permalink
Support multi method authentication. (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Oct 4, 2024
1 parent f9ae42a commit cfa7caf
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 82 deletions.
11 changes: 11 additions & 0 deletions src/Tmds.Ssh/AuthResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// This file is part of Tmds.Ssh which is released under MIT.
// See file LICENSE for full license details.

namespace Tmds.Ssh;

enum AuthResult
{
Failure,
Success,
Partial
}
1 change: 1 addition & 0 deletions src/Tmds.Ssh/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ static class Constants
public const int MaxKeyLength = 1024; // Arbitrary limit, may be increased.
public const int MaxMPIntLength = 1024; // Arbitrary limit, may be increased.
public const int MaxBannerPackets = 1024; // Abitrary limit
public const int MaxPartialAuths = 16; // Abitrary limit
}
6 changes: 6 additions & 0 deletions src/Tmds.Ssh/SshClientLogger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ public static void PacketSend(this ILogger<SshClient> logger, ReadOnlyPacket pac
Message = "Connection aborted")]
public static partial void ConnectionAborted(this ILogger<SshClient> logger, Exception exception);

[LoggerMessage(
EventId = 29,
Level = LogLevel.Information,
Message = "Auth {AuthMethod} method partial success, continue with: {AllowedMethods}")]
public static partial void PartialSuccessAuth(this ILogger<SshClient> logger, Name authMethod, Name[] allowedMethods);

struct PacketPayload // TODO: implement ISpanFormattable
{
private ReadOnlyPacket _packet;
Expand Down
84 changes: 68 additions & 16 deletions src/Tmds.Ssh/SshSession.Authentication.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,46 +26,98 @@ private async Task AuthenticateAsync(SshConnection connection, CancellationToken

UserAuthContext context = new UserAuthContext(connection, _settings.UserName, _settings.PublicKeyAcceptedAlgorithms, _settings.MinimumRSAKeySize, Logger);

bool authSuccess;

int partialAuthAttempts = 0;
// Try credentials.
foreach (var credential in _settings.CredentialsOrDefault)
List<Credential> credentials = new(_settings.CredentialsOrDefault);
for (int i = 0; i < credentials.Count; i++)
{
Credential credential = credentials[i];

AuthResult authResult = AuthResult.Failure;
bool? methodAccepted;
Name method;
if (credential is PasswordCredential passwordCredential)
{
authSuccess = await PasswordAuth.TryAuthenticate(passwordCredential, context, ConnectionInfo, Logger, ct).ConfigureAwait(false);
if (TryMethod(AlgorithmNames.Password))
{
authResult = await PasswordAuth.TryAuthenticate(passwordCredential, context, ConnectionInfo, Logger, ct).ConfigureAwait(false);
}
}
else if (credential is PrivateKeyCredential keyCredential)
{
authSuccess = await PublicKeyAuth.TryAuthenticate(keyCredential, context, ConnectionInfo, Logger, ct).ConfigureAwait(false);
if (TryMethod(AlgorithmNames.PublicKey))
{
authResult = await PublicKeyAuth.TryAuthenticate(keyCredential, context, ConnectionInfo, Logger, ct).ConfigureAwait(false);
}
}
else if (credential is KerberosCredential kerberosCredential)
{
authSuccess = await GssApiAuth.TryAuthenticate(kerberosCredential, context, ConnectionInfo, Logger, ct).ConfigureAwait(false);
if (TryMethod(AlgorithmNames.GssApiWithMic))
{
authResult = await GssApiAuth.TryAuthenticate(kerberosCredential, context, ConnectionInfo, Logger, ct).ConfigureAwait(false);
}
}
else
{
throw new NotImplementedException("Unsupported credential type: " + credential.GetType().FullName);
}

if (authSuccess)
// We didn't try the method, skip to the next credential.
if (methodAccepted == false)
{
continue;
}

if (authResult == AuthResult.Success)
{
return;
}

if (authResult == AuthResult.Failure)
{
// If we didn't know if the method was accepted before, check the context which was updated by SSH_MSG_USERAUTH_FAILURE.
if (methodAccepted == null)
{
methodAccepted = context.IsMethodAccepted(method);
}
// Don't try a failed credential again if it matched an accepted method.
if (methodAccepted == true || methodAccepted == null)
{
credentials.RemoveAt(i);
i--;
}
}

if (authResult == AuthResult.Partial)
{
// The same method may be requested multiple times with partial auth.
// The client should be providing different credentials when the same method is tried again.
// Move the current credential to the back of the list so we don't try it as the first one once more.
credentials.RemoveAt(i);
credentials.Add(credential);

// Start over (but limit the amount of times we want to start over to avoid an infinite loop).
if (++partialAuthAttempts == Constants.MaxPartialAuths)
{
break;
}
else
{
i = -1;
}
}

bool TryMethod(Name credentialMethod)
{
method = credentialMethod;
methodAccepted = context.IsMethodAccepted(method);
return methodAccepted != false;
}
}

throw new ConnectFailedException(ConnectFailedReason.AuthenticationFailed, "Authentication failed.", ConnectionInfo);
}

private static Name GetAuthenticationMethod(Credential credential)
=> credential switch
{
PasswordCredential => AlgorithmNames.Password,
PrivateKeyCredential => AlgorithmNames.PublicKey,
KerberosCredential => AlgorithmNames.GssApiWithMic,
_ => throw new NotImplementedException("Unsupported credential type: " + credential.GetType().FullName)
};

private static Packet CreateServiceRequestMessage(SequencePool sequencePool)
{
using var packet = sequencePool.RentPacket();
Expand Down
51 changes: 26 additions & 25 deletions src/Tmds.Ssh/UserAuthContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// See file LICENSE for full license details.

using Microsoft.Extensions.Logging;
using System.Diagnostics;

namespace Tmds.Ssh;

Expand All @@ -11,6 +12,7 @@ sealed class UserAuthContext
private readonly ILogger<SshClient> _logger;
private int _bannerPacketCount = 0;
private Name[]? _allowedAuthentications;
private bool _wasPartial;
private Name _currentMethod;

public UserAuthContext(SshConnection connection, string userName, List<Name> publicKeyAcceptedAlgorithms, int minimumRSAKeySize, ILogger<SshClient> logger)
Expand Down Expand Up @@ -92,50 +94,47 @@ public ValueTask SendPacketAsync(Packet packet, CancellationToken ct)
return _connection.SendPacketAsync(packet, ct);
}

public async Task<bool> ReceiveAuthIsSuccesfullAsync(CancellationToken ct)
public async Task<AuthResult> ReceiveAuthResultAsync(CancellationToken ct)
{
using Packet response = await ReceivePacketAsync(ct).ConfigureAwait(false);
bool isSuccess = IsAuthSuccesfull(response);
return isSuccess;
return GetAuthResult(response);
}

private bool IsAuthSuccesfull(ReadOnlyPacket packet)
private AuthResult GetAuthResult(ReadOnlyPacket packet)
{
var reader = packet.GetReader();
MessageId b = reader.ReadMessageId();
switch (b)
{
case MessageId.SSH_MSG_USERAUTH_SUCCESS:
return true;
return AuthResult.Success;
case MessageId.SSH_MSG_USERAUTH_FAILURE:
return false;
return _wasPartial ? AuthResult.Partial : AuthResult.Failure;
default:
ThrowHelper.ThrowProtocolUnexpectedValue();
return false;
return AuthResult.Failure;
}
}

public bool TryStartAuth(Name method)
public bool? IsMethodAccepted(Name method)
{
if (!_currentMethod.IsEmpty)
if (_allowedAuthentications == null)
{
throw new InvalidOperationException("Already authenticating using an other method.");
return null;
}
return Array.IndexOf(_allowedAuthentications, method) >= 0;
}

bool shouldStart = !SkipMethod(method);

if (shouldStart)
public void StartAuth(Name method)
{
if (!_currentMethod.IsEmpty)
{
_currentMethod = method;
throw new InvalidOperationException("Already authenticating using an other method.");
}

return shouldStart;
}
Debug.Assert(IsMethodAccepted(method) != false);

public bool SkipMethod(Name method)
{
bool isAllowed = _allowedAuthentications == null ? true : Array.IndexOf(_allowedAuthentications, method) >= 0;
return !isAllowed;
_currentMethod = method;
}

private void ParseAuthFail(ReadOnlyPacket packet)
Expand All @@ -148,13 +147,15 @@ boolean partial success
*/
reader.ReadMessageId(MessageId.SSH_MSG_USERAUTH_FAILURE);
_allowedAuthentications = reader.ReadNameList();
bool partial_success = reader.ReadBoolean();
_wasPartial = reader.ReadBoolean();

_logger.AuthMethodFailed(_currentMethod, _allowedAuthentications);

if (partial_success)
if (_wasPartial)
{
_logger.PartialSuccessAuth(_currentMethod, _allowedAuthentications);
}
else
{
throw new NotImplementedException("Partial success auth is not implemented.");
_logger.AuthMethodFailed(_currentMethod, _allowedAuthentications);
}
}
}
15 changes: 6 additions & 9 deletions src/Tmds.Ssh/UserAuthentication.GssApiAuth.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@ public sealed class GssApiAuth
// Kerberos - 1.2.840.113554.1.2.2 - This is DER encoding of the OID.
private static readonly byte[] KRB5_OID = [0x06, 0x09, 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x12, 0x01, 0x02, 0x02];

public static async Task<bool> TryAuthenticate(KerberosCredential credential, UserAuthContext context, SshConnectionInfo connectionInfo, ILogger<SshClient> logger, CancellationToken ct)
public static async Task<AuthResult> TryAuthenticate(KerberosCredential credential, UserAuthContext context, SshConnectionInfo connectionInfo, ILogger<SshClient> logger, CancellationToken ct)
{
if (!context.TryStartAuth(AlgorithmNames.GssApiWithMic))
{
return false;
}
context.StartAuth(AlgorithmNames.GssApiWithMic);

// RFC uses hostbased SPN format "service@host" but Windows SSPI needs the service/host format.
// .NET converts this format to the hostbased format expected by GSSAPI for us.
Expand All @@ -37,7 +34,7 @@ public static async Task<bool> TryAuthenticate(KerberosCredential credential, Us
bool isOidSuccess = await TryStageOid(context, logger, context.UserName, ct).ConfigureAwait(false);
if (!isOidSuccess)
{
return false;
return AuthResult.Failure;
}

var negotiateOptions = new NegotiateAuthenticationClientOptions()
Expand All @@ -61,7 +58,7 @@ public static async Task<bool> TryAuthenticate(KerberosCredential credential, Us
bool isAuthSuccess = await TryStageAuthentication(context, logger, authContext, ct).ConfigureAwait(false);
if (!isAuthSuccess)
{
return false;
return AuthResult.Failure;
}

try
Expand All @@ -75,10 +72,10 @@ public static async Task<bool> TryAuthenticate(KerberosCredential credential, Us
catch (MissingMethodException)
{
// Remove once .NET 8 is no longer the minimum and we get rid of reflection.
return false;
return AuthResult.Failure;
}

return await context.ReceiveAuthIsSuccesfullAsync(ct).ConfigureAwait(false);
return await context.ReceiveAuthResultAsync(ct).ConfigureAwait(false);
}

private static async Task<bool> TryStageOid(UserAuthContext context, ILogger logger, string userName, CancellationToken ct)
Expand Down
15 changes: 4 additions & 11 deletions src/Tmds.Ssh/UserAuthentication.PasswordAuth.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@ partial class UserAuthentication
// https://datatracker.ietf.org/doc/html/rfc4252 - Password Authentication Method: "password"
public sealed class PasswordAuth
{
public static async Task<bool> TryAuthenticate(PasswordCredential passwordCredential, UserAuthContext context, SshConnectionInfo connectionInfo, ILogger<SshClient> logger, CancellationToken ct)
public static async Task<AuthResult> TryAuthenticate(PasswordCredential passwordCredential, UserAuthContext context, SshConnectionInfo connectionInfo, ILogger<SshClient> logger, CancellationToken ct)
{
string? password = passwordCredential.GetPassword();
if (password is not null)
{
if (!context.TryStartAuth(AlgorithmNames.Password))
{
return false;
}
context.StartAuth(AlgorithmNames.Password);

logger.PasswordAuth();

Expand All @@ -27,14 +24,10 @@ public static async Task<bool> TryAuthenticate(PasswordCredential passwordCreden
await context.SendPacketAsync(userAuthMsg.Move(), ct).ConfigureAwait(false);
}

bool success = await context.ReceiveAuthIsSuccesfullAsync(ct).ConfigureAwait(false);
if (success)
{
return true;
}
return await context.ReceiveAuthResultAsync(ct).ConfigureAwait(false);
}

return false;
return AuthResult.Failure;
}

private static Packet CreatePasswordRequestMessage(SequencePool sequencePool, string userName, string password)
Expand Down
Loading

0 comments on commit cfa7caf

Please sign in to comment.