Skip to content

Commit

Permalink
Merge pull request #276 from TheNeikos/feature/add_keep_alive
Browse files Browse the repository at this point in the history
Add keep alive
  • Loading branch information
TheNeikos authored Apr 4, 2024
2 parents c0fc750 + edd71b6 commit c04592a
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 7 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ debug = ["winnow/debug"]

[dependencies]
futures = "0.3.30"
futures-timer = "3.0.3"
mqtt-format = { version = "0.5.0", path = "mqtt-format", features = [
"yoke",
"mqttv5",
Expand Down
19 changes: 18 additions & 1 deletion cloudmqtt-bin/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//

use std::time::Duration;

use clap::Parser;
use cloudmqtt::client::connect::MqttClientConnector;
use cloudmqtt::client::send::Publish;
Expand Down Expand Up @@ -47,7 +49,7 @@ async fn main() {
connection,
client_id,
cloudmqtt::client::connect::CleanStart::Yes,
cloudmqtt::keep_alive::KeepAlive::Disabled,
cloudmqtt::keep_alive::KeepAlive::Seconds(5.try_into().unwrap()),
);

let client = MqttClient::new_with_default_handlers();
Expand All @@ -69,5 +71,20 @@ async fn main() {

client.ping().await.unwrap().response().await;

tokio::time::sleep(Duration::from_secs(3)).await;

client
.publish(Publish {
topic: "foo/bar".try_into().unwrap(),
qos: cloudmqtt::qos::QualityOfService::AtMostOnce,
retain: false,
payload: vec![123].try_into().unwrap(),
on_packet_recv: None,
})
.await
.unwrap();

tokio::time::sleep(Duration::from_secs(20)).await;

println!("Sent message! Bye");
}
82 changes: 77 additions & 5 deletions src/client/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//

use std::time::Duration;

use futures::select;
use futures::FutureExt;
use futures::SinkExt;
use futures::StreamExt;
Expand All @@ -13,6 +16,7 @@ use tokio_util::codec::FramedWrite;
use super::MqttClient;
use crate::bytes::MqttBytes;
use crate::client::state::OutstandingPackets;
use crate::client::state::TransportWriter;
use crate::client::ConnectState;
use crate::client::SessionState;
use crate::client_identifier::ProposedClientIdentifier;
Expand Down Expand Up @@ -213,6 +217,9 @@ impl MqttClient {
});
}

let (sender, heartbeat_receiver) = futures::channel::mpsc::channel(1);
let conn_write = TransportWriter::new(conn_write, sender);

let (conn_read_sender, conn_read_recv) = futures::channel::oneshot::channel();

let connect_client_state = ConnectState {
Expand All @@ -222,6 +229,15 @@ impl MqttClient {
retain_available: connack.properties.retain_available().map(|ra| ra.0),
maximum_packet_size: connack.properties.maximum_packet_size().map(|mps| mps.0),
topic_alias_maximum: connack.properties.topic_alias_maximum().map(|tam| tam.0),
keep_alive: connack
.properties
.server_keep_alive()
.map(|ska| {
std::num::NonZeroU16::try_from(ska.0)
.map(KeepAlive::Seconds)
.unwrap_or(KeepAlive::Disabled)
})
.unwrap_or(connector.keep_alive),
conn_write,
conn_read_recv,
next_packet_identifier: std::num::NonZeroU16::MIN,
Expand Down Expand Up @@ -257,6 +273,8 @@ impl MqttClient {
};
}

let keep_alive = connect_client_state.keep_alive;

inner.connection_state = Some(connect_client_state);
inner.session_state = Some(SessionState {
client_identifier,
Expand All @@ -267,11 +285,34 @@ impl MqttClient {
crate::packets::connack::ConnackPropertiesView::try_from(maybe_connack)
.expect("An already matched value suddenly changed?");

let background_task = crate::client::receive::handle_background_receiving(
inner_clone,
conn_read,
conn_read_sender,
)
let background_task = async move {
let receiving_inner = inner_clone.clone();
let receiving = crate::client::receive::handle_background_receiving(
receiving_inner,
conn_read,
conn_read_sender,
);

let heartbeat_inner = inner_clone;

let heartbeat = if let KeepAlive::Seconds(time) = keep_alive {
handle_heartbeats(
heartbeat_receiver,
Duration::from_secs(time.get().into()),
heartbeat_inner,
)
.left_future()
} else {
tracing::info!(
"Keep Alive is disabled, will not send PingReq packets automatically"
);
futures::future::ok(()).right_future()
};

tokio::try_join!(receiving, heartbeat)
.map(drop)
.map_err(drop)
}
.boxed();

return Ok(Connected {
Expand All @@ -285,3 +326,34 @@ impl MqttClient {
todo!()
}
}

async fn handle_heartbeats(
mut heartbeat_receiver: futures::channel::mpsc::Receiver<()>,
duration: Duration,
heartbeat_inner: std::sync::Arc<futures::lock::Mutex<super::InnerClient>>,
) -> Result<(), ()> {
let mut timeout = futures_timer::Delay::new(duration).fuse();
loop {
select! {
heartbeat = heartbeat_receiver.next() => match heartbeat {
None => break,
Some(_) => {
timeout = futures_timer::Delay::new(duration).fuse();
},
},
_ = timeout => {
let mut inner = heartbeat_inner.lock().await;
let inner = &mut *inner;
let Some(conn_state) = inner.connection_state.as_mut() else {
todo!();
};

// We make sure that this won't deadlock in the send method
conn_state.conn_write.send(
mqtt_format::v5::packets::MqttPacket::Pingreq(mqtt_format::v5::packets::pingreq::MPingreq)
).await.unwrap();
}
}
}
Ok(())
}
37 changes: 36 additions & 1 deletion src/client/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,63 @@

use std::num::NonZeroU16;

use futures::SinkExt;
use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;

use crate::codecs::MqttPacketCodec;
use crate::codecs::MqttPacketCodecError;
use crate::keep_alive::KeepAlive;
use crate::packet_identifier::PacketIdentifier;
use crate::string::MqttString;
use crate::transport::MqttConnection;

pub(super) struct TransportWriter {
conn: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>,
notify: futures::channel::mpsc::Sender<()>,
}

impl TransportWriter {
pub(super) fn new(
conn: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>,
notify: futures::channel::mpsc::Sender<()>,
) -> Self {
Self { conn, notify }
}

pub(super) async fn send(
&mut self,
packet: mqtt_format::v5::packets::MqttPacket<'_>,
) -> Result<(), MqttPacketCodecError> {
self.conn.send(packet).await?;
if let Err(e) = self.notify.try_send(()) {
if e.is_full() {
// This is fine, we are already notifying of a send
}
if e.is_disconnected() {
todo!("Could not send to heartbeat!?")
}
}

Ok(())
}
}

pub(super) struct ConnectState {
pub(super) session_present: bool,
pub(super) receive_maximum: Option<NonZeroU16>,
pub(super) maximum_qos: Option<mqtt_format::v5::qos::MaximumQualityOfService>,
pub(super) retain_available: Option<bool>,
pub(super) topic_alias_maximum: Option<u16>,
pub(super) maximum_packet_size: Option<u32>,
pub(super) conn_write: FramedWrite<tokio::io::WriteHalf<MqttConnection>, MqttPacketCodec>,
pub(super) conn_write: TransportWriter,

pub(super) conn_read_recv: futures::channel::oneshot::Receiver<
FramedRead<tokio::io::ReadHalf<MqttConnection>, MqttPacketCodec>,
>,

pub(super) next_packet_identifier: std::num::NonZeroU16,
pub(crate) keep_alive: KeepAlive,
}

pub(super) struct SessionState {
Expand Down
1 change: 1 addition & 0 deletions src/keep_alive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use std::num::NonZeroU16;
use std::time::Duration;

#[derive(Debug, Clone, Copy)]
pub enum KeepAlive {
Disabled,
Seconds(NonZeroU16),
Expand Down

0 comments on commit c04592a

Please sign in to comment.