Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SftpClient: support uploading file from Stream. #271

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ class LocalForward : IDisposable
class SftpClient : IDisposable
{
// Note: umask is applied on the server.
const UnixFilePermissions DefaultCreateDirectoryPermissions; // = '-rw-rw-rw-'.
const UnixFilePermissions DefaultCreateFilePermissions; // = '-rwxrwxrwx'.
const UnixFilePermissions DefaultCreateDirectoryPermissions; // = '-rwxrwxrwx'.
const UnixFilePermissions DefaultCreateFilePermissions; // = '-rw-rw-rw-'.

// The SftpClient owns the connection.
SftpClient(string destination, ILoggerFactory? loggerFactory = null, SftpClientOptions? options = null);
Expand Down Expand Up @@ -181,7 +181,9 @@ class SftpClient : IDisposable
IAsyncEnumerable<T> GetDirectoryEntriesAsync<T>(string path, SftpFileEntryTransform<T> transform, EnumerationOptions? options = null);

ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, CancellationToken cancellationToken);
ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions, CancellationToken cancellationToken = default);
ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions = null, CancellationToken cancellationToken = default);
ValueTask UploadFileAsync(Stream source, string remoteFilePath, CancellationToken cancellationToken);
ValueTask UploadFileAsync(Stream source, string remoteFilePath, bool overwrite = false, UnixFilePermissions createPermissions = DefaultCreateFilePermissions, CancellationToken cancellationToken = default);
ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, CancellationToken cancellationToken = default);
ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, UploadEntriesOptions? options, CancellationToken cancellationToken = default);

Expand Down
149 changes: 104 additions & 45 deletions src/Tmds.Ssh/SftpChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length)
}

buffer = ArrayPool<byte>.Shared.Rent(length);
bytesRead = await sourceFile.ReadAtAsync(buffer, sourceFile.Position + offset, cancellationToken).ConfigureAwait(false);
bytesRead = await sourceFile.ReadAtAsync(buffer.AsMemory(0, length), sourceFile.Position + offset, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
break;
Expand Down Expand Up @@ -782,78 +782,137 @@ private static UnixFilePermissions GetPermissionsForFile(SafeFileHandle fileHand

public async ValueTask UploadFileAsync(string localPath, string remotePath, long? length, bool overwrite, UnixFilePermissions? permissions, CancellationToken cancellationToken)
{
using SafeFileHandle localFile = File.OpenHandle(localPath, FileMode.Open, FileAccess.Read, FileShare.Read);
using FileStream localFile = new FileStream(localPath, FileMode.Open, FileAccess.Read, FileShare.Read, bufferSize: 0);

permissions ??= GetPermissionsForFile(localFile);
permissions ??= GetPermissionsForFile(localFile.SafeFileHandle);

using SftpFile remoteFile = (await OpenFileCoreAsync(remotePath, (overwrite ? SftpOpenFlags.OpenOrCreate : SftpOpenFlags.CreateNew) | SftpOpenFlags.Write, permissions.Value, SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false))!;
await UploadFileAsync(localFile, remotePath, length, overwrite, permissions.Value, cancellationToken).ConfigureAwait(false);
}

length ??= RandomAccess.GetLength(localFile);
public async ValueTask UploadFileAsync(Stream source, string remotePath, long? length, bool overwrite, UnixFilePermissions permissions, CancellationToken cancellationToken)
{
using SftpFile remoteFile = (await OpenFileCoreAsync(remotePath, (overwrite ? SftpOpenFlags.OpenOrCreate : SftpOpenFlags.CreateNew) | SftpOpenFlags.Write, permissions, SftpClient.DefaultFileOpenOptions, cancellationToken).ConfigureAwait(false))!;

ValueTask previous = default;
// Pipeline the writes when the source is a sync, seekable Stream.
bool pipelineSyncWrites = source.CanSeek && IsSyncStream(source);

CancellationTokenSource? breakLoop = length > 0 ? new() : null;
if (!pipelineSyncWrites)
{
await source.CopyToAsync(remoteFile, GetMaxWritePayload(remoteFile.Handle)).ConfigureAwait(false);

for (long offset = 0; offset < length; offset += GetMaxWritePayload(remoteFile.Handle))
await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false);
}
else
{
Debug.Assert(breakLoop is not null);
if (!breakLoop.IsCancellationRequested)
length ??= source.Length;
if (length == 0)
{
await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
previous = CopyBuffer(previous, offset, GetMaxWritePayload(remoteFile.Handle));
return;
}
}

await previous.ConfigureAwait(false);
ValueTask previous = default;
long startOffset = source.Position;
long bytesSuccesfullyWritten = 0;
CancellationTokenSource breakLoop = new();
int maxWritePayload = GetMaxWritePayload(remoteFile.Handle);

await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false);
for (long offset = 0; offset < length; offset += maxWritePayload)
{
if (!breakLoop.IsCancellationRequested)
{
await s_uploadBufferSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
int copyLength = (int)Math.Min((long)maxWritePayload, length.Value - offset);
previous = CopyBuffer(previous, offset, copyLength);
}
}

async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length)
{
bool ignorePositionUpdateException = false;
try
{
byte[]? buffer = null;
await previous.ConfigureAwait(false);

await remoteFile.CloseAsync(cancellationToken).ConfigureAwait(false);
}
catch
{
ignorePositionUpdateException = true;

throw;
}
finally
{
// Set the position to what was succesfully written.
try
{
if (breakLoop.IsCancellationRequested)
source.Position = startOffset + bytesSuccesfullyWritten;
}
catch when (ignorePositionUpdateException)
{ }
}

async ValueTask CopyBuffer(ValueTask previousCopy, long offset, int length)
{
try
{
byte[]? buffer = null;
try
{
return;
}
if (breakLoop.IsCancellationRequested)
{
return;
}

buffer = ArrayPool<byte>.Shared.Rent(length);
do
buffer = ArrayPool<byte>.Shared.Rent(length);
int remaining = length;
long readOffset = startOffset + offset;
do
{
int bytesRead;
lock (breakLoop) // Ensure only one thread is reading the Stream concurrently.
{
source.Position = readOffset;
bytesRead = source.Read(buffer.AsSpan(length - remaining, remaining));
}
if (bytesRead == 0)
{
throw new IOException("Unexpected end of file. The source was truncated during the upload.");
}
remaining -= bytesRead;
readOffset += bytesRead;
} while (remaining > 0);

await remoteFile.WriteAtAsync(buffer.AsMemory(0, length), offset, cancellationToken).ConfigureAwait(false);
}
catch
{
int bytesRead = RandomAccess.Read(localFile, buffer.AsSpan(0, length), offset);
if (bytesRead == 0)
length = 0; // Assume nothing was written succesfully.
breakLoop.Cancel();
throw;
}
finally
{
if (buffer != null)
{
break;
ArrayPool<byte>.Shared.Return(buffer);
}
await remoteFile.WriteAtAsync(buffer.AsMemory(0, bytesRead), offset, cancellationToken).ConfigureAwait(false);
length -= bytesRead;
offset += bytesRead;
} while (length > 0);
}
catch
{
breakLoop.Cancel();
throw;
s_uploadBufferSemaphore.Release();
}
}
finally
{
if (buffer != null)
{
ArrayPool<byte>.Shared.Return(buffer);
}
s_uploadBufferSemaphore.Release();
await previousCopy.ConfigureAwait(false);

// Update with our length after the previous write completed succesfully.
bytesSuccesfullyWritten += length;
}
}
finally
{
await previousCopy.ConfigureAwait(false);
}
}
}

// Consider it okay to do sync operation on these types of streams.
private static bool IsSyncStream(Stream stream)
=> stream is MemoryStream or FileStream;

private IAsyncEnumerable<T> GetDirectoryEntriesAsync<T>(string path, SftpFileEntryTransform<T> transform, EnumerationOptions options)
=> new SftpFileSystemEnumerable<T>(this, path, transform, options);

Expand Down Expand Up @@ -1087,7 +1146,7 @@ private async ValueTask DownloadFileAsync(string remotePath, string? localPath,

Debug.Assert(destination is not null);

bool writeSync = destination is FileStream or MemoryStream;
bool writeSync = IsSyncStream(destination);

ValueTask previous = default;
CancellationTokenSource? breakLoop = length > 0 ? new() : null;
Expand Down
25 changes: 17 additions & 8 deletions src/Tmds.Ssh/SftpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)

internal async ValueTask OpenAsync(CancellationToken cancellationToken)
{
await GetChannelAsync(cancellationToken);
await GetChannelAsync(cancellationToken).ConfigureAwait(false);
}

internal ValueTask<SftpChannel> GetChannelAsync(CancellationToken cancellationToken, bool explicitConnect = false)
Expand Down Expand Up @@ -330,7 +330,7 @@ public ValueTask CreateDirectoryAsync(string path, CancellationToken cancellatio
public async ValueTask CreateDirectoryAsync(string path, bool createParents = false, UnixFilePermissions permissions = DefaultCreateDirectoryPermissions, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.CreateDirectoryAsync(path, createParents, permissions, cancellationToken);
await channel.CreateDirectoryAsync(path, createParents, permissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask CreateNewDirectoryAsync(string path, CancellationToken cancellationToken)
Expand All @@ -339,7 +339,7 @@ public ValueTask CreateNewDirectoryAsync(string path, CancellationToken cancella
public async ValueTask CreateNewDirectoryAsync(string path, bool createParents = false, UnixFilePermissions permissions = DefaultCreateDirectoryPermissions, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.CreateNewDirectoryAsync(path, createParents, permissions, cancellationToken);
await channel.CreateNewDirectoryAsync(path, createParents, permissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, CancellationToken cancellationToken = default)
Expand All @@ -348,7 +348,7 @@ public ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteD
public async ValueTask UploadDirectoryEntriesAsync(string localDirPath, string remoteDirPath, UploadEntriesOptions? options, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.UploadDirectoryEntriesAsync(localDirPath, remoteDirPath, options, cancellationToken);
await channel.UploadDirectoryEntriesAsync(localDirPath, remoteDirPath, options, cancellationToken).ConfigureAwait(false);
}

public ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, CancellationToken cancellationToken)
Expand All @@ -357,7 +357,16 @@ public ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, Ca
public async ValueTask UploadFileAsync(string localFilePath, string remoteFilePath, bool overwrite = false, UnixFilePermissions? createPermissions = default, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.UploadFileAsync(localFilePath, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken);
await channel.UploadFileAsync(localFilePath, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask UploadFileAsync(Stream source, string remoteFilePath, CancellationToken cancellationToken)
=> UploadFileAsync(source, remoteFilePath, overwrite: false, createPermissions: DefaultCreateFilePermissions, cancellationToken);

public async ValueTask UploadFileAsync(Stream source, string remoteFilePath, bool overwrite = false, UnixFilePermissions createPermissions = DefaultCreateFilePermissions, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.UploadFileAsync(source, remoteFilePath, length: null, overwrite, createPermissions, cancellationToken).ConfigureAwait(false);
}

public ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string localDirPath, CancellationToken cancellationToken = default)
Expand All @@ -366,7 +375,7 @@ public ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string loca
public async ValueTask DownloadDirectoryEntriesAsync(string remoteDirPath, string localDirPath, DownloadEntriesOptions? options, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.DownloadDirectoryEntriesAsync(remoteDirPath, localDirPath, options, cancellationToken);
await channel.DownloadDirectoryEntriesAsync(remoteDirPath, localDirPath, options, cancellationToken).ConfigureAwait(false);
}

public ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath, CancellationToken cancellationToken)
Expand All @@ -375,13 +384,13 @@ public ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath,
public async ValueTask DownloadFileAsync(string remoteFilePath, string localFilePath, bool overwrite = false, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.DownloadFileAsync(remoteFilePath, localFilePath, overwrite, cancellationToken);
await channel.DownloadFileAsync(remoteFilePath, localFilePath, overwrite, cancellationToken).ConfigureAwait(false);
}

public async ValueTask DownloadFileAsync(string remoteFilePath, Stream destination, CancellationToken cancellationToken = default)
{
var channel = await GetChannelAsync(cancellationToken).ConfigureAwait(false);
await channel.DownloadFileAsync(remoteFilePath, destination, cancellationToken);
await channel.DownloadFileAsync(remoteFilePath, destination, cancellationToken).ConfigureAwait(false);
}

private ObjectDisposedException NewObjectDisposedException()
Expand Down
Loading