Skip to content

Commit

Permalink
Update TestForward (gravitational#15321)
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger authored Aug 11, 2022
1 parent 66a428b commit b19836a
Showing 1 changed file with 43 additions and 67 deletions.
110 changes: 43 additions & 67 deletions lib/sshutils/x11/forward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package x11

import (
"context"
"io"
"net"
"testing"

Expand All @@ -26,85 +25,62 @@ import (
func TestForward(t *testing.T) {
ctx := context.Background()

// Create a fake client display. In practice, the display
// set in $DISPLAY is used to connect to the client display.
fakeClientDisplay, err := net.Listen("tcp", ":0")
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, fakeClientDisplay.Close())
})
// Open a dual sided connection on a new tcp listener
openConn := func(t *testing.T) (clientConn net.Conn, serverConn net.Conn) {
l, err := net.Listen("tcp", ":0")
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, l.Close()) })

// Handle connections to the client XServer
echoMsg := "msg"
go func() {
for {
localConn, err := fakeClientDisplay.Accept()
serverErrC := make(chan error)
serverConnC := make(chan net.Conn)
go func() {
serverConn, err := l.Accept()
if err != nil {
// listener is closed, test is done.
return
serverErrC <- err
close(serverConnC)
}
serverConnC <- serverConn
close(serverErrC)
}()

go func() {
defer localConn.Close()

// read request and expect what was written to server
bytes, err := io.ReadAll(localConn)
require.NoError(t, err)
require.Equal(t, echoMsg, string(bytes))
}()
}
}()

// Create a fake XServer proxy just like the one in sshserver.
sl, serverDisplay, err := OpenNewXServerListener(DefaultDisplayOffset, DefaultMaxDisplays, 0)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, sl.Close())
})
clientConn, err = net.Dial("tcp", l.Addr().String())
require.NoError(t, err)
t.Cleanup(func() { clientConn.Close() })

// Handle connection to XServer proxy
go func() {
for {
serverConn, err := sl.Accept()
if err != nil {
// listener is closed, test is done.
return
}
serverConn = <-serverConnC
require.NoError(t, <-serverErrC)
t.Cleanup(func() { serverConn.Close() })

go func() {
defer serverConn.Close()
return clientConn, serverConn
}

clientConn, err := net.Dial("tcp", fakeClientDisplay.Addr().String())
if err != nil {
// fakeClientDisplay is closed, test is done.
return
}
cConn1, sConn1 := openConn(t)
cConn2, sConn2 := openConn(t)

clientXConn, ok := clientConn.(*net.TCPConn)
require.True(t, ok)
defer clientConn.Close()
// Start forwarding between connections so that we get
// this flow: cConn1 -> sConn1 -> cConn2 -> sConn2.
serverConnToForward, ok := sConn1.(*net.TCPConn)
require.True(t, ok)
clientConnToForward, ok := cConn2.(*net.TCPConn)
require.True(t, ok)

err = Forward(ctx, clientXConn, serverConn)
require.NoError(t, err)
}()
}
forwardErrC := make(chan error, 1)
go func() {
forwardErrC <- Forward(ctx, serverConnToForward, clientConnToForward)
}()

// Create a fake XServer request to the XServer proxy
xreq, err := serverDisplay.Dial()
require.NoError(t, err)
_, err = xreq.Write([]byte(echoMsg))
// Write a msg to client connection 1, which should propagate to server connection 2.
message := "msg"
_, err := cConn1.Write([]byte(message))
require.NoError(t, err)

// Create a second request simultaneously
xreq2, err := serverDisplay.Dial()
require.NoError(t, err)
_, err = xreq2.Write([]byte(echoMsg))
buf := make([]byte, len(message))
_, err = sConn2.Read(buf)
require.NoError(t, err)
xreq2.Close()
require.Equal(t, message, string(buf))

// Close XServer requests, forwarding should stop as soon as
// the open connection has been read and forwarded fully.
xreq.Close()
xreq2.Close()
// Fowarding should stop once the other sides of the forwarded connections are closed.
require.NoError(t, cConn1.Close())
require.NoError(t, sConn2.Close())
require.NoError(t, <-forwardErrC)
}

0 comments on commit b19836a

Please sign in to comment.