Skip to content

Commit

Permalink
Gateway: Fix repeat joins on same channel from stalling (#47)
Browse files Browse the repository at this point in the history
Joining a channel returns a future which fires on receipt of two messages from discord (by locally storing a channel). However, joining this same channel again after a success returns only *one* such message, causing the command to hang until another join fires or the channel is left. This alters internal behaviour to correctly cancel an in-progress connection attempt, or return success with known data if such a connection is present.

This introduces a breaking change on `Call::update_state` to include the target `ChannelId`. The reason for this is that although the `ChannelId` of a target channel was being stored, server admins may move or kick a bot from its voice channel. This changes the true channel, and may accidentally trigger a "double join" elsewhere.

This fix was tested by using an example to have a bot join its channel twice, to do so in a channel it had been moved to, and to move from a channel it had been moved to.
  • Loading branch information
FelixMcFelix committed May 10, 2021
1 parent 9e202f6 commit 95dd19e
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 36 deletions.
110 changes: 81 additions & 29 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ enum Return {
/// [`Driver`]: struct@Driver
#[derive(Clone, Debug)]
pub struct Call {
connection: Option<(ChannelId, ConnectionProgress, Return)>,
connection: Option<(ConnectionProgress, Return)>,

#[cfg(feature = "driver-core")]
/// The internal controller of the voice connection monitor thread.
Expand Down Expand Up @@ -132,12 +132,12 @@ impl Call {
#[instrument(skip(self))]
fn do_connect(&mut self) {
match &self.connection {
Some((_, ConnectionProgress::Complete(c), Return::Info(tx))) => {
Some((ConnectionProgress::Complete(c), Return::Info(tx))) => {
// It's okay if the receiver hung up.
let _ = tx.send(c.clone());
},
#[cfg(feature = "driver-core")]
Some((_, ConnectionProgress::Complete(c), Return::Conn(tx))) => {
Some((ConnectionProgress::Complete(c), Return::Conn(tx))) => {
self.driver.raw_connect(c.clone(), tx.clone());
},
_ => {},
Expand Down Expand Up @@ -171,6 +171,31 @@ impl Call {
self.self_deaf
}

async fn should_actually_join<F, G>(
&mut self,
completion_generator: F,
tx: &Sender<G>,
channel_id: ChannelId,
) -> JoinResult<bool>
where
F: FnOnce(&Self) -> G,
{
Ok(if let Some(conn) = &self.connection {
if conn.0.in_progress() {
self.leave().await?;
true
} else if conn.0.channel_id() == channel_id {
let _ = tx.send(completion_generator(&self));
false
} else {
// not in progress, and/or a channel change.
true
}
} else {
true
})
}

#[cfg(feature = "driver-core")]
/// Connect or switch to the given voice channel by its Id.
///
Expand All @@ -190,13 +215,20 @@ impl Call {
) -> JoinResult<RecvFut<'static, ConnectionResult<()>>> {
let (tx, rx) = flume::unbounded();

self.connection = Some((
channel_id,
ConnectionProgress::new(self.guild_id, self.user_id),
Return::Conn(tx),
));
let do_conn = self
.should_actually_join(|_| Ok(()), &tx, channel_id)
.await?;

if do_conn {
self.connection = Some((
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
Return::Conn(tx),
));

self.update().await.map(|_| rx.into_recv_async())
self.update().await.map(|_| rx.into_recv_async())
} else {
Ok(rx.into_recv_async())
}
}

/// Join the selected voice channel, *without* running/starting an RTP
Expand All @@ -221,21 +253,32 @@ impl Call {
) -> JoinResult<RecvFut<'static, ConnectionInfo>> {
let (tx, rx) = flume::unbounded();

self.connection = Some((
channel_id,
ConnectionProgress::new(self.guild_id, self.user_id),
Return::Info(tx),
));

self.update().await.map(|_| rx.into_recv_async())
let do_conn = self
.should_actually_join(
|call| call.connection.as_ref().unwrap().0.info().unwrap(),
&tx,
channel_id,
)
.await?;

if do_conn {
self.connection = Some((
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
Return::Info(tx),
));

self.update().await.map(|_| rx.into_recv_async())
} else {
Ok(rx.into_recv_async())
}
}

/// Returns the current voice connection details for this Call,
/// if available.
#[instrument(skip(self))]
pub fn current_connection(&self) -> Option<&ConnectionInfo> {
match &self.connection {
Some((_, progress, _)) => progress.get_connection_info(),
Some((progress, _)) => progress.get_connection_info(),
_ => None,
}
}
Expand Down Expand Up @@ -265,13 +308,17 @@ impl Call {
/// [`standalone`]: Call::standalone
#[instrument(skip(self))]
pub async fn leave(&mut self) -> JoinResult<()> {
self.leave_local();

// Only send an update if we were in a voice channel.
self.update().await
}

fn leave_local(&mut self) {
self.connection = None;

#[cfg(feature = "driver-core")]
self.driver.leave();

self.update().await
}

/// Sets whether the current connection is to be muted.
Expand Down Expand Up @@ -307,7 +354,7 @@ impl Call {
/// [`standalone`]: Call::standalone
#[instrument(skip(self, token))]
pub fn update_server(&mut self, endpoint: String, token: String) {
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
progress.apply_server_update(endpoint, token)
} else {
false
Expand All @@ -325,15 +372,20 @@ impl Call {
///
/// [`standalone`]: Call::standalone
#[instrument(skip(self))]
pub fn update_state(&mut self, session_id: String) {
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
progress.apply_state_update(session_id)
pub fn update_state(&mut self, session_id: String, channel_id: Option<ChannelId>) {
if let Some(channel_id) = channel_id {
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
progress.apply_state_update(session_id, channel_id)
} else {
false
};

if try_conn {
self.do_connect();
}
} else {
false
};

if try_conn {
self.do_connect();
// Likely that we were disconnected by an admin.
self.leave_local();
}
}

Expand All @@ -348,7 +400,7 @@ impl Call {
let map = json!({
"op": 4,
"d": {
"channel_id": self.connection.as_ref().map(|c| c.0.0),
"channel_id": self.connection.as_ref().map(|c| c.0.channel_id().0),
"guild_id": self.guild_id.0,
"self_deaf": self.self_deaf,
"self_mute": self.self_mute,
Expand Down
69 changes: 64 additions & 5 deletions src/info.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::id::{GuildId, UserId};
use crate::id::{ChannelId, GuildId, UserId};
use std::fmt;

#[derive(Clone, Debug)]
Expand All @@ -8,8 +8,9 @@ pub(crate) enum ConnectionProgress {
}

impl ConnectionProgress {
pub fn new(guild_id: GuildId, user_id: UserId) -> Self {
pub(crate) fn new(guild_id: GuildId, user_id: UserId, channel_id: ChannelId) -> Self {
ConnectionProgress::Incomplete(Partial {
channel_id,
guild_id,
user_id,
..Default::default()
Expand All @@ -24,7 +25,46 @@ impl ConnectionProgress {
}
}

pub(crate) fn apply_state_update(&mut self, session_id: String) -> bool {
pub(crate) fn in_progress(&self) -> bool {
matches!(self, ConnectionProgress::Incomplete(_))
}

pub(crate) fn channel_id(&self) -> ChannelId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info
.channel_id
.expect("All code paths MUST set channel_id for local tracking."),
ConnectionProgress::Incomplete(part) => part.channel_id,
}
}

pub(crate) fn guild_id(&self) -> GuildId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info.guild_id,
ConnectionProgress::Incomplete(part) => part.guild_id,
}
}

pub(crate) fn user_id(&self) -> UserId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info.user_id,
ConnectionProgress::Incomplete(part) => part.user_id,
}
}

pub(crate) fn info(&self) -> Option<ConnectionInfo> {
match self {
ConnectionProgress::Complete(conn_info) => Some(conn_info.clone()),
_ => None,
}
}

pub(crate) fn apply_state_update(&mut self, session_id: String, channel_id: ChannelId) -> bool {
if self.channel_id() != channel_id {
// Likely that the bot was moved to a different channel by an admin.
*self = ConnectionProgress::new(self.guild_id(), self.user_id(), channel_id);
}

use ConnectionProgress::*;
match self {
Complete(c) => {
Expand All @@ -33,7 +73,7 @@ impl ConnectionProgress {
should_reconn
},
Incomplete(i) => i
.apply_state_update(session_id)
.apply_state_update(session_id, channel_id)
.map(|info| {
*self = Complete(info);
})
Expand Down Expand Up @@ -66,6 +106,11 @@ impl ConnectionProgress {
/// with the Songbird driver, lavalink, or other system.
#[derive(Clone)]
pub struct ConnectionInfo {
/// ID of the voice channel being joined, if it is known.
///
/// This is not needed to establish a connection, but can be useful
/// for book-keeping.
pub channel_id: Option<ChannelId>,
/// URL of the voice websocket gateway server assigned to this call.
pub endpoint: String,
/// ID of the target voice channel's parent guild.
Expand All @@ -83,6 +128,7 @@ pub struct ConnectionInfo {
impl fmt::Debug for ConnectionInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionInfo")
.field("channel_id", &self.channel_id)
.field("endpoint", &self.endpoint)
.field("guild_id", &self.guild_id)
.field("session_id", &self.session_id)
Expand All @@ -94,6 +140,7 @@ impl fmt::Debug for ConnectionInfo {

#[derive(Clone, Default)]
pub(crate) struct Partial {
pub channel_id: ChannelId,
pub endpoint: Option<String>,
pub guild_id: GuildId,
pub session_id: Option<String>,
Expand All @@ -104,6 +151,7 @@ pub(crate) struct Partial {
impl fmt::Debug for Partial {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Partial")
.field("channel_id", &self.channel_id)
.field("endpoint", &self.endpoint)
.field("session_id", &self.session_id)
.field("token_is_some", &self.token.is_some())
Expand All @@ -119,6 +167,7 @@ impl Partial {
let token = self.token.take().unwrap();

Some(ConnectionInfo {
channel_id: Some(self.channel_id),
endpoint,
session_id,
token,
Expand All @@ -130,7 +179,17 @@ impl Partial {
}
}

fn apply_state_update(&mut self, session_id: String) -> Option<ConnectionInfo> {
fn apply_state_update(
&mut self,
session_id: String,
channel_id: ChannelId,
) -> Option<ConnectionInfo> {
if self.channel_id != channel_id {
self.endpoint = None;
self.token = None;
}

self.channel_id = channel_id;
self.session_id = Some(session_id);

self.finalise()
Expand Down
10 changes: 8 additions & 2 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,10 @@ impl Songbird {

if let Some(call) = call {
let mut handler = call.lock().await;
handler.update_state(v.0.session_id.clone());
handler.update_state(
v.0.session_id.clone(),
v.0.channel_id.clone().map(Into::into),
);
}
},
_ => {},
Expand Down Expand Up @@ -390,7 +393,10 @@ impl VoiceGatewayManager for Songbird {

if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
handler.update_state(voice_state.session_id.clone());
handler.update_state(
voice_state.session_id.clone(),
voice_state.channel_id.clone().map(Into::into),
);
}
}
}
Expand Down

0 comments on commit 95dd19e

Please sign in to comment.