Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gateway: Fix repeat joins on same channel from stalling #47

Merged
merged 2 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 81 additions & 29 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ enum Return {
/// [`Driver`]: struct@Driver
#[derive(Clone, Debug)]
pub struct Call {
connection: Option<(ChannelId, ConnectionProgress, Return)>,
connection: Option<(ConnectionProgress, Return)>,

#[cfg(feature = "driver-core")]
/// The internal controller of the voice connection monitor thread.
Expand Down Expand Up @@ -132,12 +132,12 @@ impl Call {
#[instrument(skip(self))]
fn do_connect(&mut self) {
match &self.connection {
Some((_, ConnectionProgress::Complete(c), Return::Info(tx))) => {
Some((ConnectionProgress::Complete(c), Return::Info(tx))) => {
// It's okay if the receiver hung up.
let _ = tx.send(c.clone());
},
#[cfg(feature = "driver-core")]
Some((_, ConnectionProgress::Complete(c), Return::Conn(tx))) => {
Some((ConnectionProgress::Complete(c), Return::Conn(tx))) => {
self.driver.raw_connect(c.clone(), tx.clone());
},
_ => {},
Expand Down Expand Up @@ -171,6 +171,31 @@ impl Call {
self.self_deaf
}

async fn should_actually_join<F, G>(
&mut self,
completion_generator: F,
tx: &Sender<G>,
channel_id: ChannelId,
) -> JoinResult<bool>
where
F: FnOnce(&Self) -> G,
{
Ok(if let Some(conn) = &self.connection {
if conn.0.in_progress() {
self.leave().await?;
true
} else if conn.0.channel_id() == channel_id {
let _ = tx.send(completion_generator(&self));
false
} else {
// not in progress, and/or a channel change.
true
}
} else {
true
})
}

#[cfg(feature = "driver-core")]
/// Connect or switch to the given voice channel by its Id.
///
Expand All @@ -190,13 +215,20 @@ impl Call {
) -> JoinResult<RecvFut<'static, ConnectionResult<()>>> {
let (tx, rx) = flume::unbounded();

self.connection = Some((
channel_id,
ConnectionProgress::new(self.guild_id, self.user_id),
Return::Conn(tx),
));
let do_conn = self
.should_actually_join(|_| Ok(()), &tx, channel_id)
.await?;

if do_conn {
self.connection = Some((
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
Return::Conn(tx),
));

self.update().await.map(|_| rx.into_recv_async())
self.update().await.map(|_| rx.into_recv_async())
} else {
Ok(rx.into_recv_async())
}
}

/// Join the selected voice channel, *without* running/starting an RTP
Expand All @@ -221,21 +253,32 @@ impl Call {
) -> JoinResult<RecvFut<'static, ConnectionInfo>> {
let (tx, rx) = flume::unbounded();

self.connection = Some((
channel_id,
ConnectionProgress::new(self.guild_id, self.user_id),
Return::Info(tx),
));

self.update().await.map(|_| rx.into_recv_async())
let do_conn = self
.should_actually_join(
|call| call.connection.as_ref().unwrap().0.info().unwrap(),
&tx,
channel_id,
)
.await?;

if do_conn {
self.connection = Some((
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
Return::Info(tx),
));

self.update().await.map(|_| rx.into_recv_async())
} else {
Ok(rx.into_recv_async())
}
}

/// Returns the current voice connection details for this Call,
/// if available.
#[instrument(skip(self))]
pub fn current_connection(&self) -> Option<&ConnectionInfo> {
match &self.connection {
Some((_, progress, _)) => progress.get_connection_info(),
Some((progress, _)) => progress.get_connection_info(),
_ => None,
}
}
Expand All @@ -252,13 +295,17 @@ impl Call {
/// [`standalone`]: Call::standalone
#[instrument(skip(self))]
pub async fn leave(&mut self) -> JoinResult<()> {
self.leave_local();

// Only send an update if we were in a voice channel.
self.update().await
}

fn leave_local(&mut self) {
self.connection = None;

#[cfg(feature = "driver-core")]
self.driver.leave();

self.update().await
}

/// Sets whether the current connection is to be muted.
Expand Down Expand Up @@ -294,7 +341,7 @@ impl Call {
/// [`standalone`]: Call::standalone
#[instrument(skip(self, token))]
pub fn update_server(&mut self, endpoint: String, token: String) {
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
progress.apply_server_update(endpoint, token)
} else {
false
Expand All @@ -312,15 +359,20 @@ impl Call {
///
/// [`standalone`]: Call::standalone
#[instrument(skip(self))]
pub fn update_state(&mut self, session_id: String) {
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
progress.apply_state_update(session_id)
pub fn update_state(&mut self, session_id: String, channel_id: Option<ChannelId>) {
if let Some(channel_id) = channel_id {
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
progress.apply_state_update(session_id, channel_id)
} else {
false
};

if try_conn {
self.do_connect();
}
} else {
false
};

if try_conn {
self.do_connect();
// Likely that we were disconnected by an admin.
self.leave_local();
}
}

Expand All @@ -335,7 +387,7 @@ impl Call {
let map = json!({
"op": 4,
"d": {
"channel_id": self.connection.as_ref().map(|c| c.0.0),
"channel_id": self.connection.as_ref().map(|c| c.0.channel_id().0),
"guild_id": self.guild_id.0,
"self_deaf": self.self_deaf,
"self_mute": self.self_mute,
Expand Down
69 changes: 64 additions & 5 deletions src/info.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::id::{GuildId, UserId};
use crate::id::{ChannelId, GuildId, UserId};
use std::fmt;

#[derive(Clone, Debug)]
Expand All @@ -8,8 +8,9 @@ pub(crate) enum ConnectionProgress {
}

impl ConnectionProgress {
pub fn new(guild_id: GuildId, user_id: UserId) -> Self {
pub(crate) fn new(guild_id: GuildId, user_id: UserId, channel_id: ChannelId) -> Self {
ConnectionProgress::Incomplete(Partial {
channel_id,
guild_id,
user_id,
..Default::default()
Expand All @@ -24,7 +25,46 @@ impl ConnectionProgress {
}
}

pub(crate) fn apply_state_update(&mut self, session_id: String) -> bool {
pub(crate) fn in_progress(&self) -> bool {
matches!(self, ConnectionProgress::Incomplete(_))
}

pub(crate) fn channel_id(&self) -> ChannelId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info
.channel_id
.expect("All code paths MUST set channel_id for local tracking."),
ConnectionProgress::Incomplete(part) => part.channel_id,
}
}

pub(crate) fn guild_id(&self) -> GuildId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info.guild_id,
ConnectionProgress::Incomplete(part) => part.guild_id,
}
}

pub(crate) fn user_id(&self) -> UserId {
match self {
ConnectionProgress::Complete(conn_info) => conn_info.user_id,
ConnectionProgress::Incomplete(part) => part.user_id,
}
}

pub(crate) fn info(&self) -> Option<ConnectionInfo> {
match self {
ConnectionProgress::Complete(conn_info) => Some(conn_info.clone()),
_ => None,
}
}

pub(crate) fn apply_state_update(&mut self, session_id: String, channel_id: ChannelId) -> bool {
if self.channel_id() != channel_id {
// Likely that the bot was moved to a different channel by an admin.
*self = ConnectionProgress::new(self.guild_id(), self.user_id(), channel_id);
}

use ConnectionProgress::*;
match self {
Complete(c) => {
Expand All @@ -33,7 +73,7 @@ impl ConnectionProgress {
should_reconn
},
Incomplete(i) => i
.apply_state_update(session_id)
.apply_state_update(session_id, channel_id)
.map(|info| {
*self = Complete(info);
})
Expand Down Expand Up @@ -66,6 +106,11 @@ impl ConnectionProgress {
/// with the Songbird driver, lavalink, or other system.
#[derive(Clone)]
pub struct ConnectionInfo {
/// ID of the voice channel being joined, if it is known.
///
/// This is not needed to establish a connection, but can be useful
/// for book-keeping.
pub channel_id: Option<ChannelId>,
/// URL of the voice websocket gateway server assigned to this call.
pub endpoint: String,
/// ID of the target voice channel's parent guild.
Expand All @@ -83,6 +128,7 @@ pub struct ConnectionInfo {
impl fmt::Debug for ConnectionInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionInfo")
.field("channel_id", &self.channel_id)
.field("endpoint", &self.endpoint)
.field("guild_id", &self.guild_id)
.field("session_id", &self.session_id)
Expand All @@ -94,6 +140,7 @@ impl fmt::Debug for ConnectionInfo {

#[derive(Clone, Default)]
pub(crate) struct Partial {
pub channel_id: ChannelId,
pub endpoint: Option<String>,
pub guild_id: GuildId,
pub session_id: Option<String>,
Expand All @@ -104,6 +151,7 @@ pub(crate) struct Partial {
impl fmt::Debug for Partial {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Partial")
.field("channel_id", &self.channel_id)
.field("endpoint", &self.endpoint)
.field("session_id", &self.session_id)
.field("token_is_some", &self.token.is_some())
Expand All @@ -119,6 +167,7 @@ impl Partial {
let token = self.token.take().unwrap();

Some(ConnectionInfo {
channel_id: Some(self.channel_id),
endpoint,
session_id,
token,
Expand All @@ -130,7 +179,17 @@ impl Partial {
}
}

fn apply_state_update(&mut self, session_id: String) -> Option<ConnectionInfo> {
fn apply_state_update(
&mut self,
session_id: String,
channel_id: ChannelId,
) -> Option<ConnectionInfo> {
if self.channel_id != channel_id {
self.endpoint = None;
self.token = None;
}

self.channel_id = channel_id;
self.session_id = Some(session_id);

self.finalise()
Expand Down
10 changes: 8 additions & 2 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,10 @@ impl Songbird {

if let Some(call) = call {
let mut handler = call.lock().await;
handler.update_state(v.0.session_id.clone());
handler.update_state(
v.0.session_id.clone(),
v.0.channel_id.clone().map(Into::into),
);
}
},
_ => {},
Expand Down Expand Up @@ -390,7 +393,10 @@ impl VoiceGatewayManager for Songbird {

if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
handler.update_state(voice_state.session_id.clone());
handler.update_state(
voice_state.session_id.clone(),
voice_state.channel_id.clone().map(Into::into),
);
}
}
}
Expand Down