diff --git a/src/handler.rs b/src/handler.rs index 05f962c36..7ff56536f 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -32,7 +32,7 @@ enum Return { /// [`Driver`]: struct@Driver #[derive(Clone, Debug)] pub struct Call { - connection: Option<(ChannelId, ConnectionProgress, Return)>, + connection: Option<(ConnectionProgress, Return)>, #[cfg(feature = "driver-core")] /// The internal controller of the voice connection monitor thread. @@ -132,12 +132,12 @@ impl Call { #[instrument(skip(self))] fn do_connect(&mut self) { match &self.connection { - Some((_, ConnectionProgress::Complete(c), Return::Info(tx))) => { + Some((ConnectionProgress::Complete(c), Return::Info(tx))) => { // It's okay if the receiver hung up. let _ = tx.send(c.clone()); }, #[cfg(feature = "driver-core")] - Some((_, ConnectionProgress::Complete(c), Return::Conn(tx))) => { + Some((ConnectionProgress::Complete(c), Return::Conn(tx))) => { self.driver.raw_connect(c.clone(), tx.clone()); }, _ => {}, @@ -171,6 +171,31 @@ impl Call { self.self_deaf } + async fn should_actually_join( + &mut self, + completion_generator: F, + tx: &Sender, + channel_id: ChannelId, + ) -> JoinResult + where + F: FnOnce(&Self) -> G, + { + Ok(if let Some(conn) = &self.connection { + if conn.0.in_progress() { + self.leave().await?; + true + } else if conn.0.channel_id() == channel_id { + let _ = tx.send(completion_generator(&self)); + false + } else { + // not in progress, and/or a channel change. + true + } + } else { + true + }) + } + #[cfg(feature = "driver-core")] /// Connect or switch to the given voice channel by its Id. /// @@ -190,13 +215,20 @@ impl Call { ) -> JoinResult>> { let (tx, rx) = flume::unbounded(); - self.connection = Some(( - channel_id, - ConnectionProgress::new(self.guild_id, self.user_id), - Return::Conn(tx), - )); + let do_conn = self + .should_actually_join(|_| Ok(()), &tx, channel_id) + .await?; + + if do_conn { + self.connection = Some(( + ConnectionProgress::new(self.guild_id, self.user_id, channel_id), + Return::Conn(tx), + )); - self.update().await.map(|_| rx.into_recv_async()) + self.update().await.map(|_| rx.into_recv_async()) + } else { + Ok(rx.into_recv_async()) + } } /// Join the selected voice channel, *without* running/starting an RTP @@ -221,13 +253,24 @@ impl Call { ) -> JoinResult> { let (tx, rx) = flume::unbounded(); - self.connection = Some(( - channel_id, - ConnectionProgress::new(self.guild_id, self.user_id), - Return::Info(tx), - )); - - self.update().await.map(|_| rx.into_recv_async()) + let do_conn = self + .should_actually_join( + |call| call.connection.as_ref().unwrap().0.info().unwrap(), + &tx, + channel_id, + ) + .await?; + + if do_conn { + self.connection = Some(( + ConnectionProgress::new(self.guild_id, self.user_id, channel_id), + Return::Info(tx), + )); + + self.update().await.map(|_| rx.into_recv_async()) + } else { + Ok(rx.into_recv_async()) + } } /// Returns the current voice connection details for this Call, @@ -235,7 +278,7 @@ impl Call { #[instrument(skip(self))] pub fn current_connection(&self) -> Option<&ConnectionInfo> { match &self.connection { - Some((_, progress, _)) => progress.get_connection_info(), + Some((progress, _)) => progress.get_connection_info(), _ => None, } } @@ -252,13 +295,17 @@ impl Call { /// [`standalone`]: Call::standalone #[instrument(skip(self))] pub async fn leave(&mut self) -> JoinResult<()> { + self.leave_local(); + // Only send an update if we were in a voice channel. + self.update().await + } + + fn leave_local(&mut self) { self.connection = None; #[cfg(feature = "driver-core")] self.driver.leave(); - - self.update().await } /// Sets whether the current connection is to be muted. @@ -294,7 +341,7 @@ impl Call { /// [`standalone`]: Call::standalone #[instrument(skip(self, token))] pub fn update_server(&mut self, endpoint: String, token: String) { - let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() { + let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() { progress.apply_server_update(endpoint, token) } else { false @@ -312,15 +359,20 @@ impl Call { /// /// [`standalone`]: Call::standalone #[instrument(skip(self))] - pub fn update_state(&mut self, session_id: String) { - let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() { - progress.apply_state_update(session_id) + pub fn update_state(&mut self, session_id: String, channel_id: Option) { + if let Some(channel_id) = channel_id { + let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() { + progress.apply_state_update(session_id, channel_id) + } else { + false + }; + + if try_conn { + self.do_connect(); + } } else { - false - }; - - if try_conn { - self.do_connect(); + // Likely that we were disconnected by an admin. + self.leave_local(); } } @@ -335,7 +387,7 @@ impl Call { let map = json!({ "op": 4, "d": { - "channel_id": self.connection.as_ref().map(|c| c.0.0), + "channel_id": self.connection.as_ref().map(|c| c.0.channel_id().0), "guild_id": self.guild_id.0, "self_deaf": self.self_deaf, "self_mute": self.self_mute, diff --git a/src/info.rs b/src/info.rs index 1adbe05e8..e33383d56 100644 --- a/src/info.rs +++ b/src/info.rs @@ -1,4 +1,4 @@ -use crate::id::{GuildId, UserId}; +use crate::id::{ChannelId, GuildId, UserId}; use std::fmt; #[derive(Clone, Debug)] @@ -8,8 +8,9 @@ pub(crate) enum ConnectionProgress { } impl ConnectionProgress { - pub fn new(guild_id: GuildId, user_id: UserId) -> Self { + pub(crate) fn new(guild_id: GuildId, user_id: UserId, channel_id: ChannelId) -> Self { ConnectionProgress::Incomplete(Partial { + channel_id, guild_id, user_id, ..Default::default() @@ -24,7 +25,46 @@ impl ConnectionProgress { } } - pub(crate) fn apply_state_update(&mut self, session_id: String) -> bool { + pub(crate) fn in_progress(&self) -> bool { + matches!(self, ConnectionProgress::Incomplete(_)) + } + + pub(crate) fn channel_id(&self) -> ChannelId { + match self { + ConnectionProgress::Complete(conn_info) => conn_info + .channel_id + .expect("All code paths MUST set channel_id for local tracking."), + ConnectionProgress::Incomplete(part) => part.channel_id, + } + } + + pub(crate) fn guild_id(&self) -> GuildId { + match self { + ConnectionProgress::Complete(conn_info) => conn_info.guild_id, + ConnectionProgress::Incomplete(part) => part.guild_id, + } + } + + pub(crate) fn user_id(&self) -> UserId { + match self { + ConnectionProgress::Complete(conn_info) => conn_info.user_id, + ConnectionProgress::Incomplete(part) => part.user_id, + } + } + + pub(crate) fn info(&self) -> Option { + match self { + ConnectionProgress::Complete(conn_info) => Some(conn_info.clone()), + _ => None, + } + } + + pub(crate) fn apply_state_update(&mut self, session_id: String, channel_id: ChannelId) -> bool { + if self.channel_id() != channel_id { + // Likely that the bot was moved to a different channel by an admin. + *self = ConnectionProgress::new(self.guild_id(), self.user_id(), channel_id); + } + use ConnectionProgress::*; match self { Complete(c) => { @@ -33,7 +73,7 @@ impl ConnectionProgress { should_reconn }, Incomplete(i) => i - .apply_state_update(session_id) + .apply_state_update(session_id, channel_id) .map(|info| { *self = Complete(info); }) @@ -66,6 +106,11 @@ impl ConnectionProgress { /// with the Songbird driver, lavalink, or other system. #[derive(Clone)] pub struct ConnectionInfo { + /// ID of the voice channel being joined, if it is known. + /// + /// This is not needed to establish a connection, but can be useful + /// for book-keeping. + pub channel_id: Option, /// URL of the voice websocket gateway server assigned to this call. pub endpoint: String, /// ID of the target voice channel's parent guild. @@ -83,6 +128,7 @@ pub struct ConnectionInfo { impl fmt::Debug for ConnectionInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ConnectionInfo") + .field("channel_id", &self.channel_id) .field("endpoint", &self.endpoint) .field("guild_id", &self.guild_id) .field("session_id", &self.session_id) @@ -94,6 +140,7 @@ impl fmt::Debug for ConnectionInfo { #[derive(Clone, Default)] pub(crate) struct Partial { + pub channel_id: ChannelId, pub endpoint: Option, pub guild_id: GuildId, pub session_id: Option, @@ -104,6 +151,7 @@ pub(crate) struct Partial { impl fmt::Debug for Partial { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Partial") + .field("channel_id", &self.channel_id) .field("endpoint", &self.endpoint) .field("session_id", &self.session_id) .field("token_is_some", &self.token.is_some()) @@ -119,6 +167,7 @@ impl Partial { let token = self.token.take().unwrap(); Some(ConnectionInfo { + channel_id: Some(self.channel_id), endpoint, session_id, token, @@ -130,7 +179,17 @@ impl Partial { } } - fn apply_state_update(&mut self, session_id: String) -> Option { + fn apply_state_update( + &mut self, + session_id: String, + channel_id: ChannelId, + ) -> Option { + if self.channel_id != channel_id { + self.endpoint = None; + self.token = None; + } + + self.channel_id = channel_id; self.session_id = Some(session_id); self.finalise() diff --git a/src/manager.rs b/src/manager.rs index c25297eb1..f12bbc195 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -351,7 +351,10 @@ impl Songbird { if let Some(call) = call { let mut handler = call.lock().await; - handler.update_state(v.0.session_id.clone()); + handler.update_state( + v.0.session_id.clone(), + v.0.channel_id.clone().map(Into::into), + ); } }, _ => {}, @@ -390,7 +393,10 @@ impl VoiceGatewayManager for Songbird { if let Some(call) = self.get(guild_id) { let mut handler = call.lock().await; - handler.update_state(voice_state.session_id.clone()); + handler.update_state( + voice_state.session_id.clone(), + voice_state.channel_id.clone().map(Into::into), + ); } } }