diff --git a/Agent/Interfaces/IAppLauncher.cs b/Agent/Interfaces/IAppLauncher.cs index a7f540d35..d53fba75b 100644 --- a/Agent/Interfaces/IAppLauncher.cs +++ b/Agent/Interfaces/IAppLauncher.cs @@ -8,5 +8,5 @@ public interface IAppLauncher { Task LaunchChatService(string pipeName, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection); Task LaunchRemoteControl(int targetSessionId, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection); - Task RestartScreenCaster(List viewerIDs, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1); + Task RestartScreenCaster(string[] viewerIds, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1); } diff --git a/Agent/Program.cs b/Agent/Program.cs index 027bd27ed..af954741b 100644 --- a/Agent/Program.cs +++ b/Agent/Program.cs @@ -72,6 +72,8 @@ private static async Task Init(IServiceProvider services) { SetSas(services, logger); } + + // TODO: Move this to a BackgroundService. await services.GetRequiredService().BeginChecking(); await services.GetRequiredService().Connect(); } @@ -91,34 +93,34 @@ private static void RegisterServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); services.AddHostedService(services => services.GetRequiredService()); - services.AddScoped(); + services.AddSingleton(); services.AddTransient(); services.AddTransient(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); - services.AddScoped(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); if (OperatingSystem.IsWindows()) { - services.AddScoped(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); } else if (OperatingSystem.IsLinux()) { - services.AddScoped(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); } else if (OperatingSystem.IsMacOS()) { - services.AddScoped(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); } diff --git a/Agent/Services/AgentHubConnection.cs b/Agent/Services/AgentHubConnection.cs index 39d81a61c..0a803f7ef 100644 --- a/Agent/Services/AgentHubConnection.cs +++ b/Agent/Services/AgentHubConnection.cs @@ -7,6 +7,7 @@ using Remotely.Agent.Interfaces; using Remotely.Shared; using Remotely.Shared.Enums; +using Remotely.Shared.Interfaces; using Remotely.Shared.Models; using Remotely.Shared.Services; using System; @@ -23,7 +24,7 @@ namespace Remotely.Agent.Services; -public interface IAgentHubConnection +public interface IAgentHubConnection : IAgentHubClient { bool IsConnected { get; } @@ -34,22 +35,21 @@ public interface IAgentHubConnection public class AgentHubConnection : IAgentHubConnection, IDisposable { private readonly IAppLauncher _appLauncher; + private readonly IHostApplicationLifetime _appLifetime; private readonly IChatClientService _chatService; private readonly IConfigService _configService; private readonly IDeviceInformationService _deviceInfoService; + private readonly IFileLogsManager _fileLogsManager; private readonly IHttpClientFactory _httpFactory; - private readonly IWakeOnLanService _wakeOnLanService; private readonly ILogger _logger; - private readonly IFileLogsManager _fileLogsManager; - private readonly IHostApplicationLifetime _appLifetime; - private readonly IScriptingShellFactory _scriptingShellFactory; private readonly IScriptExecutor _scriptExecutor; + private readonly IScriptingShellFactory _scriptingShellFactory; private readonly IUninstaller _uninstaller; private readonly IUpdater _updater; - + private readonly IWakeOnLanService _wakeOnLanService; private ConnectionInfo? _connectionInfo; - private HubConnection? _hubConnection; private Timer? _heartbeatTimer; + private HubConnection? _hubConnection; private bool _isServerVerified; public AgentHubConnection( @@ -84,6 +84,26 @@ public AgentHubConnection( public bool IsConnected => _hubConnection?.State == HubConnectionState.Connected; + public async Task ChangeWindowsSession(string viewerConnectionId, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, int targetSessionId) + { + try + { + EnsureHubConnection(); + + if (!_isServerVerified) + { + _logger.LogWarning("Session change attempted before server was verified."); + return; + } + + await _appLauncher.RestartScreenCaster(new[] { viewerConnectionId }, sessionId, accessKey, userConnectionId, requesterName, orgName, orgId, _hubConnection, targetSessionId); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while handling ChangeWindowsSession."); + } + } + public async Task Connect() { using var throttle = new SemaphoreSlim(1, 1); @@ -172,373 +192,459 @@ public async Task Connect() } } + public async Task DeleteLogs() + { + try + { + await _fileLogsManager.DeleteLogs(_appLifetime.ApplicationStopping); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while deleting logs."); + } + } + public void Dispose() { GC.SuppressFinalize(this); _heartbeatTimer?.Dispose(); } - public async Task SendHeartbeat() + public async Task ExecuteCommand(ScriptingShell shell, string command, string authToken, string senderUsername, string senderConnectionId) { try { - if (_connectionInfo is null || _hubConnection is null) + EnsureHubConnection(); + + if (!_isServerVerified) { + _logger.LogWarning( + "Command attempted before server was verified. Shell: {shell}. Command: {command}. Sender: {senderConnectionID}", + shell, + command, + senderConnectionId); return; } - if (string.IsNullOrWhiteSpace(_connectionInfo.OrganizationID)) + await _scriptExecutor.RunCommandFromTerminal( + shell, + command, + authToken, + senderUsername, + senderConnectionId, + ScriptInputType.Terminal, + TimeSpan.FromSeconds(30), + _hubConnection) + .ConfigureAwait(false); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while executing command."); + } + } + + public async Task ExecuteCommandFromApi(ScriptingShell shell, string authToken, string requestID, string command, string senderUsername) + { + try + { + EnsureHubConnection(); + + if (!_isServerVerified) { - _logger.LogError("Organization ID is not set. Please set it in the config file."); + _logger.LogWarning( + "Command attempted before server was verified. Shell: {shell}. Command: {command}. Sender: {senderUsername}", + shell, + command, + senderUsername); return; } - var currentInfo = await _deviceInfoService.CreateDevice(_connectionInfo.DeviceID, _connectionInfo.OrganizationID); - await _hubConnection.SendAsync("DeviceHeartbeat", currentInfo); + await _scriptExecutor + .RunCommandFromApi(shell, requestID, command, senderUsername, authToken, _hubConnection) + .ConfigureAwait(false); } catch (Exception ex) { - _logger.LogWarning(ex, "Error while sending heartbeat."); + _logger.LogError(ex, "Error while executing command from API."); } } - private async Task CheckForServerMigration() + public async Task GetLogs(string senderConnectionId) { - if (_connectionInfo is null || _hubConnection is null) + try { - return false; - } + EnsureHubConnection(); - var serverUrl = await _hubConnection.InvokeAsync("GetServerUrl"); + if (!await _fileLogsManager.AnyLogsExist(_appLifetime.ApplicationStopping)) + { + var message = "There are no log entries written."; + await _hubConnection.InvokeAsync("SendLogs", message, senderConnectionId).ConfigureAwait(false); + return; + } - if (Uri.TryCreate(serverUrl, UriKind.Absolute, out var serverUri) && - Uri.TryCreate(_connectionInfo.Host, UriKind.Absolute, out var savedUri) && - serverUri.Host != savedUri.Host) + await foreach (var chunk in _fileLogsManager.ReadAllBytes(_appLifetime.ApplicationStopping)) + { + var lines = Encoding.UTF8.GetString(chunk); + await _hubConnection.InvokeAsync("SendLogs", lines, senderConnectionId).ConfigureAwait(false); + } + } + catch (Exception ex) { - _connectionInfo.Host = serverUrl.Trim().TrimEnd('/'); - _connectionInfo.ServerVerificationToken = null; - _configService.SaveConnectionInfo(_connectionInfo); - await _hubConnection.DisposeAsync(); - return true; + _logger.LogError(ex, "Error while retrieving logs."); } - return false; } - private async void HeartbeatTimer_Elapsed(object? sender, ElapsedEventArgs e) + public async Task GetPowerShellCompletions(string inputText, int currentIndex, CompletionIntent intent, bool? forward, string senderConnectionId) { - await SendHeartbeat(); + try + { + EnsureHubConnection(); + var session = _scriptingShellFactory.GetOrCreatePsCoreShell(senderConnectionId); + var completion = session.GetCompletions(inputText, currentIndex, forward); + var completionModel = completion.ToPwshCompletion(); + await _hubConnection + .InvokeAsync("ReturnPowerShellCompletions", completionModel, intent, senderConnectionId) + .ConfigureAwait(false); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while getting PowerShell completions."); + } } - private async Task HubConnection_Reconnected(string? arg) + public Task InvokeCtrlAltDel() { - if (_connectionInfo is null || _hubConnection is null) + try { - return; - } - - _logger.LogInformation("Reconnected to server."); - await _updater.CheckForUpdates(); + if (!OperatingSystem.IsWindows()) + { + return Task.CompletedTask; + } - var device = await _deviceInfoService.CreateDevice(_connectionInfo.DeviceID, $"{_connectionInfo.OrganizationID}"); + if (!_isServerVerified) + { + _logger.LogWarning("CtrlAltDel attempted before server was verified."); + return Task.CompletedTask; + } - if (!await _hubConnection.InvokeAsync("DeviceCameOnline", device)) - { - await Connect(); - return; + User32.SendSAS(false); } - - if (await CheckForServerMigration()) + catch (Exception ex) { - await Connect(); - return; + _logger.LogError(ex, "Error while invoking CtrlAltDel."); } + return Task.CompletedTask; } - private void RegisterMessageHandlers() + public async Task ReinstallAgent() { - if (_hubConnection is null) + try { - throw new InvalidOperationException("Hub connection is null."); + await _updater.InstallLatestVersion(); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while reinstalling agent."); } + } - _hubConnection.On("ChangeWindowsSession", async (string viewerConnectionId, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, int targetSessionID) => + public async Task RemoteControl(Guid sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId) + { + try { - try - { - if (!_isServerVerified) - { - _logger.LogWarning("Session change attempted before server was verified."); - return; - } + EnsureHubConnection(); - await _appLauncher.RestartScreenCaster(new List() { viewerConnectionId }, sessionId, accessKey, userConnectionId, requesterName, orgName, orgId, _hubConnection, targetSessionID); - } - catch (Exception ex) + if (!_isServerVerified) { - _logger.LogError(ex, "Error while handling ChangeWindowsSession."); + _logger.LogWarning("Remote control attempted before server was verified."); + return; } - }); + await _appLauncher.LaunchRemoteControl(-1, $"{sessionId}", accessKey, userConnectionId, requesterName, orgName, orgId, _hubConnection); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while starting remote control."); + } + } - _hubConnection.On("Chat", async (string senderName, string message, string orgName, string orgId, bool disconnected, string senderConnectionID) => + public async Task RestartScreenCaster(string[] viewerIds, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId) + { + try { - try - { - if (!_isServerVerified) - { - _logger.LogWarning("Chat attempted before server was verified."); - return; - } + EnsureHubConnection(); - await _chatService.SendMessage(senderName, message, orgName, orgId, disconnected, senderConnectionID, _hubConnection); - } - catch (Exception ex) + if (!_isServerVerified) { - _logger.LogError(ex, "Error while handling chat message."); + _logger.LogWarning("Remote control attempted before server was verified."); + return; } - }); + await _appLauncher.RestartScreenCaster( + viewerIds, + sessionId, + accessKey, + userConnectionId, + requesterName, + orgName, + orgId, + _hubConnection); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while restarting screen caster."); + } + } - _hubConnection.On("CtrlAltDel", () => + public async Task RunScript( + Guid savedScriptId, + int scriptRunId, + string initiator, + ScriptInputType scriptInputType, + string authToken) + { + try { if (!_isServerVerified) { - _logger.LogWarning("CtrlAltDel attempted before server was verified."); + _logger.LogWarning( + "Script run attempted before server was verified. Script ID: {savedScriptId}. Initiator: {initiator}", + savedScriptId, + initiator); return; } - User32.SendSAS(false); - }); - _hubConnection.On("DeleteLogs", () => - { - _fileLogsManager.DeleteLogs(_appLifetime.ApplicationStopping); - }); + await _scriptExecutor.RunScript(savedScriptId, scriptRunId, initiator, scriptInputType, authToken); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while running script."); + } + } - _hubConnection.On("ExecuteCommand", ((ScriptingShell shell, string command, string authToken, string senderUsername, string senderConnectionID) => + public async Task SendChatMessage(string senderName, string message, string orgName, string orgId, bool disconnected, string senderConnectionId) + { + try { - try - { - if (!_isServerVerified) - { - _logger.LogWarning( - "Command attempted before server was verified. Shell: {shell}. Command: {command}. Sender: {senderConnectionID}", - shell, - command, - senderConnectionID); - return; - } + EnsureHubConnection(); - _ = _scriptExecutor.RunCommandFromTerminal(shell, - command, - authToken, - senderUsername, - senderConnectionID, - ScriptInputType.Terminal, - TimeSpan.FromSeconds(30), - _hubConnection); - } - catch (Exception ex) + if (!_isServerVerified) { - _logger.LogError(ex, "Error while executing command."); + _logger.LogWarning("Chat attempted before server was verified."); + return; } - })); - _hubConnection.On("ExecuteCommandFromApi", ( - ScriptingShell shell, - string authToken, - string requestID, - string command, - string senderUsername) => + await _chatService + .SendMessage(senderName, message, orgName, orgId, disconnected, senderConnectionId, _hubConnection) + .ConfigureAwait(false); + } + catch (Exception ex) { - try - { - if (!_isServerVerified) - { - _logger.LogWarning( - "Command attempted before server was verified. Shell: {shell}. Command: {command}. Sender: {senderUsername}", - shell, - command, - senderUsername); - return; - } + _logger.LogError(ex, "Error while handling chat message."); + } + } - _ = _scriptExecutor.RunCommandFromApi(shell, requestID, command, senderUsername, authToken, _hubConnection); + public async Task SendHeartbeat() + { + try + { + if (_connectionInfo is null || _hubConnection is null) + { + return; } - catch (Exception ex) + + if (string.IsNullOrWhiteSpace(_connectionInfo.OrganizationID)) { - _logger.LogError(ex, "Error while executing command from API."); + _logger.LogError("Organization ID is not set. Please set it in the config file."); + return; } - }); + var currentInfo = await _deviceInfoService.CreateDevice(_connectionInfo.DeviceID, _connectionInfo.OrganizationID); + await _hubConnection + .SendAsync("DeviceHeartbeat", currentInfo) + .ConfigureAwait(false); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Error while sending heartbeat."); + } + } - _hubConnection.On("GetLogs", async (string senderConnectionId) => + public async Task TransferFileFromBrowserToAgent(string transferId, string[] fileIds, string requesterId, string expiringToken) + { + try { - try - { - if (!await _fileLogsManager.AnyLogsExist(_appLifetime.ApplicationStopping)) - { - var message = "There are no log entries written."; - await _hubConnection.InvokeAsync("SendLogs", message, senderConnectionId).ConfigureAwait(false); - return; - } + EnsureHubConnection(); - await foreach (var chunk in _fileLogsManager.ReadAllBytes(_appLifetime.ApplicationStopping)) - { - var lines = Encoding.UTF8.GetString(chunk); - await _hubConnection.InvokeAsync("SendLogs", lines, senderConnectionId).ConfigureAwait(false); - } - } - catch (Exception ex) + if (!_isServerVerified) { - _logger.LogError(ex, "Error while retrieving logs."); + _logger.LogWarning("File upload attempted before server was verified."); + return; } - }); + _logger.LogInformation("File upload started by {requesterID}.", requesterId); - _hubConnection.On("GetPowerShellCompletions", async (string inputText, int currentIndex, CompletionIntent intent, bool? forward, string senderConnectionId) => - { - try - { - var session = _scriptingShellFactory.GetOrCreatePsCoreShell(senderConnectionId); - var completion = session.GetCompletions(inputText, currentIndex, forward); - var completionModel = completion.ToPwshCompletion(); - await _hubConnection.InvokeAsync("ReturnPowerShellCompletions", completionModel, intent, senderConnectionId).ConfigureAwait(false); - } - catch (Exception ex) + var sharedFilePath = Directory.CreateDirectory(Path.Combine(Path.GetTempPath(), "RemotelySharedFiles")).FullName; + + foreach (var fileID in fileIds) { - _logger.LogError(ex, "Error while getting PowerShell completions."); + var url = $"{_connectionInfo?.Host}/API/FileSharing/{fileID}"; + using var client = _httpFactory.CreateClient(); + client.DefaultRequestHeaders.Add(AppConstants.ExpiringTokenHeaderName, expiringToken); + using var response = await client.GetAsync(url); + + var filename = response.Content.Headers.ContentDisposition?.FileName ?? Path.GetRandomFileName(); + var invalidChars = Path.GetInvalidFileNameChars().ToHashSet(); + var legalChars = filename.ToCharArray().Where(x => !invalidChars.Contains(x)); + + filename = new string(legalChars.ToArray()); + + using var rs = await response.Content.ReadAsStreamAsync(); + using var fs = new FileStream(Path.Combine(sharedFilePath, filename), FileMode.Create); + rs.CopyTo(fs); } - }); + await _hubConnection.SendAsync("TransferCompleted", transferId, requesterId); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while transfering file from browser to agent."); + } + } + public Task TriggerHeartbeat() => SendHeartbeat(); - _hubConnection.On("ReinstallAgent", async () => + public Task UninstallAgent() + { + try { - try - { - await _updater.InstallLatestVersion(); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error while reinstalling agent."); - } - }); + _uninstaller.UninstallAgent(); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while uninstalling agent."); + } + return Task.CompletedTask; + } - _hubConnection.On("UninstallAgent", () => + public async Task WakeDevice(string macAddress) + { + try { - try - { - _uninstaller.UninstallAgent(); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error while uninstalling agent."); - } - }); + _logger.LogInformation( + "Received request to wake device with MAC address {macAddress}.", + macAddress); + await _wakeOnLanService.WakeDevice(macAddress); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while waking device."); + } + } - _hubConnection.On("RemoteControl", async (string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId) => + private async Task CheckForServerMigration() + { + if (_connectionInfo is null || _hubConnection is null) { - try - { - if (!_isServerVerified) - { - _logger.LogWarning("Remote control attempted before server was verified."); - return; - } - await _appLauncher.LaunchRemoteControl(-1, sessionId, accessKey, userConnectionId, requesterName, orgName, orgId, _hubConnection); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error while starting remote control."); - } - }); + return false; + } - _hubConnection.On("RestartScreenCaster", async (List viewerIDs, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId) => + var serverUrl = await _hubConnection.InvokeAsync("GetServerUrl"); + + if (Uri.TryCreate(serverUrl, UriKind.Absolute, out var serverUri) && + Uri.TryCreate(_connectionInfo.Host, UriKind.Absolute, out var savedUri) && + serverUri.Host != savedUri.Host) { - try - { - if (!_isServerVerified) - { - _logger.LogWarning("Remote control attempted before server was verified."); - return; - } - await _appLauncher.RestartScreenCaster(viewerIDs, sessionId, accessKey, userConnectionId, requesterName, orgName, orgId, _hubConnection); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error while restarting screen caster."); - } - }); + _connectionInfo.Host = serverUrl.Trim().TrimEnd('/'); + _connectionInfo.ServerVerificationToken = null; + _configService.SaveConnectionInfo(_connectionInfo); + await _hubConnection.DisposeAsync(); + return true; + } + return false; + } + [MemberNotNull(nameof(_hubConnection))] + private void EnsureHubConnection() + { + if (_hubConnection is null || _hubConnection.State != HubConnectionState.Connected) + { + throw new InvalidOperationException("Hub connection is not established."); + } + } + private async void HeartbeatTimer_Elapsed(object? sender, ElapsedEventArgs e) + { + await SendHeartbeat(); + } - _hubConnection.On("RunScript", (Guid savedScriptId, int scriptRunId, string initiator, ScriptInputType scriptInputType, string authToken) => + private async Task HubConnection_Reconnected(string? arg) + { + if (_connectionInfo is null || _hubConnection is null) { - try - { - if (!_isServerVerified) - { - _logger.LogWarning( - "Script run attempted before server was verified. Script ID: {savedScriptId}. Initiator: {initiator}", - savedScriptId, - initiator); - return; - } + return; + } + + _logger.LogInformation("Reconnected to server."); + await _updater.CheckForUpdates(); - _ = _scriptExecutor.RunScript(savedScriptId, scriptRunId, initiator, scriptInputType, authToken); + var device = await _deviceInfoService.CreateDevice(_connectionInfo.DeviceID, $"{_connectionInfo.OrganizationID}"); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error while running script."); - } - }); + if (!await _hubConnection.InvokeAsync("DeviceCameOnline", device)) + { + await Connect(); + return; + } - _hubConnection.On("TransferFileFromBrowserToAgent", async (string transferID, List fileIDs, string requesterID, string expiringToken) => + if (await CheckForServerMigration()) { - try - { - if (!_isServerVerified) - { - _logger.LogWarning("File upload attempted before server was verified."); - return; - } + await Connect(); + return; + } + } + + private void RegisterMessageHandlers() + { + if (_hubConnection is null) + { + throw new InvalidOperationException("Hub connection is null."); + } - _logger.LogInformation("File upload started by {requesterID}.", requesterID); + // TODO: Replace all these parameters with a single DTO per method. + _hubConnection.On( + nameof(ChangeWindowsSession), + ChangeWindowsSession); - var sharedFilePath = Directory.CreateDirectory(Path.Combine(Path.GetTempPath(), "RemotelySharedFiles")).FullName; + _hubConnection.On(nameof(SendChatMessage), SendChatMessage); - foreach (var fileID in fileIDs) - { - var url = $"{_connectionInfo?.Host}/API/FileSharing/{fileID}"; - using var client = _httpFactory.CreateClient(); - client.DefaultRequestHeaders.Add(AppConstants.ExpiringTokenHeaderName, expiringToken); - using var response = await client.GetAsync(url); + _hubConnection.On(nameof(InvokeCtrlAltDel), InvokeCtrlAltDel); - var filename = response.Content.Headers.ContentDisposition?.FileName ?? Path.GetRandomFileName(); - var invalidChars = Path.GetInvalidFileNameChars().ToHashSet(); - var legalChars = filename.ToCharArray().Where(x => !invalidChars.Contains(x)); + _hubConnection.On(nameof(DeleteLogs), DeleteLogs); - filename = new string(legalChars.ToArray()); + _hubConnection.On(nameof(ExecuteCommand), ExecuteCommand); - using var rs = await response.Content.ReadAsStreamAsync(); - using var fs = new FileStream(Path.Combine(sharedFilePath, filename), FileMode.Create); - rs.CopyTo(fs); - } - await _hubConnection.SendAsync("TransferCompleted", transferID, requesterID); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error while transfering file from browser to agent."); - } - }); + _hubConnection.On(nameof(ExecuteCommandFromApi), ExecuteCommandFromApi); + + _hubConnection.On(nameof(GetLogs), GetLogs); - _hubConnection.On("TriggerHeartbeat", SendHeartbeat); + _hubConnection.On(nameof(GetPowerShellCompletions), GetPowerShellCompletions); - _hubConnection.On("WakeDevice", async (string macAddress) => - { - _logger.LogInformation( - "Received request to wake device with MAC address {macAddress}.", - macAddress); - await _wakeOnLanService.WakeDevice(macAddress); - }); + _hubConnection.On(nameof(ReinstallAgent), ReinstallAgent); + + _hubConnection.On(nameof(UninstallAgent), UninstallAgent); + + _hubConnection.On(nameof(RemoteControl), RemoteControl); + + _hubConnection.On( + nameof(RestartScreenCaster), + RestartScreenCaster); + + _hubConnection.On(nameof(RunScript), RunScript); + + _hubConnection.On( + nameof(TransferFileFromBrowserToAgent), + TransferFileFromBrowserToAgent); + + _hubConnection.On(nameof(TriggerHeartbeat), TriggerHeartbeat); + + _hubConnection.On(nameof(WakeDevice), WakeDevice); } private async Task VerifyServer() @@ -577,7 +683,7 @@ private class RetryPolicy : IRetryPolicy { private readonly ILogger _logger; - public RetryPolicy(ILogger logger) + public RetryPolicy(ILogger logger) { _logger = logger; } diff --git a/Agent/Services/CpuUtilizationSampler.cs b/Agent/Services/CpuUtilizationSampler.cs index 13f4da357..b5fe4c292 100644 --- a/Agent/Services/CpuUtilizationSampler.cs +++ b/Agent/Services/CpuUtilizationSampler.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using Remotely.Shared.Services; using System; using System.Collections.Generic; using System.Diagnostics; @@ -18,11 +19,15 @@ public interface ICpuUtilizationSampler : IHostedService internal class CpuUtilizationSampler : BackgroundService, ICpuUtilizationSampler { private readonly HashSet _ignoredProcesses = new(); + private readonly IElevationDetector _elevationDetector; private readonly ILogger _logger; private double _currentUtilization; - public CpuUtilizationSampler(ILogger logger) + public CpuUtilizationSampler( + IElevationDetector elevationDetector, + ILogger logger) { + _elevationDetector = elevationDetector; _logger = logger; } @@ -58,6 +63,16 @@ private async Task GetCpuUtilization(CancellationToken cancelToken) var utilizations = new Dictionary>(); var processes = Process.GetProcesses(); + // If we're on Windows and not running in an elevated process, + // don't try to get CPU utilization for session 0 processes. It + // will throw. + if (OperatingSystem.IsWindows() && + !_elevationDetector.IsElevated()) + { + + processes = processes.Where(x => x.SessionId != 0).ToArray(); + } + foreach (var proc in processes) { if (cancelToken.IsCancellationRequested) diff --git a/Agent/Services/Linux/AppLauncherLinux.cs b/Agent/Services/Linux/AppLauncherLinux.cs index 53e3106c7..f1b588776 100644 --- a/Agent/Services/Linux/AppLauncherLinux.cs +++ b/Agent/Services/Linux/AppLauncherLinux.cs @@ -175,7 +175,7 @@ await hubConnection.SendAsync("DisplayMessage", } } - public async Task RestartScreenCaster(List viewerIDs, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1) + public async Task RestartScreenCaster(string[] viewerIds, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1) { try { @@ -192,7 +192,7 @@ public async Task RestartScreenCaster(List viewerIDs, string sessionId, } catch (Exception ex) { - await hubConnection.SendAsync("SendConnectionFailedToViewers", viewerIDs); + await hubConnection.SendAsync("SendConnectionFailedToViewers", viewerIds); _logger.LogError(ex, "Error while restarting screen caster."); throw; } diff --git a/Agent/Services/MacOS/AppLauncherMac.cs b/Agent/Services/MacOS/AppLauncherMac.cs index 2f8521008..238ba8615 100644 --- a/Agent/Services/MacOS/AppLauncherMac.cs +++ b/Agent/Services/MacOS/AppLauncherMac.cs @@ -19,7 +19,7 @@ public async Task LaunchRemoteControl(int targetSessionId, string sessionId, str await hubConnection.SendAsync("DisplayMessage", "Feature under development.", "Feature is under development.", "bg-warning", userConnectionId); } - public async Task RestartScreenCaster(List viewerIDs, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1) + public async Task RestartScreenCaster(string[] viewerIds, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1) { await hubConnection.SendAsync("DisplayMessage", "Feature under development.", "Feature is under development.", "bg-warning", userConnectionId); } diff --git a/Agent/Services/Windows/AppLauncherWin.cs b/Agent/Services/Windows/AppLauncherWin.cs index 335637fd5..cfef5df84 100644 --- a/Agent/Services/Windows/AppLauncherWin.cs +++ b/Agent/Services/Windows/AppLauncherWin.cs @@ -159,7 +159,7 @@ await hubConnection.SendAsync("DisplayMessage", userConnectionId); } } - public async Task RestartScreenCaster(List viewerIDs, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1) + public async Task RestartScreenCaster(string[] viewerIds, string sessionId, string accessKey, string userConnectionId, string requesterName, string orgName, string orgId, HubConnection hubConnection, int targetSessionID = -1) { try { @@ -179,7 +179,7 @@ public async Task RestartScreenCaster(List viewerIDs, string sessionId, $" --org-id \"{orgId}\"" + $" --session-id \"{sessionId}\"" + $" --access-key \"{accessKey}\"" + - $" --viewers {string.Join(",", viewerIDs)}", + $" --viewers {string.Join(",", viewerIds)}", targetSessionId: targetSessionID, forceConsoleSession: Shlwapi.IsOS(OsType.OS_ANYSERVER) && targetSessionID == -1, @@ -190,7 +190,7 @@ public async Task RestartScreenCaster(List viewerIDs, string sessionId, if (!result) { _logger.LogWarning("Failed to relaunch screen caster."); - await hubConnection.SendAsync("SendConnectionFailedToViewers", viewerIDs); + await hubConnection.SendAsync("SendConnectionFailedToViewers", viewerIds); await hubConnection.SendAsync("DisplayMessage", "Remote control failed to start on target device.", "Failed to start remote control.", @@ -209,12 +209,12 @@ await hubConnection.SendAsync("DisplayMessage", $" --org-id \"{orgId}\"" + $" --session-id \"{sessionId}\"" + $" --access-key \"{accessKey}\"" + - $" --viewers {string.Join(",", viewerIDs)}"); + $" --viewers {string.Join(",", viewerIds)}"); } } catch (Exception ex) { - await hubConnection.SendAsync("SendConnectionFailedToViewers", viewerIDs); + await hubConnection.SendAsync("SendConnectionFailedToViewers", viewerIds); _logger.LogError(ex, "Error while restarting screen caster."); throw; } diff --git a/Server/API/AgentUpdateController.cs b/Server/API/AgentUpdateController.cs index 9bd6900cb..ce47d82a0 100644 --- a/Server/API/AgentUpdateController.cs +++ b/Server/API/AgentUpdateController.cs @@ -6,6 +6,7 @@ using Remotely.Server.Hubs; using Remotely.Server.RateLimiting; using Remotely.Server.Services; +using Remotely.Shared.Interfaces; using System; using System.IO; using System.Linq; @@ -18,7 +19,7 @@ namespace Remotely.Server.API; [ApiController] public class AgentUpdateController : ControllerBase { - private readonly IHubContext _agentHubContext; + private readonly IHubContext _agentHubContext; private readonly ILogger _logger; private readonly IApplicationConfig _appConfig; private readonly IWebHostEnvironment _hostEnv; @@ -27,7 +28,7 @@ public class AgentUpdateController : ControllerBase public AgentUpdateController(IWebHostEnvironment hostingEnv, IApplicationConfig appConfig, IAgentHubSessionCache serviceSessionCache, - IHubContext agentHubContext, + IHubContext agentHubContext, ILogger logger) { _hostEnv = hostingEnv; @@ -111,8 +112,7 @@ private async Task CheckForDeviceBan(string deviceIp) var bannedDevices = _serviceSessionCache.GetAllDevices().Where(x => x.PublicIP == deviceIp); var connectionIds = _serviceSessionCache.GetConnectionIdsByDeviceIds(bannedDevices.Select(x => x.ID)); - - await _agentHubContext.Clients.Clients(connectionIds).SendAsync("UninstallAgent"); + await _agentHubContext.Clients.Clients(connectionIds).UninstallAgent(); return true; } diff --git a/Server/API/RemoteControlController.cs b/Server/API/RemoteControlController.cs index 4fc065aea..fe339e07f 100644 --- a/Server/API/RemoteControlController.cs +++ b/Server/API/RemoteControlController.cs @@ -14,6 +14,7 @@ using Microsoft.Extensions.Logging; using Remotely.Server.Extensions; using Remotely.Shared.Entities; +using Remotely.Shared.Interfaces; // For more information on enabling Web API for empty projects, visit https://go.microsoft.com/fwlink/?LinkID=397860 @@ -23,7 +24,7 @@ namespace Remotely.Server.API; [ApiController] public class RemoteControlController : ControllerBase { - private readonly IHubContext _serviceHub; + private readonly IHubContext _agentHub; private readonly IRemoteControlSessionCache _remoteControlSessionCache; private readonly IAgentHubSessionCache _serviceSessionCache; private readonly IApplicationConfig _appConfig; @@ -37,7 +38,7 @@ public RemoteControlController( SignInManager signInManager, IDataService dataService, IRemoteControlSessionCache remoteControlSessionCache, - IHubContext serviceHub, + IHubContext agentHub, IAgentHubSessionCache serviceSessionCache, IOtpProvider otpProvider, IHubEventHandler hubEvents, @@ -45,7 +46,7 @@ public RemoteControlController( ILogger logger) { _dataService = dataService; - _serviceHub = serviceHub; + _agentHub = agentHub; _remoteControlSessionCache = remoteControlSessionCache; _serviceSessionCache = serviceSessionCache; _appConfig = appConfig; @@ -179,7 +180,7 @@ private async Task InitiateRemoteControl(string deviceID, string return BadRequest("Failed to resolve organization name."); } - await _serviceHub.Clients.Client(serviceConnectionId).SendAsync("RemoteControl", + await _agentHub.Clients.Client(serviceConnectionId).RemoteControl( sessionId, accessKey, HttpContext.Connection.Id, diff --git a/Server/API/ScriptingController.cs b/Server/API/ScriptingController.cs index f9cf6cb9c..59c6f2710 100644 --- a/Server/API/ScriptingController.cs +++ b/Server/API/ScriptingController.cs @@ -13,6 +13,7 @@ using Remotely.Shared; using Remotely.Server.Extensions; using Remotely.Shared.Entities; +using Remotely.Shared.Interfaces; namespace Remotely.Server.API; @@ -20,7 +21,7 @@ namespace Remotely.Server.API; [Route("api/[controller]")] public class ScriptingController : ControllerBase { - private readonly IHubContext _agentHubContext; + private readonly IHubContext _agentHubContext; private readonly IDataService _dataService; private readonly IAgentHubSessionCache _serviceSessionCache; @@ -32,7 +33,7 @@ public ScriptingController(UserManager userManager, IDataService dataService, IAgentHubSessionCache serviceSessionCache, IExpiringTokenService expiringTokenService, - IHubContext agentHub) + IHubContext agentHub) { _dataService = dataService; _serviceSessionCache = serviceSessionCache; @@ -61,7 +62,6 @@ public async Task> ExecuteCommand(string mode, string command = await sr.ReadToEndAsync(); } - var userID = string.Empty; if (Request.HttpContext.User.Identity?.IsAuthenticated == true) { var username = Request.HttpContext.User.Identity.Name; @@ -97,7 +97,12 @@ public async Task> ExecuteCommand(string mode, string var authToken = _expiringTokenService.GetToken(Time.Now.AddMinutes(AppConstants.ScriptRunExpirationMinutes)); // TODO: Replace with new invoke capability in .NET 7. - await _agentHubContext.Clients.Client(connectionId).SendAsync("ExecuteCommandFromApi", shell, authToken, requestID, command, User?.Identity?.Name); + await _agentHubContext.Clients.Client(connectionId).ExecuteCommandFromApi( + shell, + authToken, + requestID, + command, + User?.Identity?.Name ?? "API Key"); var success = await WaitHelper.WaitForAsync(() => AgentHub.ApiScriptResults.TryGetValue(requestID, out _), TimeSpan.FromSeconds(30)); if (!success) diff --git a/Server/Hubs/AgentHub.cs b/Server/Hubs/AgentHub.cs index a4c42404d..6d19fb5eb 100644 --- a/Server/Hubs/AgentHub.cs +++ b/Server/Hubs/AgentHub.cs @@ -8,6 +8,7 @@ using Remotely.Shared.Dtos; using Remotely.Shared.Entities; using Remotely.Shared.Enums; +using Remotely.Shared.Interfaces; using Remotely.Shared.Models; using Remotely.Shared.Utilities; using System; @@ -17,7 +18,7 @@ namespace Remotely.Server.Hubs; -public class AgentHub : Hub +public class AgentHub : Hub { private readonly IApplicationConfig _appConfig; private readonly ICircuitManager _circuitManager; @@ -77,7 +78,13 @@ public Task Chat(string message, bool disconnected, string browserConnectionId) } else { - return Clients.Caller.SendAsync("Chat", string.Empty, string.Empty, string.Empty, true, browserConnectionId); + return Clients.Caller.SendChatMessage( + senderName: string.Empty, + message: string.Empty, + orgName: string.Empty, + orgId: string.Empty, + disconnected: true, + senderConnectionId: browserConnectionId); } } @@ -91,12 +98,17 @@ public async Task CheckForPendingScriptRuns() var authToken = _expiringTokenService.GetToken(Time.Now.AddMinutes(AppConstants.ScriptRunExpirationMinutes)); var scriptRuns = await _dataService.GetPendingScriptRuns(Device.ID); + foreach (var run in scriptRuns) { - await Clients.Caller.SendAsync("RunScript", - run.SavedScriptId, + if (run.SavedScriptId is null) + { + continue; + } + await Clients.Caller.RunScript( + run.SavedScriptId.Value, run.Id, - run.Initiator, + run.Initiator ?? "Unknown Initiator", run.InputType, authToken); } @@ -106,7 +118,7 @@ public async Task DeviceCameOnline(DeviceClientDto device) { try { - if (CheckForDeviceBan(device.ID, device.DeviceName)) + if (await CheckForDeviceBan(device.ID, device.DeviceName)) { return false; } @@ -118,7 +130,7 @@ public async Task DeviceCameOnline(DeviceClientDto device) } device.PublicIP = $"{ip}"; - if (CheckForDeviceBan(device.PublicIP)) + if (await CheckForDeviceBan(device.PublicIP)) { return false; } @@ -161,7 +173,7 @@ public async Task DeviceCameOnline(DeviceClientDto device) public async Task DeviceHeartbeat(DeviceClientDto device) { - if (CheckForDeviceBan(device.ID, device.DeviceName)) + if (await CheckForDeviceBan(device.ID, device.DeviceName)) { return; } @@ -173,7 +185,7 @@ public async Task DeviceHeartbeat(DeviceClientDto device) } device.PublicIP = $"{ip}"; - if (CheckForDeviceBan(device.PublicIP)) + if (await CheckForDeviceBan(device.PublicIP)) { return; } @@ -308,7 +320,7 @@ public Task TransferCompleted(string transferID, string requesterID) { return _circuitManager.InvokeOnConnection(requesterID, CircuitEventName.TransferCompleted, transferID); } - private bool CheckForDeviceBan(params string[] deviceIdNameOrIPs) + private async Task CheckForDeviceBan(params string[] deviceIdNameOrIPs) { foreach (var device in deviceIdNameOrIPs) { @@ -322,7 +334,7 @@ private bool CheckForDeviceBan(params string[] deviceIdNameOrIPs) { _logger.LogWarning("Device ID/name/IP ({device}) is banned. Sending uninstall command.", device); - _ = Clients.Caller.SendAsync("UninstallAgent"); + await Clients.Caller.UninstallAgent(); return true; } } diff --git a/Server/Hubs/CircuitConnection.cs b/Server/Hubs/CircuitConnection.cs index e9e4886b3..1c07dc2cc 100644 --- a/Server/Hubs/CircuitConnection.cs +++ b/Server/Hubs/CircuitConnection.cs @@ -14,6 +14,7 @@ using Remotely.Shared; using Remotely.Shared.Entities; using Remotely.Shared.Enums; +using Remotely.Shared.Interfaces; using Remotely.Shared.Utilities; using System; using System.Collections.Concurrent; @@ -54,7 +55,6 @@ public interface ICircuitConnection Task UninstallAgents(string[] deviceIDs); Task UpdateTags(string deviceID, string tags); - Task UploadFiles(List fileIDs, string transferID, string[] deviceIDs); /// /// Sends a Wake-On-LAN request for the specified device to its peer devices. @@ -73,7 +73,7 @@ public interface ICircuitConnection public class CircuitConnection : CircuitHandler, ICircuitConnection { - private readonly IHubContext _agentHubContext; + private readonly IHubContext _agentHubContext; private readonly IApplicationConfig _appConfig; private readonly IClientAppState _appState; private readonly IAuthService _authService; @@ -92,7 +92,7 @@ public CircuitConnection( IAuthService authService, IDataService dataService, IClientAppState appState, - IHubContext agentHubContext, + IHubContext agentHubContext, IApplicationConfig appConfig, ICircuitManager circuitManager, IToastService toastService, @@ -154,10 +154,10 @@ public Task DeleteRemoteLogs(string deviceId) deviceId, User?.UserName); - return _agentHubContext.Clients.Client(key).SendAsync("DeleteLogs"); + return _agentHubContext.Clients.Client(key).DeleteLogs(); } - public Task ExecuteCommandOnAgent(ScriptingShell shell, string command, string[] deviceIDs) + public async Task ExecuteCommandOnAgent(ScriptingShell shell, string command, string[] deviceIDs) { deviceIDs = _dataService.FilterDeviceIdsByUserPermission(deviceIDs, User); var connections = GetActiveConnectionsForUserOrg(deviceIDs); @@ -170,17 +170,12 @@ public Task ExecuteCommandOnAgent(ScriptingShell shell, string command, string[] var authTokenForUploadingResults = _expiringTokenService.GetToken(Time.Now.AddMinutes(5)); - foreach (var connection in connections) - { - _agentHubContext.Clients.Client(connection).SendAsync("ExecuteCommand", - shell, - command, - authTokenForUploadingResults, - User.UserName, - ConnectionId); - } - - return Task.CompletedTask; + await _agentHubContext.Clients.Clients(connections).ExecuteCommand( + shell, + command, + authTokenForUploadingResults, + $"{User.UserName}", + ConnectionId); } public Task GetPowerShellCompletions(string inputText, int currentIndex, CompletionIntent intent, bool? forward) @@ -197,7 +192,12 @@ public Task GetPowerShellCompletions(string inputText, int currentIndex, Complet return Task.CompletedTask; } - return _agentHubContext.Clients.Client(key).SendAsync("GetPowerShellCompletions", inputText, currentIndex, intent, forward, ConnectionId); + return _agentHubContext.Clients.Client(key).GetPowerShellCompletions( + inputText, + currentIndex, + intent, + forward, + ConnectionId); } public Task GetRemoteLogs(string deviceId) @@ -209,7 +209,7 @@ public Task GetRemoteLogs(string deviceId) return Task.CompletedTask; } - return _agentHubContext.Clients.Client(key).SendAsync("GetLogs", ConnectionId); + return _agentHubContext.Clients.Client(key).GetLogs(ConnectionId); } public Task InvokeCircuitEvent(CircuitEventName eventName, params object[] args) @@ -244,16 +244,12 @@ public override async Task OnCircuitOpenedAsync(Circuit circuit, CancellationTok await base.OnCircuitOpenedAsync(circuit, cancellationToken); } - public Task ReinstallAgents(string[] deviceIDs) + public async Task ReinstallAgents(string[] deviceIDs) { deviceIDs = _dataService.FilterDeviceIdsByUserPermission(deviceIDs, User); var connections = GetActiveConnectionsForUserOrg(deviceIDs); - foreach (var connection in connections) - { - _agentHubContext.Clients.Client(connection).SendAsync("ReinstallAgent"); - } + await _agentHubContext.Clients.Clients(connections).ReinstallAgent(); _dataService.RemoveDevices(deviceIDs); - return Task.CompletedTask; } public async Task> RemoteControl(string deviceId, bool viewOnly) @@ -266,7 +262,7 @@ public async Task> RemoteControl(string deviceId, "bg-warning")); return Result.Fail("Device is not online."); } - + if (!_dataService.DoesUserHaveAccessToDevice(deviceId, User)) { @@ -326,11 +322,11 @@ public async Task> RemoteControl(string deviceId, return Result.Fail(orgResult.Reason); } - await _agentHubContext.Clients.Client(serviceConnectionId).SendAsync("RemoteControl", + await _agentHubContext.Clients.Client(serviceConnectionId).RemoteControl( sessionId, accessKey, ConnectionId, - User.UserOptions?.DisplayName, + $"{User.UserOptions?.DisplayName}", orgResult.Value, User.OrganizationID); @@ -349,10 +345,10 @@ public Task RemoveDevices(string[] deviceIDs) } public async Task RunScript( - IEnumerable deviceIds, - Guid savedScriptId, - int scriptRunId, - ScriptInputType scriptInputType, + IEnumerable deviceIds, + Guid savedScriptId, + int scriptRunId, + ScriptInputType scriptInputType, bool runAsHostedService) { var username = string.Empty; @@ -365,14 +361,19 @@ public async Task RunScript( username = User.UserName; deviceIds = _dataService.FilterDeviceIdsByUserPermission(deviceIds.ToArray(), User); } - + var authToken = _expiringTokenService.GetToken(Time.Now.AddMinutes(AppConstants.ScriptRunExpirationMinutes)); var connectionIds = _agentSessionCache.GetConnectionIdsByDeviceIds(deviceIds).ToArray(); if (connectionIds.Any()) { - await _agentHubContext.Clients.Clients(connectionIds).SendAsync("RunScript", savedScriptId, scriptRunId, username, scriptInputType, authToken); + await _agentHubContext.Clients.Clients(connectionIds).RunScript( + savedScriptId, + scriptRunId, + $"{username}", + scriptInputType, + authToken); } } @@ -404,8 +405,8 @@ public async Task SendChat(string message, string deviceId) return; } - await _agentHubContext.Clients.Client(connectionId).SendAsync("Chat", - User.UserOptions?.DisplayName ?? User.UserName, + await _agentHubContext.Clients.Client(connectionId).SendChatMessage( + User.UserOptions?.DisplayName ?? $"{User.UserName}", message, orgResult.Value, User.OrganizationID, @@ -419,7 +420,7 @@ public async Task TransferFileFromBrowserToAgent(string deviceId, string t { return false; } - + if (!_dataService.DoesUserHaveAccessToDevice(deviceId, User)) { _logger.LogWarning("User {username} does not have access to device ID {deviceId} and attempted file upload.", @@ -431,12 +432,13 @@ public async Task TransferFileFromBrowserToAgent(string deviceId, string t var authToken = _expiringTokenService.GetToken(Time.Now.AddMinutes(5)); - await _agentHubContext.Clients.Client(connectionId).SendAsync( - "TransferFileFromBrowserToAgent", - transferId, - fileIds, - ConnectionId, - authToken); + await _agentHubContext.Clients + .Client(connectionId) + .TransferFileFromBrowserToAgent( + transferId, + fileIds, + ConnectionId, + authToken); return true; } @@ -450,17 +452,14 @@ public async Task TriggerHeartbeat(string deviceId) return; } - await _agentHubContext.Clients.Client(connectionId).SendAsync("TriggerHeartbeat"); + await _agentHubContext.Clients.Client(connectionId).TriggerHeartbeat(); } public async Task UninstallAgents(string[] deviceIDs) { deviceIDs = _dataService.FilterDeviceIdsByUserPermission(deviceIDs, User); var connections = GetActiveConnectionsForUserOrg(deviceIDs); - foreach (var connection in connections) - { - await _agentHubContext.Clients.Client(connection).SendAsync("UninstallAgent"); - } + await _agentHubContext.Clients.Clients(connections).UninstallAgent(); _dataService.RemoveDevices(deviceIDs); } @@ -484,29 +483,13 @@ public Task UpdateTags(string deviceID, string tags) return Task.CompletedTask; } - public Task UploadFiles(List fileIDs, string transferID, string[] deviceIDs) - { - _logger.LogInformation( - "File transfer started by {userName}. File transfer IDs: {fileIds}.", - User.UserName, - string.Join(", ", fileIDs)); - - deviceIDs = _dataService.FilterDeviceIdsByUserPermission(deviceIDs, User); - var connections = GetActiveConnectionsForUserOrg(deviceIDs); - foreach (var connection in connections) - { - _agentHubContext.Clients.Client(connection).SendAsync("UploadFiles", transferID, fileIDs, ConnectionId); - } - return Task.CompletedTask; - } - public async Task WakeDevice(Device device) { try { if (!_dataService.DoesUserHaveAccessToDevice(device.ID, User.Id)) { - return Result.Fail("Unauthorized.") ; + return Result.Fail("Unauthorized."); } var availableDevices = _agentSessionCache @@ -658,7 +641,7 @@ private async Task SendWakeCommand(Device deviceToWake, IEnumerable peer peerDevice.DeviceName, peerDevice.ID, User.UserName); - await _agentHubContext.Clients.Client(connectionId).SendAsync("WakeDevice", mac); + await _agentHubContext.Clients.Client(connectionId).WakeDevice(mac); } } } diff --git a/Server/Pages/ServerConfig.razor.cs b/Server/Pages/ServerConfig.razor.cs index 0f1cccee1..83f6fb586 100644 --- a/Server/Pages/ServerConfig.razor.cs +++ b/Server/Pages/ServerConfig.razor.cs @@ -11,6 +11,7 @@ using Remotely.Server.Services; using Remotely.Shared.Entities; using Remotely.Shared.Enums; +using Remotely.Shared.Interfaces; using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; @@ -144,7 +145,7 @@ public partial class ServerConfig : AuthComponentBase [Inject] - private IHubContext AgentHubContext { get; init; } = null!; + private IHubContext AgentHubContext { get; init; } = null!; [Inject] private IConfiguration Configuration { get; init; } = null!; @@ -468,7 +469,7 @@ private async Task UpdateAllDevices() var agentConnections = ServiceSessionCache.GetConnectionIdsByDeviceIds(OutdatedDevices); - await AgentHubContext.Clients.Clients(agentConnections).SendAsync("ReinstallAgent"); + await AgentHubContext.Clients.Clients(agentConnections).ReinstallAgent(); ToastService.ShowToast("Update command sent."); } } diff --git a/Server/Services/RcImplementations/HubEventHandler.cs b/Server/Services/RcImplementations/HubEventHandler.cs index f72073222..499cdcd32 100644 --- a/Server/Services/RcImplementations/HubEventHandler.cs +++ b/Server/Services/RcImplementations/HubEventHandler.cs @@ -8,10 +8,12 @@ using Remotely.Server.Hubs; using Remotely.Server.Models; using Remotely.Shared.Enums; +using Remotely.Shared.Interfaces; using Remotely.Shared.Models; using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -19,11 +21,11 @@ namespace Remotely.Server.Services.RcImplementations; public class HubEventHandler : IHubEventHandler { - private readonly IHubContext _serviceHub; + private readonly IHubContext _serviceHub; private readonly ILogger _logger; public HubEventHandler( - IHubContext serviceHub, + IHubContext serviceHub, ILogger logger) { _serviceHub = serviceHub; @@ -40,9 +42,9 @@ public Task ChangeWindowsSession(RemoteControlSession session, string viewerConn return _serviceHub.Clients .Client(ex.AgentConnectionId) - .SendAsync("ChangeWindowsSession", + .ChangeWindowsSession( viewerConnectionId, - ex.UnattendedSessionId, + $"{ex.UnattendedSessionId}", ex.AccessKey, ex.UserConnectionId, ex.RequesterUserName, @@ -59,7 +61,7 @@ public Task InvokeCtrlAltDel(RemoteControlSession session, string viewerConnecti return Task.CompletedTask; } - return _serviceHub.Clients.Client(ex.AgentConnectionId).SendAsync("CtrlAltDel"); + return _serviceHub.Clients.Client(ex.AgentConnectionId).InvokeCtrlAltDel(); } public Task NotifyDesktopSessionAdded(RemoteControlSession sessionInfo) @@ -105,9 +107,9 @@ public Task NotifySessionChanged(RemoteControlSession session, SessionSwitchReas return _serviceHub.Clients .Client(ex.AgentConnectionId) - .SendAsync("RestartScreenCaster", - ex.ViewerList, - ex.UnattendedSessionId, + .RestartScreenCaster( + ex.ViewerList.ToArray(), + $"{ex.UnattendedSessionId}", ex.AccessKey, ex.UserConnectionId, ex.RequesterUserName, @@ -126,9 +128,9 @@ public Task RestartScreenCaster(RemoteControlSession session, HashSet vi return _serviceHub.Clients .Client(ex.AgentConnectionId) - .SendAsync("RestartScreenCaster", - viewerList, - ex.UnattendedSessionId, + .RestartScreenCaster( + viewerList.ToArray(), + $"{ex.UnattendedSessionId}", ex.AccessKey, ex.UserConnectionId, ex.RequesterName, diff --git a/Shared/Interfaces/IAgentHubClient.cs b/Shared/Interfaces/IAgentHubClient.cs new file mode 100644 index 000000000..78c283bcf --- /dev/null +++ b/Shared/Interfaces/IAgentHubClient.cs @@ -0,0 +1,92 @@ +using Remotely.Shared.Enums; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Remotely.Shared.Interfaces; +public interface IAgentHubClient +{ + Task ChangeWindowsSession( + string viewerConnectionId, + string sessionId, + string accessKey, + string userConnectionId, + string requesterName, + string orgName, + string orgId, + int targetSessionId); + + Task SendChatMessage( + string senderName, + string message, + string orgName, + string orgId, + bool disconnected, + string senderConnectionId); + + Task InvokeCtrlAltDel(); + + Task DeleteLogs(); + + Task ExecuteCommand( + ScriptingShell shell, + string command, + string authToken, + string senderUsername, + string senderConnectionId); + + Task ExecuteCommandFromApi(ScriptingShell shell, + string authToken, + string requestID, + string command, + string senderUsername); + + Task GetLogs(string senderConnectionId); + + Task GetPowerShellCompletions( + string inputText, + int currentIndex, + CompletionIntent intent, + bool? forward, + string senderConnectionId); + + Task ReinstallAgent(); + + Task UninstallAgent(); + + Task RemoteControl( + Guid sessionId, + string accessKey, + string userConnectionId, + string requesterName, + string orgName, + string orgId); + + Task RestartScreenCaster( + string[] viewerIds, + string sessionId, + string accessKey, + string userConnectionId, + string requesterName, + string orgName, + string orgId); + + Task RunScript( + Guid savedScriptId, + int scriptRunId, + string initiator, + ScriptInputType scriptInputType, + string authToken); + + Task TransferFileFromBrowserToAgent( + string transferId, + string[] fileIds, + string requesterId, + string expiringToken); + + Task TriggerHeartbeat(); + + Task WakeDevice(string macAddress); +} diff --git a/Tests/Server.Tests/AgentHubTests.cs b/Tests/Server.Tests/AgentHubTests.cs index e5a7283bf..474c1c06a 100644 --- a/Tests/Server.Tests/AgentHubTests.cs +++ b/Tests/Server.Tests/AgentHubTests.cs @@ -7,17 +7,12 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; using Remotely.Server.Hubs; -using Remotely.Server.Models; using Remotely.Server.Services; -using Remotely.Shared.Dtos; -using Remotely.Shared.Models; -using System; using System.Collections.Generic; -using System.Linq; using System.Security.Claims; -using System.Text; using System.Threading; using System.Threading.Tasks; +using Remotely.Shared.Interfaces; namespace Remotely.Tests; @@ -52,14 +47,14 @@ public async Task DeviceCameOnline_BannedByName() expiringTokenService.Object, logger.Object); - var hubClients = new Mock(); - var caller = new Mock(); + var hubClients = new Mock>(); + var caller = new Mock(); hubClients.Setup(x => x.Caller).Returns(caller.Object); hub.Clients = hubClients.Object; Assert.IsFalse(await hub.DeviceCameOnline(_testData.Org1Device1.ToDto())); hubClients.Verify(x => x.Caller, Times.Once); - caller.Verify(x => x.SendCoreAsync("UninstallAgent", It.IsAny(), It.IsAny()), Times.Once); + caller.Verify(x => x.UninstallAgent(), Times.Once); } // TODO: Checking of device ban should be pulled out into @@ -89,14 +84,14 @@ public async Task DeviceCameOnline_BannedById() expiringTokenService.Object, logger.Object); - var hubClients = new Mock(); - var caller = new Mock(); + var hubClients = new Mock>(); + var caller = new Mock(); hubClients.Setup(x => x.Caller).Returns(caller.Object); hub.Clients = hubClients.Object; Assert.IsFalse(await hub.DeviceCameOnline(_testData.Org1Device1.ToDto())); hubClients.Verify(x => x.Caller, Times.Once); - caller.Verify(x => x.SendCoreAsync("UninstallAgent", It.IsAny(), It.IsAny()), Times.Once); + caller.Verify(x => x.UninstallAgent(), Times.Once); } [TestCleanup] diff --git a/Tests/Server.Tests/CircuitConnectionTests.cs b/Tests/Server.Tests/CircuitConnectionTests.cs index 37a104085..77a5f5bd8 100644 --- a/Tests/Server.Tests/CircuitConnectionTests.cs +++ b/Tests/Server.Tests/CircuitConnectionTests.cs @@ -12,6 +12,7 @@ using Remotely.Server.Tests.Mocks; using Remotely.Shared.Dtos; using Remotely.Shared.Extensions; +using Remotely.Shared.Interfaces; using Remotely.Shared.Models; using Remotely.Tests; using System; @@ -30,7 +31,7 @@ public class CircuitConnectionTests private IDataService _dataService; private Mock _authService; private Mock _clientAppState; - private HubContextFixture _agentHubContextFixture; + private HubContextFixture _agentHubContextFixture; private Mock _appConfig; private Mock _circuitManager; private Mock _toastService; @@ -50,7 +51,7 @@ public async Task Init() _dataService = IoCActivator.ServiceProvider.GetRequiredService(); _authService = new Mock(); _clientAppState = new Mock(); - _agentHubContextFixture = new HubContextFixture(); + _agentHubContextFixture = new HubContextFixture(); _appConfig = new Mock(); _circuitManager = new Mock(); _toastService = new Mock(); @@ -174,11 +175,7 @@ public async Task WakeDevice_GivenMatchingPeerByIp_UsesCorrectPeer() _agentHubContextFixture.SingleClientProxyMock .Verify(x => - x.SendCoreAsync( - "WakeDevice", - new object[] { macAddress }, - default), - Times.Once); + x.WakeDevice(macAddress), Times.Once); _agentHubContextFixture.SingleClientProxyMock.VerifyNoOtherCalls(); _agentHubContextFixture.HubContextMock.VerifyNoOtherCalls(); @@ -247,11 +244,7 @@ public async Task WakeDevice_GivenMatchingPeerByGroupId_UsesCorrectPeer() _agentHubContextFixture.SingleClientProxyMock .Verify(x => - x.SendCoreAsync( - "WakeDevice", - new object[] { macAddress }, - default), - Times.Once); + x.WakeDevice(macAddress), Times.Once); _agentHubContextFixture.SingleClientProxyMock.VerifyNoOtherCalls(); _agentHubContextFixture.HubContextMock.VerifyNoOtherCalls(); @@ -387,11 +380,7 @@ public async Task WakeDevices_GivenPeerIpMatches_UsesCorrectPeer() _agentHubContextFixture.SingleClientProxyMock .Verify(x => - x.SendCoreAsync( - "WakeDevice", - new object[] { macAddress }, - default), - Times.Once); + x.WakeDevice(macAddress), Times.Once); _agentHubContextFixture.SingleClientProxyMock.VerifyNoOtherCalls(); _agentHubContextFixture.HubContextMock.VerifyNoOtherCalls(); @@ -468,11 +457,7 @@ public async Task WakeDevices_GivenMatchingPeerByGroupId_UsesCorrectPeer() _agentHubContextFixture.SingleClientProxyMock .Verify(x => - x.SendCoreAsync( - "WakeDevice", - new object[] { macAddress }, - default), - Times.Once); + x.WakeDevice(macAddress), Times.Once); _agentHubContextFixture.SingleClientProxyMock.VerifyNoOtherCalls(); _agentHubContextFixture.HubContextMock.VerifyNoOtherCalls(); diff --git a/Tests/Server.Tests/Mocks/HubContextFixture.cs b/Tests/Server.Tests/Mocks/HubContextFixture.cs index 85221ed50..e71f816d4 100644 --- a/Tests/Server.Tests/Mocks/HubContextFixture.cs +++ b/Tests/Server.Tests/Mocks/HubContextFixture.cs @@ -9,17 +9,19 @@ namespace Remotely.Server.Tests.Mocks; -public class HubContextFixture - where T : Hub +public class HubContextFixture + where THub : Hub + where THubClient : class { public HubContextFixture() { - HubContextMock = new Mock>(); - HubClientsMock = new Mock(); + + HubContextMock = new Mock>(); + HubClientsMock = new Mock>(); GroupManagerMock = new Mock(); - SingleClientProxyMock = new Mock(); - ClientProxyMock = new Mock(); - + SingleClientProxyMock = new Mock(); + ClientProxyMock = new Mock(); + HubContextMock .Setup(x => x.Clients) .Returns(HubClientsMock.Object); @@ -37,9 +39,9 @@ public HubContextFixture() .Returns(ClientProxyMock.Object); } - public Mock> HubContextMock { get; } - public Mock HubClientsMock { get; } + public Mock> HubContextMock { get; } + public Mock> HubClientsMock { get; } public Mock GroupManagerMock { get; } - public Mock SingleClientProxyMock { get; } - public Mock ClientProxyMock { get; } + public Mock SingleClientProxyMock { get; } + public Mock ClientProxyMock { get; } } diff --git a/submodules/Immense.RemoteControl b/submodules/Immense.RemoteControl index d571422a0..c2788d676 160000 --- a/submodules/Immense.RemoteControl +++ b/submodules/Immense.RemoteControl @@ -1 +1 @@ -Subproject commit d571422a0e282fd009769bf14c7f6adf8867b0d8 +Subproject commit c2788d6768e0f51ba1801fe774aa9b39d6254164