diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index ffd27289ce85d..c02203b6c4f01 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -108,6 +108,12 @@ func init() { } // Client is the RDP client. +// Its lifecycle is: +// +// ``` +// rdpc := New() // creates client +// rdpc.Run() // starts rdp and waits for the duration of the connection +// ``` type Client struct { cfg Config @@ -115,7 +121,7 @@ type Client struct { clientWidth, clientHeight uint16 username string - // handle allows the rust code to call back into the client + // handle allows the rust code to call back into the client. handle cgo.Handle // RDP client on the Rust side. @@ -155,11 +161,21 @@ func New(ctx context.Context, cfg Config) (*Client, error) { if err := c.readClientSize(); err != nil { return nil, trace.Wrap(err) } + return c, nil +} + +// Run starts the rdp client and blocks until the client disconnects, +// then runs the cleanup. +func (c *Client) Run(ctx context.Context) error { + defer c.close() + if err := c.connect(ctx); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } c.start() - return c, nil + c.wg.Wait() + + return nil } func (c *Client) readClientUsername() error { @@ -240,7 +256,6 @@ func (c *Client) start() { c.wg.Add(1) go func() { defer c.wg.Done() - defer c.Close() defer c.cfg.Log.Info("RDP output streaming finished") // C.read_rdp_output blocks for the duration of the RDP connection and @@ -260,7 +275,6 @@ func (c *Client) start() { c.wg.Add(1) go func() { defer c.wg.Done() - defer c.Close() defer c.cfg.Log.Info("TDP input streaming finished") // Remember mouse coordinates to send them with all CGOPointer events. var mouseX, mouseY uint32 @@ -465,24 +479,21 @@ func (c *Client) sharedDirectoryAcknowledge(ack tdp.SharedDirectoryAcknowledge) return C.ErrCodeSuccess } -// Wait blocks until the client disconnects and runs the cleanup. -func (c *Client) Wait() error { - c.wg.Wait() - // Let the Rust side free its data. - C.free_rdp(c.rustClient) - return nil -} - -// Close shuts down the client and closes any existing connections. -// It is safe to call multiple times, from multiple goroutines. -// Calls other than the first one are no-ops. -func (c *Client) Close() { +// close frees the memory of the cgo.Handle, +// closes the RDP client connection, +// and frees the Rust client. +func (c *Client) close() { c.closeOnce.Do(func() { - c.handle.Delete() - + // Close the RDP client if err := C.close_rdp(c.rustClient); err != C.ErrCodeSuccess { c.cfg.Log.Warningf("failed to close the RDP client") } + + // Let the Rust side free its data + C.free_rdp(c.rustClient) + + // Release the memory of the cgo.Handle + c.handle.Delete() }) } diff --git a/lib/srv/desktop/rdp/rdpclient/client_nop.go b/lib/srv/desktop/rdp/rdpclient/client_nop.go index d89f34ccd2b6a..d32fd4adfccf6 100644 --- a/lib/srv/desktop/rdp/rdpclient/client_nop.go +++ b/lib/srv/desktop/rdp/rdpclient/client_nop.go @@ -38,15 +38,12 @@ func New(ctx context.Context, cfg Config) (*Client, error) { return &Client{}, errors.New("the real rdpclient.Client implementation was not included in this build") } -// Wait blocks until the client disconnects and runs the cleanup. -func (c *Client) Wait() error { +// Run starts the rdp client and blocks until the client disconnects, +// then runs the cleanup. +func (c *Client) Run(ctx context.Context) error { return errors.New("the real rdpclient.Client implementation was not included in this build") } -// Close shuts down the client and closes any existing connections. -func (c *Client) Close() { -} - // GetClientLastActive returns the time of the last recorded activity. func (c *Client) GetClientLastActive() time.Time { return time.Now().UTC() diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index eb6b2422193d0..b3b7a95f301ec 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -871,17 +871,21 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, monitorCfg.DisconnectExpiredCert = identity.Expires } + // UpdateClientActivity before starting monitor to + // be doubly sure that the client isn't disconnected + // due to an idle timeout before its had the chance to + // call StartAndWait() + rdpc.UpdateClientActivity() if err := srv.StartMonitor(monitorCfg); err != nil { // if we can't establish a connection monitor then we can't enforce RBAC. // consider this a connection failure and return an error // (in the happy path, rdpc remains open until Wait() completes) - rdpc.Close() s.onSessionStart(ctx, sw, &identity, sessionStartTime, windowsUser, string(sessionID), desktop, err) return trace.Wrap(err) } s.onSessionStart(ctx, sw, &identity, sessionStartTime, windowsUser, string(sessionID), desktop, nil) - err = rdpc.Wait() + err = rdpc.Run(ctx) s.onSessionEnd(ctx, sw, &identity, sessionStartTime, recordSession, windowsUser, string(sessionID), desktop) return trace.Wrap(err)