From b19836abe774a2196de32aa305e6acddcde1737b Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Thu, 11 Aug 2022 09:57:43 -0700 Subject: [PATCH] Update TestForward (#15321) --- lib/sshutils/x11/forward_test.go | 110 ++++++++++++------------------- 1 file changed, 43 insertions(+), 67 deletions(-) diff --git a/lib/sshutils/x11/forward_test.go b/lib/sshutils/x11/forward_test.go index 2d5c8a3b87088..52b85d465479c 100644 --- a/lib/sshutils/x11/forward_test.go +++ b/lib/sshutils/x11/forward_test.go @@ -16,7 +16,6 @@ package x11 import ( "context" - "io" "net" "testing" @@ -26,85 +25,62 @@ import ( func TestForward(t *testing.T) { ctx := context.Background() - // Create a fake client display. In practice, the display - // set in $DISPLAY is used to connect to the client display. - fakeClientDisplay, err := net.Listen("tcp", ":0") - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, fakeClientDisplay.Close()) - }) + // Open a dual sided connection on a new tcp listener + openConn := func(t *testing.T) (clientConn net.Conn, serverConn net.Conn) { + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, l.Close()) }) - // Handle connections to the client XServer - echoMsg := "msg" - go func() { - for { - localConn, err := fakeClientDisplay.Accept() + serverErrC := make(chan error) + serverConnC := make(chan net.Conn) + go func() { + serverConn, err := l.Accept() if err != nil { - // listener is closed, test is done. - return + serverErrC <- err + close(serverConnC) } + serverConnC <- serverConn + close(serverErrC) + }() - go func() { - defer localConn.Close() - - // read request and expect what was written to server - bytes, err := io.ReadAll(localConn) - require.NoError(t, err) - require.Equal(t, echoMsg, string(bytes)) - }() - } - }() - - // Create a fake XServer proxy just like the one in sshserver. - sl, serverDisplay, err := OpenNewXServerListener(DefaultDisplayOffset, DefaultMaxDisplays, 0) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, sl.Close()) - }) + clientConn, err = net.Dial("tcp", l.Addr().String()) + require.NoError(t, err) + t.Cleanup(func() { clientConn.Close() }) - // Handle connection to XServer proxy - go func() { - for { - serverConn, err := sl.Accept() - if err != nil { - // listener is closed, test is done. - return - } + serverConn = <-serverConnC + require.NoError(t, <-serverErrC) + t.Cleanup(func() { serverConn.Close() }) - go func() { - defer serverConn.Close() + return clientConn, serverConn + } - clientConn, err := net.Dial("tcp", fakeClientDisplay.Addr().String()) - if err != nil { - // fakeClientDisplay is closed, test is done. - return - } + cConn1, sConn1 := openConn(t) + cConn2, sConn2 := openConn(t) - clientXConn, ok := clientConn.(*net.TCPConn) - require.True(t, ok) - defer clientConn.Close() + // Start forwarding between connections so that we get + // this flow: cConn1 -> sConn1 -> cConn2 -> sConn2. + serverConnToForward, ok := sConn1.(*net.TCPConn) + require.True(t, ok) + clientConnToForward, ok := cConn2.(*net.TCPConn) + require.True(t, ok) - err = Forward(ctx, clientXConn, serverConn) - require.NoError(t, err) - }() - } + forwardErrC := make(chan error, 1) + go func() { + forwardErrC <- Forward(ctx, serverConnToForward, clientConnToForward) }() - // Create a fake XServer request to the XServer proxy - xreq, err := serverDisplay.Dial() - require.NoError(t, err) - _, err = xreq.Write([]byte(echoMsg)) + // Write a msg to client connection 1, which should propagate to server connection 2. + message := "msg" + _, err := cConn1.Write([]byte(message)) require.NoError(t, err) - // Create a second request simultaneously - xreq2, err := serverDisplay.Dial() - require.NoError(t, err) - _, err = xreq2.Write([]byte(echoMsg)) + buf := make([]byte, len(message)) + _, err = sConn2.Read(buf) require.NoError(t, err) - xreq2.Close() + require.Equal(t, message, string(buf)) - // Close XServer requests, forwarding should stop as soon as - // the open connection has been read and forwarded fully. - xreq.Close() - xreq2.Close() + // Fowarding should stop once the other sides of the forwarded connections are closed. + require.NoError(t, cConn1.Close()) + require.NoError(t, sConn2.Close()) + require.NoError(t, <-forwardErrC) }