diff --git a/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java b/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java index a6c524df5..630d080dd 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java @@ -32,7 +32,7 @@ private SocketStreamCopyMonitor(Runnable r) { setDaemon(true); } - private static Closeable wrapSocket(final Socket socket) { + public static Closeable wrapSocket(final Socket socket) { return new Closeable() { @Override public void close() diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java index 9ea10bf51..421951b4e 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java @@ -16,12 +16,12 @@ package net.schmizz.sshj.connection.channel.direct; import net.schmizz.concurrent.Event; +import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.connection.Connection; -import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor; -import net.schmizz.sshj.transport.TransportException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -112,11 +112,15 @@ public LocalPortForwarder(Connection conn, Parameters parameters, ServerSocket s this.serverSocket = serverSocket; } - protected DirectTCPIPChannel openChannel(Socket socket) - throws TransportException, ConnectionException { - final DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters); - chan.open(); - return chan; + private void startChannel(Socket socket) throws IOException { + DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters); + try { + chan.open(); + chan.start(); + } catch (IOException e) { + IOUtils.closeQuietly(chan, SocketStreamCopyMonitor.wrapSocket(socket)); + throw e; + } } /** @@ -130,7 +134,7 @@ public void listen() while (!Thread.currentThread().isInterrupted()) { final Socket socket = serverSocket.accept(); log.debug("Got connection from {}", socket.getRemoteSocketAddress()); - openChannel(socket).start(); + startChannel(socket); } log.debug("Interrupted!"); }