diff --git a/src/driver/config.rs b/src/driver/config.rs index c5349b628..dfcaed2c0 100644 --- a/src/driver/config.rs +++ b/src/driver/config.rs @@ -1,10 +1,81 @@ -use super::CryptoMode; +use super::{CryptoMode, DecodeMode}; /// Configuration for the inner Driver. /// -/// At present, this cannot be changed. -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub struct Config { /// Selected tagging mode for voice packet encryption. - pub crypto_mode: Option, + /// + /// Defaults to [`CryptoMode::Normal`]. + /// + /// Changes to this field will not immediately apply if the + /// driver is actively connected, but will apply to subsequent + /// sessions. + /// + /// [`CryptoMode::Normal`]: enum.CryptoMode.html#variant.Normal + pub crypto_mode: CryptoMode, + /// Configures whether decoding and decryption occur for all received packets. + /// + /// If voice receiving voice packets, generally you should choose [`DecodeMode::Decode`]. + /// [`DecodeMode::Decrypt`] is intended for users running their own selective decoding, + /// who rely upon [user speaking events], or who need to inspect Opus packets. + /// If you're certain you will never need any RT(C)P events, then consider [`DecodeMode::Pass`]. + /// + /// Defaults to [`DecodeMode::Decrypt`]. This is due to per-packet decoding costs, + /// which most users will not want to pay, but allowing speaking events which are commonly used. + /// + /// [`DecodeMode::Decode`]: enum.DecodeMode.html#variant.Decode + /// [`DecodeMode::Decrypt`]: enum.DecodeMode.html#variant.Decrypt + /// [`DecodeMode::Pass`]: enum.DecodeMode.html#variant.Pass + /// [user speaking events]: ../events/enum.CoreEvent.html#variant.SpeakingUpdate + pub decode_mode: DecodeMode, + /// Number of concurrently active tracks to allocate memory for. + /// + /// This should be set at, or just above, the maximum number of tracks + /// you expect your bot will play at the same time. Exceeding the size of + /// the internal queue will trigger a larger memory allocation and copy, + /// possibly causing the mixer thread to miss a packet deadline. + /// + /// Defaults to `1`. + /// + /// Changes to this field in a running driver will only ever increase + /// the capacity of the track store. + pub preallocated_tracks: usize, +} + +impl Default for Config { + fn default() -> Self { + Self { + crypto_mode: CryptoMode::Normal, + decode_mode: DecodeMode::Decrypt, + preallocated_tracks: 1, + } + } +} + +impl Config { + /// Sets this `Config`'s chosen cryptographic tagging scheme. + pub fn crypto_mode(mut self, crypto_mode: CryptoMode) -> Self { + self.crypto_mode = crypto_mode; + self + } + + /// Sets this `Config`'s received packet decryption/decoding behaviour. + pub fn decode_mode(mut self, decode_mode: DecodeMode) -> Self { + self.decode_mode = decode_mode; + self + } + + /// Sets this `Config`'s number of tracks to preallocate. + pub fn preallocated_tracks(mut self, preallocated_tracks: usize) -> Self { + self.preallocated_tracks = preallocated_tracks; + self + } + + /// This is used to prevent changes which would invalidate the current session. + pub(crate) fn make_safe(&mut self, previous: &Config, connected: bool) { + if connected { + self.crypto_mode = previous.crypto_mode; + } + } } diff --git a/src/driver/connection/mod.rs b/src/driver/connection/mod.rs index ee5a416f6..b2b8190fe 100644 --- a/src/driver/connection/mod.rs +++ b/src/driver/connection/mod.rs @@ -41,8 +41,6 @@ impl Connection { interconnect: &Interconnect, config: &Config, ) -> Result { - let crypto_mode = config.crypto_mode.unwrap_or(CryptoMode::Normal); - let url = generate_url(&mut info.endpoint)?; #[cfg(all(feature = "rustls", not(feature = "native")))] @@ -95,7 +93,7 @@ impl Connection { let ready = ready.expect("Ready packet expected in connection initialisation, but not found."); - if !has_valid_mode(&ready.modes, crypto_mode) { + if !has_valid_mode(&ready.modes, config.crypto_mode) { return Err(Error::CryptoModeUnavailable); } @@ -147,14 +145,14 @@ impl Connection { protocol: "udp".into(), data: ProtocolData { address, - mode: crypto_mode.to_request_str().into(), + mode: config.crypto_mode.to_request_str().into(), port: view.get_port(), }, })) .await?; } - let cipher = init_cipher(&mut client, crypto_mode).await?; + let cipher = init_cipher(&mut client, config.crypto_mode).await?; info!("Connected to: {}", info.endpoint); @@ -169,6 +167,7 @@ impl Connection { let mix_conn = MixerConnection { cipher: cipher.clone(), + crypto_state: config.crypto_mode.into(), udp_rx: udp_receiver_msg_tx, udp_tx: udp_sender_msg_tx, }; @@ -193,7 +192,7 @@ impl Connection { interconnect.clone(), udp_receiver_msg_rx, cipher, - crypto_mode, + config.clone(), udp_rx, )); tokio::spawn(udp_tx::runner(udp_sender_msg_rx, ssrc, udp_tx)); diff --git a/src/driver/crypto.rs b/src/driver/crypto.rs index e7a306d55..cfbc81304 100644 --- a/src/driver/crypto.rs +++ b/src/driver/crypto.rs @@ -1,38 +1,223 @@ //! Encryption schemes supported by Discord's secure RTP negotiation. +use byteorder::{NetworkEndian, WriteBytesExt}; +use discortp::{rtp::RtpPacket, MutablePacket}; +use rand::Rng; +use std::num::Wrapping; +use xsalsa20poly1305::{ + aead::{AeadInPlace, Error as CryptoError}, + Nonce, + Tag, + XSalsa20Poly1305 as Cipher, + NONCE_SIZE, + TAG_SIZE, +}; /// Variants of the XSalsa20Poly1305 encryption scheme. -/// -/// At present, only `Normal` is supported or selectable. #[derive(Clone, Copy, Debug, Eq, PartialEq)] #[non_exhaustive] -pub enum Mode { +pub enum CryptoMode { /// The RTP header is used as the source of nonce bytes for the packet. /// /// Equivalent to a nonce of at most 48b (6B) at no extra packet overhead: /// the RTP sequence number and timestamp are the varying quantities. Normal, /// An additional random 24B suffix is used as the source of nonce bytes for the packet. + /// This is regenerated randomly for each packet. /// /// Full nonce width of 24B (192b), at an extra 24B per packet (~1.2 kB/s). Suffix, - /// An additional random 24B suffix is used as the source of nonce bytes for the packet. + /// An additional random 4B suffix is used as the source of nonce bytes for the packet. + /// This nonce value increments by `1` with each packet. /// /// Nonce width of 4B (32b), at an extra 4B per packet (~0.2 kB/s). Lite, } -impl Mode { +impl From for CryptoMode { + fn from(val: CryptoState) -> Self { + use CryptoState::*; + match val { + Normal => CryptoMode::Normal, + Suffix => CryptoMode::Suffix, + Lite(_) => CryptoMode::Lite, + } + } +} + +impl CryptoMode { /// Returns the name of a mode as it will appear during negotiation. pub fn to_request_str(self) -> &'static str { - use Mode::*; + use CryptoMode::*; match self { Normal => "xsalsa20_poly1305", Suffix => "xsalsa20_poly1305_suffix", Lite => "xsalsa20_poly1305_lite", } } + + /// Returns the number of bytes each nonce is stored as within + /// a packet. + pub fn nonce_size(self) -> usize { + use CryptoMode::*; + match self { + Normal => RtpPacket::minimum_packet_size(), + Suffix => NONCE_SIZE, + Lite => 4, + } + } + + /// Returns the number of bytes occupied by the encryption scheme + /// which fall before the payload. + pub fn payload_prefix_len(self) -> usize { + TAG_SIZE + } + + /// Returns the number of bytes occupied by the encryption scheme + /// which fall after the payload. + pub fn payload_suffix_len(self) -> usize { + use CryptoMode::*; + match self { + Normal => 0, + Suffix | Lite => self.nonce_size(), + } + } + + /// Calculates the number of additional bytes required compared + /// to an unencrypted payload. + pub fn payload_overhead(self) -> usize { + self.payload_prefix_len() + self.payload_suffix_len() + } + + /// Extracts the byte slice in a packet used as the nonce, and the remaining mutable + /// portion of the packet. + fn nonce_slice<'a>(self, header: &'a [u8], body: &'a mut [u8]) -> (&'a [u8], &'a mut [u8]) { + use CryptoMode::*; + match self { + Normal => (header, body), + Suffix | Lite => { + let len = body.len(); + let (body_left, nonce_loc) = body.split_at_mut(len - self.payload_suffix_len()); + (&nonce_loc[..self.nonce_size()], body_left) + }, + } + } + + /// Decrypts a Discord RT(C)P packet using the given key. + /// + /// If successful, this returns the number of bytes to be ignored from the + /// start and end of the packet payload. + #[inline] + pub(crate) fn decrypt_in_place( + self, + packet: &mut impl MutablePacket, + cipher: &Cipher, + ) -> Result<(usize, usize), CryptoError> { + let header_len = packet.packet().len() - packet.payload().len(); + let (header, body) = packet.packet_mut().split_at_mut(header_len); + let (slice_to_use, body_remaining) = self.nonce_slice(header, body); + + let mut nonce = Nonce::default(); + let nonce_slice = if slice_to_use.len() == NONCE_SIZE { + Nonce::from_slice(&slice_to_use[..NONCE_SIZE]) + } else { + let max_bytes_avail = slice_to_use.len(); + nonce[..self.nonce_size().min(max_bytes_avail)].copy_from_slice(slice_to_use); + &nonce + }; + + let body_start = self.payload_prefix_len(); + let body_tail = self.payload_suffix_len(); + + let (tag_bytes, data_bytes) = body_remaining.split_at_mut(body_start); + let tag = Tag::from_slice(tag_bytes); + + Ok(cipher + .decrypt_in_place_detached(nonce_slice, b"", data_bytes, tag) + .map(|_| (body_start, body_tail))?) + } + + /// Encrypts a Discord RT(C)P packet using the given key. + /// + /// Use of this requires that the input packet has had a nonce generated in the correct location, + /// and `payload_len` specifies the number of bytes after the header including this nonce. + #[inline] + pub fn encrypt_in_place( + self, + packet: &mut impl MutablePacket, + cipher: &Cipher, + payload_len: usize, + ) -> Result<(), CryptoError> { + let header_len = packet.packet().len() - packet.payload().len(); + let (header, body) = packet.packet_mut().split_at_mut(header_len); + let (slice_to_use, body_remaining) = self.nonce_slice(header, &mut body[..payload_len]); + + let mut nonce = Nonce::default(); + let nonce_slice = if slice_to_use.len() == NONCE_SIZE { + Nonce::from_slice(&slice_to_use[..NONCE_SIZE]) + } else { + nonce[..self.nonce_size()].copy_from_slice(slice_to_use); + &nonce + }; + + // body_remaining is now correctly truncated by this point. + // the true_payload to encrypt follows after the first TAG_LEN bytes. + let tag = + cipher.encrypt_in_place_detached(nonce_slice, b"", &mut body_remaining[TAG_SIZE..])?; + body_remaining[..TAG_SIZE].copy_from_slice(&tag[..]); + + Ok(()) + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub(crate) enum CryptoState { + Normal, + Suffix, + Lite(Wrapping), +} + +impl From for CryptoState { + fn from(val: CryptoMode) -> Self { + use CryptoMode::*; + match val { + Normal => CryptoState::Normal, + Suffix => CryptoState::Suffix, + Lite => CryptoState::Lite(Wrapping(rand::random::())), + } + } } -// TODO: implement encrypt + decrypt + nonce selection for each. -// This will probably need some research into correct handling of -// padding, reported length, SRTP profiles, and so on. +impl CryptoState { + /// Writes packet nonce into the body, if required, returning the new length. + pub fn write_packet_nonce( + &mut self, + packet: &mut impl MutablePacket, + payload_end: usize, + ) -> usize { + let mode = self.kind(); + let endpoint = payload_end + mode.payload_suffix_len(); + + use CryptoState::*; + match self { + Suffix => { + rand::thread_rng().fill(&mut packet.payload_mut()[payload_end..endpoint]); + }, + Lite(mut i) => { + (&mut packet.payload_mut()[payload_end..endpoint]) + .write_u32::(i.0) + .expect( + "Nonce size is guaranteed to be sufficient to write u32 for lite tagging.", + ); + i += Wrapping(1); + }, + _ => {}, + } + + endpoint + } + + pub fn kind(&self) -> CryptoMode { + CryptoMode::from(*self) + } +} diff --git a/src/driver/decode_mode.rs b/src/driver/decode_mode.rs new file mode 100644 index 000000000..7003e779f --- /dev/null +++ b/src/driver/decode_mode.rs @@ -0,0 +1,32 @@ +/// Decode behaviour for received RTP packets within the driver. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[non_exhaustive] +pub enum DecodeMode { + /// Packets received from Discord are handed over to events without any + /// changes applied. + /// + /// No CPU work involved. + /// + /// *BEWARE: this will almost certainly break [user speaking events]. + /// Silent frame detection only works if extensions can be parsed or + /// are not present, as they are encrypted. + /// This event requires such functionality.* + /// + /// [user speaking events]: ../events/enum.CoreEvent.html#variant.SpeakingUpdate + Pass, + /// Decrypts the body of each received packet. + /// + /// Small per-packet CPU use. + Decrypt, + /// Decrypts and decodes each received packet, correctly accounting for losses. + /// + /// Larger per-packet CPU use. + Decode, +} + +impl DecodeMode { + /// Returns whether this mode will decrypt received packets. + pub fn should_decrypt(self) -> bool { + self != DecodeMode::Pass + } +} diff --git a/src/driver/mod.rs b/src/driver/mod.rs index cd148bc4d..792083467 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -11,11 +11,13 @@ mod config; pub(crate) mod connection; mod crypto; +mod decode_mode; pub(crate) mod tasks; pub use config::Config; use connection::error::Result; -pub use crypto::Mode as CryptoMode; +pub use crypto::*; +pub use decode_mode::DecodeMode; use crate::{ events::EventData, @@ -187,6 +189,13 @@ impl Driver { self.send(CoreMessage::SetTrack(None)) } + /// Sets the configuration for this driver. + #[instrument(skip(self))] + pub fn set_config(&mut self, config: Config) { + self.config = config.clone(); + self.send(CoreMessage::SetConfig(config)) + } + /// Attach a global event handler to an audio context. Global events may receive /// any [`EventContext`]. /// diff --git a/src/driver/tasks/message/core.rs b/src/driver/tasks/message/core.rs index 3c5c01793..270beecf2 100644 --- a/src/driver/tasks/message/core.rs +++ b/src/driver/tasks/message/core.rs @@ -1,5 +1,5 @@ use crate::{ - driver::connection::error::Error, + driver::{connection::error::Error, Config}, events::EventData, tracks::Track, Bitrate, @@ -16,6 +16,7 @@ pub enum CoreMessage { AddTrack(Track), SetBitrate(Bitrate), AddEvent(EventData), + SetConfig(Config), Mute(bool), Reconnect, FullReconnect, diff --git a/src/driver/tasks/message/mixer.rs b/src/driver/tasks/message/mixer.rs index 4c2eec573..260f4008b 100644 --- a/src/driver/tasks/message/mixer.rs +++ b/src/driver/tasks/message/mixer.rs @@ -1,11 +1,16 @@ use super::{Interconnect, UdpRxMessage, UdpTxMessage, WsMessage}; -use crate::{tracks::Track, Bitrate}; +use crate::{ + driver::{Config, CryptoState}, + tracks::Track, + Bitrate, +}; use flume::Sender; use xsalsa20poly1305::XSalsa20Poly1305 as Cipher; pub(crate) struct MixerConnection { pub cipher: Cipher, + pub crypto_state: CryptoState, pub udp_rx: Sender, pub udp_tx: Sender, } @@ -20,13 +25,17 @@ impl Drop for MixerConnection { pub(crate) enum MixerMessage { AddTrack(Track), SetTrack(Option), + SetBitrate(Bitrate), + SetConfig(Config), SetMute(bool), + SetConn(MixerConnection, u32), + Ws(Option>), DropConn, + ReplaceInterconnect(Interconnect), RebuildEncoder, - Ws(Option>), Poison, } diff --git a/src/driver/tasks/message/udp_rx.rs b/src/driver/tasks/message/udp_rx.rs index 91e740d35..453415d75 100644 --- a/src/driver/tasks/message/udp_rx.rs +++ b/src/driver/tasks/message/udp_rx.rs @@ -1,6 +1,8 @@ use super::Interconnect; +use crate::driver::Config; pub(crate) enum UdpRxMessage { + SetConfig(Config), ReplaceInterconnect(Interconnect), Poison, diff --git a/src/driver/tasks/mixer.rs b/src/driver/tasks/mixer.rs index 3fa5d1d3e..2d27bffcf 100644 --- a/src/driver/tasks/mixer.rs +++ b/src/driver/tasks/mixer.rs @@ -1,4 +1,4 @@ -use super::{error::Result, message::*}; +use super::{error::Result, message::*, Config}; use crate::{ constants::*, tracks::{PlayMode, Track}, @@ -13,7 +13,6 @@ use audiopus::{ use discortp::{ rtp::{MutableRtpPacket, RtpPacket}, MutablePacket, - Packet, }; use flume::{Receiver, Sender, TryRecvError}; use rand::random; @@ -21,11 +20,12 @@ use spin_sleep::SpinSleeper; use std::time::Instant; use tokio::runtime::Handle; use tracing::{error, instrument}; -use xsalsa20poly1305::{aead::AeadInPlace, Nonce, TAG_SIZE}; +use xsalsa20poly1305::TAG_SIZE; struct Mixer { async_handle: Handle, bitrate: Bitrate, + config: Config, conn_active: Option, deadline: Instant, encoder: OpusEncoder, @@ -53,6 +53,7 @@ impl Mixer { mix_rx: Receiver, async_handle: Handle, interconnect: Interconnect, + config: Config, ) -> Self { let bitrate = DEFAULT_BITRATE; let encoder = new_encoder(bitrate) @@ -70,9 +71,12 @@ impl Mixer { rtp.set_sequence(random::().into()); rtp.set_timestamp(random::().into()); + let tracks = Vec::with_capacity(1.max(config.preallocated_tracks)); + Self { async_handle, bitrate, + config, conn_active: None, deadline: Instant::now(), encoder, @@ -84,7 +88,7 @@ impl Mixer { silence_frames: 0, sleeper: Default::default(), soft_clip, - tracks: vec![], + tracks, ws: None, } } @@ -137,6 +141,8 @@ impl Mixer { (Blame: VOICE_PACKET_MAX?)", ); rtp.set_ssrc(ssrc); + rtp.set_sequence(random::().into()); + rtp.set_timestamp(random::().into()); self.deadline = Instant::now(); Ok(()) }, @@ -160,6 +166,23 @@ impl Mixer { self.rebuild_tracks() }, + Ok(SetConfig(new_config)) => { + self.config = new_config.clone(); + + if self.tracks.capacity() < self.config.preallocated_tracks { + self.tracks + .reserve(self.config.preallocated_tracks - self.tracks.len()); + } + + if let Some(conn) = &self.conn_active { + conn_failure |= conn + .udp_rx + .send(UdpRxMessage::SetConfig(new_config)) + .is_err(); + } + + Ok(()) + }, Ok(RebuildEncoder) => match new_encoder(self.bitrate) { Ok(encoder) => { self.encoder = encoder; @@ -449,38 +472,38 @@ impl Mixer { .as_mut() .expect("Shouldn't be mixing packets without access to a cipher + UDP dest."); - let mut nonce = Nonce::default(); let index = { let mut rtp = MutableRtpPacket::new(&mut self.packet[..]).expect( "FATAL: Too few bytes in self.packet for RTP header.\ (Blame: VOICE_PACKET_MAX?)", ); - let pkt = rtp.packet(); - let rtp_len = RtpPacket::minimum_packet_size(); - nonce[..rtp_len].copy_from_slice(&pkt[..rtp_len]); - let payload = rtp.payload_mut(); + let crypto_mode = conn.crypto_state.kind(); let payload_len = if opus_frame.is_empty() { - self.encoder - .encode_float(&buffer[..STEREO_FRAME_SIZE], &mut payload[TAG_SIZE..])? + let total_payload_space = payload.len() - crypto_mode.payload_suffix_len(); + self.encoder.encode_float( + &buffer[..STEREO_FRAME_SIZE], + &mut payload[TAG_SIZE..total_payload_space], + )? } else { let len = opus_frame.len(); payload[TAG_SIZE..TAG_SIZE + len].clone_from_slice(opus_frame); len }; - let final_payload_size = TAG_SIZE + payload_len; + let final_payload_size = conn + .crypto_state + .write_packet_nonce(&mut rtp, TAG_SIZE + payload_len); - let tag = conn.cipher.encrypt_in_place_detached( - &nonce, - b"", - &mut payload[TAG_SIZE..final_payload_size], + conn.crypto_state.kind().encrypt_in_place( + &mut rtp, + &conn.cipher, + final_payload_size, )?; - payload[..TAG_SIZE].copy_from_slice(&tag[..]); - rtp_len + final_payload_size + RtpPacket::minimum_packet_size() + final_payload_size }; // TODO: This is dog slow, don't do this. @@ -509,8 +532,9 @@ pub(crate) fn runner( interconnect: Interconnect, mix_rx: Receiver, async_handle: Handle, + config: Config, ) { - let mut mixer = Mixer::new(mix_rx, async_handle, interconnect); + let mut mixer = Mixer::new(mix_rx, async_handle, interconnect, config); mixer.run(); } diff --git a/src/driver/tasks/mod.rs b/src/driver/tasks/mod.rs index 2e0b2d032..fe0257cda 100644 --- a/src/driver/tasks/mod.rs +++ b/src/driver/tasks/mod.rs @@ -23,7 +23,7 @@ pub(crate) fn start(config: Config, rx: Receiver, tx: Sender) -> Interconnect { +fn start_internals(core: Sender, config: Config) -> Interconnect { let (evt_tx, evt_rx) = flume::unbounded(); let (mix_tx, mix_rx) = flume::unbounded(); @@ -44,7 +44,7 @@ fn start_internals(core: Sender) -> Interconnect { let handle = Handle::current(); std::thread::spawn(move || { info!("Mixer started."); - mixer::runner(ic, mix_rx, handle); + mixer::runner(ic, mix_rx, handle, config); info!("Mixer finished."); }); @@ -52,13 +52,23 @@ fn start_internals(core: Sender) -> Interconnect { } #[instrument(skip(rx, tx))] -async fn runner(config: Config, rx: Receiver, tx: Sender) { +async fn runner(mut config: Config, rx: Receiver, tx: Sender) { + let mut next_config: Option = None; let mut connection = None; - let mut interconnect = start_internals(tx); + let mut interconnect = start_internals(tx, config.clone()); loop { match rx.recv_async().await { Ok(CoreMessage::ConnectWithResult(info, tx)) => { + config = if let Some(new_config) = next_config.take() { + let _ = interconnect + .mixer + .send(MixerMessage::SetConfig(new_config.clone())); + new_config + } else { + config + }; + connection = match Connection::new(info, &interconnect, &config).await { Ok(connection) => { // Other side may not be listening: this is fine. @@ -87,6 +97,13 @@ async fn runner(config: Config, rx: Receiver, tx: Sender { let _ = interconnect.mixer.send(MixerMessage::SetBitrate(b)); }, + Ok(CoreMessage::SetConfig(mut new_config)) => { + next_config = Some(new_config.clone()); + + new_config.make_safe(&config, connection.is_some()); + + let _ = interconnect.mixer.send(MixerMessage::SetConfig(new_config)); + }, Ok(CoreMessage::AddEvent(evt)) => { let _ = interconnect.events.send(EventMessage::AddGlobalEvent(evt)); }, diff --git a/src/driver/tasks/udp_rx.rs b/src/driver/tasks/udp_rx.rs index 263ef7617..937457398 100644 --- a/src/driver/tasks/udp_rx.rs +++ b/src/driver/tasks/udp_rx.rs @@ -2,13 +2,16 @@ use super::{ error::{Error, Result}, message::*, }; -use crate::{constants::*, driver::CryptoMode, events::CoreContext}; +use crate::{ + constants::*, + driver::{Config, DecodeMode}, + events::CoreContext, +}; use audiopus::{coder::Decoder as OpusDecoder, Channels}; use discortp::{ demux::{self, DemuxedMut}, rtp::{RtpExtensionPacket, RtpPacket}, FromPacket, - MutablePacket, Packet, PacketSize, }; @@ -16,7 +19,7 @@ use flume::Receiver; use std::collections::HashMap; use tokio::net::udp::RecvHalf; use tracing::{error, info, instrument, warn}; -use xsalsa20poly1305::{aead::AeadInPlace, Nonce, Tag, XSalsa20Poly1305 as Cipher, TAG_SIZE}; +use xsalsa20poly1305::XSalsa20Poly1305 as Cipher; #[derive(Debug)] struct SsrcState { @@ -46,19 +49,38 @@ impl SsrcState { &mut self, pkt: RtpPacket<'_>, data_offset: usize, - ) -> Result<(SpeakingDelta, Vec)> { + data_trailer: usize, + decode_mode: DecodeMode, + decrypted: bool, + ) -> Result<(SpeakingDelta, Option>)> { let new_seq: u16 = pkt.get_sequence().into(); + let payload_len = pkt.payload().len(); let extensions = pkt.get_extension() != 0; let seq_delta = new_seq.wrapping_sub(self.last_seq); Ok(if seq_delta >= (1 << 15) { // Overflow, reordered (previously missing) packet. - (SpeakingDelta::Same, vec![]) + (SpeakingDelta::Same, Some(vec![])) } else { self.last_seq = new_seq; let missed_packets = seq_delta.saturating_sub(1); - let (audio, pkt_size) = - self.scan_and_decode(&pkt.payload()[data_offset..], extensions, missed_packets)?; + + // Note: we still need to handle this for non-decoded. + // This is mainly because packet events and speaking events can be handed to the + // user. + let (audio, pkt_size) = if decode_mode.should_decrypt() && decrypted { + self.scan_and_decode( + &pkt.payload()[data_offset..payload_len - data_trailer], + extensions, + missed_packets, + decode_mode == DecodeMode::Decode, + )? + } else { + // The latter part is an upper bound, as we cannot determine + // how long packet extensions are. + // WIthout decryption, speaking detection is thus broken. + (None, payload_len - data_offset - data_trailer) + }; let delta = if pkt_size == SILENT_FRAME.len() { // Frame is silent. @@ -91,8 +113,8 @@ impl SsrcState { data: &[u8], extension: bool, missed_packets: u16, - ) -> Result<(Vec, usize)> { - let mut out = vec![0; STEREO_FRAME_SIZE]; + decode: bool, + ) -> Result<(Option>, usize)> { let start = if extension { RtpExtensionPacket::new(data) .map(|pkt| pkt.packet_size()) @@ -104,26 +126,34 @@ impl SsrcState { Ok(0) }?; - for _ in 0..missed_packets { - let missing_frame: Option<&[u8]> = None; - if let Err(e) = self.decoder.decode(missing_frame, &mut out[..], false) { - warn!("Issue while decoding for missed packet: {:?}.", e); + let pkt = if decode { + let mut out = vec![0; STEREO_FRAME_SIZE]; + + for _ in 0..missed_packets { + let missing_frame: Option<&[u8]> = None; + if let Err(e) = self.decoder.decode(missing_frame, &mut out[..], false) { + warn!("Issue while decoding for missed packet: {:?}.", e); + } } - } - let audio_len = self - .decoder - .decode(Some(&data[start..]), &mut out[..], false) - .map_err(|e| { - error!("Failed to decode received packet: {:?}.", e); - e - })?; + let audio_len = self + .decoder + .decode(Some(&data[start..]), &mut out[..], false) + .map_err(|e| { + error!("Failed to decode received packet: {:?}.", e); + e + })?; - // Decoding to stereo: audio_len refers to sample count irrespective of channel count. - // => multiply by number of channels. - out.truncate(2 * audio_len); + // Decoding to stereo: audio_len refers to sample count irrespective of channel count. + // => multiply by number of channels. + out.truncate(2 * audio_len); - Ok((out, data.len() - start)) + Some(out) + } else { + None + }; + + Ok((pkt, data.len() - start)) } } @@ -131,7 +161,7 @@ struct UdpRx { cipher: Cipher, decoder_map: HashMap, #[allow(dead_code)] - mode: CryptoMode, // In future, this will allow crypto mode selection. + config: Config, packet_buffer: [u8; VOICE_PACKET_MAX], rx: Receiver, udp_socket: RecvHalf, @@ -150,7 +180,10 @@ impl UdpRx { match msg { Ok(ReplaceInterconnect(i)) => { *interconnect = i; - } + }, + Ok(SetConfig(c)) => { + self.config = c; + }, Ok(Poison) | Err(_) => break, } } @@ -166,6 +199,7 @@ impl UdpRx { // For simplicity, we nominate the mixing context to rebuild the event // context if it fails (hence, the `let _ =` statements.), as it will try to // make contact every 20ms. + let crypto_mode = self.config.crypto_mode; let packet = &mut self.packet_buffer[..len]; match demux::demux_mut(packet) { @@ -175,15 +209,40 @@ impl UdpRx { return; } - let rtp_body_start = - decrypt_in_place(&mut rtp, &self.cipher).expect("RTP decryption failed."); + let packet_data = if self.config.decode_mode.should_decrypt() { + let out = crypto_mode + .decrypt_in_place(&mut rtp, &self.cipher) + .map(|(s, t)| (s, t, true)); + + if let Err(e) = out { + warn!("RTP decryption failed: {:?}", e); + } + + out.ok() + } else { + None + }; + + let (rtp_body_start, rtp_body_tail, decrypted) = packet_data.unwrap_or_else(|| { + ( + crypto_mode.payload_prefix_len(), + crypto_mode.payload_suffix_len(), + false, + ) + }); let entry = self .decoder_map .entry(rtp.get_ssrc()) .or_insert_with(|| SsrcState::new(rtp.to_immutable())); - if let Ok((delta, audio)) = entry.process(rtp.to_immutable(), rtp_body_start) { + if let Ok((delta, audio)) = entry.process( + rtp.to_immutable(), + rtp_body_start, + rtp_body_tail, + self.config.decode_mode, + decrypted, + ) { match delta { SpeakingDelta::Start => { let _ = interconnect.events.send(EventMessage::FireCoreEvent( @@ -209,25 +268,40 @@ impl UdpRx { audio, packet: rtp.from_packet(), payload_offset: rtp_body_start, + payload_end_pad: rtp_body_tail, }, )); } else { - warn!("RTP decoding/decrytion failed."); + warn!("RTP decoding/processing failed."); } }, DemuxedMut::Rtcp(mut rtcp) => { - let rtcp_body_start = decrypt_in_place(&mut rtcp, &self.cipher); + let packet_data = if self.config.decode_mode.should_decrypt() { + let out = crypto_mode.decrypt_in_place(&mut rtcp, &self.cipher); - if let Ok(start) = rtcp_body_start { - let _ = interconnect.events.send(EventMessage::FireCoreEvent( - CoreContext::RtcpPacket { - packet: rtcp.from_packet(), - payload_offset: start, - }, - )); + if let Err(e) = out { + warn!("RTCP decryption failed: {:?}", e); + } + + out.ok() } else { - warn!("RTCP decryption failed."); - } + None + }; + + let (start, tail) = packet_data.unwrap_or_else(|| { + ( + crypto_mode.payload_prefix_len(), + crypto_mode.payload_suffix_len(), + ) + }); + + let _ = interconnect.events.send(EventMessage::FireCoreEvent( + CoreContext::RtcpPacket { + packet: rtcp.from_packet(), + payload_offset: start, + payload_end_pad: tail, + }, + )); }, DemuxedMut::FailedParse(t) => { warn!("Failed to parse message of type {:?}.", t); @@ -244,7 +318,7 @@ pub(crate) async fn runner( mut interconnect: Interconnect, rx: Receiver, cipher: Cipher, - mode: CryptoMode, + config: Config, udp_socket: RecvHalf, ) { info!("UDP receive handle started."); @@ -252,7 +326,7 @@ pub(crate) async fn runner( let mut state = UdpRx { cipher, decoder_map: Default::default(), - mode, + config, packet_buffer: [0u8; VOICE_PACKET_MAX], rx, udp_socket, @@ -263,23 +337,6 @@ pub(crate) async fn runner( info!("UDP receive handle stopped."); } -#[inline] -fn decrypt_in_place(packet: &mut impl MutablePacket, cipher: &Cipher) -> Result { - // Applies discord's cheapest. - // In future, might want to make a choice... - let header_len = packet.packet().len() - packet.payload().len(); - let mut nonce = Nonce::default(); - nonce[..header_len].copy_from_slice(&packet.packet()[..header_len]); - - let data = packet.payload_mut(); - let (tag_bytes, data_bytes) = data.split_at_mut(TAG_SIZE); - let tag = Tag::from_slice(tag_bytes); - - Ok(cipher - .decrypt_in_place_detached(&nonce, b"", data_bytes, tag) - .map(|_| TAG_SIZE)?) -} - #[inline] fn rtp_valid(packet: RtpPacket<'_>) -> bool { packet.get_version() == RTP_VERSION && packet.get_payload_type() == RTP_PROFILE_TYPE diff --git a/src/events/context.rs b/src/events/context.rs index 004465fb2..1acf088f9 100644 --- a/src/events/context.rs +++ b/src/events/context.rs @@ -42,13 +42,15 @@ pub enum EventContext<'a> { /// if `audio.len() == 0`, then this packet arrived out-of-order. VoicePacket { /// Decoded audio from this packet. - audio: &'a Vec, + audio: &'a Option>, /// Raw RTP packet data. /// /// Includes the SSRC (i.e., sender) of this packet. packet: &'a Rtp, - /// Byte index into the packet for where the payload begins. + /// Byte index into the packet body (after headers) for where the payload begins. payload_offset: usize, + /// Number of bytes at the end of the packet to discard. + payload_end_pad: usize, }, /// Telemetry/statistics packet, received from another stream (detailed in `packet`). /// `payload_offset` contains the true payload location within the raw packet's `payload()`, @@ -56,8 +58,10 @@ pub enum EventContext<'a> { RtcpPacket { /// Raw RTCP packet data. packet: &'a Rtcp, - /// Byte index into the packet for where the payload begins. + /// Byte index into the packet body (after headers) for where the payload begins. payload_offset: usize, + /// Number of bytes at the end of the packet to discard. + payload_end_pad: usize, }, /// Fired whenever a client connects to a call for the first time, allowing SSRC/UserID /// matching. @@ -74,13 +78,15 @@ pub(crate) enum CoreContext { speaking: bool, }, VoicePacket { - audio: Vec, + audio: Option>, packet: Rtp, payload_offset: usize, + payload_end_pad: usize, }, RtcpPacket { packet: Rtcp, payload_offset: usize, + payload_end_pad: usize, }, ClientConnect(ClientConnect), ClientDisconnect(ClientDisconnect), @@ -100,17 +106,21 @@ impl<'a> CoreContext { audio, packet, payload_offset, + payload_end_pad, } => EventContext::VoicePacket { audio, packet, payload_offset: *payload_offset, + payload_end_pad: *payload_end_pad, }, RtcpPacket { packet, payload_offset, + payload_end_pad, } => EventContext::RtcpPacket { packet, payload_offset: *payload_offset, + payload_end_pad: *payload_end_pad, }, ClientConnect(evt) => EventContext::ClientConnect(*evt), ClientDisconnect(evt) => EventContext::ClientDisconnect(*evt), diff --git a/src/handler.rs b/src/handler.rs index 3ecb089cf..ee56c25e9 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,5 +1,8 @@ #[cfg(feature = "driver")] -use crate::{driver::Driver, error::ConnectionResult}; +use crate::{ + driver::{Config, Driver}, + error::ConnectionResult, +}; use crate::{ error::{JoinError, JoinResult}, id::{ChannelId, GuildId, UserId}, @@ -59,6 +62,19 @@ impl Call { Self::new_raw(guild_id, Some(ws), user_id) } + #[cfg(feature = "driver")] + /// Creates a new Call, configuring the driver as specified. + #[inline] + #[instrument] + pub fn from_driver_config( + guild_id: GuildId, + ws: Shard, + user_id: UserId, + config: Config, + ) -> Self { + Self::new_raw_cfg(guild_id, Some(ws), user_id, config) + } + /// Creates a new, standalone Call which is not connected via /// WebSocket to the Gateway. /// @@ -73,6 +89,18 @@ impl Call { Self::new_raw(guild_id, None, user_id) } + #[cfg(feature = "driver")] + /// Creates a new standalone Call, configuring the driver as specified. + #[inline] + #[instrument] + pub fn standalone_from_driver_config( + guild_id: GuildId, + user_id: UserId, + config: Config, + ) -> Self { + Self::new_raw_cfg(guild_id, None, user_id, config) + } + fn new_raw(guild_id: GuildId, ws: Option, user_id: UserId) -> Self { Call { connection: None, @@ -86,6 +114,19 @@ impl Call { } } + #[cfg(feature = "driver")] + fn new_raw_cfg(guild_id: GuildId, ws: Option, user_id: UserId, config: Config) -> Self { + Call { + connection: None, + driver: Driver::new(config), + guild_id, + self_deaf: false, + self_mute: false, + user_id, + ws, + } + } + #[instrument(skip(self))] fn do_connect(&mut self) { match &self.connection { diff --git a/src/manager.rs b/src/manager.rs index 754397587..8b1dbd255 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -1,5 +1,5 @@ #[cfg(feature = "driver")] -use crate::error::ConnectionResult; +use crate::{driver::Config, error::ConnectionResult}; use crate::{ error::{JoinError, JoinResult}, id::{ChannelId, GuildId, UserId}, @@ -47,6 +47,9 @@ pub struct Songbird { client_data: PRwLock, calls: PRwLock>>>, sharder: Sharder, + + #[cfg(feature = "driver")] + driver_config: PRwLock>, } impl Songbird { @@ -61,6 +64,9 @@ impl Songbird { client_data: Default::default(), calls: Default::default(), sharder: Sharder::Serenity(Default::default()), + + #[cfg(feature = "driver")] + driver_config: Default::default(), }) } @@ -84,6 +90,9 @@ impl Songbird { }), calls: Default::default(), sharder: Sharder::Twilight(cluster), + + #[cfg(feature = "driver")] + driver_config: Default::default(), }) } @@ -133,7 +142,18 @@ impl Songbird { .get_shard(shard) .expect("Failed to get shard handle: shard_count incorrect?"); - Arc::new(Mutex::new(Call::new(guild_id, shard_handle, info.user_id))) + #[cfg(feature = "driver")] + let call = Call::from_driver_config( + guild_id, + shard_handle, + info.user_id, + self.driver_config.read().clone().unwrap_or_default(), + ); + + #[cfg(not(feature = "driver"))] + let call = Call::new(guild_id, shard_handle, info.user_id); + + Arc::new(Mutex::new(call)) }) .clone() }) @@ -347,6 +367,20 @@ impl VoiceGatewayManager for Songbird { } } +#[cfg(feature = "driver")] +impl Songbird { + /// Sets a shared configuration for all drivers created from this + /// manager. + /// + /// Changes made here will apply to new Call and Driver instances only. + /// + /// Requires the `"driver"` feature. + pub fn set_config(&self, new_config: Config) { + let mut config = self.driver_config.write(); + *config = Some(new_config); + } +} + #[inline] fn shard_id(guild_id: u64, shard_count: u64) -> u64 { (guild_id >> 22) % shard_count