Skip to content

Commit

Permalink
Driver: Migrate to tokio_tungstenite (#138)
Browse files Browse the repository at this point in the history
This places songbird, serenity, and twilight onto the same WS library, hopefully reducing the compile overhead for everyone.

Tested using `cargo make ready` and by running `examples/voice`.

Closes #129.
  • Loading branch information
FelixMcFelix committed Nov 19, 2023
1 parent 13946b4 commit 76c9851
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 187 deletions.
26 changes: 12 additions & 14 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,10 @@ serde_json = "1"
tracing = { version = "0.1", features = ["log"] }
tracing-futures = "0.2"

[dependencies.once_cell]
version = "1"
optional = true

[dependencies.async-trait]
optional = true
version = "0.1"

[dependencies.async-tungstenite]
default-features = false
features = ["tokio-runtime"]
optional = true
version = "0.17"

[dependencies.audiopus]
optional = true
version = "0.3.0-rc.0"
Expand All @@ -59,6 +49,10 @@ version = "0.10"
[dependencies.futures]
version = "0.3"

[dependencies.once_cell]
version = "1"
optional = true

[dependencies.parking_lot]
optional = true
version = "0.12"
Expand Down Expand Up @@ -127,6 +121,10 @@ optional = true
version = "1.0"
default-features = false

[dependencies.tokio-tungstenite]
optional = true
version = "0.17"

[dependencies.tokio-util]
optional = true
version = "0.7"
Expand Down Expand Up @@ -184,7 +182,6 @@ gateway = [
]
driver = [
"async-trait",
"async-tungstenite",
"audiopus",
"byteorder",
"discortp",
Expand All @@ -201,7 +198,6 @@ driver = [
"symphonia",
"symphonia-core",
"rusty_pool",
"tokio-util",
"tokio/fs",
"tokio/io-util",
"tokio/macros",
Expand All @@ -210,13 +206,15 @@ driver = [
"tokio/rt",
"tokio/sync",
"tokio/time",
"tokio-tungstenite",
"tokio-util",
"typemap_rev",
"url",
"uuid",
"xsalsa20poly1305",
]
rustls = ["async-tungstenite/tokio-rustls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"]
native = ["async-tungstenite/tokio-native-tls", "native-marker", "reqwest/native-tls"]
rustls = ["tokio-tungstenite/rustls-tls-webpki-roots", "reqwest/rustls-tls", "rustls-marker"]
native = ["tokio-tungstenite/native-tls", "native-marker", "reqwest/native-tls"]
serenity-rustls = ["serenity/rustls_backend", "rustls", "gateway", "serenity-deps"]
serenity-native = ["serenity/native_tls_backend", "native", "gateway", "serenity-deps"]
twilight-rustls = ["twilight", "twilight-gateway/rustls-native-roots", "rustls", "gateway"]
Expand Down
21 changes: 3 additions & 18 deletions src/driver/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
Event as GatewayEvent,
ProtocolData,
},
ws::{self, ReceiverExt, SenderExt, WsStream},
ws::WsStream,
ConnectionInfo,
};
use discortp::discord::{IpDiscoveryPacket, IpDiscoveryType, MutableIpDiscoveryPacket};
Expand All @@ -24,12 +24,6 @@ use tracing::{debug, info, instrument};
use url::Url;
use xsalsa20poly1305::{aead::NewAead, XSalsa20Poly1305 as Cipher};

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
use ws::create_rustls_client;

#[cfg(feature = "native-marker")]
use ws::create_native_tls_client;

pub(crate) struct Connection {
pub(crate) info: ConnectionInfo,
pub(crate) ssrc: u32,
Expand Down Expand Up @@ -58,11 +52,7 @@ impl Connection {
) -> Result<Connection> {
let url = generate_url(&mut info.endpoint)?;

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
let mut client = create_rustls_client(url).await?;

#[cfg(feature = "native-marker")]
let mut client = create_native_tls_client(url).await?;
let mut client = WsStream::connect(url).await?;

let mut hello = None;
let mut ready = None;
Expand Down Expand Up @@ -241,12 +231,7 @@ impl Connection {

// Thread may have died, we want to send to prompt a clean exit
// (if at all possible) and then proceed as normal.

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
let mut client = create_rustls_client(url).await?;

#[cfg(feature = "native-marker")]
let mut client = create_native_tls_client(url).await?;
let mut client = WsStream::connect(url).await?;

client
.send_json(&GatewayEvent::from(Resume {
Expand Down
4 changes: 2 additions & 2 deletions src/driver/tasks/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ use crate::{
FromPrimitive,
SpeakingState,
},
ws::{Error as WsError, ReceiverExt, SenderExt, WsStream},
ws::{Error as WsError, WsStream},
ConnectionInfo,
};
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use flume::Receiver;
use rand::random;
use std::time::Duration;
use tokio::{
select,
time::{sleep_until, Instant},
};
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tracing::{debug, info, instrument, trace, warn};

struct AuxNetwork {
Expand Down
2 changes: 1 addition & 1 deletion src/events/context/data/disconnect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
model::{CloseCode as VoiceCloseCode, FromPrimitive},
ws::Error as WsError,
};
use async_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;

/// Voice connection details gathered at termination or failure.
///
Expand Down
208 changes: 56 additions & 152 deletions src/ws.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,71 @@
use crate::{error::JsonError, model::Event};

use async_trait::async_trait;
use async_tungstenite::{
self as tungstenite,
tokio::ConnectStream,
tungstenite::{error::Error as TungsteniteError, protocol::CloseFrame, Message},
use futures::{SinkExt, StreamExt, TryStreamExt};
use tokio::{
net::TcpStream,
time::{timeout, Duration},
};
use tokio_tungstenite::{
tungstenite::{
error::Error as TungsteniteError,
protocol::{CloseFrame, WebSocketConfig as Config},
Message,
},
MaybeTlsStream,
WebSocketStream,
};
use futures::{SinkExt, StreamExt, TryStreamExt};
use tokio::time::{timeout, Duration};
use tracing::instrument;
use url::Url;

pub struct WsStream(WebSocketStream<MaybeTlsStream<TcpStream>>);

impl WsStream {
#[instrument]
pub(crate) async fn connect(url: Url) -> Result<Self> {
let (stream, _) = tokio_tungstenite::connect_async_with_config::<Url>(
url,
Some(Config {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await?;

Ok(Self(stream))
}

pub(crate) async fn recv_json(&mut self) -> Result<Option<Event>> {
const TIMEOUT: Duration = Duration::from_millis(500);

let ws_message = match timeout(TIMEOUT, self.0.next()).await {
Ok(Some(Ok(v))) => Some(v),
Ok(Some(Err(e))) => return Err(e.into()),
Ok(None) | Err(_) => None,
};

convert_ws_message(ws_message)
}

pub type WsStream = WebSocketStream<ConnectStream>;
pub(crate) async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
convert_ws_message(self.0.try_next().await?)
}

pub(crate) async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.0.send(m))?
.await?)
}
}

pub type Result<T> = std::result::Result<T, Error>;

#[derive(Debug)]
pub enum Error {
Json(JsonError),
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
Tls(RustlsError),

/// The discord voice gateway does not support or offer zlib compression.
/// As a result, only text messages are expected.
Expand All @@ -36,80 +82,12 @@ impl From<JsonError> for Error {
}
}

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl From<RustlsError> for Error {
fn from(e: RustlsError) -> Error {
Error::Tls(e)
}
}

impl From<TungsteniteError> for Error {
fn from(e: TungsteniteError) -> Error {
Error::Ws(e)
}
}

use futures::stream::SplitSink;
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
use std::{
error::Error as StdError,
fmt::{Display, Formatter, Result as FmtResult},
io::Error as IoError,
};
use url::Url;

#[async_trait]
pub trait ReceiverExt {
async fn recv_json(&mut self) -> Result<Option<Event>>;
async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>>;
}

#[async_trait]
pub trait SenderExt {
async fn send_json(&mut self, value: &Event) -> Result<()>;
}

#[async_trait]
impl ReceiverExt for WsStream {
async fn recv_json(&mut self) -> Result<Option<Event>> {
const TIMEOUT: Duration = Duration::from_millis(500);

let ws_message = match timeout(TIMEOUT, self.next()).await {
Ok(Some(Ok(v))) => Some(v),
Ok(Some(Err(e))) => return Err(e.into()),
Ok(None) | Err(_) => None,
};

convert_ws_message(ws_message)
}

async fn recv_json_no_timeout(&mut self) -> Result<Option<Event>> {
convert_ws_message(self.try_next().await?)
}
}

#[async_trait]
impl SenderExt for SplitSink<WsStream, Message> {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}

#[async_trait]
impl SenderExt for WsStream {
async fn send_json(&mut self, value: &Event) -> Result<()> {
Ok(crate::json::to_string(value)
.map(Message::Text)
.map_err(Error::from)
.map(|m| self.send(m))?
.await?)
}
}

#[inline]
pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Event>> {
Ok(match message {
Expand All @@ -125,77 +103,3 @@ pub(crate) fn convert_ws_message(message: Option<Message>) -> Result<Option<Even
_ => None,
})
}

/// An error that occured while connecting over rustls
#[derive(Debug)]
#[non_exhaustive]
#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
pub enum RustlsError {
/// An error with the handshake in tungstenite
HandshakeError,
/// Standard IO error happening while creating the tcp stream
Io(IoError),
}

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl From<IoError> for RustlsError {
fn from(e: IoError) -> Self {
RustlsError::Io(e)
}
}

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl Display for RustlsError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
RustlsError::HandshakeError =>
f.write_str("TLS handshake failed when making the websocket connection"),
RustlsError::Io(inner) => Display::fmt(&inner, f),
}
}
}

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
impl StdError for RustlsError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
RustlsError::Io(inner) => Some(inner),
_ => None,
}
}
}

#[cfg(all(feature = "rustls-marker", not(feature = "native-marker")))]
#[instrument]
pub(crate) async fn create_rustls_client(url: Url) -> Result<WsStream> {
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await
.map_err(|_| RustlsError::HandshakeError)?;

Ok(stream)
}

#[cfg(feature = "native-marker")]
#[instrument]
pub(crate) async fn create_native_tls_client(url: Url) -> Result<WsStream> {
let (stream, _) = tungstenite::tokio::connect_async_with_config::<Url>(
url,
Some(tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: None,
max_frame_size: None,
max_send_queue: None,
..Default::default()
}),
)
.await?;

Ok(stream)
}

0 comments on commit 76c9851

Please sign in to comment.