Skip to content

Commit

Permalink
chore: implement the port-forwarding correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
mdelapenya committed Apr 23, 2024
1 parent c2c2866 commit 13a3c6b
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 252 deletions.
230 changes: 80 additions & 150 deletions port_forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@ package testcontainers

import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"

"github.com/docker/docker/api/types/container"
"github.com/google/uuid"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"

"github.com/testcontainers/testcontainers-go/internal/core/network"
)
Expand All @@ -28,7 +23,7 @@ const (
)

// sshPassword is a random password generated for the SSHD container.
var sshPassword = uuid.NewString()
var sshPassword = "123456"

// exposeHostPorts performs all the necessary steps to expose the host ports to the container, leveraging
// the SSHD container to create the tunnel, and the container lifecycle hooks to manage the tunnel lifecycle.
Expand Down Expand Up @@ -128,25 +123,11 @@ func exposeHostPorts(ctx context.Context, req *ContainerRequest, p ...int) (Cont
}

// after the container is ready, create the SSH tunnel
// for each exposed port from the host. We are going to
// use an error group to expose all ports in parallel,
// and return an error if any of them fails.
// for each exposed port from the host.
sshdConnectHook = ContainerLifecycleHooks{
PostReadies: []ContainerHook{
func(ctx context.Context, c Container) error {
var errs []error

for _, exposedHostPort := range req.HostAccessPorts {
err := sshdContainer.exposeHostPort(ctx, exposedHostPort)
if err != nil {
errs = append(errs, err)
}
}

if len(errs) > 0 {
return fmt.Errorf("failed to expose host ports: %w", errors.Join(errs...))
}

sshdContainer.exposeHostPort(ctx, req.HostAccessPorts...)
return nil
},
},
Expand Down Expand Up @@ -195,45 +176,42 @@ func newSshdContainer(ctx context.Context, opts ...ContainerCustomizer) (*sshdCo

sshd := &sshdContainer{
DockerContainer: dc,
tunnels: make(map[int]*sshTunnel),
portForwarders: []PortForwarder{},
}

err = sshd.configureSSHSession(ctx)
sshClientConfig, err := configureSSHConfig(ctx, sshd)
if err != nil {
// return the container and the error to the caller to handle it
return sshd, err
}

sshd.sshConfig = sshClientConfig

return sshd, nil
}

// sshdContainer represents the SSHD container type used for the port forwarding container.
// It's an internal type that extends the DockerContainer type, to add the SSH tunneling capabilities.
type sshdContainer struct {
*DockerContainer
port string
sshConfig *ssh.ClientConfig
tunnels map[int]*sshTunnel
port string
sshConfig *ssh.ClientConfig
portForwarders []PortForwarder
}

// Terminate stops the container and closes the SSH session
func (sshdC *sshdContainer) Terminate(ctx context.Context) error {
for _, t := range sshdC.tunnels {
defer t.Close()
for _, pfw := range sshdC.portForwarders {
pfw.Close(ctx)
}

return sshdC.DockerContainer.Terminate(ctx)
}

func (sshdC *sshdContainer) configureSSHSession(ctx context.Context) error {
if sshdC.sshConfig != nil {
// do not configure the SSH session twice
return nil
}

func configureSSHConfig(ctx context.Context, sshdC *sshdContainer) (*ssh.ClientConfig, error) {
mappedPort, err := sshdC.MappedPort(ctx, sshPort)
if err != nil {
return err
return nil, err
}
sshdC.port = mappedPort.Port()

Expand All @@ -244,152 +222,104 @@ func (sshdC *sshdContainer) configureSSHSession(ctx context.Context) error {
Timeout: 30 * time.Second,
}

sshdC.sshConfig = &sshConfig

return nil
return &sshConfig, nil
}

func (sshdC *sshdContainer) exposeHostPort(ctx context.Context, port int) error {
if _, ok := sshdC.tunnels[port]; ok {
// do not expose the same port twice
return nil
}
func (sshdC *sshdContainer) exposeHostPort(ctx context.Context, ports ...int) error {
for _, port := range ports {
pw := NewPortForwarder(fmt.Sprintf("localhost:%s", sshdC.port), sshdC.sshConfig, port, port)
sshdC.portForwarders = append(sshdC.portForwarders, *pw)

// Setup the tunnel, but do not yet start it yet.
tunnel := newSSHTunnel(
fmt.Sprintf("%s@localhost:%s", user, sshdC.port),
sshdC.sshConfig,
"localhost", port, // The destination host and port of the actual server.
)

// use testcontainers logger
tunnel.Log = Logger

err := tunnel.Start(ctx)
if err != nil {
return fmt.Errorf("failed to start the SSH tunnel: %w", err)
go pw.Forward(ctx)
}

sshdC.tunnels[port] = tunnel

return nil
}

type sshEndpoint struct {
Host string
Port int
User string
}

func newSshEndpoint(s string) *sshEndpoint {
endpoint := &sshEndpoint{
Host: s,
}
if parts := strings.Split(endpoint.Host, "@"); len(parts) > 1 {
endpoint.User = parts[0]
endpoint.Host = parts[1]
}
if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 {
endpoint.Host = parts[0]
endpoint.Port, _ = strconv.Atoi(parts[1])
// continue when all port forwarders have created the connection
for _, pfw := range sshdC.portForwarders {
<-pfw.connectionCreated
}
return endpoint
}

func (endpoint *sshEndpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
return nil
}

type sshTunnel struct {
Local *sshEndpoint
Server *sshEndpoint
Remote *sshEndpoint
Config *ssh.ClientConfig
Log Logging
type PortForwarder struct {
sshDAddr string
sshConfig *ssh.ClientConfig
remotePort int
localPort int
connectionCreated chan bool // used to signal that the connection has been created, so the caller can proceed
}

func newSSHTunnel(tunnel string, sshConfig *ssh.ClientConfig, target string, targetPort int) *sshTunnel {
destination := fmt.Sprintf("%s:%d", target, targetPort)

localEndpoint := newSshEndpoint(destination)

server := newSshEndpoint(tunnel)
if server.Port == 0 {
server.Port = 22
}

return &sshTunnel{
Config: sshConfig,
Local: localEndpoint,
Server: server,
Remote: newSshEndpoint(destination),
func NewPortForwarder(sshDAddr string, sshConfig *ssh.ClientConfig, remotePort, localPort int) *PortForwarder {
return &PortForwarder{
sshDAddr: sshDAddr,
sshConfig: sshConfig,
remotePort: remotePort,
localPort: localPort,
connectionCreated: make(chan bool),
}
}

func (tunnel *sshTunnel) logf(fmt string, args ...interface{}) {
if tunnel.Log != nil {
tunnel.Log.Printf(fmt, args...)
}
func (pf *PortForwarder) Close(ctx context.Context) {
close(pf.connectionCreated)
}

func (tunnel *sshTunnel) Close() error {
if tunnel.Log != nil {
tunnel.logf("closing tunnel")
func (pf *PortForwarder) Forward(ctx context.Context) error {
client, err := ssh.Dial("tcp", pf.sshDAddr, pf.sshConfig)
if err != nil {
return fmt.Errorf("error dialing ssh server: %w", err)
}
defer client.Close()

return nil
}

func (tunnel *sshTunnel) Start(ctx context.Context) error {
lcfg := net.ListenConfig{}

listener, err := lcfg.Listen(ctx, "tcp", tunnel.Local.String())
listener, err := client.Listen("tcp", fmt.Sprintf("localhost:%d", pf.remotePort))
if err != nil {
return err
return fmt.Errorf("error listening on remote port: %w", err)
}
defer listener.Close()

tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port
// signal that the connection has been created
pf.connectionCreated <- true

// close the listener and the client when the context is done
go func() {
<-ctx.Done()
listener.Close()
client.Close()
}()

for {
conn, err := listener.Accept()
remote, err := listener.Accept()
if err != nil {
return err
return fmt.Errorf("error accepting connection: %w", err)
}
tunnel.logf("accepted connection")
tunnel.forward(conn)
}
}

func (tunnel *sshTunnel) forward(localConn net.Conn) error {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
return fmt.Errorf("server dial error: %s", err)
go runTunnel(ctx, remote, pf.localPort)
}
tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String())
}

remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
// runTunnel runs a tunnel between two connections; as soon as one connection
// reaches EOF or reports an error, both connections are closed and this
// function returns.
func runTunnel(ctx context.Context, remote net.Conn, port int) {
var dialer net.Dialer
local, err := dialer.DialContext(ctx, "tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
return fmt.Errorf("remote dial error: %s", err)
remote.Close()
return
}
tunnel.logf("connected to %s (2 of 2)\n", tunnel.Remote.String())
defer local.Close()

copyConn := func(writer, reader net.Conn) error {
_, err := io.Copy(writer, reader)
if err != nil {
return fmt.Errorf("io.Copy error: %s", err)
}

return nil
}
defer remote.Close()
done := make(chan struct{}, 2)

errgr := errgroup.Group{}
go func() {
io.Copy(local, remote)
done <- struct{}{}
}()

errgr.Go(func() error {
return copyConn(localConn, remoteConn)
})
errgr.Go(func() error {
return copyConn(remoteConn, localConn)
})
go func() {
io.Copy(remote, local)
done <- struct{}{}
}()

return errgr.Wait()
<-done
}
Loading

0 comments on commit 13a3c6b

Please sign in to comment.