From 893dbaae34b56c01fbd482840e9b794944f90ca9 Mon Sep 17 00:00:00 2001 From: Kyle Simpson Date: Mon, 8 Aug 2022 14:36:27 +0100 Subject: [PATCH] Driver: Prune `SsrcState` after timeout/disconnect (#145) `SsrcState` objects are created on a per-user basis when "receive" is enabled, but were previously never destroyed. This PR adds some shared dashmaps for the WS task to communicate SSRC-to-ID mappings to the UDP Rx task, as well as any disconnections. Additionally, decoder state is pruned a default 1 minute after a user last speaks. This was tested using `cargo make ready` and via `examples/serenity/voice_receive/`. Closes #133 --- src/config.rs | 17 ++++++++ src/driver/connection/mod.rs | 20 +++++++-- src/driver/tasks/message/udp_rx.rs | 8 ++++ src/driver/tasks/udp_rx.rs | 66 ++++++++++++++++++++++++++---- src/driver/tasks/ws.rs | 42 ++++++++++--------- 5 files changed, 122 insertions(+), 31 deletions(-) diff --git a/src/config.rs b/src/config.rs index f89ae5c0a..897056215 100644 --- a/src/config.rs +++ b/src/config.rs @@ -49,6 +49,13 @@ pub struct Config { /// [user speaking events]: crate::events::CoreEvent::SpeakingUpdate pub decode_mode: DecodeMode, + #[cfg(all(feature = "driver", feature = "receive"))] + /// Configures the amount of time after a user/SSRC is inactive before their decoder state + /// should be removed. + /// + /// Defaults to 1 minute. + pub decode_state_timeout: Duration, + #[cfg(feature = "gateway")] /// Configures the amount of time to wait for Discord to reply with connection information /// if [`Call::join`]/[`join_gateway`] are used. @@ -155,6 +162,8 @@ impl Default for Config { crypto_mode: CryptoMode::Normal, #[cfg(all(feature = "driver", feature = "receive"))] decode_mode: DecodeMode::Decrypt, + #[cfg(all(feature = "driver", feature = "receive"))] + decode_state_timeout: Duration::from_secs(60), #[cfg(feature = "gateway")] gateway_timeout: Some(Duration::from_secs(10)), #[cfg(feature = "driver")] @@ -198,6 +207,14 @@ impl Config { self } + #[cfg(feature = "receive")] + /// Sets this `Config`'s received packet decoder cleanup timer. + #[must_use] + pub fn decode_state_timeout(mut self, decode_state_timeout: Duration) -> Self { + self.decode_state_timeout = decode_state_timeout; + self + } + /// Sets this `Config`'s audio mixing channel count. #[must_use] pub fn mix_mode(mut self, mix_mode: MixMode) -> Self { diff --git a/src/driver/connection/mod.rs b/src/driver/connection/mod.rs index 4c2f1d36d..442fd8899 100644 --- a/src/driver/connection/mod.rs +++ b/src/driver/connection/mod.rs @@ -3,7 +3,10 @@ pub mod error; #[cfg(feature = "receive")] use super::tasks::udp_rx; use super::{ - tasks::{message::*, ws as ws_task}, + tasks::{ + message::*, + ws::{self as ws_task, AuxNetwork}, + }, Config, CryptoMode, }; @@ -21,6 +24,8 @@ use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPa use error::{Error, Result}; use flume::Sender; use socket2::Socket; +#[cfg(feature = "receive")] +use std::sync::Arc; use std::{net::IpAddr, str::FromStr}; use tokio::{net::UdpSocket, spawn, time::timeout}; use tracing::{debug, info, instrument}; @@ -217,15 +222,21 @@ impl Connection { .mixer .send(MixerMessage::SetConn(mix_conn, ready.ssrc))?; - spawn(ws_task::runner( - interconnect.clone(), + #[cfg(feature = "receive")] + let ssrc_tracker = Arc::new(SsrcTracker::default()); + + let ws_state = AuxNetwork::new( ws_msg_rx, client, ssrc, hello.heartbeat_interval, idx, info.clone(), - )); + #[cfg(feature = "receive")] + ssrc_tracker.clone(), + ); + + spawn(ws_task::runner(interconnect.clone(), ws_state)); #[cfg(feature = "receive")] spawn(udp_rx::runner( @@ -234,6 +245,7 @@ impl Connection { cipher, config.clone(), udp_rx, + ssrc_tracker, )); Ok(Connection { diff --git a/src/driver/tasks/message/udp_rx.rs b/src/driver/tasks/message/udp_rx.rs index 12a6d0c31..202dd54f2 100644 --- a/src/driver/tasks/message/udp_rx.rs +++ b/src/driver/tasks/message/udp_rx.rs @@ -2,8 +2,16 @@ use super::Interconnect; use crate::driver::Config; +use dashmap::{DashMap, DashSet}; +use serenity_voice_model::id::UserId; pub enum UdpRxMessage { SetConfig(Config), ReplaceInterconnect(Interconnect), } + +#[derive(Debug, Default)] +pub struct SsrcTracker { + pub disconnected_users: DashSet, + pub user_ssrc_map: DashMap, +} diff --git a/src/driver/tasks/udp_rx.rs b/src/driver/tasks/udp_rx.rs index 3a3d1fc50..ce276a025 100644 --- a/src/driver/tasks/udp_rx.rs +++ b/src/driver/tasks/udp_rx.rs @@ -22,8 +22,8 @@ use discortp::{ PacketSize, }; use flume::Receiver; -use std::{collections::HashMap, convert::TryInto}; -use tokio::{net::UdpSocket, select}; +use std::{collections::HashMap, convert::TryInto, sync::Arc, time::Duration}; +use tokio::{net::UdpSocket, select, time::Instant}; use tracing::{error, instrument, trace, warn}; use xsalsa20poly1305::XSalsa20Poly1305 as Cipher; @@ -33,6 +33,8 @@ struct SsrcState { decoder: OpusDecoder, last_seq: u16, decode_size: PacketDecodeSize, + prune_time: Instant, + disconnected: bool, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -84,13 +86,21 @@ enum SpeakingDelta { } impl SsrcState { - fn new(pkt: &RtpPacket<'_>) -> Self { + fn new(pkt: &RtpPacket<'_>, state_timeout: Duration) -> Self { Self { silent_frame_count: 5, // We do this to make the first speech packet fire an event. decoder: OpusDecoder::new(SAMPLE_RATE, Channels::Stereo) .expect("Failed to create new Opus decoder for source."), last_seq: pkt.get_sequence().into(), decode_size: PacketDecodeSize::TwentyMillis, + prune_time: Instant::now() + state_timeout, + disconnected: false, + } + } + + fn refresh_timer(&mut self, state_timeout: Duration) { + if !self.disconnected { + self.prune_time = Instant::now() + state_timeout; } } @@ -236,21 +246,23 @@ impl SsrcState { struct UdpRx { cipher: Cipher, decoder_map: HashMap, - #[allow(dead_code)] config: Config, packet_buffer: [u8; VOICE_PACKET_MAX], rx: Receiver, + ssrc_signalling: Arc, udp_socket: UdpSocket, } impl UdpRx { #[instrument(skip(self))] async fn run(&mut self, interconnect: &mut Interconnect) { + let mut cleanup_time = Instant::now(); + loop { select! { Ok((len, _addr)) = self.udp_socket.recv_from(&mut self.packet_buffer[..]) => { self.process_udp_message(interconnect, len); - } + }, msg = self.rx.recv_async() => { match msg { Ok(UdpRxMessage::ReplaceInterconnect(i)) => { @@ -261,7 +273,41 @@ impl UdpRx { }, Err(flume::RecvError::Disconnected) => break, } - } + }, + _ = tokio::time::sleep_until(cleanup_time) => { + // periodic cleanup. + let now = Instant::now(); + + // check ssrc map to see if the WS task has informed us of any disconnects. + loop { + // This is structured in an odd way to prevent deadlocks. + // while-let seemed to keep the dashmap iter() alive for block scope, rather than + // just the initialiser. + let id = { + if let Some(id) = self.ssrc_signalling.disconnected_users.iter().next().map(|v| *v.key()) { + id + } else { + break; + } + }; + + let _ = self.ssrc_signalling.disconnected_users.remove(&id); + if let Some((_, ssrc)) = self.ssrc_signalling.user_ssrc_map.remove(&id) { + if let Some(state) = self.decoder_map.get_mut(&ssrc) { + // don't cleanup immediately: leave for later cycle + // this is key with reorder/jitter buffers where we may + // still need to decode post disconnect for ~0.2s. + state.prune_time = now + Duration::from_secs(1); + state.disconnected = true; + } + } + } + + // now remove all dead ssrcs. + self.decoder_map.retain(|_, v| v.prune_time > now); + + cleanup_time = now + Duration::from_secs(5); + }, } } } @@ -310,7 +356,11 @@ impl UdpRx { let entry = self .decoder_map .entry(rtp.get_ssrc()) - .or_insert_with(|| SsrcState::new(&rtp)); + .or_insert_with(|| SsrcState::new(&rtp, self.config.decode_state_timeout)); + + // Only do this on RTP, rather than RTCP -- this pins decoder state liveness + // to *speech* rather than just presence. + entry.refresh_timer(self.config.decode_state_timeout); if let Ok((delta, audio)) = entry.process( &rtp, @@ -396,6 +446,7 @@ pub(crate) async fn runner( cipher: Cipher, config: Config, udp_socket: UdpSocket, + ssrc_signalling: Arc, ) { trace!("UDP receive handle started."); @@ -405,6 +456,7 @@ pub(crate) async fn runner( config, packet_buffer: [0u8; VOICE_PACKET_MAX], rx, + ssrc_signalling, udp_socket, }; diff --git a/src/driver/tasks/ws.rs b/src/driver/tasks/ws.rs index 07bc7960c..49b4eeb01 100644 --- a/src/driver/tasks/ws.rs +++ b/src/driver/tasks/ws.rs @@ -13,6 +13,8 @@ use crate::{ }; use flume::Receiver; use rand::random; +#[cfg(feature = "receive")] +use std::sync::Arc; use std::time::Duration; use tokio::{ select, @@ -21,7 +23,7 @@ use tokio::{ use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tracing::{debug, info, instrument, trace, warn}; -struct AuxNetwork { +pub(crate) struct AuxNetwork { rx: Receiver, ws_client: WsStream, dont_send: bool, @@ -34,6 +36,9 @@ struct AuxNetwork { attempt_idx: usize, info: ConnectionInfo, + + #[cfg(feature = "receive")] + ssrc_signalling: Arc, } impl AuxNetwork { @@ -44,6 +49,7 @@ impl AuxNetwork { heartbeat_interval: f64, attempt_idx: usize, info: ConnectionInfo, + #[cfg(feature = "receive")] ssrc_signalling: Arc, ) -> Self { Self { rx: evt_rx, @@ -58,6 +64,9 @@ impl AuxNetwork { attempt_idx, info, + + #[cfg(feature = "receive")] + ssrc_signalling, } } @@ -186,6 +195,11 @@ impl AuxNetwork { fn process_ws(&mut self, interconnect: &Interconnect, value: GatewayEvent) { match value { GatewayEvent::Speaking(ev) => { + #[cfg(feature = "receive")] + if let Some(user_id) = &ev.user_id { + self.ssrc_signalling.user_ssrc_map.insert(*user_id, ev.ssrc); + } + drop(interconnect.events.send(EventMessage::FireCoreEvent( CoreContext::SpeakingStateUpdate(ev), ))); @@ -194,6 +208,11 @@ impl AuxNetwork { debug!("Received discontinued ClientConnect: {:?}", ev); }, GatewayEvent::ClientDisconnect(ev) => { + #[cfg(feature = "receive")] + { + self.ssrc_signalling.disconnected_users.insert(ev.user_id); + } + drop(interconnect.events.send(EventMessage::FireCoreEvent( CoreContext::ClientDisconnect(ev), ))); @@ -217,26 +236,9 @@ impl AuxNetwork { } } -#[instrument(skip(interconnect, ws_client))] -pub(crate) async fn runner( - mut interconnect: Interconnect, - evt_rx: Receiver, - ws_client: WsStream, - ssrc: u32, - heartbeat_interval: f64, - attempt_idx: usize, - info: ConnectionInfo, -) { +#[instrument(skip(interconnect, aux))] +pub(crate) async fn runner(mut interconnect: Interconnect, mut aux: AuxNetwork) { trace!("WS thread started."); - let mut aux = AuxNetwork::new( - evt_rx, - ws_client, - ssrc, - heartbeat_interval, - attempt_idx, - info, - ); - aux.run(&mut interconnect).await; trace!("WS thread finished."); }