Skip to content

Commit

Permalink
Add Websockets support
Browse files Browse the repository at this point in the history
Signed-off-by: Tomasz Pietrek <[email protected]>
  • Loading branch information
Jarema authored Oct 30, 2024
1 parent 1fbcfdd commit e40abcf
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ jobs:

- name: Install msrv Rust on ubuntu-latest
id: install-rust
uses: dtolnay/rust-toolchain@1.70.0
uses: dtolnay/rust-toolchain@1.79.0
- name: Cache the build artifacts
uses: Swatinem/rust-cache@v2
with:
Expand Down
10 changes: 6 additions & 4 deletions async-nats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "async-nats"
authors = ["Tomasz Pietrek <[email protected]>", "Casper Beyer <[email protected]>"]
version = "0.37.0"
edition = "2021"
rust = "1.74.0"
rust = "1.79.0"
description = "A async Rust NATS client"
license = "Apache-2.0"
documentation = "https://docs.rs/async-nats"
Expand Down Expand Up @@ -41,6 +41,8 @@ ring = { version = "0.17", optional = true }
rand = "0.8"
webpki = { package = "rustls-webpki", version = "0.102" }
portable-atomic = "1"
tokio-websockets = { version = "0.10", features = ["client", "rand", "rustls-native-roots"], optional = true }
pin-project = "1.0"

[dev-dependencies]
ring = "0.17"
Expand All @@ -57,13 +59,13 @@ jsonschema = "0.17.1"
# for -Z minimal-versions
num = "0.4.1"


[features]
default = ["server_2_10", "ring"]
# Enables Service API for the client.
service = []
aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs"]
ring = ["dep:ring", "tokio-rustls/ring"]
websockets = ["dep:tokio-websockets"]
aws-lc-rs = ["dep:aws-lc-rs", "tokio-rustls/aws-lc-rs", "tokio-websockets/aws-lc-rs"]
ring = ["dep:ring", "tokio-rustls/ring", "tokio-websockets/ring"]
fips = ["aws-lc-rs", "tokio-rustls/fips"]
# All experimental features are part of this feature flag.
experimental = ["service"]
Expand Down
110 changes: 110 additions & 0 deletions async-nats/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::{Context, Poll};

#[cfg(feature = "websockets")]
use {
futures::{SinkExt, StreamExt},
pin_project::pin_project,
tokio::io::ReadBuf,
tokio_websockets::WebSocketStream,
};

use bytes::{Buf, Bytes, BytesMut};
use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};

Expand Down Expand Up @@ -683,6 +691,108 @@ impl Connection {
}
}

#[cfg(feature = "websockets")]
#[pin_project]
pub(crate) struct WebSocketAdapter<T> {
#[pin]
pub(crate) inner: WebSocketStream<T>,
pub(crate) read_buf: BytesMut,
}

#[cfg(feature = "websockets")]
impl<T> WebSocketAdapter<T> {
pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
Self {
inner,
read_buf: BytesMut::new(),
}
}
}

#[cfg(feature = "websockets")]
impl<T> AsyncRead for WebSocketAdapter<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let mut this = self.project();

loop {
// If we have data in the read buffer, let's move it to the output buffer.
if !this.read_buf.is_empty() {
let len = std::cmp::min(buf.remaining(), this.read_buf.len());
buf.put_slice(&this.read_buf.split_to(len));
return Poll::Ready(Ok(()));
}

match this.inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(message))) => {
this.read_buf.extend_from_slice(message.as_payload());
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
}
Poll::Ready(None) => {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"WebSocket closed",
)));
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}

#[cfg(feature = "websockets")]
impl<T> AsyncWrite for WebSocketAdapter<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let mut this = self.project();

let data = buf.to_vec();
match this.inner.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => match this
.inner
.start_send_unpin(tokio_websockets::Message::binary(data))
{
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))),
},
Poll::Ready(Err(e)) => {
Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
}
Poll::Pending => Poll::Pending,
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project()
.inner
.poll_flush_unpin(cx)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.project()
.inner
.poll_close_unpin(cx)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
}
}

#[cfg(test)]
mod read_op {
use std::sync::Arc;
Expand Down
95 changes: 78 additions & 17 deletions async-nats/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use crate::auth::Auth;
use crate::client::Statistics;
use crate::connection::Connection;
use crate::connection::State;
#[cfg(feature = "websockets")]
use crate::connection::WebSocketAdapter;
use crate::options::CallbackArg1;
use crate::tls;
use crate::AuthError;
Expand Down Expand Up @@ -168,7 +170,11 @@ impl Connector {
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
for socket_addr in socket_addrs {
match self
.try_connect_to(&socket_addr, server_addr.tls_required(), server_addr.host())
.try_connect_to(
&socket_addr,
server_addr.tls_required(),
server_addr.clone(),
)
.await
{
Ok((server_info, mut connection)) => {
Expand Down Expand Up @@ -321,22 +327,76 @@ impl Connector {
&self,
socket_addr: &SocketAddr,
tls_required: bool,
tls_host: &str,
server_addr: ServerAddr,
) -> Result<(ServerInfo, Connection), ConnectError> {
let tcp_stream = tokio::time::timeout(
self.options.connection_timeout,
TcpStream::connect(socket_addr),
)
.await
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??;

tcp_stream.set_nodelay(true)?;
let mut connection = match server_addr.scheme() {
#[cfg(feature = "websockets")]
"ws" => {
let ws = tokio::time::timeout(
self.options.connection_timeout,
tokio_websockets::client::Builder::new()
.uri(format!("{}://{}", server_addr.scheme(), socket_addr).as_str())
.map_err(|err| {
ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
})?
.connect(),
)
.await
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;

let mut connection = Connection::new(
Box::new(tcp_stream),
self.options.read_buffer_capacity.into(),
self.connect_stats.clone(),
);
let con = WebSocketAdapter::new(ws.0);
Connection::new(Box::new(con), 0, self.connect_stats.clone())
}
#[cfg(feature = "websockets")]
"wss" => {
let domain = webpki::types::ServerName::try_from(server_addr.host())
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
let tls_config =
Arc::new(tls::config_tls(&self.options).await.map_err(|err| {
ConnectError::with_source(crate::ConnectErrorKind::Tls, err)
})?);
let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
let ws = tokio::time::timeout(
self.options.connection_timeout,
tokio_websockets::client::Builder::new()
.connector(&tokio_websockets::Connector::Rustls(tls_connector))
.uri(
format!(
"{}://{}:{}",
server_addr.scheme(),
domain.to_str(),
server_addr.port()
)
.as_str(),
)
.map_err(|err| {
ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
})?
.connect(),
)
.await
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
let con = WebSocketAdapter::new(ws.0);
Connection::new(Box::new(con), 0, self.connect_stats.clone())
}
_ => {
let tcp_stream = tokio::time::timeout(
self.options.connection_timeout,
TcpStream::connect(socket_addr),
)
.await
.map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??;
tcp_stream.set_nodelay(true)?;

Connection::new(
Box::new(tcp_stream),
self.options.read_buffer_capacity.into(),
self.connect_stats.clone(),
)
}
};

let tls_connection = |connection: Connection| async {
let tls_config = Arc::new(
Expand All @@ -346,7 +406,7 @@ impl Connector {
);
let tls_connector = tokio_rustls::TlsConnector::from(tls_config);

let domain = webpki::types::ServerName::try_from(tls_host)
let domain = webpki::types::ServerName::try_from(server_addr.host())
.map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;

let tls_stream = tls_connector
Expand All @@ -363,7 +423,7 @@ impl Connector {
// If `tls_first` was set, establish TLS connection before getting INFO.
// There is no point in checking if tls is required, because
// the connection has to be be upgraded to TLS anyway as it's different flow.
if self.options.tls_first {
if self.options.tls_first && !server_addr.is_websocket() {
connection = tls_connection(connection).await?;
}

Expand All @@ -386,6 +446,7 @@ impl Connector {

// If `tls_first` was not set, establish TLS connection if it is required.
if !self.options.tls_first
&& !server_addr.is_websocket()
&& (self.options.tls_required || info.tls_required || tls_required)
{
connection = tls_connection(connection).await?;
Expand Down
14 changes: 13 additions & 1 deletion async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,11 @@ impl FromStr for ServerAddr {
impl ServerAddr {
/// Check if the URL is a valid NATS server address.
pub fn from_url(url: Url) -> io::Result<Self> {
if url.scheme() != "nats" && url.scheme() != "tls" {
if url.scheme() != "nats"
&& url.scheme() != "tls"
&& url.scheme() != "ws"
&& url.scheme() != "wss"
{
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
format!("invalid scheme for NATS server URL: {}", url.scheme()),
Expand All @@ -1480,6 +1484,10 @@ impl ServerAddr {
self.0.username() != ""
}

pub fn scheme(&self) -> &str {
self.0.scheme()
}

/// Returns the host.
pub fn host(&self) -> &str {
match self.0.host() {
Expand All @@ -1493,6 +1501,10 @@ impl ServerAddr {
}
}

pub fn is_websocket(&self) -> bool {
self.0.scheme() == "ws" || self.0.scheme() == "wss"
}

/// Returns the port.
pub fn port(&self) -> u16 {
self.0.port().unwrap_or(4222)
Expand Down
5 changes: 5 additions & 0 deletions async-nats/tests/configs/ws.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
jetstream {}
websocket {
port: 8444
no_tls: true
}
15 changes: 15 additions & 0 deletions async-nats/tests/configs/ws_tls.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
authorization {
user: derek
password: porkchop
timeout: 1
}

websocket {
tls {

cert_file: "./tests/configs/certs/server-cert.pem"
key_file: "./tests/configs/certs/server-key.pem"
ca_file: "./tests/configs/certs/rootCA.pem"
}
port: 8445
}
Loading

0 comments on commit e40abcf

Please sign in to comment.