Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle async server errors #397

Merged
merged 1 commit into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ use futures_util::future::FutureExt;
use futures_util::select;
use futures_util::stream::Stream;
use futures_util::StreamExt;
use tls::TlsOptions;

use std::cmp;
use std::collections::HashMap;
Expand Down Expand Up @@ -645,7 +646,7 @@ impl Connection {
/// Maintains a list of servers and establishes connections.
pub(crate) struct Connector {
server_addrs: Vec<ServerAddr>,
options: ConnectOptions,
options: TlsOptions,
}

impl Connector {
Expand Down Expand Up @@ -755,14 +756,20 @@ pub(crate) struct ConnectionHandler {
connection: Connection,
connector: Connector,
subscriptions: HashMap<u64, Subscription>,
events: mpsc::Sender<ServerEvent>,
}

impl ConnectionHandler {
pub(crate) fn new(connection: Connection, connector: Connector) -> ConnectionHandler {
pub(crate) fn new(
connection: Connection,
connector: Connector,
events: mpsc::Sender<ServerEvent>,
) -> ConnectionHandler {
ConnectionHandler {
connection,
connector,
subscriptions: HashMap::new(),
events,
}
}

Expand All @@ -789,12 +796,13 @@ impl ConnectionHandler {
println!("error handling operation {}", err);
}
Ok(None) => {
if let Err(err) = self.handle_reconnect().await {
if let Err(err) = self.handle_disconnect().await {
println!("error handling operation {}", err);
} else {
}
}
Err(err) => {
if let Err(err) = self.handle_reconnect().await {
if let Err(err) = self.handle_disconnect().await {
println!("error handling operation {}", err);
}
},
Expand All @@ -813,6 +821,9 @@ impl ConnectionHandler {
ServerOp::Ping => {
self.connection.write_op(ClientOp::Pong).await?;
}
ServerOp::Error(error) => {
self.events.try_send(ServerEvent::Error(error)).ok();
}
ServerOp::Message {
sid,
subject,
Expand Down Expand Up @@ -887,12 +898,12 @@ impl ConnectionHandler {
}
Command::Ping => {
if let Err(err) = self.connection.write_op(ClientOp::Ping).await {
self.handle_reconnect().await?;
self.handle_disconnect().await?;
}
}
Command::Flush { result } => {
if let Err(err) = self.connection.flush().await {
if let Err(err) = self.handle_reconnect().await {
if let Err(err) = self.handle_disconnect().await {
result.send(Err(err)).map_err(|_| {
io::Error::new(io::ErrorKind::Other, "one shot failed to be received")
})?;
Expand Down Expand Up @@ -956,7 +967,7 @@ impl ConnectionHandler {
})
.await
{
self.handle_reconnect().await?;
self.handle_disconnect().await?;
println!("Sending Publish failed with {:?}", err);
}
}
Expand All @@ -966,14 +977,21 @@ impl ConnectionHandler {
.write_op(ClientOp::Connect(connect_info.clone()))
.await
{
self.handle_reconnect().await?;
self.handle_disconnect().await?;
}
}
}

Ok(())
}

async fn handle_disconnect(&mut self) -> io::Result<()> {
self.events.try_send(ServerEvent::Disconnect).ok();
self.handle_reconnect().await?;
self.events.try_send(ServerEvent::Reconnect).ok();
Ok(())
}

async fn handle_reconnect(&mut self) -> Result<(), io::Error> {
let (_, connection) = self.connector.connect().await?;
self.connection = connection;
Expand All @@ -988,7 +1006,7 @@ impl ConnectionHandler {
.await
.unwrap();
}

self.events.try_send(ServerEvent::Reconnect).ok();
Ok(())
}
}
Expand Down Expand Up @@ -1187,19 +1205,34 @@ pub async fn connect_with_options<A: ToServerAddrs>(
addrs: A,
options: ConnectOptions,
) -> Result<Client, io::Error> {
let tls_required = options.tls_required;
let ping_interval = options.ping_interval;
let flush_interval = options.flush_interval;

let tls_options = TlsOptions {
tls_required: options.tls_required,
certificates: options.certificates,
client_key: options.client_key,
client_cert: options.client_cert,
tls_client_config: options.tls_client_config,
};

let mut connector = Connector {
server_addrs: addrs.to_server_addrs()?.into_iter().collect(),
options: options.clone(),
options: tls_options,
};

let (_, connection) = connector.try_connect().await?;
let mut connection_handler = ConnectionHandler::new(connection, connector);
let (events_tx, mut events_rx) = mpsc::channel(128);

let mut connection_handler = ConnectionHandler::new(connection, connector, events_tx);

// TODO make channel size configurable
let (sender, receiver) = mpsc::channel(128);

let client = Client::new(sender.clone());
let mut connect_info = ConnectInfo {
tls_required: options.tls_required,
tls_required,
// FIXME(tp): have optional name
name: Some("beta-rust-client".to_string()),
pedantic: false,
Expand Down Expand Up @@ -1250,7 +1283,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
let sender = sender.clone();
async move {
loop {
tokio::time::sleep(options.ping_interval).await;
tokio::time::sleep(ping_interval).await;
match sender.send(Command::Ping).await {
Ok(()) => {}
Err(_) => return,
Expand All @@ -1261,19 +1294,35 @@ pub async fn connect_with_options<A: ToServerAddrs>(

tokio::spawn(async move {
loop {
tokio::time::sleep(options.flush_interval).await;
tokio::time::sleep(flush_interval).await;
match sender.send(Command::TryFlush).await {
Ok(()) => {}
Err(_) => return,
}
}
});

task::spawn(async move {
while let Some(event) = events_rx.recv().await {
match event {
ServerEvent::Reconnect => options.reconnect_callback.call().await,
ServerEvent::Disconnect => options.disconnect_callback.call().await,
ServerEvent::Error(error) => options.error_callback.call(error).await,
}
}
});

task::spawn(async move { connection_handler.process(receiver).await });

Ok(client)
}

pub(crate) enum ServerEvent {
Reconnect,
Disconnect,
Error(ServerError),
}

/// Connects to NATS with default config.
///
/// Returns clonable [Client].
Expand Down
Loading