Skip to content

Commit

Permalink
RemoteProcess: Write EOF when StandardInputStream gets Closed/Dispose…
Browse files Browse the repository at this point in the history
…d. (#262)
  • Loading branch information
tmds authored Dec 7, 2024
1 parent f8e6526 commit ad9c9f4
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class RemoteProcess : IDisposable
ValueTask WriteAsync(string value, CancellationToken cancellationToken = default);
ValueTask WriteLineAsync(ReadOnlyMemory<char> buffer, CancellationToken cancellationToken = default);
ValueTask WriteLineAsync(string? value, CancellationToken cancellationToken = default);
Stream StandardInputStream { get; }
void WriteEof();
Stream StandardInputStream { get; } // Disposing/Closing the Stream calls WriteEof.
StreamWriter StandardInputWriter { get; }

// Wait for the remote process to exit.
Expand Down
2 changes: 1 addition & 1 deletion src/Tmds.Ssh/ISshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ interface ISshChannel
CancellationToken cancellationToken = default);

ValueTask WriteAsync(ReadOnlyMemory<byte> data, CancellationToken cancellationToken = default);
void WriteEof();
void WriteEof(bool noThrow = false);

Exception CreateCloseException();
}
18 changes: 18 additions & 0 deletions src/Tmds.Ssh/RemoteProcess.cs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,16 @@ private enum ReadMode

private bool HasExited { get => _readMode == ReadMode.Exited; } // delays exit until it was read by the user.

private void WriteEof(bool noThrow)
{
_channel.WriteEof(noThrow);
}

public void WriteEof()
{
WriteEof(noThrow: false);
}

public ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
Expand Down Expand Up @@ -640,6 +650,14 @@ public override Task FlushAsync(CancellationToken cancellationToken = default)
{
throw new IOException($"Unable to transport data: {ex.Message}.", ex);
}

}

public override void Close()
{
// The base Stream class calls Close for implementing Dispose.
// We mustn't throw to avoid throwing on Dispose.
_process.WriteEof(noThrow: true);
}
}
}
20 changes: 13 additions & 7 deletions src/Tmds.Ssh/SshChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,20 @@ string data
}
}

public void WriteEof()
public void WriteEof(bool noThrow)
{
ThrowIfDisposed();
ThrowIfAborted();
ThrowIfEofSent();

_eofSent = true;
TrySendEofMessage();
if (!noThrow)
{
ThrowIfDisposed();
ThrowIfAborted();
ThrowIfEofSent();
}

if (!_eofSent)
{
_eofSent = true;
TrySendEofMessage();
}
}

private void ThrowIfEofSent()
Expand Down
41 changes: 40 additions & 1 deletion test/Tmds.Ssh.Tests/RemoteProcessTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public async Task WriteAndRead(WriteApi writeApi)

byte[] buffer = new byte[512];
(bool isError, int bytesRead) = await process.ReadAsync(buffer, buffer);
Assert.False(false);
Assert.False(isError);
Assert.Equal(helloWorldBytes, buffer.AsSpan(0, bytesRead).ToArray());
}

Expand Down Expand Up @@ -421,4 +421,43 @@ public static IEnumerable<object[]> NewlineTestData
}
}
}

public enum EofApi
{
WriteEof,
StandardInputStreamClose,
StandardInputStreamDispose,
}

[Theory]
[InlineData(EofApi.WriteEof)]
[InlineData(EofApi.StandardInputStreamClose)]
[InlineData(EofApi.StandardInputStreamDispose)]
public async Task WriteEof(EofApi eofApi)
{
using var client = await _sshServer.CreateClientAsync();
using var process = await client.ExecuteAsync("cat");

switch (eofApi)
{
case EofApi.WriteEof:
process.WriteEof();
break;
case EofApi.StandardInputStreamClose:
process.StandardInputStream.Close();
break;
case EofApi.StandardInputStreamDispose:
process.StandardInputStream.Dispose();
break;
}

// Verify that Disposing the Stream after sending EOF does NOT throw.
Stream s = process.StandardInputStream;
s.Dispose();

byte[] buffer = new byte[512];
(bool isError, int bytesRead) = await process.ReadAsync(buffer, buffer);
Assert.False(isError);
Assert.Equal(0, bytesRead);
}
}

0 comments on commit ad9c9f4

Please sign in to comment.