Skip to content

Commit

Permalink
Added ability to reboot and reconnect.
Browse files Browse the repository at this point in the history
  • Loading branch information
bitbound committed Sep 6, 2023
1 parent e0e1e23 commit 9128a24
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 65 deletions.
31 changes: 16 additions & 15 deletions Agent/Services/AgentHubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ public async Task Connect()
_heartbeatTimer.Start();

await _hubConnection.SendAsync("CheckForPendingSriptRuns");
await _hubConnection.SendAsync("CheckForPendingRemoteControlSessions");

break;
}
Expand Down Expand Up @@ -377,13 +378,13 @@ public async Task RestartScreenCaster(string[] viewerIds, string sessionId, stri
return;
}
await _appLauncher.RestartScreenCaster(
viewerIds,
sessionId,
accessKey,
userConnectionId,
requesterName,
orgName,
orgId,
viewerIds,
sessionId,
accessKey,
userConnectionId,
requesterName,
orgName,
orgId,
_hubConnection);
}
catch (Exception ex)
Expand All @@ -393,10 +394,10 @@ await _appLauncher.RestartScreenCaster(
}

public async Task RunScript(
Guid savedScriptId,
int scriptRunId,
string initiator,
ScriptInputType scriptInputType,
Guid savedScriptId,
int scriptRunId,
string initiator,
ScriptInputType scriptInputType,
string authToken)
{
try
Expand Down Expand Up @@ -607,7 +608,7 @@ private void RegisterMessageHandlers()

// TODO: Replace all these parameters with a single DTO per method.
_hubConnection.On<string, string, string, string, string, string, string, int>(
nameof(ChangeWindowsSession),
nameof(ChangeWindowsSession),
ChangeWindowsSession);

_hubConnection.On<string, string, string, string, bool, string>(nameof(SendChatMessage), SendChatMessage);
Expand All @@ -619,7 +620,7 @@ private void RegisterMessageHandlers()
_hubConnection.On<ScriptingShell, string, string, string, string>(nameof(ExecuteCommand), ExecuteCommand);

_hubConnection.On<ScriptingShell, string, string, string, string>(nameof(ExecuteCommandFromApi), ExecuteCommandFromApi);

_hubConnection.On<string>(nameof(GetLogs), GetLogs);

_hubConnection.On<string, int, CompletionIntent, bool?, string>(nameof(GetPowerShellCompletions), GetPowerShellCompletions);
Expand All @@ -631,13 +632,13 @@ private void RegisterMessageHandlers()
_hubConnection.On<Guid, string, string, string, string, string>(nameof(RemoteControl), RemoteControl);

_hubConnection.On<string[], string, string, string, string, string, string>(
nameof(RestartScreenCaster),
nameof(RestartScreenCaster),
RestartScreenCaster);

_hubConnection.On<Guid, int, string, ScriptInputType, string>(nameof(RunScript), RunScript);

_hubConnection.On<string, string[], string, string>(
nameof(TransferFileFromBrowserToAgent),
nameof(TransferFileFromBrowserToAgent),
TransferFileFromBrowserToAgent);

_hubConnection.On(nameof(TriggerHeartbeat), TriggerHeartbeat);
Expand Down
90 changes: 68 additions & 22 deletions Server/Hubs/AgentHub.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Immense.RemoteControl.Server.Hubs;
using Immense.RemoteControl.Server.Services;
using Immense.SimpleMessenger;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Caching.Memory;
Expand Down Expand Up @@ -26,8 +27,9 @@ public class AgentHub : Hub<IAgentHubClient>
private readonly ICircuitManager _circuitManager;
private readonly IDataService _dataService;
private readonly IExpiringTokenService _expiringTokenService;
private readonly IMessenger _messenger;
private readonly ILogger<AgentHub> _logger;
private readonly IMessenger _messenger;
private readonly IRemoteControlSessionCache _remoteControlSessions;
private readonly IAgentHubSessionCache _serviceSessionCache;
private readonly IHubContext<ViewerHub> _viewerHubContext;

Expand All @@ -37,6 +39,7 @@ public AgentHub(IDataService dataService,
IHubContext<ViewerHub> viewerHubContext,
ICircuitManager circuitManager,
IExpiringTokenService expiringTokenService,
IRemoteControlSessionCache remoteControlSessionCache,
IMessenger messenger,
ILogger<AgentHub> logger)
{
Expand All @@ -46,6 +49,7 @@ public AgentHub(IDataService dataService,
_appConfig = appConfig;
_circuitManager = circuitManager;
_expiringTokenService = expiringTokenService;
_remoteControlSessions = remoteControlSessionCache;
_messenger = messenger;
_logger = logger;
}
Expand All @@ -57,7 +61,7 @@ private Device? Device
{
get
{
if (Context.Items["Device"] is Device device)
if (Context.Items["Device"] is Device device)
{
return device;
}
Expand Down Expand Up @@ -94,6 +98,46 @@ await Clients.Caller.SendChatMessage(
}
}

public async Task CheckForPendingRemoteControlSessions()
{
try
{
if (Device is null)
{
return;
}

_logger.LogDebug(
"Checking for pending remote control sessions for device {deviceId}.",
Device.ID);

var waitingSessions = _remoteControlSessions
.Sessions
.OfType<RemoteControlSessionEx>()
.Where(x => x.DeviceId == Device.ID);

foreach (var session in waitingSessions)
{
_logger.LogDebug(
"Restarting remote control session {sessionId}.",
session.UnattendedSessionId);

session.AgentConnectionId = Context.ConnectionId;
await Clients.Caller.RestartScreenCaster(
session.ViewerList.ToArray(),
$"{session.UnattendedSessionId}",
session.AccessKey,
session.UserConnectionId,
session.RequesterName,
session.OrganizationName,
session.OrganizationId);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error while checking for pending remote control sessions.");
}
}

public async Task CheckForPendingScriptRuns()
{
Expand Down Expand Up @@ -142,32 +186,31 @@ public async Task<bool> DeviceCameOnline(DeviceClientDto device)
}

var result = await _dataService.AddOrUpdateDevice(device);
if (result.IsSuccess)
if (!result.IsSuccess)
{
Device = result.Value;
// Organization wasn't found.
return false;
}

_serviceSessionCache.AddOrUpdateByConnectionId(Context.ConnectionId, Device);
Device = result.Value;

var userIDs = _circuitManager.Connections.Select(x => x.User.Id);
_serviceSessionCache.AddOrUpdateByConnectionId(Context.ConnectionId, Device);

var filteredUserIDs = _dataService.FilterUsersByDevicePermission(userIDs, Device.ID);
var userIDs = _circuitManager.Connections.Select(x => x.User.Id);

var connections = _circuitManager.Connections
.Where(x => x.User.OrganizationID == Device.OrganizationID &&
filteredUserIDs.Contains(x.User.Id));
var filteredUserIDs = _dataService.FilterUsersByDevicePermission(userIDs, Device.ID);

foreach (var connection in connections)
{
var message = new DeviceStateChangedMessage(Device);
await _messenger.Send(message, connection.ConnectionId);
}
return true;
}
else
var connections = _circuitManager.Connections
.Where(x => x.User.OrganizationID == Device.OrganizationID &&
filteredUserIDs.Contains(x.User.Id));

foreach (var connection in connections)
{
// Organization wasn't found.
return false;
var message = new DeviceStateChangedMessage(Device);
await _messenger.Send(message, connection.ConnectionId);
}

return true;
}
catch (Exception ex)
{
Expand Down Expand Up @@ -227,7 +270,6 @@ public async Task DeviceHeartbeat(DeviceClientDto device)
await CheckForPendingScriptRuns();
}


public Task DisplayMessage(string consoleMessage, string popupMessage, string className, string requesterId)
{
var message = new DisplayNotificationMessage(consoleMessage, popupMessage, className);
Expand Down Expand Up @@ -310,6 +352,7 @@ public void ScriptResultViaApi(string commandID, string requestID)
{
ApiScriptResults.Set(requestID, commandID, DateTimeOffset.Now.AddHours(1));
}

public Task SendConnectionFailedToViewers(List<string> viewerIDs)
{
return _viewerHubContext.Clients.Clients(viewerIDs).SendAsync("ConnectionFailed");
Expand All @@ -320,6 +363,7 @@ public Task SendLogs(string logChunk, string requesterConnectionId)
var message = new ReceiveLogsMessage(logChunk);
return _messenger.Send(message, requesterConnectionId);
}

public void SetServerVerificationToken(string verificationToken)
{
if (Device is null)
Expand All @@ -329,11 +373,13 @@ public void SetServerVerificationToken(string verificationToken)
Device.ServerVerificationToken = verificationToken;
_dataService.SetServerVerificationToken(Device.ID, verificationToken);
}

public Task TransferCompleted(string transferId, string requesterId)
{
var message = new TransferCompleteMessage(transferId);
return _messenger.Send(message, requesterId);
}

private async Task<bool> CheckForDeviceBan(params string[] deviceIdNameOrIPs)
{
foreach (var device in deviceIdNameOrIPs)
Expand All @@ -352,7 +398,7 @@ private async Task<bool> CheckForDeviceBan(params string[] deviceIdNameOrIPs)
return true;
}
}

return false;
}
}
49 changes: 26 additions & 23 deletions Server/Services/RcImplementations/HubEventHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,28 @@
using Immense.RemoteControl.Server.Models;
using Immense.RemoteControl.Shared.Enums;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Build.Framework;
using Microsoft.Extensions.Logging;
using NuGet.Protocol.Core.Types;
using Remotely.Server.Hubs;
using Remotely.Server.Models;
using Remotely.Shared.Enums;
using Remotely.Shared.Interfaces;
using Remotely.Shared.Models;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace Remotely.Server.Services.RcImplementations;

public class HubEventHandler : IHubEventHandler
{
private readonly IHubContext<AgentHub, IAgentHubClient> _serviceHub;
private readonly IAgentHubSessionCache _agentCache;
private readonly ILogger<HubEventHandler> _logger;

public HubEventHandler(
IHubContext<AgentHub, IAgentHubClient> serviceHub,
IAgentHubSessionCache agentHubSessionCache,
ILogger<HubEventHandler> logger)
{
_serviceHub = serviceHub;
_agentCache = agentHubSessionCache;
_logger = logger;
}

Expand All @@ -38,7 +33,7 @@ public Task ChangeWindowsSession(RemoteControlSession session, string viewerConn
{
_logger.LogError("Event should have been for RemoteControlSessionEx.");
return Task.CompletedTask;
}
}

return _serviceHub.Clients
.Client(ex.AgentConnectionId)
Expand Down Expand Up @@ -123,30 +118,38 @@ public Task NotifySessionChanged(RemoteControlSession session, SessionSwitchReas
ex.OrganizationId);
}

public Task RestartScreenCaster(RemoteControlSession session, HashSet<string> viewerList)
public async Task RestartScreenCaster(RemoteControlSession session)
{

if (session is not RemoteControlSessionEx ex)
if (session is not RemoteControlSessionEx sessionEx)
{
_logger.LogError("Event should have been for RemoteControlSessionEx.");
return Task.CompletedTask;
return;
}

if (ex.RequireConsent)
if (sessionEx.RequireConsent)
{
// Don't restart if consent wasn't granted on the first request.
return Task.CompletedTask;
return;
}

return _serviceHub.Clients
.Client(ex.AgentConnectionId)
if (!_agentCache.TryGetConnectionId(sessionEx.DeviceId, out var agentConnectionId))
{

return;
}

sessionEx.AgentConnectionId = agentConnectionId;

await _serviceHub.Clients
.Client(sessionEx.AgentConnectionId)
.RestartScreenCaster(
viewerList.ToArray(),
$"{ex.UnattendedSessionId}",
ex.AccessKey,
ex.UserConnectionId,
ex.RequesterName,
ex.OrganizationName,
ex.OrganizationId);
session.ViewerList.ToArray(),
$"{sessionEx.UnattendedSessionId}",
sessionEx.AccessKey,
sessionEx.UserConnectionId,
sessionEx.RequesterName,
sessionEx.OrganizationName,
sessionEx.OrganizationId);
}
}
Loading

0 comments on commit 9128a24

Please sign in to comment.