diff --git a/rpc/Cargo.toml b/rpc/Cargo.toml index a2c81265..02850e72 100644 --- a/rpc/Cargo.toml +++ b/rpc/Cargo.toml @@ -18,11 +18,13 @@ serde1 = ["trace/serde", "serde", "serde/derive"] tokio1 = ["tokio"] [dependencies] +derivative = "1" fnv = "1.0" futures-preview = { version = "0.3.0-alpha.18" } humantime = "1.0" log = "0.4" pin-utils = "0.1.0-alpha.4" +raii-counter = "0.2" rand = "0.7" tokio-timer = "0.3.0-alpha.4" trace = { package = "tarpc-trace", version = "0.2", path = "../trace" } diff --git a/rpc/src/client/mod.rs b/rpc/src/client/mod.rs index ee563680..e7dce91c 100644 --- a/rpc/src/client/mod.rs +++ b/rpc/src/client/mod.rs @@ -103,7 +103,6 @@ where } /// Settings that control the behavior of the client. -#[non_exhaustive] #[derive(Clone, Debug)] pub struct Config { /// The number of requests that can be in flight at once. diff --git a/rpc/src/context.rs b/rpc/src/context.rs index 83da870f..968e2dae 100644 --- a/rpc/src/context.rs +++ b/rpc/src/context.rs @@ -17,7 +17,6 @@ use trace::{self, TraceId}; /// be different for each request in scope. #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. diff --git a/rpc/src/lib.rs b/rpc/src/lib.rs index 8da8efcb..757007ed 100644 --- a/rpc/src/lib.rs +++ b/rpc/src/lib.rs @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![feature(weak_counts, non_exhaustive, trait_alias)] #![deny(missing_docs, missing_debug_implementations)] //! An RPC framework providing client and server. @@ -31,7 +30,7 @@ pub mod server; pub mod transport; pub(crate) mod util; -pub use crate::{client::Client, server::Server, transport::Transport}; +pub use crate::{client::Client, server::Server, transport::sealed::Transport}; use futures::task::Poll; use std::{io, time::SystemTime}; @@ -39,7 +38,6 @@ use std::{io, time::SystemTime}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which @@ -65,7 +63,6 @@ pub enum ClientMessage { /// A request from a client to a server. #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. pub context: context::Context, @@ -78,7 +75,6 @@ pub struct Request { /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct Response { /// The ID of the request being responded to. pub request_id: u64, @@ -89,7 +85,6 @@ pub struct Response { /// An error response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct ServerError { #[cfg_attr( feature = "serde1", diff --git a/rpc/src/server/filter.rs b/rpc/src/server/filter.rs index 7dd17582..ae762127 100644 --- a/rpc/src/server/filter.rs +++ b/rpc/src/server/filter.rs @@ -8,6 +8,7 @@ use crate::{ server::{self, Channel}, util::Compact, }; +use derivative::Derivative; use fnv::FnvHashMap; use futures::{ channel::mpsc, @@ -19,6 +20,7 @@ use futures::{ }; use log::{debug, info, trace}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use raii_counter::WeakCounter; use std::sync::{Arc, Weak}; use std::{ collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, @@ -28,42 +30,59 @@ use std::{ #[derive(Debug)] pub struct ChannelFilter where - K: Eq + Hash, + K: Eq + Hash + fmt::Debug, { listener: Fuse, channels_per_key: u32, dropped_keys: mpsc::UnboundedReceiver, dropped_keys_tx: mpsc::UnboundedSender, - key_counts: FnvHashMap>>, + key_counts: FnvHashMap>, keymaker: F, } /// A channel that is tracked by a ChannelFilter. #[derive(Debug)] -pub struct TrackedChannel { +pub struct TrackedChannel { inner: C, - tracker: Arc>, + tracker: Tracker, } -impl TrackedChannel { +impl TrackedChannel { unsafe_pinned!(inner: C); } -#[derive(Debug)] -struct Tracker { - key: Option, +#[derive(Clone, Derivative)] +#[derivative(Debug)] +struct Tracker { + key: Option>, + #[derivative(Debug = "ignore")] + counter: raii_counter::Counter, dropped_keys: mpsc::UnboundedSender, } -impl Drop for Tracker { +impl Drop for Tracker { fn drop(&mut self) { - // Don't care if the listener is dropped. - let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap()); + if self.counter.count() <= 1 { + // Don't care if the listener is dropped. + let _ = self + .dropped_keys + .unbounded_send(Arc::try_unwrap(self.key.take().unwrap()).unwrap()); + } } } +#[derive(Clone, Derivative)] +#[derivative(Debug)] +struct TrackerPrototype { + key: Weak, + #[derivative(Debug = "ignore")] + counter: WeakCounter, + dropped_keys: mpsc::UnboundedSender, +} + impl Stream for TrackedChannel where + K: fmt::Debug, C: Stream, { type Item = ::Item; @@ -75,6 +94,7 @@ where impl Sink for TrackedChannel where + K: fmt::Debug, C: Sink, { type Error = C::Error; @@ -96,7 +116,7 @@ where } } -impl AsRef for TrackedChannel { +impl AsRef for TrackedChannel { fn as_ref(&self) -> &C { &self.inner } @@ -104,6 +124,7 @@ impl AsRef for TrackedChannel { impl Channel for TrackedChannel where + K: fmt::Debug, C: Channel, { type Req = C::Req; @@ -122,7 +143,7 @@ where } } -impl TrackedChannel { +impl TrackedChannel { /// Returns the inner channel. pub fn get_ref(&self) -> &C { &self.inner @@ -136,19 +157,19 @@ impl TrackedChannel { impl ChannelFilter where - K: fmt::Display + Eq + Hash + Clone, + K: fmt::Display + fmt::Debug + Eq + Hash + Clone, { unsafe_pinned!(listener: Fuse); unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver); unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender); - unsafe_unpinned!(key_counts: FnvHashMap>>); + unsafe_unpinned!(key_counts: FnvHashMap>); unsafe_unpinned!(channels_per_key: u32); unsafe_unpinned!(keymaker: F); } impl ChannelFilter where - K: Eq + Hash, + K: fmt::Debug + Eq + Hash, S: Stream, F: Fn(&S::Item) -> K, { @@ -169,7 +190,7 @@ where impl ChannelFilter where S: Stream, - K: fmt::Display + Eq + Hash + Clone + Unpin, + K: fmt::Display + fmt::Debug + Eq + Hash + Clone + Unpin, F: Fn(&S::Item) -> K, { fn handle_new_channel( @@ -182,7 +203,7 @@ where trace!( "[{}] Opening channel ({}/{}) channels for key.", key, - Arc::strong_count(&tracker), + tracker.counter.count(), self.as_mut().channels_per_key() ); @@ -192,22 +213,28 @@ where }) } - fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result>, K> { + fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result, K> { let channels_per_key = self.channels_per_key; let dropped_keys = self.dropped_keys_tx.clone(); let key_counts = &mut self.as_mut().key_counts(); match key_counts.entry(key.clone()) { Entry::Vacant(vacant) => { - let tracker = Arc::new(Tracker { + let key = Arc::new(key); + let counter = WeakCounter::new(); + + vacant.insert(TrackerPrototype { + key: Arc::downgrade(&key), + counter: counter.clone(), + dropped_keys: dropped_keys.clone(), + }); + Ok(Tracker { key: Some(key), + counter: counter.upgrade(), dropped_keys, - }); - - vacant.insert(Arc::downgrade(&tracker)); - Ok(tracker) + }) } - Entry::Occupied(mut o) => { - let count = o.get().strong_count(); + Entry::Occupied(o) => { + let count = o.get().counter.count(); if count >= channels_per_key.try_into().unwrap() { info!( "[{}] Opened max channels from key ({}/{}).", @@ -215,15 +242,16 @@ where ); Err(key) } else { - Ok(o.get().upgrade().unwrap_or_else(|| { - let tracker = Arc::new(Tracker { - key: Some(key), - dropped_keys, - }); - - *o.get_mut() = Arc::downgrade(&tracker); - tracker - })) + let TrackerPrototype { + key, + counter, + dropped_keys, + } = o.get().clone(); + Ok(Tracker { + counter: counter.upgrade(), + key: Some(key.upgrade().unwrap()), + dropped_keys, + }) } } } @@ -255,7 +283,7 @@ where impl Stream for ChannelFilter where S: Stream, - K: fmt::Display + Eq + Hash + Clone + Unpin, + K: fmt::Display + fmt::Debug + Eq + Hash + Clone + Unpin, F: Fn(&S::Item) -> K, { type Item = TrackedChannel; @@ -296,10 +324,12 @@ fn ctx() -> Context<'static> { #[test] fn tracker_drop() { use assert_matches::assert_matches; + use raii_counter::Counter; let (tx, mut rx) = mpsc::unbounded(); Tracker { - key: Some(1), + key: Some(Arc::new(1)), + counter: Counter::new(), dropped_keys: tx, }; assert_matches!(rx.try_next(), Ok(Some(1))); @@ -309,15 +339,17 @@ fn tracker_drop() { fn tracked_channel_stream() { use assert_matches::assert_matches; use pin_utils::pin_mut; + use raii_counter::Counter; let (chan_tx, chan) = mpsc::unbounded(); let (dropped_keys, _) = mpsc::unbounded(); let channel = TrackedChannel { inner: chan, - tracker: Arc::new(Tracker { - key: Some(1), + tracker: Tracker { + key: Some(Arc::new(1)), + counter: Counter::new(), dropped_keys, - }), + }, }; chan_tx.unbounded_send("test").unwrap(); @@ -329,15 +361,17 @@ fn tracked_channel_stream() { fn tracked_channel_sink() { use assert_matches::assert_matches; use pin_utils::pin_mut; + use raii_counter::Counter; let (chan, mut chan_rx) = mpsc::unbounded(); let (dropped_keys, _) = mpsc::unbounded(); let channel = TrackedChannel { inner: chan, - tracker: Arc::new(Tracker { - key: Some(1), + tracker: Tracker { + key: Some(Arc::new(1)), + counter: Counter::new(), dropped_keys, - }), + }, }; pin_mut!(channel); @@ -359,12 +393,12 @@ fn channel_filter_increment_channels_for_key() { let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); - assert_eq!(Arc::strong_count(&tracker1), 1); + assert_eq!(tracker1.counter.count(), 1); let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap(); - assert_eq!(Arc::strong_count(&tracker1), 2); + assert_eq!(tracker1.counter.count(), 2); assert_matches!(filter.increment_channels_for_key("key"), Err("key")); drop(tracker2); - assert_eq!(Arc::strong_count(&tracker1), 1); + assert_eq!(tracker1.counter.count(), 1); } #[test] @@ -383,20 +417,20 @@ fn channel_filter_handle_new_channel() { .as_mut() .handle_new_channel(TestChannel { key: "key" }) .unwrap(); - assert_eq!(Arc::strong_count(&channel1.tracker), 1); + assert_eq!(channel1.tracker.counter.count(), 1); let channel2 = filter .as_mut() .handle_new_channel(TestChannel { key: "key" }) .unwrap(); - assert_eq!(Arc::strong_count(&channel1.tracker), 2); + assert_eq!(channel1.tracker.counter.count(), 2); assert_matches!( filter.handle_new_channel(TestChannel { key: "key" }), Err("key") ); drop(channel2); - assert_eq!(Arc::strong_count(&channel1.tracker), 1); + assert_eq!(channel1.tracker.counter.count(), 1); } #[test] @@ -417,14 +451,14 @@ fn channel_filter_poll_listener() { .unwrap(); let channel1 = assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); - assert_eq!(Arc::strong_count(&channel1.tracker), 1); + assert_eq!(channel1.tracker.counter.count(), 1); new_channels .unbounded_send(TestChannel { key: "key" }) .unwrap(); let _channel2 = assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); - assert_eq!(Arc::strong_count(&channel1.tracker), 2); + assert_eq!(channel1.tracker.counter.count(), 2); new_channels .unbounded_send(TestChannel { key: "key" }) @@ -432,7 +466,7 @@ fn channel_filter_poll_listener() { let key = assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k); assert_eq!(key, "key"); - assert_eq!(Arc::strong_count(&channel1.tracker), 2); + assert_eq!(channel1.tracker.counter.count(), 2); } #[test] diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index d419b251..26b4d925 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -49,7 +49,6 @@ impl Default for Server { } /// Settings that control the behavior of the server. -#[non_exhaustive] #[derive(Clone, Debug)] pub struct Config { /// The number of responses per client that can be buffered server-side before being sent. @@ -134,7 +133,7 @@ where /// Enforces channel per-key limits. fn max_channels_per_key(self, n: u32, keymaker: KF) -> filter::ChannelFilter where - K: fmt::Display + Eq + Hash + Clone + Unpin, + K: fmt::Display + fmt::Debug + Eq + Hash + Clone + Unpin, KF: Fn(&C) -> K, { ChannelFilter::new(self, n, keymaker) diff --git a/rpc/src/transport/mod.rs b/rpc/src/transport/mod.rs index e934aef9..911f9552 100644 --- a/rpc/src/transport/mod.rs +++ b/rpc/src/transport/mod.rs @@ -14,6 +14,17 @@ use std::io; pub mod channel; -/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. -pub trait Transport = - Stream> + Sink; +pub(crate) mod sealed { + use super::*; + + /// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. + pub trait Transport: + Stream> + Sink + { + } + + impl Transport for T where + T: Stream> + Sink + ?Sized + { + } +}