From e91005855cc26517a5775ecfae40626329850f7b Mon Sep 17 00:00:00 2001 From: Artem Vorotnikov Date: Tue, 1 Oct 2019 16:33:14 +0300 Subject: [PATCH] Remove remaining feature flags --- .travis.yml | 2 +- rpc/Cargo.toml | 1 + rpc/src/client/channel.rs | 3 + rpc/src/client/mod.rs | 4 +- rpc/src/context.rs | 4 +- rpc/src/lib.rs | 15 +++-- rpc/src/server/filter.rs | 111 +++++++++++++++++++++++-------------- rpc/src/server/mod.rs | 4 +- rpc/src/server/testing.rs | 2 + rpc/src/server/throttle.rs | 4 ++ rpc/src/transport/mod.rs | 17 +++++- tarpc/src/lib.rs | 1 - 12 files changed, 113 insertions(+), 55 deletions(-) diff --git a/.travis.yml b/.travis.yml index a4ded2f4..b29bf187 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: rust rust: - - nightly + - beta sudo: false cache: cargo diff --git a/rpc/Cargo.toml b/rpc/Cargo.toml index a2c81265..531fb0f4 100644 --- a/rpc/Cargo.toml +++ b/rpc/Cargo.toml @@ -23,6 +23,7 @@ futures-preview = { version = "0.3.0-alpha.18" } humantime = "1.0" log = "0.4" pin-utils = "0.1.0-alpha.4" +raii-counter = "0.2" rand = "0.7" tokio-timer = "0.3.0-alpha.4" trace = { package = "tarpc-trace", version = "0.2", path = "../trace" } diff --git a/rpc/src/client/channel.rs b/rpc/src/client/channel.rs index e07c1079..70a44750 100644 --- a/rpc/src/client/channel.rs +++ b/rpc/src/client/channel.rs @@ -396,7 +396,9 @@ where context: context::Context { deadline: dispatch_request.ctx.deadline, trace_context: dispatch_request.ctx.trace_context, + _non_exhaustive: (), }, + _non_exhaustive: (), }); self.as_mut().transport().start_send(request)?; self.as_mut().in_flight_requests().insert( @@ -798,6 +800,7 @@ mod tests { Response { request_id: 0, message: Ok("hello".into()), + _non_exhaustive: (), }, ); block_on(dispatch).unwrap(); diff --git a/rpc/src/client/mod.rs b/rpc/src/client/mod.rs index ee563680..4d490e57 100644 --- a/rpc/src/client/mod.rs +++ b/rpc/src/client/mod.rs @@ -103,7 +103,6 @@ where } /// Settings that control the behavior of the client. -#[non_exhaustive] #[derive(Clone, Debug)] pub struct Config { /// The number of requests that can be in flight at once. @@ -114,6 +113,8 @@ pub struct Config { /// `pending_requests_buffer` controls the size of the channel clients use /// to communicate with the request dispatch task. pub pending_request_buffer: usize, + #[doc(hidden)] + _non_exhaustive: (), } impl Default for Config { @@ -121,6 +122,7 @@ impl Default for Config { Config { max_in_flight_requests: 1_000, pending_request_buffer: 100, + _non_exhaustive: (), } } } diff --git a/rpc/src/context.rs b/rpc/src/context.rs index 83da870f..2b2ef5dc 100644 --- a/rpc/src/context.rs +++ b/rpc/src/context.rs @@ -17,7 +17,6 @@ use trace::{self, TraceId}; /// be different for each request in scope. #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct Context { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. @@ -36,6 +35,8 @@ pub struct Context { /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. pub trace_context: trace::Context, + #[doc(hidden)] + pub(crate) _non_exhaustive: (), } #[cfg(feature = "serde1")] @@ -49,6 +50,7 @@ pub fn current() -> Context { Context { deadline: SystemTime::now() + Duration::from_secs(10), trace_context: trace::Context::new_root(), + _non_exhaustive: (), } } diff --git a/rpc/src/lib.rs b/rpc/src/lib.rs index 8da8efcb..6d45aa37 100644 --- a/rpc/src/lib.rs +++ b/rpc/src/lib.rs @@ -4,7 +4,6 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -#![feature(weak_counts, non_exhaustive, trait_alias)] #![deny(missing_docs, missing_debug_implementations)] //! An RPC framework providing client and server. @@ -31,7 +30,7 @@ pub mod server; pub mod transport; pub(crate) mod util; -pub use crate::{client::Client, server::Server, transport::Transport}; +pub use crate::{client::Client, server::Server, transport::sealed::Transport}; use futures::task::Poll; use std::{io, time::SystemTime}; @@ -39,7 +38,6 @@ use std::{io, time::SystemTime}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which @@ -60,12 +58,13 @@ pub enum ClientMessage { /// The ID of the request to cancel. request_id: u64, }, + #[doc(hidden)] + _NonExhaustive, } /// A request from a client to a server. #[derive(Clone, Copy, Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. pub context: context::Context, @@ -73,23 +72,25 @@ pub struct Request { pub id: u64, /// The request body. pub message: T, + #[doc(hidden)] + _non_exhaustive: (), } /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct Response { /// The ID of the request being responded to. pub request_id: u64, /// The response body, or an error if the request failed. pub message: Result, + #[doc(hidden)] + _non_exhaustive: (), } /// An error response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -#[non_exhaustive] pub struct ServerError { #[cfg_attr( feature = "serde1", @@ -103,6 +104,8 @@ pub struct ServerError { pub kind: io::ErrorKind, /// A message describing more detail about the error that occurred. pub detail: Option, + #[doc(hidden)] + _non_exhaustive: (), } impl From for io::Error { diff --git a/rpc/src/server/filter.rs b/rpc/src/server/filter.rs index 7dd17582..c59daa38 100644 --- a/rpc/src/server/filter.rs +++ b/rpc/src/server/filter.rs @@ -19,6 +19,7 @@ use futures::{ }; use log::{debug, info, trace}; use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use raii_counter::{Counter, WeakCounter}; use std::sync::{Arc, Weak}; use std::{ collections::hash_map::Entry, convert::TryInto, fmt, hash::Hash, marker::Unpin, pin::Pin, @@ -34,7 +35,7 @@ where channels_per_key: u32, dropped_keys: mpsc::UnboundedReceiver, dropped_keys_tx: mpsc::UnboundedSender, - key_counts: FnvHashMap>>, + key_counts: FnvHashMap>, keymaker: F, } @@ -42,26 +43,41 @@ where #[derive(Debug)] pub struct TrackedChannel { inner: C, - tracker: Arc>, + tracker: Tracker, } impl TrackedChannel { unsafe_pinned!(inner: C); } -#[derive(Debug)] +#[derive(Clone, Debug)] struct Tracker { - key: Option, + key: Option>, + counter: Counter, dropped_keys: mpsc::UnboundedSender, } impl Drop for Tracker { fn drop(&mut self) { - // Don't care if the listener is dropped. - let _ = self.dropped_keys.unbounded_send(self.key.take().unwrap()); + if self.counter.count() <= 1 { + // Don't care if the listener is dropped. + match Arc::try_unwrap(self.key.take().unwrap()) { + Ok(key) => { + let _ = self.dropped_keys.unbounded_send(key); + } + _ => unreachable!(), + } + } } } +#[derive(Clone, Debug)] +struct TrackerPrototype { + key: Weak, + counter: WeakCounter, + dropped_keys: mpsc::UnboundedSender, +} + impl Stream for TrackedChannel where C: Stream, @@ -141,7 +157,7 @@ where unsafe_pinned!(listener: Fuse); unsafe_pinned!(dropped_keys: mpsc::UnboundedReceiver); unsafe_pinned!(dropped_keys_tx: mpsc::UnboundedSender); - unsafe_unpinned!(key_counts: FnvHashMap>>); + unsafe_unpinned!(key_counts: FnvHashMap>); unsafe_unpinned!(channels_per_key: u32); unsafe_unpinned!(keymaker: F); } @@ -182,7 +198,7 @@ where trace!( "[{}] Opening channel ({}/{}) channels for key.", key, - Arc::strong_count(&tracker), + tracker.counter.count(), self.as_mut().channels_per_key() ); @@ -192,22 +208,28 @@ where }) } - fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result>, K> { + fn increment_channels_for_key(mut self: Pin<&mut Self>, key: K) -> Result, K> { let channels_per_key = self.channels_per_key; let dropped_keys = self.dropped_keys_tx.clone(); let key_counts = &mut self.as_mut().key_counts(); match key_counts.entry(key.clone()) { Entry::Vacant(vacant) => { - let tracker = Arc::new(Tracker { + let key = Arc::new(key); + let counter = WeakCounter::new(); + + vacant.insert(TrackerPrototype { + key: Arc::downgrade(&key), + counter: counter.clone(), + dropped_keys: dropped_keys.clone(), + }); + Ok(Tracker { key: Some(key), + counter: counter.upgrade(), dropped_keys, - }); - - vacant.insert(Arc::downgrade(&tracker)); - Ok(tracker) + }) } - Entry::Occupied(mut o) => { - let count = o.get().strong_count(); + Entry::Occupied(o) => { + let count = o.get().counter.count(); if count >= channels_per_key.try_into().unwrap() { info!( "[{}] Opened max channels from key ({}/{}).", @@ -215,15 +237,16 @@ where ); Err(key) } else { - Ok(o.get().upgrade().unwrap_or_else(|| { - let tracker = Arc::new(Tracker { - key: Some(key), - dropped_keys, - }); - - *o.get_mut() = Arc::downgrade(&tracker); - tracker - })) + let TrackerPrototype { + key, + counter, + dropped_keys, + } = o.get().clone(); + Ok(Tracker { + counter: counter.upgrade(), + key: Some(key.upgrade().unwrap()), + dropped_keys, + }) } } } @@ -296,10 +319,12 @@ fn ctx() -> Context<'static> { #[test] fn tracker_drop() { use assert_matches::assert_matches; + use raii_counter::Counter; let (tx, mut rx) = mpsc::unbounded(); Tracker { - key: Some(1), + key: Some(Arc::new(1)), + counter: Counter::new(), dropped_keys: tx, }; assert_matches!(rx.try_next(), Ok(Some(1))); @@ -309,15 +334,17 @@ fn tracker_drop() { fn tracked_channel_stream() { use assert_matches::assert_matches; use pin_utils::pin_mut; + use raii_counter::Counter; let (chan_tx, chan) = mpsc::unbounded(); let (dropped_keys, _) = mpsc::unbounded(); let channel = TrackedChannel { inner: chan, - tracker: Arc::new(Tracker { - key: Some(1), + tracker: Tracker { + key: Some(Arc::new(1)), + counter: Counter::new(), dropped_keys, - }), + }, }; chan_tx.unbounded_send("test").unwrap(); @@ -329,15 +356,17 @@ fn tracked_channel_stream() { fn tracked_channel_sink() { use assert_matches::assert_matches; use pin_utils::pin_mut; + use raii_counter::Counter; let (chan, mut chan_rx) = mpsc::unbounded(); let (dropped_keys, _) = mpsc::unbounded(); let channel = TrackedChannel { inner: chan, - tracker: Arc::new(Tracker { - key: Some(1), + tracker: Tracker { + key: Some(Arc::new(1)), + counter: Counter::new(), dropped_keys, - }), + }, }; pin_mut!(channel); @@ -359,12 +388,12 @@ fn channel_filter_increment_channels_for_key() { let filter = ChannelFilter::new(listener, 2, |chan: &TestChannel| chan.key); pin_mut!(filter); let tracker1 = filter.as_mut().increment_channels_for_key("key").unwrap(); - assert_eq!(Arc::strong_count(&tracker1), 1); + assert_eq!(tracker1.counter.count(), 1); let tracker2 = filter.as_mut().increment_channels_for_key("key").unwrap(); - assert_eq!(Arc::strong_count(&tracker1), 2); + assert_eq!(tracker1.counter.count(), 2); assert_matches!(filter.increment_channels_for_key("key"), Err("key")); drop(tracker2); - assert_eq!(Arc::strong_count(&tracker1), 1); + assert_eq!(tracker1.counter.count(), 1); } #[test] @@ -383,20 +412,20 @@ fn channel_filter_handle_new_channel() { .as_mut() .handle_new_channel(TestChannel { key: "key" }) .unwrap(); - assert_eq!(Arc::strong_count(&channel1.tracker), 1); + assert_eq!(channel1.tracker.counter.count(), 1); let channel2 = filter .as_mut() .handle_new_channel(TestChannel { key: "key" }) .unwrap(); - assert_eq!(Arc::strong_count(&channel1.tracker), 2); + assert_eq!(channel1.tracker.counter.count(), 2); assert_matches!( filter.handle_new_channel(TestChannel { key: "key" }), Err("key") ); drop(channel2); - assert_eq!(Arc::strong_count(&channel1.tracker), 1); + assert_eq!(channel1.tracker.counter.count(), 1); } #[test] @@ -417,14 +446,14 @@ fn channel_filter_poll_listener() { .unwrap(); let channel1 = assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); - assert_eq!(Arc::strong_count(&channel1.tracker), 1); + assert_eq!(channel1.tracker.counter.count(), 1); new_channels .unbounded_send(TestChannel { key: "key" }) .unwrap(); let _channel2 = assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Ok(c))) => c); - assert_eq!(Arc::strong_count(&channel1.tracker), 2); + assert_eq!(channel1.tracker.counter.count(), 2); new_channels .unbounded_send(TestChannel { key: "key" }) @@ -432,7 +461,7 @@ fn channel_filter_poll_listener() { let key = assert_matches!(filter.as_mut().poll_listener(&mut ctx()), Poll::Ready(Some(Err(k))) => k); assert_eq!(key, "key"); - assert_eq!(Arc::strong_count(&channel1.tracker), 2); + assert_eq!(channel1.tracker.counter.count(), 2); } #[test] diff --git a/rpc/src/server/mod.rs b/rpc/src/server/mod.rs index d419b251..27d430b5 100644 --- a/rpc/src/server/mod.rs +++ b/rpc/src/server/mod.rs @@ -49,7 +49,6 @@ impl Default for Server { } /// Settings that control the behavior of the server. -#[non_exhaustive] #[derive(Clone, Debug)] pub struct Config { /// The number of responses per client that can be buffered server-side before being sent. @@ -307,6 +306,7 @@ where } => { self.as_mut().cancel_request(&trace_context, request_id); } + ClientMessage::_NonExhaustive => unreachable!(), }, None => return Poll::Ready(None), } @@ -582,9 +582,11 @@ where "Response did not complete before deadline of {}s.", format_rfc3339(self.deadline) )), + _non_exhaustive: (), }) } }, + _non_exhaustive: (), }); *self.as_mut().state() = RespState::PollReady; } diff --git a/rpc/src/server/testing.rs b/rpc/src/server/testing.rs index 5ba0455b..17a5efe1 100644 --- a/rpc/src/server/testing.rs +++ b/rpc/src/server/testing.rs @@ -85,9 +85,11 @@ impl FakeChannel>, Response> { context: context::Context { deadline: SystemTime::UNIX_EPOCH, trace_context: Default::default(), + _non_exhaustive: (), }, id, message, + _non_exhaustive: (), })); } } diff --git a/rpc/src/server/throttle.rs b/rpc/src/server/throttle.rs index cedb219e..7c2ee79d 100644 --- a/rpc/src/server/throttle.rs +++ b/rpc/src/server/throttle.rs @@ -66,7 +66,9 @@ where message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: Some("Server throttled the request.".into()), + _non_exhaustive: (), }), + _non_exhaustive: (), })?; } None => return Poll::Ready(None), @@ -315,6 +317,7 @@ fn throttler_start_send() { .start_send(Response { request_id: 0, message: Ok(1), + _non_exhaustive: (), }) .unwrap(); assert!(throttler.inner.in_flight_requests.is_empty()); @@ -323,6 +326,7 @@ fn throttler_start_send() { Some(&Response { request_id: 0, message: Ok(1), + _non_exhaustive: () }) ); } diff --git a/rpc/src/transport/mod.rs b/rpc/src/transport/mod.rs index e934aef9..911f9552 100644 --- a/rpc/src/transport/mod.rs +++ b/rpc/src/transport/mod.rs @@ -14,6 +14,17 @@ use std::io; pub mod channel; -/// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. -pub trait Transport = - Stream> + Sink; +pub(crate) mod sealed { + use super::*; + + /// A bidirectional stream ([`Sink`] + [`Stream`]) of messages. + pub trait Transport: + Stream> + Sink + { + } + + impl Transport for T where + T: Stream> + Sink + ?Sized + { + } +} diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 41f8f2ee..b2efce78 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -203,7 +203,6 @@ //! items expanded by a `service!` invocation. #![deny(missing_docs, missing_debug_implementations)] -#![feature(external_doc)] pub use rpc::*; /// The main macro that creates RPC services.