diff --git a/README.md b/README.md index 302c315..0f29e04 100644 --- a/README.md +++ b/README.md @@ -127,8 +127,8 @@ class LocalForward : IDisposable class SftpClient : IDisposable { // Note: umask is applied on the server. - const UnixFilePermissions DefaultCreateDirectoryPermissions; // = '-rw-rw-rw-'. - const UnixFilePermissions DefaultCreateFilePermissions; // = '-rwxrwxrwx'. + const UnixFilePermissions DefaultCreateDirectoryPermissions; // = '-rwxrwxrwx'. + const UnixFilePermissions DefaultCreateFilePermissions; // = '-rw-rw-rw-'. // The SftpClient owns the connection. SftpClient(string destination, ILoggerFactory? loggerFactory = null, SftpClientOptions? options = null); @@ -181,7 +181,9 @@ class SftpClient : IDisposable IAsyncEnumerable GetDirectoryEntriesAsync(string path, SftpFileEntryTransform transform, EnumerationOptions? options = null); ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, CancellationToken cancellationToken); - ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions, CancellationToken cancellationToken = default); + ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions = null, CancellationToken cancellationToken = default); + ValueTask UploadFileAsync(Stream source, string remoteFilePath, CancellationToken cancellationToken); + ValueTask UploadFileAsync(Stream source, string remoteFilePath, bool overwrite = false, UnixFilePermissions createPermissions = DefaultCreateFilePermissions, CancellationToken cancellationToken = default); ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, CancellationToken cancellationToken = default); ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, UploadEntriesOptions? options, CancellationToken cancellationToken = default); diff --git a/src/Tmds.Ssh/SftpChannel.cs b/src/Tmds.Ssh/SftpChannel.cs index d83ea38..d5e5669 100644 --- a/src/Tmds.Ssh/SftpChannel.cs +++ b/src/Tmds.Ssh/SftpChannel.cs @@ -287,7 +287,7 @@ async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length) } buffer = ArrayPool.Shared.Rent(length); - bytesRead = await sourceFile.ReadAtAsync(buffer, sourceFile.Position + offset, cancellationToken).ConfigureAwait(false); + bytesRead = await sourceFile.ReadAtAsync(buffer.AsMemory(0, length), sourceFile.Position + offset, cancellationToken).ConfigureAwait(false); if (bytesRead == 0) { break; @@ -782,78 +782,137 @@ private static UnixFilePermissions GetPermissionsForFile(SafeFileHandle fileHand public async ValueTask UploadFileAsync(string localPath, string remotePath, long? length, bool overwrite, UnixFilePermissions? permissions, CancellationToken cancellationToken) { - using SafeFileHandle localFile = File.OpenHandle(localPath, FileMode.Open, FileAccess.Read, FileShare.Read); + using FileStream localFile = new FileStream(localPath, FileMode.Open, FileAccess.Read, FileShare.Read, bufferSize: 0); - permissions ??= GetPermissionsForFile(localFile); + permissions ??= GetPermissionsForFile(localFile.SafeFileHandle); - using SftpFile remoteFile = (await OpenFileCoreAsync(remotePath, (overwrite ? SftpOpenFlags.OpenOrCreate : SftpOpenFlags.CreateNew) | SftpOpenFlags.Write, permissions.Value, SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false))!; + await UploadFileAsync(localFile, remotePath, length, overwrite, permissions.Value, cancellationToken).ConfigureAwait(false); + } - length ??= RandomAccess.GetLength(localFile); + public async ValueTask UploadFileAsync(Stream source, string remotePath, long? length, bool overwrite, UnixFilePermissions permissions, CancellationToken cancellationToken) + { + using SftpFile remoteFile = (await OpenFileCoreAsync(remotePath, (overwrite ? SftpOpenFlags.OpenOrCreate : SftpOpenFlags.CreateNew) | SftpOpenFlags.Write, permissions, SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false))!; - ValueTask previous = default; + // Pipeline the writes when the source is a sync, seekable Stream. + bool pipelineSyncWrites = source.CanSeek && IsSyncStream(source); - CancellationTokenSource? breakLoop = length > 0 ? new() : null; + if (!pipelineSyncWrites) + { + await source.CopyToAsync(remoteFile, GetMaxWritePayload(remoteFile.Handle)).ConfigureAwait(false); - for (long offset = 0; offset < length; offset += GetMaxWritePayload(remoteFile.Handle)) + await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false); + } + else { - Debug.Assert(breakLoop is not null); - if (!breakLoop.IsCancellationRequested) + length ??= source.Length; + if (length == 0) { - await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); - previous = CopyBuffer(previous, offset, GetMaxWritePayload(remoteFile.Handle)); + return; } - } - await previous.ConfigureAwait(false); + ValueTask previous = default; + long startOffset = source.Position; + long bytesSuccesfullyWritten = 0; + CancellationTokenSource breakLoop = new(); + int maxWritePayload = GetMaxWritePayload(remoteFile.Handle); - await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false); + for (long offset = 0; offset < length; offset += maxWritePayload) + { + if (!breakLoop.IsCancellationRequested) + { + await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + int copyLength = (int)Math.Min((long)maxWritePayload, length.Value - offset); + previous = CopyBuffer(previous, offset, copyLength); + } + } - async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length) - { + bool ignorePositionUpdateException = false; try { - byte[]? buffer = null; + await previous.ConfigureAwait(false); + + await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false); + } + catch + { + ignorePositionUpdateException = true; + + throw; + } + finally + { + // Set the position to what was succesfully written. try { - if (breakLoop.IsCancellationRequested) + source.Position = startOffset + bytesSuccesfullyWritten; + } + catch when (ignorePositionUpdateException) + { } + } + + async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length) + { + try + { + byte[]? buffer = null; + try { - return; - } + if (breakLoop.IsCancellationRequested) + { + return; + } - buffer = ArrayPool.Shared.Rent(length); - do + buffer = ArrayPool.Shared.Rent(length); + int remaining = length; + long readOffset = startOffset + offset; + do + { + int bytesRead; + lock (breakLoop) // Ensure only one thread is reading the Stream concurrently. + { + source.Position = readOffset; + bytesRead = source.Read(buffer.AsSpan(length - remaining, remaining)); + } + if (bytesRead == 0) + { + throw new IOException("Unexpected end of file. The source was truncated during the upload."); + } + remaining -= bytesRead; + readOffset += bytesRead; + } while (remaining > 0); + + await remoteFile.WriteAtAsync(buffer.AsMemory(0, length), offset, cancellationToken).ConfigureAwait(false); + } + catch { - int bytesRead = RandomAccess.Read(localFile, buffer.AsSpan(0, length), offset); - if (bytesRead == 0) + length = 0; // Assume nothing was written succesfully. + breakLoop.Cancel(); + throw; + } + finally + { + if (buffer != null) { - break; + ArrayPool.Shared.Return(buffer); } - await remoteFile.WriteAtAsync(buffer.AsMemory(0, bytesRead), offset, cancellationToken).ConfigureAwait(false); - length -= bytesRead; - offset += bytesRead; - } while (length > 0); - } - catch - { - breakLoop.Cancel(); - throw; + s_uploadBufferSemaphore.Release(); + } } finally { - if (buffer != null) - { - ArrayPool.Shared.Return(buffer); - } - s_uploadBufferSemaphore.Release(); + await previousCopy.ConfigureAwait(false); + + // Update with our length after the previous write completed succesfully. + bytesSuccesfullyWritten += length; } } - finally - { - await previousCopy.ConfigureAwait(false); - } } } + // Consider it okay to do sync operation on these types of streams. + private static bool IsSyncStream(Stream stream) + => stream is MemoryStream or FileStream; + private IAsyncEnumerable GetDirectoryEntriesAsync(string path, SftpFileEntryTransform transform, EnumerationOptions options) => new SftpFileSystemEnumerable(this, path, transform, options); @@ -1087,7 +1146,7 @@ private async ValueTask DownloadFileAsync(string remotePath, string? localPath, Debug.Assert(destination is not null); - bool writeSync = destination is FileStream or MemoryStream; + bool writeSync = IsSyncStream(destination); ValueTask previous = default; CancellationTokenSource? breakLoop = length > 0 ? new() : null; diff --git a/src/Tmds.Ssh/SftpClient.cs b/src/Tmds.Ssh/SftpClient.cs index d61fb8c..8bec47c 100644 --- a/src/Tmds.Ssh/SftpClient.cs +++ b/src/Tmds.Ssh/SftpClient.cs @@ -97,7 +97,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) internal async ValueTask OpenAsync(CancellationToken cancellationToken) { - await GetChannelAsync(cancellationToken); + await GetChannelAsync(cancellationToken).ConfigureAwait(false); } internal ValueTask GetChannelAsync(CancellationToken cancellationToken, bool explicitConnect = false) @@ -330,7 +330,7 @@ public ValueTask CreateDirectoryAsync(string path, CancellationToken cancellatio public async ValueTask CreateDirectoryAsync(string path, bool createParents = false, UnixFilePermissions permissions = DefaultCreateDirectoryPermissions, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); - await channel.CreateDirectoryAsync(path, createParents, permissions, cancellationToken); + await channel.CreateDirectoryAsync(path, createParents, permissions, cancellationToken).ConfigureAwait(false); } public ValueTask CreateNewDirectoryAsync(string path, CancellationToken cancellationToken) @@ -339,7 +339,7 @@ public ValueTask CreateNewDirectoryAsync(string path, CancellationToken cancella public async ValueTask CreateNewDirectoryAsync(string path, bool createParents = false, UnixFilePermissions permissions = DefaultCreateDirectoryPermissions, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); - await channel.CreateNewDirectoryAsync(path, createParents, permissions, cancellationToken); + await channel.CreateNewDirectoryAsync(path, createParents, permissions, cancellationToken).ConfigureAwait(false); } public ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, CancellationToken cancellationToken = default) @@ -348,7 +348,7 @@ public ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteD public async ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, UploadEntriesOptions? options, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); - await channel.UploadDirectoryEntriesAsync(localDirPath, remoteDirPath, options, cancellationToken); + await channel.UploadDirectoryEntriesAsync(localDirPath, remoteDirPath, options, cancellationToken).ConfigureAwait(false); } public ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, CancellationToken cancellationToken) @@ -357,7 +357,16 @@ public ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, Ca public async ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions = default, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); - await channel.UploadFileAsync(localFilePath, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken); + await channel.UploadFileAsync(localFilePath, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken).ConfigureAwait(false); + } + + public ValueTask UploadFileAsync(Stream source, string remoteFilePath, CancellationToken cancellationToken) + => UploadFileAsync(source, remoteFilePath, overwrite: false, createPermissions: DefaultCreateFilePermissions, cancellationToken); + + public async ValueTask UploadFileAsync(Stream source, string remoteFilePath, bool overwrite = false, UnixFilePermissions createPermissions = DefaultCreateFilePermissions, CancellationToken cancellationToken = default) + { + var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); + await channel.UploadFileAsync(source, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken).ConfigureAwait(false); } public ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string localDirPath, CancellationToken cancellationToken = default) @@ -366,7 +375,7 @@ public ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string loca public async ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string localDirPath, DownloadEntriesOptions? options, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); - await channel.DownloadDirectoryEntriesAsync(remoteDirPath, localDirPath, options, cancellationToken); + await channel.DownloadDirectoryEntriesAsync(remoteDirPath, localDirPath, options, cancellationToken).ConfigureAwait(false); } public ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath, CancellationToken cancellationToken) @@ -375,13 +384,13 @@ public ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath, public async ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath, bool overwrite = false, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); - await channel.DownloadFileAsync(remoteFilePath, localFilePath, overwrite, cancellationToken); + await channel.DownloadFileAsync(remoteFilePath, localFilePath, overwrite, cancellationToken).ConfigureAwait(false); } public async ValueTask DownloadFileAsync(string remoteFilePath, Stream destination, CancellationToken cancellationToken = default) { var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false); - await channel.DownloadFileAsync(remoteFilePath, destination, cancellationToken); + await channel.DownloadFileAsync(remoteFilePath, destination, cancellationToken).ConfigureAwait(false); } private ObjectDisposedException NewObjectDisposedException() diff --git a/test/Tmds.Ssh.Tests/SftpClientTests.cs b/test/Tmds.Ssh.Tests/SftpClientTests.cs index ee3ed09..783236f 100644 --- a/test/Tmds.Ssh.Tests/SftpClientTests.cs +++ b/test/Tmds.Ssh.Tests/SftpClientTests.cs @@ -708,15 +708,44 @@ public async Task UploadDownloadFile(int fileSize) } } - [Fact] - public async Task DownloadFileToStream() + [InlineData(0)] + [InlineData(10)] + [InlineData(10 * MultiPacketSize)] // Ensure some pipelined writing. + [Theory] + public async Task UploadDownloadFileWithStream(int size) { using var sftpClient = await _sshServer.CreateSftpClientAsync(); - var (sourceFileName, sourceData) = await CreateRemoteFileWithRandomDataAsync(sftpClient, length: 100); + + byte[] sourceData = new byte[size]; + Random.Shared.NextBytes(sourceData); + MemoryStream uploadStream = new MemoryStream(sourceData); + + string remotePath = $"/tmp/{Path.GetRandomFileName()}"; + await sftpClient.UploadFileAsync(uploadStream, remotePath); + Assert.Equal(sourceData.Length, uploadStream.Position); await using var downloadStream = new MemoryStream(); - await sftpClient.DownloadFileAsync(sourceFileName, downloadStream); + await sftpClient.DownloadFileAsync(remotePath, downloadStream); + Assert.Equal(sourceData, downloadStream.ToArray()); + } + + [InlineData(0)] + [InlineData(10)] + [InlineData(10 * MultiPacketSize)] + [Theory] + public async Task UploadDownloadFileWithAsyncStream(int size) + { + using var sftpClient = await _sshServer.CreateSftpClientAsync(); + + byte[] sourceData = new byte[size]; + Random.Shared.NextBytes(sourceData); + Stream uploadStream = new NonSeekableAsyncStream(sourceData); + + string remotePath = $"/tmp/{Path.GetRandomFileName()}"; + await sftpClient.UploadFileAsync(uploadStream, remotePath); + await using var downloadStream = new NonSeekableAsyncStream(); + await sftpClient.DownloadFileAsync(remotePath, downloadStream); Assert.Equal(sourceData, downloadStream.ToArray()); } @@ -1219,4 +1248,61 @@ public async Task AutoReconnect(bool autoReconnect) await Assert.ThrowsAsync(() => client.GetFullPathAsync("").AsTask()); } } + + sealed class NonSeekableAsyncStream : Stream + { + private readonly MemoryStream _innerStream = new(); + + public NonSeekableAsyncStream() + { + _innerStream = new(); + } + + public NonSeekableAsyncStream(byte[] data) + { + _innerStream = new(data); + } + + public byte[] ToArray() + => _innerStream.ToArray(); + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => true; + + public override long Length => throw new NotImplementedException(); + + public override long Position + { + get => throw new NotImplementedException(); + set => throw new NotImplementedException(); + } + + public override void Flush() + { + throw new NotImplementedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _innerStream.Read(buffer, offset, count); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _innerStream.Write(buffer, offset, count); + } + } }