Skip to content

Commit

Permalink
Support multiple async runtimes but not only tokio
Browse files Browse the repository at this point in the history
Solves #26.

changelog: changed
breaking: running with tokio needs additional setups now
  • Loading branch information
kezhuw committed May 10, 2024
1 parent c71c44d commit c5d560f
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 135 deletions.
13 changes: 12 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,18 @@ rustls = { version = "0.23.2", optional = true }
rustls-pemfile = { version = "2", optional = true }
webpki-roots = { version = "0.26.1", optional = true }
derive-where = "1.2.7"
tokio-rustls = "0.26.0"
fastrand = "2.0.2"
tracing = "0.1.40"
rsasl = { version = "2.0.1", default-features = false, features = ["provider", "config_builder", "registry_static", "std"], optional = true }
md5 = { version = "0.7.0", optional = true }
hex = { version = "0.4.3", optional = true }
linkme = { version = "0.2", optional = true }
async-io = "2.3.2"
futures = "0.3.30"
async-net = "2.0.0"
futures-rustls = "0.26.0"
futures-lite = "2.3.0"
asyncs = "0.2.0"

[dev-dependencies]
test-log = { version = "0.2.15", features = ["log", "trace"] }
Expand All @@ -59,9 +64,15 @@ assert_matches = "1.5.0"
tempfile = "3.6.0"
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }
serial_test = "3.0.0"
asyncs = { version = "0.2.0", features = ["test"] }
blocking = "1.6.0"

[package.metadata.cargo-all-features]
skip_optional_dependencies = true

[package.metadata.docs.rs]
all-features = true

[profile.dev]
# Need this for linkme crate to work for spawns in macOS
lto = "thin"
16 changes: 8 additions & 8 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use std::time::Duration;

use const_format::formatcp;
use either::{Either, Left, Right};
use futures::channel::mpsc;
use ignore_result::Ignore;
use thiserror::Error;
use tokio::sync::mpsc;
use tracing::instrument;

pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher};
Expand Down Expand Up @@ -322,9 +322,9 @@ impl Client {

fn send_marshalled_request(&self, request: MarshalledRequest) -> StateReceiver {
let (operation, receiver) = SessionOperation::new_marshalled(request).with_responser();
if let Err(mpsc::error::SendError(operation)) = self.requester.send(operation) {
if let Err(err) = self.requester.unbounded_send(operation) {
let state = self.state();
operation.responser.send(Err(state.to_error()));
err.into_inner().responser.send(Err(state.to_error()));
}
receiver
}
Expand Down Expand Up @@ -514,7 +514,7 @@ impl Client {

// TODO: move these to session side so to eliminate owned Client and String.
fn delete_background(self, path: String) {
tokio::spawn(async move {
asyncs::spawn(async move {
self.delete_foreground(&path).await;
});
}
Expand All @@ -524,7 +524,7 @@ impl Client {
}

fn delete_ephemeral_background(self, prefix: String, unique: bool) {
tokio::spawn(async move {
asyncs::spawn(async move {
let (parent, tree, name) = util::split_path(&prefix);
let mut children = Self::retry_on_connection_loss(|| self.list_children(parent)).await?;
if unique {
Expand Down Expand Up @@ -1673,13 +1673,13 @@ impl Connector {
let mut buf = Vec::with_capacity(4096);
let mut depot = Depot::new();
let conn = session.start(&mut endpoints, &mut buf, &mut depot).await?;
let (sender, receiver) = mpsc::unbounded_channel();
let (sender, receiver) = mpsc::unbounded();
let session_info = session.session.clone();
let session_timeout = session.session_timeout;
let mut state_watcher = StateWatcher::new(state_receiver);
// Consume all state changes so far.
state_watcher.state();
tokio::spawn(async move {
asyncs::spawn(async move {
session.serve(endpoints, conn, buf, depot, receiver).await;
});
let client =
Expand Down Expand Up @@ -2270,7 +2270,7 @@ mod tests {
.is_equal_to(Error::BadArguments(&"directory node must not be sequential"));
}

#[test_log::test(tokio::test)]
#[test_log::test(asyncs::test)]
async fn session_last_zxid_seen() {
use testcontainers::clients::Cli as DockerCli;
use testcontainers::core::{Healthcheck, WaitFor};
Expand Down
25 changes: 12 additions & 13 deletions src/deadline.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::time::{Duration, Instant};

use tokio::time::{self, Instant, Sleep};
use async_io::Timer;
use futures::future::{Fuse, FusedFuture, FutureExt};

pub struct Deadline {
sleep: Option<Sleep>,
timer: Fuse<Timer>,
deadline: Option<Instant>,
}

impl Deadline {
pub fn never() -> Self {
Self { sleep: None }
Self { timer: Timer::never().fuse(), deadline: None }
}

pub fn until(deadline: Instant) -> Self {
Self { sleep: Some(time::sleep_until(deadline)) }
Self { timer: Timer::at(deadline).fuse(), deadline: Some(deadline) }
}

pub fn elapsed(&self) -> bool {
self.sleep.as_ref().map(|f| f.is_elapsed()).unwrap_or(false)
self.timer.is_terminated()
}

/// Remaining timeout.
pub fn timeout(&self) -> Duration {
match self.sleep.as_ref() {
match self.deadline.as_ref() {
None => Duration::MAX,
Some(sleep) => sleep.deadline().saturating_duration_since(Instant::now()),
Some(deadline) => deadline.saturating_duration_since(Instant::now()),
}
}
}
Expand All @@ -35,10 +37,7 @@ impl Future for Deadline {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.sleep.is_none() {
return Poll::Pending;
}
let sleep = unsafe { self.map_unchecked_mut(|deadline| deadline.sleep.as_mut().unwrap_unchecked()) };
sleep.poll(cx)
let timer = unsafe { self.map_unchecked_mut(|deadline| &mut deadline.timer) };
timer.poll(cx).map(|_| ())
}
}
6 changes: 4 additions & 2 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::fmt::{self, Display, Formatter};
use std::time::Duration;

use async_io::Timer;

use crate::chroot::Chroot;
use crate::error::Error;
use crate::util::{Ref, ToRef};
Expand Down Expand Up @@ -219,7 +221,7 @@ impl IterableEndpoints {
async fn delay(&self, index: Index, max_delay: Duration) {
let timeout = max_delay.min(Self::timeout(index, self.endpoints.len()));
if timeout != Duration::ZERO {
tokio::time::sleep(timeout).await;
Timer::after(timeout).await;
}
}

Expand Down Expand Up @@ -336,7 +338,7 @@ mod tests {
);
}

#[tokio::test]
#[asyncs::test]
async fn test_iterable_endpoints_next() {
use std::time::Duration;

Expand Down
48 changes: 26 additions & 22 deletions src/session/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use async_io::Timer;
use async_net::TcpStream;
use asyncs::select;
use bytes::buf::BufMut;
use futures::io::BufReader;
use futures::prelude::*;
use futures_lite::AsyncReadExt;
use ignore_result::Ignore;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
use tokio::net::TcpStream;
use tokio::{select, time};
use tracing::{debug, trace};

#[cfg(feature = "tls")]
mod tls {
pub use std::sync::Arc;

pub use futures_rustls::client::TlsStream;
pub use futures_rustls::TlsConnector;
pub use rustls::pki_types::ServerName;
pub use rustls::ClientConfig;
pub use tokio_rustls::client::TlsStream;
pub use tokio_rustls::TlsConnector;
}
#[cfg(feature = "tls")]
use tls::*;
Expand Down Expand Up @@ -51,7 +54,7 @@ pub trait AsyncReadToBuf: AsyncReadExt {
impl<T> AsyncReadToBuf for T where T: AsyncReadExt {}

impl AsyncRead for Connection {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "tls")]
Expand Down Expand Up @@ -85,11 +88,11 @@ impl AsyncWrite for Connection {
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
match self.get_mut() {
Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Raw(stream) => Pin::new(stream).poll_close(cx),
#[cfg(feature = "tls")]
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_close(cx),
}
}
}
Expand All @@ -99,7 +102,7 @@ pub struct ConnReader<'a> {
}

impl AsyncRead for ConnReader<'_> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>> {
Pin::new(&mut self.get_mut().conn).poll_read(cx, buf)
}
}
Expand All @@ -121,8 +124,8 @@ impl AsyncWrite for ConnWriter<'_> {
Pin::new(&mut self.get_mut().conn).poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().conn).poll_shutdown(cx)
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.get_mut().conn).poll_close(cx)
}
}

Expand All @@ -142,13 +145,14 @@ impl Connection {
Self::Tls(stream)
}

pub async fn command(self, cmd: &str) -> Result<String> {
let mut stream = BufStream::new(self);
stream.write_all(cmd.as_bytes()).await?;
stream.flush().await?;
pub async fn command(mut self, cmd: &str) -> Result<String> {
// let mut stream = BufStream::new(self);
self.write_all(cmd.as_bytes()).await?;
self.flush().await?;
let mut line = String::new();
stream.read_line(&mut line).await?;
stream.shutdown().await.ignore();
let mut reader = BufReader::new(self);
reader.read_line(&mut line).await?;
reader.close().await.ignore();
Ok(line)
}

Expand Down Expand Up @@ -212,7 +216,7 @@ impl Connector {
}
select! {
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
_ = time::sleep(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
_ = Timer::after(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
r = TcpStream::connect((endpoint.host, endpoint.port)) => {
match r {
Err(err) => Err(err),
Expand Down Expand Up @@ -255,10 +259,10 @@ impl Connector {
"fails to contact writable server from endpoints {:?}",
endpoints.endpoints()
);
time::sleep(timeout).await;
Timer::after(timeout).await;
timeout = max_timeout.min(timeout * 2);
} else {
time::sleep(Duration::from_millis(5)).await;
Timer::after(Duration::from_millis(5)).await;
}
}
None
Expand All @@ -273,7 +277,7 @@ mod tests {
use crate::deadline::Deadline;
use crate::endpoint::EndpointRef;

#[tokio::test]
#[asyncs::test]
async fn raw() {
let connector = Connector::new();
let endpoint = EndpointRef::new("host1", 2181, true);
Expand Down
2 changes: 1 addition & 1 deletion src/session/depot.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::collections::VecDeque;
use std::io::IoSlice;

use futures_lite::io::AsyncWriteExt;
use hashbrown::HashMap;
use strum::IntoEnumIterator;
use tokio::io::AsyncWriteExt;
use tracing::debug;

use super::request::{MarshalledRequest, OpStat, Operation, SessionOperation, StateResponser};
Expand Down
Loading

0 comments on commit c5d560f

Please sign in to comment.