Skip to content

Commit

Permalink
Implement Named Pipe connections. Fixes #454
Browse files Browse the repository at this point in the history
  • Loading branch information
bgrainger committed Jun 8, 2018
1 parent d3d4994 commit f1a73d3
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 47 deletions.
16 changes: 16 additions & 0 deletions docs/content/connection-options.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ These are the basic options that need to be defined to connect to a MySQL databa
<td></td>
<td>(Optional) The case-sensitive name of the initial database to use. This may be required if the MySQL user account only has access rights to particular databases on the server.</td>
</tr>
<tr>
<td>Protocol, ConnectionProtocol, Connection Protocol</td>
<td>Socket</td>
<td>How to connect to the MySQL Server. This option has the following values:
<ul>
<li><b>Socket</b> (default): Use TCP/IP sockets.</li>
<li><b>Unix</b>: Use a Unix socket.</li>
<li><b>Pipe</b>: Use a Windows named pipe.</li>
</ul>
</td>
</tr>
<tr>
<td>Pipe, PipeName, Pipe Name</td>
<td>MYSQL</td>
<td>The name of the Windows named pipe to use to connect to the server. You must also set <code>ConnectionProtocol=pipe</code> to used named pipes.</td>
</tr>
</table>

SSL/TLS Options
Expand Down
2 changes: 1 addition & 1 deletion src/MySqlConnector/Core/ConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ private ConnectionPool(ConnectionSettings cs)
m_hostSessions[hostName] = 0;
}

m_loadBalancer = cs.ConnectionType != ConnectionType.Tcp ? null :
m_loadBalancer = cs.ConnectionProtocol != MySqlConnectionProtocol.Sockets ? null :
cs.HostNames.Count == 1 || cs.LoadBalance == MySqlLoadBalance.FailOver ? FailOverLoadBalancer.Instance :
cs.LoadBalance == MySqlLoadBalance.Random ? RandomLoadBalancer.Instance :
cs.LoadBalance == MySqlLoadBalance.LeastConnections ? new LeastConnectionsLoadBalancer(this) :
Expand Down
23 changes: 18 additions & 5 deletions src/MySqlConnector/Core/ConnectionSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,31 @@ public ConnectionSettings(MySqlConnectionStringBuilder csb)
ConnectionStringBuilder = csb;
ConnectionString = csb.ConnectionString;

// Base Options
if (!Utility.IsWindows() && (csb.Server.StartsWith("/", StringComparison.Ordinal) || csb.Server.StartsWith("./", StringComparison.Ordinal)))
if (csb.ConnectionProtocol == MySqlConnectionProtocol.UnixSocket || (!Utility.IsWindows() && (csb.Server.StartsWith("/", StringComparison.Ordinal) || csb.Server.StartsWith("./", StringComparison.Ordinal))))
{
if (!File.Exists(csb.Server))
throw new MySqlException("Cannot find Unix Socket at " + csb.Server);
ConnectionType = ConnectionType.Unix;
ConnectionProtocol = MySqlConnectionProtocol.UnixSocket;
UnixSocket = Path.GetFullPath(csb.Server);
}
else if (csb.ConnectionProtocol == MySqlConnectionProtocol.NamedPipe)
{
ConnectionProtocol = MySqlConnectionProtocol.NamedPipe;
HostNames = (csb.Server == "." || string.Equals(csb.Server, "localhost", StringComparison.OrdinalIgnoreCase)) ? s_localhostPipeServer : new[] { csb.Server };
PipeName = csb.PipeName;
}
else if (csb.ConnectionProtocol == MySqlConnectionProtocol.SharedMemory)
{
throw new NotSupportedException("Shared Memory connections are not supported");
}
else
{
ConnectionType = ConnectionType.Tcp;
ConnectionProtocol = MySqlConnectionProtocol.Sockets;
HostNames = csb.Server.Split(',');
LoadBalance = csb.LoadBalance;
Port = (int) csb.Port;
}

UserID = csb.UserID;
Password = csb.Password;
Database = csb.Database;
Expand Down Expand Up @@ -97,10 +107,11 @@ private static MySqlGuidFormat GetEffectiveGuidFormat(MySqlGuidFormat guidFormat

// Base Options
public string ConnectionString { get; }
public ConnectionType ConnectionType { get; }
public MySqlConnectionProtocol ConnectionProtocol { get; }
public IReadOnlyList<string> HostNames { get; }
public MySqlLoadBalance LoadBalance { get; }
public int Port { get; }
public string PipeName { get; }
public string UnixSocket { get; }
public string UserID { get; }
public string Password { get; }
Expand Down Expand Up @@ -163,5 +174,7 @@ public int ConnectionTimeoutMilliseconds
return m_connectionTimeoutMilliseconds.Value;
}
}

static readonly string[] s_localhostPipeServer = { "." };
}
}
18 changes: 0 additions & 18 deletions src/MySqlConnector/Core/ConnectionType.cs

This file was deleted.

65 changes: 56 additions & 9 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
using System.Diagnostics;
using System.Globalization;
using System.IO;
#if !NETSTANDARD1_3
using System.IO.Pipes;
#endif
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
Expand Down Expand Up @@ -242,10 +245,12 @@ public async Task ConnectAsync(ConnectionSettings cs, ILoadBalancer loadBalancer
shouldRetrySsl = (sslProtocols == SslProtocols.None || (sslProtocols & SslProtocols.Tls12) == SslProtocols.Tls12) && Utility.IsWindows();

var connected = false;
if (cs.ConnectionType == ConnectionType.Tcp)
if (cs.ConnectionProtocol == MySqlConnectionProtocol.Sockets)
connected = await OpenTcpSocketAsync(cs, loadBalancer, ioBehavior, cancellationToken).ConfigureAwait(false);
else if (cs.ConnectionType == ConnectionType.Unix)
else if (cs.ConnectionProtocol == MySqlConnectionProtocol.UnixSocket)
connected = await OpenUnixSocketAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
else if (cs.ConnectionProtocol == MySqlConnectionProtocol.NamedPipe)
connected = await OpenNamedPipeAsync(cs, ioBehavior, cancellationToken).ConfigureAwait(false);
if (!connected)
{
lock (m_lock)
Expand All @@ -254,7 +259,7 @@ public async Task ConnectAsync(ConnectionSettings cs, ILoadBalancer loadBalancer
throw new MySqlException("Unable to connect to any of the specified MySQL hosts.");
}

var byteHandler = new SocketByteHandler(m_socket);
var byteHandler = m_socket != null ? (IByteHandler) new SocketByteHandler(m_socket) : new StreamByteHandler(m_stream);
m_payloadHandler = new StandardPayloadHandler(byteHandler);

payload = await ReceiveAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -755,7 +760,7 @@ private async Task<bool> OpenTcpSocketAsync(ConnectionSettings cs, ILoadBalancer
HostName = hostName;
m_tcpClient = tcpClient;
m_socket = m_tcpClient.Client;
m_networkStream = m_tcpClient.GetStream();
m_stream = m_tcpClient.GetStream();
m_socket.SetKeepAlive(cs.Keepalive);
lock (m_lock)
m_state = State.Connected;
Expand Down Expand Up @@ -806,7 +811,7 @@ private async Task<bool> OpenUnixSocketAsync(ConnectionSettings cs, IOBehavior i
if (socket.Connected)
{
m_socket = socket;
m_networkStream = new NetworkStream(socket);
m_stream = new NetworkStream(socket);

lock (m_lock)
m_state = State.Connected;
Expand All @@ -816,6 +821,48 @@ private async Task<bool> OpenUnixSocketAsync(ConnectionSettings cs, IOBehavior i
return false;
}

private Task<bool> OpenNamedPipeAsync(ConnectionSettings cs, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
#if NETSTANDARD1_3
throw new NotSupportedException("Named pipe connections are not supported in netstandard1.3");
#else
if (Log.IsInfoEnabled())
Log.Info("Session{0} connecting to NamedPipe '{1}' on Server '{2}'", m_logArguments[0], cs.PipeName, cs.HostNames[0]);
var namedPipeStream = new NamedPipeClientStream(cs.HostNames[0], cs.PipeName, PipeDirection.InOut, PipeOptions.Asynchronous);
try
{
using (cancellationToken.Register(() => namedPipeStream.Dispose()))
{
try
{
namedPipeStream.Connect();
}
catch (ObjectDisposedException ex) when (cancellationToken.IsCancellationRequested)
{
m_logArguments[1] = cs.PipeName;
Log.Info("Session{0} connect timeout expired connecting to named pipe '{1}'", m_logArguments);
throw new MySqlException("Connect Timeout expired.", ex);
}
}
}
catch (IOException)
{
namedPipeStream.Dispose();
}

if (namedPipeStream.IsConnected)
{
m_stream = namedPipeStream;

lock (m_lock)
m_state = State.Connected;
return Task.FromResult(true);
}

return Task.FromResult(false);
#endif
}

private async Task InitSslAsync(ProtocolCapabilities serverCapabilities, ConnectionSettings cs, SslProtocols sslProtocols, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
Log.Info("Session{0} initializing TLS connection", m_logArguments);
Expand Down Expand Up @@ -936,9 +983,9 @@ bool ValidateRemoteCertificate(object rcbSender, X509Certificate rcbCertificate,

SslStream sslStream;
if (clientCertificates == null)
sslStream = new SslStream(m_networkStream, false, ValidateRemoteCertificate);
sslStream = new SslStream(m_stream, false, ValidateRemoteCertificate);
else
sslStream = new SslStream(m_networkStream, false, ValidateRemoteCertificate, ValidateLocalCertificate);
sslStream = new SslStream(m_stream, false, ValidateRemoteCertificate, ValidateLocalCertificate);

var checkCertificateRevocation = cs.SslMode == MySqlSslMode.VerifyFull;

Expand Down Expand Up @@ -1056,7 +1103,7 @@ private void ShutdownSocket()
{
Log.Info("Session{0} closing stream/socket", m_logArguments);
Utility.Dispose(ref m_payloadHandler);
Utility.Dispose(ref m_networkStream);
Utility.Dispose(ref m_stream);
SafeDispose(ref m_tcpClient);
SafeDispose(ref m_socket);
#if NET45
Expand Down Expand Up @@ -1259,7 +1306,7 @@ private enum State
State m_state;
TcpClient m_tcpClient;
Socket m_socket;
NetworkStream m_networkStream;
Stream m_stream;
SslStream m_sslStream;
X509Certificate2 m_clientCertificate;
IPayloadHandler m_payloadHandler;
Expand Down
4 changes: 1 addition & 3 deletions src/MySqlConnector/MySql.Data.MySqlClient/MySqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,7 @@ public override string ConnectionString

public override ConnectionState State => m_connectionState;

public override string DataSource => (GetConnectionSettings().ConnectionType == ConnectionType.Tcp
? string.Join(",", GetConnectionSettings().HostNames)
: GetConnectionSettings().UnixSocket) ?? "";
public override string DataSource => GetConnectionSettings().ConnectionStringBuilder.Server;

public override string ServerVersion => m_session.ServerVersion.OriginalString;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
namespace MySql.Data.MySqlClient
{
/// <summary>
/// Specifies the type of connection to make to the server.
/// </summary>
public enum MySqlConnectionProtocol
{
/// <summary>
/// TCP/IP connection.
/// </summary>
Sockets = 1,
Socket = 1,
Tcp = 1,

/// <summary>
/// Named pipe connection. Only works on Windows.
/// </summary>
Pipe = 2,
NamedPipe = 2,

/// <summary>
/// Unix domain socket connection. Only works on Unix/Linux.
/// </summary>
UnixSocket = 3,
Unix = 3,

/// <summary>
/// Shared memory connection. Not currently supported.
/// </summary>
SharedMemory = 4,
Memory = 4
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ public MySqlLoadBalance LoadBalance
set => MySqlConnectionStringOption.LoadBalance.SetValue(this, value);
}

public MySqlConnectionProtocol ConnectionProtocol
{
get => MySqlConnectionStringOption.ConnectionProtocol.GetValue(this);
set => MySqlConnectionStringOption.ConnectionProtocol.SetValue(this, value);
}

public string PipeName
{
get => MySqlConnectionStringOption.PipeName.GetValue(this);
set => MySqlConnectionStringOption.PipeName.SetValue(this, value);
}

// SSL/TLS Options
public MySqlSslMode SslMode
{
Expand Down Expand Up @@ -289,6 +301,8 @@ internal abstract class MySqlConnectionStringOption
public static readonly MySqlConnectionStringOption<string> Password;
public static readonly MySqlConnectionStringOption<string> Database;
public static readonly MySqlConnectionStringOption<MySqlLoadBalance> LoadBalance;
public static readonly MySqlConnectionStringOption<MySqlConnectionProtocol> ConnectionProtocol;
public static readonly MySqlConnectionStringOption<string> PipeName;

// SSL/TLS Options
public static readonly MySqlConnectionStringOption<MySqlSslMode> SslMode;
Expand Down Expand Up @@ -377,6 +391,14 @@ static MySqlConnectionStringOption()
keys: new[] { "LoadBalance", "Load Balance" },
defaultValue: MySqlLoadBalance.RoundRobin));

AddOption(ConnectionProtocol = new MySqlConnectionStringOption<MySqlConnectionProtocol>(
keys: new[] { "ConnectionProtocol", "Connection Protocol", "Protocol" },
defaultValue: MySqlConnectionProtocol.Socket));

AddOption(PipeName = new MySqlConnectionStringOption<string>(
keys: new[] { "PipeName", "Pipe", "Pipe Name" },
defaultValue: "MYSQL"));

// SSL/TLS Options
AddOption(SslMode = new MySqlConnectionStringOption<MySqlSslMode>(
keys: new[] { "SSL Mode", "SslMode" },
Expand Down Expand Up @@ -538,14 +560,16 @@ private static T ChangeType(object objectValue)
return (T) (object) false;
}

if ((typeof(T) == typeof(MySqlLoadBalance) || typeof(T) == typeof(MySqlSslMode) || typeof(T) == typeof(MySqlDateTimeKind) || typeof(T) == typeof(MySqlGuidFormat)) && objectValue is string enumString)
if ((typeof(T) == typeof(MySqlLoadBalance) || typeof(T) == typeof(MySqlSslMode) || typeof(T) == typeof(MySqlDateTimeKind) || typeof(T) == typeof(MySqlGuidFormat) || typeof(T) == typeof(MySqlConnectionProtocol)) && objectValue is string enumString)
{
foreach (var val in Enum.GetValues(typeof(T)))
try
{
return (T) Enum.Parse(typeof(T), enumString, ignoreCase: true);
}
catch (Exception ex)
{
if (string.Equals(enumString, val.ToString(), StringComparison.OrdinalIgnoreCase))
return (T) val;
throw new InvalidOperationException("Value '{0}' not supported for option '{1}'.".FormatInvariant(objectValue, typeof(T).Name), ex);
}
throw new InvalidOperationException("Value '{0}' not supported for option '{1}'.".FormatInvariant(objectValue, typeof(T).Name));
}

return (T) Convert.ChangeType(objectValue, typeof(T), CultureInfo.InvariantCulture);
Expand Down
22 changes: 17 additions & 5 deletions src/MySqlConnector/Protocol/Serialization/StreamByteHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ public StreamByteHandler(Stream stream)

public ValueTask<int> ReadBytesAsync(ArraySegment<byte> buffer, IOBehavior ioBehavior)
{
return (ioBehavior == IOBehavior.Asynchronous) ?
new ValueTask<int>(DoReadBytesAsync(buffer)) : DoReadBytesSync(buffer);
return ioBehavior == IOBehavior.Asynchronous ? new ValueTask<int>(DoReadBytesAsync(buffer)) :
RemainingTimeout <= 0 ? ValueTaskExtensions.FromException<int>(MySqlException.CreateForTimeout()) :
m_stream.CanTimeout ? DoReadBytesSync(buffer) :
DoReadBytesSyncOverAsync(buffer);

ValueTask<int> DoReadBytesSync(ArraySegment<byte> buffer_)
{
if (RemainingTimeout <= 0)
return ValueTaskExtensions.FromException<int>(MySqlException.CreateForTimeout());

m_stream.ReadTimeout = RemainingTimeout == Constants.InfiniteTimeout ? Timeout.Infinite : RemainingTimeout;
var startTime = RemainingTimeout == Constants.InfiniteTimeout ? 0 : Environment.TickCount;
int bytesRead;
Expand All @@ -48,6 +47,19 @@ ValueTask<int> DoReadBytesSync(ArraySegment<byte> buffer_)
return new ValueTask<int>(bytesRead);
}

ValueTask<int> DoReadBytesSyncOverAsync(ArraySegment<byte> buffer_)
{
try
{
// handle timeout by setting a timer to close the stream in the background
return new ValueTask<int>(DoReadBytesAsync(buffer_).GetAwaiter().GetResult());
}
catch (Exception ex)
{
return ValueTaskExtensions.FromException<int>(ex);
}
}

async Task<int> DoReadBytesAsync(ArraySegment<byte> buffer_)
{
var startTime = RemainingTimeout == Constants.InfiniteTimeout ? 0 : Environment.TickCount;
Expand Down
Loading

0 comments on commit f1a73d3

Please sign in to comment.