From 4fcf792c3d7efddc9cd261c6a9120531d9002bdc Mon Sep 17 00:00:00 2001 From: kangalio Date: Sat, 8 Apr 2023 18:27:08 +0200 Subject: [PATCH] Remove `ShardManagerMonitor` and other gateway cleanup (#2372) The bulk of this commit is removing `ShardManagerMonitor`. It was just a background task that received `ShardManagerMessage`'s and called the respective `ShardManager` function. Now, you can just call the respective `ShardManager` function directly. --- src/client/bridge/gateway/event.rs | 6 - src/client/bridge/gateway/mod.rs | 51 +--- src/client/bridge/gateway/shard_manager.rs | 87 +++--- .../bridge/gateway/shard_manager_monitor.rs | 112 -------- src/client/bridge/gateway/shard_messenger.rs | 37 ++- src/client/bridge/gateway/shard_queuer.rs | 20 +- src/client/bridge/gateway/shard_runner.rs | 249 +++++++----------- .../bridge/gateway/shard_runner_message.rs | 6 + src/client/bridge/voice/mod.rs | 4 +- src/client/context.rs | 7 +- src/client/dispatch.rs | 19 -- src/client/mod.rs | 24 +- src/gateway/mod.rs | 32 --- 13 files changed, 189 insertions(+), 465 deletions(-) delete mode 100644 src/client/bridge/gateway/shard_manager_monitor.rs diff --git a/src/client/bridge/gateway/event.rs b/src/client/bridge/gateway/event.rs index 2ee53ee4689..a1ef5cb3194 100644 --- a/src/client/bridge/gateway/event.rs +++ b/src/client/bridge/gateway/event.rs @@ -3,12 +3,6 @@ use super::ShardId; use crate::gateway::ConnectionStage; -#[allow(clippy::enum_variant_names)] -#[derive(Clone, Debug)] -pub(crate) enum ClientEvent { - ShardStageUpdate(ShardStageUpdateEvent), -} - /// An event denoting that a shard's connection stage was changed. /// /// # Examples diff --git a/src/client/bridge/gateway/mod.rs b/src/client/bridge/gateway/mod.rs index a8978d1ac29..484e2259356 100644 --- a/src/client/bridge/gateway/mod.rs +++ b/src/client/bridge/gateway/mod.rs @@ -17,7 +17,7 @@ //! ### [`ShardQueuer`] //! //! The shard queuer is a light wrapper around an mpsc receiver that receives -//! [`ShardManagerMessage`]s. It should be run in its own thread so it can receive messages to +//! [`ShardQueuerMessage`]s. It should be run in its own thread so it can receive messages to //! start shards in a queue. //! //! Refer to [its documentation][`ShardQueuer`] for more information. @@ -43,7 +43,6 @@ pub mod event; mod shard_manager; -mod shard_manager_monitor; mod shard_messenger; mod shard_queuer; mod shard_runner; @@ -53,7 +52,6 @@ use std::fmt; use std::time::Duration as StdDuration; pub use self::shard_manager::{ShardManager, ShardManagerOptions}; -pub use self::shard_manager_monitor::{ShardManagerError, ShardManagerMonitor}; pub use self::shard_messenger::ShardMessenger; pub use self::shard_queuer::ShardQueuer; pub use self::shard_runner::{ShardRunner, ShardRunnerOptions}; @@ -61,54 +59,7 @@ pub use self::shard_runner_message::{ChunkGuildFilter, ShardRunnerMessage}; use crate::gateway::ConnectionStage; use crate::model::event::Event; -/// A message either for a [`ShardManager`] or a [`ShardRunner`]. -#[derive(Debug)] -pub enum ShardClientMessage { - /// A message intended to be worked with by a [`ShardManager`]. - Manager(ShardManagerMessage), - /// A message intended to be worked with by a [`ShardRunner`]. - Runner(Box), -} - -/// A message for a [`ShardManager`] relating to an operation with a shard. -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub enum ShardManagerMessage { - /// Indicator that a [`ShardManagerMonitor`] should restart a shard. - Restart(ShardId), - /// An update from a shard runner, - ShardUpdate { id: ShardId, latency: Option, stage: ConnectionStage }, - /// Indicator that a [`ShardManagerMonitor`] should fully shutdown a shard without bringing it - /// back up. - Shutdown(ShardId, u16), - /// Indicator that a [`ShardManagerMonitor`] should fully shutdown all shards and end its - /// monitoring process for the [`ShardManager`]. - ShutdownAll, - /// Indicator that a [`ShardManager`] has initiated a shutdown, and for the component that - /// receives this to also shutdown with no further action taken. - ShutdownInitiated, - /// Indicator that a [`ShardRunner`] has finished the shutdown of a shard, allowing it to move - /// toward the next one. - ShutdownFinished(ShardId), - /// Indicator that a shard sent invalid authentication (a bad token) when identifying with the - /// gateway. Emitted when a shard receives an [`InvalidAuthentication`] Error - /// - /// [`InvalidAuthentication`]: crate::gateway::GatewayError::InvalidAuthentication - ShardInvalidAuthentication, - /// Indicator that a shard provided undocumented gateway intents. Emitted when a shard received - /// an [`InvalidGatewayIntents`] error. - /// - /// [`InvalidGatewayIntents`]: crate::gateway::GatewayError::InvalidGatewayIntents - ShardInvalidGatewayIntents, - /// If a connection has been established but privileged gateway intents were provided without - /// enabling them prior. Emitted when a shard received a [`DisallowedGatewayIntents`] error. - /// - /// [`DisallowedGatewayIntents`]: crate::gateway::GatewayError::DisallowedGatewayIntents - ShardDisallowedGatewayIntents, -} - /// A message to be sent to the [`ShardQueuer`]. -/// -/// This should usually be wrapped in a [`ShardClientMessage`]. #[derive(Clone, Debug)] pub enum ShardQueuerMessage { /// Message to start a shard, where the 0-index element is the ID of the Shard to start and the diff --git a/src/client/bridge/gateway/shard_manager.rs b/src/client/bridge/gateway/shard_manager.rs index 446d1767d38..f803638cb32 100644 --- a/src/client/bridge/gateway/shard_manager.rs +++ b/src/client/bridge/gateway/shard_manager.rs @@ -1,8 +1,9 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; +use std::time::Duration; use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender}; -use futures::StreamExt; +use futures::{SinkExt, StreamExt}; #[cfg(feature = "framework")] use once_cell::sync::OnceCell; use tokio::sync::{Mutex, RwLock}; @@ -10,14 +11,7 @@ use tokio::time::timeout; use tracing::{info, instrument, warn}; use typemap_rev::TypeMap; -use super::{ - ShardId, - ShardManagerMessage, - ShardManagerMonitor, - ShardQueuer, - ShardQueuerMessage, - ShardRunnerInfo, -}; +use super::{ShardId, ShardQueuer, ShardQueuerMessage, ShardRunnerInfo}; #[cfg(feature = "cache")] use crate::cache::Cache; #[cfg(feature = "voice")] @@ -25,7 +19,7 @@ use crate::client::bridge::voice::VoiceGatewayManager; use crate::client::{EventHandler, RawEventHandler}; #[cfg(feature = "framework")] use crate::framework::Framework; -use crate::gateway::PresenceData; +use crate::gateway::{ConnectionStage, GatewayError, PresenceData}; use crate::http::Http; use crate::internal::prelude::*; use crate::internal::tokio::spawn_named; @@ -103,7 +97,7 @@ use crate::model::gateway::GatewayIntents; /// [`Client`]: crate::Client #[derive(Debug)] pub struct ShardManager { - monitor_tx: Sender, + return_value_tx: Sender>, /// The shard runners currently managed. /// /// **Note**: It is highly unrecommended to mutate this yourself unless you need to. Instead @@ -117,6 +111,7 @@ pub struct ShardManager { shard_total: u32, shard_queuer: Sender, shard_shutdown: Receiver, + shard_shutdown_send: Sender, gateway_intents: GatewayIntents, } @@ -124,13 +119,25 @@ impl ShardManager { /// Creates a new shard manager, returning both the manager and a monitor for usage in a /// separate thread. #[must_use] - pub fn new(opt: ShardManagerOptions) -> (Arc>, ShardManagerMonitor) { - let (thread_tx, thread_rx) = mpsc::unbounded(); + pub fn new(opt: ShardManagerOptions) -> (Arc>, Receiver>) { + let (return_value_tx, return_value_rx) = mpsc::unbounded(); let (shard_queue_tx, shard_queue_rx) = mpsc::unbounded(); let runners = Arc::new(Mutex::new(HashMap::new())); let (shutdown_send, shutdown_recv) = mpsc::unbounded(); + let manager = Arc::new(Mutex::new(Self { + return_value_tx, + shard_index: opt.shard_index, + shard_init: opt.shard_init, + shard_queuer: shard_queue_tx, + shard_total: opt.shard_total, + shard_shutdown: shutdown_recv, + shard_shutdown_send: shutdown_send, + runners: Arc::clone(&runners), + gateway_intents: opt.intents, + })); + let mut shard_queuer = ShardQueuer { data: opt.data, event_handlers: opt.event_handlers, @@ -138,9 +145,9 @@ impl ShardManager { #[cfg(feature = "framework")] framework: opt.framework, last_start: None, - manager_tx: thread_tx.clone(), + manager: Arc::clone(&manager), queue: VecDeque::new(), - runners: Arc::clone(&runners), + runners, rx: shard_queue_rx, #[cfg(feature = "voice")] voice_manager: opt.voice_manager, @@ -156,22 +163,7 @@ impl ShardManager { shard_queuer.run().await; }); - let manager = Arc::new(Mutex::new(Self { - monitor_tx: thread_tx, - shard_index: opt.shard_index, - shard_init: opt.shard_init, - shard_queuer: shard_queue_tx, - shard_total: opt.shard_total, - shard_shutdown: shutdown_recv, - runners, - gateway_intents: opt.intents, - })); - - (Arc::clone(&manager), ShardManagerMonitor { - rx: thread_rx, - manager, - shutdown: shutdown_send, - }) + (Arc::clone(&manager), return_value_rx) } /// Returns whether the shard manager contains either an active instance of a shard runner @@ -324,7 +316,6 @@ impl ShardManager { } drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown)); - drop(self.monitor_tx.unbounded_send(ShardManagerMessage::ShutdownInitiated)); } #[instrument(skip(self))] @@ -341,6 +332,37 @@ impl ShardManager { pub fn intents(&self) -> GatewayIntents { self.gateway_intents } + + pub async fn return_with_value(&mut self, ret: Result<(), GatewayError>) { + if let Err(e) = self.return_value_tx.send(ret).await { + tracing::warn!("failed to send return value: {}", e); + } + } + + pub fn shutdown_finished(&self, id: ShardId) { + if let Err(e) = self.shard_shutdown_send.unbounded_send(id) { + tracing::warn!("failed to notify about finished shutdown: {}", e); + } + } + + pub async fn restart_shard(&mut self, id: ShardId) { + self.restart(id).await; + if let Err(e) = self.shard_shutdown_send.unbounded_send(id) { + tracing::warn!("failed to notify about finished shutdown: {}", e); + } + } + + pub async fn shard_update( + &self, + id: ShardId, + latency: Option, + stage: ConnectionStage, + ) { + if let Some(runner) = self.runners.lock().await.get_mut(&id) { + runner.latency = latency; + runner.stage = stage; + } + } } impl Drop for ShardManager { @@ -352,7 +374,6 @@ impl Drop for ShardManager { /// [`ShardRunner`]: super::ShardRunner fn drop(&mut self) { drop(self.shard_queuer.unbounded_send(ShardQueuerMessage::Shutdown)); - drop(self.monitor_tx.unbounded_send(ShardManagerMessage::ShutdownInitiated)); } } diff --git a/src/client/bridge/gateway/shard_manager_monitor.rs b/src/client/bridge/gateway/shard_manager_monitor.rs deleted file mode 100644 index 53bd37ee6e2..00000000000 --- a/src/client/bridge/gateway/shard_manager_monitor.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::Arc; - -use futures::channel::mpsc::{UnboundedReceiver as Receiver, UnboundedSender as Sender}; -use futures::StreamExt; -use tokio::sync::Mutex; -use tracing::{debug, instrument, warn}; - -use super::{ShardManager, ShardManagerMessage}; -use crate::client::bridge::gateway::ShardId; - -/// The shard manager monitor monitors the shard manager and performs actions on it as received. -/// -/// The monitor is essentially responsible for running in its own task and receiving -/// [`ShardManagerMessage`]s, such as whether to shutdown a shard or shutdown everything entirely. -#[derive(Debug)] -pub struct ShardManagerMonitor { - /// An clone of the Arc to the manager itself. - pub manager: Arc>, - /// The mpsc Receiver channel to receive shard manager messages over. - pub rx: Receiver, - /// The mpsc Sender channel to inform the manager that a shard has just properly shut down - pub shutdown: Sender, -} -#[derive(Debug)] -pub enum ShardManagerError { - /// Returned when a shard received an [`InvalidAuthentication`] error. An invalid token has - /// been specified. - /// - /// [`InvalidAuthentication`]: crate::gateway::GatewayError::InvalidAuthentication - InvalidToken, - /// Returned when a shard received an [`InvalidGatewayIntents`] error. - /// - /// [`InvalidGatewayIntents`]: crate::gateway::GatewayError::InvalidGatewayIntents - InvalidGatewayIntents, - /// Returned when a shard received a [`DisallowedGatewayIntents`] error. - /// - /// [`DisallowedGatewayIntents`]: crate::gateway::GatewayError::DisallowedGatewayIntents - DisallowedGatewayIntents, -} - -type Result = std::result::Result; - -impl ShardManagerMonitor { - /// "Runs" the monitor, waiting for messages over the Receiver. - /// - /// This should be called in its own thread due to its blocking, looped nature. - /// - /// This will continue running until either: - /// - a [`ShardManagerMessage::ShutdownAll`] has been received - /// - an error is returned while receiving a message from the channel (probably indicating that - /// the shard manager should stop anyway) - #[instrument(skip(self))] - pub async fn run(&mut self) -> Result<()> { - debug!("Starting shard manager worker"); - - while let Some(value) = self.rx.next().await { - match value { - ShardManagerMessage::Restart(shard_id) => { - self.manager.lock().await.restart(shard_id).await; - drop(self.shutdown.unbounded_send(shard_id)); - }, - ShardManagerMessage::ShardUpdate { - id, - latency, - stage, - } => { - let manager = self.manager.lock().await; - let mut runners = manager.runners.lock().await; - - if let Some(runner) = runners.get_mut(&id) { - runner.latency = latency; - runner.stage = stage; - } - }, - ShardManagerMessage::Shutdown(shard_id, code) => { - self.manager.lock().await.shutdown(shard_id, code).await; - drop(self.shutdown.unbounded_send(shard_id)); - }, - ShardManagerMessage::ShutdownAll => { - self.manager.lock().await.shutdown_all().await; - - break; - }, - ShardManagerMessage::ShutdownInitiated => break, - ShardManagerMessage::ShutdownFinished(shard_id) => { - if let Err(why) = self.shutdown.unbounded_send(shard_id) { - warn!( - "[ShardMonitor] Could not forward Shutdown signal to ShardManager for shard {}: {:#?}", - shard_id, - why - ); - } - }, - ShardManagerMessage::ShardInvalidAuthentication => { - self.manager.lock().await.shutdown_all().await; - return Err(ShardManagerError::InvalidToken); - }, - - ShardManagerMessage::ShardInvalidGatewayIntents => { - self.manager.lock().await.shutdown_all().await; - return Err(ShardManagerError::InvalidGatewayIntents); - }, - ShardManagerMessage::ShardDisallowedGatewayIntents => { - self.manager.lock().await.shutdown_all().await; - return Err(ShardManagerError::DisallowedGatewayIntents); - }, - } - } - - Ok(()) - } -} diff --git a/src/client/bridge/gateway/shard_messenger.rs b/src/client/bridge/gateway/shard_messenger.rs index c1185d15d09..e2a8b0650f4 100644 --- a/src/client/bridge/gateway/shard_messenger.rs +++ b/src/client/bridge/gateway/shard_messenger.rs @@ -1,10 +1,10 @@ -use futures::channel::mpsc::{TrySendError, UnboundedSender as Sender}; +use futures::channel::mpsc::UnboundedSender as Sender; use tokio_tungstenite::tungstenite::Message; #[cfg(feature = "collector")] use super::CollectorCallback; -use super::{ChunkGuildFilter, ShardClientMessage, ShardRunnerMessage}; -use crate::gateway::{ActivityData, InterMessage}; +use super::{ChunkGuildFilter, ShardRunnerMessage}; +use crate::gateway::ActivityData; use crate::model::prelude::*; /// A lightweight wrapper around an mpsc sender. @@ -16,7 +16,7 @@ use crate::model::prelude::*; /// [`ShardRunner`]: super::ShardRunner #[derive(Clone, Debug)] pub struct ShardMessenger { - pub(crate) tx: Sender, + pub(crate) tx: Sender, } impl ShardMessenger { @@ -27,7 +27,7 @@ impl ShardMessenger { /// [`Client`]: crate::Client #[inline] #[must_use] - pub const fn new(tx: Sender) -> Self { + pub const fn new(tx: Sender) -> Self { Self { tx, } @@ -111,12 +111,12 @@ impl ShardMessenger { filter: ChunkGuildFilter, nonce: Option, ) { - drop(self.send_to_shard(ShardRunnerMessage::ChunkGuild { + self.send_to_shard(ShardRunnerMessage::ChunkGuild { guild_id, limit, filter, nonce, - })); + }); } /// Sets the user's current activity, if any. @@ -149,7 +149,7 @@ impl ShardMessenger { /// # } /// ``` pub fn set_activity(&self, activity: Option) { - drop(self.send_to_shard(ShardRunnerMessage::SetActivity(activity))); + self.send_to_shard(ShardRunnerMessage::SetActivity(activity)); } /// Sets the user's full presence information. @@ -188,7 +188,7 @@ impl ShardMessenger { status = OnlineStatus::Invisible; } - drop(self.send_to_shard(ShardRunnerMessage::SetPresence(activity, status))); + self.send_to_shard(ShardRunnerMessage::SetPresence(activity, status)); } /// Sets the user's current online status. @@ -232,12 +232,12 @@ impl ShardMessenger { online_status = OnlineStatus::Invisible; } - drop(self.send_to_shard(ShardRunnerMessage::SetStatus(online_status))); + self.send_to_shard(ShardRunnerMessage::SetStatus(online_status)); } /// Shuts down the websocket by attempting to cleanly close the connection. pub fn shutdown_clean(&self) { - drop(self.send_to_shard(ShardRunnerMessage::Close(1000, None))); + self.send_to_shard(ShardRunnerMessage::Close(1000, None)); } /// Sends a raw message over the WebSocket. @@ -247,23 +247,20 @@ impl ShardMessenger { /// You should only use this if you know what you're doing. If you're wanting to, for example, /// send a presence update, prefer the usage of the [`Self::set_presence`] method. pub fn websocket_message(&self, message: Message) { - drop(self.send_to_shard(ShardRunnerMessage::Message(message))); + self.send_to_shard(ShardRunnerMessage::Message(message)); } /// Sends a message to the shard. - /// - /// # Errors - /// - /// Returns a [`TrySendError`] if the shard's receiver was closed. #[inline] - pub fn send_to_shard(&self, msg: ShardRunnerMessage) -> Result<(), TrySendError> { - // TODO: don't propagate send error but handle here directly via a tracing::warn - self.tx.unbounded_send(InterMessage::Client(ShardClientMessage::Runner(Box::new(msg)))) + pub fn send_to_shard(&self, msg: ShardRunnerMessage) { + if let Err(e) = self.tx.unbounded_send(msg) { + tracing::warn!("failed to send ShardRunnerMessage to shard: {}", e); + } } #[cfg(feature = "collector")] pub fn add_collector(&self, collector: CollectorCallback) { - drop(self.send_to_shard(ShardRunnerMessage::AddCollector(collector))); + self.send_to_shard(ShardRunnerMessage::AddCollector(collector)); } } diff --git a/src/client/bridge/gateway/shard_queuer.rs b/src/client/bridge/gateway/shard_queuer.rs index ffa3af450fa..e470216bff4 100644 --- a/src/client/bridge/gateway/shard_queuer.rs +++ b/src/client/bridge/gateway/shard_queuer.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; -use futures::channel::mpsc::{UnboundedReceiver as Receiver, UnboundedSender as Sender}; +use futures::channel::mpsc::UnboundedReceiver as Receiver; use futures::StreamExt; #[cfg(feature = "framework")] use once_cell::sync::OnceCell; @@ -11,9 +11,8 @@ use tracing::{debug, info, instrument, warn}; use typemap_rev::TypeMap; use super::{ - ShardClientMessage, ShardId, - ShardManagerMessage, + ShardManager, ShardMessenger, ShardQueuerMessage, ShardRunner, @@ -22,12 +21,13 @@ use super::{ }; #[cfg(feature = "cache")] use crate::cache::Cache; +use crate::client::bridge::gateway::ShardRunnerMessage; #[cfg(feature = "voice")] use crate::client::bridge::voice::VoiceGatewayManager; use crate::client::{EventHandler, RawEventHandler}; #[cfg(feature = "framework")] use crate::framework::Framework; -use crate::gateway::{ConnectionStage, InterMessage, PresenceData, Shard}; +use crate::gateway::{ConnectionStage, PresenceData, Shard}; use crate::http::Http; use crate::internal::prelude::*; use crate::internal::tokio::spawn_named; @@ -59,10 +59,8 @@ pub struct ShardQueuer { /// /// This is used to determine how long to wait between shard IDENTIFYs. pub last_start: Option, - /// A copy of the sender channel to communicate with the [`ShardManagerMonitor`]. - /// - /// [`ShardManagerMonitor`]: super::ShardManagerMonitor - pub manager_tx: Sender, + /// A copy of the [`ShardManager`] to communicate with it. + pub manager: Arc>, /// The shards that are queued for booting. /// /// This will typically be filled with previously failed boots. @@ -187,7 +185,7 @@ impl ShardQueuer { raw_event_handlers: self.raw_event_handlers.clone(), #[cfg(feature = "framework")] framework: self.framework.get().map(Arc::clone), - manager_tx: self.manager_tx.clone(), + manager: Arc::clone(&self.manager), #[cfg(feature = "voice")] voice_manager: self.voice_manager.clone(), shard, @@ -241,9 +239,7 @@ impl ShardQueuer { info!("Shutting down shard {}", shard_id); if let Some(runner) = self.runners.lock().await.get(&shard_id) { - let shutdown = ShardManagerMessage::Shutdown(shard_id, code); - let client_msg = ShardClientMessage::Manager(shutdown); - let msg = InterMessage::Client(client_msg); + let msg = ShardRunnerMessage::Shutdown(shard_id, code); if let Err(why) = runner.runner_tx.tx.unbounded_send(msg) { warn!( diff --git a/src/client/bridge/gateway/shard_runner.rs b/src/client/bridge/gateway/shard_runner.rs index 4f2c6c52b68..bdc0d9388a8 100644 --- a/src/client/bridge/gateway/shard_runner.rs +++ b/src/client/bridge/gateway/shard_runner.rs @@ -2,28 +2,29 @@ use std::borrow::Cow; use std::sync::Arc; use futures::channel::mpsc::{self, UnboundedReceiver as Receiver, UnboundedSender as Sender}; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, RwLock}; use tokio_tungstenite::tungstenite; use tokio_tungstenite::tungstenite::error::Error as TungsteniteError; use tokio_tungstenite::tungstenite::protocol::frame::CloseFrame; use tracing::{debug, error, info, instrument, trace, warn}; use typemap_rev::TypeMap; -use super::event::{ClientEvent, ShardStageUpdateEvent}; +use super::event::ShardStageUpdateEvent; #[cfg(feature = "collector")] use super::CollectorCallback; -use super::{ShardClientMessage, ShardId, ShardManagerMessage, ShardRunnerMessage}; +use super::{ShardId, ShardManager, ShardRunnerMessage}; #[cfg(feature = "cache")] use crate::cache::Cache; #[cfg(feature = "voice")] use crate::client::bridge::voice::VoiceGatewayManager; -use crate::client::dispatch::{dispatch_client, dispatch_model}; +use crate::client::dispatch::dispatch_model; use crate::client::{Context, EventHandler, RawEventHandler}; #[cfg(feature = "framework")] use crate::framework::Framework; -use crate::gateway::{GatewayError, InterMessage, ReconnectType, Shard, ShardAction}; +use crate::gateway::{GatewayError, ReconnectType, Shard, ShardAction}; use crate::http::Http; use crate::internal::prelude::*; +use crate::internal::tokio::spawn_named; use crate::model::event::{Event, GatewayEvent}; /// A runner for managing a [`Shard`] and its respective WebSocket client. @@ -33,11 +34,11 @@ pub struct ShardRunner { raw_event_handlers: Vec>, #[cfg(feature = "framework")] framework: Option>, - manager_tx: Sender, + manager: Arc>, // channel to receive messages from the shard manager and dispatches - runner_rx: Receiver, + runner_rx: Receiver, // channel to send messages to the shard runner from the shard manager - runner_tx: Sender, + runner_tx: Sender, pub(crate) shard: Shard, #[cfg(feature = "voice")] voice_manager: Option>, @@ -61,7 +62,7 @@ impl ShardRunner { raw_event_handlers: opt.raw_event_handlers, #[cfg(feature = "framework")] framework: opt.framework, - manager_tx: opt.manager_tx, + manager: opt.manager, shard: opt.shard, #[cfg(feature = "voice")] voice_manager: opt.voice_manager, @@ -116,15 +117,19 @@ impl ShardRunner { let post = self.shard.stage(); if post != pre { - self.update_manager(); - - let e = ClientEvent::ShardStageUpdate(ShardStageUpdateEvent { - new: post, - old: pre, - shard_id: ShardId(self.shard.shard_info().id), - }); - - dispatch_client(e, self.make_context(), self.event_handlers.clone()).await; + self.update_manager().await; + + for event_handler in self.event_handlers.clone() { + let context = self.make_context(); + let event = ShardStageUpdateEvent { + new: post, + old: pre, + shard_id: ShardId(self.shard.shard_info().id), + }; + spawn_named("dispatch::event_handler::shard_stage_update", async move { + event_handler.shard_stage_update(context, event).await; + }); + } } match action { @@ -181,7 +186,7 @@ impl ShardRunner { } /// Clones the internal copy of the Sender to the shard runner. - pub(super) fn runner_tx(&self) -> Sender { + pub(super) fn runner_tx(&self) -> Sender { self.runner_tx.clone() } @@ -245,14 +250,7 @@ impl ShardRunner { } // Inform the manager that shutdown for this shard has finished. - if let Err(why) = self.manager_tx.unbounded_send(ShardManagerMessage::ShutdownFinished(id)) - { - warn!( - "[ShardRunner {:?}] Could not send ShutdownFinished: {:#?}", - self.shard.shard_info(), - why, - ); - } + self.manager.lock().await.shutdown_finished(id); false } @@ -274,88 +272,51 @@ impl ShardRunner { // This always returns true, except in the case that the shard manager asked the runner to // shutdown. #[instrument(skip(self))] - async fn handle_rx_value(&mut self, value: InterMessage) -> bool { - match value { - InterMessage::Client(value) => match value { - ShardClientMessage::Manager(msg) => match msg { - ShardManagerMessage::Restart(id) => self.checked_shutdown(id, 4000).await, - ShardManagerMessage::Shutdown(id, code) => { - self.checked_shutdown(id, code).await - }, - ShardManagerMessage::ShutdownAll => { - // This variant should never be received. - warn!( - "[ShardRunner {:?}] Received a ShutdownAll?", - self.shard.shard_info(), - ); - - true - }, - ShardManagerMessage::ShardUpdate { - .. - } - | ShardManagerMessage::ShutdownInitiated - | ShardManagerMessage::ShutdownFinished(_) => { - // nb: not sent here - true - }, - ShardManagerMessage::ShardDisallowedGatewayIntents - | ShardManagerMessage::ShardInvalidAuthentication - | ShardManagerMessage::ShardInvalidGatewayIntents => { - // These variants should never be received. - warn!("[ShardRunner {:?}] Received a ShardError?", self.shard.shard_info(),); - - true - }, - }, - ShardClientMessage::Runner(msg) => match *msg { - ShardRunnerMessage::ChunkGuild { - guild_id, - limit, - filter, - nonce, - } => self - .shard - .chunk_guild(guild_id, limit, filter, nonce.as_deref()) - .await - .is_ok(), - ShardRunnerMessage::Close(code, reason) => { - let reason = reason.unwrap_or_default(); - let close = CloseFrame { - code: code.into(), - reason: Cow::from(reason), - }; - self.shard.client.close(Some(close)).await.is_ok() - }, - ShardRunnerMessage::Message(msg) => self.shard.client.send(msg).await.is_ok(), - ShardRunnerMessage::SetActivity(activity) => { - // To avoid a clone of `activity`, we do a little bit of trickery here: - // - // First, we obtain a reference to the current presence of the shard, and - // create a new presence tuple of the new activity we received over the - // channel as well as the online status that the shard already had. - // - // We then (attempt to) send the websocket message with the status update, - // expressively returning: - // - whether the message successfully sent - // - the original activity we received over the channel - self.shard.set_activity(activity); - self.shard.update_presence().await.is_ok() - }, - ShardRunnerMessage::SetPresence(activity, status) => { - self.shard.set_presence(activity, status); - self.shard.update_presence().await.is_ok() - }, - ShardRunnerMessage::SetStatus(status) => { - self.shard.set_status(status); - self.shard.update_presence().await.is_ok() - }, - #[cfg(feature = "collector")] - ShardRunnerMessage::AddCollector(collector) => { - self.collectors.push(collector); - true - }, - }, + async fn handle_rx_value(&mut self, msg: ShardRunnerMessage) -> bool { + match msg { + ShardRunnerMessage::Restart(id) => self.checked_shutdown(id, 4000).await, + ShardRunnerMessage::Shutdown(id, code) => self.checked_shutdown(id, code).await, + ShardRunnerMessage::ChunkGuild { + guild_id, + limit, + filter, + nonce, + } => self.shard.chunk_guild(guild_id, limit, filter, nonce.as_deref()).await.is_ok(), + ShardRunnerMessage::Close(code, reason) => { + let reason = reason.unwrap_or_default(); + let close = CloseFrame { + code: code.into(), + reason: Cow::from(reason), + }; + self.shard.client.close(Some(close)).await.is_ok() + }, + ShardRunnerMessage::Message(msg) => self.shard.client.send(msg).await.is_ok(), + ShardRunnerMessage::SetActivity(activity) => { + // To avoid a clone of `activity`, we do a little bit of trickery here: + // + // First, we obtain a reference to the current presence of the shard, and + // create a new presence tuple of the new activity we received over the + // channel as well as the online status that the shard already had. + // + // We then (attempt to) send the websocket message with the status update, + // expressively returning: + // - whether the message successfully sent + // - the original activity we received over the channel + self.shard.set_activity(activity); + self.shard.update_presence().await.is_ok() + }, + ShardRunnerMessage::SetPresence(activity, status) => { + self.shard.set_presence(activity, status); + self.shard.update_presence().await.is_ok() + }, + ShardRunnerMessage::SetStatus(status) => { + self.shard.set_status(status); + self.shard.update_presence().await.is_ok() + }, + #[cfg(feature = "collector")] + ShardRunnerMessage::AddCollector(collector) => { + self.collectors.push(collector); + true }, } } @@ -456,41 +417,13 @@ impl ShardRunner { Err(why) => { error!("Shard handler received err: {:?}", why); - match why { - Error::Gateway(GatewayError::InvalidAuthentication) => { - if self - .manager_tx - .unbounded_send(ShardManagerMessage::ShardInvalidAuthentication) - .is_err() - { - panic!( - "Failed sending InvalidAuthentication error to the shard manager." - ); - } - - return Err(why); - }, - Error::Gateway(GatewayError::InvalidGatewayIntents) => { - if self - .manager_tx - .unbounded_send(ShardManagerMessage::ShardInvalidGatewayIntents) - .is_err() - { - panic!( - "Failed sending InvalidGatewayIntents error to the shard manager." - ); - } - - return Err(why); - }, - Error::Gateway(GatewayError::DisallowedGatewayIntents) => { - if self - .manager_tx - .unbounded_send(ShardManagerMessage::ShardDisallowedGatewayIntents) - .is_err() - { - panic!("Failed sending DisallowedGatewayIntents error to the shard manager."); - } + match &why { + Error::Gateway( + error @ (GatewayError::InvalidAuthentication + | GatewayError::InvalidGatewayIntents + | GatewayError::DisallowedGatewayIntents), + ) => { + self.manager.lock().await.return_with_value(Err(error.clone())).await; return Err(why); }, @@ -500,7 +433,7 @@ impl ShardRunner { }; if let Ok(GatewayEvent::HeartbeatAck) = event { - self.update_manager(); + self.update_manager().await; } #[cfg(feature = "voice")] @@ -520,15 +453,11 @@ impl ShardRunner { #[instrument(skip(self))] async fn request_restart(&mut self) -> Result<()> { - self.update_manager(); + self.update_manager().await; debug!("[ShardRunner {:?}] Requesting restart", self.shard.shard_info(),); let shard_id = ShardId(self.shard.shard_info().id); - let msg = ShardManagerMessage::Restart(shard_id); - - if let Err(error) = self.manager_tx.unbounded_send(msg) { - warn!("Error sending request restart: {:?}", error); - } + self.manager.lock().await.restart_shard(shard_id).await; #[cfg(feature = "voice")] if let Some(voice_manager) = &self.voice_manager { @@ -539,12 +468,16 @@ impl ShardRunner { } #[instrument(skip(self))] - fn update_manager(&self) { - drop(self.manager_tx.unbounded_send(ShardManagerMessage::ShardUpdate { - id: ShardId(self.shard.shard_info().id), - latency: self.shard.latency(), - stage: self.shard.stage(), - })); + async fn update_manager(&self) { + self.manager + .lock() + .await + .shard_update( + ShardId(self.shard.shard_info().id), + self.shard.latency(), + self.shard.stage(), + ) + .await; } } @@ -555,7 +488,7 @@ pub struct ShardRunnerOptions { pub raw_event_handlers: Vec>, #[cfg(feature = "framework")] pub framework: Option>, - pub manager_tx: Sender, + pub manager: Arc>, pub shard: Shard, #[cfg(feature = "voice")] pub voice_manager: Option>, diff --git a/src/client/bridge/gateway/shard_runner_message.rs b/src/client/bridge/gateway/shard_runner_message.rs index c290ff8d560..30852d6abb9 100644 --- a/src/client/bridge/gateway/shard_runner_message.rs +++ b/src/client/bridge/gateway/shard_runner_message.rs @@ -2,6 +2,7 @@ use tokio_tungstenite::tungstenite::Message; #[cfg(feature = "collector")] use super::CollectorCallback; +use super::ShardId; use crate::gateway::ActivityData; pub use crate::gateway::ChunkGuildFilter; use crate::model::id::GuildId; @@ -10,6 +11,11 @@ use crate::model::user::OnlineStatus; /// A message to send from a shard over a WebSocket. #[derive(Debug)] pub enum ShardRunnerMessage { + /// Indicator that a shard should be restarted. + Restart(ShardId), + /// Indicator that a shard should be fully shutdown without bringing it + /// back up. + Shutdown(ShardId, u16), /// Indicates that the client is to send a member chunk message. ChunkGuild { /// The IDs of the [`Guild`] to chunk. diff --git a/src/client/bridge/voice/mod.rs b/src/client/bridge/voice/mod.rs index b94fe655015..fb8737c3661 100644 --- a/src/client/bridge/voice/mod.rs +++ b/src/client/bridge/voice/mod.rs @@ -1,7 +1,7 @@ use async_trait::async_trait; use futures::channel::mpsc::UnboundedSender as Sender; -use crate::gateway::InterMessage; +use crate::client::bridge::gateway::ShardRunnerMessage; use crate::model::id::{GuildId, UserId}; use crate::model::voice::VoiceState; @@ -22,7 +22,7 @@ pub trait VoiceGatewayManager: Send + Sync { /// active shard. /// /// [`Ready`]: crate::model::event::Event - async fn register_shard(&self, shard_id: u32, sender: Sender); + async fn register_shard(&self, shard_id: u32, sender: Sender); /// Handler fired in response to a disconnect, reconnection, or rebalance. /// diff --git a/src/client/context.rs b/src/client/context.rs index 4f48073df1f..9c6c5c8d697 100644 --- a/src/client/context.rs +++ b/src/client/context.rs @@ -5,13 +5,12 @@ use futures::channel::mpsc::UnboundedSender as Sender; use tokio::sync::RwLock; use typemap_rev::TypeMap; +use super::bridge::gateway::ShardRunnerMessage; #[cfg(feature = "cache")] pub use crate::cache::Cache; #[cfg(feature = "gateway")] use crate::client::bridge::gateway::ShardMessenger; use crate::gateway::ActivityData; -#[cfg(feature = "gateway")] -use crate::gateway::InterMessage; use crate::http::Http; use crate::model::prelude::*; @@ -59,7 +58,7 @@ impl Context { #[cfg(all(feature = "cache", feature = "gateway"))] pub(crate) fn new( data: Arc>, - runner_tx: Sender, + runner_tx: Sender, shard_id: u32, http: Arc, cache: Arc, @@ -86,7 +85,7 @@ impl Context { #[cfg(all(not(feature = "cache"), feature = "gateway"))] pub(crate) fn new( data: Arc>, - runner_tx: Sender, + runner_tx: Sender, shard_id: u32, http: Arc, ) -> Context { diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 453c4e0566c..a4fce0989b5 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -2,8 +2,6 @@ use std::sync::Arc; use tracing::debug; -#[cfg(feature = "gateway")] -use super::bridge::gateway::event::ClientEvent; #[cfg(feature = "gateway")] use super::event_handler::{EventHandler, RawEventHandler}; use super::{Context, FullEvent}; @@ -65,23 +63,6 @@ pub(crate) async fn dispatch_model<'rec>( } } -pub(crate) async fn dispatch_client<'rec>( - event: ClientEvent, - context: Context, - event_handlers: Vec>, -) { - match event { - ClientEvent::ShardStageUpdate(event) => { - for event_handler in event_handlers { - let (context, event) = (context.clone(), event.clone()); - spawn_named("dispatch::event_handler::shard_stage_update", async move { - event_handler.shard_stage_update(context, event).await; - }); - } - }, - } -} - /// Updates the cache with the incoming event data and builds the full event data out of it. /// /// Can return a secondary [`FullEvent`] for "virtual" events like [`FullEvent::CacheReady`] or diff --git a/src/client/mod.rs b/src/client/mod.rs index f4313723292..695a8c103d7 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -27,7 +27,9 @@ use std::future::IntoFuture; use std::ops::Range; use std::sync::Arc; +use futures::channel::mpsc::UnboundedReceiver as Receiver; use futures::future::BoxFuture; +use futures::StreamExt as _; #[cfg(feature = "framework")] use once_cell::sync::OnceCell; use tokio::sync::{Mutex, RwLock}; @@ -35,12 +37,7 @@ use tracing::{debug, error, info, instrument}; use typemap_rev::{TypeMap, TypeMapKey}; #[cfg(feature = "gateway")] -use self::bridge::gateway::{ - ShardManager, - ShardManagerError, - ShardManagerMonitor, - ShardManagerOptions, -}; +use self::bridge::gateway::{ShardManager, ShardManagerOptions}; #[cfg(feature = "voice")] use self::bridge::voice::VoiceGatewayManager; pub use self::context::Context; @@ -378,7 +375,7 @@ impl IntoFuture for ClientBuilder { #[cfg(feature = "framework")] let framework_cell = Arc::new(OnceCell::new()); - let (shard_manager, shard_manager_worker) = ShardManager::new(ShardManagerOptions { + let (shard_manager, shard_manager_ret_value) = ShardManager::new(ShardManagerOptions { data: Arc::clone(&data), event_handlers, raw_event_handlers, @@ -400,7 +397,7 @@ impl IntoFuture for ClientBuilder { let client = Client { data, shard_manager, - shard_manager_worker, + shard_manager_return_value: shard_manager_ret_value, #[cfg(feature = "voice")] voice_manager, ws_url, @@ -634,7 +631,7 @@ pub struct Client { /// # } /// ``` pub shard_manager: Arc>, - shard_manager_worker: ShardManagerMonitor, + shard_manager_return_value: Receiver>, /// The voice manager for the client. /// /// This is an ergonomic structure for interfacing over shards' voice @@ -958,14 +955,7 @@ impl Client { } } - if let Err(why) = self.shard_manager_worker.run().await { - let err = match why { - ShardManagerError::DisallowedGatewayIntents => { - GatewayError::DisallowedGatewayIntents - }, - ShardManagerError::InvalidGatewayIntents => GatewayError::InvalidGatewayIntents, - ShardManagerError::InvalidToken => GatewayError::InvalidAuthentication, - }; + if let Some(Err(err)) = self.shard_manager_return_value.next().await { return Err(Error::Gateway(err)); } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e687327ffb8..1b16b0e789d 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -49,14 +49,10 @@ use std::fmt; #[cfg(feature = "http")] use reqwest::IntoUrl; use reqwest::Url; -#[cfg(feature = "client")] -use tokio_tungstenite::tungstenite; pub use self::error::Error as GatewayError; pub use self::shard::Shard; pub use self::ws::WsClient; -#[cfg(feature = "client")] -use crate::client::bridge::gateway::{ShardClientMessage, ShardRunnerMessage}; #[cfg(feature = "http")] use crate::internal::prelude::*; use crate::model::gateway::{Activity, ActivityType}; @@ -218,34 +214,6 @@ impl fmt::Display for ConnectionStage { } } -/// A message to be passed around within the library. -/// -/// As a user you usually don't need to worry about this, but when working with the lower-level -/// internals of the [`client`], [`gateway`], and [`voice`] modules it may be necessary. -/// -/// [`client`]: crate::client -/// [`gateway`]: crate::gateway -/// [`voice`]: crate::model::voice -#[derive(Debug)] -#[non_exhaustive] -pub enum InterMessage { - #[cfg(feature = "client")] - Client(ShardClientMessage), -} - -impl InterMessage { - /// Constructs a custom message which will send the given `value` over the WebSocket. - /// - /// This is simply sugar for constructing and nesting [`ShardRunnerMessage::Message`]. - #[cfg(feature = "client")] - #[must_use] - pub fn json(value: String) -> Self { - Self::Client(ShardClientMessage::Runner(Box::new(ShardRunnerMessage::Message( - tungstenite::Message::Text(value), - )))) - } -} - #[derive(Debug)] #[non_exhaustive] pub enum ShardAction {