Skip to content

Commit

Permalink
Handle async server errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Jarema committed Apr 28, 2022
1 parent c55d5e5 commit 474e3ae
Showing 1 changed file with 111 additions and 5 deletions.
116 changes: 111 additions & 5 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
//! # Ok(())
//! # }

use futures::TryFutureExt;
use futures_util::future::FutureExt;
use futures_util::select;
use futures_util::stream::Stream;
Expand Down Expand Up @@ -207,6 +208,7 @@ pub(crate) enum ServerOp {
Info(Box<ServerInfo>),
Ping,
Pong,
Error(String),
Message {
sid: u64,
subject: String,
Expand Down Expand Up @@ -299,6 +301,17 @@ impl Connection {
return Ok(Some(ServerOp::Pong));
}

if self.buffer.starts_with(b"-ERR") {
if let Some(len) = self.buffer.find(b"\r\n") {
let line = std::str::from_utf8(&self.buffer[5..len])
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let error_message = line.trim_matches('\'').to_string();
self.buffer.advance(len + 2);

return Ok(Some(ServerOp::Error(error_message)));
}
}

if self.buffer.starts_with(b"INFO ") {
if let Some(len) = self.buffer.find(b"\r\n") {
let line = std::str::from_utf8(&self.buffer[5..len])
Expand Down Expand Up @@ -633,6 +646,7 @@ impl ConnectionHandler {
pub async fn process(
&mut self,
mut receiver: mpsc::Receiver<Command>,
errors: mpsc::Sender<String>,
) -> Result<(), io::Error> {
loop {
select! {
Expand All @@ -649,7 +663,7 @@ impl ConnectionHandler {

maybe_op_result = self.connection.read_op().fuse() => {
match maybe_op_result {
Ok(Some(server_op)) => if let Err(err) = self.handle_server_op(server_op).await {
Ok(Some(server_op)) => if let Err(err) = self.handle_server_op(server_op, &errors).await {
println!("error handling operation {}", err);
}
Ok(None) => {
Expand All @@ -668,11 +682,26 @@ impl ConnectionHandler {
Ok(())
}

async fn handle_server_op(&mut self, server_op: ServerOp) -> Result<(), io::Error> {
async fn handle_server_op(
&mut self,
server_op: ServerOp,
errors: &mpsc::Sender<String>,
) -> Result<(), io::Error> {
match server_op {
ServerOp::Ping => {
self.connection.write_op(ClientOp::Pong).await?;
}
ServerOp::Error(error) => {
errors
.send(error)
.map_err(|err| {
io::Error::new(
ErrorKind::Other,
"failed to send error message to the errors stream",
)
})
.await?;
}
ServerOp::Message {
sid,
subject,
Expand Down Expand Up @@ -847,23 +876,69 @@ impl ConnectionHandler {
/// Client is a `Clonable` handle to NATS connection.
/// Client should not be created directly. Instead, one of two methods can be used:
/// [connect] and [ConnectOptions::connect]
#[derive(Clone)]
pub struct Client {
sender: mpsc::Sender<Command>,
errors: Option<mpsc::Receiver<String>>,
subscription_context: Arc<Mutex<SubscriptionContext>>,
}

impl Clone for Client {
fn clone(&self) -> Self {
Client {
sender: self.sender.clone(),
errors: None,
subscription_context: self.subscription_context.clone(),
}
}
}

impl Client {
pub(crate) fn new(
sender: mpsc::Sender<Command>,
errors: Option<mpsc::Receiver<String>>,
subscription_context: Arc<Mutex<SubscriptionContext>>,
) -> Client {
Client {
sender,
errors,
subscription_context,
}
}

/// Returns stream of asynchronous errors received from NATS server.
///
/// # Examples
/// ```
/// # use futures_util::StreamExt;
/// # #[tokio::main]
/// # async fn main() -> std::io::Result<()> {
/// let mut nc = async_nats::connect("demo.nats.io").await?;
///
/// let mut errs = nc.errors_stream().await?;
/// tokio::spawn({
/// async move {
/// if let Some(err) = errs.next().await {
/// println!("received error: {}", err);
/// };
/// }
/// });
/// # Ok(())
/// # }
///
/// ```
pub async fn errors_stream(&mut self) -> io::Result<Errors> {
let errors = self.errors.take();
errors.map_or_else(
|| {
Err(io::Error::new(
ErrorKind::AlreadyExists,
"errors stream already consumerd or used on cloned Client",
))
},
|errors| Ok(Errors::new(errors)),
)
}

pub async fn publish(&mut self, subject: String, payload: Bytes) -> Result<(), Error> {
self.sender
.send(Command::Publish {
Expand Down Expand Up @@ -1004,7 +1079,8 @@ pub async fn connect_with_options<A: ToServerAddrs>(

// TODO make channel size configurable
let (sender, receiver) = mpsc::channel(128);
let client = Client::new(sender.clone(), subscription_context);
let (errors_tx, errors_rx) = mpsc::channel(128);
let client = Client::new(sender.clone(), Some(errors_rx), subscription_context);
let connect_info = ConnectInfo {
tls_required: options.tls_required,
// FIXME(tp): have optional name
Expand Down Expand Up @@ -1058,7 +1134,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
}
});

task::spawn(async move { connection_handler.process(receiver).await });
task::spawn(async move { connection_handler.process(receiver, errors_tx).await });

Ok(client)
}
Expand Down Expand Up @@ -1207,6 +1283,24 @@ impl Stream for Subscriber {
}
}

#[derive(Debug)]
pub struct Errors {
receiver: tokio::sync::mpsc::Receiver<String>,
}

impl Errors {
fn new(receiver: tokio::sync::mpsc::Receiver<String>) -> Errors {
Errors { receiver }
}
}

impl Stream for Errors {
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_recv(cx)
}
}

/// Info to construct a CONNECT message.
#[derive(Clone, Debug, Serialize)]
#[doc(hidden)]
Expand Down Expand Up @@ -1417,3 +1511,15 @@ impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
(**self).to_server_addrs()
}
}

#[derive(Clone)]
pub(crate) enum Authorization {
/// No authentication.
None,

/// Authenticate using a token.
Token(String),

/// Authenticate using a username and password.
UsernamePassword(String, String),
}

0 comments on commit 474e3ae

Please sign in to comment.